Compare commits

..

4 Commits

Author SHA1 Message Date
Jessica Singh
d272ac252e mypy 2025-11-21 17:08:18 -08:00
Jessica Singh
4620fb1129 fix 2025-11-21 14:08:09 -08:00
Jessica Singh
7e334f0de1 use tokenizer 2025-11-21 14:04:07 -08:00
Jessica Singh
a381104d6f thread context for slack bot 2025-11-21 14:04:07 -08:00
558 changed files with 30655 additions and 45591 deletions

View File

@@ -17,7 +17,6 @@ self-hosted-runner:
- runner=16cpu-linux-x64
- ubuntu-slim # Currently in public preview
- volume=40gb
- volume=50gb
# Configuration variables in array of strings defined in your repository or
# organization. `null` means disabling configuration variables check.

View File

@@ -0,0 +1,135 @@
name: 'Build and Push Docker Image with Retry'
description: 'Attempts to build and push a Docker image, with a retry on failure'
inputs:
context:
description: 'Build context'
required: true
file:
description: 'Dockerfile location'
required: true
platforms:
description: 'Target platforms'
required: true
pull:
description: 'Always attempt to pull a newer version of the image'
required: false
default: 'true'
push:
description: 'Push the image to registry'
required: false
default: 'true'
load:
description: 'Load the image into Docker daemon'
required: false
default: 'true'
tags:
description: 'Image tags'
required: true
no-cache:
description: 'Read from cache'
required: false
default: 'false'
cache-from:
description: 'Cache sources'
required: false
cache-to:
description: 'Cache destinations'
required: false
outputs:
description: 'Output destinations'
required: false
provenance:
description: 'Generate provenance attestation'
required: false
default: 'false'
build-args:
description: 'Build arguments'
required: false
retry-wait-time:
description: 'Time to wait before attempt 2 in seconds'
required: false
default: '60'
retry-wait-time-2:
description: 'Time to wait before attempt 3 in seconds'
required: false
default: '120'
runs:
using: "composite"
steps:
- name: Build and push Docker image (Attempt 1 of 3)
id: buildx1
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
continue-on-error: true
with:
context: ${{ inputs.context }}
file: ${{ inputs.file }}
platforms: ${{ inputs.platforms }}
pull: ${{ inputs.pull }}
push: ${{ inputs.push }}
load: ${{ inputs.load }}
tags: ${{ inputs.tags }}
no-cache: ${{ inputs.no-cache }}
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}
outputs: ${{ inputs.outputs }}
provenance: ${{ inputs.provenance }}
build-args: ${{ inputs.build-args }}
- name: Wait before attempt 2
if: steps.buildx1.outcome != 'success'
run: |
echo "First attempt failed. Waiting ${{ inputs.retry-wait-time }} seconds before retry..."
sleep ${{ inputs.retry-wait-time }}
shell: bash
- name: Build and push Docker image (Attempt 2 of 3)
id: buildx2
if: steps.buildx1.outcome != 'success'
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ${{ inputs.context }}
file: ${{ inputs.file }}
platforms: ${{ inputs.platforms }}
pull: ${{ inputs.pull }}
push: ${{ inputs.push }}
load: ${{ inputs.load }}
tags: ${{ inputs.tags }}
no-cache: ${{ inputs.no-cache }}
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}
outputs: ${{ inputs.outputs }}
provenance: ${{ inputs.provenance }}
build-args: ${{ inputs.build-args }}
- name: Wait before attempt 3
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success'
run: |
echo "Second attempt failed. Waiting ${{ inputs.retry-wait-time-2 }} seconds before retry..."
sleep ${{ inputs.retry-wait-time-2 }}
shell: bash
- name: Build and push Docker image (Attempt 3 of 3)
id: buildx3
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success'
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ${{ inputs.context }}
file: ${{ inputs.file }}
platforms: ${{ inputs.platforms }}
pull: ${{ inputs.pull }}
push: ${{ inputs.push }}
load: ${{ inputs.load }}
tags: ${{ inputs.tags }}
no-cache: ${{ inputs.no-cache }}
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}
outputs: ${{ inputs.outputs }}
provenance: ${{ inputs.provenance }}
build-args: ${{ inputs.build-args }}
- name: Report failure
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success' && steps.buildx3.outcome != 'success'
run: |
echo "All attempts failed. Possible transient infrastucture issues? Try again later or inspect logs for details."
shell: bash

View File

@@ -0,0 +1,42 @@
name: "Prepare Build (OpenAPI generation)"
description: "Sets up Python with uv, installs deps, generates OpenAPI schema and Python client, uploads artifact"
inputs:
docker-username:
required: true
docker-password:
required: true
runs:
using: "composite"
steps:
- name: Setup Python and Install Dependencies
uses: ./.github/actions/setup-python-and-install-dependencies
- name: Generate OpenAPI schema
shell: bash
working-directory: backend
env:
PYTHONPATH: "."
run: |
python scripts/onyx_openapi_schema.py --filename generated/openapi.json
# needed for pulling openapitools/openapi-generator-cli
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ inputs['docker-username'] }}
password: ${{ inputs['docker-password'] }}
- name: Generate OpenAPI Python client
shell: bash
run: |
docker run --rm \
-v "${{ github.workspace }}/backend/generated:/local" \
openapitools/openapi-generator-cli generate \
-i /local/openapi.json \
-g python \
-o /local/onyx_openapi_client \
--package-name onyx_openapi_client \
--skip-validate-spec \
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"

View File

@@ -1,9 +1,5 @@
name: "Setup Python and Install Dependencies"
description: "Sets up Python with uv and installs deps"
inputs:
requirements:
description: "Newline-separated list of requirement files to install (relative to repo root)"
required: true
runs:
using: "composite"
steps:
@@ -13,26 +9,11 @@ runs:
# with:
# enable-cache: true
- name: Compute requirements hash
id: req-hash
shell: bash
env:
REQUIREMENTS: ${{ inputs.requirements }}
run: |
# Hash the contents of the specified requirement files
hash=""
while IFS= read -r req; do
if [ -n "$req" ] && [ -f "$req" ]; then
hash="$hash$(sha256sum "$req")"
fi
done <<< "$REQUIREMENTS"
echo "hash=$(echo "$hash" | sha256sum | cut -d' ' -f1)" >> "$GITHUB_OUTPUT"
- name: Cache uv cache directory
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: ~/.cache/uv
key: ${{ runner.os }}-uv-${{ steps.req-hash.outputs.hash }}
key: ${{ runner.os }}-uv-${{ hashFiles('backend/requirements/*.txt', 'backend/pyproject.toml') }}
restore-keys: |
${{ runner.os }}-uv-
@@ -43,30 +24,15 @@ runs:
- name: Create virtual environment
shell: bash
env:
VENV_DIR: ${{ runner.temp }}/venv
run: | # zizmor: ignore[github-env]
uv venv "$VENV_DIR"
# Validate path before adding to GITHUB_PATH to prevent code injection
if [ -d "$VENV_DIR/bin" ]; then
realpath "$VENV_DIR/bin" >> "$GITHUB_PATH"
else
echo "Error: $VENV_DIR/bin does not exist"
exit 1
fi
run: |
uv venv ${{ runner.temp }}/venv
echo "VENV_PATH=${{ runner.temp }}/venv" >> $GITHUB_ENV
echo "${{ runner.temp }}/venv/bin" >> $GITHUB_PATH
- name: Install Python dependencies with uv
shell: bash
env:
REQUIREMENTS: ${{ inputs.requirements }}
run: |
# Build the uv pip install command with each requirement file as array elements
cmd=("uv" "pip" "install")
while IFS= read -r req; do
# Skip empty lines
if [ -n "$req" ]; then
cmd+=("-r" "$req")
fi
done <<< "$REQUIREMENTS"
echo "Running: ${cmd[*]}"
"${cmd[@]}"
uv pip install \
-r backend/requirements/default.txt \
-r backend/requirements/dev.txt \
-r backend/requirements/model_server.txt

View File

@@ -21,27 +21,26 @@ runs:
shell: bash
env:
SLACK_WEBHOOK_URL: ${{ inputs.webhook-url }}
FAILED_JOBS: ${{ inputs.failed-jobs }}
TITLE: ${{ inputs.title }}
REF_NAME: ${{ inputs.ref-name }}
REPO: ${{ github.repository }}
WORKFLOW: ${{ github.workflow }}
RUN_NUMBER: ${{ github.run_number }}
RUN_ID: ${{ github.run_id }}
SERVER_URL: ${{ github.server_url }}
GITHUB_REF_NAME: ${{ github.ref_name }}
run: |
if [ -z "$SLACK_WEBHOOK_URL" ]; then
echo "webhook-url input or SLACK_WEBHOOK_URL env var is not set, skipping notification"
exit 0
fi
# Build workflow URL
# Get inputs with defaults
FAILED_JOBS="${{ inputs.failed-jobs }}"
TITLE="${{ inputs.title }}"
REF_NAME="${{ inputs.ref-name }}"
REPO="${{ github.repository }}"
WORKFLOW="${{ github.workflow }}"
RUN_NUMBER="${{ github.run_number }}"
RUN_ID="${{ github.run_id }}"
SERVER_URL="${{ github.server_url }}"
WORKFLOW_URL="${SERVER_URL}/${REPO}/actions/runs/${RUN_ID}"
# Use ref_name from input or fall back to github.ref_name
if [ -z "$REF_NAME" ]; then
REF_NAME="$GITHUB_REF_NAME"
REF_NAME="${{ github.ref_name }}"
fi
# Escape JSON special characters

View File

@@ -4,8 +4,6 @@ updates:
directory: "/"
schedule:
interval: "weekly"
cooldown:
default-days: 4
open-pull-requests-limit: 3
assignees:
- "jmelahman"
@@ -15,8 +13,6 @@ updates:
directory: "/backend"
schedule:
interval: "weekly"
cooldown:
default-days: 4
open-pull-requests-limit: 3
assignees:
- "jmelahman"

View File

@@ -16,11 +16,10 @@ permissions:
jobs:
check-lazy-imports:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false

View File

@@ -18,7 +18,6 @@ jobs:
determine-builds:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 90
outputs:
build-web: ${{ steps.check.outputs.build-web }}
build-web-cloud: ${{ steps.check.outputs.build-web-cloud }}
@@ -91,7 +90,6 @@ jobs:
- runner=4cpu-linux-x64
- run-id=${{ github.run_id }}-web-amd64
- extras=ecr-cache
timeout-minutes: 90
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -100,7 +98,7 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -149,7 +147,6 @@ jobs:
- runner=4cpu-linux-arm64
- run-id=${{ github.run_id }}-web-arm64
- extras=ecr-cache
timeout-minutes: 90
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -158,7 +155,7 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -209,7 +206,6 @@ jobs:
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-merge-web
- extras=ecr-cache
timeout-minutes: 90
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
steps:
@@ -257,7 +253,6 @@ jobs:
- runner=4cpu-linux-x64
- run-id=${{ github.run_id }}-web-cloud-amd64
- extras=ecr-cache
timeout-minutes: 90
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -266,7 +261,7 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -323,7 +318,6 @@ jobs:
- runner=4cpu-linux-arm64
- run-id=${{ github.run_id }}-web-cloud-arm64
- extras=ecr-cache
timeout-minutes: 90
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -332,7 +326,7 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -391,7 +385,6 @@ jobs:
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-merge-web-cloud
- extras=ecr-cache
timeout-minutes: 90
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
steps:
@@ -436,7 +429,6 @@ jobs:
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-backend-amd64
- extras=ecr-cache
timeout-minutes: 90
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -445,7 +437,7 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -493,7 +485,6 @@ jobs:
- runner=2cpu-linux-arm64
- run-id=${{ github.run_id }}-backend-arm64
- extras=ecr-cache
timeout-minutes: 90
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -502,7 +493,7 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -552,7 +543,6 @@ jobs:
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-merge-backend
- extras=ecr-cache
timeout-minutes: 90
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
steps:
@@ -601,7 +591,6 @@ jobs:
- run-id=${{ github.run_id }}-model-server-amd64
- volume=40gb
- extras=ecr-cache
timeout-minutes: 90
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -610,7 +599,7 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -665,7 +654,6 @@ jobs:
- run-id=${{ github.run_id }}-model-server-arm64
- volume=40gb
- extras=ecr-cache
timeout-minutes: 90
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -674,7 +662,7 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -730,7 +718,6 @@ jobs:
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-merge-model-server
- extras=ecr-cache
timeout-minutes: 90
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
steps:
@@ -780,7 +767,6 @@ jobs:
- runner=2cpu-linux-arm64
- run-id=${{ github.run_id }}-trivy-scan-web
- extras=ecr-cache
timeout-minutes: 90
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
steps:
@@ -820,7 +806,6 @@ jobs:
- runner=2cpu-linux-arm64
- run-id=${{ github.run_id }}-trivy-scan-web-cloud
- extras=ecr-cache
timeout-minutes: 90
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
steps:
@@ -860,14 +845,13 @@ jobs:
- runner=2cpu-linux-arm64
- run-id=${{ github.run_id }}-trivy-scan-backend
- extras=ecr-cache
timeout-minutes: 90
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -907,7 +891,6 @@ jobs:
- runner=2cpu-linux-arm64
- run-id=${{ github.run_id }}-trivy-scan-model-server
- extras=ecr-cache
timeout-minutes: 90
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
steps:
@@ -954,10 +937,9 @@ jobs:
if: always() && (needs.build-web-amd64.result == 'failure' || needs.build-web-arm64.result == 'failure' || needs.merge-web.result == 'failure' || needs.build-web-cloud-amd64.result == 'failure' || needs.build-web-cloud-arm64.result == 'failure' || needs.merge-web-cloud.result == 'failure' || needs.build-backend-amd64.result == 'failure' || needs.build-backend-arm64.result == 'failure' || needs.merge-backend.result == 'failure' || needs.build-model-server-amd64.result == 'failure' || needs.build-model-server-arm64.result == 'failure' || needs.merge-model-server.result == 'failure') && github.event_name != 'workflow_dispatch'
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 90
steps:
- name: Checkout
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false

View File

@@ -18,7 +18,6 @@ jobs:
# See https://runs-on.com/runners/linux/
# use a lower powered instance since this just does i/o to docker hub
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-tag"]
timeout-minutes: 45
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3

View File

@@ -18,7 +18,6 @@ jobs:
# See https://runs-on.com/runners/linux/
# use a lower powered instance since this just does i/o to docker hub
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-tag"]
timeout-minutes: 45
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3

View File

@@ -12,10 +12,9 @@ jobs:
permissions:
contents: write
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Checkout
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false

View File

@@ -11,9 +11,8 @@ permissions:
jobs:
stale:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # ratchet:actions/stale@v10
- uses: actions/stale@5bef64f19d7facfb25b37b414482c7164d639639 # ratchet:actions/stale@v9
with:
stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'

View File

@@ -20,7 +20,6 @@ jobs:
scan-licenses:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-licenses"]
timeout-minutes: 45
permissions:
actions: read
contents: read
@@ -28,7 +27,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -90,7 +89,6 @@ jobs:
scan-trivy:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-trivy"]
timeout-minutes: 45
steps:
- name: Set up Docker Buildx

View File

@@ -12,14 +12,12 @@ permissions:
contents: read
env:
# AWS credentials for S3-specific test
S3_AWS_ACCESS_KEY_ID_FOR_TEST: ${{ secrets.S3_AWS_ACCESS_KEY_ID }}
S3_AWS_SECRET_ACCESS_KEY_FOR_TEST: ${{ secrets.S3_AWS_SECRET_ACCESS_KEY }}
# AWS
S3_AWS_ACCESS_KEY_ID: ${{ secrets.S3_AWS_ACCESS_KEY_ID }}
S3_AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_AWS_SECRET_ACCESS_KEY }}
# MinIO
S3_ENDPOINT_URL: "http://localhost:9004"
S3_AWS_ACCESS_KEY_ID: "minioadmin"
S3_AWS_SECRET_ACCESS_KEY: "minioadmin"
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
@@ -33,20 +31,15 @@ env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
# Code Interpreter
# TODO: debug why this is failing and enable
CODE_INTERPRETER_BASE_URL: http://localhost:8000
jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
outputs:
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
steps:
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -65,7 +58,6 @@ jobs:
- runner=2cpu-linux-arm64
- ${{ format('run-id={0}-external-dependency-unit-tests-job-{1}', github.run_id, strategy['job-index']) }}
- extras=s3-cache
timeout-minutes: 45
strategy:
fail-fast: false
matrix:
@@ -79,17 +71,12 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Setup Python and Install Dependencies
uses: ./.github/actions/setup-python-and-install-dependencies
with:
requirements: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/ee.txt
- name: Setup Playwright
uses: ./.github/actions/setup-playwright
@@ -103,24 +90,10 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Create .env file for Docker Compose
run: |
cat <<EOF > deployment/docker_compose/.env
CODE_INTERPRETER_BETA_ENABLED=true
EOF
- name: Set up Standard Dependencies
run: |
cd deployment/docker_compose
docker compose \
-f docker-compose.yml \
-f docker-compose.dev.yml \
up -d \
minio \
relational_db \
cache \
index \
code-interpreter
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d minio relational_db cache index
- name: Run migrations
run: |
@@ -140,30 +113,3 @@ jobs:
-xv \
--ff \
backend/tests/external_dependency_unit/${TEST_DIR}
- name: Collect Docker logs on failure
if: failure()
run: |
mkdir -p docker-logs
cd deployment/docker_compose
# Get list of running containers
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml ps -q)
# Collect logs from each container
for container in $containers; do
container_name=$(docker inspect --format='{{.Name}}' $container | sed 's/^\///')
echo "Collecting logs from $container_name..."
docker logs $container > ../../docker-logs/${container_name}.log 2>&1
done
cd ../..
echo "Docker logs collected in docker-logs directory"
- name: Upload Docker logs
if: failure()
uses: actions/upload-artifact@v4
with:
name: docker-logs-${{ matrix.test-dir }}
path: docker-logs/
retention-days: 7

View File

@@ -16,12 +16,11 @@ jobs:
helm-chart-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}-helm-chart-check"]
timeout-minutes: 45
# fetch-depth 0 is required for helm/chart-testing-action
steps:
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false

View File

@@ -35,12 +35,11 @@ jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
outputs:
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
steps:
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -67,11 +66,10 @@ jobs:
build-backend-image:
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-backend-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -104,11 +102,10 @@ jobs:
build-model-server-image:
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-model-server-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -139,11 +136,10 @@ jobs:
build-integration-image:
runs-on: [runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-build-integration-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -185,7 +181,6 @@ jobs:
- runner=4cpu-linux-arm64
- ${{ format('run-id={0}-integration-tests-job-{1}', github.run_id, strategy['job-index']) }}
- extras=ecr-cache
timeout-minutes: 45
strategy:
fail-fast: false
@@ -195,7 +190,7 @@ jobs:
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -226,7 +221,6 @@ jobs:
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
INTEGRATION_TESTS_MODE=true \
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 \
MCP_SERVER_ENABLED=true \
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
relational_db \
index \
@@ -235,56 +229,43 @@ jobs:
api_server \
inference_model_server \
indexing_model_server \
mcp_server \
background \
-d
id: start_docker
- name: Wait for services to be ready
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
wait_for_service() {
local url=$1
local label=$2
local timeout=${3:-300} # default 5 minutes
local start_time
start_time=$(date +%s)
docker logs -f onyx-api_server-1 &
while true; do
local current_time
current_time=$(date +%s)
local elapsed_time=$((current_time - start_time))
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. ${label} did not become ready in $timeout seconds."
exit 1
fi
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
local response
response=$(curl -s -o /dev/null -w "%{http_code}" "$url" || echo "curl_error")
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
if [ "$response" = "200" ]; then
echo "${label} is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error while checking ${label}. Retrying in 5 seconds..."
else
echo "${label} not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
sleep 5
done
}
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
wait_for_service "http://localhost:8080/health" "API server"
test_dir="${{ matrix.test-dir.path }}"
if [ "$test_dir" = "tests/mcp" ]; then
wait_for_service "http://localhost:8090/health" "MCP server"
else
echo "Skipping MCP server wait for non-MCP suite: $test_dir"
fi
echo "Finished waiting for services."
sleep 5
done
echo "Finished waiting for service."
- name: Start Mock Services
run: |
@@ -313,8 +294,6 @@ jobs:
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e MCP_SERVER_HOST=mcp_server \
-e MCP_SERVER_PORT=8090 \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
@@ -367,12 +346,11 @@ jobs:
build-integration-image,
]
runs-on: [runs-on, runner=8cpu-linux-arm64, "run-id=${{ github.run_id }}-multitenant-tests", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -396,7 +374,6 @@ jobs:
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
DEV_MODE=true \
MCP_SERVER_ENABLED=true \
docker compose -f docker-compose.multitenant-dev.yml up \
relational_db \
index \
@@ -405,7 +382,6 @@ jobs:
api_server \
inference_model_server \
indexing_model_server \
mcp_server \
background \
-d
id: start_docker_multi_tenant
@@ -454,8 +430,6 @@ jobs:
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e MCP_SERVER_HOST=mcp_server \
-e MCP_SERVER_PORT=8090 \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
@@ -497,7 +471,6 @@ jobs:
required:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
needs: [integration-tests, multitenant-tests]
if: ${{ always() }}
steps:

View File

@@ -12,10 +12,9 @@ jobs:
jest-tests:
name: Jest Tests
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false

View File

@@ -16,7 +16,6 @@ permissions:
jobs:
validate_pr_title:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Check PR title for Conventional Commits
env:

View File

@@ -13,7 +13,6 @@ permissions:
jobs:
linear-check:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Check PR body for Linear link or override
env:

View File

@@ -14,7 +14,6 @@ env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
@@ -32,12 +31,11 @@ jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
outputs:
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
steps:
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -63,11 +61,10 @@ jobs:
build-backend-image:
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-backend-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -98,11 +95,10 @@ jobs:
build-model-server-image:
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-model-server-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -132,11 +128,10 @@ jobs:
build-integration-image:
runs-on: [runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-build-integration-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -178,7 +173,6 @@ jobs:
- runner=4cpu-linux-arm64
- ${{ format('run-id={0}-integration-tests-mit-job-{1}', github.run_id, strategy['job-index']) }}
- extras=ecr-cache
timeout-minutes: 45
strategy:
fail-fast: false
@@ -188,7 +182,7 @@ jobs:
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -217,7 +211,6 @@ jobs:
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
INTEGRATION_TESTS_MODE=true \
MCP_SERVER_ENABLED=true \
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
relational_db \
index \
@@ -226,56 +219,43 @@ jobs:
api_server \
inference_model_server \
indexing_model_server \
mcp_server \
background \
-d
id: start_docker
- name: Wait for services to be ready
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
wait_for_service() {
local url=$1
local label=$2
local timeout=${3:-300} # default 5 minutes
local start_time
start_time=$(date +%s)
docker logs -f onyx-api_server-1 &
while true; do
local current_time
current_time=$(date +%s)
local elapsed_time=$((current_time - start_time))
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. ${label} did not become ready in $timeout seconds."
exit 1
fi
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
local response
response=$(curl -s -o /dev/null -w "%{http_code}" "$url" || echo "curl_error")
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
if [ "$response" = "200" ]; then
echo "${label} is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error while checking ${label}. Retrying in 5 seconds..."
else
echo "${label} not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
sleep 5
done
}
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
wait_for_service "http://localhost:8080/health" "API server"
test_dir="${{ matrix.test-dir.path }}"
if [ "$test_dir" = "tests/mcp" ]; then
wait_for_service "http://localhost:8090/health" "MCP server"
else
echo "Skipping MCP server wait for non-MCP suite: $test_dir"
fi
echo "Finished waiting for services."
sleep 5
done
echo "Finished waiting for service."
- name: Start Mock Services
run: |
@@ -305,10 +285,7 @@ jobs:
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e MCP_SERVER_HOST=mcp_server \
-e MCP_SERVER_PORT=8090 \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
@@ -354,7 +331,6 @@ jobs:
required:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
needs: [integration-tests-mit]
if: ${{ always() }}
steps:

View File

@@ -27,13 +27,6 @@ env:
MCP_OAUTH_USERNAME: ${{ vars.MCP_OAUTH_USERNAME }}
MCP_OAUTH_PASSWORD: ${{ secrets.MCP_OAUTH_PASSWORD }}
# for MCP API Key tests
MCP_API_KEY: test-api-key-12345
MCP_API_KEY_TEST_PORT: 8005
MCP_API_KEY_TEST_URL: http://host.docker.internal:8005/mcp
MCP_API_KEY_SERVER_HOST: 0.0.0.0
MCP_API_KEY_SERVER_PUBLIC_HOST: host.docker.internal
MOCK_LLM_RESPONSE: true
MCP_TEST_SERVER_PORT: 8004
MCP_TEST_SERVER_URL: http://host.docker.internal:8004/mcp
@@ -47,12 +40,11 @@ env:
jobs:
build-web-image:
runs-on: [runs-on, runner=4cpu-linux-arm64, "run-id=${{ github.run_id }}-build-web-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -83,12 +75,11 @@ jobs:
build-backend-image:
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-backend-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -120,12 +111,11 @@ jobs:
build-model-server-image:
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-model-server-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
@@ -157,13 +147,7 @@ jobs:
playwright-tests:
needs: [build-web-image, build-backend-image, build-model-server-image]
name: Playwright Tests (${{ matrix.project }})
runs-on:
- runs-on
- runner=8cpu-linux-arm64
- "run-id=${{ github.run_id }}-playwright-tests-${{ matrix.project }}"
- "extras=ecr-cache"
- volume=50gb
timeout-minutes: 45
runs-on: [runs-on, runner=8cpu-linux-arm64, "run-id=${{ github.run_id }}-playwright-tests-${{ matrix.project }}", "extras=ecr-cache"]
strategy:
fail-fast: false
matrix:
@@ -172,7 +156,7 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
@@ -234,7 +218,7 @@ jobs:
- name: Start Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.mcp-oauth-test.yml -f docker-compose.mcp-api-key-test.yml up -d
docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.mcp-oauth-test.yml up -d
id: start_docker
- name: Wait for service to be ready
@@ -294,54 +278,6 @@ jobs:
sleep 3
done
- name: Wait for MCP API Key mock server
run: |
echo "Waiting for MCP API Key mock server on port ${MCP_API_KEY_TEST_PORT:-8005}..."
start_time=$(date +%s)
timeout=120
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. MCP API Key mock server did not become ready in ${timeout}s."
exit 1
fi
if curl -sf "http://localhost:${MCP_API_KEY_TEST_PORT:-8005}/healthz" > /dev/null; then
echo "MCP API Key mock server is ready!"
break
fi
sleep 3
done
- name: Wait for web server to be ready
run: |
echo "Waiting for web server on port 3000..."
start_time=$(date +%s)
timeout=120
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Web server did not become ready in ${timeout}s."
exit 1
fi
if curl -sf "http://localhost:3000/api/health" > /dev/null 2>&1 || \
curl -sf "http://localhost:3000/" > /dev/null 2>&1; then
echo "Web server is ready!"
break
fi
echo "Web server not ready yet. Retrying in 3 seconds..."
sleep 3
done
- name: Run Playwright tests
working-directory: ./web
env:
@@ -382,7 +318,6 @@ jobs:
playwright-required:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
needs: [playwright-tests]
if: ${{ always() }}
steps:
@@ -408,7 +343,7 @@ jobs:
# ]
# steps:
# - name: Checkout code
# uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
# uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
# with:
# fetch-depth: 0

View File

@@ -14,55 +14,19 @@ permissions:
contents: read
jobs:
validate-requirements:
runs-on: ubuntu-slim
timeout-minutes: 45
steps:
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Setup uv
uses: astral-sh/setup-uv@caf0cab7a618c569241d31dcd442f54681755d39 # ratchet:astral-sh/setup-uv@v3
# TODO: Enable caching once there is a uv.lock file checked in.
# with:
# enable-cache: true
- name: Validate requirements lock files
run: ./backend/scripts/compile_requirements.py --check
mypy-check:
# See https://runs-on.com/runners/linux/
# Note: Mypy seems quite optimized for x64 compared to arm64.
# Similarly, mypy is single-threaded and incremental, so 2cpu is sufficient.
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-mypy-check", "extras=s3-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Setup Python and Install Dependencies
uses: ./.github/actions/setup-python-and-install-dependencies
with:
requirements: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
backend/requirements/ee.txt
- name: Generate OpenAPI schema
shell: bash
working-directory: backend
env:
PYTHONPATH: "."
run: |
python scripts/onyx_openapi_schema.py --filename generated/openapi.json
# needed for pulling openapitools/openapi-generator-cli
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
@@ -72,18 +36,11 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Generate OpenAPI Python client
shell: bash
run: |
docker run --rm \
-v "${{ github.workspace }}/backend/generated:/local" \
openapitools/openapi-generator-cli generate \
-i /local/openapi.json \
-g python \
-o /local/onyx_openapi_client \
--package-name onyx_openapi_client \
--skip-validate-spec \
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
- name: Prepare build
uses: ./.github/actions/prepare-build
with:
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-password: ${{ secrets.DOCKER_TOKEN }}
- name: Cache mypy cache
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}

View File

@@ -126,7 +126,6 @@ jobs:
connectors-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-connectors-check", "extras=s3-cache"]
timeout-minutes: 45
env:
PYTHONPATH: ./backend
@@ -135,16 +134,12 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Setup Python and Install Dependencies
uses: ./.github/actions/setup-python-and-install-dependencies
with:
requirements: |
backend/requirements/default.txt
backend/requirements/dev.txt
- name: Setup Playwright
uses: ./.github/actions/setup-playwright

View File

@@ -32,14 +32,13 @@ jobs:
model-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}-model-check"]
timeout-minutes: 45
env:
PYTHONPATH: ./backend
steps:
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false

View File

@@ -17,7 +17,6 @@ jobs:
backend-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-backend-check"]
timeout-minutes: 45
env:
@@ -31,18 +30,12 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Setup Python and Install Dependencies
uses: ./.github/actions/setup-python-and-install-dependencies
with:
requirements: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
backend/requirements/ee.txt
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"

View File

@@ -14,10 +14,9 @@ jobs:
quality-checks:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-quality-checks"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
- uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
@@ -27,13 +26,5 @@ jobs:
- name: Setup Terraform
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # ratchet:hashicorp/setup-terraform@v3
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet:pre-commit/action@v3.0.1
env:
# uv-run is mypy's id and mypy is covered by the Python Checks which caches dependencies better.
SKIP: uv-run
with:
extra_args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || '' }}
- name: Check Actions
uses: giner/check-actions@28d366c7cbbe235f9624a88aa31a628167eee28c # ratchet:giner/check-actions@v1.0.1
with:
check_permissions: false
check_versions: false

View File

@@ -9,12 +9,11 @@ on:
jobs:
sync-foss:
runs-on: ubuntu-latest
timeout-minutes: 45
permissions:
contents: read
steps:
- name: Checkout main Onyx repo
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false

View File

@@ -11,14 +11,13 @@ permissions:
jobs:
create-and-push-tag:
runs-on: ubuntu-slim
timeout-minutes: 45
steps:
# actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
# implement here which needs an actual user's deploy key
- name: Checkout code
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
ssh-key: "${{ secrets.DEPLOY_KEY }}"
persist-credentials: true

View File

@@ -12,12 +12,11 @@ jobs:
zizmor:
name: zizmor
runs-on: ubuntu-slim
timeout-minutes: 45
permissions:
security-events: write # needed for SARIF uploads
steps:
- name: Checkout repository
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6.0.0
uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # ratchet:actions/checkout@v5.0.1
with:
persist-credentials: false

5
.gitignore vendored
View File

@@ -1,7 +1,6 @@
# editors
.vscode
.zed
.cursor
# macos
.DS_store
@@ -29,8 +28,6 @@ settings.json
# others
/deployment/data/nginx/app.conf
/deployment/data/nginx/mcp.conf.inc
/deployment/data/nginx/mcp_upstream.conf.inc
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml
*.egg-info
@@ -49,7 +46,5 @@ CLAUDE.md
# Local .terraform.lock.hcl file
.terraform.lock.hcl
node_modules
# MCP configs
.playwright-mcp

8
.mcp.json.template Normal file
View File

@@ -0,0 +1,8 @@
{
"mcpServers": {
"onyx-mcp": {
"type": "http",
"url": "http://localhost:8000/mcp"
}
}
}

View File

@@ -1,20 +1,4 @@
default_install_hook_types:
- pre-commit
- post-checkout
- post-merge
- post-rewrite
repos:
- repo: https://github.com/astral-sh/uv-pre-commit
# This revision is from https://github.com/astral-sh/uv-pre-commit/pull/53
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
hooks:
- id: uv-sync
- id: uv-run
name: mypy
args: ["mypy"]
pass_filenames: true
files: ^backend/.*\.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
@@ -87,3 +71,31 @@ repos:
entry: python3 backend/scripts/check_lazy_imports.py
language: system
files: ^backend/(?!\.venv/).*\.py$
# We would like to have a mypy pre-commit hook, but due to the fact that
# pre-commit runs in it's own isolated environment, we would need to install
# and keep in sync all dependencies so mypy has access to the appropriate type
# stubs. This does not seem worth it at the moment, so for now we will stick to
# having mypy run via Github Actions / manually by contributors
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.1.1
# hooks:
# - id: mypy
# exclude: ^tests/
# # below are needed for type stubs since pre-commit runs in it's own
# # isolated environment. Unfortunately, this needs to be kept in sync
# # with requirements/dev.txt + requirements/default.txt
# additional_dependencies: [
# alembic==1.10.4,
# types-beautifulsoup4==4.12.0.3,
# types-html5lib==1.1.11.13,
# types-oauthlib==3.2.0.9,
# types-psycopg2==2.9.21.10,
# types-python-dateutil==2.8.19.13,
# types-regex==2023.3.23.1,
# types-requests==2.28.11.17,
# types-retry==0.9.9.3,
# types-urllib3==1.26.25.11
# ]
# # TODO: add back once errors are addressed
# # args: [--strict]

View File

@@ -20,7 +20,6 @@
"Web Server",
"Model Server",
"API Server",
"MCP Server",
"Slack Bot",
"Celery primary",
"Celery light",
@@ -153,34 +152,6 @@
},
"consoleTitle": "Slack Bot Console"
},
{
"name": "MCP Server",
"consoleName": "MCP Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"MCP_SERVER_ENABLED": "true",
"MCP_SERVER_PORT": "8090",
"MCP_SERVER_CORS_ORIGINS": "http://localhost:*",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
"args": [
"onyx.mcp_server.api:mcp_app",
"--reload",
"--port",
"8090",
"--timeout-graceful-shutdown",
"0"
],
"presentation": {
"group": "2"
},
"consoleTitle": "MCP Server Console"
},
{
"name": "Celery primary",
"type": "debugpy",

View File

@@ -12,13 +12,6 @@ ENV DANSWER_RUNNING_IN_DOCKER="true" \
DO_NOT_TRACK="true" \
PLAYWRIGHT_BROWSERS_PATH="/app/.cache/ms-playwright"
# Create non-root user for security best practices
RUN groupadd -g 1001 onyx && \
useradd -u 1001 -g onyx -m -s /bin/bash onyx && \
mkdir -p /var/log/onyx && \
chmod 755 /var/log/onyx && \
chown onyx:onyx /var/log/onyx
COPY --from=ghcr.io/astral-sh/uv:0.9.9 /uv /uvx /bin/
# Install system dependencies
@@ -58,7 +51,6 @@ RUN uv pip install --system --no-cache-dir --upgrade \
pip uninstall -y py && \
playwright install chromium && \
playwright install-deps chromium && \
chown -R onyx:onyx /app && \
ln -s /usr/local/bin/supervisord /usr/bin/supervisord && \
# Cleanup for CVEs and size reduction
# https://github.com/tornadoweb/tornado/issues/3107
@@ -102,6 +94,13 @@ tiktoken.get_encoding('cl100k_base')"
# Set up application files
WORKDIR /app
# Create non-root user for security best practices
RUN groupadd -g 1001 onyx && \
useradd -u 1001 -g onyx -m -s /bin/bash onyx && \
mkdir -p /var/log/onyx && \
chmod 755 /var/log/onyx && \
chown onyx:onyx /var/log/onyx
# Enterprise Version Files
COPY --chown=onyx:onyx ./ee /app/ee
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf

View File

@@ -1,104 +0,0 @@
"""add_open_url_tool
Revision ID: 4f8a2b3c1d9e
Revises: a852cbe15577
Create Date: 2025-11-24 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4f8a2b3c1d9e"
down_revision = "a852cbe15577"
branch_labels = None
depends_on = None
OPEN_URL_TOOL = {
"name": "OpenURLTool",
"display_name": "Open URL",
"description": (
"The Open URL Action allows the agent to fetch and read contents of web pages."
),
"in_code_tool_id": "OpenURLTool",
"enabled": True,
}
def upgrade() -> None:
conn = op.get_bind()
# Check if tool already exists
existing = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
{"in_code_tool_id": OPEN_URL_TOOL["in_code_tool_id"]},
).fetchone()
if existing:
tool_id = existing[0]
# Update existing tool
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description
WHERE in_code_tool_id = :in_code_tool_id
"""
),
OPEN_URL_TOOL,
)
else:
# Insert new tool
conn.execute(
sa.text(
"""
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled)
"""
),
OPEN_URL_TOOL,
)
# Get the newly inserted tool's id
result = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
{"in_code_tool_id": OPEN_URL_TOOL["in_code_tool_id"]},
).fetchone()
tool_id = result[0] # type: ignore
# Associate the tool with all existing personas
# Get all persona IDs
persona_ids = conn.execute(sa.text("SELECT id FROM persona")).fetchall()
for (persona_id,) in persona_ids:
# Check if association already exists
exists = conn.execute(
sa.text(
"""
SELECT 1 FROM persona__tool
WHERE persona_id = :persona_id AND tool_id = :tool_id
"""
),
{"persona_id": persona_id, "tool_id": tool_id},
).fetchone()
if not exists:
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (:persona_id, :tool_id)
"""
),
{"persona_id": persona_id, "tool_id": tool_id},
)
def downgrade() -> None:
# We don't remove the tool on downgrade since it's fine to have it around.
# If we upgrade again, it will be a no-op.
pass

View File

@@ -1,44 +0,0 @@
"""add_created_at_in_project_userfile
Revision ID: 6436661d5b65
Revises: c7e9f4a3b2d1
Create Date: 2025-11-24 11:50:24.536052
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "6436661d5b65"
down_revision = "c7e9f4a3b2d1"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add created_at column to project__user_file table
op.add_column(
"project__user_file",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
)
# Add composite index on (project_id, created_at DESC)
op.create_index(
"ix_project__user_file_project_id_created_at",
"project__user_file",
["project_id", sa.text("created_at DESC")],
)
def downgrade() -> None:
# Remove composite index on (project_id, created_at)
op.drop_index(
"ix_project__user_file_project_id_created_at", table_name="project__user_file"
)
# Remove created_at column from project__user_file table
op.drop_column("project__user_file", "created_at")

View File

@@ -1,572 +0,0 @@
"""New Chat History
Revision ID: a852cbe15577
Revises: 6436661d5b65
Create Date: 2025-11-08 15:16:37.781308
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "a852cbe15577"
down_revision = "6436661d5b65"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop research agent tables (if they exist)
op.execute("DROP TABLE IF EXISTS research_agent_iteration_sub_step CASCADE")
op.execute("DROP TABLE IF EXISTS research_agent_iteration CASCADE")
# Drop agent sub query and sub question tables (if they exist)
op.execute("DROP TABLE IF EXISTS agent__sub_query__search_doc CASCADE")
op.execute("DROP TABLE IF EXISTS agent__sub_query CASCADE")
op.execute("DROP TABLE IF EXISTS agent__sub_question CASCADE")
# Update ChatMessage table
# Rename parent_message to parent_message_id and make it a foreign key (if not already done)
conn = op.get_bind()
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'chat_message' AND column_name = 'parent_message'
"""
)
)
if result.fetchone():
op.alter_column(
"chat_message", "parent_message", new_column_name="parent_message_id"
)
op.create_foreign_key(
"fk_chat_message_parent_message_id",
"chat_message",
"chat_message",
["parent_message_id"],
["id"],
)
# Rename latest_child_message to latest_child_message_id and make it a foreign key (if not already done)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'chat_message' AND column_name = 'latest_child_message'
"""
)
)
if result.fetchone():
op.alter_column(
"chat_message",
"latest_child_message",
new_column_name="latest_child_message_id",
)
op.create_foreign_key(
"fk_chat_message_latest_child_message_id",
"chat_message",
"chat_message",
["latest_child_message_id"],
["id"],
)
# Add reasoning_tokens column (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'chat_message' AND column_name = 'reasoning_tokens'
"""
)
)
if not result.fetchone():
op.add_column(
"chat_message", sa.Column("reasoning_tokens", sa.Text(), nullable=True)
)
# Drop columns no longer needed (if they exist)
for col in [
"rephrased_query",
"alternate_assistant_id",
"overridden_model",
"is_agentic",
"refined_answer_improvement",
"research_type",
"research_plan",
"research_answer_purpose",
]:
result = conn.execute(
sa.text(
f"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'chat_message' AND column_name = '{col}'
"""
)
)
if result.fetchone():
op.drop_column("chat_message", col)
# Update ToolCall table
# Add chat_session_id column (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'chat_session_id'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
)
op.create_foreign_key(
"fk_tool_call_chat_session_id",
"tool_call",
"chat_session",
["chat_session_id"],
["id"],
)
# Rename message_id to parent_chat_message_id and make nullable (if not already done)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'message_id'
"""
)
)
if result.fetchone():
op.alter_column(
"tool_call",
"message_id",
new_column_name="parent_chat_message_id",
nullable=True,
)
# Add parent_tool_call_id (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'parent_tool_call_id'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call", sa.Column("parent_tool_call_id", sa.Integer(), nullable=True)
)
op.create_foreign_key(
"fk_tool_call_parent_tool_call_id",
"tool_call",
"tool_call",
["parent_tool_call_id"],
["id"],
)
op.drop_constraint("uq_tool_call_message_id", "tool_call", type_="unique")
# Add turn_number, tool_id (if not exists)
for col_name in ["turn_number", "tool_id"]:
result = conn.execute(
sa.text(
f"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = '{col_name}'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column(col_name, sa.Integer(), nullable=False, server_default="0"),
)
# Add tool_call_id as String (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_call_id'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column("tool_call_id", sa.String(), nullable=False, server_default=""),
)
# Add reasoning_tokens (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'reasoning_tokens'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call", sa.Column("reasoning_tokens", sa.Text(), nullable=True)
)
# Rename tool_arguments to tool_call_arguments (if not already done)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_arguments'
"""
)
)
if result.fetchone():
op.alter_column(
"tool_call", "tool_arguments", new_column_name="tool_call_arguments"
)
# Rename tool_result to tool_call_response and change type from JSONB to Text (if not already done)
result = conn.execute(
sa.text(
"""
SELECT column_name, data_type FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_result'
"""
)
)
tool_result_row = result.fetchone()
if tool_result_row:
op.alter_column(
"tool_call", "tool_result", new_column_name="tool_call_response"
)
# Change type from JSONB to Text
op.execute(
sa.text(
"""
ALTER TABLE tool_call
ALTER COLUMN tool_call_response TYPE TEXT
USING tool_call_response::text
"""
)
)
else:
# Check if tool_call_response already exists and is JSONB, then convert to Text
result = conn.execute(
sa.text(
"""
SELECT data_type FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_call_response'
"""
)
)
tool_call_response_row = result.fetchone()
if tool_call_response_row and tool_call_response_row[0] == "jsonb":
op.execute(
sa.text(
"""
ALTER TABLE tool_call
ALTER COLUMN tool_call_response TYPE TEXT
USING tool_call_response::text
"""
)
)
# Add tool_call_tokens (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_call_tokens'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column(
"tool_call_tokens", sa.Integer(), nullable=False, server_default="0"
),
)
# Add generated_images column for image generation tool replay (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'generated_images'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column("generated_images", postgresql.JSONB(), nullable=True),
)
# Drop tool_name column (if exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_name'
"""
)
)
if result.fetchone():
op.drop_column("tool_call", "tool_name")
# Create tool_call__search_doc association table (if not exists)
result = conn.execute(
sa.text(
"""
SELECT table_name FROM information_schema.tables
WHERE table_name = 'tool_call__search_doc'
"""
)
)
if not result.fetchone():
op.create_table(
"tool_call__search_doc",
sa.Column("tool_call_id", sa.Integer(), nullable=False),
sa.Column("search_doc_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["tool_call_id"], ["tool_call.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(
["search_doc_id"], ["search_doc.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("tool_call_id", "search_doc_id"),
)
# Add replace_base_system_prompt to persona table (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'persona' AND column_name = 'replace_base_system_prompt'
"""
)
)
if not result.fetchone():
op.add_column(
"persona",
sa.Column(
"replace_base_system_prompt",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
def downgrade() -> None:
# Reverse persona changes
op.drop_column("persona", "replace_base_system_prompt")
# Drop tool_call__search_doc association table
op.execute("DROP TABLE IF EXISTS tool_call__search_doc CASCADE")
# Reverse ToolCall changes
op.add_column("tool_call", sa.Column("tool_name", sa.String(), nullable=False))
op.drop_column("tool_call", "tool_id")
op.drop_column("tool_call", "tool_call_tokens")
op.drop_column("tool_call", "generated_images")
# Change tool_call_response back to JSONB before renaming
op.execute(
sa.text(
"""
ALTER TABLE tool_call
ALTER COLUMN tool_call_response TYPE JSONB
USING tool_call_response::jsonb
"""
)
)
op.alter_column("tool_call", "tool_call_response", new_column_name="tool_result")
op.alter_column(
"tool_call", "tool_call_arguments", new_column_name="tool_arguments"
)
op.drop_column("tool_call", "reasoning_tokens")
op.drop_column("tool_call", "tool_call_id")
op.drop_column("tool_call", "turn_number")
op.drop_constraint(
"fk_tool_call_parent_tool_call_id", "tool_call", type_="foreignkey"
)
op.drop_column("tool_call", "parent_tool_call_id")
op.alter_column(
"tool_call",
"parent_chat_message_id",
new_column_name="message_id",
nullable=False,
)
op.drop_constraint("fk_tool_call_chat_session_id", "tool_call", type_="foreignkey")
op.drop_column("tool_call", "chat_session_id")
op.add_column(
"chat_message",
sa.Column(
"research_answer_purpose",
sa.Enum("INTRO", "DEEP_DIVE", name="researchanswerpurpose"),
nullable=True,
),
)
op.add_column(
"chat_message", sa.Column("research_plan", postgresql.JSONB(), nullable=True)
)
op.add_column(
"chat_message",
sa.Column(
"research_type",
sa.Enum("SIMPLE", "DEEP", name="researchtype"),
nullable=True,
),
)
op.add_column(
"chat_message",
sa.Column("refined_answer_improvement", sa.Boolean(), nullable=True),
)
op.add_column(
"chat_message",
sa.Column("is_agentic", sa.Boolean(), nullable=False, server_default="false"),
)
op.add_column(
"chat_message", sa.Column("overridden_model", sa.String(), nullable=True)
)
op.add_column(
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
)
op.add_column(
"chat_message", sa.Column("rephrased_query", sa.Text(), nullable=True)
)
op.drop_column("chat_message", "reasoning_tokens")
op.drop_constraint(
"fk_chat_message_latest_child_message_id", "chat_message", type_="foreignkey"
)
op.alter_column(
"chat_message",
"latest_child_message_id",
new_column_name="latest_child_message",
)
op.drop_constraint(
"fk_chat_message_parent_message_id", "chat_message", type_="foreignkey"
)
op.alter_column(
"chat_message", "parent_message_id", new_column_name="parent_message"
)
# Recreate agent sub question and sub query tables
op.create_table(
"agent__sub_question",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("primary_question_id", sa.Integer(), nullable=False),
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("sub_question", sa.Text(), nullable=False),
sa.Column("level", sa.Integer(), nullable=False),
sa.Column("level_question_num", sa.Integer(), nullable=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("sub_answer", sa.Text(), nullable=False),
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=False),
sa.ForeignKeyConstraint(
["primary_question_id"], ["chat_message.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["chat_session_id"], ["chat_session.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"agent__sub_query",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("parent_question_id", sa.Integer(), nullable=False),
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("sub_query", sa.Text(), nullable=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["parent_question_id"], ["agent__sub_question.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["chat_session_id"], ["chat_session.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"agent__sub_query__search_doc",
sa.Column("sub_query_id", sa.Integer(), nullable=False),
sa.Column("search_doc_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["sub_query_id"], ["agent__sub_query.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["search_doc_id"], ["search_doc.id"]),
sa.PrimaryKeyConstraint("sub_query_id", "search_doc_id"),
)
# Recreate research agent tables
op.create_table(
"research_agent_iteration",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("primary_question_id", sa.Integer(), nullable=False),
sa.Column("iteration_nr", sa.Integer(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("purpose", sa.String(), nullable=True),
sa.Column("reasoning", sa.String(), nullable=True),
sa.ForeignKeyConstraint(
["primary_question_id"], ["chat_message.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"primary_question_id",
"iteration_nr",
name="_research_agent_iteration_unique_constraint",
),
)
op.create_table(
"research_agent_iteration_sub_step",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("primary_question_id", sa.Integer(), nullable=False),
sa.Column("iteration_nr", sa.Integer(), nullable=False),
sa.Column("iteration_sub_step_nr", sa.Integer(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("sub_step_instructions", sa.String(), nullable=True),
sa.Column("sub_step_tool_id", sa.Integer(), nullable=True),
sa.Column("reasoning", sa.String(), nullable=True),
sa.Column("sub_answer", sa.String(), nullable=True),
sa.Column("cited_doc_results", postgresql.JSONB(), nullable=False),
sa.Column("claims", postgresql.JSONB(), nullable=True),
sa.Column("is_web_fetch", sa.Boolean(), nullable=True),
sa.Column("queries", postgresql.JSONB(), nullable=True),
sa.Column("generated_images", postgresql.JSONB(), nullable=True),
sa.Column("additional_data", postgresql.JSONB(), nullable=True),
sa.ForeignKeyConstraint(
["primary_question_id", "iteration_nr"],
[
"research_agent_iteration.primary_question_id",
"research_agent_iteration.iteration_nr",
],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(["sub_step_tool_id"], ["tool.id"], ondelete="SET NULL"),
sa.PrimaryKeyConstraint("id"),
)

View File

@@ -1,73 +0,0 @@
"""add_python_tool
Revision ID: c7e9f4a3b2d1
Revises: 3c9a65f1207f
Create Date: 2025-11-08 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "c7e9f4a3b2d1"
down_revision = "3c9a65f1207f"
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Add PythonTool to built-in tools"""
conn = op.get_bind()
conn.execute(
sa.text(
"""
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled)
"""
),
{
"name": "PythonTool",
# in the UI, call it `Code Interpreter` since this is a well known term for this tool
"display_name": "Code Interpreter",
"description": (
"The Code Interpreter Action allows the assistant to execute "
"Python code in a secure, isolated environment for data analysis, "
"computation, visualization, and file processing."
),
"in_code_tool_id": "PythonTool",
"enabled": True,
},
)
# needed to store files generated by the python tool
op.add_column(
"research_agent_iteration_sub_step",
sa.Column(
"file_ids",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
def downgrade() -> None:
"""Remove PythonTool from built-in tools"""
conn = op.get_bind()
conn.execute(
sa.text(
"""
DELETE FROM tool
WHERE in_code_tool_id = :in_code_tool_id
"""
),
{
"in_code_tool_id": "PythonTool",
},
)
op.drop_column("research_agent_iteration_sub_step", "file_ids")

View File

@@ -199,7 +199,10 @@ def fetch_persona_message_analytics(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
ChatSession.persona_id == persona_id,
or_(
ChatMessage.alternate_assistant_id == persona_id,
ChatSession.persona_id == persona_id,
),
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
@@ -228,7 +231,10 @@ def fetch_persona_unique_users(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
ChatSession.persona_id == persona_id,
or_(
ChatMessage.alternate_assistant_id == persona_id,
ChatSession.persona_id == persona_id,
),
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
@@ -259,7 +265,10 @@ def fetch_assistant_message_analytics(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
ChatSession.persona_id == assistant_id,
or_(
ChatMessage.alternate_assistant_id == assistant_id,
ChatSession.persona_id == assistant_id,
),
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
@@ -290,7 +299,10 @@ def fetch_assistant_unique_users(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
ChatSession.persona_id == assistant_id,
or_(
ChatMessage.alternate_assistant_id == assistant_id,
ChatSession.persona_id == assistant_id,
),
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
@@ -320,7 +332,10 @@ def fetch_assistant_unique_users_total(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
ChatSession.persona_id == assistant_id,
or_(
ChatMessage.alternate_assistant_id == assistant_id,
ChatSession.persona_id == assistant_id,
),
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,

View File

@@ -55,7 +55,18 @@ def get_empty_chat_messages_entries__paginated(
# Get assistant name (from session persona, or alternate if specified)
assistant_name = None
if chat_session.persona:
if message.alternate_assistant_id:
# If there's an alternate assistant, we need to fetch it
from onyx.db.models import Persona
alternate_persona = (
db_session.query(Persona)
.filter(Persona.id == message.alternate_assistant_id)
.first()
)
if alternate_persona:
assistant_name = alternate_persona.name
elif chat_session.persona:
assistant_name = chat_session.persona.name
message_skeletons.append(

View File

@@ -581,48 +581,6 @@ def update_user_curator_relationship(
db_session.commit()
def add_users_to_user_group(
db_session: Session,
user: User | None,
user_group_id: int,
user_ids: list[UUID],
) -> UserGroup:
db_user_group = fetch_user_group(db_session=db_session, user_group_id=user_group_id)
if db_user_group is None:
raise ValueError(f"UserGroup with id '{user_group_id}' not found")
missing_users = [
user_id for user_id in user_ids if fetch_user_by_id(db_session, user_id) is None
]
if missing_users:
raise ValueError(
f"User(s) not found: {', '.join(str(user_id) for user_id in missing_users)}"
)
_check_user_group_is_modifiable(db_user_group)
current_user_ids = [user.id for user in db_user_group.users]
current_user_ids_set = set(current_user_ids)
new_user_ids = [
user_id for user_id in user_ids if user_id not in current_user_ids_set
]
if not new_user_ids:
return db_user_group
user_group_update = UserGroupUpdate(
user_ids=current_user_ids + new_user_ids,
cc_pair_ids=[cc_pair.id for cc_pair in db_user_group.cc_pairs],
)
return update_user_group(
db_session=db_session,
user=user,
user_group_id=user_group_id,
user_group_update=user_group_update,
)
def update_user_group(
db_session: Session,
user: User | None,
@@ -645,17 +603,6 @@ def update_user_group(
added_user_ids = list(updated_user_ids - current_user_ids)
removed_user_ids = list(current_user_ids - updated_user_ids)
if added_user_ids:
missing_users = [
user_id
for user_id in added_user_ids
if fetch_user_by_id(db_session, user_id) is None
]
if missing_users:
raise ValueError(
f"User(s) not found: {', '.join(str(user_id) for user_id in missing_users)}"
)
# LEAVING THIS HERE FOR NOW FOR GIVING DIFFERENT ROLES
# ACCESS TO DIFFERENT PERMISSIONS
# if (removed_user_ids or added_user_ids) and (

View File

@@ -23,7 +23,7 @@ from ee.onyx.server.query_and_chat.chat_backend import (
router as chat_router,
)
from ee.onyx.server.query_and_chat.query_backend import (
basic_router as ee_query_router,
basic_router as query_router,
)
from ee.onyx.server.query_history.api import router as query_history_router
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
@@ -48,9 +48,6 @@ from onyx.main import include_auth_router_with_prefix
from onyx.main import include_router_with_global_prefix_prepended
from onyx.main import lifespan as lifespan_base
from onyx.main import use_route_function_names_as_operation_ids
from onyx.server.query_and_chat.query_backend import (
basic_router as query_router,
)
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
@@ -122,7 +119,6 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, query_history_router)
# EE only backend APIs
include_router_with_global_prefix_prepended(application, query_router)
include_router_with_global_prefix_prepended(application, ee_query_router)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, standard_answer_router)
include_router_with_global_prefix_prepended(application, ee_oauth_router)

View File

@@ -9,7 +9,7 @@ from ee.onyx.server.query_and_chat.models import (
)
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_history_chain
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.models import ChatBasicResponse
from onyx.chat.process_message import gather_stream
from onyx.chat.process_message import stream_chat_message_objects
@@ -69,7 +69,7 @@ def handle_simplified_chat_message(
chat_session_id = chat_message_req.chat_session_id
try:
parent_message, _ = create_chat_history_chain(
parent_message, _ = create_chat_chain(
chat_session_id=chat_session_id, db_session=db_session
)
except Exception:

View File

@@ -6,14 +6,18 @@ from pydantic import BaseModel
from pydantic import Field
from pydantic import model_validator
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import BasicChunkRequest
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.enums import SearchType
from onyx.context.search.models import ChunkContext
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RetrievalDetails
from onyx.server.manage.models import StandardAnswer
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import SubQuestionIdentifier
class StandardAnswerRequest(BaseModel):
@@ -25,12 +29,14 @@ class StandardAnswerResponse(BaseModel):
standard_answers: list[StandardAnswer] = Field(default_factory=list)
class DocumentSearchRequest(BasicChunkRequest):
user_selected_filters: BaseFilters | None = None
class DocumentSearchResponse(BaseModel):
top_documents: list[InferenceChunk]
class DocumentSearchRequest(ChunkContext):
message: str
search_type: SearchType
retrieval_options: RetrievalDetails
recency_bias_multiplier: float = 1.0
evaluation_type: LLMEvaluationType
# None to use system defaults for reranking
rerank_settings: RerankingDetails | None = None
class BasicCreateChatMessageRequest(ChunkContext):
@@ -90,17 +96,17 @@ class SimpleDoc(BaseModel):
metadata: dict | None
class AgentSubQuestion(BaseModel):
class AgentSubQuestion(SubQuestionIdentifier):
sub_question: str
document_ids: list[str]
class AgentAnswer(BaseModel):
class AgentAnswer(SubQuestionIdentifier):
answer: str
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
class AgentSubQuery(BaseModel):
class AgentSubQuery(SubQuestionIdentifier):
sub_query: str
query_id: int
@@ -146,3 +152,45 @@ class AgentSubQuery(BaseModel):
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
)
return sorted_dict
class OneShotQARequest(ChunkContext):
# Supports simplier APIs that don't deal with chat histories or message edits
# Easier APIs to work with for developers
persona_override_config: PersonaOverrideConfig | None = None
persona_id: int | None = None
messages: list[ThreadMessage]
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
rerank_settings: RerankingDetails | None = None
# allows the caller to specify the exact search query they want to use
# can be used if the message sent to the LLM / query should not be the same
# will also disable Thread-based Rewording if specified
query_override: str | None = None
# If True, skips generating an AI response to the search query
skip_gen_ai_answer_generation: bool = False
# If True, uses agentic search instead of basic search
use_agentic_search: bool = False
@model_validator(mode="after")
def check_persona_fields(self) -> "OneShotQARequest":
if self.persona_override_config is None and self.persona_id is None:
raise ValueError("Exactly one of persona_config or persona_id must be set")
elif self.persona_override_config is not None and (self.persona_id is not None):
raise ValueError(
"If persona_override_config is set, persona_id cannot be set"
)
return self
class OneShotQAResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break
answer: str | None = None
rephrase: str | None = None
citations: list[CitationInfo] | None = None
docs: QADocsResponse | None = None
error_msg: str | None = None
chat_message_id: int | None = None

View File

@@ -1,23 +1,316 @@
import json
from collections.abc import Generator
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.onyxbot.slack.handlers.handle_standard_answers import (
oneoff_standard_answers,
)
from ee.onyx.server.query_and_chat.models import DocumentSearchRequest
from ee.onyx.server.query_and_chat.models import OneShotQARequest
from ee.onyx.server.query_and_chat.models import OneShotQAResponse
from ee.onyx.server.query_and_chat.models import StandardAnswerRequest
from ee.onyx.server.query_and_chat.models import StandardAnswerResponse
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import prepare_chat_message_request
from onyx.chat.models import AnswerStream
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.process_message import gather_stream
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.onyxbot_configs import MAX_THREAD_CONTEXT_PERCENTAGE
from onyx.context.search.models import SavedSearchDocWithContent
from onyx.context.search.models import SearchRequest
from onyx.context.search.pipeline import SearchPipeline
from onyx.context.search.utils import dedupe_documents
from onyx.context.search.utils import drop_llm_indices
from onyx.context.search.utils import relevant_sections_to_indices
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.persona import get_persona_by_id
from onyx.llm.factory import get_default_llms
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.factory import get_main_llm_from_tuple
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.utils import get_json_line
from onyx.utils.logger import setup_logger
logger = setup_logger()
logger = setup_logger()
basic_router = APIRouter(prefix="/query")
class DocumentSearchPagination(BaseModel):
offset: int
limit: int
returned_count: int
has_more: bool
next_offset: int | None = None
class DocumentSearchResponse(BaseModel):
top_documents: list[SavedSearchDocWithContent]
llm_indices: list[int]
pagination: DocumentSearchPagination
def _normalize_pagination(limit: int | None, offset: int | None) -> tuple[int, int]:
if limit is None:
resolved_limit = NUM_RETURNED_HITS
else:
resolved_limit = limit
if resolved_limit <= 0:
raise HTTPException(
status_code=400, detail="retrieval_options.limit must be positive"
)
if offset is None:
resolved_offset = 0
else:
resolved_offset = offset
if resolved_offset < 0:
raise HTTPException(
status_code=400, detail="retrieval_options.offset cannot be negative"
)
return resolved_limit, resolved_offset
@basic_router.post("/document-search")
def handle_search_request(
search_request: DocumentSearchRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> DocumentSearchResponse:
"""Simple search endpoint, does not create a new message or records in the DB"""
query = search_request.message
logger.notice(f"Received document search query: {query}")
llm, fast_llm = get_default_llms()
pagination_limit, pagination_offset = _normalize_pagination(
limit=search_request.retrieval_options.limit,
offset=search_request.retrieval_options.offset,
)
search_pipeline = SearchPipeline(
search_request=SearchRequest(
query=query,
search_type=search_request.search_type,
human_selected_filters=search_request.retrieval_options.filters,
enable_auto_detect_filters=search_request.retrieval_options.enable_auto_detect_filters,
persona=None, # For simplicity, default settings should be good for this search
offset=pagination_offset,
limit=pagination_limit + 1,
rerank_settings=search_request.rerank_settings,
evaluation_type=search_request.evaluation_type,
chunks_above=search_request.chunks_above,
chunks_below=search_request.chunks_below,
full_doc=search_request.full_doc,
),
user=user,
llm=llm,
fast_llm=fast_llm,
skip_query_analysis=False,
db_session=db_session,
bypass_acl=False,
)
top_sections = search_pipeline.reranked_sections
relevance_sections = search_pipeline.section_relevance
top_docs = [
SavedSearchDocWithContent(
document_id=section.center_chunk.document_id,
chunk_ind=section.center_chunk.chunk_id,
content=section.center_chunk.content,
semantic_identifier=section.center_chunk.semantic_identifier or "Unknown",
link=(
section.center_chunk.source_links.get(0)
if section.center_chunk.source_links
else None
),
blurb=section.center_chunk.blurb,
source_type=section.center_chunk.source_type,
boost=section.center_chunk.boost,
hidden=section.center_chunk.hidden,
metadata=section.center_chunk.metadata,
score=section.center_chunk.score or 0.0,
match_highlights=section.center_chunk.match_highlights,
updated_at=section.center_chunk.updated_at,
primary_owners=section.center_chunk.primary_owners,
secondary_owners=section.center_chunk.secondary_owners,
is_internet=False,
db_doc_id=0,
)
for section in top_sections
]
# Track whether the underlying retrieval produced more items than requested
has_more_results = len(top_docs) > pagination_limit
# Deduping happens at the last step to avoid harming quality by dropping content early on
deduped_docs = top_docs
dropped_inds = None
if search_request.retrieval_options.dedupe_docs:
deduped_docs, dropped_inds = dedupe_documents(top_docs)
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections, items=deduped_docs
)
if dropped_inds:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=deduped_docs,
dropped_indices=dropped_inds,
)
paginated_docs = deduped_docs[:pagination_limit]
llm_indices = [index for index in llm_indices if index < len(paginated_docs)]
has_more = has_more_results
pagination = DocumentSearchPagination(
offset=pagination_offset,
limit=pagination_limit,
returned_count=len(paginated_docs),
has_more=has_more,
next_offset=(pagination_offset + pagination_limit) if has_more else None,
)
return DocumentSearchResponse(
top_documents=paginated_docs,
llm_indices=llm_indices,
pagination=pagination,
)
def get_answer_stream(
query_request: OneShotQARequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> AnswerStream:
query = query_request.messages[0].message
logger.notice(f"Received query for Answer API: {query}")
if (
query_request.persona_override_config is None
and query_request.persona_id is None
):
raise KeyError("Must provide persona ID or Persona Config")
persona_info: Persona | PersonaOverrideConfig | None = None
if query_request.persona_override_config is not None:
persona_info = query_request.persona_override_config
elif query_request.persona_id is not None:
persona_info = get_persona_by_id(
persona_id=query_request.persona_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
llm = get_main_llm_from_tuple(get_llms_for_persona(persona=persona_info, user=user))
llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,
provider_type=llm.config.model_provider,
)
max_history_tokens = int(
llm.config.max_input_tokens * MAX_THREAD_CONTEXT_PERCENTAGE
)
combined_message = combine_message_thread(
messages=query_request.messages,
max_tokens=max_history_tokens,
llm_tokenizer=llm_tokenizer,
)
# Also creates a new chat session
request = prepare_chat_message_request(
message_text=combined_message,
user=user,
persona_id=query_request.persona_id,
persona_override_config=query_request.persona_override_config,
message_ts_to_respond_to=None,
retrieval_details=query_request.retrieval_options,
rerank_settings=query_request.rerank_settings,
db_session=db_session,
use_agentic_search=query_request.use_agentic_search,
skip_gen_ai_answer_generation=query_request.skip_gen_ai_answer_generation,
)
packets = stream_chat_message_objects(
new_msg_req=request,
user=user,
db_session=db_session,
)
return packets
@basic_router.post("/answer-with-citation")
def get_answer_with_citation(
request: OneShotQARequest,
db_session: Session = Depends(get_session),
user: User | None = Depends(current_user),
) -> OneShotQAResponse:
try:
packets = get_answer_stream(request, user, db_session)
answer = gather_stream(packets)
if answer.error_msg:
raise RuntimeError(answer.error_msg)
return OneShotQAResponse(
answer=answer.answer,
chat_message_id=answer.message_id,
error_msg=answer.error_msg,
citations=[
CitationInfo(citation_num=i, document_id=doc_id)
for i, doc_id in answer.cited_documents.items()
],
docs=QADocsResponse(
top_documents=answer.top_documents,
predicted_flow=None,
predicted_search=None,
applied_source_filters=None,
applied_time_cutoff=None,
recency_bias_multiplier=0.0,
),
)
except Exception as e:
logger.error(f"Error in get_answer_with_citation: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail="An internal server error occurred")
@basic_router.post("/stream-answer-with-citation")
def stream_answer_with_citation(
request: OneShotQARequest,
db_session: Session = Depends(get_session),
user: User | None = Depends(current_user),
) -> StreamingResponse:
def stream_generator() -> Generator[str, None, None]:
try:
for packet in get_answer_stream(request, user, db_session):
serialized = get_json_line(packet.model_dump())
yield serialized
except Exception as e:
logger.exception("Error in answer streaming")
yield json.dumps({"error": str(e)})
return StreamingResponse(stream_generator(), media_type="application/json")
@basic_router.get("/standard-answer")
def get_standard_answer(
request: StandardAnswerRequest,

View File

@@ -24,7 +24,7 @@ from onyx.auth.users import current_admin_user
from onyx.auth.users import get_display_email
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.background.task_utils import construct_query_history_report_name
from onyx.chat.chat_utils import create_chat_history_chain
from onyx.chat.chat_utils import create_chat_chain
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import FileType
@@ -123,9 +123,10 @@ def snapshot_from_chat_session(
) -> ChatSessionSnapshot | None:
try:
# Older chats may not have the right structure
messages = create_chat_history_chain(
last_message, messages = create_chat_chain(
chat_session_id=chat_session.id, db_session=db_session
)
messages.append(last_message)
except RuntimeError:
return None

View File

@@ -4,14 +4,12 @@ from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.db.user_group import add_users_to_user_group
from ee.onyx.db.user_group import fetch_user_groups
from ee.onyx.db.user_group import fetch_user_groups_for_user
from ee.onyx.db.user_group import insert_user_group
from ee.onyx.db.user_group import prepare_user_group_for_deletion
from ee.onyx.db.user_group import update_user_curator_relationship
from ee.onyx.db.user_group import update_user_group
from ee.onyx.server.user_group.models import AddUsersToUserGroupRequest
from ee.onyx.server.user_group.models import SetCuratorRequest
from ee.onyx.server.user_group.models import UserGroup
from ee.onyx.server.user_group.models import UserGroupCreate
@@ -81,26 +79,6 @@ def patch_user_group(
raise HTTPException(status_code=404, detail=str(e))
@router.post("/admin/user-group/{user_group_id}/add-users")
def add_users(
user_group_id: int,
add_users_request: AddUsersToUserGroupRequest,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> UserGroup:
try:
return UserGroup.from_model(
add_users_to_user_group(
db_session=db_session,
user=user,
user_group_id=user_group_id,
user_ids=add_users_request.user_ids,
)
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@router.post("/admin/user-group/{user_group_id}/set-curator")
def set_user_curator(
user_group_id: int,

View File

@@ -87,10 +87,6 @@ class UserGroupUpdate(BaseModel):
cc_pair_ids: list[int]
class AddUsersToUserGroupRequest(BaseModel):
user_ids: list[UUID]
class SetCuratorRequest(BaseModel):
user_id: UUID
is_curator: bool

View File

@@ -1,309 +1,365 @@
# import json
# from collections.abc import Callable
# from collections.abc import Iterator
# from collections.abc import Sequence
# from dataclasses import dataclass
# from typing import Any
import json
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
# from onyx.agents.agent_framework.models import RunItemStreamEvent
# from onyx.agents.agent_framework.models import StreamEvent
# from onyx.agents.agent_framework.models import ToolCallStreamItem
# from onyx.llm.interfaces import LanguageModelInput
# from onyx.llm.interfaces import LLM
# from onyx.llm.interfaces import ToolChoiceOptions
# from onyx.llm.message_types import ChatCompletionMessage
# from onyx.llm.message_types import ToolCall
# from onyx.llm.model_response import ModelResponseStream
# from onyx.tools.tool import Tool
# from onyx.tracing.framework.create import agent_span
# from onyx.tracing.framework.create import generation_span
import onyx.tracing.framework._error_tracing as _error_tracing
from onyx.agents.agent_framework.models import RunItemStreamEvent
from onyx.agents.agent_framework.models import StreamEvent
from onyx.agents.agent_framework.models import ToolCallOutputStreamItem
from onyx.agents.agent_framework.models import ToolCallStreamItem
from onyx.llm.interfaces import LanguageModelInput
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.llm.message_types import ChatCompletionMessage
from onyx.llm.message_types import ToolCall
from onyx.llm.model_response import ModelResponseStream
from onyx.tools.tool import RunContextWrapper
from onyx.tools.tool import Tool
from onyx.tracing.framework.create import agent_span
from onyx.tracing.framework.create import function_span
from onyx.tracing.framework.create import generation_span
from onyx.tracing.framework.spans import SpanError
# @dataclass
# class QueryResult:
# stream: Iterator[StreamEvent]
# new_messages_stateful: list[ChatCompletionMessage]
@dataclass
class QueryResult:
stream: Iterator[StreamEvent]
new_messages_stateful: list[ChatCompletionMessage]
# def _serialize_tool_output(output: Any) -> str:
# if isinstance(output, str):
# return output
# try:
# return json.dumps(output)
# except TypeError:
# return str(output)
def _serialize_tool_output(output: Any) -> str:
if isinstance(output, str):
return output
try:
return json.dumps(output)
except TypeError:
return str(output)
# def _parse_tool_calls_from_message_content(
# content: str,
# ) -> list[dict[str, Any]]:
# """Parse JSON content that represents tool call instructions."""
# try:
# parsed_content = json.loads(content)
# except json.JSONDecodeError:
# return []
def _parse_tool_calls_from_message_content(
content: str,
) -> list[dict[str, Any]]:
"""Parse JSON content that represents tool call instructions."""
try:
parsed_content = json.loads(content)
except json.JSONDecodeError:
return []
# if isinstance(parsed_content, dict):
# candidates = [parsed_content]
# elif isinstance(parsed_content, list):
# candidates = [item for item in parsed_content if isinstance(item, dict)]
# else:
# return []
if isinstance(parsed_content, dict):
candidates = [parsed_content]
elif isinstance(parsed_content, list):
candidates = [item for item in parsed_content if isinstance(item, dict)]
else:
return []
# tool_calls: list[dict[str, Any]] = []
tool_calls: list[dict[str, Any]] = []
# for candidate in candidates:
# name = candidate.get("name")
# arguments = candidate.get("arguments")
for candidate in candidates:
name = candidate.get("name")
arguments = candidate.get("arguments")
# if not isinstance(name, str) or arguments is None:
# continue
if not isinstance(name, str) or arguments is None:
continue
# if not isinstance(arguments, dict):
# continue
if not isinstance(arguments, dict):
continue
# call_id = candidate.get("id")
# arguments_str = json.dumps(arguments)
# tool_calls.append(
# {
# "id": call_id,
# "name": name,
# "arguments": arguments_str,
# }
# )
call_id = candidate.get("id")
arguments_str = json.dumps(arguments)
tool_calls.append(
{
"id": call_id,
"name": name,
"arguments": arguments_str,
}
)
# return tool_calls
return tool_calls
# def _try_convert_content_to_tool_calls_for_non_tool_calling_llms(
# tool_calls_in_progress: dict[int, dict[str, Any]],
# content_parts: list[str],
# structured_response_format: dict | None,
# next_synthetic_tool_call_id: Callable[[], str],
# ) -> None:
# """Populate tool_calls_in_progress when a non-tool-calling LLM returns JSON content describing tool calls."""
# if tool_calls_in_progress or not content_parts or structured_response_format:
# return
def _try_convert_content_to_tool_calls_for_non_tool_calling_llms(
tool_calls_in_progress: dict[int, dict[str, Any]],
content_parts: list[str],
structured_response_format: dict | None,
next_synthetic_tool_call_id: Callable[[], str],
) -> None:
"""Populate tool_calls_in_progress when a non-tool-calling LLM returns JSON content describing tool calls."""
if tool_calls_in_progress or not content_parts or structured_response_format:
return
# tool_calls_from_content = _parse_tool_calls_from_message_content(
# "".join(content_parts)
# )
tool_calls_from_content = _parse_tool_calls_from_message_content(
"".join(content_parts)
)
# if not tool_calls_from_content:
# return
if not tool_calls_from_content:
return
# content_parts.clear()
content_parts.clear()
# for index, tool_call_data in enumerate(tool_calls_from_content):
# call_id = tool_call_data["id"] or next_synthetic_tool_call_id()
# tool_calls_in_progress[index] = {
# "id": call_id,
# "name": tool_call_data["name"],
# "arguments": tool_call_data["arguments"],
# }
for index, tool_call_data in enumerate(tool_calls_from_content):
call_id = tool_call_data["id"] or next_synthetic_tool_call_id()
tool_calls_in_progress[index] = {
"id": call_id,
"name": tool_call_data["name"],
"arguments": tool_call_data["arguments"],
}
# def _update_tool_call_with_delta(
# tool_calls_in_progress: dict[int, dict[str, Any]],
# tool_call_delta: Any,
# ) -> None:
# index = tool_call_delta.index
def _update_tool_call_with_delta(
tool_calls_in_progress: dict[int, dict[str, Any]],
tool_call_delta: Any,
) -> None:
index = tool_call_delta.index
# if index not in tool_calls_in_progress:
# tool_calls_in_progress[index] = {
# "id": None,
# "name": None,
# "arguments": "",
# }
if index not in tool_calls_in_progress:
tool_calls_in_progress[index] = {
"id": None,
"name": None,
"arguments": "",
}
# if tool_call_delta.id:
# tool_calls_in_progress[index]["id"] = tool_call_delta.id
if tool_call_delta.id:
tool_calls_in_progress[index]["id"] = tool_call_delta.id
# if tool_call_delta.function:
# if tool_call_delta.function.name:
# tool_calls_in_progress[index]["name"] = tool_call_delta.function.name
if tool_call_delta.function:
if tool_call_delta.function.name:
tool_calls_in_progress[index]["name"] = tool_call_delta.function.name
# if tool_call_delta.function.arguments:
# tool_calls_in_progress[index][
# "arguments"
# ] += tool_call_delta.function.arguments
if tool_call_delta.function.arguments:
tool_calls_in_progress[index][
"arguments"
] += tool_call_delta.function.arguments
# def query(
# llm_with_default_settings: LLM,
# messages: LanguageModelInput,
# tools: Sequence[Tool],
# context: Any,
# tool_choice: ToolChoiceOptions | None = None,
# structured_response_format: dict | None = None,
# ) -> QueryResult:
# tool_definitions = [tool.tool_definition() for tool in tools]
# tools_by_name = {tool.name: tool for tool in tools}
def query(
llm_with_default_settings: LLM,
messages: LanguageModelInput,
tools: Sequence[Tool],
context: Any,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> QueryResult:
tool_definitions = [tool.tool_definition() for tool in tools]
tools_by_name = {tool.name: tool for tool in tools}
# new_messages_stateful: list[ChatCompletionMessage] = []
new_messages_stateful: list[ChatCompletionMessage] = []
# current_span = agent_span(
# name="agent_framework_query",
# output_type="dict" if structured_response_format else "str",
# )
# current_span.start(mark_as_current=True)
# current_span.span_data.tools = [t.name for t in tools]
current_span = agent_span(
name="agent_framework_query",
output_type="dict" if structured_response_format else "str",
)
current_span.start(mark_as_current=True)
current_span.span_data.tools = [t.name for t in tools]
# def stream_generator() -> Iterator[StreamEvent]:
# message_started = False
# reasoning_started = False
def stream_generator() -> Iterator[StreamEvent]:
message_started = False
reasoning_started = False
# tool_calls_in_progress: dict[int, dict[str, Any]] = {}
tool_calls_in_progress: dict[int, dict[str, Any]] = {}
# content_parts: list[str] = []
content_parts: list[str] = []
# synthetic_tool_call_counter = 0
synthetic_tool_call_counter = 0
# def _next_synthetic_tool_call_id() -> str:
# nonlocal synthetic_tool_call_counter
# call_id = f"synthetic_tool_call_{synthetic_tool_call_counter}"
# synthetic_tool_call_counter += 1
# return call_id
def _next_synthetic_tool_call_id() -> str:
nonlocal synthetic_tool_call_counter
call_id = f"synthetic_tool_call_{synthetic_tool_call_counter}"
synthetic_tool_call_counter += 1
return call_id
# with generation_span( # type: ignore[misc]
# model=llm_with_default_settings.config.model_name,
# model_config={
# "base_url": str(llm_with_default_settings.config.api_base or ""),
# "model_impl": "litellm",
# },
# ) as span_generation:
# # Only set input if messages is a sequence (not a string)
# # ChatCompletionMessage TypedDicts are compatible with Mapping[str, Any] at runtime
# if isinstance(messages, Sequence) and not isinstance(messages, str):
# # Convert ChatCompletionMessage sequence to Sequence[Mapping[str, Any]]
# span_generation.span_data.input = [dict(msg) for msg in messages] # type: ignore[assignment]
# for chunk in llm_with_default_settings.stream(
# prompt=messages,
# tools=tool_definitions,
# tool_choice=tool_choice,
# structured_response_format=structured_response_format,
# ):
# assert isinstance(chunk, ModelResponseStream)
# usage = getattr(chunk, "usage", None)
# if usage:
# span_generation.span_data.usage = {
# "input_tokens": usage.prompt_tokens,
# "output_tokens": usage.completion_tokens,
# "cache_read_input_tokens": usage.cache_read_input_tokens,
# "cache_creation_input_tokens": usage.cache_creation_input_tokens,
# }
with generation_span( # type: ignore[misc]
model=llm_with_default_settings.config.model_name,
model_config={
"base_url": str(llm_with_default_settings.config.api_base or ""),
"model_impl": "litellm",
},
) as span_generation:
# Only set input if messages is a sequence (not a string)
# ChatCompletionMessage TypedDicts are compatible with Mapping[str, Any] at runtime
if isinstance(messages, Sequence) and not isinstance(messages, str):
# Convert ChatCompletionMessage sequence to Sequence[Mapping[str, Any]]
span_generation.span_data.input = [dict(msg) for msg in messages] # type: ignore[assignment]
for chunk in llm_with_default_settings.stream(
prompt=messages,
tools=tool_definitions,
tool_choice=tool_choice,
structured_response_format=structured_response_format,
):
assert isinstance(chunk, ModelResponseStream)
usage = getattr(chunk, "usage", None)
if usage:
span_generation.span_data.usage = {
"input_tokens": usage.prompt_tokens,
"output_tokens": usage.completion_tokens,
"cache_read_input_tokens": usage.cache_read_input_tokens,
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
}
# delta = chunk.choice.delta
# finish_reason = chunk.choice.finish_reason
delta = chunk.choice.delta
finish_reason = chunk.choice.finish_reason
# if delta.reasoning_content:
# if not reasoning_started:
# yield RunItemStreamEvent(type="reasoning_start")
# reasoning_started = True
if delta.reasoning_content:
if not reasoning_started:
yield RunItemStreamEvent(type="reasoning_start")
reasoning_started = True
# if delta.content:
# if reasoning_started:
# yield RunItemStreamEvent(type="reasoning_done")
# reasoning_started = False
# content_parts.append(delta.content)
# if not message_started:
# yield RunItemStreamEvent(type="message_start")
# message_started = True
if delta.content:
if reasoning_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
content_parts.append(delta.content)
if not message_started:
yield RunItemStreamEvent(type="message_start")
message_started = True
# if delta.tool_calls:
# if reasoning_started:
# yield RunItemStreamEvent(type="reasoning_done")
# reasoning_started = False
# if message_started:
# yield RunItemStreamEvent(type="message_done")
# message_started = False
if delta.tool_calls:
if reasoning_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
if message_started:
yield RunItemStreamEvent(type="message_done")
message_started = False
# for tool_call_delta in delta.tool_calls:
# _update_tool_call_with_delta(
# tool_calls_in_progress, tool_call_delta
# )
for tool_call_delta in delta.tool_calls:
_update_tool_call_with_delta(
tool_calls_in_progress, tool_call_delta
)
# yield chunk
yield chunk
# if not finish_reason:
# continue
if not finish_reason:
continue
# if reasoning_started:
# yield RunItemStreamEvent(type="reasoning_done")
# reasoning_started = False
# if message_started:
# yield RunItemStreamEvent(type="message_done")
# message_started = False
if reasoning_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
if message_started:
yield RunItemStreamEvent(type="message_done")
message_started = False
# if tool_choice != "none":
# _try_convert_content_to_tool_calls_for_non_tool_calling_llms(
# tool_calls_in_progress,
# content_parts,
# structured_response_format,
# _next_synthetic_tool_call_id,
# )
if tool_choice != "none":
_try_convert_content_to_tool_calls_for_non_tool_calling_llms(
tool_calls_in_progress,
content_parts,
structured_response_format,
_next_synthetic_tool_call_id,
)
# if content_parts:
# new_messages_stateful.append(
# {
# "role": "assistant",
# "content": "".join(content_parts),
# }
# )
# span_generation.span_data.output = new_messages_stateful
if content_parts:
new_messages_stateful.append(
{
"role": "assistant",
"content": "".join(content_parts),
}
)
span_generation.span_data.output = new_messages_stateful
# # Execute tool calls outside of the stream loop and generation_span
# if tool_calls_in_progress:
# sorted_tool_calls = sorted(tool_calls_in_progress.items())
# Execute tool calls outside of the stream loop and generation_span
if tool_calls_in_progress:
sorted_tool_calls = sorted(tool_calls_in_progress.items())
# # Build tool calls for the message and execute tools
# assistant_tool_calls: list[ToolCall] = []
# Build tool calls for the message and execute tools
assistant_tool_calls: list[ToolCall] = []
tool_outputs: dict[str, str] = {}
# for _, tool_call_data in sorted_tool_calls:
# call_id = tool_call_data["id"]
# name = tool_call_data["name"]
# arguments_str = tool_call_data["arguments"]
for _, tool_call_data in sorted_tool_calls:
call_id = tool_call_data["id"]
name = tool_call_data["name"]
arguments_str = tool_call_data["arguments"]
# if call_id is None or name is None:
# continue
if call_id is None or name is None:
continue
# assistant_tool_calls.append(
# {
# "id": call_id,
# "type": "function",
# "function": {
# "name": name,
# "arguments": arguments_str,
# },
# }
# )
assistant_tool_calls.append(
{
"id": call_id,
"type": "function",
"function": {
"name": name,
"arguments": arguments_str,
},
}
)
# yield RunItemStreamEvent(
# type="tool_call",
# details=ToolCallStreamItem(
# call_id=call_id,
# name=name,
# arguments=arguments_str,
# ),
# )
yield RunItemStreamEvent(
type="tool_call",
details=ToolCallStreamItem(
call_id=call_id,
name=name,
arguments=arguments_str,
),
)
# if name in tools_by_name:
# tools_by_name[name]
# json.loads(arguments_str)
if name in tools_by_name:
tool = tools_by_name[name]
arguments = json.loads(arguments_str)
# run_context = RunContextWrapper(context=context)
run_context = RunContextWrapper(context=context)
# TODO: Instead of executing sequentially, execute in parallel
# In practice, it's not a must right now since we don't use parallel
# tool calls, so kicking the can down the road for now.
# TODO: Instead of executing sequentially, execute in parallel
# In practice, it's not a must right now since we don't use parallel
# tool calls, so kicking the can down the road for now.
with function_span(tool.name) as span_fn:
span_fn.span_data.input = arguments
try:
output = tool.run_v2(run_context, **arguments)
tool_outputs[call_id] = _serialize_tool_output(output)
span_fn.span_data.output = output
except Exception as e:
_error_tracing.attach_error_to_current_span(
SpanError(
message="Error running tool",
data={"tool_name": tool.name, "error": str(e)},
)
)
# Treat the error as the tool output so the framework can continue
error_output = f"Error: {str(e)}"
tool_outputs[call_id] = error_output
output = error_output
# TODO broken for now, no need for a run_v2
# output = tool.run_v2(run_context, **arguments)
yield RunItemStreamEvent(
type="tool_call_output",
details=ToolCallOutputStreamItem(
call_id=call_id,
output=output,
),
)
else:
not_found_output = f"Tool {name} not found"
tool_outputs[call_id] = _serialize_tool_output(not_found_output)
yield RunItemStreamEvent(
type="tool_call_output",
details=ToolCallOutputStreamItem(
call_id=call_id,
output=not_found_output,
),
)
# yield RunItemStreamEvent(
# type="tool_call_output",
# details=ToolCallOutputStreamItem(
# call_id=call_id,
# output=output,
# ),
# )
new_messages_stateful.append(
{
"role": "assistant",
"content": None,
"tool_calls": assistant_tool_calls,
}
)
for _, tool_call_data in sorted_tool_calls:
call_id = tool_call_data["id"]
if call_id in tool_outputs:
new_messages_stateful.append(
{
"role": "tool",
"content": tool_outputs[call_id],
"tool_call_id": call_id,
}
)
current_span.finish(reset_current=True)
return QueryResult(
stream=stream_generator(),
new_messages_stateful=new_messages_stateful,
)

View File

@@ -26,9 +26,9 @@ def monkey_patch_convert_tool_choice_to_ignore_openai_hosted_web_search() -> Non
# Without this patch, the library uses special formatting that breaks our custom tools
# See: https://platform.openai.com/docs/api-reference/responses/create#responses_create-tool_choice-hosted_tool-type
if tool_choice == "web_search":
return {"type": "function", "name": "web_search"}
return "web_search"
if tool_choice == "image_generation":
return {"type": "function", "name": "image_generation"}
return "image_generation"
return orig_func(cls, tool_choice)
OpenAIResponsesConverter.convert_tool_choice = classmethod( # type: ignore[method-assign, assignment]

View File

@@ -1,21 +1,21 @@
# from operator import add
# from typing import Annotated
from operator import add
from typing import Annotated
# from pydantic import BaseModel
from pydantic import BaseModel
# class CoreState(BaseModel):
# """
# This is the core state that is shared across all subgraphs.
# """
class CoreState(BaseModel):
"""
This is the core state that is shared across all subgraphs.
"""
# log_messages: Annotated[list[str], add] = []
# current_step_nr: int = 1
log_messages: Annotated[list[str], add] = []
current_step_nr: int = 1
# class SubgraphCoreState(BaseModel):
# """
# This is the core state that is shared across all subgraphs.
# """
class SubgraphCoreState(BaseModel):
"""
This is the core state that is shared across all subgraphs.
"""
# log_messages: Annotated[list[str], add] = []
log_messages: Annotated[list[str], add] = []

View File

@@ -1,62 +1,62 @@
# from collections.abc import Hashable
# from typing import cast
from collections.abc import Hashable
from typing import cast
# from langchain_core.runnables.config import RunnableConfig
# from langgraph.types import Send
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import Send
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
# from onyx.agents.agent_search.dc_search_analysis.states import (
# ObjectResearchInformationUpdate,
# )
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
# from onyx.agents.agent_search.dc_search_analysis.states import (
# SearchSourcesObjectsUpdate,
# )
# from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectResearchInformationUpdate,
)
from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
from onyx.agents.agent_search.dc_search_analysis.states import (
SearchSourcesObjectsUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
# def parallel_object_source_research_edge(
# state: SearchSourcesObjectsUpdate, config: RunnableConfig
# ) -> list[Send | Hashable]:
# """
# LangGraph edge to parallelize the research for an individual object and source
# """
def parallel_object_source_research_edge(
state: SearchSourcesObjectsUpdate, config: RunnableConfig
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the research for an individual object and source
"""
# search_objects = state.analysis_objects
# search_sources = state.analysis_sources
search_objects = state.analysis_objects
search_sources = state.analysis_sources
# object_source_combinations = [
# (object, source) for object in search_objects for source in search_sources
# ]
object_source_combinations = [
(object, source) for object in search_objects for source in search_sources
]
# return [
# Send(
# "research_object_source",
# ObjectSourceInput(
# object_source_combination=object_source_combination,
# log_messages=[],
# ),
# )
# for object_source_combination in object_source_combinations
# ]
return [
Send(
"research_object_source",
ObjectSourceInput(
object_source_combination=object_source_combination,
log_messages=[],
),
)
for object_source_combination in object_source_combinations
]
# def parallel_object_research_consolidation_edge(
# state: ObjectResearchInformationUpdate, config: RunnableConfig
# ) -> list[Send | Hashable]:
# """
# LangGraph edge to parallelize the research for an individual object and source
# """
# cast(GraphConfig, config["metadata"]["config"])
# object_research_information_results = state.object_research_information_results
def parallel_object_research_consolidation_edge(
state: ObjectResearchInformationUpdate, config: RunnableConfig
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the research for an individual object and source
"""
cast(GraphConfig, config["metadata"]["config"])
object_research_information_results = state.object_research_information_results
# return [
# Send(
# "consolidate_object_research",
# ObjectInformationInput(
# object_information=object_information,
# log_messages=[],
# ),
# )
# for object_information in object_research_information_results
# ]
return [
Send(
"consolidate_object_research",
ObjectInformationInput(
object_information=object_information,
log_messages=[],
),
)
for object_information in object_research_information_results
]

View File

@@ -1,103 +1,103 @@
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from onyx.agents.agent_search.dc_search_analysis.edges import (
# parallel_object_research_consolidation_edge,
# )
# from onyx.agents.agent_search.dc_search_analysis.edges import (
# parallel_object_source_research_edge,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a1_search_objects import (
# search_objects,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a2_research_object_source import (
# research_object_source,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a3_structure_research_by_object import (
# structure_research_by_object,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a4_consolidate_object_research import (
# consolidate_object_research,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a5_consolidate_research import (
# consolidate_research,
# )
# from onyx.agents.agent_search.dc_search_analysis.states import MainInput
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dc_search_analysis.edges import (
parallel_object_research_consolidation_edge,
)
from onyx.agents.agent_search.dc_search_analysis.edges import (
parallel_object_source_research_edge,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a1_search_objects import (
search_objects,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a2_research_object_source import (
research_object_source,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a3_structure_research_by_object import (
structure_research_by_object,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a4_consolidate_object_research import (
consolidate_object_research,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a5_consolidate_research import (
consolidate_research,
)
from onyx.agents.agent_search.dc_search_analysis.states import MainInput
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# test_mode = False
test_mode = False
# def divide_and_conquer_graph_builder(test_mode: bool = False) -> StateGraph:
# """
# LangGraph graph builder for the knowledge graph search process.
# """
def divide_and_conquer_graph_builder(test_mode: bool = False) -> StateGraph:
"""
LangGraph graph builder for the knowledge graph search process.
"""
# graph = StateGraph(
# state_schema=MainState,
# input=MainInput,
# )
graph = StateGraph(
state_schema=MainState,
input=MainInput,
)
# ### Add nodes ###
### Add nodes ###
# graph.add_node(
# "search_objects",
# search_objects,
# )
graph.add_node(
"search_objects",
search_objects,
)
# graph.add_node(
# "structure_research_by_source",
# structure_research_by_object,
# )
graph.add_node(
"structure_research_by_source",
structure_research_by_object,
)
# graph.add_node(
# "research_object_source",
# research_object_source,
# )
graph.add_node(
"research_object_source",
research_object_source,
)
# graph.add_node(
# "consolidate_object_research",
# consolidate_object_research,
# )
graph.add_node(
"consolidate_object_research",
consolidate_object_research,
)
# graph.add_node(
# "consolidate_research",
# consolidate_research,
# )
graph.add_node(
"consolidate_research",
consolidate_research,
)
# ### Add edges ###
### Add edges ###
# graph.add_edge(start_key=START, end_key="search_objects")
graph.add_edge(start_key=START, end_key="search_objects")
# graph.add_conditional_edges(
# source="search_objects",
# path=parallel_object_source_research_edge,
# path_map=["research_object_source"],
# )
graph.add_conditional_edges(
source="search_objects",
path=parallel_object_source_research_edge,
path_map=["research_object_source"],
)
# graph.add_edge(
# start_key="research_object_source",
# end_key="structure_research_by_source",
# )
graph.add_edge(
start_key="research_object_source",
end_key="structure_research_by_source",
)
# graph.add_conditional_edges(
# source="structure_research_by_source",
# path=parallel_object_research_consolidation_edge,
# path_map=["consolidate_object_research"],
# )
graph.add_conditional_edges(
source="structure_research_by_source",
path=parallel_object_research_consolidation_edge,
path_map=["consolidate_object_research"],
)
# graph.add_edge(
# start_key="consolidate_object_research",
# end_key="consolidate_research",
# )
graph.add_edge(
start_key="consolidate_object_research",
end_key="consolidate_research",
)
# graph.add_edge(
# start_key="consolidate_research",
# end_key=END,
# )
graph.add_edge(
start_key="consolidate_research",
end_key=END,
)
# return graph
return graph

View File

@@ -1,146 +1,146 @@
# from typing import cast
from typing import cast
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
# from onyx.agents.agent_search.dc_search_analysis.ops import research
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
# from onyx.agents.agent_search.dc_search_analysis.states import (
# SearchSourcesObjectsUpdate,
# )
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_SEPARATOR
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT
# from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.ops import research
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import (
SearchSourcesObjectsUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.prompts.agents.dc_prompts import DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SEPARATOR
from onyx.prompts.agents.dc_prompts import DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT
from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# logger = setup_logger()
logger = setup_logger()
# def search_objects(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> SearchSourcesObjectsUpdate:
# """
# LangGraph node to start the agentic search process.
# """
def search_objects(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> SearchSourcesObjectsUpdate:
"""
LangGraph node to start the agentic search process.
"""
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# question = graph_config.inputs.prompt_builder.raw_user_query
# search_tool = graph_config.tooling.search_tool
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
search_tool = graph_config.tooling.search_tool
# if search_tool is None or graph_config.inputs.persona is None:
# raise ValueError("Search tool and persona must be provided for DivCon search")
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# try:
# instructions = graph_config.inputs.persona.system_prompt or ""
try:
instructions = graph_config.inputs.persona.system_prompt or ""
# agent_1_instructions = extract_section(
# instructions, "Agent Step 1:", "Agent Step 2:"
# )
# if agent_1_instructions is None:
# raise ValueError("Agent 1 instructions not found")
agent_1_instructions = extract_section(
instructions, "Agent Step 1:", "Agent Step 2:"
)
if agent_1_instructions is None:
raise ValueError("Agent 1 instructions not found")
# agent_1_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
agent_1_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
# agent_1_task = extract_section(
# agent_1_instructions, "Task:", "Independent Research Sources:"
# )
# if agent_1_task is None:
# raise ValueError("Agent 1 task not found")
agent_1_task = extract_section(
agent_1_instructions, "Task:", "Independent Research Sources:"
)
if agent_1_task is None:
raise ValueError("Agent 1 task not found")
# agent_1_independent_sources_str = extract_section(
# agent_1_instructions, "Independent Research Sources:", "Output Objective:"
# )
# if agent_1_independent_sources_str is None:
# raise ValueError("Agent 1 Independent Research Sources not found")
agent_1_independent_sources_str = extract_section(
agent_1_instructions, "Independent Research Sources:", "Output Objective:"
)
if agent_1_independent_sources_str is None:
raise ValueError("Agent 1 Independent Research Sources not found")
# document_sources = strings_to_document_sources(
# [
# x.strip().lower()
# for x in agent_1_independent_sources_str.split(DC_OBJECT_SEPARATOR)
# ]
# )
document_sources = strings_to_document_sources(
[
x.strip().lower()
for x in agent_1_independent_sources_str.split(DC_OBJECT_SEPARATOR)
]
)
# agent_1_output_objective = extract_section(
# agent_1_instructions, "Output Objective:"
# )
# if agent_1_output_objective is None:
# raise ValueError("Agent 1 output objective not found")
agent_1_output_objective = extract_section(
agent_1_instructions, "Output Objective:"
)
if agent_1_output_objective is None:
raise ValueError("Agent 1 output objective not found")
# except Exception as e:
# raise ValueError(
# f"Agent 1 instructions not found or not formatted correctly: {e}"
# )
except Exception as e:
raise ValueError(
f"Agent 1 instructions not found or not formatted correctly: {e}"
)
# # Extract objects
# Extract objects
# if agent_1_base_data is None:
# # Retrieve chunks for objects
if agent_1_base_data is None:
# Retrieve chunks for objects
# retrieved_docs = research(question, search_tool)[:10]
retrieved_docs = research(question, search_tool)[:10]
# document_texts_list = []
# for doc_num, doc in enumerate(retrieved_docs):
# chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
# document_texts_list.append(chunk_text)
document_texts_list = []
for doc_num, doc in enumerate(retrieved_docs):
chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
document_texts_list.append(chunk_text)
# document_texts = "\n\n".join(document_texts_list)
document_texts = "\n\n".join(document_texts_list)
# dc_object_extraction_prompt = DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT.format(
# question=question,
# task=agent_1_task,
# document_text=document_texts,
# objects_of_interest=agent_1_output_objective,
# )
# else:
# dc_object_extraction_prompt = DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT.format(
# question=question,
# task=agent_1_task,
# base_data=agent_1_base_data,
# objects_of_interest=agent_1_output_objective,
# )
dc_object_extraction_prompt = DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT.format(
question=question,
task=agent_1_task,
document_text=document_texts,
objects_of_interest=agent_1_output_objective,
)
else:
dc_object_extraction_prompt = DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT.format(
question=question,
task=agent_1_task,
base_data=agent_1_base_data,
objects_of_interest=agent_1_output_objective,
)
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=dc_object_extraction_prompt,
# reserved_str="",
# ),
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# 30,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=30,
# max_tokens=300,
# )
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_object_extraction_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
# cleaned_response = (
# str(llm_response.content)
# .replace("```json\n", "")
# .replace("\n```", "")
# .replace("\n", "")
# )
# cleaned_response = cleaned_response.split("OBJECTS:")[1]
# object_list = [x.strip() for x in cleaned_response.split(";")]
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
cleaned_response = cleaned_response.split("OBJECTS:")[1]
object_list = [x.strip() for x in cleaned_response.split(";")]
# except Exception as e:
# raise ValueError(f"Error in search_objects: {e}")
except Exception as e:
raise ValueError(f"Error in search_objects: {e}")
# return SearchSourcesObjectsUpdate(
# analysis_objects=object_list,
# analysis_sources=document_sources,
# log_messages=["Agent 1 Task done"],
# )
return SearchSourcesObjectsUpdate(
analysis_objects=object_list,
analysis_sources=document_sources,
log_messages=["Agent 1 Task done"],
)

View File

@@ -1,180 +1,180 @@
# from datetime import datetime
# from datetime import timedelta
# from datetime import timezone
# from typing import cast
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import cast
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
# from onyx.agents.agent_search.dc_search_analysis.ops import research
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
# from onyx.agents.agent_search.dc_search_analysis.states import (
# ObjectSourceResearchUpdate,
# )
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_SOURCE_RESEARCH_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.ops import research
from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectSourceResearchUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SOURCE_RESEARCH_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# logger = setup_logger()
logger = setup_logger()
# def research_object_source(
# state: ObjectSourceInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> ObjectSourceResearchUpdate:
# """
# LangGraph node to start the agentic search process.
# """
# datetime.now()
def research_object_source(
state: ObjectSourceInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ObjectSourceResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
datetime.now()
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# search_tool = graph_config.tooling.search_tool
# question = graph_config.inputs.prompt_builder.raw_user_query
# object, document_source = state.object_source_combination
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
question = graph_config.inputs.prompt_builder.raw_user_query
object, document_source = state.object_source_combination
# if search_tool is None or graph_config.inputs.persona is None:
# raise ValueError("Search tool and persona must be provided for DivCon search")
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# try:
# instructions = graph_config.inputs.persona.system_prompt or ""
try:
instructions = graph_config.inputs.persona.system_prompt or ""
# agent_2_instructions = extract_section(
# instructions, "Agent Step 2:", "Agent Step 3:"
# )
# if agent_2_instructions is None:
# raise ValueError("Agent 2 instructions not found")
agent_2_instructions = extract_section(
instructions, "Agent Step 2:", "Agent Step 3:"
)
if agent_2_instructions is None:
raise ValueError("Agent 2 instructions not found")
# agent_2_task = extract_section(
# agent_2_instructions, "Task:", "Independent Research Sources:"
# )
# if agent_2_task is None:
# raise ValueError("Agent 2 task not found")
agent_2_task = extract_section(
agent_2_instructions, "Task:", "Independent Research Sources:"
)
if agent_2_task is None:
raise ValueError("Agent 2 task not found")
# agent_2_time_cutoff = extract_section(
# agent_2_instructions, "Time Cutoff:", "Research Topics:"
# )
agent_2_time_cutoff = extract_section(
agent_2_instructions, "Time Cutoff:", "Research Topics:"
)
# agent_2_research_topics = extract_section(
# agent_2_instructions, "Research Topics:", "Output Objective"
# )
agent_2_research_topics = extract_section(
agent_2_instructions, "Research Topics:", "Output Objective"
)
# agent_2_output_objective = extract_section(
# agent_2_instructions, "Output Objective:"
# )
# if agent_2_output_objective is None:
# raise ValueError("Agent 2 output objective not found")
agent_2_output_objective = extract_section(
agent_2_instructions, "Output Objective:"
)
if agent_2_output_objective is None:
raise ValueError("Agent 2 output objective not found")
# except Exception:
# raise ValueError(
# "Agent 1 instructions not found or not formatted correctly: {e}"
# )
except Exception:
raise ValueError(
"Agent 1 instructions not found or not formatted correctly: {e}"
)
# # Populate prompt
# Populate prompt
# # Retrieve chunks for objects
# Retrieve chunks for objects
# if agent_2_time_cutoff is not None and agent_2_time_cutoff.strip() != "":
# if agent_2_time_cutoff.strip().endswith("d"):
# try:
# days = int(agent_2_time_cutoff.strip()[:-1])
# agent_2_source_start_time = datetime.now(timezone.utc) - timedelta(
# days=days
# )
# except ValueError:
# raise ValueError(
# f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
# )
# else:
# raise ValueError(
# f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
# )
# else:
# agent_2_source_start_time = None
if agent_2_time_cutoff is not None and agent_2_time_cutoff.strip() != "":
if agent_2_time_cutoff.strip().endswith("d"):
try:
days = int(agent_2_time_cutoff.strip()[:-1])
agent_2_source_start_time = datetime.now(timezone.utc) - timedelta(
days=days
)
except ValueError:
raise ValueError(
f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
)
else:
raise ValueError(
f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
)
else:
agent_2_source_start_time = None
# document_sources = [document_source] if document_source else None
document_sources = [document_source] if document_source else None
# if len(question.strip()) > 0:
# research_area = f"{question} for {object}"
# elif agent_2_research_topics and len(agent_2_research_topics.strip()) > 0:
# research_area = f"{agent_2_research_topics} for {object}"
# else:
# research_area = object
if len(question.strip()) > 0:
research_area = f"{question} for {object}"
elif agent_2_research_topics and len(agent_2_research_topics.strip()) > 0:
research_area = f"{agent_2_research_topics} for {object}"
else:
research_area = object
# retrieved_docs = research(
# question=research_area,
# search_tool=search_tool,
# document_sources=document_sources,
# time_cutoff=agent_2_source_start_time,
# )
retrieved_docs = research(
question=research_area,
search_tool=search_tool,
document_sources=document_sources,
time_cutoff=agent_2_source_start_time,
)
# # Generate document text
# Generate document text
# document_texts_list = []
# for doc_num, doc in enumerate(retrieved_docs):
# chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
# document_texts_list.append(chunk_text)
document_texts_list = []
for doc_num, doc in enumerate(retrieved_docs):
chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
document_texts_list.append(chunk_text)
# document_texts = "\n\n".join(document_texts_list)
document_texts = "\n\n".join(document_texts_list)
# # Built prompt
# Built prompt
# today = datetime.now().strftime("%A, %Y-%m-%d")
today = datetime.now().strftime("%A, %Y-%m-%d")
# dc_object_source_research_prompt = (
# DC_OBJECT_SOURCE_RESEARCH_PROMPT.format(
# today=today,
# question=question,
# task=agent_2_task,
# document_text=document_texts,
# format=agent_2_output_objective,
# )
# .replace("---object---", object)
# .replace("---source---", document_source.value)
# )
dc_object_source_research_prompt = (
DC_OBJECT_SOURCE_RESEARCH_PROMPT.format(
today=today,
question=question,
task=agent_2_task,
document_text=document_texts,
format=agent_2_output_objective,
)
.replace("---object---", object)
.replace("---source---", document_source.value)
)
# # Run LLM
# Run LLM
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=dc_object_source_research_prompt,
# reserved_str="",
# ),
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# 30,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=30,
# max_tokens=300,
# )
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_object_source_research_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
# cleaned_response = str(llm_response.content).replace("```json\n", "")
# cleaned_response = cleaned_response.split("RESEARCH RESULTS:")[1]
# object_research_results = {
# "object": object,
# "source": document_source.value,
# "research_result": cleaned_response,
# }
cleaned_response = str(llm_response.content).replace("```json\n", "")
cleaned_response = cleaned_response.split("RESEARCH RESULTS:")[1]
object_research_results = {
"object": object,
"source": document_source.value,
"research_result": cleaned_response,
}
# except Exception as e:
# raise ValueError(f"Error in research_object_source: {e}")
except Exception as e:
raise ValueError(f"Error in research_object_source: {e}")
# logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
# return ObjectSourceResearchUpdate(
# object_source_research_results=[object_research_results],
# log_messages=["Agent Step 2 done for one object"],
# )
return ObjectSourceResearchUpdate(
object_source_research_results=[object_research_results],
log_messages=["Agent Step 2 done for one object"],
)

View File

@@ -1,48 +1,48 @@
# from collections import defaultdict
# from typing import Dict
# from typing import List
from collections import defaultdict
from typing import Dict
from typing import List
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
# from onyx.agents.agent_search.dc_search_analysis.states import (
# ObjectResearchInformationUpdate,
# )
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectResearchInformationUpdate,
)
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def structure_research_by_object(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> ObjectResearchInformationUpdate:
# """
# LangGraph node to start the agentic search process.
# """
def structure_research_by_object(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ObjectResearchInformationUpdate:
"""
LangGraph node to start the agentic search process.
"""
# object_source_research_results = state.object_source_research_results
object_source_research_results = state.object_source_research_results
# object_research_information_results: List[Dict[str, str]] = []
# object_research_information_results_list: Dict[str, List[str]] = defaultdict(list)
object_research_information_results: List[Dict[str, str]] = []
object_research_information_results_list: Dict[str, List[str]] = defaultdict(list)
# for object_source_research in object_source_research_results:
# object = object_source_research["object"]
# source = object_source_research["source"]
# research_result = object_source_research["research_result"]
for object_source_research in object_source_research_results:
object = object_source_research["object"]
source = object_source_research["source"]
research_result = object_source_research["research_result"]
# object_research_information_results_list[object].append(
# f"Source: {source}\n{research_result}"
# )
object_research_information_results_list[object].append(
f"Source: {source}\n{research_result}"
)
# for object, information in object_research_information_results_list.items():
# object_research_information_results.append(
# {"object": object, "information": "\n".join(information)}
# )
for object, information in object_research_information_results_list.items():
object_research_information_results.append(
{"object": object, "information": "\n".join(information)}
)
# logger.debug("DivCon Step A3 - Object Research Information Structuring - completed")
logger.debug("DivCon Step A3 - Object Research Information Structuring - completed")
# return ObjectResearchInformationUpdate(
# object_research_information_results=object_research_information_results,
# log_messages=["A3 - Object Research Information structured"],
# )
return ObjectResearchInformationUpdate(
object_research_information_results=object_research_information_results,
log_messages=["A3 - Object Research Information structured"],
)

View File

@@ -1,103 +1,103 @@
# from typing import cast
from typing import cast
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectResearchUpdate
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_CONSOLIDATION_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
from onyx.agents.agent_search.dc_search_analysis.states import ObjectResearchUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.prompts.agents.dc_prompts import DC_OBJECT_CONSOLIDATION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# logger = setup_logger()
logger = setup_logger()
# def consolidate_object_research(
# state: ObjectInformationInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> ObjectResearchUpdate:
# """
# LangGraph node to start the agentic search process.
# """
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# search_tool = graph_config.tooling.search_tool
# question = graph_config.inputs.prompt_builder.raw_user_query
def consolidate_object_research(
state: ObjectInformationInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ObjectResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
question = graph_config.inputs.prompt_builder.raw_user_query
# if search_tool is None or graph_config.inputs.persona is None:
# raise ValueError("Search tool and persona must be provided for DivCon search")
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# instructions = graph_config.inputs.persona.system_prompt or ""
instructions = graph_config.inputs.persona.system_prompt or ""
# agent_4_instructions = extract_section(
# instructions, "Agent Step 4:", "Agent Step 5:"
# )
# if agent_4_instructions is None:
# raise ValueError("Agent 4 instructions not found")
# agent_4_output_objective = extract_section(
# agent_4_instructions, "Output Objective:"
# )
# if agent_4_output_objective is None:
# raise ValueError("Agent 4 output objective not found")
agent_4_instructions = extract_section(
instructions, "Agent Step 4:", "Agent Step 5:"
)
if agent_4_instructions is None:
raise ValueError("Agent 4 instructions not found")
agent_4_output_objective = extract_section(
agent_4_instructions, "Output Objective:"
)
if agent_4_output_objective is None:
raise ValueError("Agent 4 output objective not found")
# object_information = state.object_information
object_information = state.object_information
# object = object_information["object"]
# information = object_information["information"]
object = object_information["object"]
information = object_information["information"]
# # Create a prompt for the object consolidation
# Create a prompt for the object consolidation
# dc_object_consolidation_prompt = DC_OBJECT_CONSOLIDATION_PROMPT.format(
# question=question,
# object=object,
# information=information,
# format=agent_4_output_objective,
# )
dc_object_consolidation_prompt = DC_OBJECT_CONSOLIDATION_PROMPT.format(
question=question,
object=object,
information=information,
format=agent_4_output_objective,
)
# # Run LLM
# Run LLM
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=dc_object_consolidation_prompt,
# reserved_str="",
# ),
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# 30,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=30,
# max_tokens=300,
# )
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_object_consolidation_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
# cleaned_response = str(llm_response.content).replace("```json\n", "")
# consolidated_information = cleaned_response.split("INFORMATION:")[1]
cleaned_response = str(llm_response.content).replace("```json\n", "")
consolidated_information = cleaned_response.split("INFORMATION:")[1]
# except Exception as e:
# raise ValueError(f"Error in consolidate_object_research: {e}")
except Exception as e:
raise ValueError(f"Error in consolidate_object_research: {e}")
# object_research_results = {
# "object": object,
# "research_result": consolidated_information,
# }
object_research_results = {
"object": object,
"research_result": consolidated_information,
}
# logger.debug(
# "DivCon Step A4 - Object Research Consolidation - completed for an object"
# )
logger.debug(
"DivCon Step A4 - Object Research Consolidation - completed for an object"
)
# return ObjectResearchUpdate(
# object_research_results=[object_research_results],
# log_messages=["Agent Source Consilidation done"],
# )
return ObjectResearchUpdate(
object_research_results=[object_research_results],
log_messages=["Agent Source Consilidation done"],
)

View File

@@ -1,127 +1,127 @@
# from typing import cast
from typing import cast
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
# from onyx.agents.agent_search.dc_search_analysis.states import ResearchUpdate
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
# from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
# from onyx.prompts.agents.dc_prompts import DC_FORMATTING_WITH_BASE_DATA_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import ResearchUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_WITH_BASE_DATA_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# logger = setup_logger()
logger = setup_logger()
# def consolidate_research(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> ResearchUpdate:
# """
# LangGraph node to start the agentic search process.
# """
def consolidate_research(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
# graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config = cast(GraphConfig, config["metadata"]["config"])
# search_tool = graph_config.tooling.search_tool
search_tool = graph_config.tooling.search_tool
# if search_tool is None or graph_config.inputs.persona is None:
# raise ValueError("Search tool and persona must be provided for DivCon search")
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# # Populate prompt
# instructions = graph_config.inputs.persona.system_prompt or ""
# Populate prompt
instructions = graph_config.inputs.persona.system_prompt or ""
# try:
# agent_5_instructions = extract_section(
# instructions, "Agent Step 5:", "Agent End"
# )
# if agent_5_instructions is None:
# raise ValueError("Agent 5 instructions not found")
# agent_5_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
# agent_5_task = extract_section(
# agent_5_instructions, "Task:", "Independent Research Sources:"
# )
# if agent_5_task is None:
# raise ValueError("Agent 5 task not found")
# agent_5_output_objective = extract_section(
# agent_5_instructions, "Output Objective:"
# )
# if agent_5_output_objective is None:
# raise ValueError("Agent 5 output objective not found")
# except ValueError as e:
# raise ValueError(
# f"Instructions for Agent Step 5 were not properly formatted: {e}"
# )
try:
agent_5_instructions = extract_section(
instructions, "Agent Step 5:", "Agent End"
)
if agent_5_instructions is None:
raise ValueError("Agent 5 instructions not found")
agent_5_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
agent_5_task = extract_section(
agent_5_instructions, "Task:", "Independent Research Sources:"
)
if agent_5_task is None:
raise ValueError("Agent 5 task not found")
agent_5_output_objective = extract_section(
agent_5_instructions, "Output Objective:"
)
if agent_5_output_objective is None:
raise ValueError("Agent 5 output objective not found")
except ValueError as e:
raise ValueError(
f"Instructions for Agent Step 5 were not properly formatted: {e}"
)
# research_result_list = []
research_result_list = []
# if agent_5_task.strip() == "*concatenate*":
# object_research_results = state.object_research_results
if agent_5_task.strip() == "*concatenate*":
object_research_results = state.object_research_results
# for object_research_result in object_research_results:
# object = object_research_result["object"]
# research_result = object_research_result["research_result"]
# research_result_list.append(f"Object: {object}\n\n{research_result}")
for object_research_result in object_research_results:
object = object_research_result["object"]
research_result = object_research_result["research_result"]
research_result_list.append(f"Object: {object}\n\n{research_result}")
# research_results = "\n\n".join(research_result_list)
research_results = "\n\n".join(research_result_list)
# else:
# raise NotImplementedError("Only '*concatenate*' is currently supported")
else:
raise NotImplementedError("Only '*concatenate*' is currently supported")
# # Create a prompt for the object consolidation
# Create a prompt for the object consolidation
# if agent_5_base_data is None:
# dc_formatting_prompt = DC_FORMATTING_NO_BASE_DATA_PROMPT.format(
# text=research_results,
# format=agent_5_output_objective,
# )
# else:
# dc_formatting_prompt = DC_FORMATTING_WITH_BASE_DATA_PROMPT.format(
# base_data=agent_5_base_data,
# text=research_results,
# format=agent_5_output_objective,
# )
if agent_5_base_data is None:
dc_formatting_prompt = DC_FORMATTING_NO_BASE_DATA_PROMPT.format(
text=research_results,
format=agent_5_output_objective,
)
else:
dc_formatting_prompt = DC_FORMATTING_WITH_BASE_DATA_PROMPT.format(
base_data=agent_5_base_data,
text=research_results,
format=agent_5_output_objective,
)
# # Run LLM
# Run LLM
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=dc_formatting_prompt,
# reserved_str="",
# ),
# )
# ]
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_formatting_prompt,
reserved_str="",
),
)
]
# try:
# _ = run_with_timeout(
# 60,
# lambda: stream_llm_answer(
# llm=graph_config.tooling.primary_llm,
# prompt=msg,
# event_name="initial_agent_answer",
# writer=writer,
# agent_answer_level=0,
# agent_answer_question_num=0,
# agent_answer_type="agent_level_answer",
# timeout_override=30,
# max_tokens=None,
# ),
# )
try:
_ = run_with_timeout(
60,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=msg,
event_name="initial_agent_answer",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=30,
max_tokens=None,
),
)
# except Exception as e:
# raise ValueError(f"Error in consolidate_research: {e}")
except Exception as e:
raise ValueError(f"Error in consolidate_research: {e}")
# logger.debug("DivCon Step A5 - Final Generation - completed")
logger.debug("DivCon Step A5 - Final Generation - completed")
# return ResearchUpdate(
# research_results=research_results,
# log_messages=["Agent Source Consilidation done"],
# )
return ResearchUpdate(
research_results=research_results,
log_messages=["Agent Source Consilidation done"],
)

View File

@@ -1,50 +1,61 @@
# from datetime import datetime
# from typing import cast
from datetime import datetime
from typing import cast
# from onyx.chat.models import LlmDoc
# from onyx.configs.constants import DocumentSource
# from onyx.tools.models import SearchToolOverrideKwargs
# from onyx.tools.tool_implementations.search.search_tool import (
# FINAL_CONTEXT_DOCUMENTS_ID,
# )
# from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.chat.models import LlmDoc
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
# def research(
# question: str,
# search_tool: SearchTool,
# document_sources: list[DocumentSource] | None = None,
# time_cutoff: datetime | None = None,
# ) -> list[LlmDoc]:
# # new db session to avoid concurrency issues
def research(
question: str,
search_tool: SearchTool,
document_sources: list[DocumentSource] | None = None,
time_cutoff: datetime | None = None,
) -> list[LlmDoc]:
# new db session to avoid concurrency issues
# retrieved_docs: list[LlmDoc] = []
callback_container: list[list[InferenceSection]] = []
retrieved_docs: list[LlmDoc] = []
# for tool_response in search_tool.run(
# query=question,
# override_kwargs=SearchToolOverrideKwargs(original_query=question),
# ):
# # get retrieved docs to send to the rest of the graph
# if tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
# retrieved_docs = cast(list[LlmDoc], tool_response.response)[:10]
# break
# return retrieved_docs
with get_session_with_current_tenant() as db_session:
for tool_response in search_tool.run(
query=question,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=False,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
document_sources=document_sources,
time_cutoff=time_cutoff,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
retrieved_docs = cast(list[LlmDoc], tool_response.response)[:10]
break
return retrieved_docs
# def extract_section(
# text: str, start_marker: str, end_marker: str | None = None
# ) -> str | None:
# """Extract text between markers, returning None if markers not found"""
# parts = text.split(start_marker)
def extract_section(
text: str, start_marker: str, end_marker: str | None = None
) -> str | None:
"""Extract text between markers, returning None if markers not found"""
parts = text.split(start_marker)
# if len(parts) == 1:
# return None
if len(parts) == 1:
return None
# after_start = parts[1].strip()
after_start = parts[1].strip()
# if not end_marker:
# return after_start
if not end_marker:
return after_start
# extract = after_start.split(end_marker)[0]
extract = after_start.split(end_marker)[0]
# return extract.strip()
return extract.strip()

View File

@@ -1,72 +1,72 @@
# from operator import add
# from typing import Annotated
# from typing import Dict
# from typing import TypedDict
from operator import add
from typing import Annotated
from typing import Dict
from typing import TypedDict
# from pydantic import BaseModel
from pydantic import BaseModel
# from onyx.agents.agent_search.core_state import CoreState
# from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
# from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
# from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
# from onyx.configs.constants import DocumentSource
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.configs.constants import DocumentSource
# ### States ###
# class LoggerUpdate(BaseModel):
# log_messages: Annotated[list[str], add] = []
### States ###
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
# class SearchSourcesObjectsUpdate(LoggerUpdate):
# analysis_objects: list[str] = []
# analysis_sources: list[DocumentSource] = []
class SearchSourcesObjectsUpdate(LoggerUpdate):
analysis_objects: list[str] = []
analysis_sources: list[DocumentSource] = []
# class ObjectSourceInput(LoggerUpdate):
# object_source_combination: tuple[str, DocumentSource]
class ObjectSourceInput(LoggerUpdate):
object_source_combination: tuple[str, DocumentSource]
# class ObjectSourceResearchUpdate(LoggerUpdate):
# object_source_research_results: Annotated[list[Dict[str, str]], add] = []
class ObjectSourceResearchUpdate(LoggerUpdate):
object_source_research_results: Annotated[list[Dict[str, str]], add] = []
# class ObjectInformationInput(LoggerUpdate):
# object_information: Dict[str, str]
class ObjectInformationInput(LoggerUpdate):
object_information: Dict[str, str]
# class ObjectResearchInformationUpdate(LoggerUpdate):
# object_research_information_results: Annotated[list[Dict[str, str]], add] = []
class ObjectResearchInformationUpdate(LoggerUpdate):
object_research_information_results: Annotated[list[Dict[str, str]], add] = []
# class ObjectResearchUpdate(LoggerUpdate):
# object_research_results: Annotated[list[Dict[str, str]], add] = []
class ObjectResearchUpdate(LoggerUpdate):
object_research_results: Annotated[list[Dict[str, str]], add] = []
# class ResearchUpdate(LoggerUpdate):
# research_results: str | None = None
class ResearchUpdate(LoggerUpdate):
research_results: str | None = None
# ## Graph Input State
# class MainInput(CoreState):
# pass
## Graph Input State
class MainInput(CoreState):
pass
# ## Graph State
# class MainState(
# # This includes the core state
# MainInput,
# ToolChoiceInput,
# ToolCallUpdate,
# ToolChoiceUpdate,
# SearchSourcesObjectsUpdate,
# ObjectSourceResearchUpdate,
# ObjectResearchInformationUpdate,
# ObjectResearchUpdate,
# ResearchUpdate,
# ):
# pass
## Graph State
class MainState(
# This includes the core state
MainInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
SearchSourcesObjectsUpdate,
ObjectSourceResearchUpdate,
ObjectResearchInformationUpdate,
ObjectResearchUpdate,
ResearchUpdate,
):
pass
# ## Graph Output State - presently not used
# class MainOutput(TypedDict):
# log_messages: list[str]
## Graph Output State - presently not used
class MainOutput(TypedDict):
log_messages: list[str]

View File

@@ -1,36 +1,36 @@
# from pydantic import BaseModel
from pydantic import BaseModel
# class RefinementSubQuestion(BaseModel):
# sub_question: str
# sub_question_id: str
# verified: bool
# answered: bool
# answer: str
class RefinementSubQuestion(BaseModel):
sub_question: str
sub_question_id: str
verified: bool
answered: bool
answer: str
# class AgentTimings(BaseModel):
# base_duration_s: float | None
# refined_duration_s: float | None
# full_duration_s: float | None
class AgentTimings(BaseModel):
base_duration_s: float | None
refined_duration_s: float | None
full_duration_s: float | None
# class AgentBaseMetrics(BaseModel):
# num_verified_documents_total: int | None
# num_verified_documents_core: int | None
# verified_avg_score_core: float | None
# num_verified_documents_base: int | float | None
# verified_avg_score_base: float | None = None
# base_doc_boost_factor: float | None = None
# support_boost_factor: float | None = None
# duration_s: float | None = None
class AgentBaseMetrics(BaseModel):
num_verified_documents_total: int | None
num_verified_documents_core: int | None
verified_avg_score_core: float | None
num_verified_documents_base: int | float | None
verified_avg_score_base: float | None = None
base_doc_boost_factor: float | None = None
support_boost_factor: float | None = None
duration_s: float | None = None
# class AgentRefinedMetrics(BaseModel):
# refined_doc_boost_factor: float | None = None
# refined_question_boost_factor: float | None = None
# duration_s: float | None = None
class AgentRefinedMetrics(BaseModel):
refined_doc_boost_factor: float | None = None
refined_question_boost_factor: float | None = None
duration_s: float | None = None
# class AgentAdditionalMetrics(BaseModel):
# pass
class AgentAdditionalMetrics(BaseModel):
pass

View File

@@ -1,61 +1,61 @@
# from collections.abc import Hashable
from collections.abc import Hashable
# from langgraph.graph import END
# from langgraph.types import Send
from langgraph.graph import END
from langgraph.types import Send
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.states import MainState
# def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
# if not state.tools_used:
# raise IndexError("state.tools_used cannot be empty")
def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
if not state.tools_used:
raise IndexError("state.tools_used cannot be empty")
# # next_tool is either a generic tool name or a DRPath string
# next_tool_name = state.tools_used[-1]
# next_tool is either a generic tool name or a DRPath string
next_tool_name = state.tools_used[-1]
# available_tools = state.available_tools
# if not available_tools:
# raise ValueError("No tool is available. This should not happen.")
available_tools = state.available_tools
if not available_tools:
raise ValueError("No tool is available. This should not happen.")
# if next_tool_name in available_tools:
# next_tool_path = available_tools[next_tool_name].path
# elif next_tool_name == DRPath.END.value:
# return END
# elif next_tool_name == DRPath.LOGGER.value:
# return DRPath.LOGGER
# elif next_tool_name == DRPath.CLOSER.value:
# return DRPath.CLOSER
# else:
# return DRPath.ORCHESTRATOR
if next_tool_name in available_tools:
next_tool_path = available_tools[next_tool_name].path
elif next_tool_name == DRPath.END.value:
return END
elif next_tool_name == DRPath.LOGGER.value:
return DRPath.LOGGER
elif next_tool_name == DRPath.CLOSER.value:
return DRPath.CLOSER
else:
return DRPath.ORCHESTRATOR
# # handle invalid paths
# if next_tool_path == DRPath.CLARIFIER:
# raise ValueError("CLARIFIER is not a valid path during iteration")
# handle invalid paths
if next_tool_path == DRPath.CLARIFIER:
raise ValueError("CLARIFIER is not a valid path during iteration")
# # handle tool calls without a query
# if (
# next_tool_path
# in (
# DRPath.INTERNAL_SEARCH,
# DRPath.WEB_SEARCH,
# DRPath.KNOWLEDGE_GRAPH,
# DRPath.IMAGE_GENERATION,
# )
# and len(state.query_list) == 0
# ):
# return DRPath.CLOSER
# handle tool calls without a query
if (
next_tool_path
in (
DRPath.INTERNAL_SEARCH,
DRPath.WEB_SEARCH,
DRPath.KNOWLEDGE_GRAPH,
DRPath.IMAGE_GENERATION,
)
and len(state.query_list) == 0
):
return DRPath.CLOSER
# return next_tool_path
return next_tool_path
# def completeness_router(state: MainState) -> DRPath | str:
# if not state.tools_used:
# raise IndexError("tools_used cannot be empty")
def completeness_router(state: MainState) -> DRPath | str:
if not state.tools_used:
raise IndexError("tools_used cannot be empty")
# # go to closer if path is CLOSER or no queries
# next_path = state.tools_used[-1]
# go to closer if path is CLOSER or no queries
next_path = state.tools_used[-1]
# if next_path == DRPath.ORCHESTRATOR.value:
# return DRPath.ORCHESTRATOR
# return DRPath.LOGGER
if next_path == DRPath.ORCHESTRATOR.value:
return DRPath.ORCHESTRATOR
return DRPath.LOGGER

View File

@@ -1,27 +1,31 @@
# from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchType
# MAX_CHAT_HISTORY_MESSAGES = (
# 3 # note: actual count is x2 to account for user and assistant messages
# )
MAX_CHAT_HISTORY_MESSAGES = (
3 # note: actual count is x2 to account for user and assistant messages
)
# MAX_DR_PARALLEL_SEARCH = 4
MAX_DR_PARALLEL_SEARCH = 4
# # TODO: test more, generally not needed/adds unnecessary iterations
# MAX_NUM_CLOSER_SUGGESTIONS = (
# 0 # how many times the closer can send back to the orchestrator
# )
# TODO: test more, generally not needed/adds unnecessary iterations
MAX_NUM_CLOSER_SUGGESTIONS = (
0 # how many times the closer can send back to the orchestrator
)
# CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
# HIGH_LEVEL_PLAN_PREFIX = "The Plan:"
CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
HIGH_LEVEL_PLAN_PREFIX = "The Plan:"
# AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
# DRPath.INTERNAL_SEARCH: 1.0,
# DRPath.KNOWLEDGE_GRAPH: 2.0,
# DRPath.WEB_SEARCH: 1.5,
# DRPath.IMAGE_GENERATION: 3.0,
# DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
# DRPath.CLOSER: 0.0,
# }
AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
DRPath.INTERNAL_SEARCH: 1.0,
DRPath.KNOWLEDGE_GRAPH: 2.0,
DRPath.WEB_SEARCH: 1.5,
DRPath.IMAGE_GENERATION: 3.0,
DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
DRPath.CLOSER: 0.0,
}
# # Default time budget for agentic search (when use_agentic_search is True)
# DR_TIME_BUDGET_DEFAULT = 12.0
DR_TIME_BUDGET_BY_TYPE = {
ResearchType.THOUGHTFUL: 3.0,
ResearchType.DEEP: 12.0,
ResearchType.FAST: 0.5,
}

View File

@@ -1,111 +1,112 @@
# from datetime import datetime
from datetime import datetime
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.models import DRPromptPurpose
# from onyx.agents.agent_search.dr.models import OrchestratorTool
# from onyx.prompts.dr_prompts import GET_CLARIFICATION_PROMPT
# from onyx.prompts.dr_prompts import KG_TYPES_DESCRIPTIONS
# from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
# from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
# from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
# from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
# from onyx.prompts.dr_prompts import ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
# from onyx.prompts.dr_prompts import TOOL_DIFFERENTIATION_HINTS
# from onyx.prompts.dr_prompts import TOOL_QUESTION_HINTS
# from onyx.prompts.prompt_template import PromptTemplate
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import DRPromptPurpose
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.prompts.dr_prompts import GET_CLARIFICATION_PROMPT
from onyx.prompts.dr_prompts import KG_TYPES_DESCRIPTIONS
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
from onyx.prompts.dr_prompts import TOOL_DIFFERENTIATION_HINTS
from onyx.prompts.dr_prompts import TOOL_QUESTION_HINTS
from onyx.prompts.prompt_template import PromptTemplate
# def get_dr_prompt_orchestration_templates(
# purpose: DRPromptPurpose,
# use_agentic_search: bool,
# available_tools: dict[str, OrchestratorTool],
# entity_types_string: str | None = None,
# relationship_types_string: str | None = None,
# reasoning_result: str | None = None,
# tool_calls_string: str | None = None,
# ) -> PromptTemplate:
# available_tools = available_tools or {}
# tool_names = list(available_tools.keys())
# tool_description_str = "\n\n".join(
# f"- {tool_name}: {tool.description}"
# for tool_name, tool in available_tools.items()
# )
# tool_cost_str = "\n".join(
# f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
# )
def get_dr_prompt_orchestration_templates(
purpose: DRPromptPurpose,
research_type: ResearchType,
available_tools: dict[str, OrchestratorTool],
entity_types_string: str | None = None,
relationship_types_string: str | None = None,
reasoning_result: str | None = None,
tool_calls_string: str | None = None,
) -> PromptTemplate:
available_tools = available_tools or {}
tool_names = list(available_tools.keys())
tool_description_str = "\n\n".join(
f"- {tool_name}: {tool.description}"
for tool_name, tool in available_tools.items()
)
tool_cost_str = "\n".join(
f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
)
# tool_differentiations: list[str] = [
# TOOL_DIFFERENTIATION_HINTS[(tool_1, tool_2)]
# for tool_1 in available_tools
# for tool_2 in available_tools
# if (tool_1, tool_2) in TOOL_DIFFERENTIATION_HINTS
# ]
# tool_differentiation_hint_string = (
# "\n".join(tool_differentiations) or "(No differentiating hints available)"
# )
# # TODO: add tool deliniation pairs for custom tools as well
tool_differentiations: list[str] = [
TOOL_DIFFERENTIATION_HINTS[(tool_1, tool_2)]
for tool_1 in available_tools
for tool_2 in available_tools
if (tool_1, tool_2) in TOOL_DIFFERENTIATION_HINTS
]
tool_differentiation_hint_string = (
"\n".join(tool_differentiations) or "(No differentiating hints available)"
)
# TODO: add tool deliniation pairs for custom tools as well
# tool_question_hint_string = (
# "\n".join(
# "- " + TOOL_QUESTION_HINTS[tool]
# for tool in available_tools
# if tool in TOOL_QUESTION_HINTS
# )
# or "(No examples available)"
# )
tool_question_hint_string = (
"\n".join(
"- " + TOOL_QUESTION_HINTS[tool]
for tool in available_tools
if tool in TOOL_QUESTION_HINTS
)
or "(No examples available)"
)
# if DRPath.KNOWLEDGE_GRAPH.value in available_tools and (
# entity_types_string or relationship_types_string
# ):
if DRPath.KNOWLEDGE_GRAPH.value in available_tools and (
entity_types_string or relationship_types_string
):
# kg_types_descriptions = KG_TYPES_DESCRIPTIONS.build(
# possible_entities=entity_types_string or "",
# possible_relationships=relationship_types_string or "",
# )
# else:
# kg_types_descriptions = "(The Knowledge Graph is not used.)"
kg_types_descriptions = KG_TYPES_DESCRIPTIONS.build(
possible_entities=entity_types_string or "",
possible_relationships=relationship_types_string or "",
)
else:
kg_types_descriptions = "(The Knowledge Graph is not used.)"
# if purpose == DRPromptPurpose.PLAN:
# if not use_agentic_search:
# raise ValueError("plan generation is only supported for agentic search")
# base_template = ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
if purpose == DRPromptPurpose.PLAN:
if research_type == ResearchType.THOUGHTFUL:
raise ValueError("plan generation is not supported for FAST time budget")
base_template = ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
# elif purpose == DRPromptPurpose.NEXT_STEP_REASONING:
# if not use_agentic_search:
# base_template = ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
# else:
# raise ValueError(
# "reasoning is not separately required for agentic search"
# )
elif purpose == DRPromptPurpose.NEXT_STEP_REASONING:
if research_type == ResearchType.THOUGHTFUL:
base_template = ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
else:
raise ValueError(
"reasoning is not separately required for DEEP time budget"
)
# elif purpose == DRPromptPurpose.NEXT_STEP_PURPOSE:
# base_template = ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
elif purpose == DRPromptPurpose.NEXT_STEP_PURPOSE:
base_template = ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
# elif purpose == DRPromptPurpose.NEXT_STEP:
# if not use_agentic_search:
# base_template = ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
# else:
# base_template = ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
elif purpose == DRPromptPurpose.NEXT_STEP:
if research_type == ResearchType.THOUGHTFUL:
base_template = ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
else:
base_template = ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
# elif purpose == DRPromptPurpose.CLARIFICATION:
# if not use_agentic_search:
# raise ValueError("clarification is only supported for agentic search")
# base_template = GET_CLARIFICATION_PROMPT
elif purpose == DRPromptPurpose.CLARIFICATION:
if research_type == ResearchType.THOUGHTFUL:
raise ValueError("clarification is not supported for FAST time budget")
base_template = GET_CLARIFICATION_PROMPT
# else:
# # for mypy, clearly a mypy bug
# raise ValueError(f"Invalid purpose: {purpose}")
else:
# for mypy, clearly a mypy bug
raise ValueError(f"Invalid purpose: {purpose}")
# return base_template.partial_build(
# num_available_tools=str(len(tool_names)),
# available_tools=", ".join(tool_names),
# tool_choice_options=" or ".join(tool_names),
# current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
# kg_types_descriptions=kg_types_descriptions,
# tool_descriptions=tool_description_str,
# tool_differentiation_hints=tool_differentiation_hint_string,
# tool_question_hints=tool_question_hint_string,
# average_tool_costs=tool_cost_str,
# reasoning_result=reasoning_result or "(No reasoning result provided.)",
# tool_calls_string=tool_calls_string or "(No tool calls provided.)",
# )
return base_template.partial_build(
num_available_tools=str(len(tool_names)),
available_tools=", ".join(tool_names),
tool_choice_options=" or ".join(tool_names),
current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
kg_types_descriptions=kg_types_descriptions,
tool_descriptions=tool_description_str,
tool_differentiation_hints=tool_differentiation_hint_string,
tool_question_hints=tool_question_hint_string,
average_tool_costs=tool_cost_str,
reasoning_result=reasoning_result or "(No reasoning result provided.)",
tool_calls_string=tool_calls_string or "(No tool calls provided.)",
)

View File

@@ -1,22 +1,32 @@
# from enum import Enum
from enum import Enum
# class ResearchAnswerPurpose(str, Enum):
# """Research answer purpose options for agent search operations"""
class ResearchType(str, Enum):
"""Research type options for agent search operations"""
# ANSWER = "ANSWER"
# CLARIFICATION_REQUEST = "CLARIFICATION_REQUEST"
# BASIC = "BASIC"
LEGACY_AGENTIC = "LEGACY_AGENTIC" # only used for legacy agentic search migrations
THOUGHTFUL = "THOUGHTFUL"
DEEP = "DEEP"
FAST = "FAST"
# class DRPath(str, Enum):
# CLARIFIER = "Clarifier"
# ORCHESTRATOR = "Orchestrator"
# INTERNAL_SEARCH = "Internal Search"
# GENERIC_TOOL = "Generic Tool"
# KNOWLEDGE_GRAPH = "Knowledge Graph Search"
# WEB_SEARCH = "Web Search"
# IMAGE_GENERATION = "Image Generation"
# GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
# CLOSER = "Closer"
# LOGGER = "Logger"
# END = "End"
class ResearchAnswerPurpose(str, Enum):
"""Research answer purpose options for agent search operations"""
ANSWER = "ANSWER"
CLARIFICATION_REQUEST = "CLARIFICATION_REQUEST"
class DRPath(str, Enum):
CLARIFIER = "Clarifier"
ORCHESTRATOR = "Orchestrator"
INTERNAL_SEARCH = "Internal Search"
GENERIC_TOOL = "Generic Tool"
KNOWLEDGE_GRAPH = "Knowledge Graph Search"
WEB_SEARCH = "Web Search"
IMAGE_GENERATION = "Image Generation"
GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
CLOSER = "Closer"
LOGGER = "Logger"
END = "End"

View File

@@ -1,88 +1,88 @@
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from onyx.agents.agent_search.dr.conditional_edges import completeness_router
# from onyx.agents.agent_search.dr.conditional_edges import decision_router
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.nodes.dr_a0_clarification import clarifier
# from onyx.agents.agent_search.dr.nodes.dr_a1_orchestrator import orchestrator
# from onyx.agents.agent_search.dr.nodes.dr_a2_closer import closer
# from onyx.agents.agent_search.dr.nodes.dr_a3_logger import logging
# from onyx.agents.agent_search.dr.states import MainInput
# from onyx.agents.agent_search.dr.states import MainState
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_graph_builder import (
# dr_basic_search_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
# dr_custom_tool_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_graph_builder import (
# dr_generic_internal_tool_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_graph_builder import (
# dr_image_generation_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
# dr_kg_search_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_graph_builder import (
# dr_ws_graph_builder,
# )
from onyx.agents.agent_search.dr.conditional_edges import completeness_router
from onyx.agents.agent_search.dr.conditional_edges import decision_router
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.nodes.dr_a0_clarification import clarifier
from onyx.agents.agent_search.dr.nodes.dr_a1_orchestrator import orchestrator
from onyx.agents.agent_search.dr.nodes.dr_a2_closer import closer
from onyx.agents.agent_search.dr.nodes.dr_a3_logger import logging
from onyx.agents.agent_search.dr.states import MainInput
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_graph_builder import (
dr_basic_search_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
dr_custom_tool_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_graph_builder import (
dr_generic_internal_tool_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_graph_builder import (
dr_image_generation_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
dr_kg_search_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_graph_builder import (
dr_ws_graph_builder,
)
# # from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import search
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import search
# def dr_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for the deep research agent.
# """
def dr_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the deep research agent.
"""
# graph = StateGraph(state_schema=MainState, input=MainInput)
graph = StateGraph(state_schema=MainState, input=MainInput)
# ### Add nodes ###
### Add nodes ###
# graph.add_node(DRPath.CLARIFIER, clarifier)
graph.add_node(DRPath.CLARIFIER, clarifier)
# graph.add_node(DRPath.ORCHESTRATOR, orchestrator)
graph.add_node(DRPath.ORCHESTRATOR, orchestrator)
# basic_search_graph = dr_basic_search_graph_builder().compile()
# graph.add_node(DRPath.INTERNAL_SEARCH, basic_search_graph)
basic_search_graph = dr_basic_search_graph_builder().compile()
graph.add_node(DRPath.INTERNAL_SEARCH, basic_search_graph)
# kg_search_graph = dr_kg_search_graph_builder().compile()
# graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
kg_search_graph = dr_kg_search_graph_builder().compile()
graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
# internet_search_graph = dr_ws_graph_builder().compile()
# graph.add_node(DRPath.WEB_SEARCH, internet_search_graph)
internet_search_graph = dr_ws_graph_builder().compile()
graph.add_node(DRPath.WEB_SEARCH, internet_search_graph)
# image_generation_graph = dr_image_generation_graph_builder().compile()
# graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
image_generation_graph = dr_image_generation_graph_builder().compile()
graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
# custom_tool_graph = dr_custom_tool_graph_builder().compile()
# graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
custom_tool_graph = dr_custom_tool_graph_builder().compile()
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
# generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
# graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)
generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)
# graph.add_node(DRPath.CLOSER, closer)
# graph.add_node(DRPath.LOGGER, logging)
graph.add_node(DRPath.CLOSER, closer)
graph.add_node(DRPath.LOGGER, logging)
# ### Add edges ###
### Add edges ###
# graph.add_edge(start_key=START, end_key=DRPath.CLARIFIER)
graph.add_edge(start_key=START, end_key=DRPath.CLARIFIER)
# graph.add_conditional_edges(DRPath.CLARIFIER, decision_router)
graph.add_conditional_edges(DRPath.CLARIFIER, decision_router)
# graph.add_conditional_edges(DRPath.ORCHESTRATOR, decision_router)
graph.add_conditional_edges(DRPath.ORCHESTRATOR, decision_router)
# graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.WEB_SEARCH, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.WEB_SEARCH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
# graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
# graph.add_edge(start_key=DRPath.LOGGER, end_key=END)
graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
graph.add_edge(start_key=DRPath.LOGGER, end_key=END)
# return graph
return graph

View File

@@ -1,131 +1,131 @@
# from enum import Enum
from enum import Enum
# from pydantic import BaseModel
from pydantic import BaseModel
from pydantic import ConfigDict
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
# GeneratedImage,
# )
# from onyx.context.search.models import InferenceSection
# from onyx.tools.tool import Tool
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
GeneratedImage,
)
from onyx.context.search.models import InferenceSection
from onyx.tools.tool import Tool
# class OrchestratorStep(BaseModel):
# tool: str
# questions: list[str]
class OrchestratorStep(BaseModel):
tool: str
questions: list[str]
# class OrchestratorDecisonsNoPlan(BaseModel):
# reasoning: str
# next_step: OrchestratorStep
class OrchestratorDecisonsNoPlan(BaseModel):
reasoning: str
next_step: OrchestratorStep
# class OrchestrationPlan(BaseModel):
# reasoning: str
# plan: str
class OrchestrationPlan(BaseModel):
reasoning: str
plan: str
# class ClarificationGenerationResponse(BaseModel):
# clarification_needed: bool
# clarification_question: str
class ClarificationGenerationResponse(BaseModel):
clarification_needed: bool
clarification_question: str
# class DecisionResponse(BaseModel):
# reasoning: str
# decision: str
class DecisionResponse(BaseModel):
reasoning: str
decision: str
# class QueryEvaluationResponse(BaseModel):
# reasoning: str
# query_permitted: bool
class QueryEvaluationResponse(BaseModel):
reasoning: str
query_permitted: bool
# class OrchestrationClarificationInfo(BaseModel):
# clarification_question: str
# clarification_response: str | None = None
class OrchestrationClarificationInfo(BaseModel):
clarification_question: str
clarification_response: str | None = None
# class WebSearchAnswer(BaseModel):
# urls_to_open_indices: list[int]
class WebSearchAnswer(BaseModel):
urls_to_open_indices: list[int]
# class SearchAnswer(BaseModel):
# reasoning: str
# answer: str
# claims: list[str] | None = None
class SearchAnswer(BaseModel):
reasoning: str
answer: str
claims: list[str] | None = None
# class TestInfoCompleteResponse(BaseModel):
# reasoning: str
# complete: bool
# gaps: list[str]
class TestInfoCompleteResponse(BaseModel):
reasoning: str
complete: bool
gaps: list[str]
# # TODO: revisit with custom tools implementation in v2
# # each tool should be a class with the attributes below, plus the actual tool implementation
# # this will also allow custom tools to have their own cost
# class OrchestratorTool(BaseModel):
# tool_id: int
# name: str
# llm_path: str # the path for the LLM to refer by
# path: DRPath # the actual path in the graph
# description: str
# metadata: dict[str, str]
# cost: float
# tool_object: Tool | None = None # None for CLOSER
# TODO: revisit with custom tools implementation in v2
# each tool should be a class with the attributes below, plus the actual tool implementation
# this will also allow custom tools to have their own cost
class OrchestratorTool(BaseModel):
tool_id: int
name: str
llm_path: str # the path for the LLM to refer by
path: DRPath # the actual path in the graph
description: str
metadata: dict[str, str]
cost: float
tool_object: Tool | None = None # None for CLOSER
# class Config:
# arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)
# class IterationInstructions(BaseModel):
# iteration_nr: int
# plan: str | None
# reasoning: str
# purpose: str
class IterationInstructions(BaseModel):
iteration_nr: int
plan: str | None
reasoning: str
purpose: str
# class IterationAnswer(BaseModel):
# tool: str
# tool_id: int
# iteration_nr: int
# parallelization_nr: int
# question: str
# reasoning: str | None
# answer: str
# cited_documents: dict[int, InferenceSection]
# background_info: str | None = None
# claims: list[str] | None = None
# additional_data: dict[str, str] | None = None
# response_type: str | None = None
# data: dict | list | str | int | float | bool | None = None
# file_ids: list[str] | None = None
# # TODO: This is not ideal, but we'll can rework the schema
# # for deep research later
# is_web_fetch: bool = False
# # for image generation step-types
# generated_images: list[GeneratedImage] | None = None
# # for multi-query search tools (v2 web search and internal search)
# # TODO: Clean this up to be more flexible to tools
# queries: list[str] | None = None
class IterationAnswer(BaseModel):
tool: str
tool_id: int
iteration_nr: int
parallelization_nr: int
question: str
reasoning: str | None
answer: str
cited_documents: dict[int, InferenceSection]
background_info: str | None = None
claims: list[str] | None = None
additional_data: dict[str, str] | None = None
response_type: str | None = None
data: dict | list | str | int | float | bool | None = None
file_ids: list[str] | None = None
# TODO: This is not ideal, but we'll can rework the schema
# for deep research later
is_web_fetch: bool = False
# for image generation step-types
generated_images: list[GeneratedImage] | None = None
# for multi-query search tools (v2 web search and internal search)
# TODO: Clean this up to be more flexible to tools
queries: list[str] | None = None
# class AggregatedDRContext(BaseModel):
# context: str
# cited_documents: list[InferenceSection]
# is_internet_marker_dict: dict[str, bool]
# global_iteration_responses: list[IterationAnswer]
class AggregatedDRContext(BaseModel):
context: str
cited_documents: list[InferenceSection]
is_internet_marker_dict: dict[str, bool]
global_iteration_responses: list[IterationAnswer]
# class DRPromptPurpose(str, Enum):
# PLAN = "PLAN"
# NEXT_STEP = "NEXT_STEP"
# NEXT_STEP_REASONING = "NEXT_STEP_REASONING"
# NEXT_STEP_PURPOSE = "NEXT_STEP_PURPOSE"
# CLARIFICATION = "CLARIFICATION"
class DRPromptPurpose(str, Enum):
PLAN = "PLAN"
NEXT_STEP = "NEXT_STEP"
NEXT_STEP_REASONING = "NEXT_STEP_REASONING"
NEXT_STEP_PURPOSE = "NEXT_STEP_PURPOSE"
CLARIFICATION = "CLARIFICATION"
# class BaseSearchProcessingResponse(BaseModel):
# specified_source_types: list[str]
# rewritten_query: str
# time_filter: str
class BaseSearchProcessingResponse(BaseModel):
specified_source_types: list[str]
rewritten_query: str
time_filter: str

File diff suppressed because it is too large Load Diff

View File

@@ -1,418 +1,423 @@
# import re
# from datetime import datetime
# from typing import cast
import re
from datetime import datetime
from typing import cast
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
# from sqlalchemy.orm import Session
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
# from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
# from onyx.agents.agent_search.dr.constants import MAX_NUM_CLOSER_SUGGESTIONS
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
# from onyx.agents.agent_search.dr.models import AggregatedDRContext
# from onyx.agents.agent_search.dr.models import TestInfoCompleteResponse
# from onyx.agents.agent_search.dr.states import FinalUpdate
# from onyx.agents.agent_search.dr.states import MainState
# from onyx.agents.agent_search.dr.states import OrchestrationUpdate
# from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
# GeneratedImageFullResult,
# )
# from onyx.agents.agent_search.dr.utils import aggregate_context
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.dr.utils import get_chat_history_string
# from onyx.agents.agent_search.dr.utils import get_prompt_question
# from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
# from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.agents.agent_search.utils import create_question_prompt
# from onyx.chat.chat_utils import llm_doc_from_inference_section
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
# from onyx.context.search.models import InferenceSection
# from onyx.db.chat import create_search_doc_from_inference_section
# from onyx.db.chat import update_db_session_with_messages
# from onyx.db.models import ChatMessage__SearchDoc
# from onyx.db.models import ResearchAgentIteration
# from onyx.db.models import ResearchAgentIterationSubStep
# from onyx.db.models import SearchDoc as DbSearchDoc
# from onyx.llm.utils import check_number_of_tokens
# from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
# from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
# from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
# from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
# from onyx.server.query_and_chat.streaming_models import CitationDelta
# from onyx.server.query_and_chat.streaming_models import CitationStart
# from onyx.server.query_and_chat.streaming_models import MessageStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.server.query_and_chat.streaming_models import StreamingType
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
from onyx.agents.agent_search.dr.constants import MAX_NUM_CLOSER_SUGGESTIONS
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.models import TestInfoCompleteResponse
from onyx.agents.agent_search.dr.states import FinalUpdate
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.states import OrchestrationUpdate
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
GeneratedImageFullResult,
)
from onyx.agents.agent_search.dr.utils import aggregate_context
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.dr.utils import get_chat_history_string
from onyx.agents.agent_search.dr.utils import get_prompt_question
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.chat.chat_utils import llm_doc_from_inference_section
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
from onyx.context.search.models import InferenceSection
from onyx.db.chat import create_search_doc_from_inference_section
from onyx.db.chat import update_db_session_with_messages
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ResearchAgentIteration
from onyx.db.models import ResearchAgentIterationSubStep
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.llm.utils import check_number_of_tokens
from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.server.query_and_chat.streaming_models import StreamingType
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# logger = setup_logger()
logger = setup_logger()
# def extract_citation_numbers(text: str) -> list[int]:
# """
# Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
# Returns a list of all unique citation numbers found.
# """
# # Pattern to match [[number]] or [[number1, number2, ...]]
# pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
# matches = re.findall(pattern, text)
def extract_citation_numbers(text: str) -> list[int]:
"""
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
Returns a list of all unique citation numbers found.
"""
# Pattern to match [[number]] or [[number1, number2, ...]]
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
matches = re.findall(pattern, text)
# cited_numbers = []
# for match in matches:
# # Split by comma and extract all numbers
# numbers = [int(num.strip()) for num in match.split(",")]
# cited_numbers.extend(numbers)
cited_numbers = []
for match in matches:
# Split by comma and extract all numbers
numbers = [int(num.strip()) for num in match.split(",")]
cited_numbers.extend(numbers)
# return list(set(cited_numbers)) # Return unique numbers
return list(set(cited_numbers)) # Return unique numbers
# def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
# citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
# numbers = [int(num.strip()) for num in citation_content.split(",")]
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
numbers = [int(num.strip()) for num in citation_content.split(",")]
# # For multiple citations like [[3, 5, 7]], create separate linked citations
# linked_citations = []
# for num in numbers:
# if num - 1 < len(docs): # Check bounds
# link = docs[num - 1].link or ""
# linked_citations.append(f"[[{num}]]({link})")
# else:
# linked_citations.append(f"[[{num}]]") # No link if out of bounds
# For multiple citations like [[3, 5, 7]], create separate linked citations
linked_citations = []
for num in numbers:
if num - 1 < len(docs): # Check bounds
link = docs[num - 1].link or ""
linked_citations.append(f"[[{num}]]({link})")
else:
linked_citations.append(f"[[{num}]]") # No link if out of bounds
# return "".join(linked_citations)
return "".join(linked_citations)
# def insert_chat_message_search_doc_pair(
# message_id: int, search_doc_ids: list[int], db_session: Session
# ) -> None:
# """
# Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
def insert_chat_message_search_doc_pair(
message_id: int, search_doc_ids: list[int], db_session: Session
) -> None:
"""
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
# Args:
# message_id: The ID of the chat message
# search_doc_id: The ID of the search document
# db_session: The database session
# """
# for search_doc_id in search_doc_ids:
# chat_message_search_doc = ChatMessage__SearchDoc(
# chat_message_id=message_id, search_doc_id=search_doc_id
# )
# db_session.add(chat_message_search_doc)
Args:
message_id: The ID of the chat message
search_doc_id: The ID of the search document
db_session: The database session
"""
for search_doc_id in search_doc_ids:
chat_message_search_doc = ChatMessage__SearchDoc(
chat_message_id=message_id, search_doc_id=search_doc_id
)
db_session.add(chat_message_search_doc)
# def save_iteration(
# state: MainState,
# graph_config: GraphConfig,
# aggregated_context: AggregatedDRContext,
# final_answer: str,
# all_cited_documents: list[InferenceSection],
# is_internet_marker_dict: dict[str, bool],
# ) -> None:
# db_session = graph_config.persistence.db_session
# message_id = graph_config.persistence.message_id
def save_iteration(
state: MainState,
graph_config: GraphConfig,
aggregated_context: AggregatedDRContext,
final_answer: str,
all_cited_documents: list[InferenceSection],
is_internet_marker_dict: dict[str, bool],
) -> None:
db_session = graph_config.persistence.db_session
message_id = graph_config.persistence.message_id
research_type = graph_config.behavior.research_type
db_session = graph_config.persistence.db_session
# # first, insert the search_docs
# search_docs = [
# create_search_doc_from_inference_section(
# inference_section=inference_section,
# is_internet=is_internet_marker_dict.get(
# inference_section.center_chunk.document_id, False
# ), # TODO: revisit
# db_session=db_session,
# commit=False,
# )
# for inference_section in all_cited_documents
# ]
# first, insert the search_docs
search_docs = [
create_search_doc_from_inference_section(
inference_section=inference_section,
is_internet=is_internet_marker_dict.get(
inference_section.center_chunk.document_id, False
), # TODO: revisit
db_session=db_session,
commit=False,
)
for inference_section in all_cited_documents
]
# # then, map_search_docs to message
# insert_chat_message_search_doc_pair(
# message_id, [search_doc.id for search_doc in search_docs], db_session
# )
# then, map_search_docs to message
insert_chat_message_search_doc_pair(
message_id, [search_doc.id for search_doc in search_docs], db_session
)
# # lastly, insert the citations
# citation_dict: dict[int, int] = {}
# cited_doc_nrs = extract_citation_numbers(final_answer)
# for cited_doc_nr in cited_doc_nrs:
# citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# lastly, insert the citations
citation_dict: dict[int, int] = {}
cited_doc_nrs = extract_citation_numbers(final_answer)
for cited_doc_nr in cited_doc_nrs:
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# # TODO: generate plan as dict in the first place
# plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
# plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# TODO: generate plan as dict in the first place
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# # Update the chat message and its parent message in database
# update_db_session_with_messages(
# db_session=db_session,
# chat_message_id=message_id,
# chat_session_id=graph_config.persistence.chat_session_id,
# is_agentic=graph_config.behavior.use_agentic_search,
# message=final_answer,
# citations=citation_dict,
# research_type=None, # research_type is deprecated
# research_plan=plan_of_record_dict,
# final_documents=search_docs,
# update_parent_message=True,
# research_answer_purpose=ResearchAnswerPurpose.ANSWER,
# )
# Update the chat message and its parent message in database
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=graph_config.persistence.chat_session_id,
is_agentic=graph_config.behavior.use_agentic_search,
message=final_answer,
citations=citation_dict,
research_type=research_type,
research_plan=plan_of_record_dict,
final_documents=search_docs,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
)
# for iteration_preparation in state.iteration_instructions:
# research_agent_iteration_step = ResearchAgentIteration(
# primary_question_id=message_id,
# reasoning=iteration_preparation.reasoning,
# purpose=iteration_preparation.purpose,
# iteration_nr=iteration_preparation.iteration_nr,
# )
# db_session.add(research_agent_iteration_step)
for iteration_preparation in state.iteration_instructions:
research_agent_iteration_step = ResearchAgentIteration(
primary_question_id=message_id,
reasoning=iteration_preparation.reasoning,
purpose=iteration_preparation.purpose,
iteration_nr=iteration_preparation.iteration_nr,
)
db_session.add(research_agent_iteration_step)
# for iteration_answer in aggregated_context.global_iteration_responses:
for iteration_answer in aggregated_context.global_iteration_responses:
# retrieved_search_docs = convert_inference_sections_to_search_docs(
# list(iteration_answer.cited_documents.values())
# )
retrieved_search_docs = convert_inference_sections_to_search_docs(
list(iteration_answer.cited_documents.values())
)
# # Convert SavedSearchDoc objects to JSON-serializable format
# serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
# Convert SavedSearchDoc objects to JSON-serializable format
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
# research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
# primary_question_id=message_id,
# iteration_nr=iteration_answer.iteration_nr,
# iteration_sub_step_nr=iteration_answer.parallelization_nr,
# sub_step_instructions=iteration_answer.question,
# sub_step_tool_id=iteration_answer.tool_id,
# sub_answer=iteration_answer.answer,
# reasoning=iteration_answer.reasoning,
# claims=iteration_answer.claims,
# cited_doc_results=serialized_search_docs,
# generated_images=(
# GeneratedImageFullResult(images=iteration_answer.generated_images)
# if iteration_answer.generated_images
# else None
# ),
# additional_data=iteration_answer.additional_data,
# queries=iteration_answer.queries,
# )
# db_session.add(research_agent_iteration_sub_step)
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
primary_question_id=message_id,
iteration_nr=iteration_answer.iteration_nr,
iteration_sub_step_nr=iteration_answer.parallelization_nr,
sub_step_instructions=iteration_answer.question,
sub_step_tool_id=iteration_answer.tool_id,
sub_answer=iteration_answer.answer,
reasoning=iteration_answer.reasoning,
claims=iteration_answer.claims,
cited_doc_results=serialized_search_docs,
generated_images=(
GeneratedImageFullResult(images=iteration_answer.generated_images)
if iteration_answer.generated_images
else None
),
additional_data=iteration_answer.additional_data,
queries=iteration_answer.queries,
)
db_session.add(research_agent_iteration_sub_step)
# db_session.commit()
db_session.commit()
# def closer(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> FinalUpdate | OrchestrationUpdate:
# """
# LangGraph node to close the DR process and finalize the answer.
# """
def closer(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> FinalUpdate | OrchestrationUpdate:
"""
LangGraph node to close the DR process and finalize the answer.
"""
# node_start_time = datetime.now()
# # TODO: generate final answer using all the previous steps
# # (right now, answers from each step are concatenated onto each other)
# # Also, add missing fields once usage in UI is clear.
node_start_time = datetime.now()
# TODO: generate final answer using all the previous steps
# (right now, answers from each step are concatenated onto each other)
# Also, add missing fields once usage in UI is clear.
# current_step_nr = state.current_step_nr
current_step_nr = state.current_step_nr
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = state.original_question
# if not base_question:
# raise ValueError("Question is required for closer")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = state.original_question
if not base_question:
raise ValueError("Question is required for closer")
# use_agentic_search = graph_config.behavior.use_agentic_search
research_type = graph_config.behavior.research_type
# assistant_system_prompt: str = state.assistant_system_prompt or ""
# assistant_task_prompt = state.assistant_task_prompt
assistant_system_prompt: str = state.assistant_system_prompt or ""
assistant_task_prompt = state.assistant_task_prompt
# uploaded_context = state.uploaded_test_context or ""
uploaded_context = state.uploaded_test_context or ""
# clarification = state.clarification
# prompt_question = get_prompt_question(base_question, clarification)
clarification = state.clarification
prompt_question = get_prompt_question(base_question, clarification)
# chat_history_string = (
# get_chat_history_string(
# graph_config.inputs.prompt_builder.message_history,
# MAX_CHAT_HISTORY_MESSAGES,
# )
# or "(No chat history yet available)"
# )
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
# aggregated_context_w_docs = aggregate_context(
# state.iteration_responses, include_documents=True
# )
aggregated_context_w_docs = aggregate_context(
state.iteration_responses, include_documents=True
)
# aggregated_context_wo_docs = aggregate_context(
# state.iteration_responses, include_documents=False
# )
aggregated_context_wo_docs = aggregate_context(
state.iteration_responses, include_documents=False
)
# iteration_responses_w_docs_string = aggregated_context_w_docs.context
# iteration_responses_wo_docs_string = aggregated_context_wo_docs.context
# all_cited_documents = aggregated_context_w_docs.cited_documents
iteration_responses_w_docs_string = aggregated_context_w_docs.context
iteration_responses_wo_docs_string = aggregated_context_wo_docs.context
all_cited_documents = aggregated_context_w_docs.cited_documents
# num_closer_suggestions = state.num_closer_suggestions
num_closer_suggestions = state.num_closer_suggestions
# if (
# num_closer_suggestions < MAX_NUM_CLOSER_SUGGESTIONS
# and use_agentic_search
# ):
# test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
# base_question=prompt_question,
# questions_answers_claims=iteration_responses_wo_docs_string,
# chat_history_string=chat_history_string,
# high_level_plan=(
# state.plan_of_record.plan
# if state.plan_of_record
# else "No plan available"
# ),
# )
if (
num_closer_suggestions < MAX_NUM_CLOSER_SUGGESTIONS
and research_type == ResearchType.DEEP
):
test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
base_question=prompt_question,
questions_answers_claims=iteration_responses_wo_docs_string,
chat_history_string=chat_history_string,
high_level_plan=(
state.plan_of_record.plan
if state.plan_of_record
else "No plan available"
),
)
# test_info_complete_json = invoke_llm_json(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt,
# test_info_complete_prompt + (assistant_task_prompt or ""),
# ),
# schema=TestInfoCompleteResponse,
# timeout_override=TF_DR_TIMEOUT_LONG,
# # max_tokens=1000,
# )
test_info_complete_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
test_info_complete_prompt + (assistant_task_prompt or ""),
),
schema=TestInfoCompleteResponse,
timeout_override=TF_DR_TIMEOUT_LONG,
# max_tokens=1000,
)
# if test_info_complete_json.complete:
# pass
if test_info_complete_json.complete:
pass
# else:
# return OrchestrationUpdate(
# tools_used=[DRPath.ORCHESTRATOR.value],
# query_list=[],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="closer",
# node_start_time=node_start_time,
# )
# ],
# gaps=test_info_complete_json.gaps,
# num_closer_suggestions=num_closer_suggestions + 1,
# )
else:
return OrchestrationUpdate(
tools_used=[DRPath.ORCHESTRATOR.value],
query_list=[],
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="closer",
node_start_time=node_start_time,
)
],
gaps=test_info_complete_json.gaps,
num_closer_suggestions=num_closer_suggestions + 1,
)
# retrieved_search_docs = convert_inference_sections_to_search_docs(
# all_cited_documents
# )
retrieved_search_docs = convert_inference_sections_to_search_docs(
all_cited_documents
)
# write_custom_event(
# current_step_nr,
# MessageStart(
# content="",
# final_documents=retrieved_search_docs,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
MessageStart(
content="",
final_documents=retrieved_search_docs,
),
writer,
)
# if not use_agentic_search:
# final_answer_base_prompt = FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
# else:
# final_answer_base_prompt = FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
if research_type in [ResearchType.THOUGHTFUL, ResearchType.FAST]:
final_answer_base_prompt = FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
elif research_type == ResearchType.DEEP:
final_answer_base_prompt = FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
else:
raise ValueError(f"Invalid research type: {research_type}")
# estimated_final_answer_prompt_tokens = check_number_of_tokens(
# final_answer_base_prompt.build(
# base_question=prompt_question,
# iteration_responses_string=iteration_responses_w_docs_string,
# chat_history_string=chat_history_string,
# uploaded_context=uploaded_context,
# )
# )
estimated_final_answer_prompt_tokens = check_number_of_tokens(
final_answer_base_prompt.build(
base_question=prompt_question,
iteration_responses_string=iteration_responses_w_docs_string,
chat_history_string=chat_history_string,
uploaded_context=uploaded_context,
)
)
# # for DR, rely only on sub-answers and claims to save tokens if context is too long
# # TODO: consider compression step for Thoughtful mode if context is too long.
# # Should generally not be the case though.
# for DR, rely only on sub-answers and claims to save tokens if context is too long
# TODO: consider compression step for Thoughtful mode if context is too long.
# Should generally not be the case though.
# max_allowed_input_tokens = graph_config.tooling.primary_llm.config.max_input_tokens
max_allowed_input_tokens = graph_config.tooling.primary_llm.config.max_input_tokens
# if (
# estimated_final_answer_prompt_tokens > 0.8 * max_allowed_input_tokens
# and use_agentic_search
# ):
# iteration_responses_string = iteration_responses_wo_docs_string
# else:
# iteration_responses_string = iteration_responses_w_docs_string
if (
estimated_final_answer_prompt_tokens > 0.8 * max_allowed_input_tokens
and research_type == ResearchType.DEEP
):
iteration_responses_string = iteration_responses_wo_docs_string
else:
iteration_responses_string = iteration_responses_w_docs_string
# final_answer_prompt = final_answer_base_prompt.build(
# base_question=prompt_question,
# iteration_responses_string=iteration_responses_string,
# chat_history_string=chat_history_string,
# uploaded_context=uploaded_context,
# )
final_answer_prompt = final_answer_base_prompt.build(
base_question=prompt_question,
iteration_responses_string=iteration_responses_string,
chat_history_string=chat_history_string,
uploaded_context=uploaded_context,
)
# if graph_config.inputs.project_instructions:
# assistant_system_prompt = (
# assistant_system_prompt
# + PROJECT_INSTRUCTIONS_SEPARATOR
# + (graph_config.inputs.project_instructions or "")
# )
if graph_config.inputs.project_instructions:
assistant_system_prompt = (
assistant_system_prompt
+ PROJECT_INSTRUCTIONS_SEPARATOR
+ (graph_config.inputs.project_instructions or "")
)
# all_context_llmdocs = [
# llm_doc_from_inference_section(inference_section)
# for inference_section in all_cited_documents
# ]
all_context_llmdocs = [
llm_doc_from_inference_section(inference_section)
for inference_section in all_cited_documents
]
# try:
# streamed_output, _, citation_infos = run_with_timeout(
# int(3 * TF_DR_TIMEOUT_LONG),
# lambda: stream_llm_answer(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt,
# final_answer_prompt + (assistant_task_prompt or ""),
# ),
# event_name="basic_response",
# writer=writer,
# agent_answer_level=0,
# agent_answer_question_num=0,
# agent_answer_type="agent_level_answer",
# timeout_override=int(2 * TF_DR_TIMEOUT_LONG),
# answer_piece=StreamingType.MESSAGE_DELTA.value,
# ind=current_step_nr,
# context_docs=all_context_llmdocs,
# replace_citations=True,
# # max_tokens=None,
# ),
# )
try:
streamed_output, _, citation_infos = run_with_timeout(
int(3 * TF_DR_TIMEOUT_LONG),
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
final_answer_prompt + (assistant_task_prompt or ""),
),
event_name="basic_response",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=int(2 * TF_DR_TIMEOUT_LONG),
answer_piece=StreamingType.MESSAGE_DELTA.value,
ind=current_step_nr,
context_docs=all_context_llmdocs,
replace_citations=True,
# max_tokens=None,
),
)
# final_answer = "".join(streamed_output)
# except Exception as e:
# raise ValueError(f"Error in consolidate_research: {e}")
final_answer = "".join(streamed_output)
except Exception as e:
raise ValueError(f"Error in consolidate_research: {e}")
# write_custom_event(current_step_nr, SectionEnd(), writer)
write_custom_event(current_step_nr, SectionEnd(), writer)
# current_step_nr += 1
current_step_nr += 1
# write_custom_event(current_step_nr, CitationStart(), writer)
# write_custom_event(current_step_nr, CitationDelta(citations=citation_infos), writer)
# write_custom_event(current_step_nr, SectionEnd(), writer)
write_custom_event(current_step_nr, CitationStart(), writer)
write_custom_event(current_step_nr, CitationDelta(citations=citation_infos), writer)
write_custom_event(current_step_nr, SectionEnd(), writer)
# current_step_nr += 1
current_step_nr += 1
# # Log the research agent steps
# # save_iteration(
# # state,
# # graph_config,
# # aggregated_context,
# # final_answer,
# # all_cited_documents,
# # is_internet_marker_dict,
# # )
# Log the research agent steps
# save_iteration(
# state,
# graph_config,
# aggregated_context,
# final_answer,
# all_cited_documents,
# is_internet_marker_dict,
# )
# return FinalUpdate(
# final_answer=final_answer,
# all_cited_documents=all_cited_documents,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="closer",
# node_start_time=node_start_time,
# )
# ],
# )
return FinalUpdate(
final_answer=final_answer,
all_cited_documents=all_cited_documents,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="closer",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,246 +1,248 @@
# import re
# from datetime import datetime
# from typing import cast
import re
from datetime import datetime
from typing import cast
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
# from sqlalchemy.orm import Session
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
# from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
# from onyx.agents.agent_search.dr.models import AggregatedDRContext
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.states import MainState
# from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
# GeneratedImageFullResult,
# )
# from onyx.agents.agent_search.dr.utils import aggregate_context
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.context.search.models import InferenceSection
# from onyx.db.chat import create_search_doc_from_inference_section
# from onyx.db.chat import update_db_session_with_messages
# from onyx.db.models import ChatMessage__SearchDoc
# from onyx.db.models import ResearchAgentIteration
# from onyx.db.models import ResearchAgentIterationSubStep
# from onyx.db.models import SearchDoc as DbSearchDoc
# from onyx.natural_language_processing.utils import get_tokenizer
# from onyx.server.query_and_chat.streaming_models import OverallStop
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
GeneratedImageFullResult,
)
from onyx.agents.agent_search.dr.utils import aggregate_context
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.context.search.models import InferenceSection
from onyx.db.chat import create_search_doc_from_inference_section
from onyx.db.chat import update_db_session_with_messages
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ResearchAgentIteration
from onyx.db.models import ResearchAgentIterationSubStep
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def _extract_citation_numbers(text: str) -> list[int]:
# """
# Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
# Returns a list of all unique citation numbers found.
# """
# # Pattern to match [[number]] or [[number1, number2, ...]]
# pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
# matches = re.findall(pattern, text)
def _extract_citation_numbers(text: str) -> list[int]:
"""
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
Returns a list of all unique citation numbers found.
"""
# Pattern to match [[number]] or [[number1, number2, ...]]
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
matches = re.findall(pattern, text)
# cited_numbers = []
# for match in matches:
# # Split by comma and extract all numbers
# numbers = [int(num.strip()) for num in match.split(",")]
# cited_numbers.extend(numbers)
cited_numbers = []
for match in matches:
# Split by comma and extract all numbers
numbers = [int(num.strip()) for num in match.split(",")]
cited_numbers.extend(numbers)
# return list(set(cited_numbers)) # Return unique numbers
return list(set(cited_numbers)) # Return unique numbers
# def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
# citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
# numbers = [int(num.strip()) for num in citation_content.split(",")]
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
numbers = [int(num.strip()) for num in citation_content.split(",")]
# # For multiple citations like [[3, 5, 7]], create separate linked citations
# linked_citations = []
# for num in numbers:
# if num - 1 < len(docs): # Check bounds
# link = docs[num - 1].link or ""
# linked_citations.append(f"[[{num}]]({link})")
# else:
# linked_citations.append(f"[[{num}]]") # No link if out of bounds
# For multiple citations like [[3, 5, 7]], create separate linked citations
linked_citations = []
for num in numbers:
if num - 1 < len(docs): # Check bounds
link = docs[num - 1].link or ""
linked_citations.append(f"[[{num}]]({link})")
else:
linked_citations.append(f"[[{num}]]") # No link if out of bounds
# return "".join(linked_citations)
return "".join(linked_citations)
# def _insert_chat_message_search_doc_pair(
# message_id: int, search_doc_ids: list[int], db_session: Session
# ) -> None:
# """
# Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
def _insert_chat_message_search_doc_pair(
message_id: int, search_doc_ids: list[int], db_session: Session
) -> None:
"""
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
# Args:
# message_id: The ID of the chat message
# search_doc_id: The ID of the search document
# db_session: The database session
# """
# for search_doc_id in search_doc_ids:
# chat_message_search_doc = ChatMessage__SearchDoc(
# chat_message_id=message_id, search_doc_id=search_doc_id
# )
# db_session.add(chat_message_search_doc)
Args:
message_id: The ID of the chat message
search_doc_id: The ID of the search document
db_session: The database session
"""
for search_doc_id in search_doc_ids:
chat_message_search_doc = ChatMessage__SearchDoc(
chat_message_id=message_id, search_doc_id=search_doc_id
)
db_session.add(chat_message_search_doc)
# def save_iteration(
# state: MainState,
# graph_config: GraphConfig,
# aggregated_context: AggregatedDRContext,
# final_answer: str,
# all_cited_documents: list[InferenceSection],
# is_internet_marker_dict: dict[str, bool],
# num_tokens: int,
# ) -> None:
# db_session = graph_config.persistence.db_session
# message_id = graph_config.persistence.message_id
def save_iteration(
state: MainState,
graph_config: GraphConfig,
aggregated_context: AggregatedDRContext,
final_answer: str,
all_cited_documents: list[InferenceSection],
is_internet_marker_dict: dict[str, bool],
num_tokens: int,
) -> None:
db_session = graph_config.persistence.db_session
message_id = graph_config.persistence.message_id
research_type = graph_config.behavior.research_type
db_session = graph_config.persistence.db_session
# # first, insert the search_docs
# search_docs = [
# create_search_doc_from_inference_section(
# inference_section=inference_section,
# is_internet=is_internet_marker_dict.get(
# inference_section.center_chunk.document_id, False
# ), # TODO: revisit
# db_session=db_session,
# commit=False,
# )
# for inference_section in all_cited_documents
# ]
# first, insert the search_docs
search_docs = [
create_search_doc_from_inference_section(
inference_section=inference_section,
is_internet=is_internet_marker_dict.get(
inference_section.center_chunk.document_id, False
), # TODO: revisit
db_session=db_session,
commit=False,
)
for inference_section in all_cited_documents
]
# # then, map_search_docs to message
# _insert_chat_message_search_doc_pair(
# message_id, [search_doc.id for search_doc in search_docs], db_session
# )
# then, map_search_docs to message
_insert_chat_message_search_doc_pair(
message_id, [search_doc.id for search_doc in search_docs], db_session
)
# # lastly, insert the citations
# citation_dict: dict[int, int] = {}
# cited_doc_nrs = _extract_citation_numbers(final_answer)
# if search_docs:
# for cited_doc_nr in cited_doc_nrs:
# citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# lastly, insert the citations
citation_dict: dict[int, int] = {}
cited_doc_nrs = _extract_citation_numbers(final_answer)
if search_docs:
for cited_doc_nr in cited_doc_nrs:
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# # TODO: generate plan as dict in the first place
# plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
# plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# TODO: generate plan as dict in the first place
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# # Update the chat message and its parent message in database
# update_db_session_with_messages(
# db_session=db_session,
# chat_message_id=message_id,
# chat_session_id=graph_config.persistence.chat_session_id,
# is_agentic=graph_config.behavior.use_agentic_search,
# message=final_answer,
# citations=citation_dict,
# research_type=None, # research_type is deprecated
# research_plan=plan_of_record_dict,
# final_documents=search_docs,
# update_parent_message=True,
# research_answer_purpose=ResearchAnswerPurpose.ANSWER,
# token_count=num_tokens,
# )
# Update the chat message and its parent message in database
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=graph_config.persistence.chat_session_id,
is_agentic=graph_config.behavior.use_agentic_search,
message=final_answer,
citations=citation_dict,
research_type=research_type,
research_plan=plan_of_record_dict,
final_documents=search_docs,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
token_count=num_tokens,
)
# for iteration_preparation in state.iteration_instructions:
# research_agent_iteration_step = ResearchAgentIteration(
# primary_question_id=message_id,
# reasoning=iteration_preparation.reasoning,
# purpose=iteration_preparation.purpose,
# iteration_nr=iteration_preparation.iteration_nr,
# )
# db_session.add(research_agent_iteration_step)
for iteration_preparation in state.iteration_instructions:
research_agent_iteration_step = ResearchAgentIteration(
primary_question_id=message_id,
reasoning=iteration_preparation.reasoning,
purpose=iteration_preparation.purpose,
iteration_nr=iteration_preparation.iteration_nr,
)
db_session.add(research_agent_iteration_step)
# for iteration_answer in aggregated_context.global_iteration_responses:
for iteration_answer in aggregated_context.global_iteration_responses:
# retrieved_search_docs = convert_inference_sections_to_search_docs(
# list(iteration_answer.cited_documents.values())
# )
retrieved_search_docs = convert_inference_sections_to_search_docs(
list(iteration_answer.cited_documents.values())
)
# # Convert SavedSearchDoc objects to JSON-serializable format
# serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
# Convert SavedSearchDoc objects to JSON-serializable format
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
# research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
# primary_question_id=message_id,
# iteration_nr=iteration_answer.iteration_nr,
# iteration_sub_step_nr=iteration_answer.parallelization_nr,
# sub_step_instructions=iteration_answer.question,
# sub_step_tool_id=iteration_answer.tool_id,
# sub_answer=iteration_answer.answer,
# reasoning=iteration_answer.reasoning,
# claims=iteration_answer.claims,
# cited_doc_results=serialized_search_docs,
# generated_images=(
# GeneratedImageFullResult(images=iteration_answer.generated_images)
# if iteration_answer.generated_images
# else None
# ),
# additional_data=iteration_answer.additional_data,
# queries=iteration_answer.queries,
# )
# db_session.add(research_agent_iteration_sub_step)
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
primary_question_id=message_id,
iteration_nr=iteration_answer.iteration_nr,
iteration_sub_step_nr=iteration_answer.parallelization_nr,
sub_step_instructions=iteration_answer.question,
sub_step_tool_id=iteration_answer.tool_id,
sub_answer=iteration_answer.answer,
reasoning=iteration_answer.reasoning,
claims=iteration_answer.claims,
cited_doc_results=serialized_search_docs,
generated_images=(
GeneratedImageFullResult(images=iteration_answer.generated_images)
if iteration_answer.generated_images
else None
),
additional_data=iteration_answer.additional_data,
queries=iteration_answer.queries,
)
db_session.add(research_agent_iteration_sub_step)
# db_session.commit()
db_session.commit()
# def logging(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to close the DR process and finalize the answer.
# """
def logging(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to close the DR process and finalize the answer.
"""
# node_start_time = datetime.now()
# # TODO: generate final answer using all the previous steps
# # (right now, answers from each step are concatenated onto each other)
# # Also, add missing fields once usage in UI is clear.
node_start_time = datetime.now()
# TODO: generate final answer using all the previous steps
# (right now, answers from each step are concatenated onto each other)
# Also, add missing fields once usage in UI is clear.
# current_step_nr = state.current_step_nr
current_step_nr = state.current_step_nr
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = state.original_question
# if not base_question:
# raise ValueError("Question is required for closer")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = state.original_question
if not base_question:
raise ValueError("Question is required for closer")
# aggregated_context = aggregate_context(
# state.iteration_responses, include_documents=True
# )
aggregated_context = aggregate_context(
state.iteration_responses, include_documents=True
)
# all_cited_documents = aggregated_context.cited_documents
all_cited_documents = aggregated_context.cited_documents
# is_internet_marker_dict = aggregated_context.is_internet_marker_dict
is_internet_marker_dict = aggregated_context.is_internet_marker_dict
# final_answer = state.final_answer or ""
# llm_provider = graph_config.tooling.primary_llm.config.model_provider
# llm_model_name = graph_config.tooling.primary_llm.config.model_name
final_answer = state.final_answer or ""
llm_provider = graph_config.tooling.primary_llm.config.model_provider
llm_model_name = graph_config.tooling.primary_llm.config.model_name
# llm_tokenizer = get_tokenizer(
# model_name=llm_model_name,
# provider_type=llm_provider,
# )
# num_tokens = len(llm_tokenizer.encode(final_answer or ""))
llm_tokenizer = get_tokenizer(
model_name=llm_model_name,
provider_type=llm_provider,
)
num_tokens = len(llm_tokenizer.encode(final_answer or ""))
# write_custom_event(current_step_nr, OverallStop(), writer)
write_custom_event(current_step_nr, OverallStop(), writer)
# # Log the research agent steps
# save_iteration(
# state,
# graph_config,
# aggregated_context,
# final_answer,
# all_cited_documents,
# is_internet_marker_dict,
# num_tokens,
# )
# Log the research agent steps
save_iteration(
state,
graph_config,
aggregated_context,
final_answer,
all_cited_documents,
is_internet_marker_dict,
num_tokens,
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="logger",
# node_start_time=node_start_time,
# )
# ],
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="logger",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,131 +1,132 @@
# from collections.abc import Iterator
# from typing import cast
from collections.abc import Iterator
from typing import cast
# from langchain_core.messages import AIMessageChunk
# from langchain_core.messages import BaseMessage
# from langgraph.types import StreamWriter
# from pydantic import BaseModel
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langgraph.types import StreamWriter
from pydantic import BaseModel
# from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
# from onyx.chat.models import AgentAnswerPiece
# from onyx.chat.models import CitationInfo
# from onyx.chat.models import LlmDoc
# from onyx.chat.models import OnyxAnswerPiece
# from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
# from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
# from onyx.chat.stream_processing.answer_response_handler import (
# PassThroughAnswerResponseHandler,
# )
# from onyx.chat.stream_processing.utils import map_document_id_order
# from onyx.context.search.models import InferenceSection
# from onyx.server.query_and_chat.streaming_models import CitationDelta
# from onyx.server.query_and_chat.streaming_models import CitationStart
# from onyx.server.query_and_chat.streaming_models import MessageDelta
# from onyx.server.query_and_chat.streaming_models import MessageStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import CitationInfo
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
PassThroughAnswerResponseHandler,
)
from onyx.chat.stream_processing.utils import map_document_id_order
from onyx.context.search.models import InferenceSection
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# class BasicSearchProcessedStreamResults(BaseModel):
# ai_message_chunk: AIMessageChunk = AIMessageChunk(content="")
# full_answer: str | None = None
# cited_references: list[InferenceSection] = []
# retrieved_documents: list[LlmDoc] = []
class BasicSearchProcessedStreamResults(BaseModel):
ai_message_chunk: AIMessageChunk = AIMessageChunk(content="")
full_answer: str | None = None
cited_references: list[InferenceSection] = []
retrieved_documents: list[LlmDoc] = []
# def process_llm_stream(
# messages: Iterator[BaseMessage],
# should_stream_answer: bool,
# writer: StreamWriter,
# ind: int,
# search_results: list[LlmDoc] | None = None,
# generate_final_answer: bool = False,
# chat_message_id: str | None = None,
# ) -> BasicSearchProcessedStreamResults:
# tool_call_chunk = AIMessageChunk(content="")
def process_llm_stream(
messages: Iterator[BaseMessage],
should_stream_answer: bool,
writer: StreamWriter,
ind: int,
search_results: list[LlmDoc] | None = None,
generate_final_answer: bool = False,
chat_message_id: str | None = None,
) -> BasicSearchProcessedStreamResults:
tool_call_chunk = AIMessageChunk(content="")
# if search_results:
# answer_handler: AnswerResponseHandler = CitationResponseHandler(
# context_docs=search_results,
# doc_id_to_rank_map=map_document_id_order(search_results),
# )
# else:
# answer_handler = PassThroughAnswerResponseHandler()
if search_results:
answer_handler: AnswerResponseHandler = CitationResponseHandler(
context_docs=search_results,
doc_id_to_rank_map=map_document_id_order(search_results),
)
else:
answer_handler = PassThroughAnswerResponseHandler()
# full_answer = ""
# start_final_answer_streaming_set = False
# # Accumulate citation infos if handler emits them
# collected_citation_infos: list[CitationInfo] = []
full_answer = ""
start_final_answer_streaming_set = False
# Accumulate citation infos if handler emits them
collected_citation_infos: list[CitationInfo] = []
# # This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# # the stream will contain AIMessageChunks with tool call information.
# for message in messages:
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information.
for message in messages:
# answer_piece = message.content
# if not isinstance(answer_piece, str):
# # this is only used for logging, so fine to
# # just add the string representation
# answer_piece = str(answer_piece)
# full_answer += answer_piece
answer_piece = message.content
if not isinstance(answer_piece, str):
# this is only used for logging, so fine to
# just add the string representation
answer_piece = str(answer_piece)
full_answer += answer_piece
# if isinstance(message, AIMessageChunk) and (
# message.tool_call_chunks or message.tool_calls
# ):
# tool_call_chunk += message # type: ignore
# elif should_stream_answer:
# for response_part in answer_handler.handle_response_part(message):
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
):
tool_call_chunk += message # type: ignore
elif should_stream_answer:
for response_part in answer_handler.handle_response_part(message):
# # only stream out answer parts
# if (
# isinstance(response_part, (OnyxAnswerPiece, AgentAnswerPiece))
# and generate_final_answer
# and response_part.answer_piece
# ):
# if chat_message_id is None:
# raise ValueError(
# "chat_message_id is required when generating final answer"
# )
# only stream out answer parts
if (
isinstance(response_part, (OnyxAnswerPiece, AgentAnswerPiece))
and generate_final_answer
and response_part.answer_piece
):
if chat_message_id is None:
raise ValueError(
"chat_message_id is required when generating final answer"
)
# if not start_final_answer_streaming_set:
# # Convert LlmDocs to SavedSearchDocs
# saved_search_docs = saved_search_docs_from_llm_docs(
# search_results
# )
# write_custom_event(
# ind,
# MessageStart(content="", final_documents=saved_search_docs),
# writer,
# )
# start_final_answer_streaming_set = True
if not start_final_answer_streaming_set:
# Convert LlmDocs to SavedSearchDocs
saved_search_docs = saved_search_docs_from_llm_docs(
search_results
)
write_custom_event(
ind,
MessageStart(content="", final_documents=saved_search_docs),
writer,
)
start_final_answer_streaming_set = True
# write_custom_event(
# ind,
# MessageDelta(content=response_part.answer_piece),
# writer,
# )
# # collect citation info objects
# elif isinstance(response_part, CitationInfo):
# collected_citation_infos.append(response_part)
write_custom_event(
ind,
MessageDelta(content=response_part.answer_piece),
writer,
)
# collect citation info objects
elif isinstance(response_part, CitationInfo):
collected_citation_infos.append(response_part)
# if generate_final_answer and start_final_answer_streaming_set:
# # start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
# write_custom_event(
# ind,
# SectionEnd(),
# writer,
# )
if generate_final_answer and start_final_answer_streaming_set:
# start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
write_custom_event(
ind,
SectionEnd(),
writer,
)
# # Emit citations section if any were collected
# if collected_citation_infos:
# write_custom_event(ind, CitationStart(), writer)
# write_custom_event(
# ind, CitationDelta(citations=collected_citation_infos), writer
# )
# write_custom_event(ind, SectionEnd(), writer)
# Emit citations section if any were collected
if collected_citation_infos:
write_custom_event(ind, CitationStart(), writer)
write_custom_event(
ind, CitationDelta(citations=collected_citation_infos), writer
)
write_custom_event(ind, SectionEnd(), writer)
# logger.debug(f"Full answer: {full_answer}")
# return BasicSearchProcessedStreamResults(
# ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
# )
logger.debug(f"Full answer: {full_answer}")
return BasicSearchProcessedStreamResults(
ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
)

View File

@@ -1,82 +1,82 @@
# from operator import add
# from typing import Annotated
# from typing import Any
# from typing import TypedDict
from operator import add
from typing import Annotated
from typing import Any
from typing import TypedDict
# from pydantic import BaseModel
from pydantic import BaseModel
# from onyx.agents.agent_search.core_state import CoreState
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.models import IterationInstructions
# from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
# from onyx.agents.agent_search.dr.models import OrchestrationPlan
# from onyx.agents.agent_search.dr.models import OrchestratorTool
# from onyx.context.search.models import InferenceSection
# from onyx.db.connector import DocumentSource
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
from onyx.agents.agent_search.dr.models import OrchestrationPlan
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.context.search.models import InferenceSection
from onyx.db.connector import DocumentSource
# ### States ###
### States ###
# class LoggerUpdate(BaseModel):
# log_messages: Annotated[list[str], add] = []
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
# class OrchestrationUpdate(LoggerUpdate):
# tools_used: Annotated[list[str], add] = []
# query_list: list[str] = []
# iteration_nr: int = 0
# current_step_nr: int = 1
# plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
# remaining_time_budget: float = 2.0 # set by default to about 2 searches
# num_closer_suggestions: int = 0 # how many times the closer was suggested
# gaps: list[str] = (
# []
# ) # gaps that may be identified by the closer before being able to answer the question.
# iteration_instructions: Annotated[list[IterationInstructions], add] = []
class OrchestrationUpdate(LoggerUpdate):
tools_used: Annotated[list[str], add] = []
query_list: list[str] = []
iteration_nr: int = 0
current_step_nr: int = 1
plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
remaining_time_budget: float = 2.0 # set by default to about 2 searches
num_closer_suggestions: int = 0 # how many times the closer was suggested
gaps: list[str] = (
[]
) # gaps that may be identified by the closer before being able to answer the question.
iteration_instructions: Annotated[list[IterationInstructions], add] = []
# class OrchestrationSetup(OrchestrationUpdate):
# original_question: str | None = None
# chat_history_string: str | None = None
# clarification: OrchestrationClarificationInfo | None = None
# available_tools: dict[str, OrchestratorTool] | None = None
# num_closer_suggestions: int = 0 # how many times the closer was suggested
class OrchestrationSetup(OrchestrationUpdate):
original_question: str | None = None
chat_history_string: str | None = None
clarification: OrchestrationClarificationInfo | None = None
available_tools: dict[str, OrchestratorTool] | None = None
num_closer_suggestions: int = 0 # how many times the closer was suggested
# active_source_types: list[DocumentSource] | None = None
# active_source_types_descriptions: str | None = None
# assistant_system_prompt: str | None = None
# assistant_task_prompt: str | None = None
# uploaded_test_context: str | None = None
# uploaded_image_context: list[dict[str, Any]] | None = None
active_source_types: list[DocumentSource] | None = None
active_source_types_descriptions: str | None = None
assistant_system_prompt: str | None = None
assistant_task_prompt: str | None = None
uploaded_test_context: str | None = None
uploaded_image_context: list[dict[str, Any]] | None = None
# class AnswerUpdate(LoggerUpdate):
# iteration_responses: Annotated[list[IterationAnswer], add] = []
class AnswerUpdate(LoggerUpdate):
iteration_responses: Annotated[list[IterationAnswer], add] = []
# class FinalUpdate(LoggerUpdate):
# final_answer: str | None = None
# all_cited_documents: list[InferenceSection] = []
class FinalUpdate(LoggerUpdate):
final_answer: str | None = None
all_cited_documents: list[InferenceSection] = []
# ## Graph Input State
# class MainInput(CoreState):
# pass
## Graph Input State
class MainInput(CoreState):
pass
# ## Graph State
# class MainState(
# # This includes the core state
# MainInput,
# OrchestrationSetup,
# AnswerUpdate,
# FinalUpdate,
# ):
# pass
## Graph State
class MainState(
# This includes the core state
MainInput,
OrchestrationSetup,
AnswerUpdate,
FinalUpdate,
):
pass
# ## Graph Output State
# class MainOutput(TypedDict):
# log_messages: list[str]
# final_answer: str | None
# all_cited_documents: list[InferenceSection]
## Graph Output State
class MainOutput(TypedDict):
log_messages: list[str]
final_answer: str | None
all_cited_documents: list[InferenceSection]

View File

@@ -1,47 +1,47 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import SearchToolStart
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def basic_search_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
def basic_search_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# current_step_nr = state.current_step_nr
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
current_step_nr = state.current_step_nr
# logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
# write_custom_event(
# current_step_nr,
# SearchToolStart(
# is_internet_search=False,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
SearchToolStart(
is_internet_search=False,
),
writer,
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="basic_search",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,261 +1,286 @@
# import re
# from datetime import datetime
# from typing import cast
import re
from datetime import datetime
from typing import cast
from uuid import UUID
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.models import BaseSearchProcessingResponse
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.models import SearchAnswer
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.dr.utils import extract_document_citations
# from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.agents.agent_search.utils import create_question_prompt
# from onyx.chat.models import LlmDoc
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
# from onyx.context.search.models import InferenceSection
# from onyx.db.connector import DocumentSource
# from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
# from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
# from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
# from onyx.server.query_and_chat.streaming_models import SearchToolDelta
# from onyx.tools.models import SearchToolOverrideKwargs
# from onyx.tools.tool_implementations.search.search_tool import SearchTool
# from onyx.tools.tool_implementations.search_like_tool_utils import (
# SEARCH_INFERENCE_SECTIONS_ID,
# )
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import BaseSearchProcessingResponse
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import SearchAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.dr.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.chat.models import LlmDoc
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.context.search.models import InferenceSection
from onyx.db.connector import DocumentSource
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def basic_search(
# state: BranchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
def basic_search(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# parallelization_nr = state.parallelization_nr
# current_step_nr = state.current_step_nr
# assistant_system_prompt = state.assistant_system_prompt
# assistant_task_prompt = state.assistant_task_prompt
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
current_step_nr = state.current_step_nr
assistant_system_prompt = state.assistant_system_prompt
assistant_task_prompt = state.assistant_task_prompt
# branch_query = state.branch_question
# if not branch_query:
# raise ValueError("branch_query is not set")
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = graph_config.inputs.prompt_builder.raw_user_query
# use_agentic_search = graph_config.behavior.use_agentic_search
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
research_type = graph_config.behavior.research_type
# if not state.available_tools:
# raise ValueError("available_tools is not set")
if not state.available_tools:
raise ValueError("available_tools is not set")
# elif len(state.tools_used) == 0:
# raise ValueError("tools_used is empty")
elif len(state.tools_used) == 0:
raise ValueError("tools_used is empty")
# search_tool_info = state.available_tools[state.tools_used[-1]]
# search_tool = cast(SearchTool, search_tool_info.tool_object)
# graph_config.tooling.force_use_tool
search_tool_info = state.available_tools[state.tools_used[-1]]
search_tool = cast(SearchTool, search_tool_info.tool_object)
force_use_tool = graph_config.tooling.force_use_tool
# # sanity check
# if search_tool != graph_config.tooling.search_tool:
# raise ValueError("search_tool does not match the configured search tool")
# sanity check
if search_tool != graph_config.tooling.search_tool:
raise ValueError("search_tool does not match the configured search tool")
# # rewrite query and identify source types
# active_source_types_str = ", ".join(
# [source.value for source in state.active_source_types or []]
# )
# rewrite query and identify source types
active_source_types_str = ", ".join(
[source.value for source in state.active_source_types or []]
)
# base_search_processing_prompt = BASE_SEARCH_PROCESSING_PROMPT.build(
# active_source_types_str=active_source_types_str,
# branch_query=branch_query,
# current_time=datetime.now().strftime("%Y-%m-%d %H:%M"),
# )
base_search_processing_prompt = BASE_SEARCH_PROCESSING_PROMPT.build(
active_source_types_str=active_source_types_str,
branch_query=branch_query,
current_time=datetime.now().strftime("%Y-%m-%d %H:%M"),
)
# try:
# search_processing = invoke_llm_json(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt, base_search_processing_prompt
# ),
# schema=BaseSearchProcessingResponse,
# timeout_override=TF_DR_TIMEOUT_SHORT,
# # max_tokens=100,
# )
# except Exception as e:
# logger.error(f"Could not process query: {e}")
# raise e
try:
search_processing = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, base_search_processing_prompt
),
schema=BaseSearchProcessingResponse,
timeout_override=TF_DR_TIMEOUT_SHORT,
# max_tokens=100,
)
except Exception as e:
logger.error(f"Could not process query: {e}")
raise e
# rewritten_query = search_processing.rewritten_query
rewritten_query = search_processing.rewritten_query
# # give back the query so we can render it in the UI
# write_custom_event(
# current_step_nr,
# SearchToolDelta(
# queries=[rewritten_query],
# documents=[],
# ),
# writer,
# )
# give back the query so we can render it in the UI
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=[rewritten_query],
documents=[],
),
writer,
)
# implied_start_date = search_processing.time_filter
implied_start_date = search_processing.time_filter
# # Validate time_filter format if it exists
# implied_time_filter = None
# if implied_start_date:
# Validate time_filter format if it exists
implied_time_filter = None
if implied_start_date:
# # Check if time_filter is in YYYY-MM-DD format
# date_pattern = r"^\d{4}-\d{2}-\d{2}$"
# if re.match(date_pattern, implied_start_date):
# implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
# Check if time_filter is in YYYY-MM-DD format
date_pattern = r"^\d{4}-\d{2}-\d{2}$"
if re.match(date_pattern, implied_start_date):
implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
# specified_source_types: list[DocumentSource] | None = (
# strings_to_document_sources(search_processing.specified_source_types)
# if search_processing.specified_source_types
# else None
# )
specified_source_types: list[DocumentSource] | None = (
strings_to_document_sources(search_processing.specified_source_types)
if search_processing.specified_source_types
else None
)
# if specified_source_types is not None and len(specified_source_types) == 0:
# specified_source_types = None
if specified_source_types is not None and len(specified_source_types) == 0:
specified_source_types = None
# logger.debug(
# f"Search start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
logger.debug(
f"Search start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# retrieved_docs: list[InferenceSection] = []
retrieved_docs: list[InferenceSection] = []
callback_container: list[list[InferenceSection]] = []
# for tool_response in search_tool.run(
# query=rewritten_query,
# document_sources=specified_source_types,
# time_filter=implied_time_filter,
# override_kwargs=SearchToolOverrideKwargs(original_query=rewritten_query),
# ):
# # get retrieved docs to send to the rest of the graph
# if tool_response.id == SEARCH_INFERENCE_SECTIONS_ID:
# retrieved_docs = cast(list[InferenceSection], tool_response.response)
user_file_ids: list[UUID] | None = None
project_id: int | None = None
if force_use_tool.override_kwargs and isinstance(
force_use_tool.override_kwargs, SearchToolOverrideKwargs
):
override_kwargs = force_use_tool.override_kwargs
user_file_ids = override_kwargs.user_file_ids
project_id = override_kwargs.project_id
# break
# new db session to avoid concurrency issues
with get_session_with_current_tenant() as search_db_session:
for tool_response in search_tool.run(
query=rewritten_query,
document_sources=specified_source_types,
time_filter=implied_time_filter,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=search_db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
original_query=rewritten_query,
user_file_ids=user_file_ids,
project_id=project_id,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
# # render the retrieved docs in the UI
# write_custom_event(
# current_step_nr,
# SearchToolDelta(
# queries=[],
# documents=convert_inference_sections_to_search_docs(
# retrieved_docs, is_internet=False
# ),
# ),
# writer,
# )
break
# document_texts_list = []
# render the retrieved docs in the UI
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=[],
documents=convert_inference_sections_to_search_docs(
retrieved_docs, is_internet=False
),
),
writer,
)
# for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
# if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
# raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
# chunk_text = build_document_context(retrieved_doc, doc_num + 1)
# document_texts_list.append(chunk_text)
document_texts_list = []
# document_texts = "\n\n".join(document_texts_list)
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
# logger.debug(
# f"Search end/LLM start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
document_texts = "\n\n".join(document_texts_list)
# # Built prompt
logger.debug(
f"Search end/LLM start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# if use_agentic_search:
# search_prompt = INTERNAL_SEARCH_PROMPTS[use_agentic_search].build(
# search_query=branch_query,
# base_question=base_question,
# document_text=document_texts,
# )
# Built prompt
# # Run LLM
if research_type == ResearchType.DEEP:
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
search_query=branch_query,
base_question=base_question,
document_text=document_texts,
)
# # search_answer_json = None
# search_answer_json = invoke_llm_json(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
# ),
# schema=SearchAnswer,
# timeout_override=TF_DR_TIMEOUT_LONG,
# # max_tokens=1500,
# )
# Run LLM
# logger.debug(
# f"LLM/all done for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
# search_answer_json = None
search_answer_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
),
schema=SearchAnswer,
timeout_override=TF_DR_TIMEOUT_LONG,
# max_tokens=1500,
)
# # get cited documents
# answer_string = search_answer_json.answer
# claims = search_answer_json.claims or []
# reasoning = search_answer_json.reasoning
# # answer_string = ""
# # claims = []
logger.debug(
f"LLM/all done for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# (
# citation_numbers,
# answer_string,
# claims,
# ) = extract_document_citations(answer_string, claims)
# get cited documents
answer_string = search_answer_json.answer
claims = search_answer_json.claims or []
reasoning = search_answer_json.reasoning
# answer_string = ""
# claims = []
# if citation_numbers and (
# (max(citation_numbers) > len(retrieved_docs)) or min(citation_numbers) < 1
# ):
# raise ValueError("Citation numbers are out of range for retrieved docs.")
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
# cited_documents = {
# citation_number: retrieved_docs[citation_number - 1]
# for citation_number in citation_numbers
# }
if citation_numbers and (
(max(citation_numbers) > len(retrieved_docs)) or min(citation_numbers) < 1
):
raise ValueError("Citation numbers are out of range for retrieved docs.")
# else:
# answer_string = ""
# claims = []
# cited_documents = {
# doc_num + 1: retrieved_doc
# for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
# }
# reasoning = ""
cited_documents = {
citation_number: retrieved_docs[citation_number - 1]
for citation_number in citation_numbers
}
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=search_tool_info.llm_path,
# tool_id=search_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=branch_query,
# answer=answer_string,
# claims=claims,
# cited_documents=cited_documents,
# reasoning=reasoning,
# additional_data=None,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="basic_search",
# node_name="searching",
# node_start_time=node_start_time,
# )
# ],
# )
else:
answer_string = ""
claims = []
cited_documents = {
doc_num + 1: retrieved_doc
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
}
reasoning = ""
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=search_tool_info.llm_path,
tool_id=search_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=reasoning,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,77 +1,77 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.context.search.models import SavedSearchDoc
# from onyx.context.search.models import SearchDoc
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def is_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
def is_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# node_start_time = datetime.now()
node_start_time = datetime.now()
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
# current_step_nr = state.current_step_nr
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# [update.question for update in new_updates]
# doc_lists = [list(update.cited_documents.values()) for update in new_updates]
[update.question for update in new_updates]
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
# doc_list = []
doc_list = []
# for xs in doc_lists:
# for x in xs:
# doc_list.append(x)
for xs in doc_lists:
for x in xs:
doc_list.append(x)
# # Convert InferenceSections to SavedSearchDocs
# search_docs = SearchDoc.from_chunks_or_sections(doc_list)
# retrieved_saved_search_docs = [
# SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
# for search_doc in search_docs
# ]
# Convert InferenceSections to SavedSearchDocs
search_docs = SearchDoc.from_chunks_or_sections(doc_list)
retrieved_saved_search_docs = [
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
for search_doc in search_docs
]
# for retrieved_saved_search_doc in retrieved_saved_search_docs:
# retrieved_saved_search_doc.is_internet = False
for retrieved_saved_search_doc in retrieved_saved_search_docs:
retrieved_saved_search_doc.is_internet = False
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# current_step_nr += 1
current_step_nr += 1
# return SubAgentUpdate(
# iteration_responses=new_updates,
# current_step_nr=current_step_nr,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="basic_search",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,50 +1,50 @@
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_1_branch import (
# basic_search_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import (
# basic_search,
# )
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_3_reduce import (
# is_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_image_generation_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_1_branch import (
basic_search_branch,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import (
basic_search,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_3_reduce import (
is_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_image_generation_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def dr_basic_search_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Web Search Sub-Agent
# """
def dr_basic_search_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Web Search Sub-Agent
"""
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# ### Add nodes ###
### Add nodes ###
# graph.add_node("branch", basic_search_branch)
graph.add_node("branch", basic_search_branch)
# graph.add_node("act", basic_search)
graph.add_node("act", basic_search)
# graph.add_node("reducer", is_reducer)
graph.add_node("reducer", is_reducer)
# ### Add edges ###
### Add edges ###
# graph.add_edge(start_key=START, end_key="branch")
graph.add_edge(start_key=START, end_key="branch")
# graph.add_conditional_edges("branch", branching_router)
graph.add_conditional_edges("branch", branching_router)
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="reducer", end_key=END)
graph.add_edge(start_key="reducer", end_key=END)
# return graph
return graph

View File

@@ -1,30 +1,30 @@
# from collections.abc import Hashable
from collections.abc import Hashable
# from langgraph.types import Send
from langgraph.types import Send
# from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# current_step_nr=state.current_step_nr,
# context="",
# active_source_types=state.active_source_types,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:MAX_DR_PARALLEL_SEARCH]
# )
# ]
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
current_step_nr=state.current_step_nr,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:MAX_DR_PARALLEL_SEARCH]
)
]

View File

@@ -1,36 +1,36 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def custom_tool_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
def custom_tool_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
# logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,164 +1,169 @@
# import json
# from datetime import datetime
# from typing import cast
import json
from datetime import datetime
from typing import cast
# from langchain_core.messages import AIMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
# from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
# from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
# from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
# from onyx.tools.tool_implementations.mcp.mcp_tool import MCP_TOOL_RESPONSE_ID
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
from onyx.tools.tool_implementations.mcp.mcp_tool import MCP_TOOL_RESPONSE_ID
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def custom_tool_act(
# state: BranchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
def custom_tool_act(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# parallelization_nr = state.parallelization_nr
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
# if not state.available_tools:
# raise ValueError("available_tools is not set")
if not state.available_tools:
raise ValueError("available_tools is not set")
# custom_tool_info = state.available_tools[state.tools_used[-1]]
# custom_tool_name = custom_tool_info.name
# custom_tool = cast(CustomTool, custom_tool_info.tool_object)
custom_tool_info = state.available_tools[state.tools_used[-1]]
custom_tool_name = custom_tool_info.name
custom_tool = cast(CustomTool, custom_tool_info.tool_object)
# branch_query = state.branch_question
# if not branch_query:
# raise ValueError("branch_query is not set")
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = graph_config.inputs.prompt_builder.raw_user_query
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
# logger.debug(
# f"Tool call start for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
logger.debug(
f"Tool call start for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# # get tool call args
# tool_args: dict | None = None
# if graph_config.tooling.using_tool_calling_llm:
# # get tool call args from tool-calling LLM
# tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
# query=branch_query,
# base_question=base_question,
# tool_description=custom_tool_info.description,
# )
# tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
# tool_use_prompt,
# tools=[custom_tool.tool_definition()],
# tool_choice="required",
# timeout_override=TF_DR_TIMEOUT_LONG,
# )
# get tool call args
tool_args: dict | None = None
if graph_config.tooling.using_tool_calling_llm:
# get tool call args from tool-calling LLM
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
query=branch_query,
base_question=base_question,
tool_description=custom_tool_info.description,
)
tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
tool_use_prompt,
tools=[custom_tool.tool_definition()],
tool_choice="required",
timeout_override=TF_DR_TIMEOUT_LONG,
)
# # make sure we got a tool call
# if (
# isinstance(tool_calling_msg, AIMessage)
# and len(tool_calling_msg.tool_calls) == 1
# ):
# tool_args = tool_calling_msg.tool_calls[0]["args"]
# else:
# logger.warning("Tool-calling LLM did not emit a tool call")
# make sure we got a tool call
if (
isinstance(tool_calling_msg, AIMessage)
and len(tool_calling_msg.tool_calls) == 1
):
tool_args = tool_calling_msg.tool_calls[0]["args"]
else:
logger.warning("Tool-calling LLM did not emit a tool call")
# if tool_args is None:
# raise ValueError(
# "Failed to obtain tool arguments from LLM - tool calling is required"
# )
if tool_args is None:
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
tool_args = custom_tool.get_args_for_non_tool_calling_llm(
query=branch_query,
history=[],
llm=graph_config.tooling.primary_llm,
force_run=True,
)
# # run the tool
# response_summary: CustomToolCallSummary | None = None
# for tool_response in custom_tool.run(**tool_args):
# if tool_response.id in {CUSTOM_TOOL_RESPONSE_ID, MCP_TOOL_RESPONSE_ID}:
# response_summary = cast(CustomToolCallSummary, tool_response.response)
# break
if tool_args is None:
raise ValueError("Failed to obtain tool arguments from LLM")
# if not response_summary:
# raise ValueError("Custom tool did not return a valid response summary")
# run the tool
response_summary: CustomToolCallSummary | None = None
for tool_response in custom_tool.run(**tool_args):
if tool_response.id in {CUSTOM_TOOL_RESPONSE_ID, MCP_TOOL_RESPONSE_ID}:
response_summary = cast(CustomToolCallSummary, tool_response.response)
break
# # summarise tool result
# if not response_summary.response_type:
# raise ValueError("Response type is not returned.")
if not response_summary:
raise ValueError("Custom tool did not return a valid response summary")
# if response_summary.response_type == "json":
# tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
# elif response_summary.response_type in {"image", "csv"}:
# tool_result_str = f"{response_summary.response_type} files: {response_summary.tool_result.file_ids}"
# else:
# tool_result_str = str(response_summary.tool_result)
# summarise tool result
if not response_summary.response_type:
raise ValueError("Response type is not returned.")
# tool_str = (
# f"Tool used: {custom_tool_name}\n"
# f"Description: {custom_tool_info.description}\n"
# f"Result: {tool_result_str}"
# )
if response_summary.response_type == "json":
tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
elif response_summary.response_type in {"image", "csv"}:
tool_result_str = f"{response_summary.response_type} files: {response_summary.tool_result.file_ids}"
else:
tool_result_str = str(response_summary.tool_result)
# tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
# query=branch_query, base_question=base_question, tool_response=tool_str
# )
# answer_string = str(
# graph_config.tooling.primary_llm.invoke(
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
# ).content
# ).strip()
tool_str = (
f"Tool used: {custom_tool_name}\n"
f"Description: {custom_tool_info.description}\n"
f"Result: {tool_result_str}"
)
# tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
# query=branch_query, base_question=base_question, tool_response=tool_str
# )
# answer_string = str(
# graph_config.tooling.primary_llm.invoke_langchain(
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
# ).content
# ).strip()
tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
query=branch_query, base_question=base_question, tool_response=tool_str
)
answer_string = str(
graph_config.tooling.primary_llm.invoke_langchain(
tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
).content
).strip()
# logger.debug(
# f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
# get file_ids:
file_ids = None
if response_summary.response_type in {"image", "csv"} and hasattr(
response_summary.tool_result, "file_ids"
):
file_ids = response_summary.tool_result.file_ids
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=custom_tool_name,
# tool_id=custom_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=branch_query,
# answer=answer_string,
# claims=[],
# cited_documents={},
# reasoning="",
# additional_data=None,
# response_type=response_summary.response_type,
# data=response_summary.tool_result,
# file_ids=file_ids,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="tool_calling",
# node_start_time=node_start_time,
# )
# ],
# )
logger.debug(
f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=custom_tool_name,
tool_id=custom_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning="",
additional_data=None,
response_type=response_summary.response_type,
data=response_summary.tool_result,
file_ids=file_ids,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="tool_calling",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,82 +1,82 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import CustomToolDelta
# from onyx.server.query_and_chat.streaming_models import CustomToolStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def custom_tool_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
def custom_tool_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# node_start_time = datetime.now()
node_start_time = datetime.now()
# current_step_nr = state.current_step_nr
current_step_nr = state.current_step_nr
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# for new_update in new_updates:
for new_update in new_updates:
# if not new_update.response_type:
# raise ValueError("Response type is not returned.")
if not new_update.response_type:
raise ValueError("Response type is not returned.")
# write_custom_event(
# current_step_nr,
# CustomToolStart(
# tool_name=new_update.tool,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
CustomToolStart(
tool_name=new_update.tool,
),
writer,
)
# write_custom_event(
# current_step_nr,
# CustomToolDelta(
# tool_name=new_update.tool,
# response_type=new_update.response_type,
# data=new_update.data,
# file_ids=new_update.file_ids,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
CustomToolDelta(
tool_name=new_update.tool,
response_type=new_update.response_type,
data=new_update.data,
file_ids=new_update.file_ids,
),
writer,
)
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# current_step_nr += 1
current_step_nr += 1
# return SubAgentUpdate(
# iteration_responses=new_updates,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )
return SubAgentUpdate(
iteration_responses=new_updates,
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,28 +1,28 @@
# from collections.abc import Hashable
from collections.abc import Hashable
# from langgraph.types import Send
from langgraph.types import Send
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import (
# SubAgentInput,
# )
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import (
SubAgentInput,
)
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# context="",
# active_source_types=state.active_source_types,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:1] # no parallel call for now
# )
# ]
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel call for now
)
]

View File

@@ -1,50 +1,50 @@
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_1_branch import (
# custom_tool_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_2_act import (
# custom_tool_act,
# )
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_3_reduce import (
# custom_tool_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_1_branch import (
custom_tool_branch,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_2_act import (
custom_tool_act,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_3_reduce import (
custom_tool_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def dr_custom_tool_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Generic Tool Sub-Agent
# """
def dr_custom_tool_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Generic Tool Sub-Agent
"""
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# ### Add nodes ###
### Add nodes ###
# graph.add_node("branch", custom_tool_branch)
graph.add_node("branch", custom_tool_branch)
# graph.add_node("act", custom_tool_act)
graph.add_node("act", custom_tool_act)
# graph.add_node("reducer", custom_tool_reducer)
graph.add_node("reducer", custom_tool_reducer)
# ### Add edges ###
### Add edges ###
# graph.add_edge(start_key=START, end_key="branch")
graph.add_edge(start_key=START, end_key="branch")
# graph.add_conditional_edges("branch", branching_router)
graph.add_conditional_edges("branch", branching_router)
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="reducer", end_key=END)
graph.add_edge(start_key="reducer", end_key=END)
# return graph
return graph

View File

@@ -1,36 +1,36 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def generic_internal_tool_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
def generic_internal_tool_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
# logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="generic_internal_tool",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="generic_internal_tool",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,147 +1,149 @@
# import json
# from datetime import datetime
# from typing import cast
import json
from datetime import datetime
from typing import cast
# from langchain_core.messages import AIMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
# from onyx.prompts.dr_prompts import OKTA_TOOL_USE_SPECIAL_PROMPT
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
from onyx.prompts.dr_prompts import OKTA_TOOL_USE_SPECIAL_PROMPT
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def generic_internal_tool_act(
# state: BranchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
def generic_internal_tool_act(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# parallelization_nr = state.parallelization_nr
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
# if not state.available_tools:
# raise ValueError("available_tools is not set")
if not state.available_tools:
raise ValueError("available_tools is not set")
# generic_internal_tool_info = state.available_tools[state.tools_used[-1]]
# generic_internal_tool_name = generic_internal_tool_info.llm_path
# generic_internal_tool = generic_internal_tool_info.tool_object
generic_internal_tool_info = state.available_tools[state.tools_used[-1]]
generic_internal_tool_name = generic_internal_tool_info.llm_path
generic_internal_tool = generic_internal_tool_info.tool_object
# if generic_internal_tool is None:
# raise ValueError("generic_internal_tool is not set")
if generic_internal_tool is None:
raise ValueError("generic_internal_tool is not set")
# branch_query = state.branch_question
# if not branch_query:
# raise ValueError("branch_query is not set")
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = graph_config.inputs.prompt_builder.raw_user_query
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
# logger.debug(
# f"Tool call start for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
logger.debug(
f"Tool call start for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# # get tool call args
# tool_args: dict | None = None
# if graph_config.tooling.using_tool_calling_llm:
# # get tool call args from tool-calling LLM
# tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
# query=branch_query,
# base_question=base_question,
# tool_description=generic_internal_tool_info.description,
# )
# tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
# tool_use_prompt,
# tools=[generic_internal_tool.tool_definition()],
# tool_choice="required",
# timeout_override=TF_DR_TIMEOUT_SHORT,
# )
# get tool call args
tool_args: dict | None = None
if graph_config.tooling.using_tool_calling_llm:
# get tool call args from tool-calling LLM
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
query=branch_query,
base_question=base_question,
tool_description=generic_internal_tool_info.description,
)
tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
tool_use_prompt,
tools=[generic_internal_tool.tool_definition()],
tool_choice="required",
timeout_override=TF_DR_TIMEOUT_SHORT,
)
# # make sure we got a tool call
# if (
# isinstance(tool_calling_msg, AIMessage)
# and len(tool_calling_msg.tool_calls) == 1
# ):
# tool_args = tool_calling_msg.tool_calls[0]["args"]
# else:
# logger.warning("Tool-calling LLM did not emit a tool call")
# make sure we got a tool call
if (
isinstance(tool_calling_msg, AIMessage)
and len(tool_calling_msg.tool_calls) == 1
):
tool_args = tool_calling_msg.tool_calls[0]["args"]
else:
logger.warning("Tool-calling LLM did not emit a tool call")
# if tool_args is None:
# raise ValueError(
# "Failed to obtain tool arguments from LLM - tool calling is required"
# )
if tool_args is None:
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
tool_args = generic_internal_tool.get_args_for_non_tool_calling_llm(
query=branch_query,
history=[],
llm=graph_config.tooling.primary_llm,
force_run=True,
)
# # run the tool
# tool_responses = list(generic_internal_tool.run(**tool_args))
# final_data = generic_internal_tool.final_result(*tool_responses)
# tool_result_str = json.dumps(final_data, ensure_ascii=False)
if tool_args is None:
raise ValueError("Failed to obtain tool arguments from LLM")
# tool_str = (
# f"Tool used: {generic_internal_tool.display_name}\n"
# f"Description: {generic_internal_tool_info.description}\n"
# f"Result: {tool_result_str}"
# )
# run the tool
tool_responses = list(generic_internal_tool.run(**tool_args))
final_data = generic_internal_tool.final_result(*tool_responses)
tool_result_str = json.dumps(final_data, ensure_ascii=False)
# if generic_internal_tool.display_name == "Okta Profile":
# tool_prompt = OKTA_TOOL_USE_SPECIAL_PROMPT
# else:
# tool_prompt = CUSTOM_TOOL_USE_PROMPT
tool_str = (
f"Tool used: {generic_internal_tool.display_name}\n"
f"Description: {generic_internal_tool_info.description}\n"
f"Result: {tool_result_str}"
)
# tool_summary_prompt = tool_prompt.build(
# query=branch_query, base_question=base_question, tool_response=tool_str
# )
# answer_string = str(
# graph_config.tooling.primary_llm.invoke(
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
# ).content
# ).strip()
if generic_internal_tool.display_name == "Okta Profile":
tool_prompt = OKTA_TOOL_USE_SPECIAL_PROMPT
else:
tool_prompt = CUSTOM_TOOL_USE_PROMPT
# tool_summary_prompt = tool_prompt.build(
# query=branch_query, base_question=base_question, tool_response=tool_str
# )
# answer_string = str(
# graph_config.tooling.primary_llm.invoke_langchain(
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
# ).content
# ).strip()
tool_summary_prompt = tool_prompt.build(
query=branch_query, base_question=base_question, tool_response=tool_str
)
answer_string = str(
graph_config.tooling.primary_llm.invoke_langchain(
tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
).content
).strip()
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=generic_internal_tool.name,
# tool_id=generic_internal_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=branch_query,
# answer=answer_string,
# claims=[],
# cited_documents={},
# reasoning="",
# additional_data=None,
# response_type="text", # TODO: convert all response types to enums
# data=answer_string,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="tool_calling",
# node_start_time=node_start_time,
# )
# ],
# )
logger.debug(
f"Tool call end for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=generic_internal_tool.llm_name,
tool_id=generic_internal_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning="",
additional_data=None,
response_type="text", # TODO: convert all response types to enums
data=answer_string,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="tool_calling",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,82 +1,82 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import CustomToolDelta
# from onyx.server.query_and_chat.streaming_models import CustomToolStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def generic_internal_tool_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
def generic_internal_tool_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# node_start_time = datetime.now()
node_start_time = datetime.now()
# current_step_nr = state.current_step_nr
current_step_nr = state.current_step_nr
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# for new_update in new_updates:
for new_update in new_updates:
# if not new_update.response_type:
# raise ValueError("Response type is not returned.")
if not new_update.response_type:
raise ValueError("Response type is not returned.")
# write_custom_event(
# current_step_nr,
# CustomToolStart(
# tool_name=new_update.tool,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
CustomToolStart(
tool_name=new_update.tool,
),
writer,
)
# write_custom_event(
# current_step_nr,
# CustomToolDelta(
# tool_name=new_update.tool,
# response_type=new_update.response_type,
# data=new_update.data,
# file_ids=[],
# ),
# writer,
# )
write_custom_event(
current_step_nr,
CustomToolDelta(
tool_name=new_update.tool,
response_type=new_update.response_type,
data=new_update.data,
file_ids=[],
),
writer,
)
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# current_step_nr += 1
current_step_nr += 1
# return SubAgentUpdate(
# iteration_responses=new_updates,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )
return SubAgentUpdate(
iteration_responses=new_updates,
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,28 +1,28 @@
# from collections.abc import Hashable
from collections.abc import Hashable
# from langgraph.types import Send
from langgraph.types import Send
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import (
# SubAgentInput,
# )
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import (
SubAgentInput,
)
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# context="",
# active_source_types=state.active_source_types,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:1] # no parallel call for now
# )
# ]
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel call for now
)
]

View File

@@ -1,50 +1,50 @@
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_1_branch import (
# generic_internal_tool_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_2_act import (
# generic_internal_tool_act,
# )
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_3_reduce import (
# generic_internal_tool_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_1_branch import (
generic_internal_tool_branch,
)
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_2_act import (
generic_internal_tool_act,
)
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_3_reduce import (
generic_internal_tool_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def dr_generic_internal_tool_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Generic Tool Sub-Agent
# """
def dr_generic_internal_tool_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Generic Tool Sub-Agent
"""
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# ### Add nodes ###
### Add nodes ###
# graph.add_node("branch", generic_internal_tool_branch)
graph.add_node("branch", generic_internal_tool_branch)
# graph.add_node("act", generic_internal_tool_act)
graph.add_node("act", generic_internal_tool_act)
# graph.add_node("reducer", generic_internal_tool_reducer)
graph.add_node("reducer", generic_internal_tool_reducer)
# ### Add edges ###
### Add edges ###
# graph.add_edge(start_key=START, end_key="branch")
graph.add_edge(start_key=START, end_key="branch")
# graph.add_conditional_edges("branch", branching_router)
graph.add_conditional_edges("branch", branching_router)
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="reducer", end_key=END)
graph.add_edge(start_key="reducer", end_key=END)
# return graph
return graph

View File

@@ -1,45 +1,45 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def image_generation_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a image generation as part of the DR process.
# """
def image_generation_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a image generation as part of the DR process.
"""
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
# logger.debug(f"Image generation start {iteration_nr} at {datetime.now()}")
logger.debug(f"Image generation start {iteration_nr} at {datetime.now()}")
# # tell frontend that we are starting the image generation tool
# write_custom_event(
# state.current_step_nr,
# ImageGenerationToolStart(),
# writer,
# )
# tell frontend that we are starting the image generation tool
write_custom_event(
state.current_step_nr,
ImageGenerationToolStart(),
writer,
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="image_generation",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="image_generation",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,187 +1,189 @@
# import json
# from datetime import datetime
# from typing import cast
import json
from datetime import datetime
from typing import cast
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.models import GeneratedImage
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.file_store.utils import build_frontend_file_url
# from onyx.file_store.utils import save_files
# from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
# from onyx.tools.tool_implementations.images.image_generation_tool import (
# IMAGE_GENERATION_HEARTBEAT_ID,
# )
# from onyx.tools.tool_implementations.images.image_generation_tool import (
# IMAGE_GENERATION_RESPONSE_ID,
# )
# from onyx.tools.tool_implementations.images.image_generation_tool import (
# ImageGenerationResponse,
# )
# from onyx.tools.tool_implementations.images.image_generation_tool import (
# ImageGenerationTool,
# )
# from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.models import GeneratedImage
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.file_store.utils import build_frontend_file_url
from onyx.file_store.utils import save_files
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
from onyx.tools.tool_implementations.images.image_generation_tool import (
IMAGE_GENERATION_HEARTBEAT_ID,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
IMAGE_GENERATION_RESPONSE_ID,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def image_generation(
# state: BranchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
def image_generation(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# parallelization_nr = state.parallelization_nr
# state.assistant_system_prompt
# state.assistant_task_prompt
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
state.assistant_system_prompt
state.assistant_task_prompt
# branch_query = state.branch_question
# if not branch_query:
# raise ValueError("branch_query is not set")
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
# graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.prompt_builder.raw_user_query
graph_config.behavior.research_type
# if not state.available_tools:
# raise ValueError("available_tools is not set")
if not state.available_tools:
raise ValueError("available_tools is not set")
# image_tool_info = state.available_tools[state.tools_used[-1]]
# image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
image_tool_info = state.available_tools[state.tools_used[-1]]
image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
# image_prompt = branch_query
# requested_shape: ImageShape | None = None
image_prompt = branch_query
requested_shape: ImageShape | None = None
# try:
# parsed_query = json.loads(branch_query)
# except json.JSONDecodeError:
# parsed_query = None
try:
parsed_query = json.loads(branch_query)
except json.JSONDecodeError:
parsed_query = None
# if isinstance(parsed_query, dict):
# prompt_from_llm = parsed_query.get("prompt")
# if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
# image_prompt = prompt_from_llm.strip()
if isinstance(parsed_query, dict):
prompt_from_llm = parsed_query.get("prompt")
if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
image_prompt = prompt_from_llm.strip()
# raw_shape = parsed_query.get("shape")
# if isinstance(raw_shape, str):
# try:
# requested_shape = ImageShape(raw_shape)
# except ValueError:
# logger.warning(
# "Received unsupported image shape '%s' from LLM. Falling back to square.",
# raw_shape,
# )
raw_shape = parsed_query.get("shape")
if isinstance(raw_shape, str):
try:
requested_shape = ImageShape(raw_shape)
except ValueError:
logger.warning(
"Received unsupported image shape '%s' from LLM. Falling back to square.",
raw_shape,
)
# logger.debug(
# f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
logger.debug(
f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# # Generate images using the image generation tool
# image_generation_responses: list[ImageGenerationResponse] = []
# Generate images using the image generation tool
image_generation_responses: list[ImageGenerationResponse] = []
# if requested_shape is not None:
# tool_iterator = image_tool.run(
# prompt=image_prompt,
# shape=requested_shape.value,
# )
# else:
# tool_iterator = image_tool.run(prompt=image_prompt)
if requested_shape is not None:
tool_iterator = image_tool.run(
prompt=image_prompt,
shape=requested_shape.value,
)
else:
tool_iterator = image_tool.run(prompt=image_prompt)
# for tool_response in tool_iterator:
# if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
# # Stream heartbeat to frontend
# write_custom_event(
# state.current_step_nr,
# ImageGenerationToolHeartbeat(),
# writer,
# )
# elif tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
# response = cast(list[ImageGenerationResponse], tool_response.response)
# image_generation_responses = response
# break
for tool_response in tool_iterator:
if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
# Stream heartbeat to frontend
write_custom_event(
state.current_step_nr,
ImageGenerationToolHeartbeat(),
writer,
)
elif tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
response = cast(list[ImageGenerationResponse], tool_response.response)
image_generation_responses = response
break
# # save images to file store
# file_ids = save_files(
# urls=[],
# base64_files=[img.image_data for img in image_generation_responses],
# )
# save images to file store
file_ids = save_files(
urls=[],
base64_files=[img.image_data for img in image_generation_responses],
)
# final_generated_images = [
# GeneratedImage(
# file_id=file_id,
# url=build_frontend_file_url(file_id),
# revised_prompt=img.revised_prompt,
# shape=(requested_shape or ImageShape.SQUARE).value,
# )
# for file_id, img in zip(file_ids, image_generation_responses)
# ]
final_generated_images = [
GeneratedImage(
file_id=file_id,
url=build_frontend_file_url(file_id),
revised_prompt=img.revised_prompt,
shape=(requested_shape or ImageShape.SQUARE).value,
)
for file_id, img in zip(file_ids, image_generation_responses)
]
# logger.debug(
# f"Image generation complete for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
logger.debug(
f"Image generation complete for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# # Create answer string describing the generated images
# if final_generated_images:
# image_descriptions = []
# for i, img in enumerate(final_generated_images, 1):
# if img.shape and img.shape != ImageShape.SQUARE.value:
# image_descriptions.append(
# f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
# )
# else:
# image_descriptions.append(f"Image {i}: {img.revised_prompt}")
# Create answer string describing the generated images
if final_generated_images:
image_descriptions = []
for i, img in enumerate(final_generated_images, 1):
if img.shape and img.shape != ImageShape.SQUARE.value:
image_descriptions.append(
f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
)
else:
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
# answer_string = (
# f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
# + "\n".join(image_descriptions)
# )
# if requested_shape:
# reasoning = (
# "Used image generation tool to create "
# f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
# )
# else:
# reasoning = (
# "Used image generation tool to create "
# f"{len(final_generated_images)} image(s) based on the user's request."
# )
# else:
# answer_string = f"Failed to generate images for request: {image_prompt}"
# reasoning = "Image generation tool did not return any results."
answer_string = (
f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
+ "\n".join(image_descriptions)
)
if requested_shape:
reasoning = (
"Used image generation tool to create "
f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
)
else:
reasoning = (
"Used image generation tool to create "
f"{len(final_generated_images)} image(s) based on the user's request."
)
else:
answer_string = f"Failed to generate images for request: {image_prompt}"
reasoning = "Image generation tool did not return any results."
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=image_tool_info.llm_path,
# tool_id=image_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=branch_query,
# answer=answer_string,
# claims=[],
# cited_documents={},
# reasoning=reasoning,
# generated_images=final_generated_images,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="image_generation",
# node_name="generating",
# node_start_time=node_start_time,
# )
# ],
# )
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=image_tool_info.llm_path,
tool_id=image_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning=reasoning,
generated_images=final_generated_images,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="image_generation",
node_name="generating",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,71 +1,71 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.models import GeneratedImage
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.models import GeneratedImage
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def is_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
def is_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# node_start_time = datetime.now()
node_start_time = datetime.now()
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
# current_step_nr = state.current_step_nr
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
# generated_images: list[GeneratedImage] = []
# for update in new_updates:
# if update.generated_images:
# generated_images.extend(update.generated_images)
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
generated_images: list[GeneratedImage] = []
for update in new_updates:
if update.generated_images:
generated_images.extend(update.generated_images)
# # Write the results to the stream
# write_custom_event(
# current_step_nr,
# ImageGenerationFinal(
# images=generated_images,
# ),
# writer,
# )
# Write the results to the stream
write_custom_event(
current_step_nr,
ImageGenerationToolDelta(
images=generated_images,
),
writer,
)
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# current_step_nr += 1
current_step_nr += 1
# return SubAgentUpdate(
# iteration_responses=new_updates,
# current_step_nr=current_step_nr,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="image_generation",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="image_generation",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,29 +1,29 @@
# from collections.abc import Hashable
from collections.abc import Hashable
# from langgraph.types import Send
from langgraph.types import Send
# from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# context="",
# active_source_types=state.active_source_types,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:MAX_DR_PARALLEL_SEARCH]
# )
# ]
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:MAX_DR_PARALLEL_SEARCH]
)
]

View File

@@ -1,50 +1,50 @@
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_1_branch import (
# image_generation_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_2_act import (
# image_generation,
# )
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_3_reduce import (
# is_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_1_branch import (
image_generation_branch,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_2_act import (
image_generation,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_3_reduce import (
is_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def dr_image_generation_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Image Generation Sub-Agent
# """
def dr_image_generation_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Image Generation Sub-Agent
"""
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# ### Add nodes ###
### Add nodes ###
# graph.add_node("branch", image_generation_branch)
graph.add_node("branch", image_generation_branch)
# graph.add_node("act", image_generation)
graph.add_node("act", image_generation)
# graph.add_node("reducer", is_reducer)
graph.add_node("reducer", is_reducer)
# ### Add edges ###
### Add edges ###
# graph.add_edge(start_key=START, end_key="branch")
graph.add_edge(start_key=START, end_key="branch")
# graph.add_conditional_edges("branch", branching_router)
graph.add_conditional_edges("branch", branching_router)
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="reducer", end_key=END)
graph.add_edge(start_key="reducer", end_key=END)
# return graph
return graph

View File

@@ -1,13 +1,13 @@
# from pydantic import BaseModel
from pydantic import BaseModel
# class GeneratedImage(BaseModel):
# file_id: str
# url: str
# revised_prompt: str
# shape: str | None = None
class GeneratedImage(BaseModel):
file_id: str
url: str
revised_prompt: str
shape: str | None = None
# # Needed for PydanticType
# class GeneratedImageFullResult(BaseModel):
# images: list[GeneratedImage]
# Needed for PydanticType
class GeneratedImageFullResult(BaseModel):
images: list[GeneratedImage]

View File

@@ -1,36 +1,36 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def kg_search_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a KG search as part of the DR process.
# """
def kg_search_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
# logger.debug(f"Search start for KG Search {iteration_nr} at {datetime.now()}")
logger.debug(f"Search start for KG Search {iteration_nr} at {datetime.now()}")
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="kg_search",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,97 +1,97 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.utils import extract_document_citations
# from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
# from onyx.agents.agent_search.kb_search.states import MainInput as KbMainInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.context.search.models import InferenceSection
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
from onyx.agents.agent_search.kb_search.states import MainInput as KbMainInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def kg_search(
# state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> BranchUpdate:
# """
# LangGraph node to perform a KG search as part of the DR process.
# """
def kg_search(
state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BranchUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# state.current_step_nr
# parallelization_nr = state.parallelization_nr
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
state.current_step_nr
parallelization_nr = state.parallelization_nr
# search_query = state.branch_question
# if not search_query:
# raise ValueError("search_query is not set")
search_query = state.branch_question
if not search_query:
raise ValueError("search_query is not set")
# logger.debug(
# f"Search start for KG Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
logger.debug(
f"Search start for KG Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# if not state.available_tools:
# raise ValueError("available_tools is not set")
if not state.available_tools:
raise ValueError("available_tools is not set")
# kg_tool_info = state.available_tools[state.tools_used[-1]]
kg_tool_info = state.available_tools[state.tools_used[-1]]
# kb_graph = kb_graph_builder().compile()
kb_graph = kb_graph_builder().compile()
# kb_results = kb_graph.invoke(
# input=KbMainInput(question=search_query, individual_flow=False),
# config=config,
# )
kb_results = kb_graph.invoke(
input=KbMainInput(question=search_query, individual_flow=False),
config=config,
)
# # get cited documents
# answer_string = kb_results.get("final_answer") or "No answer provided"
# claims: list[str] = []
# retrieved_docs: list[InferenceSection] = kb_results.get("retrieved_documents", [])
# get cited documents
answer_string = kb_results.get("final_answer") or "No answer provided"
claims: list[str] = []
retrieved_docs: list[InferenceSection] = kb_results.get("retrieved_documents", [])
# (
# citation_numbers,
# answer_string,
# claims,
# ) = extract_document_citations(answer_string, claims)
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
# # if citation is empty, the answer must have come from the KG rather than a doc
# # in that case, simply cite the docs returned by the KG
# if not citation_numbers:
# citation_numbers = [i + 1 for i in range(len(retrieved_docs))]
# if citation is empty, the answer must have come from the KG rather than a doc
# in that case, simply cite the docs returned by the KG
if not citation_numbers:
citation_numbers = [i + 1 for i in range(len(retrieved_docs))]
# cited_documents = {
# citation_number: retrieved_docs[citation_number - 1]
# for citation_number in citation_numbers
# if citation_number <= len(retrieved_docs)
# }
cited_documents = {
citation_number: retrieved_docs[citation_number - 1]
for citation_number in citation_numbers
if citation_number <= len(retrieved_docs)
}
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=kg_tool_info.llm_path,
# tool_id=kg_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=search_query,
# answer=answer_string,
# claims=claims,
# cited_documents=cited_documents,
# reasoning=None,
# additional_data=None,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="kg_search",
# node_name="searching",
# node_start_time=node_start_time,
# )
# ],
# )
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=kg_tool_info.llm_path,
tool_id=kg_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=search_query,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=None,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,121 +1,121 @@
# from datetime import datetime
from datetime import datetime
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import ReasoningDelta
# from onyx.server.query_and_chat.streaming_models import ReasoningStart
# from onyx.server.query_and_chat.streaming_models import SearchToolDelta
# from onyx.server.query_and_chat.streaming_models import SearchToolStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# _MAX_KG_STEAMED_ANSWER_LENGTH = 1000 # num characters
_MAX_KG_STEAMED_ANSWER_LENGTH = 1000 # num characters
# def kg_search_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a KG search as part of the DR process.
# """
def kg_search_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
# node_start_time = datetime.now()
node_start_time = datetime.now()
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
# current_step_nr = state.current_step_nr
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# queries = [update.question for update in new_updates]
# doc_lists = [list(update.cited_documents.values()) for update in new_updates]
queries = [update.question for update in new_updates]
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
# doc_list = []
doc_list = []
# for xs in doc_lists:
# for x in xs:
# doc_list.append(x)
for xs in doc_lists:
for x in xs:
doc_list.append(x)
# retrieved_search_docs = convert_inference_sections_to_search_docs(doc_list)
# kg_answer = (
# "The Knowledge Graph Answer:\n\n" + new_updates[0].answer
# if len(queries) == 1
# else None
# )
retrieved_search_docs = convert_inference_sections_to_search_docs(doc_list)
kg_answer = (
"The Knowledge Graph Answer:\n\n" + new_updates[0].answer
if len(queries) == 1
else None
)
# if len(retrieved_search_docs) > 0:
# write_custom_event(
# current_step_nr,
# SearchToolStart(
# is_internet_search=False,
# ),
# writer,
# )
# write_custom_event(
# current_step_nr,
# SearchToolDelta(
# queries=queries,
# documents=retrieved_search_docs,
# ),
# writer,
# )
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
if len(retrieved_search_docs) > 0:
write_custom_event(
current_step_nr,
SearchToolStart(
is_internet_search=False,
),
writer,
)
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=queries,
documents=retrieved_search_docs,
),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# current_step_nr += 1
current_step_nr += 1
# if kg_answer is not None:
if kg_answer is not None:
# kg_display_answer = (
# f"{kg_answer[:_MAX_KG_STEAMED_ANSWER_LENGTH]}..."
# if len(kg_answer) > _MAX_KG_STEAMED_ANSWER_LENGTH
# else kg_answer
# )
kg_display_answer = (
f"{kg_answer[:_MAX_KG_STEAMED_ANSWER_LENGTH]}..."
if len(kg_answer) > _MAX_KG_STEAMED_ANSWER_LENGTH
else kg_answer
)
# write_custom_event(
# current_step_nr,
# ReasoningStart(),
# writer,
# )
# write_custom_event(
# current_step_nr,
# ReasoningDelta(reasoning=kg_display_answer),
# writer,
# )
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
write_custom_event(
current_step_nr,
ReasoningStart(),
writer,
)
write_custom_event(
current_step_nr,
ReasoningDelta(reasoning=kg_display_answer),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# current_step_nr += 1
current_step_nr += 1
# return SubAgentUpdate(
# iteration_responses=new_updates,
# current_step_nr=current_step_nr,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="kg_search",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,27 +1,27 @@
# from collections.abc import Hashable
from collections.abc import Hashable
# from langgraph.types import Send
from langgraph.types import Send
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# context="",
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:1] # no parallel search for now
# )
# ]
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel search for now
)
]

View File

@@ -1,50 +1,50 @@
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_1_branch import (
# kg_search_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_2_act import (
# kg_search,
# )
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_3_reduce import (
# kg_search_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_1_branch import (
kg_search_branch,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_2_act import (
kg_search,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_3_reduce import (
kg_search_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# logger = setup_logger()
logger = setup_logger()
# def dr_kg_search_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for KG Search Sub-Agent
# """
def dr_kg_search_graph_builder() -> StateGraph:
"""
LangGraph graph builder for KG Search Sub-Agent
"""
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# ### Add nodes ###
### Add nodes ###
# graph.add_node("branch", kg_search_branch)
graph.add_node("branch", kg_search_branch)
# graph.add_node("act", kg_search)
graph.add_node("act", kg_search)
# graph.add_node("reducer", kg_search_reducer)
graph.add_node("reducer", kg_search_reducer)
# ### Add edges ###
### Add edges ###
# graph.add_edge(start_key=START, end_key="branch")
graph.add_edge(start_key=START, end_key="branch")
# graph.add_conditional_edges("branch", branching_router)
graph.add_conditional_edges("branch", branching_router)
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="reducer", end_key=END)
graph.add_edge(start_key="reducer", end_key=END)
# return graph
return graph

View File

@@ -1,46 +1,46 @@
# from operator import add
# from typing import Annotated
from operator import add
from typing import Annotated
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.models import OrchestratorTool
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.db.connector import DocumentSource
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.db.connector import DocumentSource
# class SubAgentUpdate(LoggerUpdate):
# iteration_responses: Annotated[list[IterationAnswer], add] = []
# current_step_nr: int = 1
class SubAgentUpdate(LoggerUpdate):
iteration_responses: Annotated[list[IterationAnswer], add] = []
current_step_nr: int = 1
# class BranchUpdate(LoggerUpdate):
# branch_iteration_responses: Annotated[list[IterationAnswer], add] = []
class BranchUpdate(LoggerUpdate):
branch_iteration_responses: Annotated[list[IterationAnswer], add] = []
# class SubAgentInput(LoggerUpdate):
# iteration_nr: int = 0
# current_step_nr: int = 1
# query_list: list[str] = []
# context: str | None = None
# active_source_types: list[DocumentSource] | None = None
# tools_used: Annotated[list[str], add] = []
# available_tools: dict[str, OrchestratorTool] | None = None
# assistant_system_prompt: str | None = None
# assistant_task_prompt: str | None = None
class SubAgentInput(LoggerUpdate):
iteration_nr: int = 0
current_step_nr: int = 1
query_list: list[str] = []
context: str | None = None
active_source_types: list[DocumentSource] | None = None
tools_used: Annotated[list[str], add] = []
available_tools: dict[str, OrchestratorTool] | None = None
assistant_system_prompt: str | None = None
assistant_task_prompt: str | None = None
# class SubAgentMainState(
# # This includes the core state
# SubAgentInput,
# SubAgentUpdate,
# BranchUpdate,
# ):
# pass
class SubAgentMainState(
# This includes the core state
SubAgentInput,
SubAgentUpdate,
BranchUpdate,
):
pass
# class BranchInput(SubAgentInput):
# parallelization_nr: int = 0
# branch_question: str
class BranchInput(SubAgentInput):
parallelization_nr: int = 0
branch_question: str
# class CustomToolBranchInput(LoggerUpdate):
# tool_info: OrchestratorTool
class CustomToolBranchInput(LoggerUpdate):
tool_info: OrchestratorTool

Some files were not shown because too many files have changed in this diff Show More