Compare commits

...

24 Commits
v2.5.2 ... oops

Author SHA1 Message Date
Yuhong Sun
8a38fdf8a5 ok 2025-11-29 22:46:22 -08:00
Vega
9155d4aa21 [FIX] Fix citation document mismatch and standardize citation format (#6484) 2025-11-29 22:43:03 -08:00
Yuhong Sun
b20591611a Single Commit Rebased 2025-11-29 21:22:58 -08:00
SubashMohan
83e756bf05 fix(Projects): file ordering in project panel (#6334) 2025-11-29 19:40:01 +00:00
Jamison Lahman
19b485cffd chore(deps): upgrade supervisor 4.2.5->4.3.0 (#6466) 2025-11-26 18:38:13 -05:00
Jamison Lahman
f5a99053ac chore(deps): upgrade dropbox 11.36.2->12.0.2 (#6467) 2025-11-26 18:10:13 -05:00
Chris Weaver
91f0377dd5 chore: enable code interpreter tests (#6404) 2025-11-26 14:55:07 -08:00
Jamison Lahman
25522dfbb8 chore(gha): setup-python accepts requirements to install (#6463) 2025-11-26 17:27:30 -05:00
Jamison Lahman
b0e124ec89 chore(deps): upgrade pytest-asyncio 0.22.0->1.3.0 (#6461) 2025-11-26 16:39:52 -05:00
Raunak Bhagat
b699a65384 refactor: Edit Modal.Header to be more concise and adherent to mocks (#6452) 2025-11-26 13:17:51 -08:00
Jamison Lahman
cc82d6e506 chore(deps): remove non-dev packages (#6462) 2025-11-26 16:17:01 -05:00
Jamison Lahman
8a6db7474d chore(gha): assert GHA jobs have timeouts (#6455) 2025-11-26 18:14:23 +00:00
Jamison Lahman
fd9aea212b chore(dev): run mypy and uv-sync on pre-commit (#6454) 2025-11-26 17:24:28 +00:00
acaprau
4aed383e49 chore(logs): When final doc for context pruning gets pruned, that prob doesn't need to be an error (#6451)
Co-authored-by: Andrei <andrei@Andreis-MacBook-Pro.local>
2025-11-25 22:41:47 -08:00
Justin Tahara
d0ce313b1a fix(google): Fix embedding scopes (#6450) 2025-11-25 22:10:42 -06:00
Jamison Lahman
4d32c9f5e0 chore(python): use uv to manage and compile requirements (#6291) 2025-11-26 03:01:52 +00:00
Justin Tahara
158fe31b71 fix(azure): Normalizing Azure Target URIs (#6443) 2025-11-26 00:19:22 +00:00
Raunak Bhagat
97cddc1dd4 fix: Line item cleanup (#6444)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-11-25 16:11:16 -08:00
Chris Weaver
c520a4ec17 fix: spinner during CSV load (#6441)
Co-authored-by: SubashMohan <subashmohan75@gmail.com>
2025-11-25 22:01:55 +00:00
Raunak Bhagat
9c1f8cc98c refactor: Line item cleanup (#6434) 2025-11-25 13:51:17 -08:00
Justin Tahara
58ba8cc68a chore(langfuse): Remove Env Var (#6440) 2025-11-25 15:32:15 -06:00
Evan Lohn
a307b0d366 fix: use raw mcp url (#6432) 2025-11-25 21:10:03 +00:00
Wenxi
e34f58e994 refactor(tests): use PATManager for tests that use PATs (#6438) 2025-11-25 15:39:49 -05:00
Justin Tahara
7f6dd2dc93 feat(api): Add Users to Group Endpoint (#6427) 2025-11-25 20:12:20 +00:00
362 changed files with 38656 additions and 27744 deletions

View File

@@ -1,42 +0,0 @@
name: "Prepare Build (OpenAPI generation)"
description: "Sets up Python with uv, installs deps, generates OpenAPI schema and Python client, uploads artifact"
inputs:
docker-username:
required: true
docker-password:
required: true
runs:
using: "composite"
steps:
- name: Setup Python and Install Dependencies
uses: ./.github/actions/setup-python-and-install-dependencies
- name: Generate OpenAPI schema
shell: bash
working-directory: backend
env:
PYTHONPATH: "."
run: |
python scripts/onyx_openapi_schema.py --filename generated/openapi.json
# needed for pulling openapitools/openapi-generator-cli
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ inputs['docker-username'] }}
password: ${{ inputs['docker-password'] }}
- name: Generate OpenAPI Python client
shell: bash
run: |
docker run --rm \
-v "${{ github.workspace }}/backend/generated:/local" \
openapitools/openapi-generator-cli generate \
-i /local/openapi.json \
-g python \
-o /local/onyx_openapi_client \
--package-name onyx_openapi_client \
--skip-validate-spec \
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"

View File

@@ -1,5 +1,9 @@
name: "Setup Python and Install Dependencies"
description: "Sets up Python with uv and installs deps"
inputs:
requirements:
description: "Newline-separated list of requirement files to install (relative to repo root)"
required: true
runs:
using: "composite"
steps:
@@ -9,11 +13,26 @@ runs:
# with:
# enable-cache: true
- name: Compute requirements hash
id: req-hash
shell: bash
env:
REQUIREMENTS: ${{ inputs.requirements }}
run: |
# Hash the contents of the specified requirement files
hash=""
while IFS= read -r req; do
if [ -n "$req" ] && [ -f "$req" ]; then
hash="$hash$(sha256sum "$req")"
fi
done <<< "$REQUIREMENTS"
echo "hash=$(echo "$hash" | sha256sum | cut -d' ' -f1)" >> "$GITHUB_OUTPUT"
- name: Cache uv cache directory
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: ~/.cache/uv
key: ${{ runner.os }}-uv-${{ hashFiles('backend/requirements/*.txt', 'backend/pyproject.toml') }}
key: ${{ runner.os }}-uv-${{ steps.req-hash.outputs.hash }}
restore-keys: |
${{ runner.os }}-uv-
@@ -38,8 +57,16 @@ runs:
- name: Install Python dependencies with uv
shell: bash
env:
REQUIREMENTS: ${{ inputs.requirements }}
run: |
uv pip install \
-r backend/requirements/default.txt \
-r backend/requirements/dev.txt \
-r backend/requirements/model_server.txt
# Build the uv pip install command with each requirement file as array elements
cmd=("uv" "pip" "install")
while IFS= read -r req; do
# Skip empty lines
if [ -n "$req" ]; then
cmd+=("-r" "$req")
fi
done <<< "$REQUIREMENTS"
echo "Running: ${cmd[*]}"
"${cmd[@]}"

View File

@@ -16,6 +16,7 @@ permissions:
jobs:
check-lazy-imports:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Checkout code

View File

@@ -18,6 +18,7 @@ jobs:
determine-builds:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 90
outputs:
build-web: ${{ steps.check.outputs.build-web }}
build-web-cloud: ${{ steps.check.outputs.build-web-cloud }}
@@ -90,6 +91,7 @@ jobs:
- runner=4cpu-linux-x64
- run-id=${{ github.run_id }}-web-amd64
- extras=ecr-cache
timeout-minutes: 90
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -147,6 +149,7 @@ jobs:
- runner=4cpu-linux-arm64
- run-id=${{ github.run_id }}-web-arm64
- extras=ecr-cache
timeout-minutes: 90
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -206,6 +209,7 @@ jobs:
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-merge-web
- extras=ecr-cache
timeout-minutes: 90
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
steps:
@@ -253,6 +257,7 @@ jobs:
- runner=4cpu-linux-x64
- run-id=${{ github.run_id }}-web-cloud-amd64
- extras=ecr-cache
timeout-minutes: 90
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -318,6 +323,7 @@ jobs:
- runner=4cpu-linux-arm64
- run-id=${{ github.run_id }}-web-cloud-arm64
- extras=ecr-cache
timeout-minutes: 90
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -385,6 +391,7 @@ jobs:
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-merge-web-cloud
- extras=ecr-cache
timeout-minutes: 90
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
steps:
@@ -429,6 +436,7 @@ jobs:
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-backend-amd64
- extras=ecr-cache
timeout-minutes: 90
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -485,6 +493,7 @@ jobs:
- runner=2cpu-linux-arm64
- run-id=${{ github.run_id }}-backend-arm64
- extras=ecr-cache
timeout-minutes: 90
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -543,6 +552,7 @@ jobs:
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-merge-backend
- extras=ecr-cache
timeout-minutes: 90
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
steps:
@@ -591,6 +601,7 @@ jobs:
- run-id=${{ github.run_id }}-model-server-amd64
- volume=40gb
- extras=ecr-cache
timeout-minutes: 90
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -654,6 +665,7 @@ jobs:
- run-id=${{ github.run_id }}-model-server-arm64
- volume=40gb
- extras=ecr-cache
timeout-minutes: 90
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -718,6 +730,7 @@ jobs:
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-merge-model-server
- extras=ecr-cache
timeout-minutes: 90
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
steps:
@@ -767,6 +780,7 @@ jobs:
- runner=2cpu-linux-arm64
- run-id=${{ github.run_id }}-trivy-scan-web
- extras=ecr-cache
timeout-minutes: 90
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
steps:
@@ -806,6 +820,7 @@ jobs:
- runner=2cpu-linux-arm64
- run-id=${{ github.run_id }}-trivy-scan-web-cloud
- extras=ecr-cache
timeout-minutes: 90
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
steps:
@@ -845,6 +860,7 @@ jobs:
- runner=2cpu-linux-arm64
- run-id=${{ github.run_id }}-trivy-scan-backend
- extras=ecr-cache
timeout-minutes: 90
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
steps:
@@ -891,6 +907,7 @@ jobs:
- runner=2cpu-linux-arm64
- run-id=${{ github.run_id }}-trivy-scan-model-server
- extras=ecr-cache
timeout-minutes: 90
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
steps:
@@ -937,6 +954,7 @@ jobs:
if: always() && (needs.build-web-amd64.result == 'failure' || needs.build-web-arm64.result == 'failure' || needs.merge-web.result == 'failure' || needs.build-web-cloud-amd64.result == 'failure' || needs.build-web-cloud-arm64.result == 'failure' || needs.merge-web-cloud.result == 'failure' || needs.build-backend-amd64.result == 'failure' || needs.build-backend-arm64.result == 'failure' || needs.merge-backend.result == 'failure' || needs.build-model-server-amd64.result == 'failure' || needs.build-model-server-arm64.result == 'failure' || needs.merge-model-server.result == 'failure') && github.event_name != 'workflow_dispatch'
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 90
steps:
- name: Checkout
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6

View File

@@ -18,6 +18,7 @@ jobs:
# See https://runs-on.com/runners/linux/
# use a lower powered instance since this just does i/o to docker hub
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-tag"]
timeout-minutes: 45
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3

View File

@@ -18,6 +18,7 @@ jobs:
# See https://runs-on.com/runners/linux/
# use a lower powered instance since this just does i/o to docker hub
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-tag"]
timeout-minutes: 45
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3

View File

@@ -12,6 +12,7 @@ jobs:
permissions:
contents: write
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Checkout
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6

View File

@@ -11,6 +11,7 @@ permissions:
jobs:
stale:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # ratchet:actions/stale@v10
with:

View File

@@ -20,6 +20,7 @@ jobs:
scan-licenses:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-licenses"]
timeout-minutes: 45
permissions:
actions: read
contents: read
@@ -89,6 +90,7 @@ jobs:
scan-trivy:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-trivy"]
timeout-minutes: 45
steps:
- name: Set up Docker Buildx

View File

@@ -12,12 +12,14 @@ permissions:
contents: read
env:
# AWS
S3_AWS_ACCESS_KEY_ID: ${{ secrets.S3_AWS_ACCESS_KEY_ID }}
S3_AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_AWS_SECRET_ACCESS_KEY }}
# AWS credentials for S3-specific test
S3_AWS_ACCESS_KEY_ID_FOR_TEST: ${{ secrets.S3_AWS_ACCESS_KEY_ID }}
S3_AWS_SECRET_ACCESS_KEY_FOR_TEST: ${{ secrets.S3_AWS_SECRET_ACCESS_KEY }}
# MinIO
S3_ENDPOINT_URL: "http://localhost:9004"
S3_AWS_ACCESS_KEY_ID: "minioadmin"
S3_AWS_SECRET_ACCESS_KEY: "minioadmin"
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
@@ -33,12 +35,13 @@ env:
# Code Interpreter
# TODO: debug why this is failing and enable
# CODE_INTERPRETER_BASE_URL: http://localhost:8000
CODE_INTERPRETER_BASE_URL: http://localhost:8000
jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
outputs:
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
steps:
@@ -62,6 +65,7 @@ jobs:
- runner=2cpu-linux-arm64
- ${{ format('run-id={0}-external-dependency-unit-tests-job-{1}', github.run_id, strategy['job-index']) }}
- extras=s3-cache
timeout-minutes: 45
strategy:
fail-fast: false
matrix:
@@ -81,6 +85,11 @@ jobs:
- name: Setup Python and Install Dependencies
uses: ./.github/actions/setup-python-and-install-dependencies
with:
requirements: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/ee.txt
- name: Setup Playwright
uses: ./.github/actions/setup-playwright
@@ -94,6 +103,12 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Create .env file for Docker Compose
run: |
cat <<EOF > deployment/docker_compose/.env
CODE_INTERPRETER_BETA_ENABLED=true
EOF
- name: Set up Standard Dependencies
run: |
cd deployment/docker_compose
@@ -125,3 +140,30 @@ jobs:
-xv \
--ff \
backend/tests/external_dependency_unit/${TEST_DIR}
- name: Collect Docker logs on failure
if: failure()
run: |
mkdir -p docker-logs
cd deployment/docker_compose
# Get list of running containers
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml ps -q)
# Collect logs from each container
for container in $containers; do
container_name=$(docker inspect --format='{{.Name}}' $container | sed 's/^\///')
echo "Collecting logs from $container_name..."
docker logs $container > ../../docker-logs/${container_name}.log 2>&1
done
cd ../..
echo "Docker logs collected in docker-logs directory"
- name: Upload Docker logs
if: failure()
uses: actions/upload-artifact@v4
with:
name: docker-logs-${{ matrix.test-dir }}
path: docker-logs/
retention-days: 7

View File

@@ -16,6 +16,7 @@ jobs:
helm-chart-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}-helm-chart-check"]
timeout-minutes: 45
# fetch-depth 0 is required for helm/chart-testing-action
steps:

View File

@@ -35,6 +35,7 @@ jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
outputs:
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
steps:
@@ -66,6 +67,7 @@ jobs:
build-backend-image:
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-backend-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
@@ -102,6 +104,7 @@ jobs:
build-model-server-image:
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-model-server-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
@@ -136,6 +139,7 @@ jobs:
build-integration-image:
runs-on: [runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-build-integration-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
@@ -181,6 +185,7 @@ jobs:
- runner=4cpu-linux-arm64
- ${{ format('run-id={0}-integration-tests-job-{1}', github.run_id, strategy['job-index']) }}
- extras=ecr-cache
timeout-minutes: 45
strategy:
fail-fast: false
@@ -362,6 +367,7 @@ jobs:
build-integration-image,
]
runs-on: [runs-on, runner=8cpu-linux-arm64, "run-id=${{ github.run_id }}-multitenant-tests", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -491,6 +497,7 @@ jobs:
required:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
needs: [integration-tests, multitenant-tests]
if: ${{ always() }}
steps:

View File

@@ -12,6 +12,7 @@ jobs:
jest-tests:
name: Jest Tests
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6

View File

@@ -16,6 +16,7 @@ permissions:
jobs:
validate_pr_title:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Check PR title for Conventional Commits
env:

View File

@@ -13,6 +13,7 @@ permissions:
jobs:
linear-check:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Check PR body for Linear link or override
env:

View File

@@ -32,6 +32,7 @@ jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
outputs:
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
steps:
@@ -62,6 +63,7 @@ jobs:
build-backend-image:
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-backend-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
@@ -96,6 +98,7 @@ jobs:
build-model-server-image:
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-model-server-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
@@ -129,6 +132,7 @@ jobs:
build-integration-image:
runs-on: [runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-build-integration-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
@@ -174,6 +178,7 @@ jobs:
- runner=4cpu-linux-arm64
- ${{ format('run-id={0}-integration-tests-mit-job-{1}', github.run_id, strategy['job-index']) }}
- extras=ecr-cache
timeout-minutes: 45
strategy:
fail-fast: false
@@ -349,6 +354,7 @@ jobs:
required:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
needs: [integration-tests-mit]
if: ${{ always() }}
steps:

View File

@@ -47,6 +47,7 @@ env:
jobs:
build-web-image:
runs-on: [runs-on, runner=4cpu-linux-arm64, "run-id=${{ github.run_id }}-build-web-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -82,6 +83,7 @@ jobs:
build-backend-image:
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-backend-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -118,6 +120,7 @@ jobs:
build-model-server-image:
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-model-server-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -160,6 +163,7 @@ jobs:
- "run-id=${{ github.run_id }}-playwright-tests-${{ matrix.project }}"
- "extras=ecr-cache"
- volume=50gb
timeout-minutes: 45
strategy:
fail-fast: false
matrix:
@@ -378,6 +382,7 @@ jobs:
playwright-required:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
needs: [playwright-tests]
if: ${{ always() }}
steps:

View File

@@ -14,11 +14,30 @@ permissions:
contents: read
jobs:
validate-requirements:
runs-on: ubuntu-slim
timeout-minutes: 45
steps:
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Setup uv
uses: astral-sh/setup-uv@caf0cab7a618c569241d31dcd442f54681755d39 # ratchet:astral-sh/setup-uv@v3
# TODO: Enable caching once there is a uv.lock file checked in.
# with:
# enable-cache: true
- name: Validate requirements lock files
run: ./backend/scripts/compile_requirements.py --check
mypy-check:
# See https://runs-on.com/runners/linux/
# Note: Mypy seems quite optimized for x64 compared to arm64.
# Similarly, mypy is single-threaded and incremental, so 2cpu is sufficient.
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-mypy-check", "extras=s3-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -27,6 +46,23 @@ jobs:
with:
persist-credentials: false
- name: Setup Python and Install Dependencies
uses: ./.github/actions/setup-python-and-install-dependencies
with:
requirements: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
backend/requirements/ee.txt
- name: Generate OpenAPI schema
shell: bash
working-directory: backend
env:
PYTHONPATH: "."
run: |
python scripts/onyx_openapi_schema.py --filename generated/openapi.json
# needed for pulling openapitools/openapi-generator-cli
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
@@ -36,11 +72,18 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Prepare build
uses: ./.github/actions/prepare-build
with:
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-password: ${{ secrets.DOCKER_TOKEN }}
- name: Generate OpenAPI Python client
shell: bash
run: |
docker run --rm \
-v "${{ github.workspace }}/backend/generated:/local" \
openapitools/openapi-generator-cli generate \
-i /local/openapi.json \
-g python \
-o /local/onyx_openapi_client \
--package-name onyx_openapi_client \
--skip-validate-spec \
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
- name: Cache mypy cache
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}

View File

@@ -126,6 +126,7 @@ jobs:
connectors-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-connectors-check", "extras=s3-cache"]
timeout-minutes: 45
env:
PYTHONPATH: ./backend
@@ -140,6 +141,10 @@ jobs:
- name: Setup Python and Install Dependencies
uses: ./.github/actions/setup-python-and-install-dependencies
with:
requirements: |
backend/requirements/default.txt
backend/requirements/dev.txt
- name: Setup Playwright
uses: ./.github/actions/setup-playwright

View File

@@ -32,6 +32,7 @@ jobs:
model-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}-model-check"]
timeout-minutes: 45
env:
PYTHONPATH: ./backend

View File

@@ -17,6 +17,7 @@ jobs:
backend-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-backend-check"]
timeout-minutes: 45
env:
@@ -36,6 +37,12 @@ jobs:
- name: Setup Python and Install Dependencies
uses: ./.github/actions/setup-python-and-install-dependencies
with:
requirements: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
backend/requirements/ee.txt
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"

View File

@@ -14,6 +14,7 @@ jobs:
quality-checks:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-quality-checks"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
@@ -26,5 +27,13 @@ jobs:
- name: Setup Terraform
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # ratchet:hashicorp/setup-terraform@v3
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet:pre-commit/action@v3.0.1
env:
# uv-run is mypy's id and mypy is covered by the Python Checks which caches dependencies better.
SKIP: uv-run
with:
extra_args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || '' }}
- name: Check Actions
uses: giner/check-actions@28d366c7cbbe235f9624a88aa31a628167eee28c # ratchet:giner/check-actions@v1.0.1
with:
check_permissions: false
check_versions: false

View File

@@ -9,6 +9,7 @@ on:
jobs:
sync-foss:
runs-on: ubuntu-latest
timeout-minutes: 45
permissions:
contents: read
steps:

View File

@@ -11,6 +11,7 @@ permissions:
jobs:
create-and-push-tag:
runs-on: ubuntu-slim
timeout-minutes: 45
steps:
# actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes

View File

@@ -12,6 +12,7 @@ jobs:
zizmor:
name: zizmor
runs-on: ubuntu-slim
timeout-minutes: 45
permissions:
security-events: write # needed for SARIF uploads
steps:

2
.gitignore vendored
View File

@@ -49,5 +49,7 @@ CLAUDE.md
# Local .terraform.lock.hcl file
.terraform.lock.hcl
node_modules
# MCP configs
.playwright-mcp

View File

@@ -1,4 +1,20 @@
default_install_hook_types:
- pre-commit
- post-checkout
- post-merge
- post-rewrite
repos:
- repo: https://github.com/astral-sh/uv-pre-commit
# This revision is from https://github.com/astral-sh/uv-pre-commit/pull/53
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
hooks:
- id: uv-sync
- id: uv-run
name: mypy
args: ["mypy"]
pass_filenames: true
files: ^backend/.*\.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
@@ -71,31 +87,3 @@ repos:
entry: python3 backend/scripts/check_lazy_imports.py
language: system
files: ^backend/(?!\.venv/).*\.py$
# We would like to have a mypy pre-commit hook, but due to the fact that
# pre-commit runs in it's own isolated environment, we would need to install
# and keep in sync all dependencies so mypy has access to the appropriate type
# stubs. This does not seem worth it at the moment, so for now we will stick to
# having mypy run via Github Actions / manually by contributors
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.1.1
# hooks:
# - id: mypy
# exclude: ^tests/
# # below are needed for type stubs since pre-commit runs in it's own
# # isolated environment. Unfortunately, this needs to be kept in sync
# # with requirements/dev.txt + requirements/default.txt
# additional_dependencies: [
# alembic==1.10.4,
# types-beautifulsoup4==4.12.0.3,
# types-html5lib==1.1.11.13,
# types-oauthlib==3.2.0.9,
# types-psycopg2==2.9.21.10,
# types-python-dateutil==2.8.19.13,
# types-regex==2023.3.23.1,
# types-requests==2.28.11.17,
# types-retry==0.9.9.3,
# types-urllib3==1.26.25.11
# ]
# # TODO: add back once errors are addressed
# # args: [--strict]

View File

@@ -0,0 +1,104 @@
"""add_open_url_tool
Revision ID: 4f8a2b3c1d9e
Revises: a852cbe15577
Create Date: 2025-11-24 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4f8a2b3c1d9e"
down_revision = "a852cbe15577"
branch_labels = None
depends_on = None
OPEN_URL_TOOL = {
"name": "OpenURLTool",
"display_name": "Open URL",
"description": (
"The Open URL Action allows the agent to fetch and read contents of web pages."
),
"in_code_tool_id": "OpenURLTool",
"enabled": True,
}
def upgrade() -> None:
conn = op.get_bind()
# Check if tool already exists
existing = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
{"in_code_tool_id": OPEN_URL_TOOL["in_code_tool_id"]},
).fetchone()
if existing:
tool_id = existing[0]
# Update existing tool
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description
WHERE in_code_tool_id = :in_code_tool_id
"""
),
OPEN_URL_TOOL,
)
else:
# Insert new tool
conn.execute(
sa.text(
"""
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled)
"""
),
OPEN_URL_TOOL,
)
# Get the newly inserted tool's id
result = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
{"in_code_tool_id": OPEN_URL_TOOL["in_code_tool_id"]},
).fetchone()
tool_id = result[0] # type: ignore
# Associate the tool with all existing personas
# Get all persona IDs
persona_ids = conn.execute(sa.text("SELECT id FROM persona")).fetchall()
for (persona_id,) in persona_ids:
# Check if association already exists
exists = conn.execute(
sa.text(
"""
SELECT 1 FROM persona__tool
WHERE persona_id = :persona_id AND tool_id = :tool_id
"""
),
{"persona_id": persona_id, "tool_id": tool_id},
).fetchone()
if not exists:
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (:persona_id, :tool_id)
"""
),
{"persona_id": persona_id, "tool_id": tool_id},
)
def downgrade() -> None:
# We don't remove the tool on downgrade since it's fine to have it around.
# If we upgrade again, it will be a no-op.
pass

View File

@@ -0,0 +1,44 @@
"""add_created_at_in_project_userfile
Revision ID: 6436661d5b65
Revises: c7e9f4a3b2d1
Create Date: 2025-11-24 11:50:24.536052
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "6436661d5b65"
down_revision = "c7e9f4a3b2d1"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add created_at column to project__user_file table
op.add_column(
"project__user_file",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
)
# Add composite index on (project_id, created_at DESC)
op.create_index(
"ix_project__user_file_project_id_created_at",
"project__user_file",
["project_id", sa.text("created_at DESC")],
)
def downgrade() -> None:
# Remove composite index on (project_id, created_at)
op.drop_index(
"ix_project__user_file_project_id_created_at", table_name="project__user_file"
)
# Remove created_at column from project__user_file table
op.drop_column("project__user_file", "created_at")

View File

@@ -0,0 +1,572 @@
"""New Chat History
Revision ID: a852cbe15577
Revises: 6436661d5b65
Create Date: 2025-11-08 15:16:37.781308
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "a852cbe15577"
down_revision = "6436661d5b65"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop research agent tables (if they exist)
op.execute("DROP TABLE IF EXISTS research_agent_iteration_sub_step CASCADE")
op.execute("DROP TABLE IF EXISTS research_agent_iteration CASCADE")
# Drop agent sub query and sub question tables (if they exist)
op.execute("DROP TABLE IF EXISTS agent__sub_query__search_doc CASCADE")
op.execute("DROP TABLE IF EXISTS agent__sub_query CASCADE")
op.execute("DROP TABLE IF EXISTS agent__sub_question CASCADE")
# Update ChatMessage table
# Rename parent_message to parent_message_id and make it a foreign key (if not already done)
conn = op.get_bind()
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'chat_message' AND column_name = 'parent_message'
"""
)
)
if result.fetchone():
op.alter_column(
"chat_message", "parent_message", new_column_name="parent_message_id"
)
op.create_foreign_key(
"fk_chat_message_parent_message_id",
"chat_message",
"chat_message",
["parent_message_id"],
["id"],
)
# Rename latest_child_message to latest_child_message_id and make it a foreign key (if not already done)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'chat_message' AND column_name = 'latest_child_message'
"""
)
)
if result.fetchone():
op.alter_column(
"chat_message",
"latest_child_message",
new_column_name="latest_child_message_id",
)
op.create_foreign_key(
"fk_chat_message_latest_child_message_id",
"chat_message",
"chat_message",
["latest_child_message_id"],
["id"],
)
# Add reasoning_tokens column (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'chat_message' AND column_name = 'reasoning_tokens'
"""
)
)
if not result.fetchone():
op.add_column(
"chat_message", sa.Column("reasoning_tokens", sa.Text(), nullable=True)
)
# Drop columns no longer needed (if they exist)
for col in [
"rephrased_query",
"alternate_assistant_id",
"overridden_model",
"is_agentic",
"refined_answer_improvement",
"research_type",
"research_plan",
"research_answer_purpose",
]:
result = conn.execute(
sa.text(
f"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'chat_message' AND column_name = '{col}'
"""
)
)
if result.fetchone():
op.drop_column("chat_message", col)
# Update ToolCall table
# Add chat_session_id column (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'chat_session_id'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
)
op.create_foreign_key(
"fk_tool_call_chat_session_id",
"tool_call",
"chat_session",
["chat_session_id"],
["id"],
)
# Rename message_id to parent_chat_message_id and make nullable (if not already done)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'message_id'
"""
)
)
if result.fetchone():
op.alter_column(
"tool_call",
"message_id",
new_column_name="parent_chat_message_id",
nullable=True,
)
# Add parent_tool_call_id (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'parent_tool_call_id'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call", sa.Column("parent_tool_call_id", sa.Integer(), nullable=True)
)
op.create_foreign_key(
"fk_tool_call_parent_tool_call_id",
"tool_call",
"tool_call",
["parent_tool_call_id"],
["id"],
)
op.drop_constraint("uq_tool_call_message_id", "tool_call", type_="unique")
# Add turn_number, tool_id (if not exists)
for col_name in ["turn_number", "tool_id"]:
result = conn.execute(
sa.text(
f"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = '{col_name}'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column(col_name, sa.Integer(), nullable=False, server_default="0"),
)
# Add tool_call_id as String (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_call_id'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column("tool_call_id", sa.String(), nullable=False, server_default=""),
)
# Add reasoning_tokens (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'reasoning_tokens'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call", sa.Column("reasoning_tokens", sa.Text(), nullable=True)
)
# Rename tool_arguments to tool_call_arguments (if not already done)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_arguments'
"""
)
)
if result.fetchone():
op.alter_column(
"tool_call", "tool_arguments", new_column_name="tool_call_arguments"
)
# Rename tool_result to tool_call_response and change type from JSONB to Text (if not already done)
result = conn.execute(
sa.text(
"""
SELECT column_name, data_type FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_result'
"""
)
)
tool_result_row = result.fetchone()
if tool_result_row:
op.alter_column(
"tool_call", "tool_result", new_column_name="tool_call_response"
)
# Change type from JSONB to Text
op.execute(
sa.text(
"""
ALTER TABLE tool_call
ALTER COLUMN tool_call_response TYPE TEXT
USING tool_call_response::text
"""
)
)
else:
# Check if tool_call_response already exists and is JSONB, then convert to Text
result = conn.execute(
sa.text(
"""
SELECT data_type FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_call_response'
"""
)
)
tool_call_response_row = result.fetchone()
if tool_call_response_row and tool_call_response_row[0] == "jsonb":
op.execute(
sa.text(
"""
ALTER TABLE tool_call
ALTER COLUMN tool_call_response TYPE TEXT
USING tool_call_response::text
"""
)
)
# Add tool_call_tokens (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_call_tokens'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column(
"tool_call_tokens", sa.Integer(), nullable=False, server_default="0"
),
)
# Add generated_images column for image generation tool replay (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'generated_images'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column("generated_images", postgresql.JSONB(), nullable=True),
)
# Drop tool_name column (if exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_name'
"""
)
)
if result.fetchone():
op.drop_column("tool_call", "tool_name")
# Create tool_call__search_doc association table (if not exists)
result = conn.execute(
sa.text(
"""
SELECT table_name FROM information_schema.tables
WHERE table_name = 'tool_call__search_doc'
"""
)
)
if not result.fetchone():
op.create_table(
"tool_call__search_doc",
sa.Column("tool_call_id", sa.Integer(), nullable=False),
sa.Column("search_doc_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["tool_call_id"], ["tool_call.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(
["search_doc_id"], ["search_doc.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("tool_call_id", "search_doc_id"),
)
# Add replace_base_system_prompt to persona table (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'persona' AND column_name = 'replace_base_system_prompt'
"""
)
)
if not result.fetchone():
op.add_column(
"persona",
sa.Column(
"replace_base_system_prompt",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
def downgrade() -> None:
# Reverse persona changes
op.drop_column("persona", "replace_base_system_prompt")
# Drop tool_call__search_doc association table
op.execute("DROP TABLE IF EXISTS tool_call__search_doc CASCADE")
# Reverse ToolCall changes
op.add_column("tool_call", sa.Column("tool_name", sa.String(), nullable=False))
op.drop_column("tool_call", "tool_id")
op.drop_column("tool_call", "tool_call_tokens")
op.drop_column("tool_call", "generated_images")
# Change tool_call_response back to JSONB before renaming
op.execute(
sa.text(
"""
ALTER TABLE tool_call
ALTER COLUMN tool_call_response TYPE JSONB
USING tool_call_response::jsonb
"""
)
)
op.alter_column("tool_call", "tool_call_response", new_column_name="tool_result")
op.alter_column(
"tool_call", "tool_call_arguments", new_column_name="tool_arguments"
)
op.drop_column("tool_call", "reasoning_tokens")
op.drop_column("tool_call", "tool_call_id")
op.drop_column("tool_call", "turn_number")
op.drop_constraint(
"fk_tool_call_parent_tool_call_id", "tool_call", type_="foreignkey"
)
op.drop_column("tool_call", "parent_tool_call_id")
op.alter_column(
"tool_call",
"parent_chat_message_id",
new_column_name="message_id",
nullable=False,
)
op.drop_constraint("fk_tool_call_chat_session_id", "tool_call", type_="foreignkey")
op.drop_column("tool_call", "chat_session_id")
op.add_column(
"chat_message",
sa.Column(
"research_answer_purpose",
sa.Enum("INTRO", "DEEP_DIVE", name="researchanswerpurpose"),
nullable=True,
),
)
op.add_column(
"chat_message", sa.Column("research_plan", postgresql.JSONB(), nullable=True)
)
op.add_column(
"chat_message",
sa.Column(
"research_type",
sa.Enum("SIMPLE", "DEEP", name="researchtype"),
nullable=True,
),
)
op.add_column(
"chat_message",
sa.Column("refined_answer_improvement", sa.Boolean(), nullable=True),
)
op.add_column(
"chat_message",
sa.Column("is_agentic", sa.Boolean(), nullable=False, server_default="false"),
)
op.add_column(
"chat_message", sa.Column("overridden_model", sa.String(), nullable=True)
)
op.add_column(
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
)
op.add_column(
"chat_message", sa.Column("rephrased_query", sa.Text(), nullable=True)
)
op.drop_column("chat_message", "reasoning_tokens")
op.drop_constraint(
"fk_chat_message_latest_child_message_id", "chat_message", type_="foreignkey"
)
op.alter_column(
"chat_message",
"latest_child_message_id",
new_column_name="latest_child_message",
)
op.drop_constraint(
"fk_chat_message_parent_message_id", "chat_message", type_="foreignkey"
)
op.alter_column(
"chat_message", "parent_message_id", new_column_name="parent_message"
)
# Recreate agent sub question and sub query tables
op.create_table(
"agent__sub_question",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("primary_question_id", sa.Integer(), nullable=False),
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("sub_question", sa.Text(), nullable=False),
sa.Column("level", sa.Integer(), nullable=False),
sa.Column("level_question_num", sa.Integer(), nullable=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("sub_answer", sa.Text(), nullable=False),
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=False),
sa.ForeignKeyConstraint(
["primary_question_id"], ["chat_message.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["chat_session_id"], ["chat_session.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"agent__sub_query",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("parent_question_id", sa.Integer(), nullable=False),
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("sub_query", sa.Text(), nullable=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["parent_question_id"], ["agent__sub_question.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["chat_session_id"], ["chat_session.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"agent__sub_query__search_doc",
sa.Column("sub_query_id", sa.Integer(), nullable=False),
sa.Column("search_doc_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["sub_query_id"], ["agent__sub_query.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["search_doc_id"], ["search_doc.id"]),
sa.PrimaryKeyConstraint("sub_query_id", "search_doc_id"),
)
# Recreate research agent tables
op.create_table(
"research_agent_iteration",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("primary_question_id", sa.Integer(), nullable=False),
sa.Column("iteration_nr", sa.Integer(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("purpose", sa.String(), nullable=True),
sa.Column("reasoning", sa.String(), nullable=True),
sa.ForeignKeyConstraint(
["primary_question_id"], ["chat_message.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"primary_question_id",
"iteration_nr",
name="_research_agent_iteration_unique_constraint",
),
)
op.create_table(
"research_agent_iteration_sub_step",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("primary_question_id", sa.Integer(), nullable=False),
sa.Column("iteration_nr", sa.Integer(), nullable=False),
sa.Column("iteration_sub_step_nr", sa.Integer(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("sub_step_instructions", sa.String(), nullable=True),
sa.Column("sub_step_tool_id", sa.Integer(), nullable=True),
sa.Column("reasoning", sa.String(), nullable=True),
sa.Column("sub_answer", sa.String(), nullable=True),
sa.Column("cited_doc_results", postgresql.JSONB(), nullable=False),
sa.Column("claims", postgresql.JSONB(), nullable=True),
sa.Column("is_web_fetch", sa.Boolean(), nullable=True),
sa.Column("queries", postgresql.JSONB(), nullable=True),
sa.Column("generated_images", postgresql.JSONB(), nullable=True),
sa.Column("additional_data", postgresql.JSONB(), nullable=True),
sa.ForeignKeyConstraint(
["primary_question_id", "iteration_nr"],
[
"research_agent_iteration.primary_question_id",
"research_agent_iteration.iteration_nr",
],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(["sub_step_tool_id"], ["tool.id"], ondelete="SET NULL"),
sa.PrimaryKeyConstraint("id"),
)

View File

@@ -199,10 +199,7 @@ def fetch_persona_message_analytics(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
or_(
ChatMessage.alternate_assistant_id == persona_id,
ChatSession.persona_id == persona_id,
),
ChatSession.persona_id == persona_id,
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
@@ -231,10 +228,7 @@ def fetch_persona_unique_users(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
or_(
ChatMessage.alternate_assistant_id == persona_id,
ChatSession.persona_id == persona_id,
),
ChatSession.persona_id == persona_id,
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
@@ -265,10 +259,7 @@ def fetch_assistant_message_analytics(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
or_(
ChatMessage.alternate_assistant_id == assistant_id,
ChatSession.persona_id == assistant_id,
),
ChatSession.persona_id == assistant_id,
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
@@ -299,10 +290,7 @@ def fetch_assistant_unique_users(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
or_(
ChatMessage.alternate_assistant_id == assistant_id,
ChatSession.persona_id == assistant_id,
),
ChatSession.persona_id == assistant_id,
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
@@ -332,10 +320,7 @@ def fetch_assistant_unique_users_total(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
or_(
ChatMessage.alternate_assistant_id == assistant_id,
ChatSession.persona_id == assistant_id,
),
ChatSession.persona_id == assistant_id,
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,

View File

@@ -55,18 +55,7 @@ def get_empty_chat_messages_entries__paginated(
# Get assistant name (from session persona, or alternate if specified)
assistant_name = None
if message.alternate_assistant_id:
# If there's an alternate assistant, we need to fetch it
from onyx.db.models import Persona
alternate_persona = (
db_session.query(Persona)
.filter(Persona.id == message.alternate_assistant_id)
.first()
)
if alternate_persona:
assistant_name = alternate_persona.name
elif chat_session.persona:
if chat_session.persona:
assistant_name = chat_session.persona.name
message_skeletons.append(

View File

@@ -581,6 +581,48 @@ def update_user_curator_relationship(
db_session.commit()
def add_users_to_user_group(
db_session: Session,
user: User | None,
user_group_id: int,
user_ids: list[UUID],
) -> UserGroup:
db_user_group = fetch_user_group(db_session=db_session, user_group_id=user_group_id)
if db_user_group is None:
raise ValueError(f"UserGroup with id '{user_group_id}' not found")
missing_users = [
user_id for user_id in user_ids if fetch_user_by_id(db_session, user_id) is None
]
if missing_users:
raise ValueError(
f"User(s) not found: {', '.join(str(user_id) for user_id in missing_users)}"
)
_check_user_group_is_modifiable(db_user_group)
current_user_ids = [user.id for user in db_user_group.users]
current_user_ids_set = set(current_user_ids)
new_user_ids = [
user_id for user_id in user_ids if user_id not in current_user_ids_set
]
if not new_user_ids:
return db_user_group
user_group_update = UserGroupUpdate(
user_ids=current_user_ids + new_user_ids,
cc_pair_ids=[cc_pair.id for cc_pair in db_user_group.cc_pairs],
)
return update_user_group(
db_session=db_session,
user=user,
user_group_id=user_group_id,
user_group_update=user_group_update,
)
def update_user_group(
db_session: Session,
user: User | None,
@@ -603,6 +645,17 @@ def update_user_group(
added_user_ids = list(updated_user_ids - current_user_ids)
removed_user_ids = list(current_user_ids - updated_user_ids)
if added_user_ids:
missing_users = [
user_id
for user_id in added_user_ids
if fetch_user_by_id(db_session, user_id) is None
]
if missing_users:
raise ValueError(
f"User(s) not found: {', '.join(str(user_id) for user_id in missing_users)}"
)
# LEAVING THIS HERE FOR NOW FOR GIVING DIFFERENT ROLES
# ACCESS TO DIFFERENT PERMISSIONS
# if (removed_user_ids or added_user_ids) and (

View File

@@ -9,7 +9,7 @@ from ee.onyx.server.query_and_chat.models import (
)
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import create_chat_history_chain
from onyx.chat.models import ChatBasicResponse
from onyx.chat.process_message import gather_stream
from onyx.chat.process_message import stream_chat_message_objects
@@ -69,7 +69,7 @@ def handle_simplified_chat_message(
chat_session_id = chat_message_req.chat_session_id
try:
parent_message, _ = create_chat_chain(
parent_message, _ = create_chat_history_chain(
chat_session_id=chat_session_id, db_session=db_session
)
except Exception:

View File

@@ -8,10 +8,29 @@ from pydantic import model_validator
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import BasicChunkRequest
from onyx.context.search.models import ChunkContext
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import RetrievalDetails
from onyx.server.manage.models import StandardAnswer
from onyx.server.query_and_chat.streaming_models import SubQuestionIdentifier
class StandardAnswerRequest(BaseModel):
message: str
slack_bot_categories: list[str]
class StandardAnswerResponse(BaseModel):
standard_answers: list[StandardAnswer] = Field(default_factory=list)
class DocumentSearchRequest(BasicChunkRequest):
user_selected_filters: BaseFilters | None = None
class DocumentSearchResponse(BaseModel):
top_documents: list[InferenceChunk]
class BasicCreateChatMessageRequest(ChunkContext):
@@ -71,17 +90,17 @@ class SimpleDoc(BaseModel):
metadata: dict | None
class AgentSubQuestion(SubQuestionIdentifier):
class AgentSubQuestion(BaseModel):
sub_question: str
document_ids: list[str]
class AgentAnswer(SubQuestionIdentifier):
class AgentAnswer(BaseModel):
answer: str
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
class AgentSubQuery(SubQuestionIdentifier):
class AgentSubQuery(BaseModel):
sub_query: str
query_id: int
@@ -127,12 +146,3 @@ class AgentSubQuery(SubQuestionIdentifier):
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
)
return sorted_dict
class StandardAnswerRequest(BaseModel):
message: str
slack_bot_categories: list[str]
class StandardAnswerResponse(BaseModel):
standard_answers: list[StandardAnswer] = Field(default_factory=list)

View File

@@ -24,7 +24,7 @@ from onyx.auth.users import current_admin_user
from onyx.auth.users import get_display_email
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.background.task_utils import construct_query_history_report_name
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import create_chat_history_chain
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import FileType
@@ -123,10 +123,9 @@ def snapshot_from_chat_session(
) -> ChatSessionSnapshot | None:
try:
# Older chats may not have the right structure
last_message, messages = create_chat_chain(
messages = create_chat_history_chain(
chat_session_id=chat_session.id, db_session=db_session
)
messages.append(last_message)
except RuntimeError:
return None

View File

@@ -4,12 +4,14 @@ from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.db.user_group import add_users_to_user_group
from ee.onyx.db.user_group import fetch_user_groups
from ee.onyx.db.user_group import fetch_user_groups_for_user
from ee.onyx.db.user_group import insert_user_group
from ee.onyx.db.user_group import prepare_user_group_for_deletion
from ee.onyx.db.user_group import update_user_curator_relationship
from ee.onyx.db.user_group import update_user_group
from ee.onyx.server.user_group.models import AddUsersToUserGroupRequest
from ee.onyx.server.user_group.models import SetCuratorRequest
from ee.onyx.server.user_group.models import UserGroup
from ee.onyx.server.user_group.models import UserGroupCreate
@@ -79,6 +81,26 @@ def patch_user_group(
raise HTTPException(status_code=404, detail=str(e))
@router.post("/admin/user-group/{user_group_id}/add-users")
def add_users(
user_group_id: int,
add_users_request: AddUsersToUserGroupRequest,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> UserGroup:
try:
return UserGroup.from_model(
add_users_to_user_group(
db_session=db_session,
user=user,
user_group_id=user_group_id,
user_ids=add_users_request.user_ids,
)
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@router.post("/admin/user-group/{user_group_id}/set-curator")
def set_user_curator(
user_group_id: int,

View File

@@ -87,6 +87,10 @@ class UserGroupUpdate(BaseModel):
cc_pair_ids: list[int]
class AddUsersToUserGroupRequest(BaseModel):
user_ids: list[UUID]
class SetCuratorRequest(BaseModel):
user_id: UUID
is_curator: bool

View File

@@ -1,365 +1,309 @@
import json
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
# import json
# from collections.abc import Callable
# from collections.abc import Iterator
# from collections.abc import Sequence
# from dataclasses import dataclass
# from typing import Any
import onyx.tracing.framework._error_tracing as _error_tracing
from onyx.agents.agent_framework.models import RunItemStreamEvent
from onyx.agents.agent_framework.models import StreamEvent
from onyx.agents.agent_framework.models import ToolCallOutputStreamItem
from onyx.agents.agent_framework.models import ToolCallStreamItem
from onyx.llm.interfaces import LanguageModelInput
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.llm.message_types import ChatCompletionMessage
from onyx.llm.message_types import ToolCall
from onyx.llm.model_response import ModelResponseStream
from onyx.tools.tool import RunContextWrapper
from onyx.tools.tool import Tool
from onyx.tracing.framework.create import agent_span
from onyx.tracing.framework.create import function_span
from onyx.tracing.framework.create import generation_span
from onyx.tracing.framework.spans import SpanError
# from onyx.agents.agent_framework.models import RunItemStreamEvent
# from onyx.agents.agent_framework.models import StreamEvent
# from onyx.agents.agent_framework.models import ToolCallStreamItem
# from onyx.llm.interfaces import LanguageModelInput
# from onyx.llm.interfaces import LLM
# from onyx.llm.interfaces import ToolChoiceOptions
# from onyx.llm.message_types import ChatCompletionMessage
# from onyx.llm.message_types import ToolCall
# from onyx.llm.model_response import ModelResponseStream
# from onyx.tools.tool import Tool
# from onyx.tracing.framework.create import agent_span
# from onyx.tracing.framework.create import generation_span
@dataclass
class QueryResult:
stream: Iterator[StreamEvent]
new_messages_stateful: list[ChatCompletionMessage]
# @dataclass
# class QueryResult:
# stream: Iterator[StreamEvent]
# new_messages_stateful: list[ChatCompletionMessage]
def _serialize_tool_output(output: Any) -> str:
if isinstance(output, str):
return output
try:
return json.dumps(output)
except TypeError:
return str(output)
# def _serialize_tool_output(output: Any) -> str:
# if isinstance(output, str):
# return output
# try:
# return json.dumps(output)
# except TypeError:
# return str(output)
def _parse_tool_calls_from_message_content(
content: str,
) -> list[dict[str, Any]]:
"""Parse JSON content that represents tool call instructions."""
try:
parsed_content = json.loads(content)
except json.JSONDecodeError:
return []
# def _parse_tool_calls_from_message_content(
# content: str,
# ) -> list[dict[str, Any]]:
# """Parse JSON content that represents tool call instructions."""
# try:
# parsed_content = json.loads(content)
# except json.JSONDecodeError:
# return []
if isinstance(parsed_content, dict):
candidates = [parsed_content]
elif isinstance(parsed_content, list):
candidates = [item for item in parsed_content if isinstance(item, dict)]
else:
return []
# if isinstance(parsed_content, dict):
# candidates = [parsed_content]
# elif isinstance(parsed_content, list):
# candidates = [item for item in parsed_content if isinstance(item, dict)]
# else:
# return []
tool_calls: list[dict[str, Any]] = []
# tool_calls: list[dict[str, Any]] = []
for candidate in candidates:
name = candidate.get("name")
arguments = candidate.get("arguments")
# for candidate in candidates:
# name = candidate.get("name")
# arguments = candidate.get("arguments")
if not isinstance(name, str) or arguments is None:
continue
# if not isinstance(name, str) or arguments is None:
# continue
if not isinstance(arguments, dict):
continue
# if not isinstance(arguments, dict):
# continue
call_id = candidate.get("id")
arguments_str = json.dumps(arguments)
tool_calls.append(
{
"id": call_id,
"name": name,
"arguments": arguments_str,
}
)
# call_id = candidate.get("id")
# arguments_str = json.dumps(arguments)
# tool_calls.append(
# {
# "id": call_id,
# "name": name,
# "arguments": arguments_str,
# }
# )
return tool_calls
# return tool_calls
def _try_convert_content_to_tool_calls_for_non_tool_calling_llms(
tool_calls_in_progress: dict[int, dict[str, Any]],
content_parts: list[str],
structured_response_format: dict | None,
next_synthetic_tool_call_id: Callable[[], str],
) -> None:
"""Populate tool_calls_in_progress when a non-tool-calling LLM returns JSON content describing tool calls."""
if tool_calls_in_progress or not content_parts or structured_response_format:
return
# def _try_convert_content_to_tool_calls_for_non_tool_calling_llms(
# tool_calls_in_progress: dict[int, dict[str, Any]],
# content_parts: list[str],
# structured_response_format: dict | None,
# next_synthetic_tool_call_id: Callable[[], str],
# ) -> None:
# """Populate tool_calls_in_progress when a non-tool-calling LLM returns JSON content describing tool calls."""
# if tool_calls_in_progress or not content_parts or structured_response_format:
# return
tool_calls_from_content = _parse_tool_calls_from_message_content(
"".join(content_parts)
)
# tool_calls_from_content = _parse_tool_calls_from_message_content(
# "".join(content_parts)
# )
if not tool_calls_from_content:
return
# if not tool_calls_from_content:
# return
content_parts.clear()
# content_parts.clear()
for index, tool_call_data in enumerate(tool_calls_from_content):
call_id = tool_call_data["id"] or next_synthetic_tool_call_id()
tool_calls_in_progress[index] = {
"id": call_id,
"name": tool_call_data["name"],
"arguments": tool_call_data["arguments"],
}
# for index, tool_call_data in enumerate(tool_calls_from_content):
# call_id = tool_call_data["id"] or next_synthetic_tool_call_id()
# tool_calls_in_progress[index] = {
# "id": call_id,
# "name": tool_call_data["name"],
# "arguments": tool_call_data["arguments"],
# }
def _update_tool_call_with_delta(
tool_calls_in_progress: dict[int, dict[str, Any]],
tool_call_delta: Any,
) -> None:
index = tool_call_delta.index
# def _update_tool_call_with_delta(
# tool_calls_in_progress: dict[int, dict[str, Any]],
# tool_call_delta: Any,
# ) -> None:
# index = tool_call_delta.index
if index not in tool_calls_in_progress:
tool_calls_in_progress[index] = {
"id": None,
"name": None,
"arguments": "",
}
# if index not in tool_calls_in_progress:
# tool_calls_in_progress[index] = {
# "id": None,
# "name": None,
# "arguments": "",
# }
if tool_call_delta.id:
tool_calls_in_progress[index]["id"] = tool_call_delta.id
# if tool_call_delta.id:
# tool_calls_in_progress[index]["id"] = tool_call_delta.id
if tool_call_delta.function:
if tool_call_delta.function.name:
tool_calls_in_progress[index]["name"] = tool_call_delta.function.name
# if tool_call_delta.function:
# if tool_call_delta.function.name:
# tool_calls_in_progress[index]["name"] = tool_call_delta.function.name
if tool_call_delta.function.arguments:
tool_calls_in_progress[index][
"arguments"
] += tool_call_delta.function.arguments
# if tool_call_delta.function.arguments:
# tool_calls_in_progress[index][
# "arguments"
# ] += tool_call_delta.function.arguments
def query(
llm_with_default_settings: LLM,
messages: LanguageModelInput,
tools: Sequence[Tool],
context: Any,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> QueryResult:
tool_definitions = [tool.tool_definition() for tool in tools]
tools_by_name = {tool.name: tool for tool in tools}
# def query(
# llm_with_default_settings: LLM,
# messages: LanguageModelInput,
# tools: Sequence[Tool],
# context: Any,
# tool_choice: ToolChoiceOptions | None = None,
# structured_response_format: dict | None = None,
# ) -> QueryResult:
# tool_definitions = [tool.tool_definition() for tool in tools]
# tools_by_name = {tool.name: tool for tool in tools}
new_messages_stateful: list[ChatCompletionMessage] = []
# new_messages_stateful: list[ChatCompletionMessage] = []
current_span = agent_span(
name="agent_framework_query",
output_type="dict" if structured_response_format else "str",
)
current_span.start(mark_as_current=True)
current_span.span_data.tools = [t.name for t in tools]
# current_span = agent_span(
# name="agent_framework_query",
# output_type="dict" if structured_response_format else "str",
# )
# current_span.start(mark_as_current=True)
# current_span.span_data.tools = [t.name for t in tools]
def stream_generator() -> Iterator[StreamEvent]:
message_started = False
reasoning_started = False
# def stream_generator() -> Iterator[StreamEvent]:
# message_started = False
# reasoning_started = False
tool_calls_in_progress: dict[int, dict[str, Any]] = {}
# tool_calls_in_progress: dict[int, dict[str, Any]] = {}
content_parts: list[str] = []
# content_parts: list[str] = []
synthetic_tool_call_counter = 0
# synthetic_tool_call_counter = 0
def _next_synthetic_tool_call_id() -> str:
nonlocal synthetic_tool_call_counter
call_id = f"synthetic_tool_call_{synthetic_tool_call_counter}"
synthetic_tool_call_counter += 1
return call_id
# def _next_synthetic_tool_call_id() -> str:
# nonlocal synthetic_tool_call_counter
# call_id = f"synthetic_tool_call_{synthetic_tool_call_counter}"
# synthetic_tool_call_counter += 1
# return call_id
with generation_span( # type: ignore[misc]
model=llm_with_default_settings.config.model_name,
model_config={
"base_url": str(llm_with_default_settings.config.api_base or ""),
"model_impl": "litellm",
},
) as span_generation:
# Only set input if messages is a sequence (not a string)
# ChatCompletionMessage TypedDicts are compatible with Mapping[str, Any] at runtime
if isinstance(messages, Sequence) and not isinstance(messages, str):
# Convert ChatCompletionMessage sequence to Sequence[Mapping[str, Any]]
span_generation.span_data.input = [dict(msg) for msg in messages] # type: ignore[assignment]
for chunk in llm_with_default_settings.stream(
prompt=messages,
tools=tool_definitions,
tool_choice=tool_choice,
structured_response_format=structured_response_format,
):
assert isinstance(chunk, ModelResponseStream)
usage = getattr(chunk, "usage", None)
if usage:
span_generation.span_data.usage = {
"input_tokens": usage.prompt_tokens,
"output_tokens": usage.completion_tokens,
"cache_read_input_tokens": usage.cache_read_input_tokens,
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
}
# with generation_span( # type: ignore[misc]
# model=llm_with_default_settings.config.model_name,
# model_config={
# "base_url": str(llm_with_default_settings.config.api_base or ""),
# "model_impl": "litellm",
# },
# ) as span_generation:
# # Only set input if messages is a sequence (not a string)
# # ChatCompletionMessage TypedDicts are compatible with Mapping[str, Any] at runtime
# if isinstance(messages, Sequence) and not isinstance(messages, str):
# # Convert ChatCompletionMessage sequence to Sequence[Mapping[str, Any]]
# span_generation.span_data.input = [dict(msg) for msg in messages] # type: ignore[assignment]
# for chunk in llm_with_default_settings.stream(
# prompt=messages,
# tools=tool_definitions,
# tool_choice=tool_choice,
# structured_response_format=structured_response_format,
# ):
# assert isinstance(chunk, ModelResponseStream)
# usage = getattr(chunk, "usage", None)
# if usage:
# span_generation.span_data.usage = {
# "input_tokens": usage.prompt_tokens,
# "output_tokens": usage.completion_tokens,
# "cache_read_input_tokens": usage.cache_read_input_tokens,
# "cache_creation_input_tokens": usage.cache_creation_input_tokens,
# }
delta = chunk.choice.delta
finish_reason = chunk.choice.finish_reason
# delta = chunk.choice.delta
# finish_reason = chunk.choice.finish_reason
if delta.reasoning_content:
if not reasoning_started:
yield RunItemStreamEvent(type="reasoning_start")
reasoning_started = True
# if delta.reasoning_content:
# if not reasoning_started:
# yield RunItemStreamEvent(type="reasoning_start")
# reasoning_started = True
if delta.content:
if reasoning_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
content_parts.append(delta.content)
if not message_started:
yield RunItemStreamEvent(type="message_start")
message_started = True
# if delta.content:
# if reasoning_started:
# yield RunItemStreamEvent(type="reasoning_done")
# reasoning_started = False
# content_parts.append(delta.content)
# if not message_started:
# yield RunItemStreamEvent(type="message_start")
# message_started = True
if delta.tool_calls:
if reasoning_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
if message_started:
yield RunItemStreamEvent(type="message_done")
message_started = False
# if delta.tool_calls:
# if reasoning_started:
# yield RunItemStreamEvent(type="reasoning_done")
# reasoning_started = False
# if message_started:
# yield RunItemStreamEvent(type="message_done")
# message_started = False
for tool_call_delta in delta.tool_calls:
_update_tool_call_with_delta(
tool_calls_in_progress, tool_call_delta
)
# for tool_call_delta in delta.tool_calls:
# _update_tool_call_with_delta(
# tool_calls_in_progress, tool_call_delta
# )
yield chunk
# yield chunk
if not finish_reason:
continue
# if not finish_reason:
# continue
if reasoning_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
if message_started:
yield RunItemStreamEvent(type="message_done")
message_started = False
# if reasoning_started:
# yield RunItemStreamEvent(type="reasoning_done")
# reasoning_started = False
# if message_started:
# yield RunItemStreamEvent(type="message_done")
# message_started = False
if tool_choice != "none":
_try_convert_content_to_tool_calls_for_non_tool_calling_llms(
tool_calls_in_progress,
content_parts,
structured_response_format,
_next_synthetic_tool_call_id,
)
# if tool_choice != "none":
# _try_convert_content_to_tool_calls_for_non_tool_calling_llms(
# tool_calls_in_progress,
# content_parts,
# structured_response_format,
# _next_synthetic_tool_call_id,
# )
if content_parts:
new_messages_stateful.append(
{
"role": "assistant",
"content": "".join(content_parts),
}
)
span_generation.span_data.output = new_messages_stateful
# if content_parts:
# new_messages_stateful.append(
# {
# "role": "assistant",
# "content": "".join(content_parts),
# }
# )
# span_generation.span_data.output = new_messages_stateful
# Execute tool calls outside of the stream loop and generation_span
if tool_calls_in_progress:
sorted_tool_calls = sorted(tool_calls_in_progress.items())
# # Execute tool calls outside of the stream loop and generation_span
# if tool_calls_in_progress:
# sorted_tool_calls = sorted(tool_calls_in_progress.items())
# Build tool calls for the message and execute tools
assistant_tool_calls: list[ToolCall] = []
tool_outputs: dict[str, str] = {}
# # Build tool calls for the message and execute tools
# assistant_tool_calls: list[ToolCall] = []
for _, tool_call_data in sorted_tool_calls:
call_id = tool_call_data["id"]
name = tool_call_data["name"]
arguments_str = tool_call_data["arguments"]
# for _, tool_call_data in sorted_tool_calls:
# call_id = tool_call_data["id"]
# name = tool_call_data["name"]
# arguments_str = tool_call_data["arguments"]
if call_id is None or name is None:
continue
# if call_id is None or name is None:
# continue
assistant_tool_calls.append(
{
"id": call_id,
"type": "function",
"function": {
"name": name,
"arguments": arguments_str,
},
}
)
# assistant_tool_calls.append(
# {
# "id": call_id,
# "type": "function",
# "function": {
# "name": name,
# "arguments": arguments_str,
# },
# }
# )
yield RunItemStreamEvent(
type="tool_call",
details=ToolCallStreamItem(
call_id=call_id,
name=name,
arguments=arguments_str,
),
)
# yield RunItemStreamEvent(
# type="tool_call",
# details=ToolCallStreamItem(
# call_id=call_id,
# name=name,
# arguments=arguments_str,
# ),
# )
if name in tools_by_name:
tool = tools_by_name[name]
arguments = json.loads(arguments_str)
# if name in tools_by_name:
# tools_by_name[name]
# json.loads(arguments_str)
run_context = RunContextWrapper(context=context)
# run_context = RunContextWrapper(context=context)
# TODO: Instead of executing sequentially, execute in parallel
# In practice, it's not a must right now since we don't use parallel
# tool calls, so kicking the can down the road for now.
with function_span(tool.name) as span_fn:
span_fn.span_data.input = arguments
try:
output = tool.run_v2(run_context, **arguments)
tool_outputs[call_id] = _serialize_tool_output(output)
span_fn.span_data.output = output
except Exception as e:
_error_tracing.attach_error_to_current_span(
SpanError(
message="Error running tool",
data={"tool_name": tool.name, "error": str(e)},
)
)
# Treat the error as the tool output so the framework can continue
error_output = f"Error: {str(e)}"
tool_outputs[call_id] = error_output
output = error_output
# TODO: Instead of executing sequentially, execute in parallel
# In practice, it's not a must right now since we don't use parallel
# tool calls, so kicking the can down the road for now.
yield RunItemStreamEvent(
type="tool_call_output",
details=ToolCallOutputStreamItem(
call_id=call_id,
output=output,
),
)
else:
not_found_output = f"Tool {name} not found"
tool_outputs[call_id] = _serialize_tool_output(not_found_output)
yield RunItemStreamEvent(
type="tool_call_output",
details=ToolCallOutputStreamItem(
call_id=call_id,
output=not_found_output,
),
)
# TODO broken for now, no need for a run_v2
# output = tool.run_v2(run_context, **arguments)
new_messages_stateful.append(
{
"role": "assistant",
"content": None,
"tool_calls": assistant_tool_calls,
}
)
for _, tool_call_data in sorted_tool_calls:
call_id = tool_call_data["id"]
if call_id in tool_outputs:
new_messages_stateful.append(
{
"role": "tool",
"content": tool_outputs[call_id],
"tool_call_id": call_id,
}
)
current_span.finish(reset_current=True)
return QueryResult(
stream=stream_generator(),
new_messages_stateful=new_messages_stateful,
)
# yield RunItemStreamEvent(
# type="tool_call_output",
# details=ToolCallOutputStreamItem(
# call_id=call_id,
# output=output,
# ),
# )

View File

@@ -1,21 +1,21 @@
from operator import add
from typing import Annotated
# from operator import add
# from typing import Annotated
from pydantic import BaseModel
# from pydantic import BaseModel
class CoreState(BaseModel):
"""
This is the core state that is shared across all subgraphs.
"""
# class CoreState(BaseModel):
# """
# This is the core state that is shared across all subgraphs.
# """
log_messages: Annotated[list[str], add] = []
current_step_nr: int = 1
# log_messages: Annotated[list[str], add] = []
# current_step_nr: int = 1
class SubgraphCoreState(BaseModel):
"""
This is the core state that is shared across all subgraphs.
"""
# class SubgraphCoreState(BaseModel):
# """
# This is the core state that is shared across all subgraphs.
# """
log_messages: Annotated[list[str], add] = []
# log_messages: Annotated[list[str], add] = []

View File

@@ -1,62 +1,62 @@
from collections.abc import Hashable
from typing import cast
# from collections.abc import Hashable
# from typing import cast
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import Send
# from langchain_core.runnables.config import RunnableConfig
# from langgraph.types import Send
from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectResearchInformationUpdate,
)
from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
from onyx.agents.agent_search.dc_search_analysis.states import (
SearchSourcesObjectsUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
# from onyx.agents.agent_search.dc_search_analysis.states import (
# ObjectResearchInformationUpdate,
# )
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
# from onyx.agents.agent_search.dc_search_analysis.states import (
# SearchSourcesObjectsUpdate,
# )
# from onyx.agents.agent_search.models import GraphConfig
def parallel_object_source_research_edge(
state: SearchSourcesObjectsUpdate, config: RunnableConfig
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the research for an individual object and source
"""
# def parallel_object_source_research_edge(
# state: SearchSourcesObjectsUpdate, config: RunnableConfig
# ) -> list[Send | Hashable]:
# """
# LangGraph edge to parallelize the research for an individual object and source
# """
search_objects = state.analysis_objects
search_sources = state.analysis_sources
# search_objects = state.analysis_objects
# search_sources = state.analysis_sources
object_source_combinations = [
(object, source) for object in search_objects for source in search_sources
]
# object_source_combinations = [
# (object, source) for object in search_objects for source in search_sources
# ]
return [
Send(
"research_object_source",
ObjectSourceInput(
object_source_combination=object_source_combination,
log_messages=[],
),
)
for object_source_combination in object_source_combinations
]
# return [
# Send(
# "research_object_source",
# ObjectSourceInput(
# object_source_combination=object_source_combination,
# log_messages=[],
# ),
# )
# for object_source_combination in object_source_combinations
# ]
def parallel_object_research_consolidation_edge(
state: ObjectResearchInformationUpdate, config: RunnableConfig
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the research for an individual object and source
"""
cast(GraphConfig, config["metadata"]["config"])
object_research_information_results = state.object_research_information_results
# def parallel_object_research_consolidation_edge(
# state: ObjectResearchInformationUpdate, config: RunnableConfig
# ) -> list[Send | Hashable]:
# """
# LangGraph edge to parallelize the research for an individual object and source
# """
# cast(GraphConfig, config["metadata"]["config"])
# object_research_information_results = state.object_research_information_results
return [
Send(
"consolidate_object_research",
ObjectInformationInput(
object_information=object_information,
log_messages=[],
),
)
for object_information in object_research_information_results
]
# return [
# Send(
# "consolidate_object_research",
# ObjectInformationInput(
# object_information=object_information,
# log_messages=[],
# ),
# )
# for object_information in object_research_information_results
# ]

View File

@@ -1,103 +1,103 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dc_search_analysis.edges import (
parallel_object_research_consolidation_edge,
)
from onyx.agents.agent_search.dc_search_analysis.edges import (
parallel_object_source_research_edge,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a1_search_objects import (
search_objects,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a2_research_object_source import (
research_object_source,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a3_structure_research_by_object import (
structure_research_by_object,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a4_consolidate_object_research import (
consolidate_object_research,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a5_consolidate_research import (
consolidate_research,
)
from onyx.agents.agent_search.dc_search_analysis.states import MainInput
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dc_search_analysis.edges import (
# parallel_object_research_consolidation_edge,
# )
# from onyx.agents.agent_search.dc_search_analysis.edges import (
# parallel_object_source_research_edge,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a1_search_objects import (
# search_objects,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a2_research_object_source import (
# research_object_source,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a3_structure_research_by_object import (
# structure_research_by_object,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a4_consolidate_object_research import (
# consolidate_object_research,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a5_consolidate_research import (
# consolidate_research,
# )
# from onyx.agents.agent_search.dc_search_analysis.states import MainInput
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
test_mode = False
# test_mode = False
def divide_and_conquer_graph_builder(test_mode: bool = False) -> StateGraph:
"""
LangGraph graph builder for the knowledge graph search process.
"""
# def divide_and_conquer_graph_builder(test_mode: bool = False) -> StateGraph:
# """
# LangGraph graph builder for the knowledge graph search process.
# """
graph = StateGraph(
state_schema=MainState,
input=MainInput,
)
# graph = StateGraph(
# state_schema=MainState,
# input=MainInput,
# )
### Add nodes ###
# ### Add nodes ###
graph.add_node(
"search_objects",
search_objects,
)
# graph.add_node(
# "search_objects",
# search_objects,
# )
graph.add_node(
"structure_research_by_source",
structure_research_by_object,
)
# graph.add_node(
# "structure_research_by_source",
# structure_research_by_object,
# )
graph.add_node(
"research_object_source",
research_object_source,
)
# graph.add_node(
# "research_object_source",
# research_object_source,
# )
graph.add_node(
"consolidate_object_research",
consolidate_object_research,
)
# graph.add_node(
# "consolidate_object_research",
# consolidate_object_research,
# )
graph.add_node(
"consolidate_research",
consolidate_research,
)
# graph.add_node(
# "consolidate_research",
# consolidate_research,
# )
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key="search_objects")
# graph.add_edge(start_key=START, end_key="search_objects")
graph.add_conditional_edges(
source="search_objects",
path=parallel_object_source_research_edge,
path_map=["research_object_source"],
)
# graph.add_conditional_edges(
# source="search_objects",
# path=parallel_object_source_research_edge,
# path_map=["research_object_source"],
# )
graph.add_edge(
start_key="research_object_source",
end_key="structure_research_by_source",
)
# graph.add_edge(
# start_key="research_object_source",
# end_key="structure_research_by_source",
# )
graph.add_conditional_edges(
source="structure_research_by_source",
path=parallel_object_research_consolidation_edge,
path_map=["consolidate_object_research"],
)
# graph.add_conditional_edges(
# source="structure_research_by_source",
# path=parallel_object_research_consolidation_edge,
# path_map=["consolidate_object_research"],
# )
graph.add_edge(
start_key="consolidate_object_research",
end_key="consolidate_research",
)
# graph.add_edge(
# start_key="consolidate_object_research",
# end_key="consolidate_research",
# )
graph.add_edge(
start_key="consolidate_research",
end_key=END,
)
# graph.add_edge(
# start_key="consolidate_research",
# end_key=END,
# )
return graph
# return graph

View File

@@ -1,146 +1,146 @@
from typing import cast
# from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.ops import research
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import (
SearchSourcesObjectsUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.prompts.agents.dc_prompts import DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SEPARATOR
from onyx.prompts.agents.dc_prompts import DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT
from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
# from onyx.agents.agent_search.dc_search_analysis.ops import research
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
# from onyx.agents.agent_search.dc_search_analysis.states import (
# SearchSourcesObjectsUpdate,
# )
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_SEPARATOR
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT
# from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
# logger = setup_logger()
def search_objects(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> SearchSourcesObjectsUpdate:
"""
LangGraph node to start the agentic search process.
"""
# def search_objects(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> SearchSourcesObjectsUpdate:
# """
# LangGraph node to start the agentic search process.
# """
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
search_tool = graph_config.tooling.search_tool
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# question = graph_config.inputs.prompt_builder.raw_user_query
# search_tool = graph_config.tooling.search_tool
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# if search_tool is None or graph_config.inputs.persona is None:
# raise ValueError("Search tool and persona must be provided for DivCon search")
try:
instructions = graph_config.inputs.persona.system_prompt or ""
# try:
# instructions = graph_config.inputs.persona.system_prompt or ""
agent_1_instructions = extract_section(
instructions, "Agent Step 1:", "Agent Step 2:"
)
if agent_1_instructions is None:
raise ValueError("Agent 1 instructions not found")
# agent_1_instructions = extract_section(
# instructions, "Agent Step 1:", "Agent Step 2:"
# )
# if agent_1_instructions is None:
# raise ValueError("Agent 1 instructions not found")
agent_1_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
# agent_1_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
agent_1_task = extract_section(
agent_1_instructions, "Task:", "Independent Research Sources:"
)
if agent_1_task is None:
raise ValueError("Agent 1 task not found")
# agent_1_task = extract_section(
# agent_1_instructions, "Task:", "Independent Research Sources:"
# )
# if agent_1_task is None:
# raise ValueError("Agent 1 task not found")
agent_1_independent_sources_str = extract_section(
agent_1_instructions, "Independent Research Sources:", "Output Objective:"
)
if agent_1_independent_sources_str is None:
raise ValueError("Agent 1 Independent Research Sources not found")
# agent_1_independent_sources_str = extract_section(
# agent_1_instructions, "Independent Research Sources:", "Output Objective:"
# )
# if agent_1_independent_sources_str is None:
# raise ValueError("Agent 1 Independent Research Sources not found")
document_sources = strings_to_document_sources(
[
x.strip().lower()
for x in agent_1_independent_sources_str.split(DC_OBJECT_SEPARATOR)
]
)
# document_sources = strings_to_document_sources(
# [
# x.strip().lower()
# for x in agent_1_independent_sources_str.split(DC_OBJECT_SEPARATOR)
# ]
# )
agent_1_output_objective = extract_section(
agent_1_instructions, "Output Objective:"
)
if agent_1_output_objective is None:
raise ValueError("Agent 1 output objective not found")
# agent_1_output_objective = extract_section(
# agent_1_instructions, "Output Objective:"
# )
# if agent_1_output_objective is None:
# raise ValueError("Agent 1 output objective not found")
except Exception as e:
raise ValueError(
f"Agent 1 instructions not found or not formatted correctly: {e}"
)
# except Exception as e:
# raise ValueError(
# f"Agent 1 instructions not found or not formatted correctly: {e}"
# )
# Extract objects
# # Extract objects
if agent_1_base_data is None:
# Retrieve chunks for objects
# if agent_1_base_data is None:
# # Retrieve chunks for objects
retrieved_docs = research(question, search_tool)[:10]
# retrieved_docs = research(question, search_tool)[:10]
document_texts_list = []
for doc_num, doc in enumerate(retrieved_docs):
chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
document_texts_list.append(chunk_text)
# document_texts_list = []
# for doc_num, doc in enumerate(retrieved_docs):
# chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
# document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
# document_texts = "\n\n".join(document_texts_list)
dc_object_extraction_prompt = DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT.format(
question=question,
task=agent_1_task,
document_text=document_texts,
objects_of_interest=agent_1_output_objective,
)
else:
dc_object_extraction_prompt = DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT.format(
question=question,
task=agent_1_task,
base_data=agent_1_base_data,
objects_of_interest=agent_1_output_objective,
)
# dc_object_extraction_prompt = DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT.format(
# question=question,
# task=agent_1_task,
# document_text=document_texts,
# objects_of_interest=agent_1_output_objective,
# )
# else:
# dc_object_extraction_prompt = DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT.format(
# question=question,
# task=agent_1_task,
# base_data=agent_1_base_data,
# objects_of_interest=agent_1_output_objective,
# )
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_object_extraction_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=dc_object_extraction_prompt,
# reserved_str="",
# ),
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# 30,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=30,
# max_tokens=300,
# )
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
cleaned_response = cleaned_response.split("OBJECTS:")[1]
object_list = [x.strip() for x in cleaned_response.split(";")]
# cleaned_response = (
# str(llm_response.content)
# .replace("```json\n", "")
# .replace("\n```", "")
# .replace("\n", "")
# )
# cleaned_response = cleaned_response.split("OBJECTS:")[1]
# object_list = [x.strip() for x in cleaned_response.split(";")]
except Exception as e:
raise ValueError(f"Error in search_objects: {e}")
# except Exception as e:
# raise ValueError(f"Error in search_objects: {e}")
return SearchSourcesObjectsUpdate(
analysis_objects=object_list,
analysis_sources=document_sources,
log_messages=["Agent 1 Task done"],
)
# return SearchSourcesObjectsUpdate(
# analysis_objects=object_list,
# analysis_sources=document_sources,
# log_messages=["Agent 1 Task done"],
# )

View File

@@ -1,180 +1,180 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import cast
# from datetime import datetime
# from datetime import timedelta
# from datetime import timezone
# from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.ops import research
from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectSourceResearchUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SOURCE_RESEARCH_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
# from onyx.agents.agent_search.dc_search_analysis.ops import research
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
# from onyx.agents.agent_search.dc_search_analysis.states import (
# ObjectSourceResearchUpdate,
# )
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_SOURCE_RESEARCH_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
# logger = setup_logger()
def research_object_source(
state: ObjectSourceInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ObjectSourceResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
datetime.now()
# def research_object_source(
# state: ObjectSourceInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> ObjectSourceResearchUpdate:
# """
# LangGraph node to start the agentic search process.
# """
# datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
question = graph_config.inputs.prompt_builder.raw_user_query
object, document_source = state.object_source_combination
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# search_tool = graph_config.tooling.search_tool
# question = graph_config.inputs.prompt_builder.raw_user_query
# object, document_source = state.object_source_combination
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# if search_tool is None or graph_config.inputs.persona is None:
# raise ValueError("Search tool and persona must be provided for DivCon search")
try:
instructions = graph_config.inputs.persona.system_prompt or ""
# try:
# instructions = graph_config.inputs.persona.system_prompt or ""
agent_2_instructions = extract_section(
instructions, "Agent Step 2:", "Agent Step 3:"
)
if agent_2_instructions is None:
raise ValueError("Agent 2 instructions not found")
# agent_2_instructions = extract_section(
# instructions, "Agent Step 2:", "Agent Step 3:"
# )
# if agent_2_instructions is None:
# raise ValueError("Agent 2 instructions not found")
agent_2_task = extract_section(
agent_2_instructions, "Task:", "Independent Research Sources:"
)
if agent_2_task is None:
raise ValueError("Agent 2 task not found")
# agent_2_task = extract_section(
# agent_2_instructions, "Task:", "Independent Research Sources:"
# )
# if agent_2_task is None:
# raise ValueError("Agent 2 task not found")
agent_2_time_cutoff = extract_section(
agent_2_instructions, "Time Cutoff:", "Research Topics:"
)
# agent_2_time_cutoff = extract_section(
# agent_2_instructions, "Time Cutoff:", "Research Topics:"
# )
agent_2_research_topics = extract_section(
agent_2_instructions, "Research Topics:", "Output Objective"
)
# agent_2_research_topics = extract_section(
# agent_2_instructions, "Research Topics:", "Output Objective"
# )
agent_2_output_objective = extract_section(
agent_2_instructions, "Output Objective:"
)
if agent_2_output_objective is None:
raise ValueError("Agent 2 output objective not found")
# agent_2_output_objective = extract_section(
# agent_2_instructions, "Output Objective:"
# )
# if agent_2_output_objective is None:
# raise ValueError("Agent 2 output objective not found")
except Exception:
raise ValueError(
"Agent 1 instructions not found or not formatted correctly: {e}"
)
# except Exception:
# raise ValueError(
# "Agent 1 instructions not found or not formatted correctly: {e}"
# )
# Populate prompt
# # Populate prompt
# Retrieve chunks for objects
# # Retrieve chunks for objects
if agent_2_time_cutoff is not None and agent_2_time_cutoff.strip() != "":
if agent_2_time_cutoff.strip().endswith("d"):
try:
days = int(agent_2_time_cutoff.strip()[:-1])
agent_2_source_start_time = datetime.now(timezone.utc) - timedelta(
days=days
)
except ValueError:
raise ValueError(
f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
)
else:
raise ValueError(
f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
)
else:
agent_2_source_start_time = None
# if agent_2_time_cutoff is not None and agent_2_time_cutoff.strip() != "":
# if agent_2_time_cutoff.strip().endswith("d"):
# try:
# days = int(agent_2_time_cutoff.strip()[:-1])
# agent_2_source_start_time = datetime.now(timezone.utc) - timedelta(
# days=days
# )
# except ValueError:
# raise ValueError(
# f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
# )
# else:
# raise ValueError(
# f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
# )
# else:
# agent_2_source_start_time = None
document_sources = [document_source] if document_source else None
# document_sources = [document_source] if document_source else None
if len(question.strip()) > 0:
research_area = f"{question} for {object}"
elif agent_2_research_topics and len(agent_2_research_topics.strip()) > 0:
research_area = f"{agent_2_research_topics} for {object}"
else:
research_area = object
# if len(question.strip()) > 0:
# research_area = f"{question} for {object}"
# elif agent_2_research_topics and len(agent_2_research_topics.strip()) > 0:
# research_area = f"{agent_2_research_topics} for {object}"
# else:
# research_area = object
retrieved_docs = research(
question=research_area,
search_tool=search_tool,
document_sources=document_sources,
time_cutoff=agent_2_source_start_time,
)
# retrieved_docs = research(
# question=research_area,
# search_tool=search_tool,
# document_sources=document_sources,
# time_cutoff=agent_2_source_start_time,
# )
# Generate document text
# # Generate document text
document_texts_list = []
for doc_num, doc in enumerate(retrieved_docs):
chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
document_texts_list.append(chunk_text)
# document_texts_list = []
# for doc_num, doc in enumerate(retrieved_docs):
# chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
# document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
# document_texts = "\n\n".join(document_texts_list)
# Built prompt
# # Built prompt
today = datetime.now().strftime("%A, %Y-%m-%d")
# today = datetime.now().strftime("%A, %Y-%m-%d")
dc_object_source_research_prompt = (
DC_OBJECT_SOURCE_RESEARCH_PROMPT.format(
today=today,
question=question,
task=agent_2_task,
document_text=document_texts,
format=agent_2_output_objective,
)
.replace("---object---", object)
.replace("---source---", document_source.value)
)
# dc_object_source_research_prompt = (
# DC_OBJECT_SOURCE_RESEARCH_PROMPT.format(
# today=today,
# question=question,
# task=agent_2_task,
# document_text=document_texts,
# format=agent_2_output_objective,
# )
# .replace("---object---", object)
# .replace("---source---", document_source.value)
# )
# Run LLM
# # Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_object_source_research_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=dc_object_source_research_prompt,
# reserved_str="",
# ),
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# 30,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=30,
# max_tokens=300,
# )
cleaned_response = str(llm_response.content).replace("```json\n", "")
cleaned_response = cleaned_response.split("RESEARCH RESULTS:")[1]
object_research_results = {
"object": object,
"source": document_source.value,
"research_result": cleaned_response,
}
# cleaned_response = str(llm_response.content).replace("```json\n", "")
# cleaned_response = cleaned_response.split("RESEARCH RESULTS:")[1]
# object_research_results = {
# "object": object,
# "source": document_source.value,
# "research_result": cleaned_response,
# }
except Exception as e:
raise ValueError(f"Error in research_object_source: {e}")
# except Exception as e:
# raise ValueError(f"Error in research_object_source: {e}")
logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
# logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
return ObjectSourceResearchUpdate(
object_source_research_results=[object_research_results],
log_messages=["Agent Step 2 done for one object"],
)
# return ObjectSourceResearchUpdate(
# object_source_research_results=[object_research_results],
# log_messages=["Agent Step 2 done for one object"],
# )

View File

@@ -1,48 +1,48 @@
from collections import defaultdict
from typing import Dict
from typing import List
# from collections import defaultdict
# from typing import Dict
# from typing import List
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectResearchInformationUpdate,
)
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
# from onyx.agents.agent_search.dc_search_analysis.states import (
# ObjectResearchInformationUpdate,
# )
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def structure_research_by_object(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ObjectResearchInformationUpdate:
"""
LangGraph node to start the agentic search process.
"""
# def structure_research_by_object(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> ObjectResearchInformationUpdate:
# """
# LangGraph node to start the agentic search process.
# """
object_source_research_results = state.object_source_research_results
# object_source_research_results = state.object_source_research_results
object_research_information_results: List[Dict[str, str]] = []
object_research_information_results_list: Dict[str, List[str]] = defaultdict(list)
# object_research_information_results: List[Dict[str, str]] = []
# object_research_information_results_list: Dict[str, List[str]] = defaultdict(list)
for object_source_research in object_source_research_results:
object = object_source_research["object"]
source = object_source_research["source"]
research_result = object_source_research["research_result"]
# for object_source_research in object_source_research_results:
# object = object_source_research["object"]
# source = object_source_research["source"]
# research_result = object_source_research["research_result"]
object_research_information_results_list[object].append(
f"Source: {source}\n{research_result}"
)
# object_research_information_results_list[object].append(
# f"Source: {source}\n{research_result}"
# )
for object, information in object_research_information_results_list.items():
object_research_information_results.append(
{"object": object, "information": "\n".join(information)}
)
# for object, information in object_research_information_results_list.items():
# object_research_information_results.append(
# {"object": object, "information": "\n".join(information)}
# )
logger.debug("DivCon Step A3 - Object Research Information Structuring - completed")
# logger.debug("DivCon Step A3 - Object Research Information Structuring - completed")
return ObjectResearchInformationUpdate(
object_research_information_results=object_research_information_results,
log_messages=["A3 - Object Research Information structured"],
)
# return ObjectResearchInformationUpdate(
# object_research_information_results=object_research_information_results,
# log_messages=["A3 - Object Research Information structured"],
# )

View File

@@ -1,103 +1,103 @@
from typing import cast
# from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
from onyx.agents.agent_search.dc_search_analysis.states import ObjectResearchUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.prompts.agents.dc_prompts import DC_OBJECT_CONSOLIDATION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectResearchUpdate
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_CONSOLIDATION_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
# logger = setup_logger()
def consolidate_object_research(
state: ObjectInformationInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ObjectResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
question = graph_config.inputs.prompt_builder.raw_user_query
# def consolidate_object_research(
# state: ObjectInformationInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> ObjectResearchUpdate:
# """
# LangGraph node to start the agentic search process.
# """
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# search_tool = graph_config.tooling.search_tool
# question = graph_config.inputs.prompt_builder.raw_user_query
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# if search_tool is None or graph_config.inputs.persona is None:
# raise ValueError("Search tool and persona must be provided for DivCon search")
instructions = graph_config.inputs.persona.system_prompt or ""
# instructions = graph_config.inputs.persona.system_prompt or ""
agent_4_instructions = extract_section(
instructions, "Agent Step 4:", "Agent Step 5:"
)
if agent_4_instructions is None:
raise ValueError("Agent 4 instructions not found")
agent_4_output_objective = extract_section(
agent_4_instructions, "Output Objective:"
)
if agent_4_output_objective is None:
raise ValueError("Agent 4 output objective not found")
# agent_4_instructions = extract_section(
# instructions, "Agent Step 4:", "Agent Step 5:"
# )
# if agent_4_instructions is None:
# raise ValueError("Agent 4 instructions not found")
# agent_4_output_objective = extract_section(
# agent_4_instructions, "Output Objective:"
# )
# if agent_4_output_objective is None:
# raise ValueError("Agent 4 output objective not found")
object_information = state.object_information
# object_information = state.object_information
object = object_information["object"]
information = object_information["information"]
# object = object_information["object"]
# information = object_information["information"]
# Create a prompt for the object consolidation
# # Create a prompt for the object consolidation
dc_object_consolidation_prompt = DC_OBJECT_CONSOLIDATION_PROMPT.format(
question=question,
object=object,
information=information,
format=agent_4_output_objective,
)
# dc_object_consolidation_prompt = DC_OBJECT_CONSOLIDATION_PROMPT.format(
# question=question,
# object=object,
# information=information,
# format=agent_4_output_objective,
# )
# Run LLM
# # Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_object_consolidation_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=dc_object_consolidation_prompt,
# reserved_str="",
# ),
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# 30,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=30,
# max_tokens=300,
# )
cleaned_response = str(llm_response.content).replace("```json\n", "")
consolidated_information = cleaned_response.split("INFORMATION:")[1]
# cleaned_response = str(llm_response.content).replace("```json\n", "")
# consolidated_information = cleaned_response.split("INFORMATION:")[1]
except Exception as e:
raise ValueError(f"Error in consolidate_object_research: {e}")
# except Exception as e:
# raise ValueError(f"Error in consolidate_object_research: {e}")
object_research_results = {
"object": object,
"research_result": consolidated_information,
}
# object_research_results = {
# "object": object,
# "research_result": consolidated_information,
# }
logger.debug(
"DivCon Step A4 - Object Research Consolidation - completed for an object"
)
# logger.debug(
# "DivCon Step A4 - Object Research Consolidation - completed for an object"
# )
return ObjectResearchUpdate(
object_research_results=[object_research_results],
log_messages=["Agent Source Consilidation done"],
)
# return ObjectResearchUpdate(
# object_research_results=[object_research_results],
# log_messages=["Agent Source Consilidation done"],
# )

View File

@@ -1,127 +1,127 @@
from typing import cast
# from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import ResearchUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_WITH_BASE_DATA_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
# from onyx.agents.agent_search.dc_search_analysis.states import ResearchUpdate
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
# from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
# from onyx.prompts.agents.dc_prompts import DC_FORMATTING_WITH_BASE_DATA_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
# logger = setup_logger()
def consolidate_research(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
# def consolidate_research(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> ResearchUpdate:
# """
# LangGraph node to start the agentic search process.
# """
graph_config = cast(GraphConfig, config["metadata"]["config"])
# graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
# search_tool = graph_config.tooling.search_tool
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# if search_tool is None or graph_config.inputs.persona is None:
# raise ValueError("Search tool and persona must be provided for DivCon search")
# Populate prompt
instructions = graph_config.inputs.persona.system_prompt or ""
# # Populate prompt
# instructions = graph_config.inputs.persona.system_prompt or ""
try:
agent_5_instructions = extract_section(
instructions, "Agent Step 5:", "Agent End"
)
if agent_5_instructions is None:
raise ValueError("Agent 5 instructions not found")
agent_5_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
agent_5_task = extract_section(
agent_5_instructions, "Task:", "Independent Research Sources:"
)
if agent_5_task is None:
raise ValueError("Agent 5 task not found")
agent_5_output_objective = extract_section(
agent_5_instructions, "Output Objective:"
)
if agent_5_output_objective is None:
raise ValueError("Agent 5 output objective not found")
except ValueError as e:
raise ValueError(
f"Instructions for Agent Step 5 were not properly formatted: {e}"
)
# try:
# agent_5_instructions = extract_section(
# instructions, "Agent Step 5:", "Agent End"
# )
# if agent_5_instructions is None:
# raise ValueError("Agent 5 instructions not found")
# agent_5_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
# agent_5_task = extract_section(
# agent_5_instructions, "Task:", "Independent Research Sources:"
# )
# if agent_5_task is None:
# raise ValueError("Agent 5 task not found")
# agent_5_output_objective = extract_section(
# agent_5_instructions, "Output Objective:"
# )
# if agent_5_output_objective is None:
# raise ValueError("Agent 5 output objective not found")
# except ValueError as e:
# raise ValueError(
# f"Instructions for Agent Step 5 were not properly formatted: {e}"
# )
research_result_list = []
# research_result_list = []
if agent_5_task.strip() == "*concatenate*":
object_research_results = state.object_research_results
# if agent_5_task.strip() == "*concatenate*":
# object_research_results = state.object_research_results
for object_research_result in object_research_results:
object = object_research_result["object"]
research_result = object_research_result["research_result"]
research_result_list.append(f"Object: {object}\n\n{research_result}")
# for object_research_result in object_research_results:
# object = object_research_result["object"]
# research_result = object_research_result["research_result"]
# research_result_list.append(f"Object: {object}\n\n{research_result}")
research_results = "\n\n".join(research_result_list)
# research_results = "\n\n".join(research_result_list)
else:
raise NotImplementedError("Only '*concatenate*' is currently supported")
# else:
# raise NotImplementedError("Only '*concatenate*' is currently supported")
# Create a prompt for the object consolidation
# # Create a prompt for the object consolidation
if agent_5_base_data is None:
dc_formatting_prompt = DC_FORMATTING_NO_BASE_DATA_PROMPT.format(
text=research_results,
format=agent_5_output_objective,
)
else:
dc_formatting_prompt = DC_FORMATTING_WITH_BASE_DATA_PROMPT.format(
base_data=agent_5_base_data,
text=research_results,
format=agent_5_output_objective,
)
# if agent_5_base_data is None:
# dc_formatting_prompt = DC_FORMATTING_NO_BASE_DATA_PROMPT.format(
# text=research_results,
# format=agent_5_output_objective,
# )
# else:
# dc_formatting_prompt = DC_FORMATTING_WITH_BASE_DATA_PROMPT.format(
# base_data=agent_5_base_data,
# text=research_results,
# format=agent_5_output_objective,
# )
# Run LLM
# # Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_formatting_prompt,
reserved_str="",
),
)
]
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=dc_formatting_prompt,
# reserved_str="",
# ),
# )
# ]
try:
_ = run_with_timeout(
60,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=msg,
event_name="initial_agent_answer",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=30,
max_tokens=None,
),
)
# try:
# _ = run_with_timeout(
# 60,
# lambda: stream_llm_answer(
# llm=graph_config.tooling.primary_llm,
# prompt=msg,
# event_name="initial_agent_answer",
# writer=writer,
# agent_answer_level=0,
# agent_answer_question_num=0,
# agent_answer_type="agent_level_answer",
# timeout_override=30,
# max_tokens=None,
# ),
# )
except Exception as e:
raise ValueError(f"Error in consolidate_research: {e}")
# except Exception as e:
# raise ValueError(f"Error in consolidate_research: {e}")
logger.debug("DivCon Step A5 - Final Generation - completed")
# logger.debug("DivCon Step A5 - Final Generation - completed")
return ResearchUpdate(
research_results=research_results,
log_messages=["Agent Source Consilidation done"],
)
# return ResearchUpdate(
# research_results=research_results,
# log_messages=["Agent Source Consilidation done"],
# )

View File

@@ -1,61 +1,50 @@
from datetime import datetime
from typing import cast
# from datetime import datetime
# from typing import cast
from onyx.chat.models import LlmDoc
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
# from onyx.chat.models import LlmDoc
# from onyx.configs.constants import DocumentSource
# from onyx.tools.models import SearchToolOverrideKwargs
# from onyx.tools.tool_implementations.search.search_tool import (
# FINAL_CONTEXT_DOCUMENTS_ID,
# )
# from onyx.tools.tool_implementations.search.search_tool import SearchTool
def research(
question: str,
search_tool: SearchTool,
document_sources: list[DocumentSource] | None = None,
time_cutoff: datetime | None = None,
) -> list[LlmDoc]:
# new db session to avoid concurrency issues
# def research(
# question: str,
# search_tool: SearchTool,
# document_sources: list[DocumentSource] | None = None,
# time_cutoff: datetime | None = None,
# ) -> list[LlmDoc]:
# # new db session to avoid concurrency issues
callback_container: list[list[InferenceSection]] = []
retrieved_docs: list[LlmDoc] = []
# retrieved_docs: list[LlmDoc] = []
with get_session_with_current_tenant() as db_session:
for tool_response in search_tool.run(
query=question,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=False,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
document_sources=document_sources,
time_cutoff=time_cutoff,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
retrieved_docs = cast(list[LlmDoc], tool_response.response)[:10]
break
return retrieved_docs
# for tool_response in search_tool.run(
# query=question,
# override_kwargs=SearchToolOverrideKwargs(original_query=question),
# ):
# # get retrieved docs to send to the rest of the graph
# if tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
# retrieved_docs = cast(list[LlmDoc], tool_response.response)[:10]
# break
# return retrieved_docs
def extract_section(
text: str, start_marker: str, end_marker: str | None = None
) -> str | None:
"""Extract text between markers, returning None if markers not found"""
parts = text.split(start_marker)
# def extract_section(
# text: str, start_marker: str, end_marker: str | None = None
# ) -> str | None:
# """Extract text between markers, returning None if markers not found"""
# parts = text.split(start_marker)
if len(parts) == 1:
return None
# if len(parts) == 1:
# return None
after_start = parts[1].strip()
# after_start = parts[1].strip()
if not end_marker:
return after_start
# if not end_marker:
# return after_start
extract = after_start.split(end_marker)[0]
# extract = after_start.split(end_marker)[0]
return extract.strip()
# return extract.strip()

View File

@@ -1,72 +1,72 @@
from operator import add
from typing import Annotated
from typing import Dict
from typing import TypedDict
# from operator import add
# from typing import Annotated
# from typing import Dict
# from typing import TypedDict
from pydantic import BaseModel
# from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.configs.constants import DocumentSource
# from onyx.agents.agent_search.core_state import CoreState
# from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
# from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
# from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
# from onyx.configs.constants import DocumentSource
### States ###
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
# ### States ###
# class LoggerUpdate(BaseModel):
# log_messages: Annotated[list[str], add] = []
class SearchSourcesObjectsUpdate(LoggerUpdate):
analysis_objects: list[str] = []
analysis_sources: list[DocumentSource] = []
# class SearchSourcesObjectsUpdate(LoggerUpdate):
# analysis_objects: list[str] = []
# analysis_sources: list[DocumentSource] = []
class ObjectSourceInput(LoggerUpdate):
object_source_combination: tuple[str, DocumentSource]
# class ObjectSourceInput(LoggerUpdate):
# object_source_combination: tuple[str, DocumentSource]
class ObjectSourceResearchUpdate(LoggerUpdate):
object_source_research_results: Annotated[list[Dict[str, str]], add] = []
# class ObjectSourceResearchUpdate(LoggerUpdate):
# object_source_research_results: Annotated[list[Dict[str, str]], add] = []
class ObjectInformationInput(LoggerUpdate):
object_information: Dict[str, str]
# class ObjectInformationInput(LoggerUpdate):
# object_information: Dict[str, str]
class ObjectResearchInformationUpdate(LoggerUpdate):
object_research_information_results: Annotated[list[Dict[str, str]], add] = []
# class ObjectResearchInformationUpdate(LoggerUpdate):
# object_research_information_results: Annotated[list[Dict[str, str]], add] = []
class ObjectResearchUpdate(LoggerUpdate):
object_research_results: Annotated[list[Dict[str, str]], add] = []
# class ObjectResearchUpdate(LoggerUpdate):
# object_research_results: Annotated[list[Dict[str, str]], add] = []
class ResearchUpdate(LoggerUpdate):
research_results: str | None = None
# class ResearchUpdate(LoggerUpdate):
# research_results: str | None = None
## Graph Input State
class MainInput(CoreState):
pass
# ## Graph Input State
# class MainInput(CoreState):
# pass
## Graph State
class MainState(
# This includes the core state
MainInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
SearchSourcesObjectsUpdate,
ObjectSourceResearchUpdate,
ObjectResearchInformationUpdate,
ObjectResearchUpdate,
ResearchUpdate,
):
pass
# ## Graph State
# class MainState(
# # This includes the core state
# MainInput,
# ToolChoiceInput,
# ToolCallUpdate,
# ToolChoiceUpdate,
# SearchSourcesObjectsUpdate,
# ObjectSourceResearchUpdate,
# ObjectResearchInformationUpdate,
# ObjectResearchUpdate,
# ResearchUpdate,
# ):
# pass
## Graph Output State - presently not used
class MainOutput(TypedDict):
log_messages: list[str]
# ## Graph Output State - presently not used
# class MainOutput(TypedDict):
# log_messages: list[str]

View File

@@ -1,36 +1,36 @@
from pydantic import BaseModel
# from pydantic import BaseModel
class RefinementSubQuestion(BaseModel):
sub_question: str
sub_question_id: str
verified: bool
answered: bool
answer: str
# class RefinementSubQuestion(BaseModel):
# sub_question: str
# sub_question_id: str
# verified: bool
# answered: bool
# answer: str
class AgentTimings(BaseModel):
base_duration_s: float | None
refined_duration_s: float | None
full_duration_s: float | None
# class AgentTimings(BaseModel):
# base_duration_s: float | None
# refined_duration_s: float | None
# full_duration_s: float | None
class AgentBaseMetrics(BaseModel):
num_verified_documents_total: int | None
num_verified_documents_core: int | None
verified_avg_score_core: float | None
num_verified_documents_base: int | float | None
verified_avg_score_base: float | None = None
base_doc_boost_factor: float | None = None
support_boost_factor: float | None = None
duration_s: float | None = None
# class AgentBaseMetrics(BaseModel):
# num_verified_documents_total: int | None
# num_verified_documents_core: int | None
# verified_avg_score_core: float | None
# num_verified_documents_base: int | float | None
# verified_avg_score_base: float | None = None
# base_doc_boost_factor: float | None = None
# support_boost_factor: float | None = None
# duration_s: float | None = None
class AgentRefinedMetrics(BaseModel):
refined_doc_boost_factor: float | None = None
refined_question_boost_factor: float | None = None
duration_s: float | None = None
# class AgentRefinedMetrics(BaseModel):
# refined_doc_boost_factor: float | None = None
# refined_question_boost_factor: float | None = None
# duration_s: float | None = None
class AgentAdditionalMetrics(BaseModel):
pass
# class AgentAdditionalMetrics(BaseModel):
# pass

View File

@@ -1,61 +1,61 @@
from collections.abc import Hashable
# from collections.abc import Hashable
from langgraph.graph import END
from langgraph.types import Send
# from langgraph.graph import END
# from langgraph.types import Send
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.states import MainState
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.states import MainState
def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
if not state.tools_used:
raise IndexError("state.tools_used cannot be empty")
# def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
# if not state.tools_used:
# raise IndexError("state.tools_used cannot be empty")
# next_tool is either a generic tool name or a DRPath string
next_tool_name = state.tools_used[-1]
# # next_tool is either a generic tool name or a DRPath string
# next_tool_name = state.tools_used[-1]
available_tools = state.available_tools
if not available_tools:
raise ValueError("No tool is available. This should not happen.")
# available_tools = state.available_tools
# if not available_tools:
# raise ValueError("No tool is available. This should not happen.")
if next_tool_name in available_tools:
next_tool_path = available_tools[next_tool_name].path
elif next_tool_name == DRPath.END.value:
return END
elif next_tool_name == DRPath.LOGGER.value:
return DRPath.LOGGER
elif next_tool_name == DRPath.CLOSER.value:
return DRPath.CLOSER
else:
return DRPath.ORCHESTRATOR
# if next_tool_name in available_tools:
# next_tool_path = available_tools[next_tool_name].path
# elif next_tool_name == DRPath.END.value:
# return END
# elif next_tool_name == DRPath.LOGGER.value:
# return DRPath.LOGGER
# elif next_tool_name == DRPath.CLOSER.value:
# return DRPath.CLOSER
# else:
# return DRPath.ORCHESTRATOR
# handle invalid paths
if next_tool_path == DRPath.CLARIFIER:
raise ValueError("CLARIFIER is not a valid path during iteration")
# # handle invalid paths
# if next_tool_path == DRPath.CLARIFIER:
# raise ValueError("CLARIFIER is not a valid path during iteration")
# handle tool calls without a query
if (
next_tool_path
in (
DRPath.INTERNAL_SEARCH,
DRPath.WEB_SEARCH,
DRPath.KNOWLEDGE_GRAPH,
DRPath.IMAGE_GENERATION,
)
and len(state.query_list) == 0
):
return DRPath.CLOSER
# # handle tool calls without a query
# if (
# next_tool_path
# in (
# DRPath.INTERNAL_SEARCH,
# DRPath.WEB_SEARCH,
# DRPath.KNOWLEDGE_GRAPH,
# DRPath.IMAGE_GENERATION,
# )
# and len(state.query_list) == 0
# ):
# return DRPath.CLOSER
return next_tool_path
# return next_tool_path
def completeness_router(state: MainState) -> DRPath | str:
if not state.tools_used:
raise IndexError("tools_used cannot be empty")
# def completeness_router(state: MainState) -> DRPath | str:
# if not state.tools_used:
# raise IndexError("tools_used cannot be empty")
# go to closer if path is CLOSER or no queries
next_path = state.tools_used[-1]
# # go to closer if path is CLOSER or no queries
# next_path = state.tools_used[-1]
if next_path == DRPath.ORCHESTRATOR.value:
return DRPath.ORCHESTRATOR
return DRPath.LOGGER
# if next_path == DRPath.ORCHESTRATOR.value:
# return DRPath.ORCHESTRATOR
# return DRPath.LOGGER

View File

@@ -1,31 +1,27 @@
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchType
# from onyx.agents.agent_search.dr.enums import DRPath
MAX_CHAT_HISTORY_MESSAGES = (
3 # note: actual count is x2 to account for user and assistant messages
)
# MAX_CHAT_HISTORY_MESSAGES = (
# 3 # note: actual count is x2 to account for user and assistant messages
# )
MAX_DR_PARALLEL_SEARCH = 4
# MAX_DR_PARALLEL_SEARCH = 4
# TODO: test more, generally not needed/adds unnecessary iterations
MAX_NUM_CLOSER_SUGGESTIONS = (
0 # how many times the closer can send back to the orchestrator
)
# # TODO: test more, generally not needed/adds unnecessary iterations
# MAX_NUM_CLOSER_SUGGESTIONS = (
# 0 # how many times the closer can send back to the orchestrator
# )
CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
HIGH_LEVEL_PLAN_PREFIX = "The Plan:"
# CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
# HIGH_LEVEL_PLAN_PREFIX = "The Plan:"
AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
DRPath.INTERNAL_SEARCH: 1.0,
DRPath.KNOWLEDGE_GRAPH: 2.0,
DRPath.WEB_SEARCH: 1.5,
DRPath.IMAGE_GENERATION: 3.0,
DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
DRPath.CLOSER: 0.0,
}
# AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
# DRPath.INTERNAL_SEARCH: 1.0,
# DRPath.KNOWLEDGE_GRAPH: 2.0,
# DRPath.WEB_SEARCH: 1.5,
# DRPath.IMAGE_GENERATION: 3.0,
# DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
# DRPath.CLOSER: 0.0,
# }
DR_TIME_BUDGET_BY_TYPE = {
ResearchType.THOUGHTFUL: 3.0,
ResearchType.DEEP: 12.0,
ResearchType.FAST: 0.5,
}
# # Default time budget for agentic search (when use_agentic_search is True)
# DR_TIME_BUDGET_DEFAULT = 12.0

View File

@@ -1,112 +1,111 @@
from datetime import datetime
# from datetime import datetime
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import DRPromptPurpose
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.prompts.dr_prompts import GET_CLARIFICATION_PROMPT
from onyx.prompts.dr_prompts import KG_TYPES_DESCRIPTIONS
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
from onyx.prompts.dr_prompts import TOOL_DIFFERENTIATION_HINTS
from onyx.prompts.dr_prompts import TOOL_QUESTION_HINTS
from onyx.prompts.prompt_template import PromptTemplate
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.models import DRPromptPurpose
# from onyx.agents.agent_search.dr.models import OrchestratorTool
# from onyx.prompts.dr_prompts import GET_CLARIFICATION_PROMPT
# from onyx.prompts.dr_prompts import KG_TYPES_DESCRIPTIONS
# from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
# from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
# from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
# from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
# from onyx.prompts.dr_prompts import ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
# from onyx.prompts.dr_prompts import TOOL_DIFFERENTIATION_HINTS
# from onyx.prompts.dr_prompts import TOOL_QUESTION_HINTS
# from onyx.prompts.prompt_template import PromptTemplate
def get_dr_prompt_orchestration_templates(
purpose: DRPromptPurpose,
research_type: ResearchType,
available_tools: dict[str, OrchestratorTool],
entity_types_string: str | None = None,
relationship_types_string: str | None = None,
reasoning_result: str | None = None,
tool_calls_string: str | None = None,
) -> PromptTemplate:
available_tools = available_tools or {}
tool_names = list(available_tools.keys())
tool_description_str = "\n\n".join(
f"- {tool_name}: {tool.description}"
for tool_name, tool in available_tools.items()
)
tool_cost_str = "\n".join(
f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
)
# def get_dr_prompt_orchestration_templates(
# purpose: DRPromptPurpose,
# use_agentic_search: bool,
# available_tools: dict[str, OrchestratorTool],
# entity_types_string: str | None = None,
# relationship_types_string: str | None = None,
# reasoning_result: str | None = None,
# tool_calls_string: str | None = None,
# ) -> PromptTemplate:
# available_tools = available_tools or {}
# tool_names = list(available_tools.keys())
# tool_description_str = "\n\n".join(
# f"- {tool_name}: {tool.description}"
# for tool_name, tool in available_tools.items()
# )
# tool_cost_str = "\n".join(
# f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
# )
tool_differentiations: list[str] = [
TOOL_DIFFERENTIATION_HINTS[(tool_1, tool_2)]
for tool_1 in available_tools
for tool_2 in available_tools
if (tool_1, tool_2) in TOOL_DIFFERENTIATION_HINTS
]
tool_differentiation_hint_string = (
"\n".join(tool_differentiations) or "(No differentiating hints available)"
)
# TODO: add tool deliniation pairs for custom tools as well
# tool_differentiations: list[str] = [
# TOOL_DIFFERENTIATION_HINTS[(tool_1, tool_2)]
# for tool_1 in available_tools
# for tool_2 in available_tools
# if (tool_1, tool_2) in TOOL_DIFFERENTIATION_HINTS
# ]
# tool_differentiation_hint_string = (
# "\n".join(tool_differentiations) or "(No differentiating hints available)"
# )
# # TODO: add tool deliniation pairs for custom tools as well
tool_question_hint_string = (
"\n".join(
"- " + TOOL_QUESTION_HINTS[tool]
for tool in available_tools
if tool in TOOL_QUESTION_HINTS
)
or "(No examples available)"
)
# tool_question_hint_string = (
# "\n".join(
# "- " + TOOL_QUESTION_HINTS[tool]
# for tool in available_tools
# if tool in TOOL_QUESTION_HINTS
# )
# or "(No examples available)"
# )
if DRPath.KNOWLEDGE_GRAPH.value in available_tools and (
entity_types_string or relationship_types_string
):
# if DRPath.KNOWLEDGE_GRAPH.value in available_tools and (
# entity_types_string or relationship_types_string
# ):
kg_types_descriptions = KG_TYPES_DESCRIPTIONS.build(
possible_entities=entity_types_string or "",
possible_relationships=relationship_types_string or "",
)
else:
kg_types_descriptions = "(The Knowledge Graph is not used.)"
# kg_types_descriptions = KG_TYPES_DESCRIPTIONS.build(
# possible_entities=entity_types_string or "",
# possible_relationships=relationship_types_string or "",
# )
# else:
# kg_types_descriptions = "(The Knowledge Graph is not used.)"
if purpose == DRPromptPurpose.PLAN:
if research_type == ResearchType.THOUGHTFUL:
raise ValueError("plan generation is not supported for FAST time budget")
base_template = ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
# if purpose == DRPromptPurpose.PLAN:
# if not use_agentic_search:
# raise ValueError("plan generation is only supported for agentic search")
# base_template = ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
elif purpose == DRPromptPurpose.NEXT_STEP_REASONING:
if research_type == ResearchType.THOUGHTFUL:
base_template = ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
else:
raise ValueError(
"reasoning is not separately required for DEEP time budget"
)
# elif purpose == DRPromptPurpose.NEXT_STEP_REASONING:
# if not use_agentic_search:
# base_template = ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
# else:
# raise ValueError(
# "reasoning is not separately required for agentic search"
# )
elif purpose == DRPromptPurpose.NEXT_STEP_PURPOSE:
base_template = ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
# elif purpose == DRPromptPurpose.NEXT_STEP_PURPOSE:
# base_template = ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
elif purpose == DRPromptPurpose.NEXT_STEP:
if research_type == ResearchType.THOUGHTFUL:
base_template = ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
else:
base_template = ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
# elif purpose == DRPromptPurpose.NEXT_STEP:
# if not use_agentic_search:
# base_template = ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
# else:
# base_template = ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
elif purpose == DRPromptPurpose.CLARIFICATION:
if research_type == ResearchType.THOUGHTFUL:
raise ValueError("clarification is not supported for FAST time budget")
base_template = GET_CLARIFICATION_PROMPT
# elif purpose == DRPromptPurpose.CLARIFICATION:
# if not use_agentic_search:
# raise ValueError("clarification is only supported for agentic search")
# base_template = GET_CLARIFICATION_PROMPT
else:
# for mypy, clearly a mypy bug
raise ValueError(f"Invalid purpose: {purpose}")
# else:
# # for mypy, clearly a mypy bug
# raise ValueError(f"Invalid purpose: {purpose}")
return base_template.partial_build(
num_available_tools=str(len(tool_names)),
available_tools=", ".join(tool_names),
tool_choice_options=" or ".join(tool_names),
current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
kg_types_descriptions=kg_types_descriptions,
tool_descriptions=tool_description_str,
tool_differentiation_hints=tool_differentiation_hint_string,
tool_question_hints=tool_question_hint_string,
average_tool_costs=tool_cost_str,
reasoning_result=reasoning_result or "(No reasoning result provided.)",
tool_calls_string=tool_calls_string or "(No tool calls provided.)",
)
# return base_template.partial_build(
# num_available_tools=str(len(tool_names)),
# available_tools=", ".join(tool_names),
# tool_choice_options=" or ".join(tool_names),
# current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
# kg_types_descriptions=kg_types_descriptions,
# tool_descriptions=tool_description_str,
# tool_differentiation_hints=tool_differentiation_hint_string,
# tool_question_hints=tool_question_hint_string,
# average_tool_costs=tool_cost_str,
# reasoning_result=reasoning_result or "(No reasoning result provided.)",
# tool_calls_string=tool_calls_string or "(No tool calls provided.)",
# )

View File

@@ -1,32 +1,22 @@
from enum import Enum
# from enum import Enum
class ResearchType(str, Enum):
"""Research type options for agent search operations"""
# class ResearchAnswerPurpose(str, Enum):
# """Research answer purpose options for agent search operations"""
# BASIC = "BASIC"
LEGACY_AGENTIC = "LEGACY_AGENTIC" # only used for legacy agentic search migrations
THOUGHTFUL = "THOUGHTFUL"
DEEP = "DEEP"
FAST = "FAST"
# ANSWER = "ANSWER"
# CLARIFICATION_REQUEST = "CLARIFICATION_REQUEST"
class ResearchAnswerPurpose(str, Enum):
"""Research answer purpose options for agent search operations"""
ANSWER = "ANSWER"
CLARIFICATION_REQUEST = "CLARIFICATION_REQUEST"
class DRPath(str, Enum):
CLARIFIER = "Clarifier"
ORCHESTRATOR = "Orchestrator"
INTERNAL_SEARCH = "Internal Search"
GENERIC_TOOL = "Generic Tool"
KNOWLEDGE_GRAPH = "Knowledge Graph Search"
WEB_SEARCH = "Web Search"
IMAGE_GENERATION = "Image Generation"
GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
CLOSER = "Closer"
LOGGER = "Logger"
END = "End"
# class DRPath(str, Enum):
# CLARIFIER = "Clarifier"
# ORCHESTRATOR = "Orchestrator"
# INTERNAL_SEARCH = "Internal Search"
# GENERIC_TOOL = "Generic Tool"
# KNOWLEDGE_GRAPH = "Knowledge Graph Search"
# WEB_SEARCH = "Web Search"
# IMAGE_GENERATION = "Image Generation"
# GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
# CLOSER = "Closer"
# LOGGER = "Logger"
# END = "End"

View File

@@ -1,88 +1,88 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.conditional_edges import completeness_router
from onyx.agents.agent_search.dr.conditional_edges import decision_router
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.nodes.dr_a0_clarification import clarifier
from onyx.agents.agent_search.dr.nodes.dr_a1_orchestrator import orchestrator
from onyx.agents.agent_search.dr.nodes.dr_a2_closer import closer
from onyx.agents.agent_search.dr.nodes.dr_a3_logger import logging
from onyx.agents.agent_search.dr.states import MainInput
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_graph_builder import (
dr_basic_search_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
dr_custom_tool_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_graph_builder import (
dr_generic_internal_tool_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_graph_builder import (
dr_image_generation_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
dr_kg_search_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_graph_builder import (
dr_ws_graph_builder,
)
# from onyx.agents.agent_search.dr.conditional_edges import completeness_router
# from onyx.agents.agent_search.dr.conditional_edges import decision_router
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.nodes.dr_a0_clarification import clarifier
# from onyx.agents.agent_search.dr.nodes.dr_a1_orchestrator import orchestrator
# from onyx.agents.agent_search.dr.nodes.dr_a2_closer import closer
# from onyx.agents.agent_search.dr.nodes.dr_a3_logger import logging
# from onyx.agents.agent_search.dr.states import MainInput
# from onyx.agents.agent_search.dr.states import MainState
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_graph_builder import (
# dr_basic_search_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
# dr_custom_tool_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_graph_builder import (
# dr_generic_internal_tool_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_graph_builder import (
# dr_image_generation_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
# dr_kg_search_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_graph_builder import (
# dr_ws_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import search
# # from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import search
def dr_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the deep research agent.
"""
# def dr_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for the deep research agent.
# """
graph = StateGraph(state_schema=MainState, input=MainInput)
# graph = StateGraph(state_schema=MainState, input=MainInput)
### Add nodes ###
# ### Add nodes ###
graph.add_node(DRPath.CLARIFIER, clarifier)
# graph.add_node(DRPath.CLARIFIER, clarifier)
graph.add_node(DRPath.ORCHESTRATOR, orchestrator)
# graph.add_node(DRPath.ORCHESTRATOR, orchestrator)
basic_search_graph = dr_basic_search_graph_builder().compile()
graph.add_node(DRPath.INTERNAL_SEARCH, basic_search_graph)
# basic_search_graph = dr_basic_search_graph_builder().compile()
# graph.add_node(DRPath.INTERNAL_SEARCH, basic_search_graph)
kg_search_graph = dr_kg_search_graph_builder().compile()
graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
# kg_search_graph = dr_kg_search_graph_builder().compile()
# graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
internet_search_graph = dr_ws_graph_builder().compile()
graph.add_node(DRPath.WEB_SEARCH, internet_search_graph)
# internet_search_graph = dr_ws_graph_builder().compile()
# graph.add_node(DRPath.WEB_SEARCH, internet_search_graph)
image_generation_graph = dr_image_generation_graph_builder().compile()
graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
# image_generation_graph = dr_image_generation_graph_builder().compile()
# graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
custom_tool_graph = dr_custom_tool_graph_builder().compile()
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
# custom_tool_graph = dr_custom_tool_graph_builder().compile()
# graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)
# generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
# graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)
graph.add_node(DRPath.CLOSER, closer)
graph.add_node(DRPath.LOGGER, logging)
# graph.add_node(DRPath.CLOSER, closer)
# graph.add_node(DRPath.LOGGER, logging)
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key=DRPath.CLARIFIER)
# graph.add_edge(start_key=START, end_key=DRPath.CLARIFIER)
graph.add_conditional_edges(DRPath.CLARIFIER, decision_router)
# graph.add_conditional_edges(DRPath.CLARIFIER, decision_router)
graph.add_conditional_edges(DRPath.ORCHESTRATOR, decision_router)
# graph.add_conditional_edges(DRPath.ORCHESTRATOR, decision_router)
graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.WEB_SEARCH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.WEB_SEARCH, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
graph.add_edge(start_key=DRPath.LOGGER, end_key=END)
# graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
# graph.add_edge(start_key=DRPath.LOGGER, end_key=END)
return graph
# return graph

View File

@@ -1,131 +1,131 @@
from enum import Enum
# from enum import Enum
from pydantic import BaseModel
from pydantic import ConfigDict
# from pydantic import BaseModel
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
GeneratedImage,
)
from onyx.context.search.models import InferenceSection
from onyx.tools.tool import Tool
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
# GeneratedImage,
# )
# from onyx.context.search.models import InferenceSection
# from onyx.tools.tool import Tool
class OrchestratorStep(BaseModel):
tool: str
questions: list[str]
# class OrchestratorStep(BaseModel):
# tool: str
# questions: list[str]
class OrchestratorDecisonsNoPlan(BaseModel):
reasoning: str
next_step: OrchestratorStep
# class OrchestratorDecisonsNoPlan(BaseModel):
# reasoning: str
# next_step: OrchestratorStep
class OrchestrationPlan(BaseModel):
reasoning: str
plan: str
# class OrchestrationPlan(BaseModel):
# reasoning: str
# plan: str
class ClarificationGenerationResponse(BaseModel):
clarification_needed: bool
clarification_question: str
# class ClarificationGenerationResponse(BaseModel):
# clarification_needed: bool
# clarification_question: str
class DecisionResponse(BaseModel):
reasoning: str
decision: str
# class DecisionResponse(BaseModel):
# reasoning: str
# decision: str
class QueryEvaluationResponse(BaseModel):
reasoning: str
query_permitted: bool
# class QueryEvaluationResponse(BaseModel):
# reasoning: str
# query_permitted: bool
class OrchestrationClarificationInfo(BaseModel):
clarification_question: str
clarification_response: str | None = None
# class OrchestrationClarificationInfo(BaseModel):
# clarification_question: str
# clarification_response: str | None = None
class WebSearchAnswer(BaseModel):
urls_to_open_indices: list[int]
# class WebSearchAnswer(BaseModel):
# urls_to_open_indices: list[int]
class SearchAnswer(BaseModel):
reasoning: str
answer: str
claims: list[str] | None = None
# class SearchAnswer(BaseModel):
# reasoning: str
# answer: str
# claims: list[str] | None = None
class TestInfoCompleteResponse(BaseModel):
reasoning: str
complete: bool
gaps: list[str]
# class TestInfoCompleteResponse(BaseModel):
# reasoning: str
# complete: bool
# gaps: list[str]
# TODO: revisit with custom tools implementation in v2
# each tool should be a class with the attributes below, plus the actual tool implementation
# this will also allow custom tools to have their own cost
class OrchestratorTool(BaseModel):
tool_id: int
name: str
llm_path: str # the path for the LLM to refer by
path: DRPath # the actual path in the graph
description: str
metadata: dict[str, str]
cost: float
tool_object: Tool | None = None # None for CLOSER
# # TODO: revisit with custom tools implementation in v2
# # each tool should be a class with the attributes below, plus the actual tool implementation
# # this will also allow custom tools to have their own cost
# class OrchestratorTool(BaseModel):
# tool_id: int
# name: str
# llm_path: str # the path for the LLM to refer by
# path: DRPath # the actual path in the graph
# description: str
# metadata: dict[str, str]
# cost: float
# tool_object: Tool | None = None # None for CLOSER
model_config = ConfigDict(arbitrary_types_allowed=True)
# class Config:
# arbitrary_types_allowed = True
class IterationInstructions(BaseModel):
iteration_nr: int
plan: str | None
reasoning: str
purpose: str
# class IterationInstructions(BaseModel):
# iteration_nr: int
# plan: str | None
# reasoning: str
# purpose: str
class IterationAnswer(BaseModel):
tool: str
tool_id: int
iteration_nr: int
parallelization_nr: int
question: str
reasoning: str | None
answer: str
cited_documents: dict[int, InferenceSection]
background_info: str | None = None
claims: list[str] | None = None
additional_data: dict[str, str] | None = None
response_type: str | None = None
data: dict | list | str | int | float | bool | None = None
file_ids: list[str] | None = None
# TODO: This is not ideal, but we'll can rework the schema
# for deep research later
is_web_fetch: bool = False
# for image generation step-types
generated_images: list[GeneratedImage] | None = None
# for multi-query search tools (v2 web search and internal search)
# TODO: Clean this up to be more flexible to tools
queries: list[str] | None = None
# class IterationAnswer(BaseModel):
# tool: str
# tool_id: int
# iteration_nr: int
# parallelization_nr: int
# question: str
# reasoning: str | None
# answer: str
# cited_documents: dict[int, InferenceSection]
# background_info: str | None = None
# claims: list[str] | None = None
# additional_data: dict[str, str] | None = None
# response_type: str | None = None
# data: dict | list | str | int | float | bool | None = None
# file_ids: list[str] | None = None
# # TODO: This is not ideal, but we'll can rework the schema
# # for deep research later
# is_web_fetch: bool = False
# # for image generation step-types
# generated_images: list[GeneratedImage] | None = None
# # for multi-query search tools (v2 web search and internal search)
# # TODO: Clean this up to be more flexible to tools
# queries: list[str] | None = None
class AggregatedDRContext(BaseModel):
context: str
cited_documents: list[InferenceSection]
is_internet_marker_dict: dict[str, bool]
global_iteration_responses: list[IterationAnswer]
# class AggregatedDRContext(BaseModel):
# context: str
# cited_documents: list[InferenceSection]
# is_internet_marker_dict: dict[str, bool]
# global_iteration_responses: list[IterationAnswer]
class DRPromptPurpose(str, Enum):
PLAN = "PLAN"
NEXT_STEP = "NEXT_STEP"
NEXT_STEP_REASONING = "NEXT_STEP_REASONING"
NEXT_STEP_PURPOSE = "NEXT_STEP_PURPOSE"
CLARIFICATION = "CLARIFICATION"
# class DRPromptPurpose(str, Enum):
# PLAN = "PLAN"
# NEXT_STEP = "NEXT_STEP"
# NEXT_STEP_REASONING = "NEXT_STEP_REASONING"
# NEXT_STEP_PURPOSE = "NEXT_STEP_PURPOSE"
# CLARIFICATION = "CLARIFICATION"
class BaseSearchProcessingResponse(BaseModel):
specified_source_types: list[str]
rewritten_query: str
time_filter: str
# class BaseSearchProcessingResponse(BaseModel):
# specified_source_types: list[str]
# rewritten_query: str
# time_filter: str

File diff suppressed because it is too large Load Diff

View File

@@ -1,423 +1,418 @@
import re
from datetime import datetime
from typing import cast
# import re
# from datetime import datetime
# from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
# from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
from onyx.agents.agent_search.dr.constants import MAX_NUM_CLOSER_SUGGESTIONS
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.models import TestInfoCompleteResponse
from onyx.agents.agent_search.dr.states import FinalUpdate
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.states import OrchestrationUpdate
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
GeneratedImageFullResult,
)
from onyx.agents.agent_search.dr.utils import aggregate_context
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.dr.utils import get_chat_history_string
from onyx.agents.agent_search.dr.utils import get_prompt_question
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.chat.chat_utils import llm_doc_from_inference_section
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
from onyx.context.search.models import InferenceSection
from onyx.db.chat import create_search_doc_from_inference_section
from onyx.db.chat import update_db_session_with_messages
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ResearchAgentIteration
from onyx.db.models import ResearchAgentIterationSubStep
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.llm.utils import check_number_of_tokens
from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.server.query_and_chat.streaming_models import StreamingType
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
# from onyx.agents.agent_search.dr.constants import MAX_NUM_CLOSER_SUGGESTIONS
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
# from onyx.agents.agent_search.dr.models import AggregatedDRContext
# from onyx.agents.agent_search.dr.models import TestInfoCompleteResponse
# from onyx.agents.agent_search.dr.states import FinalUpdate
# from onyx.agents.agent_search.dr.states import MainState
# from onyx.agents.agent_search.dr.states import OrchestrationUpdate
# from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
# GeneratedImageFullResult,
# )
# from onyx.agents.agent_search.dr.utils import aggregate_context
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.dr.utils import get_chat_history_string
# from onyx.agents.agent_search.dr.utils import get_prompt_question
# from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
# from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.agents.agent_search.utils import create_question_prompt
# from onyx.chat.chat_utils import llm_doc_from_inference_section
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
# from onyx.context.search.models import InferenceSection
# from onyx.db.chat import create_search_doc_from_inference_section
# from onyx.db.chat import update_db_session_with_messages
# from onyx.db.models import ChatMessage__SearchDoc
# from onyx.db.models import ResearchAgentIteration
# from onyx.db.models import ResearchAgentIterationSubStep
# from onyx.db.models import SearchDoc as DbSearchDoc
# from onyx.llm.utils import check_number_of_tokens
# from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
# from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
# from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
# from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
# from onyx.server.query_and_chat.streaming_models import CitationDelta
# from onyx.server.query_and_chat.streaming_models import CitationStart
# from onyx.server.query_and_chat.streaming_models import MessageStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.server.query_and_chat.streaming_models import StreamingType
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
# logger = setup_logger()
def extract_citation_numbers(text: str) -> list[int]:
"""
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
Returns a list of all unique citation numbers found.
"""
# Pattern to match [[number]] or [[number1, number2, ...]]
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
matches = re.findall(pattern, text)
# def extract_citation_numbers(text: str) -> list[int]:
# """
# Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
# Returns a list of all unique citation numbers found.
# """
# # Pattern to match [[number]] or [[number1, number2, ...]]
# pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
# matches = re.findall(pattern, text)
cited_numbers = []
for match in matches:
# Split by comma and extract all numbers
numbers = [int(num.strip()) for num in match.split(",")]
cited_numbers.extend(numbers)
# cited_numbers = []
# for match in matches:
# # Split by comma and extract all numbers
# numbers = [int(num.strip()) for num in match.split(",")]
# cited_numbers.extend(numbers)
return list(set(cited_numbers)) # Return unique numbers
# return list(set(cited_numbers)) # Return unique numbers
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
numbers = [int(num.strip()) for num in citation_content.split(",")]
# def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
# citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
# numbers = [int(num.strip()) for num in citation_content.split(",")]
# For multiple citations like [[3, 5, 7]], create separate linked citations
linked_citations = []
for num in numbers:
if num - 1 < len(docs): # Check bounds
link = docs[num - 1].link or ""
linked_citations.append(f"[[{num}]]({link})")
else:
linked_citations.append(f"[[{num}]]") # No link if out of bounds
# # For multiple citations like [[3, 5, 7]], create separate linked citations
# linked_citations = []
# for num in numbers:
# if num - 1 < len(docs): # Check bounds
# link = docs[num - 1].link or ""
# linked_citations.append(f"[[{num}]]({link})")
# else:
# linked_citations.append(f"[[{num}]]") # No link if out of bounds
return "".join(linked_citations)
# return "".join(linked_citations)
def insert_chat_message_search_doc_pair(
message_id: int, search_doc_ids: list[int], db_session: Session
) -> None:
"""
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
# def insert_chat_message_search_doc_pair(
# message_id: int, search_doc_ids: list[int], db_session: Session
# ) -> None:
# """
# Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
Args:
message_id: The ID of the chat message
search_doc_id: The ID of the search document
db_session: The database session
"""
for search_doc_id in search_doc_ids:
chat_message_search_doc = ChatMessage__SearchDoc(
chat_message_id=message_id, search_doc_id=search_doc_id
)
db_session.add(chat_message_search_doc)
# Args:
# message_id: The ID of the chat message
# search_doc_id: The ID of the search document
# db_session: The database session
# """
# for search_doc_id in search_doc_ids:
# chat_message_search_doc = ChatMessage__SearchDoc(
# chat_message_id=message_id, search_doc_id=search_doc_id
# )
# db_session.add(chat_message_search_doc)
def save_iteration(
state: MainState,
graph_config: GraphConfig,
aggregated_context: AggregatedDRContext,
final_answer: str,
all_cited_documents: list[InferenceSection],
is_internet_marker_dict: dict[str, bool],
) -> None:
db_session = graph_config.persistence.db_session
message_id = graph_config.persistence.message_id
research_type = graph_config.behavior.research_type
db_session = graph_config.persistence.db_session
# def save_iteration(
# state: MainState,
# graph_config: GraphConfig,
# aggregated_context: AggregatedDRContext,
# final_answer: str,
# all_cited_documents: list[InferenceSection],
# is_internet_marker_dict: dict[str, bool],
# ) -> None:
# db_session = graph_config.persistence.db_session
# message_id = graph_config.persistence.message_id
# first, insert the search_docs
search_docs = [
create_search_doc_from_inference_section(
inference_section=inference_section,
is_internet=is_internet_marker_dict.get(
inference_section.center_chunk.document_id, False
), # TODO: revisit
db_session=db_session,
commit=False,
)
for inference_section in all_cited_documents
]
# # first, insert the search_docs
# search_docs = [
# create_search_doc_from_inference_section(
# inference_section=inference_section,
# is_internet=is_internet_marker_dict.get(
# inference_section.center_chunk.document_id, False
# ), # TODO: revisit
# db_session=db_session,
# commit=False,
# )
# for inference_section in all_cited_documents
# ]
# then, map_search_docs to message
insert_chat_message_search_doc_pair(
message_id, [search_doc.id for search_doc in search_docs], db_session
)
# # then, map_search_docs to message
# insert_chat_message_search_doc_pair(
# message_id, [search_doc.id for search_doc in search_docs], db_session
# )
# lastly, insert the citations
citation_dict: dict[int, int] = {}
cited_doc_nrs = extract_citation_numbers(final_answer)
for cited_doc_nr in cited_doc_nrs:
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# # lastly, insert the citations
# citation_dict: dict[int, int] = {}
# cited_doc_nrs = extract_citation_numbers(final_answer)
# for cited_doc_nr in cited_doc_nrs:
# citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# TODO: generate plan as dict in the first place
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# # TODO: generate plan as dict in the first place
# plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
# plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# Update the chat message and its parent message in database
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=graph_config.persistence.chat_session_id,
is_agentic=graph_config.behavior.use_agentic_search,
message=final_answer,
citations=citation_dict,
research_type=research_type,
research_plan=plan_of_record_dict,
final_documents=search_docs,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
)
# # Update the chat message and its parent message in database
# update_db_session_with_messages(
# db_session=db_session,
# chat_message_id=message_id,
# chat_session_id=graph_config.persistence.chat_session_id,
# is_agentic=graph_config.behavior.use_agentic_search,
# message=final_answer,
# citations=citation_dict,
# research_type=None, # research_type is deprecated
# research_plan=plan_of_record_dict,
# final_documents=search_docs,
# update_parent_message=True,
# research_answer_purpose=ResearchAnswerPurpose.ANSWER,
# )
for iteration_preparation in state.iteration_instructions:
research_agent_iteration_step = ResearchAgentIteration(
primary_question_id=message_id,
reasoning=iteration_preparation.reasoning,
purpose=iteration_preparation.purpose,
iteration_nr=iteration_preparation.iteration_nr,
)
db_session.add(research_agent_iteration_step)
# for iteration_preparation in state.iteration_instructions:
# research_agent_iteration_step = ResearchAgentIteration(
# primary_question_id=message_id,
# reasoning=iteration_preparation.reasoning,
# purpose=iteration_preparation.purpose,
# iteration_nr=iteration_preparation.iteration_nr,
# )
# db_session.add(research_agent_iteration_step)
for iteration_answer in aggregated_context.global_iteration_responses:
# for iteration_answer in aggregated_context.global_iteration_responses:
retrieved_search_docs = convert_inference_sections_to_search_docs(
list(iteration_answer.cited_documents.values())
)
# retrieved_search_docs = convert_inference_sections_to_search_docs(
# list(iteration_answer.cited_documents.values())
# )
# Convert SavedSearchDoc objects to JSON-serializable format
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
# # Convert SavedSearchDoc objects to JSON-serializable format
# serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
primary_question_id=message_id,
iteration_nr=iteration_answer.iteration_nr,
iteration_sub_step_nr=iteration_answer.parallelization_nr,
sub_step_instructions=iteration_answer.question,
sub_step_tool_id=iteration_answer.tool_id,
sub_answer=iteration_answer.answer,
reasoning=iteration_answer.reasoning,
claims=iteration_answer.claims,
cited_doc_results=serialized_search_docs,
generated_images=(
GeneratedImageFullResult(images=iteration_answer.generated_images)
if iteration_answer.generated_images
else None
),
additional_data=iteration_answer.additional_data,
queries=iteration_answer.queries,
)
db_session.add(research_agent_iteration_sub_step)
# research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
# primary_question_id=message_id,
# iteration_nr=iteration_answer.iteration_nr,
# iteration_sub_step_nr=iteration_answer.parallelization_nr,
# sub_step_instructions=iteration_answer.question,
# sub_step_tool_id=iteration_answer.tool_id,
# sub_answer=iteration_answer.answer,
# reasoning=iteration_answer.reasoning,
# claims=iteration_answer.claims,
# cited_doc_results=serialized_search_docs,
# generated_images=(
# GeneratedImageFullResult(images=iteration_answer.generated_images)
# if iteration_answer.generated_images
# else None
# ),
# additional_data=iteration_answer.additional_data,
# queries=iteration_answer.queries,
# )
# db_session.add(research_agent_iteration_sub_step)
db_session.commit()
# db_session.commit()
def closer(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> FinalUpdate | OrchestrationUpdate:
"""
LangGraph node to close the DR process and finalize the answer.
"""
# def closer(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> FinalUpdate | OrchestrationUpdate:
# """
# LangGraph node to close the DR process and finalize the answer.
# """
node_start_time = datetime.now()
# TODO: generate final answer using all the previous steps
# (right now, answers from each step are concatenated onto each other)
# Also, add missing fields once usage in UI is clear.
# node_start_time = datetime.now()
# # TODO: generate final answer using all the previous steps
# # (right now, answers from each step are concatenated onto each other)
# # Also, add missing fields once usage in UI is clear.
current_step_nr = state.current_step_nr
# current_step_nr = state.current_step_nr
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = state.original_question
if not base_question:
raise ValueError("Question is required for closer")
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = state.original_question
# if not base_question:
# raise ValueError("Question is required for closer")
research_type = graph_config.behavior.research_type
# use_agentic_search = graph_config.behavior.use_agentic_search
assistant_system_prompt: str = state.assistant_system_prompt or ""
assistant_task_prompt = state.assistant_task_prompt
# assistant_system_prompt: str = state.assistant_system_prompt or ""
# assistant_task_prompt = state.assistant_task_prompt
uploaded_context = state.uploaded_test_context or ""
# uploaded_context = state.uploaded_test_context or ""
clarification = state.clarification
prompt_question = get_prompt_question(base_question, clarification)
# clarification = state.clarification
# prompt_question = get_prompt_question(base_question, clarification)
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
# chat_history_string = (
# get_chat_history_string(
# graph_config.inputs.prompt_builder.message_history,
# MAX_CHAT_HISTORY_MESSAGES,
# )
# or "(No chat history yet available)"
# )
aggregated_context_w_docs = aggregate_context(
state.iteration_responses, include_documents=True
)
# aggregated_context_w_docs = aggregate_context(
# state.iteration_responses, include_documents=True
# )
aggregated_context_wo_docs = aggregate_context(
state.iteration_responses, include_documents=False
)
# aggregated_context_wo_docs = aggregate_context(
# state.iteration_responses, include_documents=False
# )
iteration_responses_w_docs_string = aggregated_context_w_docs.context
iteration_responses_wo_docs_string = aggregated_context_wo_docs.context
all_cited_documents = aggregated_context_w_docs.cited_documents
# iteration_responses_w_docs_string = aggregated_context_w_docs.context
# iteration_responses_wo_docs_string = aggregated_context_wo_docs.context
# all_cited_documents = aggregated_context_w_docs.cited_documents
num_closer_suggestions = state.num_closer_suggestions
# num_closer_suggestions = state.num_closer_suggestions
if (
num_closer_suggestions < MAX_NUM_CLOSER_SUGGESTIONS
and research_type == ResearchType.DEEP
):
test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
base_question=prompt_question,
questions_answers_claims=iteration_responses_wo_docs_string,
chat_history_string=chat_history_string,
high_level_plan=(
state.plan_of_record.plan
if state.plan_of_record
else "No plan available"
),
)
# if (
# num_closer_suggestions < MAX_NUM_CLOSER_SUGGESTIONS
# and use_agentic_search
# ):
# test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
# base_question=prompt_question,
# questions_answers_claims=iteration_responses_wo_docs_string,
# chat_history_string=chat_history_string,
# high_level_plan=(
# state.plan_of_record.plan
# if state.plan_of_record
# else "No plan available"
# ),
# )
test_info_complete_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
test_info_complete_prompt + (assistant_task_prompt or ""),
),
schema=TestInfoCompleteResponse,
timeout_override=TF_DR_TIMEOUT_LONG,
# max_tokens=1000,
)
# test_info_complete_json = invoke_llm_json(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt,
# test_info_complete_prompt + (assistant_task_prompt or ""),
# ),
# schema=TestInfoCompleteResponse,
# timeout_override=TF_DR_TIMEOUT_LONG,
# # max_tokens=1000,
# )
if test_info_complete_json.complete:
pass
# if test_info_complete_json.complete:
# pass
else:
return OrchestrationUpdate(
tools_used=[DRPath.ORCHESTRATOR.value],
query_list=[],
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="closer",
node_start_time=node_start_time,
)
],
gaps=test_info_complete_json.gaps,
num_closer_suggestions=num_closer_suggestions + 1,
)
# else:
# return OrchestrationUpdate(
# tools_used=[DRPath.ORCHESTRATOR.value],
# query_list=[],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="closer",
# node_start_time=node_start_time,
# )
# ],
# gaps=test_info_complete_json.gaps,
# num_closer_suggestions=num_closer_suggestions + 1,
# )
retrieved_search_docs = convert_inference_sections_to_search_docs(
all_cited_documents
)
# retrieved_search_docs = convert_inference_sections_to_search_docs(
# all_cited_documents
# )
write_custom_event(
current_step_nr,
MessageStart(
content="",
final_documents=retrieved_search_docs,
),
writer,
)
# write_custom_event(
# current_step_nr,
# MessageStart(
# content="",
# final_documents=retrieved_search_docs,
# ),
# writer,
# )
if research_type in [ResearchType.THOUGHTFUL, ResearchType.FAST]:
final_answer_base_prompt = FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
elif research_type == ResearchType.DEEP:
final_answer_base_prompt = FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
else:
raise ValueError(f"Invalid research type: {research_type}")
# if not use_agentic_search:
# final_answer_base_prompt = FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
# else:
# final_answer_base_prompt = FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
estimated_final_answer_prompt_tokens = check_number_of_tokens(
final_answer_base_prompt.build(
base_question=prompt_question,
iteration_responses_string=iteration_responses_w_docs_string,
chat_history_string=chat_history_string,
uploaded_context=uploaded_context,
)
)
# estimated_final_answer_prompt_tokens = check_number_of_tokens(
# final_answer_base_prompt.build(
# base_question=prompt_question,
# iteration_responses_string=iteration_responses_w_docs_string,
# chat_history_string=chat_history_string,
# uploaded_context=uploaded_context,
# )
# )
# for DR, rely only on sub-answers and claims to save tokens if context is too long
# TODO: consider compression step for Thoughtful mode if context is too long.
# Should generally not be the case though.
# # for DR, rely only on sub-answers and claims to save tokens if context is too long
# # TODO: consider compression step for Thoughtful mode if context is too long.
# # Should generally not be the case though.
max_allowed_input_tokens = graph_config.tooling.primary_llm.config.max_input_tokens
# max_allowed_input_tokens = graph_config.tooling.primary_llm.config.max_input_tokens
if (
estimated_final_answer_prompt_tokens > 0.8 * max_allowed_input_tokens
and research_type == ResearchType.DEEP
):
iteration_responses_string = iteration_responses_wo_docs_string
else:
iteration_responses_string = iteration_responses_w_docs_string
# if (
# estimated_final_answer_prompt_tokens > 0.8 * max_allowed_input_tokens
# and use_agentic_search
# ):
# iteration_responses_string = iteration_responses_wo_docs_string
# else:
# iteration_responses_string = iteration_responses_w_docs_string
final_answer_prompt = final_answer_base_prompt.build(
base_question=prompt_question,
iteration_responses_string=iteration_responses_string,
chat_history_string=chat_history_string,
uploaded_context=uploaded_context,
)
# final_answer_prompt = final_answer_base_prompt.build(
# base_question=prompt_question,
# iteration_responses_string=iteration_responses_string,
# chat_history_string=chat_history_string,
# uploaded_context=uploaded_context,
# )
if graph_config.inputs.project_instructions:
assistant_system_prompt = (
assistant_system_prompt
+ PROJECT_INSTRUCTIONS_SEPARATOR
+ (graph_config.inputs.project_instructions or "")
)
# if graph_config.inputs.project_instructions:
# assistant_system_prompt = (
# assistant_system_prompt
# + PROJECT_INSTRUCTIONS_SEPARATOR
# + (graph_config.inputs.project_instructions or "")
# )
all_context_llmdocs = [
llm_doc_from_inference_section(inference_section)
for inference_section in all_cited_documents
]
# all_context_llmdocs = [
# llm_doc_from_inference_section(inference_section)
# for inference_section in all_cited_documents
# ]
try:
streamed_output, _, citation_infos = run_with_timeout(
int(3 * TF_DR_TIMEOUT_LONG),
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
final_answer_prompt + (assistant_task_prompt or ""),
),
event_name="basic_response",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=int(2 * TF_DR_TIMEOUT_LONG),
answer_piece=StreamingType.MESSAGE_DELTA.value,
ind=current_step_nr,
context_docs=all_context_llmdocs,
replace_citations=True,
# max_tokens=None,
),
)
# try:
# streamed_output, _, citation_infos = run_with_timeout(
# int(3 * TF_DR_TIMEOUT_LONG),
# lambda: stream_llm_answer(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt,
# final_answer_prompt + (assistant_task_prompt or ""),
# ),
# event_name="basic_response",
# writer=writer,
# agent_answer_level=0,
# agent_answer_question_num=0,
# agent_answer_type="agent_level_answer",
# timeout_override=int(2 * TF_DR_TIMEOUT_LONG),
# answer_piece=StreamingType.MESSAGE_DELTA.value,
# ind=current_step_nr,
# context_docs=all_context_llmdocs,
# replace_citations=True,
# # max_tokens=None,
# ),
# )
final_answer = "".join(streamed_output)
except Exception as e:
raise ValueError(f"Error in consolidate_research: {e}")
# final_answer = "".join(streamed_output)
# except Exception as e:
# raise ValueError(f"Error in consolidate_research: {e}")
write_custom_event(current_step_nr, SectionEnd(), writer)
# write_custom_event(current_step_nr, SectionEnd(), writer)
current_step_nr += 1
# current_step_nr += 1
write_custom_event(current_step_nr, CitationStart(), writer)
write_custom_event(current_step_nr, CitationDelta(citations=citation_infos), writer)
write_custom_event(current_step_nr, SectionEnd(), writer)
# write_custom_event(current_step_nr, CitationStart(), writer)
# write_custom_event(current_step_nr, CitationDelta(citations=citation_infos), writer)
# write_custom_event(current_step_nr, SectionEnd(), writer)
current_step_nr += 1
# current_step_nr += 1
# Log the research agent steps
# save_iteration(
# state,
# graph_config,
# aggregated_context,
# final_answer,
# all_cited_documents,
# is_internet_marker_dict,
# )
# # Log the research agent steps
# # save_iteration(
# # state,
# # graph_config,
# # aggregated_context,
# # final_answer,
# # all_cited_documents,
# # is_internet_marker_dict,
# # )
return FinalUpdate(
final_answer=final_answer,
all_cited_documents=all_cited_documents,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="closer",
node_start_time=node_start_time,
)
],
)
# return FinalUpdate(
# final_answer=final_answer,
# all_cited_documents=all_cited_documents,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="closer",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,248 +1,246 @@
import re
from datetime import datetime
from typing import cast
# import re
# from datetime import datetime
# from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
# from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
GeneratedImageFullResult,
)
from onyx.agents.agent_search.dr.utils import aggregate_context
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.context.search.models import InferenceSection
from onyx.db.chat import create_search_doc_from_inference_section
from onyx.db.chat import update_db_session_with_messages
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ResearchAgentIteration
from onyx.db.models import ResearchAgentIterationSubStep
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
# from onyx.agents.agent_search.dr.models import AggregatedDRContext
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.states import MainState
# from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
# GeneratedImageFullResult,
# )
# from onyx.agents.agent_search.dr.utils import aggregate_context
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.context.search.models import InferenceSection
# from onyx.db.chat import create_search_doc_from_inference_section
# from onyx.db.chat import update_db_session_with_messages
# from onyx.db.models import ChatMessage__SearchDoc
# from onyx.db.models import ResearchAgentIteration
# from onyx.db.models import ResearchAgentIterationSubStep
# from onyx.db.models import SearchDoc as DbSearchDoc
# from onyx.natural_language_processing.utils import get_tokenizer
# from onyx.server.query_and_chat.streaming_models import OverallStop
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def _extract_citation_numbers(text: str) -> list[int]:
"""
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
Returns a list of all unique citation numbers found.
"""
# Pattern to match [[number]] or [[number1, number2, ...]]
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
matches = re.findall(pattern, text)
# def _extract_citation_numbers(text: str) -> list[int]:
# """
# Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
# Returns a list of all unique citation numbers found.
# """
# # Pattern to match [[number]] or [[number1, number2, ...]]
# pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
# matches = re.findall(pattern, text)
cited_numbers = []
for match in matches:
# Split by comma and extract all numbers
numbers = [int(num.strip()) for num in match.split(",")]
cited_numbers.extend(numbers)
# cited_numbers = []
# for match in matches:
# # Split by comma and extract all numbers
# numbers = [int(num.strip()) for num in match.split(",")]
# cited_numbers.extend(numbers)
return list(set(cited_numbers)) # Return unique numbers
# return list(set(cited_numbers)) # Return unique numbers
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
numbers = [int(num.strip()) for num in citation_content.split(",")]
# def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
# citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
# numbers = [int(num.strip()) for num in citation_content.split(",")]
# For multiple citations like [[3, 5, 7]], create separate linked citations
linked_citations = []
for num in numbers:
if num - 1 < len(docs): # Check bounds
link = docs[num - 1].link or ""
linked_citations.append(f"[[{num}]]({link})")
else:
linked_citations.append(f"[[{num}]]") # No link if out of bounds
# # For multiple citations like [[3, 5, 7]], create separate linked citations
# linked_citations = []
# for num in numbers:
# if num - 1 < len(docs): # Check bounds
# link = docs[num - 1].link or ""
# linked_citations.append(f"[[{num}]]({link})")
# else:
# linked_citations.append(f"[[{num}]]") # No link if out of bounds
return "".join(linked_citations)
# return "".join(linked_citations)
def _insert_chat_message_search_doc_pair(
message_id: int, search_doc_ids: list[int], db_session: Session
) -> None:
"""
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
# def _insert_chat_message_search_doc_pair(
# message_id: int, search_doc_ids: list[int], db_session: Session
# ) -> None:
# """
# Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
Args:
message_id: The ID of the chat message
search_doc_id: The ID of the search document
db_session: The database session
"""
for search_doc_id in search_doc_ids:
chat_message_search_doc = ChatMessage__SearchDoc(
chat_message_id=message_id, search_doc_id=search_doc_id
)
db_session.add(chat_message_search_doc)
# Args:
# message_id: The ID of the chat message
# search_doc_id: The ID of the search document
# db_session: The database session
# """
# for search_doc_id in search_doc_ids:
# chat_message_search_doc = ChatMessage__SearchDoc(
# chat_message_id=message_id, search_doc_id=search_doc_id
# )
# db_session.add(chat_message_search_doc)
def save_iteration(
state: MainState,
graph_config: GraphConfig,
aggregated_context: AggregatedDRContext,
final_answer: str,
all_cited_documents: list[InferenceSection],
is_internet_marker_dict: dict[str, bool],
num_tokens: int,
) -> None:
db_session = graph_config.persistence.db_session
message_id = graph_config.persistence.message_id
research_type = graph_config.behavior.research_type
db_session = graph_config.persistence.db_session
# def save_iteration(
# state: MainState,
# graph_config: GraphConfig,
# aggregated_context: AggregatedDRContext,
# final_answer: str,
# all_cited_documents: list[InferenceSection],
# is_internet_marker_dict: dict[str, bool],
# num_tokens: int,
# ) -> None:
# db_session = graph_config.persistence.db_session
# message_id = graph_config.persistence.message_id
# first, insert the search_docs
search_docs = [
create_search_doc_from_inference_section(
inference_section=inference_section,
is_internet=is_internet_marker_dict.get(
inference_section.center_chunk.document_id, False
), # TODO: revisit
db_session=db_session,
commit=False,
)
for inference_section in all_cited_documents
]
# # first, insert the search_docs
# search_docs = [
# create_search_doc_from_inference_section(
# inference_section=inference_section,
# is_internet=is_internet_marker_dict.get(
# inference_section.center_chunk.document_id, False
# ), # TODO: revisit
# db_session=db_session,
# commit=False,
# )
# for inference_section in all_cited_documents
# ]
# then, map_search_docs to message
_insert_chat_message_search_doc_pair(
message_id, [search_doc.id for search_doc in search_docs], db_session
)
# # then, map_search_docs to message
# _insert_chat_message_search_doc_pair(
# message_id, [search_doc.id for search_doc in search_docs], db_session
# )
# lastly, insert the citations
citation_dict: dict[int, int] = {}
cited_doc_nrs = _extract_citation_numbers(final_answer)
if search_docs:
for cited_doc_nr in cited_doc_nrs:
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# # lastly, insert the citations
# citation_dict: dict[int, int] = {}
# cited_doc_nrs = _extract_citation_numbers(final_answer)
# if search_docs:
# for cited_doc_nr in cited_doc_nrs:
# citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# TODO: generate plan as dict in the first place
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# # TODO: generate plan as dict in the first place
# plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
# plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# Update the chat message and its parent message in database
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=graph_config.persistence.chat_session_id,
is_agentic=graph_config.behavior.use_agentic_search,
message=final_answer,
citations=citation_dict,
research_type=research_type,
research_plan=plan_of_record_dict,
final_documents=search_docs,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
token_count=num_tokens,
)
# # Update the chat message and its parent message in database
# update_db_session_with_messages(
# db_session=db_session,
# chat_message_id=message_id,
# chat_session_id=graph_config.persistence.chat_session_id,
# is_agentic=graph_config.behavior.use_agentic_search,
# message=final_answer,
# citations=citation_dict,
# research_type=None, # research_type is deprecated
# research_plan=plan_of_record_dict,
# final_documents=search_docs,
# update_parent_message=True,
# research_answer_purpose=ResearchAnswerPurpose.ANSWER,
# token_count=num_tokens,
# )
for iteration_preparation in state.iteration_instructions:
research_agent_iteration_step = ResearchAgentIteration(
primary_question_id=message_id,
reasoning=iteration_preparation.reasoning,
purpose=iteration_preparation.purpose,
iteration_nr=iteration_preparation.iteration_nr,
)
db_session.add(research_agent_iteration_step)
# for iteration_preparation in state.iteration_instructions:
# research_agent_iteration_step = ResearchAgentIteration(
# primary_question_id=message_id,
# reasoning=iteration_preparation.reasoning,
# purpose=iteration_preparation.purpose,
# iteration_nr=iteration_preparation.iteration_nr,
# )
# db_session.add(research_agent_iteration_step)
for iteration_answer in aggregated_context.global_iteration_responses:
# for iteration_answer in aggregated_context.global_iteration_responses:
retrieved_search_docs = convert_inference_sections_to_search_docs(
list(iteration_answer.cited_documents.values())
)
# retrieved_search_docs = convert_inference_sections_to_search_docs(
# list(iteration_answer.cited_documents.values())
# )
# Convert SavedSearchDoc objects to JSON-serializable format
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
# # Convert SavedSearchDoc objects to JSON-serializable format
# serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
primary_question_id=message_id,
iteration_nr=iteration_answer.iteration_nr,
iteration_sub_step_nr=iteration_answer.parallelization_nr,
sub_step_instructions=iteration_answer.question,
sub_step_tool_id=iteration_answer.tool_id,
sub_answer=iteration_answer.answer,
reasoning=iteration_answer.reasoning,
claims=iteration_answer.claims,
cited_doc_results=serialized_search_docs,
generated_images=(
GeneratedImageFullResult(images=iteration_answer.generated_images)
if iteration_answer.generated_images
else None
),
additional_data=iteration_answer.additional_data,
queries=iteration_answer.queries,
)
db_session.add(research_agent_iteration_sub_step)
# research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
# primary_question_id=message_id,
# iteration_nr=iteration_answer.iteration_nr,
# iteration_sub_step_nr=iteration_answer.parallelization_nr,
# sub_step_instructions=iteration_answer.question,
# sub_step_tool_id=iteration_answer.tool_id,
# sub_answer=iteration_answer.answer,
# reasoning=iteration_answer.reasoning,
# claims=iteration_answer.claims,
# cited_doc_results=serialized_search_docs,
# generated_images=(
# GeneratedImageFullResult(images=iteration_answer.generated_images)
# if iteration_answer.generated_images
# else None
# ),
# additional_data=iteration_answer.additional_data,
# queries=iteration_answer.queries,
# )
# db_session.add(research_agent_iteration_sub_step)
db_session.commit()
# db_session.commit()
def logging(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to close the DR process and finalize the answer.
"""
# def logging(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to close the DR process and finalize the answer.
# """
node_start_time = datetime.now()
# TODO: generate final answer using all the previous steps
# (right now, answers from each step are concatenated onto each other)
# Also, add missing fields once usage in UI is clear.
# node_start_time = datetime.now()
# # TODO: generate final answer using all the previous steps
# # (right now, answers from each step are concatenated onto each other)
# # Also, add missing fields once usage in UI is clear.
current_step_nr = state.current_step_nr
# current_step_nr = state.current_step_nr
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = state.original_question
if not base_question:
raise ValueError("Question is required for closer")
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = state.original_question
# if not base_question:
# raise ValueError("Question is required for closer")
aggregated_context = aggregate_context(
state.iteration_responses, include_documents=True
)
# aggregated_context = aggregate_context(
# state.iteration_responses, include_documents=True
# )
all_cited_documents = aggregated_context.cited_documents
# all_cited_documents = aggregated_context.cited_documents
is_internet_marker_dict = aggregated_context.is_internet_marker_dict
# is_internet_marker_dict = aggregated_context.is_internet_marker_dict
final_answer = state.final_answer or ""
llm_provider = graph_config.tooling.primary_llm.config.model_provider
llm_model_name = graph_config.tooling.primary_llm.config.model_name
# final_answer = state.final_answer or ""
# llm_provider = graph_config.tooling.primary_llm.config.model_provider
# llm_model_name = graph_config.tooling.primary_llm.config.model_name
llm_tokenizer = get_tokenizer(
model_name=llm_model_name,
provider_type=llm_provider,
)
num_tokens = len(llm_tokenizer.encode(final_answer or ""))
# llm_tokenizer = get_tokenizer(
# model_name=llm_model_name,
# provider_type=llm_provider,
# )
# num_tokens = len(llm_tokenizer.encode(final_answer or ""))
write_custom_event(current_step_nr, OverallStop(), writer)
# write_custom_event(current_step_nr, OverallStop(), writer)
# Log the research agent steps
save_iteration(
state,
graph_config,
aggregated_context,
final_answer,
all_cited_documents,
is_internet_marker_dict,
num_tokens,
)
# # Log the research agent steps
# save_iteration(
# state,
# graph_config,
# aggregated_context,
# final_answer,
# all_cited_documents,
# is_internet_marker_dict,
# num_tokens,
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="logger",
node_start_time=node_start_time,
)
],
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="logger",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,132 +1,131 @@
from collections.abc import Iterator
from typing import cast
# from collections.abc import Iterator
# from typing import cast
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langgraph.types import StreamWriter
from pydantic import BaseModel
# from langchain_core.messages import AIMessageChunk
# from langchain_core.messages import BaseMessage
# from langgraph.types import StreamWriter
# from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import CitationInfo
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
PassThroughAnswerResponseHandler,
)
from onyx.chat.stream_processing.utils import map_document_id_order
from onyx.context.search.models import InferenceSection
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
# from onyx.chat.models import AgentAnswerPiece
# from onyx.chat.models import CitationInfo
# from onyx.chat.models import LlmDoc
# from onyx.chat.models import OnyxAnswerPiece
# from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
# from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
# from onyx.chat.stream_processing.answer_response_handler import (
# PassThroughAnswerResponseHandler,
# )
# from onyx.chat.stream_processing.utils import map_document_id_order
# from onyx.context.search.models import InferenceSection
# from onyx.server.query_and_chat.streaming_models import CitationDelta
# from onyx.server.query_and_chat.streaming_models import CitationStart
# from onyx.server.query_and_chat.streaming_models import MessageDelta
# from onyx.server.query_and_chat.streaming_models import MessageStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
class BasicSearchProcessedStreamResults(BaseModel):
ai_message_chunk: AIMessageChunk = AIMessageChunk(content="")
full_answer: str | None = None
cited_references: list[InferenceSection] = []
retrieved_documents: list[LlmDoc] = []
# class BasicSearchProcessedStreamResults(BaseModel):
# ai_message_chunk: AIMessageChunk = AIMessageChunk(content="")
# full_answer: str | None = None
# cited_references: list[InferenceSection] = []
# retrieved_documents: list[LlmDoc] = []
def process_llm_stream(
messages: Iterator[BaseMessage],
should_stream_answer: bool,
writer: StreamWriter,
ind: int,
search_results: list[LlmDoc] | None = None,
generate_final_answer: bool = False,
chat_message_id: str | None = None,
) -> BasicSearchProcessedStreamResults:
tool_call_chunk = AIMessageChunk(content="")
# def process_llm_stream(
# messages: Iterator[BaseMessage],
# should_stream_answer: bool,
# writer: StreamWriter,
# ind: int,
# search_results: list[LlmDoc] | None = None,
# generate_final_answer: bool = False,
# chat_message_id: str | None = None,
# ) -> BasicSearchProcessedStreamResults:
# tool_call_chunk = AIMessageChunk(content="")
if search_results:
answer_handler: AnswerResponseHandler = CitationResponseHandler(
context_docs=search_results,
doc_id_to_rank_map=map_document_id_order(search_results),
)
else:
answer_handler = PassThroughAnswerResponseHandler()
# if search_results:
# answer_handler: AnswerResponseHandler = CitationResponseHandler(
# context_docs=search_results,
# doc_id_to_rank_map=map_document_id_order(search_results),
# )
# else:
# answer_handler = PassThroughAnswerResponseHandler()
full_answer = ""
start_final_answer_streaming_set = False
# Accumulate citation infos if handler emits them
collected_citation_infos: list[CitationInfo] = []
# full_answer = ""
# start_final_answer_streaming_set = False
# # Accumulate citation infos if handler emits them
# collected_citation_infos: list[CitationInfo] = []
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information.
for message in messages:
# # This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# # the stream will contain AIMessageChunks with tool call information.
# for message in messages:
answer_piece = message.content
if not isinstance(answer_piece, str):
# this is only used for logging, so fine to
# just add the string representation
answer_piece = str(answer_piece)
full_answer += answer_piece
# answer_piece = message.content
# if not isinstance(answer_piece, str):
# # this is only used for logging, so fine to
# # just add the string representation
# answer_piece = str(answer_piece)
# full_answer += answer_piece
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
):
tool_call_chunk += message # type: ignore
elif should_stream_answer:
for response_part in answer_handler.handle_response_part(message):
# if isinstance(message, AIMessageChunk) and (
# message.tool_call_chunks or message.tool_calls
# ):
# tool_call_chunk += message # type: ignore
# elif should_stream_answer:
# for response_part in answer_handler.handle_response_part(message):
# only stream out answer parts
if (
isinstance(response_part, (OnyxAnswerPiece, AgentAnswerPiece))
and generate_final_answer
and response_part.answer_piece
):
if chat_message_id is None:
raise ValueError(
"chat_message_id is required when generating final answer"
)
# # only stream out answer parts
# if (
# isinstance(response_part, (OnyxAnswerPiece, AgentAnswerPiece))
# and generate_final_answer
# and response_part.answer_piece
# ):
# if chat_message_id is None:
# raise ValueError(
# "chat_message_id is required when generating final answer"
# )
if not start_final_answer_streaming_set:
# Convert LlmDocs to SavedSearchDocs
saved_search_docs = saved_search_docs_from_llm_docs(
search_results
)
write_custom_event(
ind,
MessageStart(content="", final_documents=saved_search_docs),
writer,
)
start_final_answer_streaming_set = True
# if not start_final_answer_streaming_set:
# # Convert LlmDocs to SavedSearchDocs
# saved_search_docs = saved_search_docs_from_llm_docs(
# search_results
# )
# write_custom_event(
# ind,
# MessageStart(content="", final_documents=saved_search_docs),
# writer,
# )
# start_final_answer_streaming_set = True
write_custom_event(
ind,
MessageDelta(content=response_part.answer_piece),
writer,
)
# collect citation info objects
elif isinstance(response_part, CitationInfo):
collected_citation_infos.append(response_part)
# write_custom_event(
# ind,
# MessageDelta(content=response_part.answer_piece),
# writer,
# )
# # collect citation info objects
# elif isinstance(response_part, CitationInfo):
# collected_citation_infos.append(response_part)
if generate_final_answer and start_final_answer_streaming_set:
# start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
write_custom_event(
ind,
SectionEnd(),
writer,
)
# if generate_final_answer and start_final_answer_streaming_set:
# # start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
# write_custom_event(
# ind,
# SectionEnd(),
# writer,
# )
# Emit citations section if any were collected
if collected_citation_infos:
write_custom_event(ind, CitationStart(), writer)
write_custom_event(
ind, CitationDelta(citations=collected_citation_infos), writer
)
write_custom_event(ind, SectionEnd(), writer)
# # Emit citations section if any were collected
# if collected_citation_infos:
# write_custom_event(ind, CitationStart(), writer)
# write_custom_event(
# ind, CitationDelta(citations=collected_citation_infos), writer
# )
# write_custom_event(ind, SectionEnd(), writer)
logger.debug(f"Full answer: {full_answer}")
return BasicSearchProcessedStreamResults(
ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
)
# logger.debug(f"Full answer: {full_answer}")
# return BasicSearchProcessedStreamResults(
# ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
# )

View File

@@ -1,82 +1,82 @@
from operator import add
from typing import Annotated
from typing import Any
from typing import TypedDict
# from operator import add
# from typing import Annotated
# from typing import Any
# from typing import TypedDict
from pydantic import BaseModel
# from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
from onyx.agents.agent_search.dr.models import OrchestrationPlan
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.context.search.models import InferenceSection
from onyx.db.connector import DocumentSource
# from onyx.agents.agent_search.core_state import CoreState
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.models import IterationInstructions
# from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
# from onyx.agents.agent_search.dr.models import OrchestrationPlan
# from onyx.agents.agent_search.dr.models import OrchestratorTool
# from onyx.context.search.models import InferenceSection
# from onyx.db.connector import DocumentSource
### States ###
# ### States ###
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
# class LoggerUpdate(BaseModel):
# log_messages: Annotated[list[str], add] = []
class OrchestrationUpdate(LoggerUpdate):
tools_used: Annotated[list[str], add] = []
query_list: list[str] = []
iteration_nr: int = 0
current_step_nr: int = 1
plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
remaining_time_budget: float = 2.0 # set by default to about 2 searches
num_closer_suggestions: int = 0 # how many times the closer was suggested
gaps: list[str] = (
[]
) # gaps that may be identified by the closer before being able to answer the question.
iteration_instructions: Annotated[list[IterationInstructions], add] = []
# class OrchestrationUpdate(LoggerUpdate):
# tools_used: Annotated[list[str], add] = []
# query_list: list[str] = []
# iteration_nr: int = 0
# current_step_nr: int = 1
# plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
# remaining_time_budget: float = 2.0 # set by default to about 2 searches
# num_closer_suggestions: int = 0 # how many times the closer was suggested
# gaps: list[str] = (
# []
# ) # gaps that may be identified by the closer before being able to answer the question.
# iteration_instructions: Annotated[list[IterationInstructions], add] = []
class OrchestrationSetup(OrchestrationUpdate):
original_question: str | None = None
chat_history_string: str | None = None
clarification: OrchestrationClarificationInfo | None = None
available_tools: dict[str, OrchestratorTool] | None = None
num_closer_suggestions: int = 0 # how many times the closer was suggested
# class OrchestrationSetup(OrchestrationUpdate):
# original_question: str | None = None
# chat_history_string: str | None = None
# clarification: OrchestrationClarificationInfo | None = None
# available_tools: dict[str, OrchestratorTool] | None = None
# num_closer_suggestions: int = 0 # how many times the closer was suggested
active_source_types: list[DocumentSource] | None = None
active_source_types_descriptions: str | None = None
assistant_system_prompt: str | None = None
assistant_task_prompt: str | None = None
uploaded_test_context: str | None = None
uploaded_image_context: list[dict[str, Any]] | None = None
# active_source_types: list[DocumentSource] | None = None
# active_source_types_descriptions: str | None = None
# assistant_system_prompt: str | None = None
# assistant_task_prompt: str | None = None
# uploaded_test_context: str | None = None
# uploaded_image_context: list[dict[str, Any]] | None = None
class AnswerUpdate(LoggerUpdate):
iteration_responses: Annotated[list[IterationAnswer], add] = []
# class AnswerUpdate(LoggerUpdate):
# iteration_responses: Annotated[list[IterationAnswer], add] = []
class FinalUpdate(LoggerUpdate):
final_answer: str | None = None
all_cited_documents: list[InferenceSection] = []
# class FinalUpdate(LoggerUpdate):
# final_answer: str | None = None
# all_cited_documents: list[InferenceSection] = []
## Graph Input State
class MainInput(CoreState):
pass
# ## Graph Input State
# class MainInput(CoreState):
# pass
## Graph State
class MainState(
# This includes the core state
MainInput,
OrchestrationSetup,
AnswerUpdate,
FinalUpdate,
):
pass
# ## Graph State
# class MainState(
# # This includes the core state
# MainInput,
# OrchestrationSetup,
# AnswerUpdate,
# FinalUpdate,
# ):
# pass
## Graph Output State
class MainOutput(TypedDict):
log_messages: list[str]
final_answer: str | None
all_cited_documents: list[InferenceSection]
# ## Graph Output State
# class MainOutput(TypedDict):
# log_messages: list[str]
# final_answer: str | None
# all_cited_documents: list[InferenceSection]

View File

@@ -1,47 +1,47 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import SearchToolStart
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def basic_search_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# def basic_search_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
current_step_nr = state.current_step_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# current_step_nr = state.current_step_nr
logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
# logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
write_custom_event(
current_step_nr,
SearchToolStart(
is_internet_search=False,
),
writer,
)
# write_custom_event(
# current_step_nr,
# SearchToolStart(
# is_internet_search=False,
# ),
# writer,
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="basic_search",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,286 +1,261 @@
import re
from datetime import datetime
from typing import cast
from uuid import UUID
# import re
# from datetime import datetime
# from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import BaseSearchProcessingResponse
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import SearchAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.dr.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.chat.models import LlmDoc
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.context.search.models import InferenceSection
from onyx.db.connector import DocumentSource
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.models import BaseSearchProcessingResponse
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.models import SearchAnswer
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.dr.utils import extract_document_citations
# from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.agents.agent_search.utils import create_question_prompt
# from onyx.chat.models import LlmDoc
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
# from onyx.context.search.models import InferenceSection
# from onyx.db.connector import DocumentSource
# from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
# from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
# from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
# from onyx.server.query_and_chat.streaming_models import SearchToolDelta
# from onyx.tools.models import SearchToolOverrideKwargs
# from onyx.tools.tool_implementations.search.search_tool import SearchTool
# from onyx.tools.tool_implementations.search_like_tool_utils import (
# SEARCH_INFERENCE_SECTIONS_ID,
# )
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def basic_search(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# def basic_search(
# state: BranchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
current_step_nr = state.current_step_nr
assistant_system_prompt = state.assistant_system_prompt
assistant_task_prompt = state.assistant_task_prompt
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# parallelization_nr = state.parallelization_nr
# current_step_nr = state.current_step_nr
# assistant_system_prompt = state.assistant_system_prompt
# assistant_task_prompt = state.assistant_task_prompt
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
# branch_query = state.branch_question
# if not branch_query:
# raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
research_type = graph_config.behavior.research_type
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = graph_config.inputs.prompt_builder.raw_user_query
# use_agentic_search = graph_config.behavior.use_agentic_search
if not state.available_tools:
raise ValueError("available_tools is not set")
# if not state.available_tools:
# raise ValueError("available_tools is not set")
elif len(state.tools_used) == 0:
raise ValueError("tools_used is empty")
# elif len(state.tools_used) == 0:
# raise ValueError("tools_used is empty")
search_tool_info = state.available_tools[state.tools_used[-1]]
search_tool = cast(SearchTool, search_tool_info.tool_object)
force_use_tool = graph_config.tooling.force_use_tool
# search_tool_info = state.available_tools[state.tools_used[-1]]
# search_tool = cast(SearchTool, search_tool_info.tool_object)
# graph_config.tooling.force_use_tool
# sanity check
if search_tool != graph_config.tooling.search_tool:
raise ValueError("search_tool does not match the configured search tool")
# # sanity check
# if search_tool != graph_config.tooling.search_tool:
# raise ValueError("search_tool does not match the configured search tool")
# rewrite query and identify source types
active_source_types_str = ", ".join(
[source.value for source in state.active_source_types or []]
)
# # rewrite query and identify source types
# active_source_types_str = ", ".join(
# [source.value for source in state.active_source_types or []]
# )
base_search_processing_prompt = BASE_SEARCH_PROCESSING_PROMPT.build(
active_source_types_str=active_source_types_str,
branch_query=branch_query,
current_time=datetime.now().strftime("%Y-%m-%d %H:%M"),
)
# base_search_processing_prompt = BASE_SEARCH_PROCESSING_PROMPT.build(
# active_source_types_str=active_source_types_str,
# branch_query=branch_query,
# current_time=datetime.now().strftime("%Y-%m-%d %H:%M"),
# )
try:
search_processing = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, base_search_processing_prompt
),
schema=BaseSearchProcessingResponse,
timeout_override=TF_DR_TIMEOUT_SHORT,
# max_tokens=100,
)
except Exception as e:
logger.error(f"Could not process query: {e}")
raise e
# try:
# search_processing = invoke_llm_json(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt, base_search_processing_prompt
# ),
# schema=BaseSearchProcessingResponse,
# timeout_override=TF_DR_TIMEOUT_SHORT,
# # max_tokens=100,
# )
# except Exception as e:
# logger.error(f"Could not process query: {e}")
# raise e
rewritten_query = search_processing.rewritten_query
# rewritten_query = search_processing.rewritten_query
# give back the query so we can render it in the UI
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=[rewritten_query],
documents=[],
),
writer,
)
# # give back the query so we can render it in the UI
# write_custom_event(
# current_step_nr,
# SearchToolDelta(
# queries=[rewritten_query],
# documents=[],
# ),
# writer,
# )
implied_start_date = search_processing.time_filter
# implied_start_date = search_processing.time_filter
# Validate time_filter format if it exists
implied_time_filter = None
if implied_start_date:
# # Validate time_filter format if it exists
# implied_time_filter = None
# if implied_start_date:
# Check if time_filter is in YYYY-MM-DD format
date_pattern = r"^\d{4}-\d{2}-\d{2}$"
if re.match(date_pattern, implied_start_date):
implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
# # Check if time_filter is in YYYY-MM-DD format
# date_pattern = r"^\d{4}-\d{2}-\d{2}$"
# if re.match(date_pattern, implied_start_date):
# implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
specified_source_types: list[DocumentSource] | None = (
strings_to_document_sources(search_processing.specified_source_types)
if search_processing.specified_source_types
else None
)
# specified_source_types: list[DocumentSource] | None = (
# strings_to_document_sources(search_processing.specified_source_types)
# if search_processing.specified_source_types
# else None
# )
if specified_source_types is not None and len(specified_source_types) == 0:
specified_source_types = None
# if specified_source_types is not None and len(specified_source_types) == 0:
# specified_source_types = None
logger.debug(
f"Search start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# logger.debug(
# f"Search start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
retrieved_docs: list[InferenceSection] = []
callback_container: list[list[InferenceSection]] = []
# retrieved_docs: list[InferenceSection] = []
user_file_ids: list[UUID] | None = None
project_id: int | None = None
if force_use_tool.override_kwargs and isinstance(
force_use_tool.override_kwargs, SearchToolOverrideKwargs
):
override_kwargs = force_use_tool.override_kwargs
user_file_ids = override_kwargs.user_file_ids
project_id = override_kwargs.project_id
# for tool_response in search_tool.run(
# query=rewritten_query,
# document_sources=specified_source_types,
# time_filter=implied_time_filter,
# override_kwargs=SearchToolOverrideKwargs(original_query=rewritten_query),
# ):
# # get retrieved docs to send to the rest of the graph
# if tool_response.id == SEARCH_INFERENCE_SECTIONS_ID:
# retrieved_docs = cast(list[InferenceSection], tool_response.response)
# new db session to avoid concurrency issues
with get_session_with_current_tenant() as search_db_session:
for tool_response in search_tool.run(
query=rewritten_query,
document_sources=specified_source_types,
time_filter=implied_time_filter,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=search_db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
original_query=rewritten_query,
user_file_ids=user_file_ids,
project_id=project_id,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
# break
break
# # render the retrieved docs in the UI
# write_custom_event(
# current_step_nr,
# SearchToolDelta(
# queries=[],
# documents=convert_inference_sections_to_search_docs(
# retrieved_docs, is_internet=False
# ),
# ),
# writer,
# )
# render the retrieved docs in the UI
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=[],
documents=convert_inference_sections_to_search_docs(
retrieved_docs, is_internet=False
),
),
writer,
)
# document_texts_list = []
document_texts_list = []
# for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
# if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
# raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
# chunk_text = build_document_context(retrieved_doc, doc_num + 1)
# document_texts_list.append(chunk_text)
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
# document_texts = "\n\n".join(document_texts_list)
document_texts = "\n\n".join(document_texts_list)
# logger.debug(
# f"Search end/LLM start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
logger.debug(
f"Search end/LLM start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# # Built prompt
# Built prompt
# if use_agentic_search:
# search_prompt = INTERNAL_SEARCH_PROMPTS[use_agentic_search].build(
# search_query=branch_query,
# base_question=base_question,
# document_text=document_texts,
# )
if research_type == ResearchType.DEEP:
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
search_query=branch_query,
base_question=base_question,
document_text=document_texts,
)
# # Run LLM
# Run LLM
# # search_answer_json = None
# search_answer_json = invoke_llm_json(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
# ),
# schema=SearchAnswer,
# timeout_override=TF_DR_TIMEOUT_LONG,
# # max_tokens=1500,
# )
# search_answer_json = None
search_answer_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
),
schema=SearchAnswer,
timeout_override=TF_DR_TIMEOUT_LONG,
# max_tokens=1500,
)
# logger.debug(
# f"LLM/all done for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
logger.debug(
f"LLM/all done for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# # get cited documents
# answer_string = search_answer_json.answer
# claims = search_answer_json.claims or []
# reasoning = search_answer_json.reasoning
# # answer_string = ""
# # claims = []
# get cited documents
answer_string = search_answer_json.answer
claims = search_answer_json.claims or []
reasoning = search_answer_json.reasoning
# answer_string = ""
# claims = []
# (
# citation_numbers,
# answer_string,
# claims,
# ) = extract_document_citations(answer_string, claims)
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
# if citation_numbers and (
# (max(citation_numbers) > len(retrieved_docs)) or min(citation_numbers) < 1
# ):
# raise ValueError("Citation numbers are out of range for retrieved docs.")
if citation_numbers and (
(max(citation_numbers) > len(retrieved_docs)) or min(citation_numbers) < 1
):
raise ValueError("Citation numbers are out of range for retrieved docs.")
# cited_documents = {
# citation_number: retrieved_docs[citation_number - 1]
# for citation_number in citation_numbers
# }
cited_documents = {
citation_number: retrieved_docs[citation_number - 1]
for citation_number in citation_numbers
}
# else:
# answer_string = ""
# claims = []
# cited_documents = {
# doc_num + 1: retrieved_doc
# for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
# }
# reasoning = ""
else:
answer_string = ""
claims = []
cited_documents = {
doc_num + 1: retrieved_doc
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
}
reasoning = ""
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=search_tool_info.llm_path,
tool_id=search_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=reasoning,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=search_tool_info.llm_path,
# tool_id=search_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=branch_query,
# answer=answer_string,
# claims=claims,
# cited_documents=cited_documents,
# reasoning=reasoning,
# additional_data=None,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="basic_search",
# node_name="searching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,77 +1,77 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.context.search.models import SavedSearchDoc
# from onyx.context.search.models import SearchDoc
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def is_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# def is_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
# current_step_nr = state.current_step_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
[update.question for update in new_updates]
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
# [update.question for update in new_updates]
# doc_lists = [list(update.cited_documents.values()) for update in new_updates]
doc_list = []
# doc_list = []
for xs in doc_lists:
for x in xs:
doc_list.append(x)
# for xs in doc_lists:
# for x in xs:
# doc_list.append(x)
# Convert InferenceSections to SavedSearchDocs
search_docs = SearchDoc.from_chunks_or_sections(doc_list)
retrieved_saved_search_docs = [
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
for search_doc in search_docs
]
# # Convert InferenceSections to SavedSearchDocs
# search_docs = SearchDoc.from_chunks_or_sections(doc_list)
# retrieved_saved_search_docs = [
# SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
# for search_doc in search_docs
# ]
for retrieved_saved_search_doc in retrieved_saved_search_docs:
retrieved_saved_search_doc.is_internet = False
# for retrieved_saved_search_doc in retrieved_saved_search_docs:
# retrieved_saved_search_doc.is_internet = False
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
current_step_nr += 1
# current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)
# return SubAgentUpdate(
# iteration_responses=new_updates,
# current_step_nr=current_step_nr,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="basic_search",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,50 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_1_branch import (
basic_search_branch,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import (
basic_search,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_3_reduce import (
is_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_image_generation_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_1_branch import (
# basic_search_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import (
# basic_search,
# )
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_3_reduce import (
# is_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_image_generation_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def dr_basic_search_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Web Search Sub-Agent
"""
# def dr_basic_search_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Web Search Sub-Agent
# """
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
# ### Add nodes ###
graph.add_node("branch", basic_search_branch)
# graph.add_node("branch", basic_search_branch)
graph.add_node("act", basic_search)
# graph.add_node("act", basic_search)
graph.add_node("reducer", is_reducer)
# graph.add_node("reducer", is_reducer)
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
# graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
# graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
# graph.add_edge(start_key="reducer", end_key=END)
return graph
# return graph

View File

@@ -1,30 +1,30 @@
from collections.abc import Hashable
# from collections.abc import Hashable
from langgraph.types import Send
# from langgraph.types import Send
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
current_step_nr=state.current_step_nr,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:MAX_DR_PARALLEL_SEARCH]
)
]
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# current_step_nr=state.current_step_nr,
# context="",
# active_source_types=state.active_source_types,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:MAX_DR_PARALLEL_SEARCH]
# )
# ]

View File

@@ -1,36 +1,36 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def custom_tool_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# def custom_tool_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
# logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="branching",
node_start_time=node_start_time,
)
],
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,169 +1,164 @@
import json
from datetime import datetime
from typing import cast
# import json
# from datetime import datetime
# from typing import cast
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.messages import AIMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
from onyx.tools.tool_implementations.mcp.mcp_tool import MCP_TOOL_RESPONSE_ID
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
# from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
# from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
# from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
# from onyx.tools.tool_implementations.mcp.mcp_tool import MCP_TOOL_RESPONSE_ID
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def custom_tool_act(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# def custom_tool_act(
# state: BranchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# parallelization_nr = state.parallelization_nr
if not state.available_tools:
raise ValueError("available_tools is not set")
# if not state.available_tools:
# raise ValueError("available_tools is not set")
custom_tool_info = state.available_tools[state.tools_used[-1]]
custom_tool_name = custom_tool_info.name
custom_tool = cast(CustomTool, custom_tool_info.tool_object)
# custom_tool_info = state.available_tools[state.tools_used[-1]]
# custom_tool_name = custom_tool_info.name
# custom_tool = cast(CustomTool, custom_tool_info.tool_object)
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
# branch_query = state.branch_question
# if not branch_query:
# raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = graph_config.inputs.prompt_builder.raw_user_query
logger.debug(
f"Tool call start for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# logger.debug(
# f"Tool call start for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
# get tool call args
tool_args: dict | None = None
if graph_config.tooling.using_tool_calling_llm:
# get tool call args from tool-calling LLM
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
query=branch_query,
base_question=base_question,
tool_description=custom_tool_info.description,
)
tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
tool_use_prompt,
tools=[custom_tool.tool_definition()],
tool_choice="required",
timeout_override=TF_DR_TIMEOUT_LONG,
)
# # get tool call args
# tool_args: dict | None = None
# if graph_config.tooling.using_tool_calling_llm:
# # get tool call args from tool-calling LLM
# tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
# query=branch_query,
# base_question=base_question,
# tool_description=custom_tool_info.description,
# )
# tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
# tool_use_prompt,
# tools=[custom_tool.tool_definition()],
# tool_choice="required",
# timeout_override=TF_DR_TIMEOUT_LONG,
# )
# make sure we got a tool call
if (
isinstance(tool_calling_msg, AIMessage)
and len(tool_calling_msg.tool_calls) == 1
):
tool_args = tool_calling_msg.tool_calls[0]["args"]
else:
logger.warning("Tool-calling LLM did not emit a tool call")
# # make sure we got a tool call
# if (
# isinstance(tool_calling_msg, AIMessage)
# and len(tool_calling_msg.tool_calls) == 1
# ):
# tool_args = tool_calling_msg.tool_calls[0]["args"]
# else:
# logger.warning("Tool-calling LLM did not emit a tool call")
if tool_args is None:
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
tool_args = custom_tool.get_args_for_non_tool_calling_llm(
query=branch_query,
history=[],
llm=graph_config.tooling.primary_llm,
force_run=True,
)
# if tool_args is None:
# raise ValueError(
# "Failed to obtain tool arguments from LLM - tool calling is required"
# )
if tool_args is None:
raise ValueError("Failed to obtain tool arguments from LLM")
# # run the tool
# response_summary: CustomToolCallSummary | None = None
# for tool_response in custom_tool.run(**tool_args):
# if tool_response.id in {CUSTOM_TOOL_RESPONSE_ID, MCP_TOOL_RESPONSE_ID}:
# response_summary = cast(CustomToolCallSummary, tool_response.response)
# break
# run the tool
response_summary: CustomToolCallSummary | None = None
for tool_response in custom_tool.run(**tool_args):
if tool_response.id in {CUSTOM_TOOL_RESPONSE_ID, MCP_TOOL_RESPONSE_ID}:
response_summary = cast(CustomToolCallSummary, tool_response.response)
break
# if not response_summary:
# raise ValueError("Custom tool did not return a valid response summary")
if not response_summary:
raise ValueError("Custom tool did not return a valid response summary")
# # summarise tool result
# if not response_summary.response_type:
# raise ValueError("Response type is not returned.")
# summarise tool result
if not response_summary.response_type:
raise ValueError("Response type is not returned.")
# if response_summary.response_type == "json":
# tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
# elif response_summary.response_type in {"image", "csv"}:
# tool_result_str = f"{response_summary.response_type} files: {response_summary.tool_result.file_ids}"
# else:
# tool_result_str = str(response_summary.tool_result)
if response_summary.response_type == "json":
tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
elif response_summary.response_type in {"image", "csv"}:
tool_result_str = f"{response_summary.response_type} files: {response_summary.tool_result.file_ids}"
else:
tool_result_str = str(response_summary.tool_result)
# tool_str = (
# f"Tool used: {custom_tool_name}\n"
# f"Description: {custom_tool_info.description}\n"
# f"Result: {tool_result_str}"
# )
tool_str = (
f"Tool used: {custom_tool_name}\n"
f"Description: {custom_tool_info.description}\n"
f"Result: {tool_result_str}"
)
# tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
# query=branch_query, base_question=base_question, tool_response=tool_str
# )
# answer_string = str(
# graph_config.tooling.primary_llm.invoke(
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
# ).content
# ).strip()
tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
query=branch_query, base_question=base_question, tool_response=tool_str
)
answer_string = str(
graph_config.tooling.primary_llm.invoke_langchain(
tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
).content
).strip()
# tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
# query=branch_query, base_question=base_question, tool_response=tool_str
# )
# answer_string = str(
# graph_config.tooling.primary_llm.invoke_langchain(
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
# ).content
# ).strip()
# get file_ids:
file_ids = None
if response_summary.response_type in {"image", "csv"} and hasattr(
response_summary.tool_result, "file_ids"
):
file_ids = response_summary.tool_result.file_ids
# logger.debug(
# f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
logger.debug(
f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=custom_tool_name,
tool_id=custom_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning="",
additional_data=None,
response_type=response_summary.response_type,
data=response_summary.tool_result,
file_ids=file_ids,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="tool_calling",
node_start_time=node_start_time,
)
],
)
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=custom_tool_name,
# tool_id=custom_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=branch_query,
# answer=answer_string,
# claims=[],
# cited_documents={},
# reasoning="",
# additional_data=None,
# response_type=response_summary.response_type,
# data=response_summary.tool_result,
# file_ids=file_ids,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="tool_calling",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,82 +1,82 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import CustomToolDelta
# from onyx.server.query_and_chat.streaming_models import CustomToolStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def custom_tool_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# def custom_tool_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
current_step_nr = state.current_step_nr
# current_step_nr = state.current_step_nr
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
for new_update in new_updates:
# for new_update in new_updates:
if not new_update.response_type:
raise ValueError("Response type is not returned.")
# if not new_update.response_type:
# raise ValueError("Response type is not returned.")
write_custom_event(
current_step_nr,
CustomToolStart(
tool_name=new_update.tool,
),
writer,
)
# write_custom_event(
# current_step_nr,
# CustomToolStart(
# tool_name=new_update.tool,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
CustomToolDelta(
tool_name=new_update.tool,
response_type=new_update.response_type,
data=new_update.data,
file_ids=new_update.file_ids,
),
writer,
)
# write_custom_event(
# current_step_nr,
# CustomToolDelta(
# tool_name=new_update.tool,
# response_type=new_update.response_type,
# data=new_update.data,
# file_ids=new_update.file_ids,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
current_step_nr += 1
# current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)
# return SubAgentUpdate(
# iteration_responses=new_updates,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,28 +1,28 @@
from collections.abc import Hashable
# from collections.abc import Hashable
from langgraph.types import Send
# from langgraph.types import Send
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import (
SubAgentInput,
)
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import (
# SubAgentInput,
# )
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel call for now
)
]
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# context="",
# active_source_types=state.active_source_types,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:1] # no parallel call for now
# )
# ]

View File

@@ -1,50 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_1_branch import (
custom_tool_branch,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_2_act import (
custom_tool_act,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_3_reduce import (
custom_tool_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_1_branch import (
# custom_tool_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_2_act import (
# custom_tool_act,
# )
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_3_reduce import (
# custom_tool_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def dr_custom_tool_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Generic Tool Sub-Agent
"""
# def dr_custom_tool_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Generic Tool Sub-Agent
# """
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
# ### Add nodes ###
graph.add_node("branch", custom_tool_branch)
# graph.add_node("branch", custom_tool_branch)
graph.add_node("act", custom_tool_act)
# graph.add_node("act", custom_tool_act)
graph.add_node("reducer", custom_tool_reducer)
# graph.add_node("reducer", custom_tool_reducer)
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
# graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
# graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
# graph.add_edge(start_key="reducer", end_key=END)
return graph
# return graph

View File

@@ -1,36 +1,36 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def generic_internal_tool_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# def generic_internal_tool_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
# logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="generic_internal_tool",
node_name="branching",
node_start_time=node_start_time,
)
],
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="generic_internal_tool",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,149 +1,147 @@
import json
from datetime import datetime
from typing import cast
# import json
# from datetime import datetime
# from typing import cast
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.messages import AIMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
from onyx.prompts.dr_prompts import OKTA_TOOL_USE_SPECIAL_PROMPT
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
# from onyx.prompts.dr_prompts import OKTA_TOOL_USE_SPECIAL_PROMPT
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def generic_internal_tool_act(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# def generic_internal_tool_act(
# state: BranchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# parallelization_nr = state.parallelization_nr
if not state.available_tools:
raise ValueError("available_tools is not set")
# if not state.available_tools:
# raise ValueError("available_tools is not set")
generic_internal_tool_info = state.available_tools[state.tools_used[-1]]
generic_internal_tool_name = generic_internal_tool_info.llm_path
generic_internal_tool = generic_internal_tool_info.tool_object
# generic_internal_tool_info = state.available_tools[state.tools_used[-1]]
# generic_internal_tool_name = generic_internal_tool_info.llm_path
# generic_internal_tool = generic_internal_tool_info.tool_object
if generic_internal_tool is None:
raise ValueError("generic_internal_tool is not set")
# if generic_internal_tool is None:
# raise ValueError("generic_internal_tool is not set")
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
# branch_query = state.branch_question
# if not branch_query:
# raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = graph_config.inputs.prompt_builder.raw_user_query
logger.debug(
f"Tool call start for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# logger.debug(
# f"Tool call start for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
# get tool call args
tool_args: dict | None = None
if graph_config.tooling.using_tool_calling_llm:
# get tool call args from tool-calling LLM
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
query=branch_query,
base_question=base_question,
tool_description=generic_internal_tool_info.description,
)
tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
tool_use_prompt,
tools=[generic_internal_tool.tool_definition()],
tool_choice="required",
timeout_override=TF_DR_TIMEOUT_SHORT,
)
# # get tool call args
# tool_args: dict | None = None
# if graph_config.tooling.using_tool_calling_llm:
# # get tool call args from tool-calling LLM
# tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
# query=branch_query,
# base_question=base_question,
# tool_description=generic_internal_tool_info.description,
# )
# tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
# tool_use_prompt,
# tools=[generic_internal_tool.tool_definition()],
# tool_choice="required",
# timeout_override=TF_DR_TIMEOUT_SHORT,
# )
# make sure we got a tool call
if (
isinstance(tool_calling_msg, AIMessage)
and len(tool_calling_msg.tool_calls) == 1
):
tool_args = tool_calling_msg.tool_calls[0]["args"]
else:
logger.warning("Tool-calling LLM did not emit a tool call")
# # make sure we got a tool call
# if (
# isinstance(tool_calling_msg, AIMessage)
# and len(tool_calling_msg.tool_calls) == 1
# ):
# tool_args = tool_calling_msg.tool_calls[0]["args"]
# else:
# logger.warning("Tool-calling LLM did not emit a tool call")
if tool_args is None:
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
tool_args = generic_internal_tool.get_args_for_non_tool_calling_llm(
query=branch_query,
history=[],
llm=graph_config.tooling.primary_llm,
force_run=True,
)
# if tool_args is None:
# raise ValueError(
# "Failed to obtain tool arguments from LLM - tool calling is required"
# )
if tool_args is None:
raise ValueError("Failed to obtain tool arguments from LLM")
# # run the tool
# tool_responses = list(generic_internal_tool.run(**tool_args))
# final_data = generic_internal_tool.final_result(*tool_responses)
# tool_result_str = json.dumps(final_data, ensure_ascii=False)
# run the tool
tool_responses = list(generic_internal_tool.run(**tool_args))
final_data = generic_internal_tool.final_result(*tool_responses)
tool_result_str = json.dumps(final_data, ensure_ascii=False)
# tool_str = (
# f"Tool used: {generic_internal_tool.display_name}\n"
# f"Description: {generic_internal_tool_info.description}\n"
# f"Result: {tool_result_str}"
# )
tool_str = (
f"Tool used: {generic_internal_tool.display_name}\n"
f"Description: {generic_internal_tool_info.description}\n"
f"Result: {tool_result_str}"
)
# if generic_internal_tool.display_name == "Okta Profile":
# tool_prompt = OKTA_TOOL_USE_SPECIAL_PROMPT
# else:
# tool_prompt = CUSTOM_TOOL_USE_PROMPT
if generic_internal_tool.display_name == "Okta Profile":
tool_prompt = OKTA_TOOL_USE_SPECIAL_PROMPT
else:
tool_prompt = CUSTOM_TOOL_USE_PROMPT
# tool_summary_prompt = tool_prompt.build(
# query=branch_query, base_question=base_question, tool_response=tool_str
# )
# answer_string = str(
# graph_config.tooling.primary_llm.invoke(
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
# ).content
# ).strip()
tool_summary_prompt = tool_prompt.build(
query=branch_query, base_question=base_question, tool_response=tool_str
)
answer_string = str(
graph_config.tooling.primary_llm.invoke_langchain(
tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
).content
).strip()
# tool_summary_prompt = tool_prompt.build(
# query=branch_query, base_question=base_question, tool_response=tool_str
# )
# answer_string = str(
# graph_config.tooling.primary_llm.invoke_langchain(
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
# ).content
# ).strip()
logger.debug(
f"Tool call end for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=generic_internal_tool.llm_name,
tool_id=generic_internal_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning="",
additional_data=None,
response_type="text", # TODO: convert all response types to enums
data=answer_string,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="tool_calling",
node_start_time=node_start_time,
)
],
)
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=generic_internal_tool.name,
# tool_id=generic_internal_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=branch_query,
# answer=answer_string,
# claims=[],
# cited_documents={},
# reasoning="",
# additional_data=None,
# response_type="text", # TODO: convert all response types to enums
# data=answer_string,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="tool_calling",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,82 +1,82 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import CustomToolDelta
# from onyx.server.query_and_chat.streaming_models import CustomToolStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def generic_internal_tool_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# def generic_internal_tool_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
current_step_nr = state.current_step_nr
# current_step_nr = state.current_step_nr
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
for new_update in new_updates:
# for new_update in new_updates:
if not new_update.response_type:
raise ValueError("Response type is not returned.")
# if not new_update.response_type:
# raise ValueError("Response type is not returned.")
write_custom_event(
current_step_nr,
CustomToolStart(
tool_name=new_update.tool,
),
writer,
)
# write_custom_event(
# current_step_nr,
# CustomToolStart(
# tool_name=new_update.tool,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
CustomToolDelta(
tool_name=new_update.tool,
response_type=new_update.response_type,
data=new_update.data,
file_ids=[],
),
writer,
)
# write_custom_event(
# current_step_nr,
# CustomToolDelta(
# tool_name=new_update.tool,
# response_type=new_update.response_type,
# data=new_update.data,
# file_ids=[],
# ),
# writer,
# )
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
current_step_nr += 1
# current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)
# return SubAgentUpdate(
# iteration_responses=new_updates,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,28 +1,28 @@
from collections.abc import Hashable
# from collections.abc import Hashable
from langgraph.types import Send
# from langgraph.types import Send
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import (
SubAgentInput,
)
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import (
# SubAgentInput,
# )
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel call for now
)
]
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# context="",
# active_source_types=state.active_source_types,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:1] # no parallel call for now
# )
# ]

View File

@@ -1,50 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_1_branch import (
generic_internal_tool_branch,
)
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_2_act import (
generic_internal_tool_act,
)
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_3_reduce import (
generic_internal_tool_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_1_branch import (
# generic_internal_tool_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_2_act import (
# generic_internal_tool_act,
# )
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_3_reduce import (
# generic_internal_tool_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def dr_generic_internal_tool_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Generic Tool Sub-Agent
"""
# def dr_generic_internal_tool_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Generic Tool Sub-Agent
# """
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
# ### Add nodes ###
graph.add_node("branch", generic_internal_tool_branch)
# graph.add_node("branch", generic_internal_tool_branch)
graph.add_node("act", generic_internal_tool_act)
# graph.add_node("act", generic_internal_tool_act)
graph.add_node("reducer", generic_internal_tool_reducer)
# graph.add_node("reducer", generic_internal_tool_reducer)
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
# graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
# graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
# graph.add_edge(start_key="reducer", end_key=END)
return graph
# return graph

View File

@@ -1,45 +1,45 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def image_generation_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a image generation as part of the DR process.
"""
# def image_generation_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a image generation as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
logger.debug(f"Image generation start {iteration_nr} at {datetime.now()}")
# logger.debug(f"Image generation start {iteration_nr} at {datetime.now()}")
# tell frontend that we are starting the image generation tool
write_custom_event(
state.current_step_nr,
ImageGenerationToolStart(),
writer,
)
# # tell frontend that we are starting the image generation tool
# write_custom_event(
# state.current_step_nr,
# ImageGenerationToolStart(),
# writer,
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="image_generation",
node_name="branching",
node_start_time=node_start_time,
)
],
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="image_generation",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,189 +1,187 @@
import json
from datetime import datetime
from typing import cast
# import json
# from datetime import datetime
# from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.models import GeneratedImage
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.file_store.utils import build_frontend_file_url
from onyx.file_store.utils import save_files
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
from onyx.tools.tool_implementations.images.image_generation_tool import (
IMAGE_GENERATION_HEARTBEAT_ID,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
IMAGE_GENERATION_RESPONSE_ID,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.models import GeneratedImage
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.file_store.utils import build_frontend_file_url
# from onyx.file_store.utils import save_files
# from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
# from onyx.tools.tool_implementations.images.image_generation_tool import (
# IMAGE_GENERATION_HEARTBEAT_ID,
# )
# from onyx.tools.tool_implementations.images.image_generation_tool import (
# IMAGE_GENERATION_RESPONSE_ID,
# )
# from onyx.tools.tool_implementations.images.image_generation_tool import (
# ImageGenerationResponse,
# )
# from onyx.tools.tool_implementations.images.image_generation_tool import (
# ImageGenerationTool,
# )
# from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def image_generation(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# def image_generation(
# state: BranchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
state.assistant_system_prompt
state.assistant_task_prompt
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# parallelization_nr = state.parallelization_nr
# state.assistant_system_prompt
# state.assistant_task_prompt
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
# branch_query = state.branch_question
# if not branch_query:
# raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.prompt_builder.raw_user_query
graph_config.behavior.research_type
# graph_config = cast(GraphConfig, config["metadata"]["config"])
if not state.available_tools:
raise ValueError("available_tools is not set")
# if not state.available_tools:
# raise ValueError("available_tools is not set")
image_tool_info = state.available_tools[state.tools_used[-1]]
image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
# image_tool_info = state.available_tools[state.tools_used[-1]]
# image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
image_prompt = branch_query
requested_shape: ImageShape | None = None
# image_prompt = branch_query
# requested_shape: ImageShape | None = None
try:
parsed_query = json.loads(branch_query)
except json.JSONDecodeError:
parsed_query = None
# try:
# parsed_query = json.loads(branch_query)
# except json.JSONDecodeError:
# parsed_query = None
if isinstance(parsed_query, dict):
prompt_from_llm = parsed_query.get("prompt")
if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
image_prompt = prompt_from_llm.strip()
# if isinstance(parsed_query, dict):
# prompt_from_llm = parsed_query.get("prompt")
# if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
# image_prompt = prompt_from_llm.strip()
raw_shape = parsed_query.get("shape")
if isinstance(raw_shape, str):
try:
requested_shape = ImageShape(raw_shape)
except ValueError:
logger.warning(
"Received unsupported image shape '%s' from LLM. Falling back to square.",
raw_shape,
)
# raw_shape = parsed_query.get("shape")
# if isinstance(raw_shape, str):
# try:
# requested_shape = ImageShape(raw_shape)
# except ValueError:
# logger.warning(
# "Received unsupported image shape '%s' from LLM. Falling back to square.",
# raw_shape,
# )
logger.debug(
f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# logger.debug(
# f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
# Generate images using the image generation tool
image_generation_responses: list[ImageGenerationResponse] = []
# # Generate images using the image generation tool
# image_generation_responses: list[ImageGenerationResponse] = []
if requested_shape is not None:
tool_iterator = image_tool.run(
prompt=image_prompt,
shape=requested_shape.value,
)
else:
tool_iterator = image_tool.run(prompt=image_prompt)
# if requested_shape is not None:
# tool_iterator = image_tool.run(
# prompt=image_prompt,
# shape=requested_shape.value,
# )
# else:
# tool_iterator = image_tool.run(prompt=image_prompt)
for tool_response in tool_iterator:
if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
# Stream heartbeat to frontend
write_custom_event(
state.current_step_nr,
ImageGenerationToolHeartbeat(),
writer,
)
elif tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
response = cast(list[ImageGenerationResponse], tool_response.response)
image_generation_responses = response
break
# for tool_response in tool_iterator:
# if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
# # Stream heartbeat to frontend
# write_custom_event(
# state.current_step_nr,
# ImageGenerationToolHeartbeat(),
# writer,
# )
# elif tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
# response = cast(list[ImageGenerationResponse], tool_response.response)
# image_generation_responses = response
# break
# save images to file store
file_ids = save_files(
urls=[],
base64_files=[img.image_data for img in image_generation_responses],
)
# # save images to file store
# file_ids = save_files(
# urls=[],
# base64_files=[img.image_data for img in image_generation_responses],
# )
final_generated_images = [
GeneratedImage(
file_id=file_id,
url=build_frontend_file_url(file_id),
revised_prompt=img.revised_prompt,
shape=(requested_shape or ImageShape.SQUARE).value,
)
for file_id, img in zip(file_ids, image_generation_responses)
]
# final_generated_images = [
# GeneratedImage(
# file_id=file_id,
# url=build_frontend_file_url(file_id),
# revised_prompt=img.revised_prompt,
# shape=(requested_shape or ImageShape.SQUARE).value,
# )
# for file_id, img in zip(file_ids, image_generation_responses)
# ]
logger.debug(
f"Image generation complete for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# logger.debug(
# f"Image generation complete for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
# Create answer string describing the generated images
if final_generated_images:
image_descriptions = []
for i, img in enumerate(final_generated_images, 1):
if img.shape and img.shape != ImageShape.SQUARE.value:
image_descriptions.append(
f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
)
else:
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
# # Create answer string describing the generated images
# if final_generated_images:
# image_descriptions = []
# for i, img in enumerate(final_generated_images, 1):
# if img.shape and img.shape != ImageShape.SQUARE.value:
# image_descriptions.append(
# f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
# )
# else:
# image_descriptions.append(f"Image {i}: {img.revised_prompt}")
answer_string = (
f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
+ "\n".join(image_descriptions)
)
if requested_shape:
reasoning = (
"Used image generation tool to create "
f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
)
else:
reasoning = (
"Used image generation tool to create "
f"{len(final_generated_images)} image(s) based on the user's request."
)
else:
answer_string = f"Failed to generate images for request: {image_prompt}"
reasoning = "Image generation tool did not return any results."
# answer_string = (
# f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
# + "\n".join(image_descriptions)
# )
# if requested_shape:
# reasoning = (
# "Used image generation tool to create "
# f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
# )
# else:
# reasoning = (
# "Used image generation tool to create "
# f"{len(final_generated_images)} image(s) based on the user's request."
# )
# else:
# answer_string = f"Failed to generate images for request: {image_prompt}"
# reasoning = "Image generation tool did not return any results."
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=image_tool_info.llm_path,
tool_id=image_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning=reasoning,
generated_images=final_generated_images,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="image_generation",
node_name="generating",
node_start_time=node_start_time,
)
],
)
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=image_tool_info.llm_path,
# tool_id=image_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=branch_query,
# answer=answer_string,
# claims=[],
# cited_documents={},
# reasoning=reasoning,
# generated_images=final_generated_images,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="image_generation",
# node_name="generating",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,71 +1,71 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.models import GeneratedImage
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.models import GeneratedImage
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def is_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# def is_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
# current_step_nr = state.current_step_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
generated_images: list[GeneratedImage] = []
for update in new_updates:
if update.generated_images:
generated_images.extend(update.generated_images)
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
# generated_images: list[GeneratedImage] = []
# for update in new_updates:
# if update.generated_images:
# generated_images.extend(update.generated_images)
# Write the results to the stream
write_custom_event(
current_step_nr,
ImageGenerationToolDelta(
images=generated_images,
),
writer,
)
# # Write the results to the stream
# write_custom_event(
# current_step_nr,
# ImageGenerationFinal(
# images=generated_images,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
current_step_nr += 1
# current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="image_generation",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)
# return SubAgentUpdate(
# iteration_responses=new_updates,
# current_step_nr=current_step_nr,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="image_generation",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,29 +1,29 @@
from collections.abc import Hashable
# from collections.abc import Hashable
from langgraph.types import Send
# from langgraph.types import Send
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:MAX_DR_PARALLEL_SEARCH]
)
]
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# context="",
# active_source_types=state.active_source_types,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:MAX_DR_PARALLEL_SEARCH]
# )
# ]

View File

@@ -1,50 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_1_branch import (
image_generation_branch,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_2_act import (
image_generation,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_3_reduce import (
is_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_1_branch import (
# image_generation_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_2_act import (
# image_generation,
# )
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_3_reduce import (
# is_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def dr_image_generation_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Image Generation Sub-Agent
"""
# def dr_image_generation_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Image Generation Sub-Agent
# """
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
# ### Add nodes ###
graph.add_node("branch", image_generation_branch)
# graph.add_node("branch", image_generation_branch)
graph.add_node("act", image_generation)
# graph.add_node("act", image_generation)
graph.add_node("reducer", is_reducer)
# graph.add_node("reducer", is_reducer)
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
# graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
# graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
# graph.add_edge(start_key="reducer", end_key=END)
return graph
# return graph

View File

@@ -1,13 +1,13 @@
from pydantic import BaseModel
# from pydantic import BaseModel
class GeneratedImage(BaseModel):
file_id: str
url: str
revised_prompt: str
shape: str | None = None
# class GeneratedImage(BaseModel):
# file_id: str
# url: str
# revised_prompt: str
# shape: str | None = None
# Needed for PydanticType
class GeneratedImageFullResult(BaseModel):
images: list[GeneratedImage]
# # Needed for PydanticType
# class GeneratedImageFullResult(BaseModel):
# images: list[GeneratedImage]

View File

@@ -1,36 +1,36 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def kg_search_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
# def kg_search_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a KG search as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
logger.debug(f"Search start for KG Search {iteration_nr} at {datetime.now()}")
# logger.debug(f"Search start for KG Search {iteration_nr} at {datetime.now()}")
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="kg_search",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,97 +1,97 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
from onyx.agents.agent_search.kb_search.states import MainInput as KbMainInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.utils import extract_document_citations
# from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
# from onyx.agents.agent_search.kb_search.states import MainInput as KbMainInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.context.search.models import InferenceSection
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def kg_search(
state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BranchUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
# def kg_search(
# state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> BranchUpdate:
# """
# LangGraph node to perform a KG search as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
state.current_step_nr
parallelization_nr = state.parallelization_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# state.current_step_nr
# parallelization_nr = state.parallelization_nr
search_query = state.branch_question
if not search_query:
raise ValueError("search_query is not set")
# search_query = state.branch_question
# if not search_query:
# raise ValueError("search_query is not set")
logger.debug(
f"Search start for KG Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# logger.debug(
# f"Search start for KG Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
if not state.available_tools:
raise ValueError("available_tools is not set")
# if not state.available_tools:
# raise ValueError("available_tools is not set")
kg_tool_info = state.available_tools[state.tools_used[-1]]
# kg_tool_info = state.available_tools[state.tools_used[-1]]
kb_graph = kb_graph_builder().compile()
# kb_graph = kb_graph_builder().compile()
kb_results = kb_graph.invoke(
input=KbMainInput(question=search_query, individual_flow=False),
config=config,
)
# kb_results = kb_graph.invoke(
# input=KbMainInput(question=search_query, individual_flow=False),
# config=config,
# )
# get cited documents
answer_string = kb_results.get("final_answer") or "No answer provided"
claims: list[str] = []
retrieved_docs: list[InferenceSection] = kb_results.get("retrieved_documents", [])
# # get cited documents
# answer_string = kb_results.get("final_answer") or "No answer provided"
# claims: list[str] = []
# retrieved_docs: list[InferenceSection] = kb_results.get("retrieved_documents", [])
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
# (
# citation_numbers,
# answer_string,
# claims,
# ) = extract_document_citations(answer_string, claims)
# if citation is empty, the answer must have come from the KG rather than a doc
# in that case, simply cite the docs returned by the KG
if not citation_numbers:
citation_numbers = [i + 1 for i in range(len(retrieved_docs))]
# # if citation is empty, the answer must have come from the KG rather than a doc
# # in that case, simply cite the docs returned by the KG
# if not citation_numbers:
# citation_numbers = [i + 1 for i in range(len(retrieved_docs))]
cited_documents = {
citation_number: retrieved_docs[citation_number - 1]
for citation_number in citation_numbers
if citation_number <= len(retrieved_docs)
}
# cited_documents = {
# citation_number: retrieved_docs[citation_number - 1]
# for citation_number in citation_numbers
# if citation_number <= len(retrieved_docs)
# }
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=kg_tool_info.llm_path,
tool_id=kg_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=search_query,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=None,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=kg_tool_info.llm_path,
# tool_id=kg_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=search_query,
# answer=answer_string,
# claims=claims,
# cited_documents=cited_documents,
# reasoning=None,
# additional_data=None,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="kg_search",
# node_name="searching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,121 +1,121 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import ReasoningDelta
# from onyx.server.query_and_chat.streaming_models import ReasoningStart
# from onyx.server.query_and_chat.streaming_models import SearchToolDelta
# from onyx.server.query_and_chat.streaming_models import SearchToolStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
_MAX_KG_STEAMED_ANSWER_LENGTH = 1000 # num characters
# _MAX_KG_STEAMED_ANSWER_LENGTH = 1000 # num characters
def kg_search_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
# def kg_search_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a KG search as part of the DR process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
# current_step_nr = state.current_step_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
queries = [update.question for update in new_updates]
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
# queries = [update.question for update in new_updates]
# doc_lists = [list(update.cited_documents.values()) for update in new_updates]
doc_list = []
# doc_list = []
for xs in doc_lists:
for x in xs:
doc_list.append(x)
# for xs in doc_lists:
# for x in xs:
# doc_list.append(x)
retrieved_search_docs = convert_inference_sections_to_search_docs(doc_list)
kg_answer = (
"The Knowledge Graph Answer:\n\n" + new_updates[0].answer
if len(queries) == 1
else None
)
# retrieved_search_docs = convert_inference_sections_to_search_docs(doc_list)
# kg_answer = (
# "The Knowledge Graph Answer:\n\n" + new_updates[0].answer
# if len(queries) == 1
# else None
# )
if len(retrieved_search_docs) > 0:
write_custom_event(
current_step_nr,
SearchToolStart(
is_internet_search=False,
),
writer,
)
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=queries,
documents=retrieved_search_docs,
),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# if len(retrieved_search_docs) > 0:
# write_custom_event(
# current_step_nr,
# SearchToolStart(
# is_internet_search=False,
# ),
# writer,
# )
# write_custom_event(
# current_step_nr,
# SearchToolDelta(
# queries=queries,
# documents=retrieved_search_docs,
# ),
# writer,
# )
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
current_step_nr += 1
# current_step_nr += 1
if kg_answer is not None:
# if kg_answer is not None:
kg_display_answer = (
f"{kg_answer[:_MAX_KG_STEAMED_ANSWER_LENGTH]}..."
if len(kg_answer) > _MAX_KG_STEAMED_ANSWER_LENGTH
else kg_answer
)
# kg_display_answer = (
# f"{kg_answer[:_MAX_KG_STEAMED_ANSWER_LENGTH]}..."
# if len(kg_answer) > _MAX_KG_STEAMED_ANSWER_LENGTH
# else kg_answer
# )
write_custom_event(
current_step_nr,
ReasoningStart(),
writer,
)
write_custom_event(
current_step_nr,
ReasoningDelta(reasoning=kg_display_answer),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# write_custom_event(
# current_step_nr,
# ReasoningStart(),
# writer,
# )
# write_custom_event(
# current_step_nr,
# ReasoningDelta(reasoning=kg_display_answer),
# writer,
# )
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
current_step_nr += 1
# current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)
# return SubAgentUpdate(
# iteration_responses=new_updates,
# current_step_nr=current_step_nr,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="kg_search",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,27 +1,27 @@
from collections.abc import Hashable
# from collections.abc import Hashable
from langgraph.types import Send
# from langgraph.types import Send
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel search for now
)
]
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# context="",
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:1] # no parallel search for now
# )
# ]

View File

@@ -1,50 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_1_branch import (
kg_search_branch,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_2_act import (
kg_search,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_3_reduce import (
kg_search_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_1_branch import (
# kg_search_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_2_act import (
# kg_search,
# )
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_3_reduce import (
# kg_search_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def dr_kg_search_graph_builder() -> StateGraph:
"""
LangGraph graph builder for KG Search Sub-Agent
"""
# def dr_kg_search_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for KG Search Sub-Agent
# """
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
# ### Add nodes ###
graph.add_node("branch", kg_search_branch)
# graph.add_node("branch", kg_search_branch)
graph.add_node("act", kg_search)
# graph.add_node("act", kg_search)
graph.add_node("reducer", kg_search_reducer)
# graph.add_node("reducer", kg_search_reducer)
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
# graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
# graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
# graph.add_edge(start_key="reducer", end_key=END)
return graph
# return graph

View File

@@ -1,46 +1,46 @@
from operator import add
from typing import Annotated
# from operator import add
# from typing import Annotated
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.db.connector import DocumentSource
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.models import OrchestratorTool
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.db.connector import DocumentSource
class SubAgentUpdate(LoggerUpdate):
iteration_responses: Annotated[list[IterationAnswer], add] = []
current_step_nr: int = 1
# class SubAgentUpdate(LoggerUpdate):
# iteration_responses: Annotated[list[IterationAnswer], add] = []
# current_step_nr: int = 1
class BranchUpdate(LoggerUpdate):
branch_iteration_responses: Annotated[list[IterationAnswer], add] = []
# class BranchUpdate(LoggerUpdate):
# branch_iteration_responses: Annotated[list[IterationAnswer], add] = []
class SubAgentInput(LoggerUpdate):
iteration_nr: int = 0
current_step_nr: int = 1
query_list: list[str] = []
context: str | None = None
active_source_types: list[DocumentSource] | None = None
tools_used: Annotated[list[str], add] = []
available_tools: dict[str, OrchestratorTool] | None = None
assistant_system_prompt: str | None = None
assistant_task_prompt: str | None = None
# class SubAgentInput(LoggerUpdate):
# iteration_nr: int = 0
# current_step_nr: int = 1
# query_list: list[str] = []
# context: str | None = None
# active_source_types: list[DocumentSource] | None = None
# tools_used: Annotated[list[str], add] = []
# available_tools: dict[str, OrchestratorTool] | None = None
# assistant_system_prompt: str | None = None
# assistant_task_prompt: str | None = None
class SubAgentMainState(
# This includes the core state
SubAgentInput,
SubAgentUpdate,
BranchUpdate,
):
pass
# class SubAgentMainState(
# # This includes the core state
# SubAgentInput,
# SubAgentUpdate,
# BranchUpdate,
# ):
# pass
class BranchInput(SubAgentInput):
parallelization_nr: int = 0
branch_question: str
# class BranchInput(SubAgentInput):
# parallelization_nr: int = 0
# branch_question: str
class CustomToolBranchInput(LoggerUpdate):
tool_info: OrchestratorTool
# class CustomToolBranchInput(LoggerUpdate):
# tool_info: OrchestratorTool

View File

@@ -1,70 +1,71 @@
from collections.abc import Sequence
# from collections.abc import Sequence
from exa_py import Exa
from exa_py.api import HighlightsContentsOptions
# from exa_py import Exa
# from exa_py.api import HighlightsContentsOptions
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContent,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.utils.retry_wrapper import retry_builder
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebContent,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchProvider,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchResult,
# )
# from onyx.configs.chat_configs import EXA_API_KEY
# from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
# from onyx.utils.retry_wrapper import retry_builder
class ExaClient(WebSearchProvider):
def __init__(self, api_key: str) -> None:
self.exa = Exa(api_key=api_key)
# class ExaClient(WebSearchProvider):
# def __init__(self, api_key: str | None = EXA_API_KEY) -> None:
# self.exa = Exa(api_key=api_key)
@retry_builder(tries=3, delay=1, backoff=2)
def search(self, query: str) -> list[WebSearchResult]:
response = self.exa.search_and_contents(
query,
type="auto",
highlights=HighlightsContentsOptions(
num_sentences=2,
highlights_per_url=1,
),
num_results=10,
)
# @retry_builder(tries=3, delay=1, backoff=2)
# def search(self, query: str) -> list[WebSearchResult]:
# response = self.exa.search_and_contents(
# query,
# type="auto",
# highlights=HighlightsContentsOptions(
# num_sentences=2,
# highlights_per_url=1,
# ),
# num_results=10,
# )
return [
WebSearchResult(
title=result.title or "",
link=result.url,
snippet=result.highlights[0] if result.highlights else "",
author=result.author,
published_date=(
time_str_to_utc(result.published_date)
if result.published_date
else None
),
)
for result in response.results
]
# return [
# WebSearchResult(
# title=result.title or "",
# link=result.url,
# snippet=result.highlights[0] if result.highlights else "",
# author=result.author,
# published_date=(
# time_str_to_utc(result.published_date)
# if result.published_date
# else None
# ),
# )
# for result in response.results
# ]
@retry_builder(tries=3, delay=1, backoff=2)
def contents(self, urls: Sequence[str]) -> list[WebContent]:
response = self.exa.get_contents(
urls=list(urls),
text=True,
livecrawl="preferred",
)
# @retry_builder(tries=3, delay=1, backoff=2)
# def contents(self, urls: Sequence[str]) -> list[WebContent]:
# response = self.exa.get_contents(
# urls=list(urls),
# text=True,
# livecrawl="preferred",
# )
return [
WebContent(
title=result.title or "",
link=result.url,
full_content=result.text or "",
published_date=(
time_str_to_utc(result.published_date)
if result.published_date
else None
),
)
for result in response.results
]
# return [
# WebContent(
# title=result.title or "",
# link=result.url,
# full_content=result.text or "",
# published_date=(
# time_str_to_utc(result.published_date)
# if result.published_date
# else None
# ),
# )
# for result in response.results
# ]

View File

@@ -1,159 +1,148 @@
import json
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
# import json
# from collections.abc import Sequence
# from concurrent.futures import ThreadPoolExecutor
import requests
# import requests
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContent,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.utils.retry_wrapper import retry_builder
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebContent,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchProvider,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchResult,
# )
# from onyx.configs.chat_configs import SERPER_API_KEY
# from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
# from onyx.utils.retry_wrapper import retry_builder
SERPER_SEARCH_URL = "https://google.serper.dev/search"
SERPER_CONTENTS_URL = "https://scrape.serper.dev"
# SERPER_SEARCH_URL = "https://google.serper.dev/search"
# SERPER_CONTENTS_URL = "https://scrape.serper.dev"
class SerperClient(WebSearchProvider):
def __init__(self, api_key: str) -> None:
self.headers = {
"X-API-KEY": api_key,
"Content-Type": "application/json",
}
# class SerperClient(WebSearchProvider):
# def __init__(self, api_key: str | None = SERPER_API_KEY) -> None:
# self.headers = {
# "X-API-KEY": api_key,
# "Content-Type": "application/json",
# }
@retry_builder(tries=3, delay=1, backoff=2)
def search(self, query: str) -> list[WebSearchResult]:
payload = {
"q": query,
}
# @retry_builder(tries=3, delay=1, backoff=2)
# def search(self, query: str) -> list[WebSearchResult]:
# payload = {
# "q": query,
# }
response = requests.post(
SERPER_SEARCH_URL,
headers=self.headers,
data=json.dumps(payload),
)
# response = requests.post(
# SERPER_SEARCH_URL,
# headers=self.headers,
# data=json.dumps(payload),
# )
try:
response.raise_for_status()
except Exception:
# Avoid leaking API keys/URLs
raise ValueError(
"Serper search failed. Check credentials or quota."
) from None
# response.raise_for_status()
results = response.json()
organic_results = results["organic"]
# results = response.json()
# organic_results = results["organic"]
return [
WebSearchResult(
title=result["title"],
link=result["link"],
snippet=result["snippet"],
author=None,
published_date=None,
)
for result in organic_results
]
# return [
# WebSearchResult(
# title=result["title"],
# link=result["link"],
# snippet=result["snippet"],
# author=None,
# published_date=None,
# )
# for result in organic_results
# ]
def contents(self, urls: Sequence[str]) -> list[WebContent]:
if not urls:
return []
# def contents(self, urls: Sequence[str]) -> list[WebContent]:
# if not urls:
# return []
# Serper can responds with 500s regularly. We want to retry,
# but in the event of failure, return an unsuccesful scrape.
def safe_get_webpage_content(url: str) -> WebContent:
try:
return self._get_webpage_content(url)
except Exception:
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
# # Serper can responds with 500s regularly. We want to retry,
# # but in the event of failure, return an unsuccesful scrape.
# def safe_get_webpage_content(url: str) -> WebContent:
# try:
# return self._get_webpage_content(url)
# except Exception:
# return WebContent(
# title="",
# link=url,
# full_content="",
# published_date=None,
# scrape_successful=False,
# )
with ThreadPoolExecutor(max_workers=min(8, len(urls))) as e:
return list(e.map(safe_get_webpage_content, urls))
# with ThreadPoolExecutor(max_workers=min(8, len(urls))) as e:
# return list(e.map(safe_get_webpage_content, urls))
@retry_builder(tries=3, delay=1, backoff=2)
def _get_webpage_content(self, url: str) -> WebContent:
payload = {
"url": url,
}
# @retry_builder(tries=3, delay=1, backoff=2)
# def _get_webpage_content(self, url: str) -> WebContent:
# payload = {
# "url": url,
# }
response = requests.post(
SERPER_CONTENTS_URL,
headers=self.headers,
data=json.dumps(payload),
)
# response = requests.post(
# SERPER_CONTENTS_URL,
# headers=self.headers,
# data=json.dumps(payload),
# )
# 400 returned when serper cannot scrape
if response.status_code == 400:
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
# # 400 returned when serper cannot scrape
# if response.status_code == 400:
# return WebContent(
# title="",
# link=url,
# full_content="",
# published_date=None,
# scrape_successful=False,
# )
try:
response.raise_for_status()
except Exception:
# Avoid leaking API keys/URLs
raise ValueError(
"Serper content fetch failed. Check credentials."
) from None
# response.raise_for_status()
response_json = response.json()
# response_json = response.json()
# Response only guarantees text
text = response_json["text"]
# # Response only guarantees text
# text = response_json["text"]
# metadata & jsonld is not guaranteed to be present
metadata = response_json.get("metadata", {})
jsonld = response_json.get("jsonld", {})
# # metadata & jsonld is not guaranteed to be present
# metadata = response_json.get("metadata", {})
# jsonld = response_json.get("jsonld", {})
title = extract_title_from_metadata(metadata)
# title = extract_title_from_metadata(metadata)
# Serper does not provide a reliable mechanism to extract the url
response_url = url
published_date_str = extract_published_date_from_jsonld(jsonld)
published_date = None
# # Serper does not provide a reliable mechanism to extract the url
# response_url = url
# published_date_str = extract_published_date_from_jsonld(jsonld)
# published_date = None
if published_date_str:
try:
published_date = time_str_to_utc(published_date_str)
except Exception:
published_date = None
# if published_date_str:
# try:
# published_date = time_str_to_utc(published_date_str)
# except Exception:
# published_date = None
return WebContent(
title=title or "",
link=response_url,
full_content=text or "",
published_date=published_date,
)
# return WebContent(
# title=title or "",
# link=response_url,
# full_content=text or "",
# published_date=published_date,
# )
def extract_title_from_metadata(metadata: dict[str, str]) -> str | None:
keys = ["title", "og:title"]
return extract_value_from_dict(metadata, keys)
# def extract_title_from_metadata(metadata: dict[str, str]) -> str | None:
# keys = ["title", "og:title"]
# return extract_value_from_dict(metadata, keys)
def extract_published_date_from_jsonld(jsonld: dict[str, str]) -> str | None:
keys = ["dateModified"]
return extract_value_from_dict(jsonld, keys)
# def extract_published_date_from_jsonld(jsonld: dict[str, str]) -> str | None:
# keys = ["dateModified"]
# return extract_value_from_dict(jsonld, keys)
def extract_value_from_dict(data: dict[str, str], keys: list[str]) -> str | None:
for key in keys:
if key in data:
return data[key]
return None
# def extract_value_from_dict(data: dict[str, str], keys: list[str]) -> str | None:
# for key in keys:
# if key in data:
# return data[key]
# return None

View File

@@ -1,47 +1,47 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import SearchToolStart
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def is_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a web search as part of the DR process.
"""
# def is_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a web search as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
current_step_nr = state.current_step_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# current_step_nr = state.current_step_nr
logger.debug(f"Search start for Web Search {iteration_nr} at {datetime.now()}")
# logger.debug(f"Search start for Web Search {iteration_nr} at {datetime.now()}")
write_custom_event(
current_step_nr,
SearchToolStart(
is_internet_search=True,
),
writer,
)
# write_custom_event(
# current_step_nr,
# SearchToolStart(
# is_internet_search=True,
# ),
# writer,
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="internet_search",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,137 +1,128 @@
from datetime import datetime
from typing import cast
# from datetime import datetime
# from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from langsmith import traceable
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
# from langsmith import traceable
from onyx.agents.agent_search.dr.models import WebSearchAnswer
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
get_default_provider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
InternetSearchInput,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
InternetSearchUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.prompts.dr_prompts import WEB_SEARCH_URL_SELECTION_PROMPT
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.models import WebSearchAnswer
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchResult,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
# get_default_provider,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
# InternetSearchInput,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
# InternetSearchUpdate,
# )
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.agents.agent_search.utils import create_question_prompt
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
# from onyx.prompts.dr_prompts import WEB_SEARCH_URL_SELECTION_PROMPT
# from onyx.server.query_and_chat.streaming_models import SearchToolDelta
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def web_search(
state: InternetSearchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> InternetSearchUpdate:
"""
LangGraph node to perform internet search and decide which URLs to fetch.
"""
# def web_search(
# state: InternetSearchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> InternetSearchUpdate:
# """
# LangGraph node to perform internet search and decide which URLs to fetch.
# """
node_start_time = datetime.now()
current_step_nr = state.current_step_nr
# node_start_time = datetime.now()
# current_step_nr = state.current_step_nr
if not current_step_nr:
raise ValueError("Current step number is not set. This should not happen.")
# if not current_step_nr:
# raise ValueError("Current step number is not set. This should not happen.")
assistant_system_prompt = state.assistant_system_prompt
assistant_task_prompt = state.assistant_task_prompt
# assistant_system_prompt = state.assistant_system_prompt
# assistant_task_prompt = state.assistant_task_prompt
if not state.available_tools:
raise ValueError("available_tools is not set")
search_query = state.branch_question
# if not state.available_tools:
# raise ValueError("available_tools is not set")
# search_query = state.branch_question
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=[search_query],
documents=[],
),
writer,
)
# write_custom_event(
# current_step_nr,
# SearchToolDelta(
# queries=[search_query],
# documents=[],
# ),
# writer,
# )
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = graph_config.inputs.prompt_builder.raw_user_query
if graph_config.inputs.persona is None:
raise ValueError("persona is not set")
# if graph_config.inputs.persona is None:
# raise ValueError("persona is not set")
provider = get_default_provider()
if not provider:
raise ValueError("No internet search provider found")
# provider = get_default_provider()
# if not provider:
# raise ValueError("No internet search provider found")
# Log which provider type is being used
provider_type = type(provider).__name__
logger.info(
f"Performing web search with {provider_type} for query: '{search_query}'"
)
# @traceable(name="Search Provider API Call")
# def _search(search_query: str) -> list[WebSearchResult]:
# search_results: list[WebSearchResult] = []
# try:
# search_results = list(provider.search(search_query))
# except Exception as e:
# logger.error(f"Error performing search: {e}")
# return search_results
@traceable(name="Search Provider API Call")
def _search(search_query: str) -> list[WebSearchResult]:
search_results: list[WebSearchResult] = []
try:
search_results = list(provider.search(search_query))
logger.info(
f"Search returned {len(search_results)} results using {provider_type}"
)
except Exception as e:
logger.error(f"Error performing search with {provider_type}: {e}")
return search_results
search_results: list[WebSearchResult] = _search(search_query)
search_results_text = "\n\n".join(
[
f"{i}. {result.title}\n URL: {result.link}\n"
+ (f" Author: {result.author}\n" if result.author else "")
+ (
f" Date: {result.published_date.strftime('%Y-%m-%d')}\n"
if result.published_date
else ""
)
+ (f" Snippet: {result.snippet}\n" if result.snippet else "")
for i, result in enumerate(search_results)
]
)
agent_decision_prompt = WEB_SEARCH_URL_SELECTION_PROMPT.build(
search_query=search_query,
base_question=base_question,
search_results_text=search_results_text,
)
agent_decision = invoke_llm_json(
llm=graph_config.tooling.fast_llm,
prompt=create_question_prompt(
assistant_system_prompt,
agent_decision_prompt + (assistant_task_prompt or ""),
),
schema=WebSearchAnswer,
timeout_override=TF_DR_TIMEOUT_SHORT,
)
results_to_open = [
(search_query, search_results[i])
for i in agent_decision.urls_to_open_indices
if i < len(search_results) and i >= 0
]
return InternetSearchUpdate(
results_to_open=results_to_open,
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)
# search_results: list[WebSearchResult] = _search(search_query)
# search_results_text = "\n\n".join(
# [
# f"{i}. {result.title}\n URL: {result.link}\n"
# + (f" Author: {result.author}\n" if result.author else "")
# + (
# f" Date: {result.published_date.strftime('%Y-%m-%d')}\n"
# if result.published_date
# else ""
# )
# + (f" Snippet: {result.snippet}\n" if result.snippet else "")
# for i, result in enumerate(search_results)
# ]
# )
# agent_decision_prompt = WEB_SEARCH_URL_SELECTION_PROMPT.build(
# search_query=search_query,
# base_question=base_question,
# search_results_text=search_results_text,
# )
# agent_decision = invoke_llm_json(
# llm=graph_config.tooling.fast_llm,
# prompt=create_question_prompt(
# assistant_system_prompt,
# agent_decision_prompt + (assistant_task_prompt or ""),
# ),
# schema=WebSearchAnswer,
# timeout_override=TF_DR_TIMEOUT_SHORT,
# )
# results_to_open = [
# (search_query, search_results[i])
# for i in agent_decision.urls_to_open_indices
# if i < len(search_results) and i >= 0
# ]
# return InternetSearchUpdate(
# results_to_open=results_to_open,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="internet_search",
# node_name="searching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,52 +1,52 @@
from collections import defaultdict
# from collections import defaultdict
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
InternetSearchInput,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
dummy_inference_section_from_internet_search_result,
)
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchResult,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
# InternetSearchInput,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
# dummy_inference_section_from_internet_search_result,
# )
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import SearchToolDelta
def dedup_urls(
state: InternetSearchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> InternetSearchInput:
branch_questions_to_urls: dict[str, list[str]] = defaultdict(list)
unique_results_by_link: dict[str, WebSearchResult] = {}
for query, result in state.results_to_open:
branch_questions_to_urls[query].append(result.link)
if result.link not in unique_results_by_link:
unique_results_by_link[result.link] = result
# def dedup_urls(
# state: InternetSearchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> InternetSearchInput:
# branch_questions_to_urls: dict[str, list[str]] = defaultdict(list)
# unique_results_by_link: dict[str, WebSearchResult] = {}
# for query, result in state.results_to_open:
# branch_questions_to_urls[query].append(result.link)
# if result.link not in unique_results_by_link:
# unique_results_by_link[result.link] = result
unique_results = list(unique_results_by_link.values())
dummy_docs_inference_sections = [
dummy_inference_section_from_internet_search_result(doc)
for doc in unique_results
]
# unique_results = list(unique_results_by_link.values())
# dummy_docs_inference_sections = [
# dummy_inference_section_from_internet_search_result(doc)
# for doc in unique_results
# ]
write_custom_event(
state.current_step_nr,
SearchToolDelta(
queries=[],
documents=convert_inference_sections_to_search_docs(
dummy_docs_inference_sections, is_internet=True
),
),
writer,
)
# write_custom_event(
# state.current_step_nr,
# SearchToolDelta(
# queries=[],
# documents=convert_inference_sections_to_search_docs(
# dummy_docs_inference_sections, is_internet=True
# ),
# ),
# writer,
# )
return InternetSearchInput(
results_to_open=[],
branch_questions_to_urls=branch_questions_to_urls,
)
# return InternetSearchInput(
# results_to_open=[],
# branch_questions_to_urls=branch_questions_to_urls,
# )

View File

@@ -1,69 +1,69 @@
from datetime import datetime
from typing import cast
# from datetime import datetime
# from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
get_default_content_provider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchUpdate
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
dummy_inference_section_from_internet_content,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
# get_default_provider,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchUpdate
# from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
# inference_section_from_internet_page_scrape,
# )
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.context.search.models import InferenceSection
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def web_fetch(
state: FetchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> FetchUpdate:
"""
LangGraph node to fetch content from URLs and process the results.
"""
# def web_fetch(
# state: FetchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> FetchUpdate:
# """
# LangGraph node to fetch content from URLs and process the results.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
if not state.available_tools:
raise ValueError("available_tools is not set")
# if not state.available_tools:
# raise ValueError("available_tools is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
# graph_config = cast(GraphConfig, config["metadata"]["config"])
if graph_config.inputs.persona is None:
raise ValueError("persona is not set")
# if graph_config.inputs.persona is None:
# raise ValueError("persona is not set")
provider = get_default_content_provider()
if provider is None:
raise ValueError("No web content provider found")
# provider = get_default_provider()
# if provider is None:
# raise ValueError("No web search provider found")
retrieved_docs: list[InferenceSection] = []
try:
retrieved_docs = [
dummy_inference_section_from_internet_content(result)
for result in provider.contents(state.urls_to_open)
]
except Exception as e:
logger.exception(e)
# retrieved_docs: list[InferenceSection] = []
# try:
# retrieved_docs = [
# inference_section_from_internet_page_scrape(result)
# for result in provider.contents(state.urls_to_open)
# ]
# except Exception as e:
# logger.error(f"Error fetching URLs: {e}")
if not retrieved_docs:
logger.warning("No content retrieved from URLs")
# if not retrieved_docs:
# logger.warning("No content retrieved from URLs")
return FetchUpdate(
raw_documents=retrieved_docs,
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="fetching",
node_start_time=node_start_time,
)
],
)
# return FetchUpdate(
# raw_documents=retrieved_docs,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="internet_search",
# node_name="fetching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,19 +1,19 @@
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
InternetSearchInput,
)
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
# InternetSearchInput,
# )
def collect_raw_docs(
state: FetchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> InternetSearchInput:
raw_documents = state.raw_documents
# def collect_raw_docs(
# state: FetchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> InternetSearchInput:
# raw_documents = state.raw_documents
return InternetSearchInput(
raw_documents=raw_documents,
)
# return InternetSearchInput(
# raw_documents=raw_documents,
# )

View File

@@ -1,133 +1,132 @@
from datetime import datetime
from typing import cast
# from datetime import datetime
# from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import SearchAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.sub_agents.web_search.states import SummarizeInput
from onyx.agents.agent_search.dr.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.context.search.models import InferenceSection
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
from onyx.utils.logger import setup_logger
from onyx.utils.url import normalize_url
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.models import SearchAnswer
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import SummarizeInput
# from onyx.agents.agent_search.dr.utils import extract_document_citations
# from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.utils import create_question_prompt
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
# from onyx.context.search.models import InferenceSection
# from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
# from onyx.utils.logger import setup_logger
# from onyx.utils.url import normalize_url
logger = setup_logger()
# logger = setup_logger()
def is_summarize(
state: SummarizeInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a internet search as part of the DR process.
"""
# def is_summarize(
# state: SummarizeInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a internet search as part of the DR process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
# build branch iterations from fetch inputs
# Normalize URLs to handle mismatches from query parameters (e.g., ?activeTab=explore)
url_to_raw_document: dict[str, InferenceSection] = {}
for raw_document in state.raw_documents:
normalized_url = normalize_url(raw_document.center_chunk.semantic_identifier)
url_to_raw_document[normalized_url] = raw_document
# # build branch iterations from fetch inputs
# # Normalize URLs to handle mismatches from query parameters (e.g., ?activeTab=explore)
# url_to_raw_document: dict[str, InferenceSection] = {}
# for raw_document in state.raw_documents:
# normalized_url = normalize_url(raw_document.center_chunk.semantic_identifier)
# url_to_raw_document[normalized_url] = raw_document
# Normalize the URLs from branch_questions_to_urls as well
urls = [
normalize_url(url)
for url in state.branch_questions_to_urls[state.branch_question]
]
current_iteration = state.iteration_nr
graph_config = cast(GraphConfig, config["metadata"]["config"])
research_type = graph_config.behavior.research_type
if not state.available_tools:
raise ValueError("available_tools is not set")
is_tool_info = state.available_tools[state.tools_used[-1]]
# # Normalize the URLs from branch_questions_to_urls as well
# urls = [
# normalize_url(url)
# for url in state.branch_questions_to_urls[state.branch_question]
# ]
# current_iteration = state.iteration_nr
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# use_agentic_search = graph_config.behavior.use_agentic_search
# if not state.available_tools:
# raise ValueError("available_tools is not set")
# is_tool_info = state.available_tools[state.tools_used[-1]]
if research_type == ResearchType.DEEP:
cited_raw_documents = [url_to_raw_document[url] for url in urls]
document_texts = _create_document_texts(cited_raw_documents)
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
search_query=state.branch_question,
base_question=graph_config.inputs.prompt_builder.raw_user_query,
document_text=document_texts,
)
assistant_system_prompt = state.assistant_system_prompt
assistant_task_prompt = state.assistant_task_prompt
search_answer_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
),
schema=SearchAnswer,
timeout_override=TF_DR_TIMEOUT_SHORT,
)
answer_string = search_answer_json.answer
claims = search_answer_json.claims or []
reasoning = search_answer_json.reasoning or ""
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
cited_documents = {
citation_number: cited_raw_documents[citation_number - 1]
for citation_number in citation_numbers
}
# if use_agentic_search:
# cited_raw_documents = [url_to_raw_document[url] for url in urls]
# document_texts = _create_document_texts(cited_raw_documents)
# search_prompt = INTERNAL_SEARCH_PROMPTS[use_agentic_search].build(
# search_query=state.branch_question,
# base_question=graph_config.inputs.prompt_builder.raw_user_query,
# document_text=document_texts,
# )
# assistant_system_prompt = state.assistant_system_prompt
# assistant_task_prompt = state.assistant_task_prompt
# search_answer_json = invoke_llm_json(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
# ),
# schema=SearchAnswer,
# timeout_override=TF_DR_TIMEOUT_SHORT,
# )
# answer_string = search_answer_json.answer
# claims = search_answer_json.claims or []
# reasoning = search_answer_json.reasoning or ""
# (
# citation_numbers,
# answer_string,
# claims,
# ) = extract_document_citations(answer_string, claims)
# cited_documents = {
# citation_number: cited_raw_documents[citation_number - 1]
# for citation_number in citation_numbers
# }
else:
answer_string = ""
reasoning = ""
claims = []
cited_raw_documents = [url_to_raw_document[url] for url in urls]
cited_documents = {
doc_num + 1: retrieved_doc
for doc_num, retrieved_doc in enumerate(cited_raw_documents)
}
# else:
# answer_string = ""
# reasoning = ""
# claims = []
# cited_raw_documents = [url_to_raw_document[url] for url in urls]
# cited_documents = {
# doc_num + 1: retrieved_doc
# for doc_num, retrieved_doc in enumerate(cited_raw_documents)
# }
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=is_tool_info.llm_path,
tool_id=is_tool_info.tool_id,
iteration_nr=current_iteration,
parallelization_nr=0,
question=state.branch_question,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=reasoning,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="summarizing",
node_start_time=node_start_time,
)
],
)
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=is_tool_info.llm_path,
# tool_id=is_tool_info.tool_id,
# iteration_nr=current_iteration,
# parallelization_nr=0,
# question=state.branch_question,
# answer=answer_string,
# claims=claims,
# cited_documents=cited_documents,
# reasoning=reasoning,
# additional_data=None,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="internet_search",
# node_name="summarizing",
# node_start_time=node_start_time,
# )
# ],
# )
def _create_document_texts(raw_documents: list[InferenceSection]) -> str:
document_texts_list = []
for doc_num, retrieved_doc in enumerate(raw_documents):
if not isinstance(retrieved_doc, InferenceSection):
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
return "\n\n".join(document_texts_list)
# def _create_document_texts(raw_documents: list[InferenceSection]) -> str:
# document_texts_list = []
# for doc_num, retrieved_doc in enumerate(raw_documents):
# if not isinstance(retrieved_doc, InferenceSection):
# raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
# chunk_text = build_document_context(retrieved_doc, doc_num + 1)
# document_texts_list.append(chunk_text)
# return "\n\n".join(document_texts_list)

View File

@@ -1,56 +1,56 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def is_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a internet search as part of the DR process.
"""
# def is_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a internet search as part of the DR process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
# current_step_nr = state.current_step_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
current_step_nr += 1
# current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="internet_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)
# return SubAgentUpdate(
# iteration_responses=new_updates,
# current_step_nr=current_step_nr,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="internet_search",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,79 +1,79 @@
from collections.abc import Hashable
# from collections.abc import Hashable
from langgraph.types import Send
# from langgraph.types import Send
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
InternetSearchInput,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.states import SummarizeInput
# from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
# InternetSearchInput,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.states import SummarizeInput
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"search",
InternetSearchInput(
iteration_nr=state.iteration_nr,
current_step_nr=state.current_step_nr,
parallelization_nr=parallelization_nr,
query_list=[query],
branch_question=query,
context="",
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
results_to_open=[],
),
)
for parallelization_nr, query in enumerate(
state.query_list[:MAX_DR_PARALLEL_SEARCH]
)
]
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "search",
# InternetSearchInput(
# iteration_nr=state.iteration_nr,
# current_step_nr=state.current_step_nr,
# parallelization_nr=parallelization_nr,
# query_list=[query],
# branch_question=query,
# context="",
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# results_to_open=[],
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:MAX_DR_PARALLEL_SEARCH]
# )
# ]
def fetch_router(state: InternetSearchInput) -> list[Send | Hashable]:
branch_questions_to_urls = state.branch_questions_to_urls
return [
Send(
"fetch",
FetchInput(
iteration_nr=state.iteration_nr,
urls_to_open=[url],
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
current_step_nr=state.current_step_nr,
branch_questions_to_urls=branch_questions_to_urls,
raw_documents=state.raw_documents,
),
)
for url in set(
url for urls in branch_questions_to_urls.values() for url in urls
)
]
# def fetch_router(state: InternetSearchInput) -> list[Send | Hashable]:
# branch_questions_to_urls = state.branch_questions_to_urls
# return [
# Send(
# "fetch",
# FetchInput(
# iteration_nr=state.iteration_nr,
# urls_to_open=[url],
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# current_step_nr=state.current_step_nr,
# branch_questions_to_urls=branch_questions_to_urls,
# raw_documents=state.raw_documents,
# ),
# )
# for url in set(
# url for urls in branch_questions_to_urls.values() for url in urls
# )
# ]
def summarize_router(state: InternetSearchInput) -> list[Send | Hashable]:
branch_questions_to_urls = state.branch_questions_to_urls
return [
Send(
"summarize",
SummarizeInput(
iteration_nr=state.iteration_nr,
raw_documents=state.raw_documents,
branch_questions_to_urls=branch_questions_to_urls,
branch_question=branch_question,
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
current_step_nr=state.current_step_nr,
),
)
for branch_question in branch_questions_to_urls.keys()
]
# def summarize_router(state: InternetSearchInput) -> list[Send | Hashable]:
# branch_questions_to_urls = state.branch_questions_to_urls
# return [
# Send(
# "summarize",
# SummarizeInput(
# iteration_nr=state.iteration_nr,
# raw_documents=state.raw_documents,
# branch_questions_to_urls=branch_questions_to_urls,
# branch_question=branch_question,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# current_step_nr=state.current_step_nr,
# ),
# )
# for branch_question in branch_questions_to_urls.keys()
# ]

View File

@@ -1,84 +1,84 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_1_branch import (
is_branch,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_2_search import (
web_search,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_3_dedup_urls import (
dedup_urls,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_4_fetch import (
web_fetch,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_5_collect_raw_docs import (
collect_raw_docs,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_6_summarize import (
is_summarize,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_7_reduce import (
is_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
fetch_router,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
summarize_router,
)
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_1_branch import (
# is_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_2_search import (
# web_search,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_3_dedup_urls import (
# dedup_urls,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_4_fetch import (
# web_fetch,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_5_collect_raw_docs import (
# collect_raw_docs,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_6_summarize import (
# is_summarize,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_7_reduce import (
# is_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
# fetch_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
# summarize_router,
# )
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def dr_ws_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Internet Search Sub-Agent
"""
# def dr_ws_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Internet Search Sub-Agent
# """
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
# ### Add nodes ###
graph.add_node("branch", is_branch)
# graph.add_node("branch", is_branch)
graph.add_node("search", web_search)
# graph.add_node("search", web_search)
graph.add_node("dedup_urls", dedup_urls)
# graph.add_node("dedup_urls", dedup_urls)
graph.add_node("fetch", web_fetch)
# graph.add_node("fetch", web_fetch)
graph.add_node("collect_raw_docs", collect_raw_docs)
# graph.add_node("collect_raw_docs", collect_raw_docs)
graph.add_node("summarize", is_summarize)
# graph.add_node("summarize", is_summarize)
graph.add_node("reducer", is_reducer)
# graph.add_node("reducer", is_reducer)
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
# graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
# graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="search", end_key="dedup_urls")
# graph.add_edge(start_key="search", end_key="dedup_urls")
graph.add_conditional_edges("dedup_urls", fetch_router)
# graph.add_conditional_edges("dedup_urls", fetch_router)
graph.add_edge(start_key="fetch", end_key="collect_raw_docs")
# graph.add_edge(start_key="fetch", end_key="collect_raw_docs")
graph.add_conditional_edges("collect_raw_docs", summarize_router)
# graph.add_conditional_edges("collect_raw_docs", summarize_router)
graph.add_edge(start_key="summarize", end_key="reducer")
# graph.add_edge(start_key="summarize", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
# graph.add_edge(start_key="reducer", end_key=END)
return graph
# return graph

Some files were not shown because too many files have changed in this diff Show More