Compare commits

...

48 Commits

Author SHA1 Message Date
Yuhong Sun
c8dbd8c7d1 project citations still not working 2025-11-28 23:22:35 -08:00
Yuhong Sun
aa37f7b162 ok 2025-11-28 22:39:38 -08:00
Yuhong Sun
a121c2cd29 k 2025-11-28 20:15:03 -08:00
Yuhong Sun
ba38585dce deleted some tests 2025-11-28 19:57:39 -08:00
Yuhong Sun
afa574dd98 ok 2025-11-28 19:51:10 -08:00
Yuhong Sun
5459b1253e k 2025-11-28 19:21:34 -08:00
Vega
1e38af1dd2 Stop Signal Handling (#6474) 2025-11-28 10:07:54 -08:00
Yuhong Sun
862b2c30bb k 2025-11-28 09:47:32 -08:00
Yuhong Sun
1720e49253 k 2025-11-27 16:51:28 -08:00
Yuhong Sun
72d1292fb8 fix import 2025-11-27 16:48:45 -08:00
Yuhong Sun
9092e09230 ok 2025-11-27 16:47:59 -08:00
Yuhong Sun
f90204f399 checkpoint 2025-11-27 13:32:38 -08:00
Yuhong Sun
9d49974249 small 2025-11-26 17:52:15 -08:00
Yuhong Sun
af58c1bdb9 small 2025-11-26 17:25:27 -08:00
Yuhong Sun
c71f154ed1 minor 2025-11-26 17:13:05 -08:00
Yuhong Sun
f6eea7c619 Basic history works 2025-11-26 15:08:58 -08:00
Yuhong Sun
ffc3dc2dc0 reminders 2025-11-26 11:22:59 -08:00
Yuhong Sun
6afde1b9b6 Image Gen works 2025-11-25 19:22:33 -08:00
Yuhong Sun
5445a1c0bd Minor fixes 2025-11-25 14:03:06 -08:00
Yuhong Sun
8f62f92dfb fix reasoning 2025-11-25 13:23:28 -08:00
Yuhong Sun
b6c3b18031 fixed replay 2025-11-25 12:00:25 -08:00
Yuhong Sun
4042a15cbc open url tool 2025-11-25 10:09:59 -08:00
Yuhong Sun
5480e89076 Web Search works again 2025-11-24 16:31:09 -08:00
Yuhong Sun
7534bad959 Minimal Replay Working 2025-11-24 15:00:49 -08:00
vega
5cc4f593d2 Fix citation processor 2025-11-23 22:17:50 -08:00
Yuhong Sun
75b6110cd9 working on loading next 2025-11-23 20:34:54 -08:00
Yuhong Sun
35a93ebf75 message saving updated 2025-11-23 17:58:23 -08:00
Yuhong Sun
19ea85f3d9 get answer from citation processor 2025-11-23 15:04:28 -08:00
Yuhong Sun
37ac066a9d Add save chat 2025-11-23 13:25:56 -08:00
Yuhong Sun
3ecd3422f0 citations not working 2025-11-23 13:21:07 -08:00
Yuhong Sun
961909b401 fixed-breaking-issues 2025-11-22 17:13:14 -08:00
Yuhong Sun
44c519f5e8 Nov22updates 2025-11-22 17:04:56 -08:00
Yuhong Sun
1be49fb51c merge-main 2025-11-20 18:55:05 -08:00
Yuhong Sun
e429ebc1be Search Tool Insides (#6348) 2025-11-19 16:32:44 -08:00
Yuhong Sun
f8dbe4f307 new message format 2025-11-15 15:25:04 -08:00
Yuhong Sun
8dd0e0fa9a merge-main 2025-11-14 16:14:24 -08:00
Yuhong Sun
3394637cad Tools new interface (#6233) 2025-11-14 14:11:06 -08:00
Yuhong Sun
abc53d609b Saving the turn (#6225) 2025-11-13 16:20:17 -08:00
Yuhong Sun
65580d1104 Fix web search (#6218) 2025-11-13 12:56:08 -08:00
Yuhong Sun
6bb2755153 Tool cleanup (#6217) 2025-11-13 12:01:31 -08:00
Yuhong Sun
c1f1731c55 Small interface update 2025-11-13 11:51:31 -08:00
Yuhong Sun
75c0b0d6a1 Tools cleanup (#6214) 2025-11-13 11:12:36 -08:00
Yuhong Sun
06371b9383 Interfaces (#6210) 2025-11-12 19:48:53 -08:00
Yuhong Sun
978f5e2618 Update README.md 2025-11-12 19:02:40 -08:00
Yuhong Sun
60d7064fd9 Update README.md 2025-11-12 18:10:02 -08:00
Yuhong Sun
424a716d08 Update README.md 2025-11-12 17:21:40 -08:00
Yuhong Sun
64718cd066 Add explanation for context (#6207) 2025-11-12 17:19:11 -08:00
Yuhong Sun
cd0f58808a Feat branch/backend refactor (#6205) 2025-11-12 16:12:02 -08:00
540 changed files with 47905 additions and 27288 deletions

View File

@@ -16,6 +16,7 @@ self-hosted-runner:
- runner=16cpu-linux-arm64
- runner=16cpu-linux-x64
- ubuntu-slim # Currently in public preview
- volume=40gb
# Configuration variables in array of strings defined in your repository or
# organization. `null` means disabling configuration variables check.

View File

@@ -7,9 +7,9 @@ runs:
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-${{ hashFiles('backend/requirements/default.txt') }}
key: ${{ runner.os }}-${{ runner.arch }}-playwright-${{ hashFiles('backend/requirements/default.txt') }}
restore-keys: |
${{ runner.os }}-playwright-
${{ runner.os }}-${{ runner.arch }}-playwright-
- name: Install playwright
shell: bash

101
.github/actions/slack-notify/action.yml vendored Normal file
View File

@@ -0,0 +1,101 @@
name: "Slack Notify on Failure"
description: "Sends a Slack notification when a workflow fails"
inputs:
webhook-url:
description: "Slack webhook URL (can also use SLACK_WEBHOOK_URL env var)"
required: false
failed-jobs:
description: "List of failed job names (newline-separated)"
required: false
title:
description: "Title for the notification"
required: false
default: "🚨 Workflow Failed"
ref-name:
description: "Git ref name (tag/branch)"
required: false
runs:
using: "composite"
steps:
- name: Send Slack notification
shell: bash
env:
SLACK_WEBHOOK_URL: ${{ inputs.webhook-url }}
run: |
if [ -z "$SLACK_WEBHOOK_URL" ]; then
echo "webhook-url input or SLACK_WEBHOOK_URL env var is not set, skipping notification"
exit 0
fi
# Get inputs with defaults
FAILED_JOBS="${{ inputs.failed-jobs }}"
TITLE="${{ inputs.title }}"
REF_NAME="${{ inputs.ref-name }}"
REPO="${{ github.repository }}"
WORKFLOW="${{ github.workflow }}"
RUN_NUMBER="${{ github.run_number }}"
RUN_ID="${{ github.run_id }}"
SERVER_URL="${{ github.server_url }}"
WORKFLOW_URL="${SERVER_URL}/${REPO}/actions/runs/${RUN_ID}"
# Use ref_name from input or fall back to github.ref_name
if [ -z "$REF_NAME" ]; then
REF_NAME="${{ github.ref_name }}"
fi
# Escape JSON special characters
escape_json() {
local input="$1"
# Escape backslashes first (but preserve \n sequences)
# Protect \n sequences temporarily
input=$(printf '%s' "$input" | sed 's/\\n/\x01NL\x01/g')
# Escape remaining backslashes
input=$(printf '%s' "$input" | sed 's/\\/\\\\/g')
# Restore \n sequences (single backslash, will be correct in JSON)
input=$(printf '%s' "$input" | sed 's/\x01NL\x01/\\n/g')
# Escape quotes
printf '%s' "$input" | sed 's/"/\\"/g'
}
REF_NAME_ESC=$(escape_json "$REF_NAME")
FAILED_JOBS_ESC=$(escape_json "$FAILED_JOBS")
WORKFLOW_URL_ESC=$(escape_json "$WORKFLOW_URL")
TITLE_ESC=$(escape_json "$TITLE")
# Build JSON payload piece by piece
# Note: FAILED_JOBS_ESC already contains \n sequences that should remain as \n in JSON
PAYLOAD="{"
PAYLOAD="${PAYLOAD}\"text\":\"${TITLE_ESC}\","
PAYLOAD="${PAYLOAD}\"blocks\":[{"
PAYLOAD="${PAYLOAD}\"type\":\"header\","
PAYLOAD="${PAYLOAD}\"text\":{\"type\":\"plain_text\",\"text\":\"${TITLE_ESC}\"}"
PAYLOAD="${PAYLOAD}},{"
PAYLOAD="${PAYLOAD}\"type\":\"section\","
PAYLOAD="${PAYLOAD}\"fields\":["
if [ -n "$REF_NAME" ]; then
PAYLOAD="${PAYLOAD}{\"type\":\"mrkdwn\",\"text\":\"*Ref:*\\n${REF_NAME_ESC}\"},"
fi
PAYLOAD="${PAYLOAD}{\"type\":\"mrkdwn\",\"text\":\"*Run ID:*\\n#${RUN_NUMBER}\"}"
PAYLOAD="${PAYLOAD}]"
PAYLOAD="${PAYLOAD}}"
if [ -n "$FAILED_JOBS" ]; then
PAYLOAD="${PAYLOAD},{"
PAYLOAD="${PAYLOAD}\"type\":\"section\","
PAYLOAD="${PAYLOAD}\"text\":{\"type\":\"mrkdwn\",\"text\":\"*Failed Jobs:*\\n${FAILED_JOBS_ESC}\"}"
PAYLOAD="${PAYLOAD}}"
fi
PAYLOAD="${PAYLOAD},{"
PAYLOAD="${PAYLOAD}\"type\":\"actions\","
PAYLOAD="${PAYLOAD}\"elements\":[{"
PAYLOAD="${PAYLOAD}\"type\":\"button\","
PAYLOAD="${PAYLOAD}\"text\":{\"type\":\"plain_text\",\"text\":\"View Workflow Run\"},"
PAYLOAD="${PAYLOAD}\"url\":\"${WORKFLOW_URL_ESC}\""
PAYLOAD="${PAYLOAD}}]"
PAYLOAD="${PAYLOAD}}"
PAYLOAD="${PAYLOAD}]"
PAYLOAD="${PAYLOAD}}"
curl -X POST -H 'Content-type: application/json' \
--data "$PAYLOAD" \
"$SLACK_WEBHOOK_URL"

View File

@@ -10,6 +10,9 @@ on:
- main
- 'release/**'
permissions:
contents: read
jobs:
check-lazy-imports:
runs-on: ubuntu-latest
@@ -17,9 +20,11 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@7f4fc3e22c37d6ff65e88745f38bd3157c663f7c # ratchet:actions/setup-python@v4
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # ratchet:actions/setup-python@v6
with:
python-version: '3.11'

File diff suppressed because it is too large Load Diff

View File

@@ -10,6 +10,9 @@ on:
description: "The version (ie v1.0.0-beta.0) to tag as beta"
required: true
permissions:
contents: read
jobs:
tag:
# See https://runs-on.com/runners/linux/
@@ -20,7 +23,7 @@ jobs:
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@dd4fa0671be5250ee6f50aedf4cb05514abda2c7 # ratchet:docker/login-action@v1
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -29,13 +32,19 @@ jobs:
run: echo "DOCKER_CLI_EXPERIMENTAL=enabled" >> $GITHUB_ENV
- name: Pull, Tag and Push Web Server Image
env:
VERSION: ${{ github.event.inputs.version }}
run: |
docker buildx imagetools create -t onyxdotapp/onyx-web-server:beta onyxdotapp/onyx-web-server:${{ github.event.inputs.version }}
docker buildx imagetools create -t onyxdotapp/onyx-web-server:beta onyxdotapp/onyx-web-server:${VERSION}
- name: Pull, Tag and Push API Server Image
env:
VERSION: ${{ github.event.inputs.version }}
run: |
docker buildx imagetools create -t onyxdotapp/onyx-backend:beta onyxdotapp/onyx-backend:${{ github.event.inputs.version }}
docker buildx imagetools create -t onyxdotapp/onyx-backend:beta onyxdotapp/onyx-backend:${VERSION}
- name: Pull, Tag and Push Model Server Image
env:
VERSION: ${{ github.event.inputs.version }}
run: |
docker buildx imagetools create -t onyxdotapp/onyx-model-server:beta onyxdotapp/onyx-model-server:${{ github.event.inputs.version }}
docker buildx imagetools create -t onyxdotapp/onyx-model-server:beta onyxdotapp/onyx-model-server:${VERSION}

View File

@@ -10,6 +10,9 @@ on:
description: "The version (ie v0.0.1) to tag as latest"
required: true
permissions:
contents: read
jobs:
tag:
# See https://runs-on.com/runners/linux/
@@ -20,7 +23,7 @@ jobs:
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@dd4fa0671be5250ee6f50aedf4cb05514abda2c7 # ratchet:docker/login-action@v1
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -29,13 +32,19 @@ jobs:
run: echo "DOCKER_CLI_EXPERIMENTAL=enabled" >> $GITHUB_ENV
- name: Pull, Tag and Push Web Server Image
env:
VERSION: ${{ github.event.inputs.version }}
run: |
docker buildx imagetools create -t onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:${{ github.event.inputs.version }}
docker buildx imagetools create -t onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:${VERSION}
- name: Pull, Tag and Push API Server Image
env:
VERSION: ${{ github.event.inputs.version }}
run: |
docker buildx imagetools create -t onyxdotapp/onyx-backend:latest onyxdotapp/onyx-backend:${{ github.event.inputs.version }}
docker buildx imagetools create -t onyxdotapp/onyx-backend:latest onyxdotapp/onyx-backend:${VERSION}
- name: Pull, Tag and Push Model Server Image
env:
VERSION: ${{ github.event.inputs.version }}
run: |
docker buildx imagetools create -t onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:${{ github.event.inputs.version }}
docker buildx imagetools create -t onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:${VERSION}

View File

@@ -17,6 +17,7 @@ jobs:
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Install Helm CLI
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4

View File

@@ -15,19 +15,24 @@ on:
permissions:
actions: read
contents: read
security-events: write
jobs:
scan-licenses:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-licenses"]
permissions:
actions: read
contents: read
security-events: write
steps:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # ratchet:actions/setup-python@v5
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # ratchet:actions/setup-python@v6
with:
python-version: '3.11'
cache: 'pip'
@@ -54,7 +59,9 @@ jobs:
- name: Print report
if: always()
run: echo "${{ steps.license_check_report.outputs.report }}"
env:
REPORT: ${{ steps.license_check_report.outputs.report }}
run: echo "$REPORT"
- name: Install npm dependencies
working-directory: ./web

View File

@@ -8,6 +8,9 @@ on:
pull_request:
branches: [main]
permissions:
contents: read
env:
# AWS
S3_AWS_ACCESS_KEY_ID: ${{ secrets.S3_AWS_ACCESS_KEY_ID }}
@@ -37,6 +40,8 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Discover test directories
id: set-matrix
@@ -67,6 +72,8 @@ jobs:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Setup Python and Install Dependencies
uses: ./.github/actions/setup-python-and-install-dependencies
@@ -97,10 +104,12 @@ jobs:
- name: Run Tests for ${{ matrix.test-dir }}
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
env:
TEST_DIR: ${{ matrix.test-dir }}
run: |
py.test \
--durations=8 \
-o junit_family=xunit2 \
-xv \
--ff \
backend/tests/external_dependency_unit/${{ matrix.test-dir }}
backend/tests/external_dependency_unit/${TEST_DIR}

View File

@@ -9,6 +9,9 @@ on:
branches: [ main ]
workflow_dispatch: # Allows manual triggering
permissions:
contents: read
jobs:
helm-chart-check:
# See https://runs-on.com/runners/linux/
@@ -20,6 +23,7 @@ jobs:
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Set up Helm
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4.3.1
@@ -32,9 +36,11 @@ jobs:
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
- name: Run chart-testing (list-changed)
id: list-changed
env:
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
run: |
echo "default_branch: ${{ github.event.repository.default_branch }}"
changed=$(ct list-changed --remote origin --target-branch ${{ github.event.repository.default_branch }} --chart-dirs deployment/helm/charts)
echo "default_branch: ${DEFAULT_BRANCH}"
changed=$(ct list-changed --remote origin --target-branch ${DEFAULT_BRANCH} --chart-dirs deployment/helm/charts)
echo "list-changed output: $changed"
if [[ -n "$changed" ]]; then
echo "changed=true" >> "$GITHUB_OUTPUT"
@@ -54,7 +60,7 @@ jobs:
- name: Create kind cluster
if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@a1b0e391336a6ee6713a0583f8c6240d70863de3 # ratchet:helm/kind-action@v1.12.0
uses: helm/kind-action@92086f6be054225fa813e0a4b13787fc9088faab # ratchet:helm/kind-action@v1.13.0
- name: Pre-install cluster status check
if: steps.list-changed.outputs.changed == 'true'

View File

@@ -10,6 +10,9 @@ on:
- main
- "release/**"
permissions:
contents: read
env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -37,6 +40,8 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Discover test directories
id: set-matrix
@@ -65,6 +70,8 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -85,8 +92,11 @@ jobs:
file: ./backend/Dockerfile
push: true
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-cache
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-cache,mode=max
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
type=registry,ref=onyxdotapp/onyx-backend:latest
cache-to: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
@@ -96,6 +106,8 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -116,8 +128,10 @@ jobs:
file: ./backend/Dockerfile.model_server
push: true
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-cache
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-cache,mode=max
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
type=registry,ref=onyxdotapp/onyx-model-server:latest
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
build-integration-image:
@@ -126,21 +140,33 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Prepare build
uses: ./.github/actions/prepare-build
# needed for pulling openapitools/openapi-generator-cli
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push integration test image with Docker Bake
env:
REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
INTEGRATION_REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
TAG: integration-test-${{ github.run_id }}
run: cd backend && docker buildx bake --push integration
run: |
cd backend && docker buildx bake --push \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
integration
integration-tests:
needs:
@@ -165,6 +191,8 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
@@ -175,27 +203,12 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Pull Docker images
run: |
# Pull all images from registry in parallel
echo "Pulling Docker images in parallel..."
# Pull images from private registry
(docker pull --platform linux/arm64 ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}) &
(docker pull --platform linux/arm64 ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}) &
(docker pull --platform linux/arm64 ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }}) &
# Wait for all background jobs to complete
wait
echo "All Docker images pulled successfully"
# Re-tag to remove registry prefix for docker-compose
docker tag ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }} onyxdotapp/onyx-backend:test
docker tag ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }} onyxdotapp/onyx-model-server:test
docker tag ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} onyxdotapp/onyx-integration:test
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
# NOTE: don't need web server for integration tests
- name: Start Docker containers
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
@@ -204,7 +217,8 @@ jobs:
POSTGRES_USE_NULL_POOL=true \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
INTEGRATION_TESTS_MODE=true \
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 \
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
@@ -281,6 +295,7 @@ jobs:
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
@@ -297,7 +312,7 @@ jobs:
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
onyxdotapp/onyx-integration:test \
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
/app/tests/integration/${{ matrix.test-dir.path }}
# ------------------------------------------------------------
@@ -336,6 +351,8 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
@@ -343,17 +360,10 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Pull Docker images
run: |
(docker pull --platform linux/arm64 ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}) &
(docker pull --platform linux/arm64 ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}) &
(docker pull --platform linux/arm64 ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }}) &
wait
docker tag ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }} onyxdotapp/onyx-backend:test
docker tag ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }} onyxdotapp/onyx-model-server:test
docker tag ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} onyxdotapp/onyx-integration:test
- name: Start Docker containers for multi-tenant tests
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
@@ -361,7 +371,8 @@ jobs:
AUTH_TYPE=cloud \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
DEV_MODE=true \
docker compose -f docker-compose.multitenant-dev.yml up \
relational_db \
@@ -402,6 +413,9 @@ jobs:
echo "Finished waiting for service."
- name: Run Multi-Tenant Integration Tests
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
echo "Running multi-tenant integration tests..."
docker run --rm --network onyx_default \
@@ -417,6 +431,7 @@ jobs:
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e AUTH_TYPE=cloud \
@@ -424,9 +439,8 @@ jobs:
-e SKIP_RESET=true \
-e REQUIRE_EMAIL_VERIFICATION=false \
-e DISABLE_TELEMETRY=true \
-e IMAGE_TAG=test \
-e DEV_MODE=true \
onyxdotapp/onyx-integration:test \
${ECR_CACHE}:integration-test-${RUN_ID} \
/app/tests/integration/multitenant_tests
- name: Dump API server logs (multi-tenant)
@@ -460,13 +474,6 @@ jobs:
needs: [integration-tests, multitenant-tests]
if: ${{ always() }}
steps:
- uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # ratchet:actions/github-script@v8
with:
script: |
const needs = ${{ toJSON(needs) }};
const failed = Object.values(needs).some(n => n.result !== 'success');
if (failed) {
core.setFailed('One or more upstream jobs failed or were cancelled.');
} else {
core.notice('All required jobs succeeded.');
}
- name: Check job status
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
run: exit 1

View File

@@ -5,6 +5,9 @@ concurrency:
on: push
permissions:
contents: read
jobs:
jest-tests:
name: Jest Tests
@@ -12,6 +15,8 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Setup node
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4

View File

@@ -1,7 +1,7 @@
name: PR Labeler
on:
pull_request_target:
pull_request:
branches:
- main
types:
@@ -12,7 +12,6 @@ on:
permissions:
contents: read
pull-requests: write
jobs:
validate_pr_title:

View File

@@ -7,6 +7,9 @@ on:
pull_request:
types: [opened, edited, reopened, synchronize]
permissions:
contents: read
jobs:
linear-check:
runs-on: ubuntu-latest

View File

@@ -7,6 +7,9 @@ on:
merge_group:
types: [checks_requested]
permissions:
contents: read
env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -33,6 +36,8 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Discover test directories
id: set-matrix
@@ -60,6 +65,8 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -80,8 +87,10 @@ jobs:
file: ./backend/Dockerfile
push: true
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-cache
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-cache,mode=max
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
type=registry,ref=onyxdotapp/onyx-backend:latest
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-model-server-image:
@@ -90,6 +99,8 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -110,8 +121,10 @@ jobs:
file: ./backend/Dockerfile.model_server
push: true
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-cache
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-cache,mode=max
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
type=registry,ref=onyxdotapp/onyx-model-server:latest
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
build-integration-image:
runs-on: [runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-build-integration-image", "extras=ecr-cache"]
@@ -119,21 +132,33 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Prepare build
uses: ./.github/actions/prepare-build
# needed for pulling openapitools/openapi-generator-cli
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push integration test image with Docker Bake
env:
REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
INTEGRATION_REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
TAG: integration-test-${{ github.run_id }}
run: cd backend && docker buildx bake --push integration
run: |
cd backend && docker buildx bake --push \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
integration
integration-tests-mit:
needs:
@@ -158,6 +183,8 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
@@ -168,27 +195,12 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Pull Docker images
run: |
# Pull all images from registry in parallel
echo "Pulling Docker images in parallel..."
# Pull images from registry
(docker pull --platform linux/arm64 ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}) &
(docker pull --platform linux/arm64 ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}) &
(docker pull --platform linux/arm64 ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }}) &
# Wait for all background jobs to complete
wait
echo "All Docker images pulled successfully"
# Re-tag to remove registry prefix for docker-compose
docker tag ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }} onyxdotapp/onyx-backend:test
docker tag ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }} onyxdotapp/onyx-model-server:test
docker tag ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} onyxdotapp/onyx-integration:test
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
# NOTE: don't need web server for integration tests
- name: Start Docker containers
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cd deployment/docker_compose
AUTH_TYPE=basic \
@@ -196,7 +208,8 @@ jobs:
POSTGRES_USE_NULL_POOL=true \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
INTEGRATION_TESTS_MODE=true \
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
relational_db \
@@ -289,7 +302,7 @@ jobs:
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
onyxdotapp/onyx-integration:test \
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
/app/tests/integration/${{ matrix.test-dir.path }}
# ------------------------------------------------------------
@@ -321,13 +334,6 @@ jobs:
needs: [integration-tests-mit]
if: ${{ always() }}
steps:
- uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # ratchet:actions/github-script@v8
with:
script: |
const needs = ${{ toJSON(needs) }};
const failed = Object.values(needs).some(n => n.result !== 'success');
if (failed) {
core.setFailed('One or more upstream jobs failed or were cancelled.');
} else {
core.notice('All required jobs succeeded.');
}
- name: Check job status
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
run: exit 1

View File

@@ -5,6 +5,9 @@ concurrency:
on: push
permissions:
contents: read
env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -16,7 +19,23 @@ env:
SLACK_CLIENT_ID: ${{ secrets.SLACK_CLIENT_ID }}
SLACK_CLIENT_SECRET: ${{ secrets.SLACK_CLIENT_SECRET }}
# for MCP Oauth tests
MCP_OAUTH_CLIENT_ID: ${{ secrets.MCP_OAUTH_CLIENT_ID }}
MCP_OAUTH_CLIENT_SECRET: ${{ secrets.MCP_OAUTH_CLIENT_SECRET }}
MCP_OAUTH_ISSUER: ${{ secrets.MCP_OAUTH_ISSUER }}
MCP_OAUTH_JWKS_URI: ${{ secrets.MCP_OAUTH_JWKS_URI }}
MCP_OAUTH_USERNAME: ${{ vars.MCP_OAUTH_USERNAME }}
MCP_OAUTH_PASSWORD: ${{ secrets.MCP_OAUTH_PASSWORD }}
MOCK_LLM_RESPONSE: true
MCP_TEST_SERVER_PORT: 8004
MCP_TEST_SERVER_URL: http://host.docker.internal:8004/mcp
MCP_TEST_SERVER_PUBLIC_URL: http://host.docker.internal:8004/mcp
MCP_TEST_SERVER_BIND_HOST: 0.0.0.0
MCP_TEST_SERVER_PUBLIC_HOST: host.docker.internal
MCP_SERVER_HOST: 0.0.0.0
MCP_SERVER_PUBLIC_HOST: host.docker.internal
MCP_SERVER_PUBLIC_URL: http://host.docker.internal:8004/mcp
jobs:
build-web-image:
@@ -26,6 +45,8 @@ jobs:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -46,8 +67,10 @@ jobs:
platforms: linux/arm64
tags: ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-${{ github.run_id }}
push: true
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-cache
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-cache,mode=max
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache
type=registry,ref=onyxdotapp/onyx-web-server:latest
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache,mode=max
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-backend-image:
@@ -57,6 +80,8 @@ jobs:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -77,8 +102,11 @@ jobs:
platforms: linux/arm64
tags: ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-${{ github.run_id }}
push: true
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-cache
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-cache,mode=max
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
type=registry,ref=onyxdotapp/onyx-backend:latest
cache-to: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-model-server-image:
@@ -88,6 +116,8 @@ jobs:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -108,8 +138,10 @@ jobs:
platforms: linux/arm64
tags: ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-${{ github.run_id }}
push: true
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-cache
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-cache,mode=max
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
type=registry,ref=onyxdotapp/onyx-model-server:latest
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
playwright-tests:
@@ -127,23 +159,7 @@ jobs:
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
fetch-depth: 0
- name: Pull Docker images
run: |
# Pull all images from ECR in parallel
echo "Pulling Docker images in parallel..."
(docker pull ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-${{ github.run_id }}) &
(docker pull ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-${{ github.run_id }}) &
(docker pull ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-${{ github.run_id }}) &
# Wait for all background jobs to complete
wait
echo "All Docker images pulled successfully"
# Re-tag with expected names for docker-compose
docker tag ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-${{ github.run_id }} onyxdotapp/onyx-web-server:test
docker tag ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-${{ github.run_id }} onyxdotapp/onyx-backend:test
docker tag ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-${{ github.run_id }} onyxdotapp/onyx-model-server:test
persist-credentials: false
- name: Setup node
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4
@@ -169,15 +185,22 @@ jobs:
run: npx playwright install --with-deps
- name: Create .env file for Docker Compose
env:
OPENAI_API_KEY_VALUE: ${{ env.OPENAI_API_KEY }}
EXA_API_KEY_VALUE: ${{ env.EXA_API_KEY }}
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cat <<EOF > deployment/docker_compose/.env
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
AUTH_TYPE=basic
GEN_AI_API_KEY=${{ env.OPENAI_API_KEY }}
EXA_API_KEY=${{ env.EXA_API_KEY }}
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
EXA_API_KEY=${EXA_API_KEY_VALUE}
REQUIRE_EMAIL_VERIFICATION=false
DISABLE_TELEMETRY=true
IMAGE_TAG=test
ONYX_BACKEND_IMAGE=${ECR_CACHE}:playwright-test-backend-${RUN_ID}
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:playwright-test-model-server-${RUN_ID}
ONYX_WEB_SERVER_IMAGE=${ECR_CACHE}:playwright-test-web-${RUN_ID}
EOF
# needed for pulling Vespa, Redis, Postgres, and Minio images
@@ -192,7 +215,7 @@ jobs:
- name: Start Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d
docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.mcp-oauth-test.yml up -d
id: start_docker
- name: Wait for service to be ready
@@ -229,12 +252,37 @@ jobs:
done
echo "Finished waiting for service."
- name: Wait for MCP OAuth mock server
run: |
echo "Waiting for MCP OAuth mock server on port ${MCP_TEST_SERVER_PORT:-8004}..."
start_time=$(date +%s)
timeout=120
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. MCP OAuth mock server did not become ready in ${timeout}s."
exit 1
fi
if curl -sf "http://localhost:${MCP_TEST_SERVER_PORT:-8004}/healthz" > /dev/null; then
echo "MCP OAuth mock server is ready!"
break
fi
sleep 3
done
- name: Run Playwright tests
working-directory: ./web
env:
PROJECT: ${{ matrix.project }}
run: |
# Create test-results directory to ensure it exists for artifact upload
mkdir -p test-results
npx playwright test --project ${{ matrix.project }}
npx playwright test --project ${PROJECT}
- uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
if: always()
@@ -247,10 +295,12 @@ jobs:
# save before stopping the containers so the logs can be captured
- name: Save Docker logs
if: success() || failure()
env:
WORKSPACE: ${{ github.workspace }}
run: |
cd deployment/docker_compose
docker compose logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
mv docker-compose.log ${WORKSPACE}/docker-compose.log
- name: Upload logs
if: success() || failure()
@@ -259,6 +309,17 @@ jobs:
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
path: ${{ github.workspace }}/docker-compose.log
playwright-required:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
needs: [playwright-tests]
if: ${{ always() }}
steps:
- name: Check job status
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
run: exit 1
# NOTE: Chromatic UI diff testing is currently disabled.
# We are using Playwright for local and CI testing without visual regression checks.
# Chromatic may be reintroduced in the future for UI diff testing if needed.

View File

@@ -10,6 +10,9 @@ on:
- main
- 'release/**'
permissions:
contents: read
jobs:
mypy-check:
# See https://runs-on.com/runners/linux/
@@ -21,6 +24,8 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
# needed for pulling openapitools/openapi-generator-cli
# otherwise, we hit the "Unauthenticated users" limit
@@ -48,6 +53,9 @@ jobs:
- name: Run MyPy
working-directory: ./backend
env:
MYPY_FORCE_COLOR: 1
TERM: xterm-256color
run: mypy .
- name: Check import order with reorder-python-imports

View File

@@ -11,6 +11,9 @@ on:
# This cron expression runs the job daily at 16:00 UTC (9am PT)
- cron: "0 16 * * *"
permissions:
contents: read
env:
# AWS
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
@@ -132,6 +135,8 @@ jobs:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Setup Python and Install Dependencies
uses: ./.github/actions/setup-python-and-install-dependencies
@@ -214,8 +219,10 @@ jobs:
if: failure() && github.event_name == 'schedule'
env:
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
REPO: ${{ github.repository }}
RUN_ID: ${{ github.run_id }}
run: |
curl -X POST \
-H 'Content-type: application/json' \
--data '{"text":"Scheduled Connector Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
--data "{\"text\":\"Scheduled Connector Tests failed! Check the run at: https://github.com/${REPO}/actions/runs/${RUN_ID}\"}" \
$SLACK_WEBHOOK

View File

@@ -11,6 +11,9 @@ on:
required: false
default: 'main'
permissions:
contents: read
env:
# Bedrock
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
@@ -36,6 +39,8 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
@@ -55,7 +60,7 @@ jobs:
docker tag onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:test
- name: Set up Python
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # ratchet:actions/setup-python@v5
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # ratchet:actions/setup-python@v6
with:
python-version: "3.11"
cache: "pip"
@@ -122,10 +127,12 @@ jobs:
if: failure() && github.event_name == 'schedule'
env:
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
REPO: ${{ github.repository }}
RUN_ID: ${{ github.run_id }}
run: |
curl -X POST \
-H 'Content-type: application/json' \
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
--data "{\"text\":\"Scheduled Model Tests failed! Check the run at: https://github.com/${REPO}/actions/runs/${RUN_ID}\"}" \
$SLACK_WEBHOOK
- name: Dump all-container logs (optional)

View File

@@ -10,6 +10,9 @@ on:
- main
- 'release/**'
permissions:
contents: read
jobs:
backend-check:
# See https://runs-on.com/runners/linux/
@@ -28,6 +31,8 @@ jobs:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Setup Python and Install Dependencies
uses: ./.github/actions/setup-python-and-install-dependencies

View File

@@ -7,6 +7,9 @@ on:
merge_group:
pull_request: null
permissions:
contents: read
jobs:
quality-checks:
# See https://runs-on.com/runners/linux/
@@ -16,7 +19,8 @@ jobs:
- uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # ratchet:actions/setup-python@v5
persist-credentials: false
- uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # ratchet:actions/setup-python@v6
with:
python-version: "3.11"
- name: Setup Terraform

View File

@@ -16,6 +16,7 @@ jobs:
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Install git-filter-repo
run: |

View File

@@ -3,30 +3,29 @@ name: Nightly Tag Push
on:
schedule:
- cron: "0 10 * * *" # Runs every day at 2 AM PST / 3 AM PDT / 10 AM UTC
workflow_dispatch:
permissions:
contents: write # Allows pushing tags to the repository
jobs:
create-and-push-tag:
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-create-and-push-tag"]
runs-on: ubuntu-slim
steps:
# actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
# implement here which needs an actual user's deploy key
# Additional NOTE: even though this is named "rkuo", the actual key is tied to the onyx repo
# and not rkuo's personal account. It is fine to leave this key as is!
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
ssh-key: "${{ secrets.DEPLOY_KEY }}"
persist-credentials: true
- name: Set up Git user
run: |
git config user.name "Richard Kuo [bot]"
git config user.email "rkuo[bot]@onyx.app"
git config user.name "Onyx Bot [bot]"
git config user.email "onyx-bot[bot]@onyx.app"
- name: Check for existing nightly tag
id: check_tag
@@ -54,3 +53,12 @@ jobs:
run: |
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
git push origin $TAG_NAME
- name: Send Slack notification
if: failure()
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
title: "🚨 Nightly Tag Push Failed"
ref-name: ${{ github.ref_name }}
failed-jobs: "create-and-push-tag"

35
.github/workflows/zizmor.yml vendored Normal file
View File

@@ -0,0 +1,35 @@
name: Run Zizmor
on:
push:
branches: ["main"]
pull_request:
branches: ["**"]
permissions: {}
jobs:
zizmor:
name: zizmor
runs-on: ubuntu-slim
permissions:
security-events: write # needed for SARIF uploads
steps:
- name: Checkout repository
uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # ratchet:actions/checkout@v5.0.1
with:
persist-credentials: false
- name: Install the latest version of uv
uses: astral-sh/setup-uv@5a7eac68fb9809dea845d802897dc5c723910fa3 # ratchet:astral-sh/setup-uv@v7.1.3
- name: Run zizmor
run: uvx zizmor==1.16.3 --format=sarif . > results.sarif
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Upload SARIF file
uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab # ratchet:github/codeql-action/upload-sarif@codeql-bundle-v2.23.5
with:
sarif_file: results.sarif
category: zizmor

2
.gitignore vendored
View File

@@ -46,5 +46,7 @@ CLAUDE.md
# Local .terraform.lock.hcl file
.terraform.lock.hcl
node_modules
# MCP configs
.playwright-mcp

View File

@@ -65,15 +65,12 @@ repos:
language: system
pass_filenames: false
files: \.tf$
- id: check-lazy-imports
name: Check lazy imports are not directly imported
name: Check lazy imports
entry: python3 backend/scripts/check_lazy_imports.py
language: system
files: ^backend/(?!\.venv/).*\.py$
pass_filenames: false
# Note: pass_filenames is false because tsc must check the entire
# project, but the files filter ensures this only runs when relevant
# files change. Using --incremental for faster subsequent checks.
# We would like to have a mypy pre-commit hook, but due to the fact that
# pre-commit runs in it's own isolated environment, we would need to install

View File

@@ -21,6 +21,11 @@
</a>
</p>
<p align="center">
<a href="https://trendshift.io/repositories/12516" target="_blank">
<img src="https://trendshift.io/api/badge/repositories/12516" alt="onyx-dot-app/onyx | Trendshift" style="width: 250px; height: 55px;" />
</a>
</p>
**[Onyx](https://www.onyx.app/?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme)** is a feature-rich, self-hostable Chat UI that works with any LLM. It is easy to deploy and can run in a completely airgapped environment.

View File

@@ -7,15 +7,12 @@ have a contract or agreement with DanswerAI, you are not permitted to use the En
Edition features outside of personal development or testing purposes. Please reach out to \
founders@onyx.app for more information. Please visit https://github.com/onyx-dot-app/onyx"
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
ARG ONYX_VERSION=0.0.0-dev
# DO_NOT_TRACK is used to disable telemetry for Unstructured
ENV ONYX_VERSION=${ONYX_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true" \
ENV DANSWER_RUNNING_IN_DOCKER="true" \
DO_NOT_TRACK="true" \
PLAYWRIGHT_BROWSERS_PATH="/app/.cache/ms-playwright"
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
COPY --from=ghcr.io/astral-sh/uv:0.9.9 /uv /uvx /bin/
# Install system dependencies
# cmake needed for psycopg (postgres)
@@ -128,6 +125,10 @@ COPY --chown=onyx:onyx ./assets /app/assets
ENV PYTHONPATH=/app
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
ARG ONYX_VERSION=0.0.0-dev
ENV ONYX_VERSION=${ONYX_VERSION}
# Default command which does nothing
# This container is used by api server and background which specify their own CMD
CMD ["tail", "-f", "/dev/null"]

View File

@@ -6,13 +6,10 @@ AI models for Onyx. This container and all the code is MIT Licensed and free for
You can find it at https://hub.docker.com/r/onyx/onyx-model-server. For more details, \
visit https://github.com/onyx-dot-app/onyx."
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
ARG ONYX_VERSION=0.0.0-dev
ENV ONYX_VERSION=${ONYX_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true" \
ENV DANSWER_RUNNING_IN_DOCKER="true" \
HF_HOME=/app/.cache/huggingface
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
COPY --from=ghcr.io/astral-sh/uv:0.9.9 /uv /uvx /bin/
# Create non-root user for security best practices
RUN mkdir -p /app && \
@@ -23,24 +20,6 @@ RUN mkdir -p /app && \
chmod 755 /var/log/onyx && \
chown onyx:onyx /var/log/onyx
# --- add toolchain needed for Rust/Python builds (fastuuid) ---
ENV RUSTUP_HOME=/usr/local/rustup \
CARGO_HOME=/usr/local/cargo \
PATH=/usr/local/cargo/bin:$PATH
RUN set -eux; \
apt-get update && apt-get install -y --no-install-recommends \
build-essential \
pkg-config \
curl \
ca-certificates \
# Install latest stable Rust (supports Cargo.lock v4)
&& curl -sSf https://sh.rustup.rs | sh -s -- -y --profile minimal --default-toolchain stable \
&& rustc --version && cargo --version \
&& apt-get remove -y --allow-remove-essential perl-base \
&& apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/*
COPY ./requirements/model_server.txt /tmp/requirements.txt
RUN uv pip install --system --no-cache-dir --upgrade \
-r /tmp/requirements.txt && \
@@ -83,4 +62,8 @@ COPY ./model_server /app/model_server
ENV PYTHONPATH=/app
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
ARG ONYX_VERSION=0.0.0-dev
ENV ONYX_VERSION=${ONYX_VERSION}
CMD ["uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000"]

View File

@@ -0,0 +1,89 @@
"""add internet search and content provider tables
Revision ID: 1f2a3b4c5d6e
Revises: 9drpiiw74ljy
Create Date: 2025-11-10 19:45:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "1f2a3b4c5d6e"
down_revision = "9drpiiw74ljy"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"internet_search_provider",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("name", sa.String(), nullable=False, unique=True),
sa.Column("provider_type", sa.String(), nullable=False),
sa.Column("api_key", sa.LargeBinary(), nullable=True),
sa.Column("config", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column(
"is_active", sa.Boolean(), nullable=False, server_default=sa.text("false")
),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
)
op.create_index(
"ix_internet_search_provider_is_active",
"internet_search_provider",
["is_active"],
)
op.create_table(
"internet_content_provider",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("name", sa.String(), nullable=False, unique=True),
sa.Column("provider_type", sa.String(), nullable=False),
sa.Column("api_key", sa.LargeBinary(), nullable=True),
sa.Column("config", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column(
"is_active", sa.Boolean(), nullable=False, server_default=sa.text("false")
),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
)
op.create_index(
"ix_internet_content_provider_is_active",
"internet_content_provider",
["is_active"],
)
def downgrade() -> None:
op.drop_index(
"ix_internet_content_provider_is_active", table_name="internet_content_provider"
)
op.drop_table("internet_content_provider")
op.drop_index(
"ix_internet_search_provider_is_active", table_name="internet_search_provider"
)
op.drop_table("internet_search_provider")

View File

@@ -0,0 +1,89 @@
"""seed_exa_provider_from_env
Revision ID: 3c9a65f1207f
Revises: 1f2a3b4c5d6e
Create Date: 2025-11-20 19:18:00.000000
"""
from __future__ import annotations
import os
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from dotenv import load_dotenv, find_dotenv
from onyx.utils.encryption import encrypt_string_to_bytes
revision = "3c9a65f1207f"
down_revision = "1f2a3b4c5d6e"
branch_labels = None
depends_on = None
EXA_PROVIDER_NAME = "Exa"
def _get_internet_search_table(metadata: sa.MetaData) -> sa.Table:
return sa.Table(
"internet_search_provider",
metadata,
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("name", sa.String),
sa.Column("provider_type", sa.String),
sa.Column("api_key", sa.LargeBinary),
sa.Column("config", postgresql.JSONB),
sa.Column("is_active", sa.Boolean),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
)
def upgrade() -> None:
load_dotenv(find_dotenv())
exa_api_key = os.environ.get("EXA_API_KEY")
if not exa_api_key:
return
bind = op.get_bind()
metadata = sa.MetaData()
table = _get_internet_search_table(metadata)
existing = bind.execute(
sa.select(table.c.id).where(table.c.name == EXA_PROVIDER_NAME)
).first()
if existing:
return
encrypted_key = encrypt_string_to_bytes(exa_api_key)
has_active_provider = bind.execute(
sa.select(table.c.id).where(table.c.is_active.is_(True))
).first()
bind.execute(
table.insert().values(
name=EXA_PROVIDER_NAME,
provider_type="exa",
api_key=encrypted_key,
config=None,
is_active=not bool(has_active_provider),
)
)
def downgrade() -> None:
return

View File

@@ -0,0 +1,104 @@
"""add_open_url_tool
Revision ID: 4f8a2b3c1d9e
Revises: a852cbe15577
Create Date: 2025-11-24 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4f8a2b3c1d9e"
down_revision = "a852cbe15577"
branch_labels = None
depends_on = None
OPEN_URL_TOOL = {
"name": "OpenURLTool",
"display_name": "Open URL",
"description": (
"The Open URL Action allows the agent to fetch and read contents of web pages."
),
"in_code_tool_id": "OpenURLTool",
"enabled": True,
}
def upgrade() -> None:
conn = op.get_bind()
# Check if tool already exists
existing = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
{"in_code_tool_id": OPEN_URL_TOOL["in_code_tool_id"]},
).fetchone()
if existing:
tool_id = existing[0]
# Update existing tool
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description
WHERE in_code_tool_id = :in_code_tool_id
"""
),
OPEN_URL_TOOL,
)
else:
# Insert new tool
conn.execute(
sa.text(
"""
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled)
"""
),
OPEN_URL_TOOL,
)
# Get the newly inserted tool's id
result = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
{"in_code_tool_id": OPEN_URL_TOOL["in_code_tool_id"]},
).fetchone()
tool_id = result[0] # type: ignore
# Associate the tool with all existing personas
# Get all persona IDs
persona_ids = conn.execute(sa.text("SELECT id FROM persona")).fetchall()
for (persona_id,) in persona_ids:
# Check if association already exists
exists = conn.execute(
sa.text(
"""
SELECT 1 FROM persona__tool
WHERE persona_id = :persona_id AND tool_id = :tool_id
"""
),
{"persona_id": persona_id, "tool_id": tool_id},
).fetchone()
if not exists:
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (:persona_id, :tool_id)
"""
),
{"persona_id": persona_id, "tool_id": tool_id},
)
def downgrade() -> None:
# We don't remove the tool on downgrade since it's fine to have it around.
# If we upgrade again, it will be a no-op.
pass

View File

@@ -0,0 +1,97 @@
"""add config to federated_connector
Revision ID: 9drpiiw74ljy
Revises: 2acdef638fc2
Create Date: 2025-11-03 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "9drpiiw74ljy"
down_revision = "2acdef638fc2"
branch_labels = None
depends_on = None
def upgrade() -> None:
connection = op.get_bind()
# Check if column already exists in current schema
result = connection.execute(
sa.text(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_schema = current_schema()
AND table_name = 'federated_connector'
AND column_name = 'config'
"""
)
)
column_exists = result.fetchone() is not None
# Add config column with default empty object (only if it doesn't exist)
if not column_exists:
op.add_column(
"federated_connector",
sa.Column(
"config", postgresql.JSONB(), nullable=False, server_default="{}"
),
)
# Data migration: Single bulk update for all Slack connectors
connection.execute(
sa.text(
"""
WITH connector_configs AS (
SELECT
fc.id as connector_id,
CASE
WHEN fcds.entities->'channels' IS NOT NULL
AND jsonb_typeof(fcds.entities->'channels') = 'array'
AND jsonb_array_length(fcds.entities->'channels') > 0
THEN
jsonb_build_object(
'channels', fcds.entities->'channels',
'search_all_channels', false
) ||
CASE
WHEN fcds.entities->'include_dm' IS NOT NULL
THEN jsonb_build_object('include_dm', fcds.entities->'include_dm')
ELSE '{}'::jsonb
END
ELSE
jsonb_build_object('search_all_channels', true) ||
CASE
WHEN fcds.entities->'include_dm' IS NOT NULL
THEN jsonb_build_object('include_dm', fcds.entities->'include_dm')
ELSE '{}'::jsonb
END
END as config
FROM federated_connector fc
LEFT JOIN LATERAL (
SELECT entities
FROM federated_connector__document_set
WHERE federated_connector_id = fc.id
AND entities IS NOT NULL
ORDER BY id
LIMIT 1
) fcds ON true
WHERE fc.source = 'FEDERATED_SLACK'
AND fcds.entities IS NOT NULL
)
UPDATE federated_connector fc
SET config = cc.config
FROM connector_configs cc
WHERE fc.id = cc.connector_id
"""
)
)
def downgrade() -> None:
op.drop_column("federated_connector", "config")

View File

@@ -0,0 +1,572 @@
"""New Chat History
Revision ID: a852cbe15577
Revises: 3c9a65f1207f
Create Date: 2025-11-08 15:16:37.781308
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "a852cbe15577"
down_revision = "3c9a65f1207f"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop research agent tables (if they exist)
op.execute("DROP TABLE IF EXISTS research_agent_iteration_sub_step CASCADE")
op.execute("DROP TABLE IF EXISTS research_agent_iteration CASCADE")
# Drop agent sub query and sub question tables (if they exist)
op.execute("DROP TABLE IF EXISTS agent__sub_query__search_doc CASCADE")
op.execute("DROP TABLE IF EXISTS agent__sub_query CASCADE")
op.execute("DROP TABLE IF EXISTS agent__sub_question CASCADE")
# Update ChatMessage table
# Rename parent_message to parent_message_id and make it a foreign key (if not already done)
conn = op.get_bind()
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'chat_message' AND column_name = 'parent_message'
"""
)
)
if result.fetchone():
op.alter_column(
"chat_message", "parent_message", new_column_name="parent_message_id"
)
op.create_foreign_key(
"fk_chat_message_parent_message_id",
"chat_message",
"chat_message",
["parent_message_id"],
["id"],
)
# Rename latest_child_message to latest_child_message_id and make it a foreign key (if not already done)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'chat_message' AND column_name = 'latest_child_message'
"""
)
)
if result.fetchone():
op.alter_column(
"chat_message",
"latest_child_message",
new_column_name="latest_child_message_id",
)
op.create_foreign_key(
"fk_chat_message_latest_child_message_id",
"chat_message",
"chat_message",
["latest_child_message_id"],
["id"],
)
# Add reasoning_tokens column (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'chat_message' AND column_name = 'reasoning_tokens'
"""
)
)
if not result.fetchone():
op.add_column(
"chat_message", sa.Column("reasoning_tokens", sa.Text(), nullable=True)
)
# Drop columns no longer needed (if they exist)
for col in [
"rephrased_query",
"alternate_assistant_id",
"overridden_model",
"is_agentic",
"refined_answer_improvement",
"research_type",
"research_plan",
"research_answer_purpose",
]:
result = conn.execute(
sa.text(
f"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'chat_message' AND column_name = '{col}'
"""
)
)
if result.fetchone():
op.drop_column("chat_message", col)
# Update ToolCall table
# Add chat_session_id column (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'chat_session_id'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
)
op.create_foreign_key(
"fk_tool_call_chat_session_id",
"tool_call",
"chat_session",
["chat_session_id"],
["id"],
)
# Rename message_id to parent_chat_message_id and make nullable (if not already done)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'message_id'
"""
)
)
if result.fetchone():
op.alter_column(
"tool_call",
"message_id",
new_column_name="parent_chat_message_id",
nullable=True,
)
# Add parent_tool_call_id (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'parent_tool_call_id'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call", sa.Column("parent_tool_call_id", sa.Integer(), nullable=True)
)
op.create_foreign_key(
"fk_tool_call_parent_tool_call_id",
"tool_call",
"tool_call",
["parent_tool_call_id"],
["id"],
)
op.drop_constraint("uq_tool_call_message_id", "tool_call", type_="unique")
# Add turn_number, tool_id (if not exists)
for col_name in ["turn_number", "tool_id"]:
result = conn.execute(
sa.text(
f"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = '{col_name}'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column(col_name, sa.Integer(), nullable=False, server_default="0"),
)
# Add tool_call_id as String (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_call_id'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column("tool_call_id", sa.String(), nullable=False, server_default=""),
)
# Add reasoning_tokens (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'reasoning_tokens'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call", sa.Column("reasoning_tokens", sa.Text(), nullable=True)
)
# Rename tool_arguments to tool_call_arguments (if not already done)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_arguments'
"""
)
)
if result.fetchone():
op.alter_column(
"tool_call", "tool_arguments", new_column_name="tool_call_arguments"
)
# Rename tool_result to tool_call_response and change type from JSONB to Text (if not already done)
result = conn.execute(
sa.text(
"""
SELECT column_name, data_type FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_result'
"""
)
)
tool_result_row = result.fetchone()
if tool_result_row:
op.alter_column(
"tool_call", "tool_result", new_column_name="tool_call_response"
)
# Change type from JSONB to Text
op.execute(
sa.text(
"""
ALTER TABLE tool_call
ALTER COLUMN tool_call_response TYPE TEXT
USING tool_call_response::text
"""
)
)
else:
# Check if tool_call_response already exists and is JSONB, then convert to Text
result = conn.execute(
sa.text(
"""
SELECT data_type FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_call_response'
"""
)
)
tool_call_response_row = result.fetchone()
if tool_call_response_row and tool_call_response_row[0] == "jsonb":
op.execute(
sa.text(
"""
ALTER TABLE tool_call
ALTER COLUMN tool_call_response TYPE TEXT
USING tool_call_response::text
"""
)
)
# Add tool_call_tokens (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_call_tokens'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column(
"tool_call_tokens", sa.Integer(), nullable=False, server_default="0"
),
)
# Add generated_images column for image generation tool replay (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'generated_images'
"""
)
)
if not result.fetchone():
op.add_column(
"tool_call",
sa.Column("generated_images", postgresql.JSONB(), nullable=True),
)
# Drop tool_name column (if exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'tool_call' AND column_name = 'tool_name'
"""
)
)
if result.fetchone():
op.drop_column("tool_call", "tool_name")
# Create tool_call__search_doc association table (if not exists)
result = conn.execute(
sa.text(
"""
SELECT table_name FROM information_schema.tables
WHERE table_name = 'tool_call__search_doc'
"""
)
)
if not result.fetchone():
op.create_table(
"tool_call__search_doc",
sa.Column("tool_call_id", sa.Integer(), nullable=False),
sa.Column("search_doc_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["tool_call_id"], ["tool_call.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(
["search_doc_id"], ["search_doc.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("tool_call_id", "search_doc_id"),
)
# Add replace_base_system_prompt to persona table (if not exists)
result = conn.execute(
sa.text(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = 'persona' AND column_name = 'replace_base_system_prompt'
"""
)
)
if not result.fetchone():
op.add_column(
"persona",
sa.Column(
"replace_base_system_prompt",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
def downgrade() -> None:
# Reverse persona changes
op.drop_column("persona", "replace_base_system_prompt")
# Drop tool_call__search_doc association table
op.execute("DROP TABLE IF EXISTS tool_call__search_doc CASCADE")
# Reverse ToolCall changes
op.add_column("tool_call", sa.Column("tool_name", sa.String(), nullable=False))
op.drop_column("tool_call", "tool_id")
op.drop_column("tool_call", "tool_call_tokens")
op.drop_column("tool_call", "generated_images")
# Change tool_call_response back to JSONB before renaming
op.execute(
sa.text(
"""
ALTER TABLE tool_call
ALTER COLUMN tool_call_response TYPE JSONB
USING tool_call_response::jsonb
"""
)
)
op.alter_column("tool_call", "tool_call_response", new_column_name="tool_result")
op.alter_column(
"tool_call", "tool_call_arguments", new_column_name="tool_arguments"
)
op.drop_column("tool_call", "reasoning_tokens")
op.drop_column("tool_call", "tool_call_id")
op.drop_column("tool_call", "turn_number")
op.drop_constraint(
"fk_tool_call_parent_tool_call_id", "tool_call", type_="foreignkey"
)
op.drop_column("tool_call", "parent_tool_call_id")
op.alter_column(
"tool_call",
"parent_chat_message_id",
new_column_name="message_id",
nullable=False,
)
op.drop_constraint("fk_tool_call_chat_session_id", "tool_call", type_="foreignkey")
op.drop_column("tool_call", "chat_session_id")
op.add_column(
"chat_message",
sa.Column(
"research_answer_purpose",
sa.Enum("INTRO", "DEEP_DIVE", name="researchanswerpurpose"),
nullable=True,
),
)
op.add_column(
"chat_message", sa.Column("research_plan", postgresql.JSONB(), nullable=True)
)
op.add_column(
"chat_message",
sa.Column(
"research_type",
sa.Enum("SIMPLE", "DEEP", name="researchtype"),
nullable=True,
),
)
op.add_column(
"chat_message",
sa.Column("refined_answer_improvement", sa.Boolean(), nullable=True),
)
op.add_column(
"chat_message",
sa.Column("is_agentic", sa.Boolean(), nullable=False, server_default="false"),
)
op.add_column(
"chat_message", sa.Column("overridden_model", sa.String(), nullable=True)
)
op.add_column(
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
)
op.add_column(
"chat_message", sa.Column("rephrased_query", sa.Text(), nullable=True)
)
op.drop_column("chat_message", "reasoning_tokens")
op.drop_constraint(
"fk_chat_message_latest_child_message_id", "chat_message", type_="foreignkey"
)
op.alter_column(
"chat_message",
"latest_child_message_id",
new_column_name="latest_child_message",
)
op.drop_constraint(
"fk_chat_message_parent_message_id", "chat_message", type_="foreignkey"
)
op.alter_column(
"chat_message", "parent_message_id", new_column_name="parent_message"
)
# Recreate agent sub question and sub query tables
op.create_table(
"agent__sub_question",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("primary_question_id", sa.Integer(), nullable=False),
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("sub_question", sa.Text(), nullable=False),
sa.Column("level", sa.Integer(), nullable=False),
sa.Column("level_question_num", sa.Integer(), nullable=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("sub_answer", sa.Text(), nullable=False),
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=False),
sa.ForeignKeyConstraint(
["primary_question_id"], ["chat_message.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["chat_session_id"], ["chat_session.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"agent__sub_query",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("parent_question_id", sa.Integer(), nullable=False),
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("sub_query", sa.Text(), nullable=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["parent_question_id"], ["agent__sub_question.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["chat_session_id"], ["chat_session.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"agent__sub_query__search_doc",
sa.Column("sub_query_id", sa.Integer(), nullable=False),
sa.Column("search_doc_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["sub_query_id"], ["agent__sub_query.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["search_doc_id"], ["search_doc.id"]),
sa.PrimaryKeyConstraint("sub_query_id", "search_doc_id"),
)
# Recreate research agent tables
op.create_table(
"research_agent_iteration",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("primary_question_id", sa.Integer(), nullable=False),
sa.Column("iteration_nr", sa.Integer(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("purpose", sa.String(), nullable=True),
sa.Column("reasoning", sa.String(), nullable=True),
sa.ForeignKeyConstraint(
["primary_question_id"], ["chat_message.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"primary_question_id",
"iteration_nr",
name="_research_agent_iteration_unique_constraint",
),
)
op.create_table(
"research_agent_iteration_sub_step",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("primary_question_id", sa.Integer(), nullable=False),
sa.Column("iteration_nr", sa.Integer(), nullable=False),
sa.Column("iteration_sub_step_nr", sa.Integer(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("sub_step_instructions", sa.String(), nullable=True),
sa.Column("sub_step_tool_id", sa.Integer(), nullable=True),
sa.Column("reasoning", sa.String(), nullable=True),
sa.Column("sub_answer", sa.String(), nullable=True),
sa.Column("cited_doc_results", postgresql.JSONB(), nullable=False),
sa.Column("claims", postgresql.JSONB(), nullable=True),
sa.Column("is_web_fetch", sa.Boolean(), nullable=True),
sa.Column("queries", postgresql.JSONB(), nullable=True),
sa.Column("generated_images", postgresql.JSONB(), nullable=True),
sa.Column("additional_data", postgresql.JSONB(), nullable=True),
sa.ForeignKeyConstraint(
["primary_question_id", "iteration_nr"],
[
"research_agent_iteration.primary_question_id",
"research_agent_iteration.iteration_nr",
],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(["sub_step_tool_id"], ["tool.id"], ondelete="SET NULL"),
sa.PrimaryKeyConstraint("id"),
)

View File

@@ -1,4 +1,16 @@
variable "REPOSITORY" {
group "default" {
targets = ["backend", "model-server"]
}
variable "BACKEND_REPOSITORY" {
default = "onyxdotapp/onyx-backend"
}
variable "MODEL_SERVER_REPOSITORY" {
default = "onyxdotapp/onyx-model-server"
}
variable "INTEGRATION_REPOSITORY" {
default = "onyxdotapp/onyx-integration"
}
@@ -9,6 +21,22 @@ variable "TAG" {
target "backend" {
context = "."
dockerfile = "Dockerfile"
cache-from = ["type=registry,ref=${BACKEND_REPOSITORY}:latest"]
cache-to = ["type=inline"]
tags = ["${BACKEND_REPOSITORY}:${TAG}"]
}
target "model-server" {
context = "."
dockerfile = "Dockerfile.model_server"
cache-from = ["type=registry,ref=${MODEL_SERVER_REPOSITORY}:latest"]
cache-to = ["type=inline"]
tags = ["${MODEL_SERVER_REPOSITORY}:${TAG}"]
}
target "integration" {
@@ -20,8 +48,5 @@ target "integration" {
base = "target:backend"
}
cache-from = ["type=registry,ref=${REPOSITORY}:integration-test-backend-cache"]
cache-to = ["type=registry,ref=${REPOSITORY}:integration-test-backend-cache,mode=max"]
tags = ["${REPOSITORY}:${TAG}"]
tags = ["${INTEGRATION_REPOSITORY}:${TAG}"]
}

View File

@@ -199,10 +199,7 @@ def fetch_persona_message_analytics(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
or_(
ChatMessage.alternate_assistant_id == persona_id,
ChatSession.persona_id == persona_id,
),
ChatSession.persona_id == persona_id,
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
@@ -231,10 +228,7 @@ def fetch_persona_unique_users(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
or_(
ChatMessage.alternate_assistant_id == persona_id,
ChatSession.persona_id == persona_id,
),
ChatSession.persona_id == persona_id,
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
@@ -265,10 +259,7 @@ def fetch_assistant_message_analytics(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
or_(
ChatMessage.alternate_assistant_id == assistant_id,
ChatSession.persona_id == assistant_id,
),
ChatSession.persona_id == assistant_id,
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
@@ -299,10 +290,7 @@ def fetch_assistant_unique_users(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
or_(
ChatMessage.alternate_assistant_id == assistant_id,
ChatSession.persona_id == assistant_id,
),
ChatSession.persona_id == assistant_id,
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
@@ -332,10 +320,7 @@ def fetch_assistant_unique_users_total(
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
or_(
ChatMessage.alternate_assistant_id == assistant_id,
ChatSession.persona_id == assistant_id,
),
ChatSession.persona_id == assistant_id,
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,

View File

@@ -55,18 +55,7 @@ def get_empty_chat_messages_entries__paginated(
# Get assistant name (from session persona, or alternate if specified)
assistant_name = None
if message.alternate_assistant_id:
# If there's an alternate assistant, we need to fetch it
from onyx.db.models import Persona
alternate_persona = (
db_session.query(Persona)
.filter(Persona.id == message.alternate_assistant_id)
.first()
)
if alternate_persona:
assistant_name = alternate_persona.name
elif chat_session.persona:
if chat_session.persona:
assistant_name = chat_session.persona.name
message_skeletons.append(

View File

@@ -9,7 +9,7 @@ from ee.onyx.server.query_and_chat.models import (
)
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import create_chat_history_chain
from onyx.chat.models import ChatBasicResponse
from onyx.chat.process_message import gather_stream
from onyx.chat.process_message import stream_chat_message_objects
@@ -69,7 +69,7 @@ def handle_simplified_chat_message(
chat_session_id = chat_message_req.chat_session_id
try:
parent_message, _ = create_chat_chain(
parent_message, _ = create_chat_history_chain(
chat_session_id=chat_session_id, db_session=db_session
)
except Exception:

View File

@@ -10,14 +10,14 @@ from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DocumentSource
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.enums import SearchType
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import BasicChunkRequest
from onyx.context.search.models import ChunkContext
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RetrievalDetails
from onyx.server.manage.models import StandardAnswer
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import SubQuestionIdentifier
class StandardAnswerRequest(BaseModel):
@@ -29,14 +29,12 @@ class StandardAnswerResponse(BaseModel):
standard_answers: list[StandardAnswer] = Field(default_factory=list)
class DocumentSearchRequest(ChunkContext):
message: str
search_type: SearchType
retrieval_options: RetrievalDetails
recency_bias_multiplier: float = 1.0
evaluation_type: LLMEvaluationType
# None to use system defaults for reranking
rerank_settings: RerankingDetails | None = None
class DocumentSearchRequest(BasicChunkRequest):
user_selected_filters: BaseFilters | None = None
class DocumentSearchResponse(BaseModel):
top_documents: list[InferenceChunk]
class BasicCreateChatMessageRequest(ChunkContext):
@@ -96,17 +94,17 @@ class SimpleDoc(BaseModel):
metadata: dict | None
class AgentSubQuestion(SubQuestionIdentifier):
class AgentSubQuestion(BaseModel):
sub_question: str
document_ids: list[str]
class AgentAnswer(SubQuestionIdentifier):
class AgentAnswer(BaseModel):
answer: str
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
class AgentSubQuery(SubQuestionIdentifier):
class AgentSubQuery(BaseModel):
sub_query: str
query_id: int

View File

@@ -24,23 +24,20 @@ from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.process_message import gather_stream
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.onyxbot_configs import MAX_THREAD_CONTEXT_PERCENTAGE
from onyx.context.search.models import SavedSearchDocWithContent
from onyx.context.search.models import SearchRequest
from onyx.context.search.pipeline import SearchPipeline
from onyx.context.search.utils import dedupe_documents
from onyx.context.search.utils import drop_llm_indices
from onyx.context.search.utils import relevant_sections_to_indices
from onyx.context.search.models import ChunkSearchRequest
from onyx.context.search.models import InferenceChunk
from onyx.context.search.pipeline import search_pipeline
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.persona import get_persona_by_id
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.llm.factory import get_default_llms
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.factory import get_main_llm_from_tuple
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.utils import get_json_line
from onyx.utils.logger import setup_logger
@@ -58,33 +55,22 @@ class DocumentSearchPagination(BaseModel):
class DocumentSearchResponse(BaseModel):
top_documents: list[SavedSearchDocWithContent]
llm_indices: list[int]
pagination: DocumentSearchPagination
top_chunks: list[InferenceChunk]
def _normalize_pagination(limit: int | None, offset: int | None) -> tuple[int, int]:
if limit is None:
resolved_limit = NUM_RETURNED_HITS
else:
resolved_limit = limit
if resolved_limit <= 0:
raise HTTPException(
status_code=400, detail="retrieval_options.limit must be positive"
)
if offset is None:
resolved_offset = 0
else:
resolved_offset = offset
if resolved_offset < 0:
raise HTTPException(
status_code=400, detail="retrieval_options.offset cannot be negative"
)
return resolved_limit, resolved_offset
def _translate_search_request(
search_request: DocumentSearchRequest,
) -> ChunkSearchRequest:
return ChunkSearchRequest(
query=search_request.query,
hybrid_alpha=search_request.hybrid_alpha,
recency_bias_multiplier=search_request.recency_bias_multiplier,
query_keywords=search_request.query_keywords,
limit=search_request.limit,
offset=search_request.offset,
user_selected_filters=search_request.user_selected_filters,
# No bypass_acl, not allowed for this endpoint
)
@basic_router.post("/document-search")
@@ -94,103 +80,28 @@ def handle_search_request(
db_session: Session = Depends(get_session),
) -> DocumentSearchResponse:
"""Simple search endpoint, does not create a new message or records in the DB"""
query = search_request.message
query = search_request.query
logger.notice(f"Received document search query: {query}")
llm, fast_llm = get_default_llms()
pagination_limit, pagination_offset = _normalize_pagination(
limit=search_request.retrieval_options.limit,
offset=search_request.retrieval_options.offset,
llm, _ = get_default_llms()
search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(
search_settings=search_settings,
secondary_search_settings=None,
)
search_pipeline = SearchPipeline(
search_request=SearchRequest(
query=query,
search_type=search_request.search_type,
human_selected_filters=search_request.retrieval_options.filters,
enable_auto_detect_filters=search_request.retrieval_options.enable_auto_detect_filters,
persona=None, # For simplicity, default settings should be good for this search
offset=pagination_offset,
limit=pagination_limit + 1,
rerank_settings=search_request.rerank_settings,
evaluation_type=search_request.evaluation_type,
chunks_above=search_request.chunks_above,
chunks_below=search_request.chunks_below,
full_doc=search_request.full_doc,
),
retrieved_chunks = search_pipeline(
chunk_search_request=_translate_search_request(search_request),
document_index=document_index,
user=user,
llm=llm,
fast_llm=fast_llm,
skip_query_analysis=False,
persona=None,
db_session=db_session,
bypass_acl=False,
)
top_sections = search_pipeline.reranked_sections
relevance_sections = search_pipeline.section_relevance
top_docs = [
SavedSearchDocWithContent(
document_id=section.center_chunk.document_id,
chunk_ind=section.center_chunk.chunk_id,
content=section.center_chunk.content,
semantic_identifier=section.center_chunk.semantic_identifier or "Unknown",
link=(
section.center_chunk.source_links.get(0)
if section.center_chunk.source_links
else None
),
blurb=section.center_chunk.blurb,
source_type=section.center_chunk.source_type,
boost=section.center_chunk.boost,
hidden=section.center_chunk.hidden,
metadata=section.center_chunk.metadata,
score=section.center_chunk.score or 0.0,
match_highlights=section.center_chunk.match_highlights,
updated_at=section.center_chunk.updated_at,
primary_owners=section.center_chunk.primary_owners,
secondary_owners=section.center_chunk.secondary_owners,
is_internet=False,
db_doc_id=0,
)
for section in top_sections
]
# Track whether the underlying retrieval produced more items than requested
has_more_results = len(top_docs) > pagination_limit
# Deduping happens at the last step to avoid harming quality by dropping content early on
deduped_docs = top_docs
dropped_inds = None
if search_request.retrieval_options.dedupe_docs:
deduped_docs, dropped_inds = dedupe_documents(top_docs)
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections, items=deduped_docs
auto_detect_filters=False,
llm=llm,
)
if dropped_inds:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=deduped_docs,
dropped_indices=dropped_inds,
)
paginated_docs = deduped_docs[:pagination_limit]
llm_indices = [index for index in llm_indices if index < len(paginated_docs)]
has_more = has_more_results
pagination = DocumentSearchPagination(
offset=pagination_offset,
limit=pagination_limit,
returned_count=len(paginated_docs),
has_more=has_more,
next_offset=(pagination_offset + pagination_limit) if has_more else None,
)
return DocumentSearchResponse(
top_documents=paginated_docs,
llm_indices=llm_indices,
pagination=pagination,
)
return DocumentSearchResponse(top_chunks=retrieved_chunks)
def get_answer_stream(
@@ -275,10 +186,7 @@ def get_answer_with_citation(
answer=answer.answer,
chat_message_id=answer.message_id,
error_msg=answer.error_msg,
citations=[
CitationInfo(citation_num=i, document_id=doc_id)
for i, doc_id in answer.cited_documents.items()
],
citations=answer.citation_info,
docs=QADocsResponse(
top_documents=answer.top_documents,
predicted_flow=None,

View File

@@ -24,7 +24,7 @@ from onyx.auth.users import current_admin_user
from onyx.auth.users import get_display_email
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.background.task_utils import construct_query_history_report_name
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import create_chat_history_chain
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import FileType
@@ -123,10 +123,9 @@ def snapshot_from_chat_session(
) -> ChatSessionSnapshot | None:
try:
# Older chats may not have the right structure
last_message, messages = create_chat_chain(
messages = create_chat_history_chain(
chat_session_id=chat_session.id, db_session=db_session
)
messages.append(last_message)
except RuntimeError:
return None

View File

@@ -116,7 +116,7 @@ def _concurrent_embedding(
# the model to fail to encode texts. It's pretty rare and we want to allow
# concurrent embedding, hence we retry (the specific error is
# "RuntimeError: Already borrowed" and occurs in the transformers library)
logger.error(f"Error encoding texts, retrying: {e}")
logger.warning(f"Error encoding texts, retrying: {e}")
time.sleep(ENCODING_RETRY_DELAY)
return model.encode(texts, normalize_embeddings=normalize_embeddings)

View File

@@ -0,0 +1,73 @@
import json
from collections.abc import Sequence
from typing import cast
from langchain_core.messages import AIMessage
from langchain_core.messages import BaseMessage
from langchain_core.messages import FunctionMessage
from onyx.llm.message_types import AssistantMessage
from onyx.llm.message_types import ChatCompletionMessage
from onyx.llm.message_types import FunctionCall
from onyx.llm.message_types import SystemMessage
from onyx.llm.message_types import ToolCall
from onyx.llm.message_types import ToolMessage
from onyx.llm.message_types import UserMessageWithText
HUMAN = "human"
SYSTEM = "system"
AI = "ai"
FUNCTION = "function"
def base_messages_to_chat_completion_msgs(
msgs: Sequence[BaseMessage],
) -> list[ChatCompletionMessage]:
return [_base_message_to_chat_completion_msg(msg) for msg in msgs]
def _base_message_to_chat_completion_msg(
msg: BaseMessage,
) -> ChatCompletionMessage:
if msg.type == HUMAN:
content = msg.content if isinstance(msg.content, str) else str(msg.content)
user_msg: UserMessageWithText = {"role": "user", "content": content}
return user_msg
if msg.type == SYSTEM:
content = msg.content if isinstance(msg.content, str) else str(msg.content)
system_msg: SystemMessage = {"role": "system", "content": content}
return system_msg
if msg.type == AI:
content = msg.content if isinstance(msg.content, str) else str(msg.content)
assistant_msg: AssistantMessage = {
"role": "assistant",
"content": content,
}
if isinstance(msg, AIMessage) and msg.tool_calls:
assistant_msg["tool_calls"] = [
ToolCall(
id=tool_call.get("id") or "",
type="function",
function=FunctionCall(
name=tool_call["name"],
arguments=json.dumps(tool_call["args"]),
),
)
for tool_call in msg.tool_calls
]
return assistant_msg
if msg.type == FUNCTION:
function_message = cast(FunctionMessage, msg)
content = (
function_message.content
if isinstance(function_message.content, str)
else str(function_message.content)
)
tool_msg: ToolMessage = {
"role": "tool",
"content": content,
"tool_call_id": function_message.name or "",
}
return tool_msg
raise ValueError(f"Unexpected message type: {msg.type}")

View File

@@ -1,9 +1,13 @@
import json
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
from agents.run_context import RunContextWrapper
import onyx.tracing.framework._error_tracing as _error_tracing
from onyx.agents.agent_framework.models import RunItemStreamEvent
from onyx.agents.agent_framework.models import StreamEvent
from onyx.agents.agent_framework.models import ToolCallOutputStreamItem
@@ -14,8 +18,11 @@ from onyx.llm.interfaces import ToolChoiceOptions
from onyx.llm.message_types import ChatCompletionMessage
from onyx.llm.message_types import ToolCall
from onyx.llm.model_response import ModelResponseStream
from onyx.tools.tool import RunContextWrapper
from onyx.tools.tool import Tool
from onyx.tracing.framework.create import agent_span
from onyx.tracing.framework.create import function_span
from onyx.tracing.framework.create import generation_span
from onyx.tracing.framework.spans import SpanError
@dataclass
@@ -33,6 +40,75 @@ def _serialize_tool_output(output: Any) -> str:
return str(output)
def _parse_tool_calls_from_message_content(
content: str,
) -> list[dict[str, Any]]:
"""Parse JSON content that represents tool call instructions."""
try:
parsed_content = json.loads(content)
except json.JSONDecodeError:
return []
if isinstance(parsed_content, dict):
candidates = [parsed_content]
elif isinstance(parsed_content, list):
candidates = [item for item in parsed_content if isinstance(item, dict)]
else:
return []
tool_calls: list[dict[str, Any]] = []
for candidate in candidates:
name = candidate.get("name")
arguments = candidate.get("arguments")
if not isinstance(name, str) or arguments is None:
continue
if not isinstance(arguments, dict):
continue
call_id = candidate.get("id")
arguments_str = json.dumps(arguments)
tool_calls.append(
{
"id": call_id,
"name": name,
"arguments": arguments_str,
}
)
return tool_calls
def _try_convert_content_to_tool_calls_for_non_tool_calling_llms(
tool_calls_in_progress: dict[int, dict[str, Any]],
content_parts: list[str],
structured_response_format: dict | None,
next_synthetic_tool_call_id: Callable[[], str],
) -> None:
"""Populate tool_calls_in_progress when a non-tool-calling LLM returns JSON content describing tool calls."""
if tool_calls_in_progress or not content_parts or structured_response_format:
return
tool_calls_from_content = _parse_tool_calls_from_message_content(
"".join(content_parts)
)
if not tool_calls_from_content:
return
content_parts.clear()
for index, tool_call_data in enumerate(tool_calls_from_content):
call_id = tool_call_data["id"] or next_synthetic_tool_call_id()
tool_calls_in_progress[index] = {
"id": call_id,
"name": tool_call_data["name"],
"arguments": tool_call_data["arguments"],
}
def _update_tool_call_with_delta(
tool_calls_in_progress: dict[int, dict[str, Any]],
tool_call_delta: Any,
@@ -65,149 +141,225 @@ def query(
tools: Sequence[Tool],
context: Any,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> QueryResult:
tool_definitions = [tool.tool_definition() for tool in tools]
tools_by_name = {tool.name: tool for tool in tools}
new_messages_stateful: list[ChatCompletionMessage] = []
current_span = agent_span(
name="agent_framework_query",
output_type="dict" if structured_response_format else "str",
)
current_span.start(mark_as_current=True)
current_span.span_data.tools = [t.name for t in tools]
def stream_generator() -> Iterator[StreamEvent]:
reasoning_started = False
message_started = False
reasoning_started = False
tool_calls_in_progress: dict[int, dict[str, Any]] = {}
content_parts: list[str] = []
reasoning_parts: list[str] = []
for chunk in llm_with_default_settings.stream(
prompt=messages,
tools=tool_definitions,
tool_choice=tool_choice,
):
assert isinstance(chunk, ModelResponseStream)
synthetic_tool_call_counter = 0
delta = chunk.choice.delta
finish_reason = chunk.choice.finish_reason
def _next_synthetic_tool_call_id() -> str:
nonlocal synthetic_tool_call_counter
call_id = f"synthetic_tool_call_{synthetic_tool_call_counter}"
synthetic_tool_call_counter += 1
return call_id
if delta.reasoning_content:
reasoning_parts.append(delta.reasoning_content)
if not reasoning_started:
yield RunItemStreamEvent(type="reasoning_start")
reasoning_started = True
with generation_span( # type: ignore[misc]
model=llm_with_default_settings.config.model_name,
model_config={
"base_url": str(llm_with_default_settings.config.api_base or ""),
"model_impl": "litellm",
},
) as span_generation:
# Only set input if messages is a sequence (not a string)
# ChatCompletionMessage TypedDicts are compatible with Mapping[str, Any] at runtime
if isinstance(messages, Sequence) and not isinstance(messages, str):
# Convert ChatCompletionMessage sequence to Sequence[Mapping[str, Any]]
span_generation.span_data.input = [dict(msg) for msg in messages] # type: ignore[assignment]
for chunk in llm_with_default_settings.stream(
prompt=messages,
tools=tool_definitions,
tool_choice=tool_choice,
structured_response_format=structured_response_format,
):
assert isinstance(chunk, ModelResponseStream)
usage = getattr(chunk, "usage", None)
if usage:
span_generation.span_data.usage = {
"input_tokens": usage.prompt_tokens,
"output_tokens": usage.completion_tokens,
"cache_read_input_tokens": usage.cache_read_input_tokens,
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
}
delta = chunk.choice.delta
finish_reason = chunk.choice.finish_reason
if delta.reasoning_content:
if not reasoning_started:
yield RunItemStreamEvent(type="reasoning_start")
reasoning_started = True
if delta.content:
if reasoning_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
content_parts.append(delta.content)
if not message_started:
yield RunItemStreamEvent(type="message_start")
message_started = True
if delta.tool_calls:
if reasoning_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
if message_started:
yield RunItemStreamEvent(type="message_done")
message_started = False
for tool_call_delta in delta.tool_calls:
_update_tool_call_with_delta(
tool_calls_in_progress, tool_call_delta
)
yield chunk
if not finish_reason:
continue
if delta.content:
content_parts.append(delta.content)
if reasoning_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
if not message_started:
yield RunItemStreamEvent(type="message_start")
message_started = True
if delta.tool_calls:
if reasoning_started and not message_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
if message_started:
yield RunItemStreamEvent(type="message_done")
message_started = False
for tool_call_delta in delta.tool_calls:
_update_tool_call_with_delta(
tool_calls_in_progress, tool_call_delta
if tool_choice != "none":
_try_convert_content_to_tool_calls_for_non_tool_calling_llms(
tool_calls_in_progress,
content_parts,
structured_response_format,
_next_synthetic_tool_call_id,
)
yield chunk
if not finish_reason:
continue
if message_started:
yield RunItemStreamEvent(type="message_done")
message_started = False
if finish_reason == "tool_calls" and tool_calls_in_progress:
sorted_tool_calls = sorted(tool_calls_in_progress.items())
# Build tool calls for the message and execute tools
assistant_tool_calls: list[ToolCall] = []
tool_outputs: dict[str, str] = {}
for _, tool_call_data in sorted_tool_calls:
call_id = tool_call_data["id"]
name = tool_call_data["name"]
arguments_str = tool_call_data["arguments"]
if call_id is None or name is None:
continue
assistant_tool_calls.append(
if content_parts:
new_messages_stateful.append(
{
"id": call_id,
"type": "function",
"function": {
"name": name,
"arguments": arguments_str,
},
"role": "assistant",
"content": "".join(content_parts),
}
)
span_generation.span_data.output = new_messages_stateful
# Execute tool calls outside of the stream loop and generation_span
if tool_calls_in_progress:
sorted_tool_calls = sorted(tool_calls_in_progress.items())
# Build tool calls for the message and execute tools
assistant_tool_calls: list[ToolCall] = []
tool_outputs: dict[str, str] = {}
for _, tool_call_data in sorted_tool_calls:
call_id = tool_call_data["id"]
name = tool_call_data["name"]
arguments_str = tool_call_data["arguments"]
if call_id is None or name is None:
continue
assistant_tool_calls.append(
{
"id": call_id,
"type": "function",
"function": {
"name": name,
"arguments": arguments_str,
},
}
)
yield RunItemStreamEvent(
type="tool_call",
details=ToolCallStreamItem(
call_id=call_id,
name=name,
arguments=arguments_str,
),
)
if name in tools_by_name:
tool = tools_by_name[name]
arguments = json.loads(arguments_str)
# TODO this is still agentssdk, this should be removed
run_context = RunContextWrapper(context=context)
# TODO: Instead of executing sequentially, execute in parallel
# In practice, it's not a must right now since we don't use parallel
# tool calls, so kicking the can down the road for now.
with function_span(tool.name) as span_fn:
span_fn.span_data.input = arguments
try:
output = tool.run_v2(run_context, **arguments) # type: ignore[attr-defined]
tool_outputs[call_id] = _serialize_tool_output(output)
span_fn.span_data.output = output
except Exception as e:
_error_tracing.attach_error_to_current_span(
SpanError(
message="Error running tool",
data={"tool_name": tool.name, "error": str(e)},
)
)
# Treat the error as the tool output so the framework can continue
error_output = f"Error: {str(e)}"
tool_outputs[call_id] = error_output
output = error_output
yield RunItemStreamEvent(
type="tool_call",
details=ToolCallStreamItem(
type="tool_call_output",
details=ToolCallOutputStreamItem(
call_id=call_id,
name=name,
arguments=arguments_str,
output=output,
),
)
else:
not_found_output = f"Tool {name} not found"
tool_outputs[call_id] = _serialize_tool_output(not_found_output)
yield RunItemStreamEvent(
type="tool_call_output",
details=ToolCallOutputStreamItem(
call_id=call_id,
output=not_found_output,
),
)
if name in tools_by_name:
tool = tools_by_name[name]
arguments = json.loads(arguments_str)
new_messages_stateful.append(
{
"role": "assistant",
"content": None,
"tool_calls": assistant_tool_calls,
}
)
run_context = RunContextWrapper(context=context)
for _, tool_call_data in sorted_tool_calls:
call_id = tool_call_data["id"]
# TODO: Instead of executing sequentially, execute in parallel
# In practice, it's not a must right now since we don't use parallel
# tool calls, so kicking the can down the road for now.
output = tool.run_v2(run_context, **arguments)
tool_outputs[call_id] = _serialize_tool_output(output)
yield RunItemStreamEvent(
type="tool_call_output",
details=ToolCallOutputStreamItem(
call_id=call_id,
output=output,
),
)
new_messages_stateful.append(
{
"role": "assistant",
"content": None,
"tool_calls": assistant_tool_calls,
}
)
for _, tool_call_data in sorted_tool_calls:
call_id = tool_call_data["id"]
if call_id in tool_outputs:
new_messages_stateful.append(
{
"role": "tool",
"content": tool_outputs[call_id],
"tool_call_id": call_id,
}
)
elif finish_reason == "stop" and content_parts:
new_messages_stateful.append(
{
"role": "assistant",
"content": "".join(content_parts),
}
)
if call_id in tool_outputs:
new_messages_stateful.append(
{
"role": "tool",
"content": tool_outputs[call_id],
"tool_call_id": call_id,
}
)
current_span.finish(reset_current=True)
return QueryResult(
stream=stream_generator(),

View File

@@ -26,9 +26,9 @@ def monkey_patch_convert_tool_choice_to_ignore_openai_hosted_web_search() -> Non
# Without this patch, the library uses special formatting that breaks our custom tools
# See: https://platform.openai.com/docs/api-reference/responses/create#responses_create-tool_choice-hosted_tool-type
if tool_choice == "web_search":
return {"type": "function", "name": "web_search"}
return "web_search"
if tool_choice == "image_generation":
return {"type": "function", "name": "image_generation"}
return "image_generation"
return orig_func(cls, tool_choice)
OpenAIResponsesConverter.convert_tool_choice = classmethod( # type: ignore[method-assign, assignment]

View File

@@ -1,21 +1,21 @@
from operator import add
from typing import Annotated
# from operator import add
# from typing import Annotated
from pydantic import BaseModel
# from pydantic import BaseModel
class CoreState(BaseModel):
"""
This is the core state that is shared across all subgraphs.
"""
# class CoreState(BaseModel):
# """
# This is the core state that is shared across all subgraphs.
# """
log_messages: Annotated[list[str], add] = []
current_step_nr: int = 1
# log_messages: Annotated[list[str], add] = []
# current_step_nr: int = 1
class SubgraphCoreState(BaseModel):
"""
This is the core state that is shared across all subgraphs.
"""
# class SubgraphCoreState(BaseModel):
# """
# This is the core state that is shared across all subgraphs.
# """
log_messages: Annotated[list[str], add] = []
# log_messages: Annotated[list[str], add] = []

View File

@@ -1,62 +1,62 @@
from collections.abc import Hashable
from typing import cast
# from collections.abc import Hashable
# from typing import cast
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import Send
# from langchain_core.runnables.config import RunnableConfig
# from langgraph.types import Send
from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectResearchInformationUpdate,
)
from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
from onyx.agents.agent_search.dc_search_analysis.states import (
SearchSourcesObjectsUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
# from onyx.agents.agent_search.dc_search_analysis.states import (
# ObjectResearchInformationUpdate,
# )
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
# from onyx.agents.agent_search.dc_search_analysis.states import (
# SearchSourcesObjectsUpdate,
# )
# from onyx.agents.agent_search.models import GraphConfig
def parallel_object_source_research_edge(
state: SearchSourcesObjectsUpdate, config: RunnableConfig
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the research for an individual object and source
"""
# def parallel_object_source_research_edge(
# state: SearchSourcesObjectsUpdate, config: RunnableConfig
# ) -> list[Send | Hashable]:
# """
# LangGraph edge to parallelize the research for an individual object and source
# """
search_objects = state.analysis_objects
search_sources = state.analysis_sources
# search_objects = state.analysis_objects
# search_sources = state.analysis_sources
object_source_combinations = [
(object, source) for object in search_objects for source in search_sources
]
# object_source_combinations = [
# (object, source) for object in search_objects for source in search_sources
# ]
return [
Send(
"research_object_source",
ObjectSourceInput(
object_source_combination=object_source_combination,
log_messages=[],
),
)
for object_source_combination in object_source_combinations
]
# return [
# Send(
# "research_object_source",
# ObjectSourceInput(
# object_source_combination=object_source_combination,
# log_messages=[],
# ),
# )
# for object_source_combination in object_source_combinations
# ]
def parallel_object_research_consolidation_edge(
state: ObjectResearchInformationUpdate, config: RunnableConfig
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the research for an individual object and source
"""
cast(GraphConfig, config["metadata"]["config"])
object_research_information_results = state.object_research_information_results
# def parallel_object_research_consolidation_edge(
# state: ObjectResearchInformationUpdate, config: RunnableConfig
# ) -> list[Send | Hashable]:
# """
# LangGraph edge to parallelize the research for an individual object and source
# """
# cast(GraphConfig, config["metadata"]["config"])
# object_research_information_results = state.object_research_information_results
return [
Send(
"consolidate_object_research",
ObjectInformationInput(
object_information=object_information,
log_messages=[],
),
)
for object_information in object_research_information_results
]
# return [
# Send(
# "consolidate_object_research",
# ObjectInformationInput(
# object_information=object_information,
# log_messages=[],
# ),
# )
# for object_information in object_research_information_results
# ]

View File

@@ -1,103 +1,103 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dc_search_analysis.edges import (
parallel_object_research_consolidation_edge,
)
from onyx.agents.agent_search.dc_search_analysis.edges import (
parallel_object_source_research_edge,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a1_search_objects import (
search_objects,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a2_research_object_source import (
research_object_source,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a3_structure_research_by_object import (
structure_research_by_object,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a4_consolidate_object_research import (
consolidate_object_research,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a5_consolidate_research import (
consolidate_research,
)
from onyx.agents.agent_search.dc_search_analysis.states import MainInput
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dc_search_analysis.edges import (
# parallel_object_research_consolidation_edge,
# )
# from onyx.agents.agent_search.dc_search_analysis.edges import (
# parallel_object_source_research_edge,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a1_search_objects import (
# search_objects,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a2_research_object_source import (
# research_object_source,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a3_structure_research_by_object import (
# structure_research_by_object,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a4_consolidate_object_research import (
# consolidate_object_research,
# )
# from onyx.agents.agent_search.dc_search_analysis.nodes.a5_consolidate_research import (
# consolidate_research,
# )
# from onyx.agents.agent_search.dc_search_analysis.states import MainInput
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
test_mode = False
# test_mode = False
def divide_and_conquer_graph_builder(test_mode: bool = False) -> StateGraph:
"""
LangGraph graph builder for the knowledge graph search process.
"""
# def divide_and_conquer_graph_builder(test_mode: bool = False) -> StateGraph:
# """
# LangGraph graph builder for the knowledge graph search process.
# """
graph = StateGraph(
state_schema=MainState,
input=MainInput,
)
# graph = StateGraph(
# state_schema=MainState,
# input=MainInput,
# )
### Add nodes ###
# ### Add nodes ###
graph.add_node(
"search_objects",
search_objects,
)
# graph.add_node(
# "search_objects",
# search_objects,
# )
graph.add_node(
"structure_research_by_source",
structure_research_by_object,
)
# graph.add_node(
# "structure_research_by_source",
# structure_research_by_object,
# )
graph.add_node(
"research_object_source",
research_object_source,
)
# graph.add_node(
# "research_object_source",
# research_object_source,
# )
graph.add_node(
"consolidate_object_research",
consolidate_object_research,
)
# graph.add_node(
# "consolidate_object_research",
# consolidate_object_research,
# )
graph.add_node(
"consolidate_research",
consolidate_research,
)
# graph.add_node(
# "consolidate_research",
# consolidate_research,
# )
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key="search_objects")
# graph.add_edge(start_key=START, end_key="search_objects")
graph.add_conditional_edges(
source="search_objects",
path=parallel_object_source_research_edge,
path_map=["research_object_source"],
)
# graph.add_conditional_edges(
# source="search_objects",
# path=parallel_object_source_research_edge,
# path_map=["research_object_source"],
# )
graph.add_edge(
start_key="research_object_source",
end_key="structure_research_by_source",
)
# graph.add_edge(
# start_key="research_object_source",
# end_key="structure_research_by_source",
# )
graph.add_conditional_edges(
source="structure_research_by_source",
path=parallel_object_research_consolidation_edge,
path_map=["consolidate_object_research"],
)
# graph.add_conditional_edges(
# source="structure_research_by_source",
# path=parallel_object_research_consolidation_edge,
# path_map=["consolidate_object_research"],
# )
graph.add_edge(
start_key="consolidate_object_research",
end_key="consolidate_research",
)
# graph.add_edge(
# start_key="consolidate_object_research",
# end_key="consolidate_research",
# )
graph.add_edge(
start_key="consolidate_research",
end_key=END,
)
# graph.add_edge(
# start_key="consolidate_research",
# end_key=END,
# )
return graph
# return graph

View File

@@ -1,146 +1,146 @@
from typing import cast
# from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.ops import research
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import (
SearchSourcesObjectsUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.prompts.agents.dc_prompts import DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SEPARATOR
from onyx.prompts.agents.dc_prompts import DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT
from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
# from onyx.agents.agent_search.dc_search_analysis.ops import research
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
# from onyx.agents.agent_search.dc_search_analysis.states import (
# SearchSourcesObjectsUpdate,
# )
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_SEPARATOR
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT
# from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
# logger = setup_logger()
def search_objects(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> SearchSourcesObjectsUpdate:
"""
LangGraph node to start the agentic search process.
"""
# def search_objects(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> SearchSourcesObjectsUpdate:
# """
# LangGraph node to start the agentic search process.
# """
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
search_tool = graph_config.tooling.search_tool
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# question = graph_config.inputs.prompt_builder.raw_user_query
# search_tool = graph_config.tooling.search_tool
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# if search_tool is None or graph_config.inputs.persona is None:
# raise ValueError("Search tool and persona must be provided for DivCon search")
try:
instructions = graph_config.inputs.persona.system_prompt or ""
# try:
# instructions = graph_config.inputs.persona.system_prompt or ""
agent_1_instructions = extract_section(
instructions, "Agent Step 1:", "Agent Step 2:"
)
if agent_1_instructions is None:
raise ValueError("Agent 1 instructions not found")
# agent_1_instructions = extract_section(
# instructions, "Agent Step 1:", "Agent Step 2:"
# )
# if agent_1_instructions is None:
# raise ValueError("Agent 1 instructions not found")
agent_1_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
# agent_1_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
agent_1_task = extract_section(
agent_1_instructions, "Task:", "Independent Research Sources:"
)
if agent_1_task is None:
raise ValueError("Agent 1 task not found")
# agent_1_task = extract_section(
# agent_1_instructions, "Task:", "Independent Research Sources:"
# )
# if agent_1_task is None:
# raise ValueError("Agent 1 task not found")
agent_1_independent_sources_str = extract_section(
agent_1_instructions, "Independent Research Sources:", "Output Objective:"
)
if agent_1_independent_sources_str is None:
raise ValueError("Agent 1 Independent Research Sources not found")
# agent_1_independent_sources_str = extract_section(
# agent_1_instructions, "Independent Research Sources:", "Output Objective:"
# )
# if agent_1_independent_sources_str is None:
# raise ValueError("Agent 1 Independent Research Sources not found")
document_sources = strings_to_document_sources(
[
x.strip().lower()
for x in agent_1_independent_sources_str.split(DC_OBJECT_SEPARATOR)
]
)
# document_sources = strings_to_document_sources(
# [
# x.strip().lower()
# for x in agent_1_independent_sources_str.split(DC_OBJECT_SEPARATOR)
# ]
# )
agent_1_output_objective = extract_section(
agent_1_instructions, "Output Objective:"
)
if agent_1_output_objective is None:
raise ValueError("Agent 1 output objective not found")
# agent_1_output_objective = extract_section(
# agent_1_instructions, "Output Objective:"
# )
# if agent_1_output_objective is None:
# raise ValueError("Agent 1 output objective not found")
except Exception as e:
raise ValueError(
f"Agent 1 instructions not found or not formatted correctly: {e}"
)
# except Exception as e:
# raise ValueError(
# f"Agent 1 instructions not found or not formatted correctly: {e}"
# )
# Extract objects
# # Extract objects
if agent_1_base_data is None:
# Retrieve chunks for objects
# if agent_1_base_data is None:
# # Retrieve chunks for objects
retrieved_docs = research(question, search_tool)[:10]
# retrieved_docs = research(question, search_tool)[:10]
document_texts_list = []
for doc_num, doc in enumerate(retrieved_docs):
chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
document_texts_list.append(chunk_text)
# document_texts_list = []
# for doc_num, doc in enumerate(retrieved_docs):
# chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
# document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
# document_texts = "\n\n".join(document_texts_list)
dc_object_extraction_prompt = DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT.format(
question=question,
task=agent_1_task,
document_text=document_texts,
objects_of_interest=agent_1_output_objective,
)
else:
dc_object_extraction_prompt = DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT.format(
question=question,
task=agent_1_task,
base_data=agent_1_base_data,
objects_of_interest=agent_1_output_objective,
)
# dc_object_extraction_prompt = DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT.format(
# question=question,
# task=agent_1_task,
# document_text=document_texts,
# objects_of_interest=agent_1_output_objective,
# )
# else:
# dc_object_extraction_prompt = DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT.format(
# question=question,
# task=agent_1_task,
# base_data=agent_1_base_data,
# objects_of_interest=agent_1_output_objective,
# )
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_object_extraction_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=dc_object_extraction_prompt,
# reserved_str="",
# ),
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# 30,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=30,
# max_tokens=300,
# )
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
cleaned_response = cleaned_response.split("OBJECTS:")[1]
object_list = [x.strip() for x in cleaned_response.split(";")]
# cleaned_response = (
# str(llm_response.content)
# .replace("```json\n", "")
# .replace("\n```", "")
# .replace("\n", "")
# )
# cleaned_response = cleaned_response.split("OBJECTS:")[1]
# object_list = [x.strip() for x in cleaned_response.split(";")]
except Exception as e:
raise ValueError(f"Error in search_objects: {e}")
# except Exception as e:
# raise ValueError(f"Error in search_objects: {e}")
return SearchSourcesObjectsUpdate(
analysis_objects=object_list,
analysis_sources=document_sources,
log_messages=["Agent 1 Task done"],
)
# return SearchSourcesObjectsUpdate(
# analysis_objects=object_list,
# analysis_sources=document_sources,
# log_messages=["Agent 1 Task done"],
# )

View File

@@ -1,180 +1,180 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import cast
# from datetime import datetime
# from datetime import timedelta
# from datetime import timezone
# from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.ops import research
from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectSourceResearchUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SOURCE_RESEARCH_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
# from onyx.agents.agent_search.dc_search_analysis.ops import research
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
# from onyx.agents.agent_search.dc_search_analysis.states import (
# ObjectSourceResearchUpdate,
# )
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_SOURCE_RESEARCH_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
# logger = setup_logger()
def research_object_source(
state: ObjectSourceInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ObjectSourceResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
datetime.now()
# def research_object_source(
# state: ObjectSourceInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> ObjectSourceResearchUpdate:
# """
# LangGraph node to start the agentic search process.
# """
# datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
question = graph_config.inputs.prompt_builder.raw_user_query
object, document_source = state.object_source_combination
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# search_tool = graph_config.tooling.search_tool
# question = graph_config.inputs.prompt_builder.raw_user_query
# object, document_source = state.object_source_combination
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# if search_tool is None or graph_config.inputs.persona is None:
# raise ValueError("Search tool and persona must be provided for DivCon search")
try:
instructions = graph_config.inputs.persona.system_prompt or ""
# try:
# instructions = graph_config.inputs.persona.system_prompt or ""
agent_2_instructions = extract_section(
instructions, "Agent Step 2:", "Agent Step 3:"
)
if agent_2_instructions is None:
raise ValueError("Agent 2 instructions not found")
# agent_2_instructions = extract_section(
# instructions, "Agent Step 2:", "Agent Step 3:"
# )
# if agent_2_instructions is None:
# raise ValueError("Agent 2 instructions not found")
agent_2_task = extract_section(
agent_2_instructions, "Task:", "Independent Research Sources:"
)
if agent_2_task is None:
raise ValueError("Agent 2 task not found")
# agent_2_task = extract_section(
# agent_2_instructions, "Task:", "Independent Research Sources:"
# )
# if agent_2_task is None:
# raise ValueError("Agent 2 task not found")
agent_2_time_cutoff = extract_section(
agent_2_instructions, "Time Cutoff:", "Research Topics:"
)
# agent_2_time_cutoff = extract_section(
# agent_2_instructions, "Time Cutoff:", "Research Topics:"
# )
agent_2_research_topics = extract_section(
agent_2_instructions, "Research Topics:", "Output Objective"
)
# agent_2_research_topics = extract_section(
# agent_2_instructions, "Research Topics:", "Output Objective"
# )
agent_2_output_objective = extract_section(
agent_2_instructions, "Output Objective:"
)
if agent_2_output_objective is None:
raise ValueError("Agent 2 output objective not found")
# agent_2_output_objective = extract_section(
# agent_2_instructions, "Output Objective:"
# )
# if agent_2_output_objective is None:
# raise ValueError("Agent 2 output objective not found")
except Exception:
raise ValueError(
"Agent 1 instructions not found or not formatted correctly: {e}"
)
# except Exception:
# raise ValueError(
# "Agent 1 instructions not found or not formatted correctly: {e}"
# )
# Populate prompt
# # Populate prompt
# Retrieve chunks for objects
# # Retrieve chunks for objects
if agent_2_time_cutoff is not None and agent_2_time_cutoff.strip() != "":
if agent_2_time_cutoff.strip().endswith("d"):
try:
days = int(agent_2_time_cutoff.strip()[:-1])
agent_2_source_start_time = datetime.now(timezone.utc) - timedelta(
days=days
)
except ValueError:
raise ValueError(
f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
)
else:
raise ValueError(
f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
)
else:
agent_2_source_start_time = None
# if agent_2_time_cutoff is not None and agent_2_time_cutoff.strip() != "":
# if agent_2_time_cutoff.strip().endswith("d"):
# try:
# days = int(agent_2_time_cutoff.strip()[:-1])
# agent_2_source_start_time = datetime.now(timezone.utc) - timedelta(
# days=days
# )
# except ValueError:
# raise ValueError(
# f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
# )
# else:
# raise ValueError(
# f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
# )
# else:
# agent_2_source_start_time = None
document_sources = [document_source] if document_source else None
# document_sources = [document_source] if document_source else None
if len(question.strip()) > 0:
research_area = f"{question} for {object}"
elif agent_2_research_topics and len(agent_2_research_topics.strip()) > 0:
research_area = f"{agent_2_research_topics} for {object}"
else:
research_area = object
# if len(question.strip()) > 0:
# research_area = f"{question} for {object}"
# elif agent_2_research_topics and len(agent_2_research_topics.strip()) > 0:
# research_area = f"{agent_2_research_topics} for {object}"
# else:
# research_area = object
retrieved_docs = research(
question=research_area,
search_tool=search_tool,
document_sources=document_sources,
time_cutoff=agent_2_source_start_time,
)
# retrieved_docs = research(
# question=research_area,
# search_tool=search_tool,
# document_sources=document_sources,
# time_cutoff=agent_2_source_start_time,
# )
# Generate document text
# # Generate document text
document_texts_list = []
for doc_num, doc in enumerate(retrieved_docs):
chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
document_texts_list.append(chunk_text)
# document_texts_list = []
# for doc_num, doc in enumerate(retrieved_docs):
# chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
# document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
# document_texts = "\n\n".join(document_texts_list)
# Built prompt
# # Built prompt
today = datetime.now().strftime("%A, %Y-%m-%d")
# today = datetime.now().strftime("%A, %Y-%m-%d")
dc_object_source_research_prompt = (
DC_OBJECT_SOURCE_RESEARCH_PROMPT.format(
today=today,
question=question,
task=agent_2_task,
document_text=document_texts,
format=agent_2_output_objective,
)
.replace("---object---", object)
.replace("---source---", document_source.value)
)
# dc_object_source_research_prompt = (
# DC_OBJECT_SOURCE_RESEARCH_PROMPT.format(
# today=today,
# question=question,
# task=agent_2_task,
# document_text=document_texts,
# format=agent_2_output_objective,
# )
# .replace("---object---", object)
# .replace("---source---", document_source.value)
# )
# Run LLM
# # Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_object_source_research_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=dc_object_source_research_prompt,
# reserved_str="",
# ),
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# 30,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=30,
# max_tokens=300,
# )
cleaned_response = str(llm_response.content).replace("```json\n", "")
cleaned_response = cleaned_response.split("RESEARCH RESULTS:")[1]
object_research_results = {
"object": object,
"source": document_source.value,
"research_result": cleaned_response,
}
# cleaned_response = str(llm_response.content).replace("```json\n", "")
# cleaned_response = cleaned_response.split("RESEARCH RESULTS:")[1]
# object_research_results = {
# "object": object,
# "source": document_source.value,
# "research_result": cleaned_response,
# }
except Exception as e:
raise ValueError(f"Error in research_object_source: {e}")
# except Exception as e:
# raise ValueError(f"Error in research_object_source: {e}")
logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
# logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
return ObjectSourceResearchUpdate(
object_source_research_results=[object_research_results],
log_messages=["Agent Step 2 done for one object"],
)
# return ObjectSourceResearchUpdate(
# object_source_research_results=[object_research_results],
# log_messages=["Agent Step 2 done for one object"],
# )

View File

@@ -1,48 +1,48 @@
from collections import defaultdict
from typing import Dict
from typing import List
# from collections import defaultdict
# from typing import Dict
# from typing import List
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectResearchInformationUpdate,
)
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
# from onyx.agents.agent_search.dc_search_analysis.states import (
# ObjectResearchInformationUpdate,
# )
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def structure_research_by_object(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ObjectResearchInformationUpdate:
"""
LangGraph node to start the agentic search process.
"""
# def structure_research_by_object(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> ObjectResearchInformationUpdate:
# """
# LangGraph node to start the agentic search process.
# """
object_source_research_results = state.object_source_research_results
# object_source_research_results = state.object_source_research_results
object_research_information_results: List[Dict[str, str]] = []
object_research_information_results_list: Dict[str, List[str]] = defaultdict(list)
# object_research_information_results: List[Dict[str, str]] = []
# object_research_information_results_list: Dict[str, List[str]] = defaultdict(list)
for object_source_research in object_source_research_results:
object = object_source_research["object"]
source = object_source_research["source"]
research_result = object_source_research["research_result"]
# for object_source_research in object_source_research_results:
# object = object_source_research["object"]
# source = object_source_research["source"]
# research_result = object_source_research["research_result"]
object_research_information_results_list[object].append(
f"Source: {source}\n{research_result}"
)
# object_research_information_results_list[object].append(
# f"Source: {source}\n{research_result}"
# )
for object, information in object_research_information_results_list.items():
object_research_information_results.append(
{"object": object, "information": "\n".join(information)}
)
# for object, information in object_research_information_results_list.items():
# object_research_information_results.append(
# {"object": object, "information": "\n".join(information)}
# )
logger.debug("DivCon Step A3 - Object Research Information Structuring - completed")
# logger.debug("DivCon Step A3 - Object Research Information Structuring - completed")
return ObjectResearchInformationUpdate(
object_research_information_results=object_research_information_results,
log_messages=["A3 - Object Research Information structured"],
)
# return ObjectResearchInformationUpdate(
# object_research_information_results=object_research_information_results,
# log_messages=["A3 - Object Research Information structured"],
# )

View File

@@ -1,103 +1,103 @@
from typing import cast
# from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
from onyx.agents.agent_search.dc_search_analysis.states import ObjectResearchUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.prompts.agents.dc_prompts import DC_OBJECT_CONSOLIDATION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectResearchUpdate
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_CONSOLIDATION_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
# logger = setup_logger()
def consolidate_object_research(
state: ObjectInformationInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ObjectResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
question = graph_config.inputs.prompt_builder.raw_user_query
# def consolidate_object_research(
# state: ObjectInformationInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> ObjectResearchUpdate:
# """
# LangGraph node to start the agentic search process.
# """
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# search_tool = graph_config.tooling.search_tool
# question = graph_config.inputs.prompt_builder.raw_user_query
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# if search_tool is None or graph_config.inputs.persona is None:
# raise ValueError("Search tool and persona must be provided for DivCon search")
instructions = graph_config.inputs.persona.system_prompt or ""
# instructions = graph_config.inputs.persona.system_prompt or ""
agent_4_instructions = extract_section(
instructions, "Agent Step 4:", "Agent Step 5:"
)
if agent_4_instructions is None:
raise ValueError("Agent 4 instructions not found")
agent_4_output_objective = extract_section(
agent_4_instructions, "Output Objective:"
)
if agent_4_output_objective is None:
raise ValueError("Agent 4 output objective not found")
# agent_4_instructions = extract_section(
# instructions, "Agent Step 4:", "Agent Step 5:"
# )
# if agent_4_instructions is None:
# raise ValueError("Agent 4 instructions not found")
# agent_4_output_objective = extract_section(
# agent_4_instructions, "Output Objective:"
# )
# if agent_4_output_objective is None:
# raise ValueError("Agent 4 output objective not found")
object_information = state.object_information
# object_information = state.object_information
object = object_information["object"]
information = object_information["information"]
# object = object_information["object"]
# information = object_information["information"]
# Create a prompt for the object consolidation
# # Create a prompt for the object consolidation
dc_object_consolidation_prompt = DC_OBJECT_CONSOLIDATION_PROMPT.format(
question=question,
object=object,
information=information,
format=agent_4_output_objective,
)
# dc_object_consolidation_prompt = DC_OBJECT_CONSOLIDATION_PROMPT.format(
# question=question,
# object=object,
# information=information,
# format=agent_4_output_objective,
# )
# Run LLM
# # Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_object_consolidation_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
primary_llm.invoke_langchain,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=dc_object_consolidation_prompt,
# reserved_str="",
# ),
# )
# ]
# primary_llm = graph_config.tooling.primary_llm
# # Grader
# try:
# llm_response = run_with_timeout(
# 30,
# primary_llm.invoke_langchain,
# prompt=msg,
# timeout_override=30,
# max_tokens=300,
# )
cleaned_response = str(llm_response.content).replace("```json\n", "")
consolidated_information = cleaned_response.split("INFORMATION:")[1]
# cleaned_response = str(llm_response.content).replace("```json\n", "")
# consolidated_information = cleaned_response.split("INFORMATION:")[1]
except Exception as e:
raise ValueError(f"Error in consolidate_object_research: {e}")
# except Exception as e:
# raise ValueError(f"Error in consolidate_object_research: {e}")
object_research_results = {
"object": object,
"research_result": consolidated_information,
}
# object_research_results = {
# "object": object,
# "research_result": consolidated_information,
# }
logger.debug(
"DivCon Step A4 - Object Research Consolidation - completed for an object"
)
# logger.debug(
# "DivCon Step A4 - Object Research Consolidation - completed for an object"
# )
return ObjectResearchUpdate(
object_research_results=[object_research_results],
log_messages=["Agent Source Consilidation done"],
)
# return ObjectResearchUpdate(
# object_research_results=[object_research_results],
# log_messages=["Agent Source Consilidation done"],
# )

View File

@@ -1,127 +1,127 @@
from typing import cast
# from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.messages import HumanMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import ResearchUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_WITH_BASE_DATA_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
# from onyx.agents.agent_search.dc_search_analysis.states import ResearchUpdate
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
# trim_prompt_piece,
# )
# from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
# from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
# from onyx.prompts.agents.dc_prompts import DC_FORMATTING_WITH_BASE_DATA_PROMPT
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
# logger = setup_logger()
def consolidate_research(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
# def consolidate_research(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> ResearchUpdate:
# """
# LangGraph node to start the agentic search process.
# """
graph_config = cast(GraphConfig, config["metadata"]["config"])
# graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
# search_tool = graph_config.tooling.search_tool
if search_tool is None or graph_config.inputs.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# if search_tool is None or graph_config.inputs.persona is None:
# raise ValueError("Search tool and persona must be provided for DivCon search")
# Populate prompt
instructions = graph_config.inputs.persona.system_prompt or ""
# # Populate prompt
# instructions = graph_config.inputs.persona.system_prompt or ""
try:
agent_5_instructions = extract_section(
instructions, "Agent Step 5:", "Agent End"
)
if agent_5_instructions is None:
raise ValueError("Agent 5 instructions not found")
agent_5_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
agent_5_task = extract_section(
agent_5_instructions, "Task:", "Independent Research Sources:"
)
if agent_5_task is None:
raise ValueError("Agent 5 task not found")
agent_5_output_objective = extract_section(
agent_5_instructions, "Output Objective:"
)
if agent_5_output_objective is None:
raise ValueError("Agent 5 output objective not found")
except ValueError as e:
raise ValueError(
f"Instructions for Agent Step 5 were not properly formatted: {e}"
)
# try:
# agent_5_instructions = extract_section(
# instructions, "Agent Step 5:", "Agent End"
# )
# if agent_5_instructions is None:
# raise ValueError("Agent 5 instructions not found")
# agent_5_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
# agent_5_task = extract_section(
# agent_5_instructions, "Task:", "Independent Research Sources:"
# )
# if agent_5_task is None:
# raise ValueError("Agent 5 task not found")
# agent_5_output_objective = extract_section(
# agent_5_instructions, "Output Objective:"
# )
# if agent_5_output_objective is None:
# raise ValueError("Agent 5 output objective not found")
# except ValueError as e:
# raise ValueError(
# f"Instructions for Agent Step 5 were not properly formatted: {e}"
# )
research_result_list = []
# research_result_list = []
if agent_5_task.strip() == "*concatenate*":
object_research_results = state.object_research_results
# if agent_5_task.strip() == "*concatenate*":
# object_research_results = state.object_research_results
for object_research_result in object_research_results:
object = object_research_result["object"]
research_result = object_research_result["research_result"]
research_result_list.append(f"Object: {object}\n\n{research_result}")
# for object_research_result in object_research_results:
# object = object_research_result["object"]
# research_result = object_research_result["research_result"]
# research_result_list.append(f"Object: {object}\n\n{research_result}")
research_results = "\n\n".join(research_result_list)
# research_results = "\n\n".join(research_result_list)
else:
raise NotImplementedError("Only '*concatenate*' is currently supported")
# else:
# raise NotImplementedError("Only '*concatenate*' is currently supported")
# Create a prompt for the object consolidation
# # Create a prompt for the object consolidation
if agent_5_base_data is None:
dc_formatting_prompt = DC_FORMATTING_NO_BASE_DATA_PROMPT.format(
text=research_results,
format=agent_5_output_objective,
)
else:
dc_formatting_prompt = DC_FORMATTING_WITH_BASE_DATA_PROMPT.format(
base_data=agent_5_base_data,
text=research_results,
format=agent_5_output_objective,
)
# if agent_5_base_data is None:
# dc_formatting_prompt = DC_FORMATTING_NO_BASE_DATA_PROMPT.format(
# text=research_results,
# format=agent_5_output_objective,
# )
# else:
# dc_formatting_prompt = DC_FORMATTING_WITH_BASE_DATA_PROMPT.format(
# base_data=agent_5_base_data,
# text=research_results,
# format=agent_5_output_objective,
# )
# Run LLM
# # Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_formatting_prompt,
reserved_str="",
),
)
]
# msg = [
# HumanMessage(
# content=trim_prompt_piece(
# config=graph_config.tooling.primary_llm.config,
# prompt_piece=dc_formatting_prompt,
# reserved_str="",
# ),
# )
# ]
try:
_ = run_with_timeout(
60,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=msg,
event_name="initial_agent_answer",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=30,
max_tokens=None,
),
)
# try:
# _ = run_with_timeout(
# 60,
# lambda: stream_llm_answer(
# llm=graph_config.tooling.primary_llm,
# prompt=msg,
# event_name="initial_agent_answer",
# writer=writer,
# agent_answer_level=0,
# agent_answer_question_num=0,
# agent_answer_type="agent_level_answer",
# timeout_override=30,
# max_tokens=None,
# ),
# )
except Exception as e:
raise ValueError(f"Error in consolidate_research: {e}")
# except Exception as e:
# raise ValueError(f"Error in consolidate_research: {e}")
logger.debug("DivCon Step A5 - Final Generation - completed")
# logger.debug("DivCon Step A5 - Final Generation - completed")
return ResearchUpdate(
research_results=research_results,
log_messages=["Agent Source Consilidation done"],
)
# return ResearchUpdate(
# research_results=research_results,
# log_messages=["Agent Source Consilidation done"],
# )

View File

@@ -1,61 +1,50 @@
from datetime import datetime
from typing import cast
# from datetime import datetime
# from typing import cast
from onyx.chat.models import LlmDoc
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
# from onyx.chat.models import LlmDoc
# from onyx.configs.constants import DocumentSource
# from onyx.tools.models import SearchToolOverrideKwargs
# from onyx.tools.tool_implementations.search.search_tool import (
# FINAL_CONTEXT_DOCUMENTS_ID,
# )
# from onyx.tools.tool_implementations.search.search_tool import SearchTool
def research(
question: str,
search_tool: SearchTool,
document_sources: list[DocumentSource] | None = None,
time_cutoff: datetime | None = None,
) -> list[LlmDoc]:
# new db session to avoid concurrency issues
# def research(
# question: str,
# search_tool: SearchTool,
# document_sources: list[DocumentSource] | None = None,
# time_cutoff: datetime | None = None,
# ) -> list[LlmDoc]:
# # new db session to avoid concurrency issues
callback_container: list[list[InferenceSection]] = []
retrieved_docs: list[LlmDoc] = []
# retrieved_docs: list[LlmDoc] = []
with get_session_with_current_tenant() as db_session:
for tool_response in search_tool.run(
query=question,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=False,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
document_sources=document_sources,
time_cutoff=time_cutoff,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
retrieved_docs = cast(list[LlmDoc], tool_response.response)[:10]
break
return retrieved_docs
# for tool_response in search_tool.run(
# query=question,
# override_kwargs=SearchToolOverrideKwargs(original_query=question),
# ):
# # get retrieved docs to send to the rest of the graph
# if tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
# retrieved_docs = cast(list[LlmDoc], tool_response.response)[:10]
# break
# return retrieved_docs
def extract_section(
text: str, start_marker: str, end_marker: str | None = None
) -> str | None:
"""Extract text between markers, returning None if markers not found"""
parts = text.split(start_marker)
# def extract_section(
# text: str, start_marker: str, end_marker: str | None = None
# ) -> str | None:
# """Extract text between markers, returning None if markers not found"""
# parts = text.split(start_marker)
if len(parts) == 1:
return None
# if len(parts) == 1:
# return None
after_start = parts[1].strip()
# after_start = parts[1].strip()
if not end_marker:
return after_start
# if not end_marker:
# return after_start
extract = after_start.split(end_marker)[0]
# extract = after_start.split(end_marker)[0]
return extract.strip()
# return extract.strip()

View File

@@ -1,72 +1,72 @@
from operator import add
from typing import Annotated
from typing import Dict
from typing import TypedDict
# from operator import add
# from typing import Annotated
# from typing import Dict
# from typing import TypedDict
from pydantic import BaseModel
# from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.configs.constants import DocumentSource
# from onyx.agents.agent_search.core_state import CoreState
# from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
# from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
# from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
# from onyx.configs.constants import DocumentSource
### States ###
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
# ### States ###
# class LoggerUpdate(BaseModel):
# log_messages: Annotated[list[str], add] = []
class SearchSourcesObjectsUpdate(LoggerUpdate):
analysis_objects: list[str] = []
analysis_sources: list[DocumentSource] = []
# class SearchSourcesObjectsUpdate(LoggerUpdate):
# analysis_objects: list[str] = []
# analysis_sources: list[DocumentSource] = []
class ObjectSourceInput(LoggerUpdate):
object_source_combination: tuple[str, DocumentSource]
# class ObjectSourceInput(LoggerUpdate):
# object_source_combination: tuple[str, DocumentSource]
class ObjectSourceResearchUpdate(LoggerUpdate):
object_source_research_results: Annotated[list[Dict[str, str]], add] = []
# class ObjectSourceResearchUpdate(LoggerUpdate):
# object_source_research_results: Annotated[list[Dict[str, str]], add] = []
class ObjectInformationInput(LoggerUpdate):
object_information: Dict[str, str]
# class ObjectInformationInput(LoggerUpdate):
# object_information: Dict[str, str]
class ObjectResearchInformationUpdate(LoggerUpdate):
object_research_information_results: Annotated[list[Dict[str, str]], add] = []
# class ObjectResearchInformationUpdate(LoggerUpdate):
# object_research_information_results: Annotated[list[Dict[str, str]], add] = []
class ObjectResearchUpdate(LoggerUpdate):
object_research_results: Annotated[list[Dict[str, str]], add] = []
# class ObjectResearchUpdate(LoggerUpdate):
# object_research_results: Annotated[list[Dict[str, str]], add] = []
class ResearchUpdate(LoggerUpdate):
research_results: str | None = None
# class ResearchUpdate(LoggerUpdate):
# research_results: str | None = None
## Graph Input State
class MainInput(CoreState):
pass
# ## Graph Input State
# class MainInput(CoreState):
# pass
## Graph State
class MainState(
# This includes the core state
MainInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
SearchSourcesObjectsUpdate,
ObjectSourceResearchUpdate,
ObjectResearchInformationUpdate,
ObjectResearchUpdate,
ResearchUpdate,
):
pass
# ## Graph State
# class MainState(
# # This includes the core state
# MainInput,
# ToolChoiceInput,
# ToolCallUpdate,
# ToolChoiceUpdate,
# SearchSourcesObjectsUpdate,
# ObjectSourceResearchUpdate,
# ObjectResearchInformationUpdate,
# ObjectResearchUpdate,
# ResearchUpdate,
# ):
# pass
## Graph Output State - presently not used
class MainOutput(TypedDict):
log_messages: list[str]
# ## Graph Output State - presently not used
# class MainOutput(TypedDict):
# log_messages: list[str]

View File

@@ -1,36 +1,36 @@
from pydantic import BaseModel
# from pydantic import BaseModel
class RefinementSubQuestion(BaseModel):
sub_question: str
sub_question_id: str
verified: bool
answered: bool
answer: str
# class RefinementSubQuestion(BaseModel):
# sub_question: str
# sub_question_id: str
# verified: bool
# answered: bool
# answer: str
class AgentTimings(BaseModel):
base_duration_s: float | None
refined_duration_s: float | None
full_duration_s: float | None
# class AgentTimings(BaseModel):
# base_duration_s: float | None
# refined_duration_s: float | None
# full_duration_s: float | None
class AgentBaseMetrics(BaseModel):
num_verified_documents_total: int | None
num_verified_documents_core: int | None
verified_avg_score_core: float | None
num_verified_documents_base: int | float | None
verified_avg_score_base: float | None = None
base_doc_boost_factor: float | None = None
support_boost_factor: float | None = None
duration_s: float | None = None
# class AgentBaseMetrics(BaseModel):
# num_verified_documents_total: int | None
# num_verified_documents_core: int | None
# verified_avg_score_core: float | None
# num_verified_documents_base: int | float | None
# verified_avg_score_base: float | None = None
# base_doc_boost_factor: float | None = None
# support_boost_factor: float | None = None
# duration_s: float | None = None
class AgentRefinedMetrics(BaseModel):
refined_doc_boost_factor: float | None = None
refined_question_boost_factor: float | None = None
duration_s: float | None = None
# class AgentRefinedMetrics(BaseModel):
# refined_doc_boost_factor: float | None = None
# refined_question_boost_factor: float | None = None
# duration_s: float | None = None
class AgentAdditionalMetrics(BaseModel):
pass
# class AgentAdditionalMetrics(BaseModel):
# pass

View File

@@ -1,61 +1,61 @@
from collections.abc import Hashable
# from collections.abc import Hashable
from langgraph.graph import END
from langgraph.types import Send
# from langgraph.graph import END
# from langgraph.types import Send
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.states import MainState
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.states import MainState
def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
if not state.tools_used:
raise IndexError("state.tools_used cannot be empty")
# def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
# if not state.tools_used:
# raise IndexError("state.tools_used cannot be empty")
# next_tool is either a generic tool name or a DRPath string
next_tool_name = state.tools_used[-1]
# # next_tool is either a generic tool name or a DRPath string
# next_tool_name = state.tools_used[-1]
available_tools = state.available_tools
if not available_tools:
raise ValueError("No tool is available. This should not happen.")
# available_tools = state.available_tools
# if not available_tools:
# raise ValueError("No tool is available. This should not happen.")
if next_tool_name in available_tools:
next_tool_path = available_tools[next_tool_name].path
elif next_tool_name == DRPath.END.value:
return END
elif next_tool_name == DRPath.LOGGER.value:
return DRPath.LOGGER
elif next_tool_name == DRPath.CLOSER.value:
return DRPath.CLOSER
else:
return DRPath.ORCHESTRATOR
# if next_tool_name in available_tools:
# next_tool_path = available_tools[next_tool_name].path
# elif next_tool_name == DRPath.END.value:
# return END
# elif next_tool_name == DRPath.LOGGER.value:
# return DRPath.LOGGER
# elif next_tool_name == DRPath.CLOSER.value:
# return DRPath.CLOSER
# else:
# return DRPath.ORCHESTRATOR
# handle invalid paths
if next_tool_path == DRPath.CLARIFIER:
raise ValueError("CLARIFIER is not a valid path during iteration")
# # handle invalid paths
# if next_tool_path == DRPath.CLARIFIER:
# raise ValueError("CLARIFIER is not a valid path during iteration")
# handle tool calls without a query
if (
next_tool_path
in (
DRPath.INTERNAL_SEARCH,
DRPath.WEB_SEARCH,
DRPath.KNOWLEDGE_GRAPH,
DRPath.IMAGE_GENERATION,
)
and len(state.query_list) == 0
):
return DRPath.CLOSER
# # handle tool calls without a query
# if (
# next_tool_path
# in (
# DRPath.INTERNAL_SEARCH,
# DRPath.WEB_SEARCH,
# DRPath.KNOWLEDGE_GRAPH,
# DRPath.IMAGE_GENERATION,
# )
# and len(state.query_list) == 0
# ):
# return DRPath.CLOSER
return next_tool_path
# return next_tool_path
def completeness_router(state: MainState) -> DRPath | str:
if not state.tools_used:
raise IndexError("tools_used cannot be empty")
# def completeness_router(state: MainState) -> DRPath | str:
# if not state.tools_used:
# raise IndexError("tools_used cannot be empty")
# go to closer if path is CLOSER or no queries
next_path = state.tools_used[-1]
# # go to closer if path is CLOSER or no queries
# next_path = state.tools_used[-1]
if next_path == DRPath.ORCHESTRATOR.value:
return DRPath.ORCHESTRATOR
return DRPath.LOGGER
# if next_path == DRPath.ORCHESTRATOR.value:
# return DRPath.ORCHESTRATOR
# return DRPath.LOGGER

View File

@@ -1,31 +1,27 @@
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchType
# from onyx.agents.agent_search.dr.enums import DRPath
MAX_CHAT_HISTORY_MESSAGES = (
3 # note: actual count is x2 to account for user and assistant messages
)
# MAX_CHAT_HISTORY_MESSAGES = (
# 3 # note: actual count is x2 to account for user and assistant messages
# )
MAX_DR_PARALLEL_SEARCH = 4
# MAX_DR_PARALLEL_SEARCH = 4
# TODO: test more, generally not needed/adds unnecessary iterations
MAX_NUM_CLOSER_SUGGESTIONS = (
0 # how many times the closer can send back to the orchestrator
)
# # TODO: test more, generally not needed/adds unnecessary iterations
# MAX_NUM_CLOSER_SUGGESTIONS = (
# 0 # how many times the closer can send back to the orchestrator
# )
CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
HIGH_LEVEL_PLAN_PREFIX = "The Plan:"
# CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
# HIGH_LEVEL_PLAN_PREFIX = "The Plan:"
AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
DRPath.INTERNAL_SEARCH: 1.0,
DRPath.KNOWLEDGE_GRAPH: 2.0,
DRPath.WEB_SEARCH: 1.5,
DRPath.IMAGE_GENERATION: 3.0,
DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
DRPath.CLOSER: 0.0,
}
# AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
# DRPath.INTERNAL_SEARCH: 1.0,
# DRPath.KNOWLEDGE_GRAPH: 2.0,
# DRPath.WEB_SEARCH: 1.5,
# DRPath.IMAGE_GENERATION: 3.0,
# DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
# DRPath.CLOSER: 0.0,
# }
DR_TIME_BUDGET_BY_TYPE = {
ResearchType.THOUGHTFUL: 3.0,
ResearchType.DEEP: 12.0,
ResearchType.FAST: 0.5,
}
# # Default time budget for agentic search (when use_agentic_search is True)
# DR_TIME_BUDGET_DEFAULT = 12.0

View File

@@ -1,112 +1,111 @@
from datetime import datetime
# from datetime import datetime
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import DRPromptPurpose
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.prompts.dr_prompts import GET_CLARIFICATION_PROMPT
from onyx.prompts.dr_prompts import KG_TYPES_DESCRIPTIONS
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
from onyx.prompts.dr_prompts import ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
from onyx.prompts.dr_prompts import TOOL_DIFFERENTIATION_HINTS
from onyx.prompts.dr_prompts import TOOL_QUESTION_HINTS
from onyx.prompts.prompt_template import PromptTemplate
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.models import DRPromptPurpose
# from onyx.agents.agent_search.dr.models import OrchestratorTool
# from onyx.prompts.dr_prompts import GET_CLARIFICATION_PROMPT
# from onyx.prompts.dr_prompts import KG_TYPES_DESCRIPTIONS
# from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
# from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
# from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
# from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
# from onyx.prompts.dr_prompts import ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
# from onyx.prompts.dr_prompts import TOOL_DIFFERENTIATION_HINTS
# from onyx.prompts.dr_prompts import TOOL_QUESTION_HINTS
# from onyx.prompts.prompt_template import PromptTemplate
def get_dr_prompt_orchestration_templates(
purpose: DRPromptPurpose,
research_type: ResearchType,
available_tools: dict[str, OrchestratorTool],
entity_types_string: str | None = None,
relationship_types_string: str | None = None,
reasoning_result: str | None = None,
tool_calls_string: str | None = None,
) -> PromptTemplate:
available_tools = available_tools or {}
tool_names = list(available_tools.keys())
tool_description_str = "\n\n".join(
f"- {tool_name}: {tool.description}"
for tool_name, tool in available_tools.items()
)
tool_cost_str = "\n".join(
f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
)
# def get_dr_prompt_orchestration_templates(
# purpose: DRPromptPurpose,
# use_agentic_search: bool,
# available_tools: dict[str, OrchestratorTool],
# entity_types_string: str | None = None,
# relationship_types_string: str | None = None,
# reasoning_result: str | None = None,
# tool_calls_string: str | None = None,
# ) -> PromptTemplate:
# available_tools = available_tools or {}
# tool_names = list(available_tools.keys())
# tool_description_str = "\n\n".join(
# f"- {tool_name}: {tool.description}"
# for tool_name, tool in available_tools.items()
# )
# tool_cost_str = "\n".join(
# f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
# )
tool_differentiations: list[str] = [
TOOL_DIFFERENTIATION_HINTS[(tool_1, tool_2)]
for tool_1 in available_tools
for tool_2 in available_tools
if (tool_1, tool_2) in TOOL_DIFFERENTIATION_HINTS
]
tool_differentiation_hint_string = (
"\n".join(tool_differentiations) or "(No differentiating hints available)"
)
# TODO: add tool deliniation pairs for custom tools as well
# tool_differentiations: list[str] = [
# TOOL_DIFFERENTIATION_HINTS[(tool_1, tool_2)]
# for tool_1 in available_tools
# for tool_2 in available_tools
# if (tool_1, tool_2) in TOOL_DIFFERENTIATION_HINTS
# ]
# tool_differentiation_hint_string = (
# "\n".join(tool_differentiations) or "(No differentiating hints available)"
# )
# # TODO: add tool deliniation pairs for custom tools as well
tool_question_hint_string = (
"\n".join(
"- " + TOOL_QUESTION_HINTS[tool]
for tool in available_tools
if tool in TOOL_QUESTION_HINTS
)
or "(No examples available)"
)
# tool_question_hint_string = (
# "\n".join(
# "- " + TOOL_QUESTION_HINTS[tool]
# for tool in available_tools
# if tool in TOOL_QUESTION_HINTS
# )
# or "(No examples available)"
# )
if DRPath.KNOWLEDGE_GRAPH.value in available_tools and (
entity_types_string or relationship_types_string
):
# if DRPath.KNOWLEDGE_GRAPH.value in available_tools and (
# entity_types_string or relationship_types_string
# ):
kg_types_descriptions = KG_TYPES_DESCRIPTIONS.build(
possible_entities=entity_types_string or "",
possible_relationships=relationship_types_string or "",
)
else:
kg_types_descriptions = "(The Knowledge Graph is not used.)"
# kg_types_descriptions = KG_TYPES_DESCRIPTIONS.build(
# possible_entities=entity_types_string or "",
# possible_relationships=relationship_types_string or "",
# )
# else:
# kg_types_descriptions = "(The Knowledge Graph is not used.)"
if purpose == DRPromptPurpose.PLAN:
if research_type == ResearchType.THOUGHTFUL:
raise ValueError("plan generation is not supported for FAST time budget")
base_template = ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
# if purpose == DRPromptPurpose.PLAN:
# if not use_agentic_search:
# raise ValueError("plan generation is only supported for agentic search")
# base_template = ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
elif purpose == DRPromptPurpose.NEXT_STEP_REASONING:
if research_type == ResearchType.THOUGHTFUL:
base_template = ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
else:
raise ValueError(
"reasoning is not separately required for DEEP time budget"
)
# elif purpose == DRPromptPurpose.NEXT_STEP_REASONING:
# if not use_agentic_search:
# base_template = ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
# else:
# raise ValueError(
# "reasoning is not separately required for agentic search"
# )
elif purpose == DRPromptPurpose.NEXT_STEP_PURPOSE:
base_template = ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
# elif purpose == DRPromptPurpose.NEXT_STEP_PURPOSE:
# base_template = ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
elif purpose == DRPromptPurpose.NEXT_STEP:
if research_type == ResearchType.THOUGHTFUL:
base_template = ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
else:
base_template = ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
# elif purpose == DRPromptPurpose.NEXT_STEP:
# if not use_agentic_search:
# base_template = ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
# else:
# base_template = ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
elif purpose == DRPromptPurpose.CLARIFICATION:
if research_type == ResearchType.THOUGHTFUL:
raise ValueError("clarification is not supported for FAST time budget")
base_template = GET_CLARIFICATION_PROMPT
# elif purpose == DRPromptPurpose.CLARIFICATION:
# if not use_agentic_search:
# raise ValueError("clarification is only supported for agentic search")
# base_template = GET_CLARIFICATION_PROMPT
else:
# for mypy, clearly a mypy bug
raise ValueError(f"Invalid purpose: {purpose}")
# else:
# # for mypy, clearly a mypy bug
# raise ValueError(f"Invalid purpose: {purpose}")
return base_template.partial_build(
num_available_tools=str(len(tool_names)),
available_tools=", ".join(tool_names),
tool_choice_options=" or ".join(tool_names),
current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
kg_types_descriptions=kg_types_descriptions,
tool_descriptions=tool_description_str,
tool_differentiation_hints=tool_differentiation_hint_string,
tool_question_hints=tool_question_hint_string,
average_tool_costs=tool_cost_str,
reasoning_result=reasoning_result or "(No reasoning result provided.)",
tool_calls_string=tool_calls_string or "(No tool calls provided.)",
)
# return base_template.partial_build(
# num_available_tools=str(len(tool_names)),
# available_tools=", ".join(tool_names),
# tool_choice_options=" or ".join(tool_names),
# current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
# kg_types_descriptions=kg_types_descriptions,
# tool_descriptions=tool_description_str,
# tool_differentiation_hints=tool_differentiation_hint_string,
# tool_question_hints=tool_question_hint_string,
# average_tool_costs=tool_cost_str,
# reasoning_result=reasoning_result or "(No reasoning result provided.)",
# tool_calls_string=tool_calls_string or "(No tool calls provided.)",
# )

View File

@@ -1,32 +1,22 @@
from enum import Enum
# from enum import Enum
class ResearchType(str, Enum):
"""Research type options for agent search operations"""
# class ResearchAnswerPurpose(str, Enum):
# """Research answer purpose options for agent search operations"""
# BASIC = "BASIC"
LEGACY_AGENTIC = "LEGACY_AGENTIC" # only used for legacy agentic search migrations
THOUGHTFUL = "THOUGHTFUL"
DEEP = "DEEP"
FAST = "FAST"
# ANSWER = "ANSWER"
# CLARIFICATION_REQUEST = "CLARIFICATION_REQUEST"
class ResearchAnswerPurpose(str, Enum):
"""Research answer purpose options for agent search operations"""
ANSWER = "ANSWER"
CLARIFICATION_REQUEST = "CLARIFICATION_REQUEST"
class DRPath(str, Enum):
CLARIFIER = "Clarifier"
ORCHESTRATOR = "Orchestrator"
INTERNAL_SEARCH = "Internal Search"
GENERIC_TOOL = "Generic Tool"
KNOWLEDGE_GRAPH = "Knowledge Graph Search"
WEB_SEARCH = "Web Search"
IMAGE_GENERATION = "Image Generation"
GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
CLOSER = "Closer"
LOGGER = "Logger"
END = "End"
# class DRPath(str, Enum):
# CLARIFIER = "Clarifier"
# ORCHESTRATOR = "Orchestrator"
# INTERNAL_SEARCH = "Internal Search"
# GENERIC_TOOL = "Generic Tool"
# KNOWLEDGE_GRAPH = "Knowledge Graph Search"
# WEB_SEARCH = "Web Search"
# IMAGE_GENERATION = "Image Generation"
# GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
# CLOSER = "Closer"
# LOGGER = "Logger"
# END = "End"

View File

@@ -1,88 +1,88 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.conditional_edges import completeness_router
from onyx.agents.agent_search.dr.conditional_edges import decision_router
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.nodes.dr_a0_clarification import clarifier
from onyx.agents.agent_search.dr.nodes.dr_a1_orchestrator import orchestrator
from onyx.agents.agent_search.dr.nodes.dr_a2_closer import closer
from onyx.agents.agent_search.dr.nodes.dr_a3_logger import logging
from onyx.agents.agent_search.dr.states import MainInput
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_graph_builder import (
dr_basic_search_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
dr_custom_tool_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_graph_builder import (
dr_generic_internal_tool_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_graph_builder import (
dr_image_generation_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
dr_kg_search_graph_builder,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_graph_builder import (
dr_ws_graph_builder,
)
# from onyx.agents.agent_search.dr.conditional_edges import completeness_router
# from onyx.agents.agent_search.dr.conditional_edges import decision_router
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.nodes.dr_a0_clarification import clarifier
# from onyx.agents.agent_search.dr.nodes.dr_a1_orchestrator import orchestrator
# from onyx.agents.agent_search.dr.nodes.dr_a2_closer import closer
# from onyx.agents.agent_search.dr.nodes.dr_a3_logger import logging
# from onyx.agents.agent_search.dr.states import MainInput
# from onyx.agents.agent_search.dr.states import MainState
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_graph_builder import (
# dr_basic_search_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
# dr_custom_tool_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_graph_builder import (
# dr_generic_internal_tool_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_graph_builder import (
# dr_image_generation_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
# dr_kg_search_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_graph_builder import (
# dr_ws_graph_builder,
# )
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import search
# # from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import search
def dr_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the deep research agent.
"""
# def dr_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for the deep research agent.
# """
graph = StateGraph(state_schema=MainState, input=MainInput)
# graph = StateGraph(state_schema=MainState, input=MainInput)
### Add nodes ###
# ### Add nodes ###
graph.add_node(DRPath.CLARIFIER, clarifier)
# graph.add_node(DRPath.CLARIFIER, clarifier)
graph.add_node(DRPath.ORCHESTRATOR, orchestrator)
# graph.add_node(DRPath.ORCHESTRATOR, orchestrator)
basic_search_graph = dr_basic_search_graph_builder().compile()
graph.add_node(DRPath.INTERNAL_SEARCH, basic_search_graph)
# basic_search_graph = dr_basic_search_graph_builder().compile()
# graph.add_node(DRPath.INTERNAL_SEARCH, basic_search_graph)
kg_search_graph = dr_kg_search_graph_builder().compile()
graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
# kg_search_graph = dr_kg_search_graph_builder().compile()
# graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
internet_search_graph = dr_ws_graph_builder().compile()
graph.add_node(DRPath.WEB_SEARCH, internet_search_graph)
# internet_search_graph = dr_ws_graph_builder().compile()
# graph.add_node(DRPath.WEB_SEARCH, internet_search_graph)
image_generation_graph = dr_image_generation_graph_builder().compile()
graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
# image_generation_graph = dr_image_generation_graph_builder().compile()
# graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
custom_tool_graph = dr_custom_tool_graph_builder().compile()
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
# custom_tool_graph = dr_custom_tool_graph_builder().compile()
# graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)
# generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
# graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)
graph.add_node(DRPath.CLOSER, closer)
graph.add_node(DRPath.LOGGER, logging)
# graph.add_node(DRPath.CLOSER, closer)
# graph.add_node(DRPath.LOGGER, logging)
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key=DRPath.CLARIFIER)
# graph.add_edge(start_key=START, end_key=DRPath.CLARIFIER)
graph.add_conditional_edges(DRPath.CLARIFIER, decision_router)
# graph.add_conditional_edges(DRPath.CLARIFIER, decision_router)
graph.add_conditional_edges(DRPath.ORCHESTRATOR, decision_router)
# graph.add_conditional_edges(DRPath.ORCHESTRATOR, decision_router)
graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.WEB_SEARCH, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.WEB_SEARCH, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
# graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
graph.add_edge(start_key=DRPath.LOGGER, end_key=END)
# graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
# graph.add_edge(start_key=DRPath.LOGGER, end_key=END)
return graph
# return graph

View File

@@ -1,131 +1,131 @@
from enum import Enum
# from enum import Enum
from pydantic import BaseModel
# from pydantic import BaseModel
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
GeneratedImage,
)
from onyx.context.search.models import InferenceSection
from onyx.tools.tool import Tool
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
# GeneratedImage,
# )
# from onyx.context.search.models import InferenceSection
# from onyx.tools.tool import Tool
class OrchestratorStep(BaseModel):
tool: str
questions: list[str]
# class OrchestratorStep(BaseModel):
# tool: str
# questions: list[str]
class OrchestratorDecisonsNoPlan(BaseModel):
reasoning: str
next_step: OrchestratorStep
# class OrchestratorDecisonsNoPlan(BaseModel):
# reasoning: str
# next_step: OrchestratorStep
class OrchestrationPlan(BaseModel):
reasoning: str
plan: str
# class OrchestrationPlan(BaseModel):
# reasoning: str
# plan: str
class ClarificationGenerationResponse(BaseModel):
clarification_needed: bool
clarification_question: str
# class ClarificationGenerationResponse(BaseModel):
# clarification_needed: bool
# clarification_question: str
class DecisionResponse(BaseModel):
reasoning: str
decision: str
# class DecisionResponse(BaseModel):
# reasoning: str
# decision: str
class QueryEvaluationResponse(BaseModel):
reasoning: str
query_permitted: bool
# class QueryEvaluationResponse(BaseModel):
# reasoning: str
# query_permitted: bool
class OrchestrationClarificationInfo(BaseModel):
clarification_question: str
clarification_response: str | None = None
# class OrchestrationClarificationInfo(BaseModel):
# clarification_question: str
# clarification_response: str | None = None
class WebSearchAnswer(BaseModel):
urls_to_open_indices: list[int]
# class WebSearchAnswer(BaseModel):
# urls_to_open_indices: list[int]
class SearchAnswer(BaseModel):
reasoning: str
answer: str
claims: list[str] | None = None
# class SearchAnswer(BaseModel):
# reasoning: str
# answer: str
# claims: list[str] | None = None
class TestInfoCompleteResponse(BaseModel):
reasoning: str
complete: bool
gaps: list[str]
# class TestInfoCompleteResponse(BaseModel):
# reasoning: str
# complete: bool
# gaps: list[str]
# TODO: revisit with custom tools implementation in v2
# each tool should be a class with the attributes below, plus the actual tool implementation
# this will also allow custom tools to have their own cost
class OrchestratorTool(BaseModel):
tool_id: int
name: str
llm_path: str # the path for the LLM to refer by
path: DRPath # the actual path in the graph
description: str
metadata: dict[str, str]
cost: float
tool_object: Tool | None = None # None for CLOSER
# # TODO: revisit with custom tools implementation in v2
# # each tool should be a class with the attributes below, plus the actual tool implementation
# # this will also allow custom tools to have their own cost
# class OrchestratorTool(BaseModel):
# tool_id: int
# name: str
# llm_path: str # the path for the LLM to refer by
# path: DRPath # the actual path in the graph
# description: str
# metadata: dict[str, str]
# cost: float
# tool_object: Tool | None = None # None for CLOSER
class Config:
arbitrary_types_allowed = True
# class Config:
# arbitrary_types_allowed = True
class IterationInstructions(BaseModel):
iteration_nr: int
plan: str | None
reasoning: str
purpose: str
# class IterationInstructions(BaseModel):
# iteration_nr: int
# plan: str | None
# reasoning: str
# purpose: str
class IterationAnswer(BaseModel):
tool: str
tool_id: int
iteration_nr: int
parallelization_nr: int
question: str
reasoning: str | None
answer: str
cited_documents: dict[int, InferenceSection]
background_info: str | None = None
claims: list[str] | None = None
additional_data: dict[str, str] | None = None
response_type: str | None = None
data: dict | list | str | int | float | bool | None = None
file_ids: list[str] | None = None
# TODO: This is not ideal, but we'll can rework the schema
# for deep research later
is_web_fetch: bool = False
# for image generation step-types
generated_images: list[GeneratedImage] | None = None
# for multi-query search tools (v2 web search and internal search)
# TODO: Clean this up to be more flexible to tools
queries: list[str] | None = None
# class IterationAnswer(BaseModel):
# tool: str
# tool_id: int
# iteration_nr: int
# parallelization_nr: int
# question: str
# reasoning: str | None
# answer: str
# cited_documents: dict[int, InferenceSection]
# background_info: str | None = None
# claims: list[str] | None = None
# additional_data: dict[str, str] | None = None
# response_type: str | None = None
# data: dict | list | str | int | float | bool | None = None
# file_ids: list[str] | None = None
# # TODO: This is not ideal, but we'll can rework the schema
# # for deep research later
# is_web_fetch: bool = False
# # for image generation step-types
# generated_images: list[GeneratedImage] | None = None
# # for multi-query search tools (v2 web search and internal search)
# # TODO: Clean this up to be more flexible to tools
# queries: list[str] | None = None
class AggregatedDRContext(BaseModel):
context: str
cited_documents: list[InferenceSection]
is_internet_marker_dict: dict[str, bool]
global_iteration_responses: list[IterationAnswer]
# class AggregatedDRContext(BaseModel):
# context: str
# cited_documents: list[InferenceSection]
# is_internet_marker_dict: dict[str, bool]
# global_iteration_responses: list[IterationAnswer]
class DRPromptPurpose(str, Enum):
PLAN = "PLAN"
NEXT_STEP = "NEXT_STEP"
NEXT_STEP_REASONING = "NEXT_STEP_REASONING"
NEXT_STEP_PURPOSE = "NEXT_STEP_PURPOSE"
CLARIFICATION = "CLARIFICATION"
# class DRPromptPurpose(str, Enum):
# PLAN = "PLAN"
# NEXT_STEP = "NEXT_STEP"
# NEXT_STEP_REASONING = "NEXT_STEP_REASONING"
# NEXT_STEP_PURPOSE = "NEXT_STEP_PURPOSE"
# CLARIFICATION = "CLARIFICATION"
class BaseSearchProcessingResponse(BaseModel):
specified_source_types: list[str]
rewritten_query: str
time_filter: str
# class BaseSearchProcessingResponse(BaseModel):
# specified_source_types: list[str]
# rewritten_query: str
# time_filter: str

File diff suppressed because it is too large Load Diff

View File

@@ -1,423 +1,418 @@
import re
from datetime import datetime
from typing import cast
# import re
# from datetime import datetime
# from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
# from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
from onyx.agents.agent_search.dr.constants import MAX_NUM_CLOSER_SUGGESTIONS
from onyx.agents.agent_search.dr.enums import DRPath
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.models import TestInfoCompleteResponse
from onyx.agents.agent_search.dr.states import FinalUpdate
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.states import OrchestrationUpdate
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
GeneratedImageFullResult,
)
from onyx.agents.agent_search.dr.utils import aggregate_context
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.dr.utils import get_chat_history_string
from onyx.agents.agent_search.dr.utils import get_prompt_question
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.chat.chat_utils import llm_doc_from_inference_section
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
from onyx.context.search.models import InferenceSection
from onyx.db.chat import create_search_doc_from_inference_section
from onyx.db.chat import update_db_session_with_messages
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ResearchAgentIteration
from onyx.db.models import ResearchAgentIterationSubStep
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.llm.utils import check_number_of_tokens
from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.server.query_and_chat.streaming_models import StreamingType
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
# from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
# from onyx.agents.agent_search.dr.constants import MAX_NUM_CLOSER_SUGGESTIONS
# from onyx.agents.agent_search.dr.enums import DRPath
# from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
# from onyx.agents.agent_search.dr.models import AggregatedDRContext
# from onyx.agents.agent_search.dr.models import TestInfoCompleteResponse
# from onyx.agents.agent_search.dr.states import FinalUpdate
# from onyx.agents.agent_search.dr.states import MainState
# from onyx.agents.agent_search.dr.states import OrchestrationUpdate
# from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
# GeneratedImageFullResult,
# )
# from onyx.agents.agent_search.dr.utils import aggregate_context
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.dr.utils import get_chat_history_string
# from onyx.agents.agent_search.dr.utils import get_prompt_question
# from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
# from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.agents.agent_search.utils import create_question_prompt
# from onyx.chat.chat_utils import llm_doc_from_inference_section
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
# from onyx.context.search.models import InferenceSection
# from onyx.db.chat import create_search_doc_from_inference_section
# from onyx.db.chat import update_db_session_with_messages
# from onyx.db.models import ChatMessage__SearchDoc
# from onyx.db.models import ResearchAgentIteration
# from onyx.db.models import ResearchAgentIterationSubStep
# from onyx.db.models import SearchDoc as DbSearchDoc
# from onyx.llm.utils import check_number_of_tokens
# from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
# from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
# from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
# from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
# from onyx.server.query_and_chat.streaming_models import CitationDelta
# from onyx.server.query_and_chat.streaming_models import CitationStart
# from onyx.server.query_and_chat.streaming_models import MessageStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.server.query_and_chat.streaming_models import StreamingType
# from onyx.utils.logger import setup_logger
# from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
# logger = setup_logger()
def extract_citation_numbers(text: str) -> list[int]:
"""
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
Returns a list of all unique citation numbers found.
"""
# Pattern to match [[number]] or [[number1, number2, ...]]
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
matches = re.findall(pattern, text)
# def extract_citation_numbers(text: str) -> list[int]:
# """
# Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
# Returns a list of all unique citation numbers found.
# """
# # Pattern to match [[number]] or [[number1, number2, ...]]
# pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
# matches = re.findall(pattern, text)
cited_numbers = []
for match in matches:
# Split by comma and extract all numbers
numbers = [int(num.strip()) for num in match.split(",")]
cited_numbers.extend(numbers)
# cited_numbers = []
# for match in matches:
# # Split by comma and extract all numbers
# numbers = [int(num.strip()) for num in match.split(",")]
# cited_numbers.extend(numbers)
return list(set(cited_numbers)) # Return unique numbers
# return list(set(cited_numbers)) # Return unique numbers
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
numbers = [int(num.strip()) for num in citation_content.split(",")]
# def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
# citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
# numbers = [int(num.strip()) for num in citation_content.split(",")]
# For multiple citations like [[3, 5, 7]], create separate linked citations
linked_citations = []
for num in numbers:
if num - 1 < len(docs): # Check bounds
link = docs[num - 1].link or ""
linked_citations.append(f"[[{num}]]({link})")
else:
linked_citations.append(f"[[{num}]]") # No link if out of bounds
# # For multiple citations like [[3, 5, 7]], create separate linked citations
# linked_citations = []
# for num in numbers:
# if num - 1 < len(docs): # Check bounds
# link = docs[num - 1].link or ""
# linked_citations.append(f"[[{num}]]({link})")
# else:
# linked_citations.append(f"[[{num}]]") # No link if out of bounds
return "".join(linked_citations)
# return "".join(linked_citations)
def insert_chat_message_search_doc_pair(
message_id: int, search_doc_ids: list[int], db_session: Session
) -> None:
"""
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
# def insert_chat_message_search_doc_pair(
# message_id: int, search_doc_ids: list[int], db_session: Session
# ) -> None:
# """
# Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
Args:
message_id: The ID of the chat message
search_doc_id: The ID of the search document
db_session: The database session
"""
for search_doc_id in search_doc_ids:
chat_message_search_doc = ChatMessage__SearchDoc(
chat_message_id=message_id, search_doc_id=search_doc_id
)
db_session.add(chat_message_search_doc)
# Args:
# message_id: The ID of the chat message
# search_doc_id: The ID of the search document
# db_session: The database session
# """
# for search_doc_id in search_doc_ids:
# chat_message_search_doc = ChatMessage__SearchDoc(
# chat_message_id=message_id, search_doc_id=search_doc_id
# )
# db_session.add(chat_message_search_doc)
def save_iteration(
state: MainState,
graph_config: GraphConfig,
aggregated_context: AggregatedDRContext,
final_answer: str,
all_cited_documents: list[InferenceSection],
is_internet_marker_dict: dict[str, bool],
) -> None:
db_session = graph_config.persistence.db_session
message_id = graph_config.persistence.message_id
research_type = graph_config.behavior.research_type
db_session = graph_config.persistence.db_session
# def save_iteration(
# state: MainState,
# graph_config: GraphConfig,
# aggregated_context: AggregatedDRContext,
# final_answer: str,
# all_cited_documents: list[InferenceSection],
# is_internet_marker_dict: dict[str, bool],
# ) -> None:
# db_session = graph_config.persistence.db_session
# message_id = graph_config.persistence.message_id
# first, insert the search_docs
search_docs = [
create_search_doc_from_inference_section(
inference_section=inference_section,
is_internet=is_internet_marker_dict.get(
inference_section.center_chunk.document_id, False
), # TODO: revisit
db_session=db_session,
commit=False,
)
for inference_section in all_cited_documents
]
# # first, insert the search_docs
# search_docs = [
# create_search_doc_from_inference_section(
# inference_section=inference_section,
# is_internet=is_internet_marker_dict.get(
# inference_section.center_chunk.document_id, False
# ), # TODO: revisit
# db_session=db_session,
# commit=False,
# )
# for inference_section in all_cited_documents
# ]
# then, map_search_docs to message
insert_chat_message_search_doc_pair(
message_id, [search_doc.id for search_doc in search_docs], db_session
)
# # then, map_search_docs to message
# insert_chat_message_search_doc_pair(
# message_id, [search_doc.id for search_doc in search_docs], db_session
# )
# lastly, insert the citations
citation_dict: dict[int, int] = {}
cited_doc_nrs = extract_citation_numbers(final_answer)
for cited_doc_nr in cited_doc_nrs:
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# # lastly, insert the citations
# citation_dict: dict[int, int] = {}
# cited_doc_nrs = extract_citation_numbers(final_answer)
# for cited_doc_nr in cited_doc_nrs:
# citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# TODO: generate plan as dict in the first place
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# # TODO: generate plan as dict in the first place
# plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
# plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# Update the chat message and its parent message in database
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=graph_config.persistence.chat_session_id,
is_agentic=graph_config.behavior.use_agentic_search,
message=final_answer,
citations=citation_dict,
research_type=research_type,
research_plan=plan_of_record_dict,
final_documents=search_docs,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
)
# # Update the chat message and its parent message in database
# update_db_session_with_messages(
# db_session=db_session,
# chat_message_id=message_id,
# chat_session_id=graph_config.persistence.chat_session_id,
# is_agentic=graph_config.behavior.use_agentic_search,
# message=final_answer,
# citations=citation_dict,
# research_type=None, # research_type is deprecated
# research_plan=plan_of_record_dict,
# final_documents=search_docs,
# update_parent_message=True,
# research_answer_purpose=ResearchAnswerPurpose.ANSWER,
# )
for iteration_preparation in state.iteration_instructions:
research_agent_iteration_step = ResearchAgentIteration(
primary_question_id=message_id,
reasoning=iteration_preparation.reasoning,
purpose=iteration_preparation.purpose,
iteration_nr=iteration_preparation.iteration_nr,
)
db_session.add(research_agent_iteration_step)
# for iteration_preparation in state.iteration_instructions:
# research_agent_iteration_step = ResearchAgentIteration(
# primary_question_id=message_id,
# reasoning=iteration_preparation.reasoning,
# purpose=iteration_preparation.purpose,
# iteration_nr=iteration_preparation.iteration_nr,
# )
# db_session.add(research_agent_iteration_step)
for iteration_answer in aggregated_context.global_iteration_responses:
# for iteration_answer in aggregated_context.global_iteration_responses:
retrieved_search_docs = convert_inference_sections_to_search_docs(
list(iteration_answer.cited_documents.values())
)
# retrieved_search_docs = convert_inference_sections_to_search_docs(
# list(iteration_answer.cited_documents.values())
# )
# Convert SavedSearchDoc objects to JSON-serializable format
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
# # Convert SavedSearchDoc objects to JSON-serializable format
# serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
primary_question_id=message_id,
iteration_nr=iteration_answer.iteration_nr,
iteration_sub_step_nr=iteration_answer.parallelization_nr,
sub_step_instructions=iteration_answer.question,
sub_step_tool_id=iteration_answer.tool_id,
sub_answer=iteration_answer.answer,
reasoning=iteration_answer.reasoning,
claims=iteration_answer.claims,
cited_doc_results=serialized_search_docs,
generated_images=(
GeneratedImageFullResult(images=iteration_answer.generated_images)
if iteration_answer.generated_images
else None
),
additional_data=iteration_answer.additional_data,
queries=iteration_answer.queries,
)
db_session.add(research_agent_iteration_sub_step)
# research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
# primary_question_id=message_id,
# iteration_nr=iteration_answer.iteration_nr,
# iteration_sub_step_nr=iteration_answer.parallelization_nr,
# sub_step_instructions=iteration_answer.question,
# sub_step_tool_id=iteration_answer.tool_id,
# sub_answer=iteration_answer.answer,
# reasoning=iteration_answer.reasoning,
# claims=iteration_answer.claims,
# cited_doc_results=serialized_search_docs,
# generated_images=(
# GeneratedImageFullResult(images=iteration_answer.generated_images)
# if iteration_answer.generated_images
# else None
# ),
# additional_data=iteration_answer.additional_data,
# queries=iteration_answer.queries,
# )
# db_session.add(research_agent_iteration_sub_step)
db_session.commit()
# db_session.commit()
def closer(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> FinalUpdate | OrchestrationUpdate:
"""
LangGraph node to close the DR process and finalize the answer.
"""
# def closer(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> FinalUpdate | OrchestrationUpdate:
# """
# LangGraph node to close the DR process and finalize the answer.
# """
node_start_time = datetime.now()
# TODO: generate final answer using all the previous steps
# (right now, answers from each step are concatenated onto each other)
# Also, add missing fields once usage in UI is clear.
# node_start_time = datetime.now()
# # TODO: generate final answer using all the previous steps
# # (right now, answers from each step are concatenated onto each other)
# # Also, add missing fields once usage in UI is clear.
current_step_nr = state.current_step_nr
# current_step_nr = state.current_step_nr
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = state.original_question
if not base_question:
raise ValueError("Question is required for closer")
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = state.original_question
# if not base_question:
# raise ValueError("Question is required for closer")
research_type = graph_config.behavior.research_type
# use_agentic_search = graph_config.behavior.use_agentic_search
assistant_system_prompt: str = state.assistant_system_prompt or ""
assistant_task_prompt = state.assistant_task_prompt
# assistant_system_prompt: str = state.assistant_system_prompt or ""
# assistant_task_prompt = state.assistant_task_prompt
uploaded_context = state.uploaded_test_context or ""
# uploaded_context = state.uploaded_test_context or ""
clarification = state.clarification
prompt_question = get_prompt_question(base_question, clarification)
# clarification = state.clarification
# prompt_question = get_prompt_question(base_question, clarification)
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
)
# chat_history_string = (
# get_chat_history_string(
# graph_config.inputs.prompt_builder.message_history,
# MAX_CHAT_HISTORY_MESSAGES,
# )
# or "(No chat history yet available)"
# )
aggregated_context_w_docs = aggregate_context(
state.iteration_responses, include_documents=True
)
# aggregated_context_w_docs = aggregate_context(
# state.iteration_responses, include_documents=True
# )
aggregated_context_wo_docs = aggregate_context(
state.iteration_responses, include_documents=False
)
# aggregated_context_wo_docs = aggregate_context(
# state.iteration_responses, include_documents=False
# )
iteration_responses_w_docs_string = aggregated_context_w_docs.context
iteration_responses_wo_docs_string = aggregated_context_wo_docs.context
all_cited_documents = aggregated_context_w_docs.cited_documents
# iteration_responses_w_docs_string = aggregated_context_w_docs.context
# iteration_responses_wo_docs_string = aggregated_context_wo_docs.context
# all_cited_documents = aggregated_context_w_docs.cited_documents
num_closer_suggestions = state.num_closer_suggestions
# num_closer_suggestions = state.num_closer_suggestions
if (
num_closer_suggestions < MAX_NUM_CLOSER_SUGGESTIONS
and research_type == ResearchType.DEEP
):
test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
base_question=prompt_question,
questions_answers_claims=iteration_responses_wo_docs_string,
chat_history_string=chat_history_string,
high_level_plan=(
state.plan_of_record.plan
if state.plan_of_record
else "No plan available"
),
)
# if (
# num_closer_suggestions < MAX_NUM_CLOSER_SUGGESTIONS
# and use_agentic_search
# ):
# test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
# base_question=prompt_question,
# questions_answers_claims=iteration_responses_wo_docs_string,
# chat_history_string=chat_history_string,
# high_level_plan=(
# state.plan_of_record.plan
# if state.plan_of_record
# else "No plan available"
# ),
# )
test_info_complete_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
test_info_complete_prompt + (assistant_task_prompt or ""),
),
schema=TestInfoCompleteResponse,
timeout_override=TF_DR_TIMEOUT_LONG,
# max_tokens=1000,
)
# test_info_complete_json = invoke_llm_json(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt,
# test_info_complete_prompt + (assistant_task_prompt or ""),
# ),
# schema=TestInfoCompleteResponse,
# timeout_override=TF_DR_TIMEOUT_LONG,
# # max_tokens=1000,
# )
if test_info_complete_json.complete:
pass
# if test_info_complete_json.complete:
# pass
else:
return OrchestrationUpdate(
tools_used=[DRPath.ORCHESTRATOR.value],
query_list=[],
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="closer",
node_start_time=node_start_time,
)
],
gaps=test_info_complete_json.gaps,
num_closer_suggestions=num_closer_suggestions + 1,
)
# else:
# return OrchestrationUpdate(
# tools_used=[DRPath.ORCHESTRATOR.value],
# query_list=[],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="closer",
# node_start_time=node_start_time,
# )
# ],
# gaps=test_info_complete_json.gaps,
# num_closer_suggestions=num_closer_suggestions + 1,
# )
retrieved_search_docs = convert_inference_sections_to_search_docs(
all_cited_documents
)
# retrieved_search_docs = convert_inference_sections_to_search_docs(
# all_cited_documents
# )
write_custom_event(
current_step_nr,
MessageStart(
content="",
final_documents=retrieved_search_docs,
),
writer,
)
# write_custom_event(
# current_step_nr,
# MessageStart(
# content="",
# final_documents=retrieved_search_docs,
# ),
# writer,
# )
if research_type in [ResearchType.THOUGHTFUL, ResearchType.FAST]:
final_answer_base_prompt = FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
elif research_type == ResearchType.DEEP:
final_answer_base_prompt = FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
else:
raise ValueError(f"Invalid research type: {research_type}")
# if not use_agentic_search:
# final_answer_base_prompt = FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
# else:
# final_answer_base_prompt = FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
estimated_final_answer_prompt_tokens = check_number_of_tokens(
final_answer_base_prompt.build(
base_question=prompt_question,
iteration_responses_string=iteration_responses_w_docs_string,
chat_history_string=chat_history_string,
uploaded_context=uploaded_context,
)
)
# estimated_final_answer_prompt_tokens = check_number_of_tokens(
# final_answer_base_prompt.build(
# base_question=prompt_question,
# iteration_responses_string=iteration_responses_w_docs_string,
# chat_history_string=chat_history_string,
# uploaded_context=uploaded_context,
# )
# )
# for DR, rely only on sub-answers and claims to save tokens if context is too long
# TODO: consider compression step for Thoughtful mode if context is too long.
# Should generally not be the case though.
# # for DR, rely only on sub-answers and claims to save tokens if context is too long
# # TODO: consider compression step for Thoughtful mode if context is too long.
# # Should generally not be the case though.
max_allowed_input_tokens = graph_config.tooling.primary_llm.config.max_input_tokens
# max_allowed_input_tokens = graph_config.tooling.primary_llm.config.max_input_tokens
if (
estimated_final_answer_prompt_tokens > 0.8 * max_allowed_input_tokens
and research_type == ResearchType.DEEP
):
iteration_responses_string = iteration_responses_wo_docs_string
else:
iteration_responses_string = iteration_responses_w_docs_string
# if (
# estimated_final_answer_prompt_tokens > 0.8 * max_allowed_input_tokens
# and use_agentic_search
# ):
# iteration_responses_string = iteration_responses_wo_docs_string
# else:
# iteration_responses_string = iteration_responses_w_docs_string
final_answer_prompt = final_answer_base_prompt.build(
base_question=prompt_question,
iteration_responses_string=iteration_responses_string,
chat_history_string=chat_history_string,
uploaded_context=uploaded_context,
)
# final_answer_prompt = final_answer_base_prompt.build(
# base_question=prompt_question,
# iteration_responses_string=iteration_responses_string,
# chat_history_string=chat_history_string,
# uploaded_context=uploaded_context,
# )
if graph_config.inputs.project_instructions:
assistant_system_prompt = (
assistant_system_prompt
+ PROJECT_INSTRUCTIONS_SEPARATOR
+ (graph_config.inputs.project_instructions or "")
)
# if graph_config.inputs.project_instructions:
# assistant_system_prompt = (
# assistant_system_prompt
# + PROJECT_INSTRUCTIONS_SEPARATOR
# + (graph_config.inputs.project_instructions or "")
# )
all_context_llmdocs = [
llm_doc_from_inference_section(inference_section)
for inference_section in all_cited_documents
]
# all_context_llmdocs = [
# llm_doc_from_inference_section(inference_section)
# for inference_section in all_cited_documents
# ]
try:
streamed_output, _, citation_infos = run_with_timeout(
int(3 * TF_DR_TIMEOUT_LONG),
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt,
final_answer_prompt + (assistant_task_prompt or ""),
),
event_name="basic_response",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=int(2 * TF_DR_TIMEOUT_LONG),
answer_piece=StreamingType.MESSAGE_DELTA.value,
ind=current_step_nr,
context_docs=all_context_llmdocs,
replace_citations=True,
# max_tokens=None,
),
)
# try:
# streamed_output, _, citation_infos = run_with_timeout(
# int(3 * TF_DR_TIMEOUT_LONG),
# lambda: stream_llm_answer(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt,
# final_answer_prompt + (assistant_task_prompt or ""),
# ),
# event_name="basic_response",
# writer=writer,
# agent_answer_level=0,
# agent_answer_question_num=0,
# agent_answer_type="agent_level_answer",
# timeout_override=int(2 * TF_DR_TIMEOUT_LONG),
# answer_piece=StreamingType.MESSAGE_DELTA.value,
# ind=current_step_nr,
# context_docs=all_context_llmdocs,
# replace_citations=True,
# # max_tokens=None,
# ),
# )
final_answer = "".join(streamed_output)
except Exception as e:
raise ValueError(f"Error in consolidate_research: {e}")
# final_answer = "".join(streamed_output)
# except Exception as e:
# raise ValueError(f"Error in consolidate_research: {e}")
write_custom_event(current_step_nr, SectionEnd(), writer)
# write_custom_event(current_step_nr, SectionEnd(), writer)
current_step_nr += 1
# current_step_nr += 1
write_custom_event(current_step_nr, CitationStart(), writer)
write_custom_event(current_step_nr, CitationDelta(citations=citation_infos), writer)
write_custom_event(current_step_nr, SectionEnd(), writer)
# write_custom_event(current_step_nr, CitationStart(), writer)
# write_custom_event(current_step_nr, CitationDelta(citations=citation_infos), writer)
# write_custom_event(current_step_nr, SectionEnd(), writer)
current_step_nr += 1
# current_step_nr += 1
# Log the research agent steps
# save_iteration(
# state,
# graph_config,
# aggregated_context,
# final_answer,
# all_cited_documents,
# is_internet_marker_dict,
# )
# # Log the research agent steps
# # save_iteration(
# # state,
# # graph_config,
# # aggregated_context,
# # final_answer,
# # all_cited_documents,
# # is_internet_marker_dict,
# # )
return FinalUpdate(
final_answer=final_answer,
all_cited_documents=all_cited_documents,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="closer",
node_start_time=node_start_time,
)
],
)
# return FinalUpdate(
# final_answer=final_answer,
# all_cited_documents=all_cited_documents,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="closer",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,248 +1,246 @@
import re
from datetime import datetime
from typing import cast
# import re
# from datetime import datetime
# from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
# from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
GeneratedImageFullResult,
)
from onyx.agents.agent_search.dr.utils import aggregate_context
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.context.search.models import InferenceSection
from onyx.db.chat import create_search_doc_from_inference_section
from onyx.db.chat import update_db_session_with_messages
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ResearchAgentIteration
from onyx.db.models import ResearchAgentIterationSubStep
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
# from onyx.agents.agent_search.dr.models import AggregatedDRContext
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.states import MainState
# from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
# GeneratedImageFullResult,
# )
# from onyx.agents.agent_search.dr.utils import aggregate_context
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.context.search.models import InferenceSection
# from onyx.db.chat import create_search_doc_from_inference_section
# from onyx.db.chat import update_db_session_with_messages
# from onyx.db.models import ChatMessage__SearchDoc
# from onyx.db.models import ResearchAgentIteration
# from onyx.db.models import ResearchAgentIterationSubStep
# from onyx.db.models import SearchDoc as DbSearchDoc
# from onyx.natural_language_processing.utils import get_tokenizer
# from onyx.server.query_and_chat.streaming_models import OverallStop
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def _extract_citation_numbers(text: str) -> list[int]:
"""
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
Returns a list of all unique citation numbers found.
"""
# Pattern to match [[number]] or [[number1, number2, ...]]
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
matches = re.findall(pattern, text)
# def _extract_citation_numbers(text: str) -> list[int]:
# """
# Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
# Returns a list of all unique citation numbers found.
# """
# # Pattern to match [[number]] or [[number1, number2, ...]]
# pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
# matches = re.findall(pattern, text)
cited_numbers = []
for match in matches:
# Split by comma and extract all numbers
numbers = [int(num.strip()) for num in match.split(",")]
cited_numbers.extend(numbers)
# cited_numbers = []
# for match in matches:
# # Split by comma and extract all numbers
# numbers = [int(num.strip()) for num in match.split(",")]
# cited_numbers.extend(numbers)
return list(set(cited_numbers)) # Return unique numbers
# return list(set(cited_numbers)) # Return unique numbers
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
numbers = [int(num.strip()) for num in citation_content.split(",")]
# def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
# citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
# numbers = [int(num.strip()) for num in citation_content.split(",")]
# For multiple citations like [[3, 5, 7]], create separate linked citations
linked_citations = []
for num in numbers:
if num - 1 < len(docs): # Check bounds
link = docs[num - 1].link or ""
linked_citations.append(f"[[{num}]]({link})")
else:
linked_citations.append(f"[[{num}]]") # No link if out of bounds
# # For multiple citations like [[3, 5, 7]], create separate linked citations
# linked_citations = []
# for num in numbers:
# if num - 1 < len(docs): # Check bounds
# link = docs[num - 1].link or ""
# linked_citations.append(f"[[{num}]]({link})")
# else:
# linked_citations.append(f"[[{num}]]") # No link if out of bounds
return "".join(linked_citations)
# return "".join(linked_citations)
def _insert_chat_message_search_doc_pair(
message_id: int, search_doc_ids: list[int], db_session: Session
) -> None:
"""
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
# def _insert_chat_message_search_doc_pair(
# message_id: int, search_doc_ids: list[int], db_session: Session
# ) -> None:
# """
# Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
Args:
message_id: The ID of the chat message
search_doc_id: The ID of the search document
db_session: The database session
"""
for search_doc_id in search_doc_ids:
chat_message_search_doc = ChatMessage__SearchDoc(
chat_message_id=message_id, search_doc_id=search_doc_id
)
db_session.add(chat_message_search_doc)
# Args:
# message_id: The ID of the chat message
# search_doc_id: The ID of the search document
# db_session: The database session
# """
# for search_doc_id in search_doc_ids:
# chat_message_search_doc = ChatMessage__SearchDoc(
# chat_message_id=message_id, search_doc_id=search_doc_id
# )
# db_session.add(chat_message_search_doc)
def save_iteration(
state: MainState,
graph_config: GraphConfig,
aggregated_context: AggregatedDRContext,
final_answer: str,
all_cited_documents: list[InferenceSection],
is_internet_marker_dict: dict[str, bool],
num_tokens: int,
) -> None:
db_session = graph_config.persistence.db_session
message_id = graph_config.persistence.message_id
research_type = graph_config.behavior.research_type
db_session = graph_config.persistence.db_session
# def save_iteration(
# state: MainState,
# graph_config: GraphConfig,
# aggregated_context: AggregatedDRContext,
# final_answer: str,
# all_cited_documents: list[InferenceSection],
# is_internet_marker_dict: dict[str, bool],
# num_tokens: int,
# ) -> None:
# db_session = graph_config.persistence.db_session
# message_id = graph_config.persistence.message_id
# first, insert the search_docs
search_docs = [
create_search_doc_from_inference_section(
inference_section=inference_section,
is_internet=is_internet_marker_dict.get(
inference_section.center_chunk.document_id, False
), # TODO: revisit
db_session=db_session,
commit=False,
)
for inference_section in all_cited_documents
]
# # first, insert the search_docs
# search_docs = [
# create_search_doc_from_inference_section(
# inference_section=inference_section,
# is_internet=is_internet_marker_dict.get(
# inference_section.center_chunk.document_id, False
# ), # TODO: revisit
# db_session=db_session,
# commit=False,
# )
# for inference_section in all_cited_documents
# ]
# then, map_search_docs to message
_insert_chat_message_search_doc_pair(
message_id, [search_doc.id for search_doc in search_docs], db_session
)
# # then, map_search_docs to message
# _insert_chat_message_search_doc_pair(
# message_id, [search_doc.id for search_doc in search_docs], db_session
# )
# lastly, insert the citations
citation_dict: dict[int, int] = {}
cited_doc_nrs = _extract_citation_numbers(final_answer)
if search_docs:
for cited_doc_nr in cited_doc_nrs:
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# # lastly, insert the citations
# citation_dict: dict[int, int] = {}
# cited_doc_nrs = _extract_citation_numbers(final_answer)
# if search_docs:
# for cited_doc_nr in cited_doc_nrs:
# citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
# TODO: generate plan as dict in the first place
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# # TODO: generate plan as dict in the first place
# plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
# plan_of_record_dict = parse_plan_to_dict(plan_of_record)
# Update the chat message and its parent message in database
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=graph_config.persistence.chat_session_id,
is_agentic=graph_config.behavior.use_agentic_search,
message=final_answer,
citations=citation_dict,
research_type=research_type,
research_plan=plan_of_record_dict,
final_documents=search_docs,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
token_count=num_tokens,
)
# # Update the chat message and its parent message in database
# update_db_session_with_messages(
# db_session=db_session,
# chat_message_id=message_id,
# chat_session_id=graph_config.persistence.chat_session_id,
# is_agentic=graph_config.behavior.use_agentic_search,
# message=final_answer,
# citations=citation_dict,
# research_type=None, # research_type is deprecated
# research_plan=plan_of_record_dict,
# final_documents=search_docs,
# update_parent_message=True,
# research_answer_purpose=ResearchAnswerPurpose.ANSWER,
# token_count=num_tokens,
# )
for iteration_preparation in state.iteration_instructions:
research_agent_iteration_step = ResearchAgentIteration(
primary_question_id=message_id,
reasoning=iteration_preparation.reasoning,
purpose=iteration_preparation.purpose,
iteration_nr=iteration_preparation.iteration_nr,
)
db_session.add(research_agent_iteration_step)
# for iteration_preparation in state.iteration_instructions:
# research_agent_iteration_step = ResearchAgentIteration(
# primary_question_id=message_id,
# reasoning=iteration_preparation.reasoning,
# purpose=iteration_preparation.purpose,
# iteration_nr=iteration_preparation.iteration_nr,
# )
# db_session.add(research_agent_iteration_step)
for iteration_answer in aggregated_context.global_iteration_responses:
# for iteration_answer in aggregated_context.global_iteration_responses:
retrieved_search_docs = convert_inference_sections_to_search_docs(
list(iteration_answer.cited_documents.values())
)
# retrieved_search_docs = convert_inference_sections_to_search_docs(
# list(iteration_answer.cited_documents.values())
# )
# Convert SavedSearchDoc objects to JSON-serializable format
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
# # Convert SavedSearchDoc objects to JSON-serializable format
# serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
primary_question_id=message_id,
iteration_nr=iteration_answer.iteration_nr,
iteration_sub_step_nr=iteration_answer.parallelization_nr,
sub_step_instructions=iteration_answer.question,
sub_step_tool_id=iteration_answer.tool_id,
sub_answer=iteration_answer.answer,
reasoning=iteration_answer.reasoning,
claims=iteration_answer.claims,
cited_doc_results=serialized_search_docs,
generated_images=(
GeneratedImageFullResult(images=iteration_answer.generated_images)
if iteration_answer.generated_images
else None
),
additional_data=iteration_answer.additional_data,
queries=iteration_answer.queries,
)
db_session.add(research_agent_iteration_sub_step)
# research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
# primary_question_id=message_id,
# iteration_nr=iteration_answer.iteration_nr,
# iteration_sub_step_nr=iteration_answer.parallelization_nr,
# sub_step_instructions=iteration_answer.question,
# sub_step_tool_id=iteration_answer.tool_id,
# sub_answer=iteration_answer.answer,
# reasoning=iteration_answer.reasoning,
# claims=iteration_answer.claims,
# cited_doc_results=serialized_search_docs,
# generated_images=(
# GeneratedImageFullResult(images=iteration_answer.generated_images)
# if iteration_answer.generated_images
# else None
# ),
# additional_data=iteration_answer.additional_data,
# queries=iteration_answer.queries,
# )
# db_session.add(research_agent_iteration_sub_step)
db_session.commit()
# db_session.commit()
def logging(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to close the DR process and finalize the answer.
"""
# def logging(
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to close the DR process and finalize the answer.
# """
node_start_time = datetime.now()
# TODO: generate final answer using all the previous steps
# (right now, answers from each step are concatenated onto each other)
# Also, add missing fields once usage in UI is clear.
# node_start_time = datetime.now()
# # TODO: generate final answer using all the previous steps
# # (right now, answers from each step are concatenated onto each other)
# # Also, add missing fields once usage in UI is clear.
current_step_nr = state.current_step_nr
# current_step_nr = state.current_step_nr
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = state.original_question
if not base_question:
raise ValueError("Question is required for closer")
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = state.original_question
# if not base_question:
# raise ValueError("Question is required for closer")
aggregated_context = aggregate_context(
state.iteration_responses, include_documents=True
)
# aggregated_context = aggregate_context(
# state.iteration_responses, include_documents=True
# )
all_cited_documents = aggregated_context.cited_documents
# all_cited_documents = aggregated_context.cited_documents
is_internet_marker_dict = aggregated_context.is_internet_marker_dict
# is_internet_marker_dict = aggregated_context.is_internet_marker_dict
final_answer = state.final_answer or ""
llm_provider = graph_config.tooling.primary_llm.config.model_provider
llm_model_name = graph_config.tooling.primary_llm.config.model_name
# final_answer = state.final_answer or ""
# llm_provider = graph_config.tooling.primary_llm.config.model_provider
# llm_model_name = graph_config.tooling.primary_llm.config.model_name
llm_tokenizer = get_tokenizer(
model_name=llm_model_name,
provider_type=llm_provider,
)
num_tokens = len(llm_tokenizer.encode(final_answer or ""))
# llm_tokenizer = get_tokenizer(
# model_name=llm_model_name,
# provider_type=llm_provider,
# )
# num_tokens = len(llm_tokenizer.encode(final_answer or ""))
write_custom_event(current_step_nr, OverallStop(), writer)
# write_custom_event(current_step_nr, OverallStop(), writer)
# Log the research agent steps
save_iteration(
state,
graph_config,
aggregated_context,
final_answer,
all_cited_documents,
is_internet_marker_dict,
num_tokens,
)
# # Log the research agent steps
# save_iteration(
# state,
# graph_config,
# aggregated_context,
# final_answer,
# all_cited_documents,
# is_internet_marker_dict,
# num_tokens,
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="logger",
node_start_time=node_start_time,
)
],
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="main",
# node_name="logger",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,132 +1,131 @@
from collections.abc import Iterator
from typing import cast
# from collections.abc import Iterator
# from typing import cast
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langgraph.types import StreamWriter
from pydantic import BaseModel
# from langchain_core.messages import AIMessageChunk
# from langchain_core.messages import BaseMessage
# from langgraph.types import StreamWriter
# from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import CitationInfo
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
PassThroughAnswerResponseHandler,
)
from onyx.chat.stream_processing.utils import map_document_id_order
from onyx.context.search.models import InferenceSection
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
# from onyx.chat.models import AgentAnswerPiece
# from onyx.chat.models import CitationInfo
# from onyx.chat.models import LlmDoc
# from onyx.chat.models import OnyxAnswerPiece
# from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
# from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
# from onyx.chat.stream_processing.answer_response_handler import (
# PassThroughAnswerResponseHandler,
# )
# from onyx.chat.stream_processing.utils import map_document_id_order
# from onyx.context.search.models import InferenceSection
# from onyx.server.query_and_chat.streaming_models import CitationDelta
# from onyx.server.query_and_chat.streaming_models import CitationStart
# from onyx.server.query_and_chat.streaming_models import MessageDelta
# from onyx.server.query_and_chat.streaming_models import MessageStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
class BasicSearchProcessedStreamResults(BaseModel):
ai_message_chunk: AIMessageChunk = AIMessageChunk(content="")
full_answer: str | None = None
cited_references: list[InferenceSection] = []
retrieved_documents: list[LlmDoc] = []
# class BasicSearchProcessedStreamResults(BaseModel):
# ai_message_chunk: AIMessageChunk = AIMessageChunk(content="")
# full_answer: str | None = None
# cited_references: list[InferenceSection] = []
# retrieved_documents: list[LlmDoc] = []
def process_llm_stream(
messages: Iterator[BaseMessage],
should_stream_answer: bool,
writer: StreamWriter,
ind: int,
search_results: list[LlmDoc] | None = None,
generate_final_answer: bool = False,
chat_message_id: str | None = None,
) -> BasicSearchProcessedStreamResults:
tool_call_chunk = AIMessageChunk(content="")
# def process_llm_stream(
# messages: Iterator[BaseMessage],
# should_stream_answer: bool,
# writer: StreamWriter,
# ind: int,
# search_results: list[LlmDoc] | None = None,
# generate_final_answer: bool = False,
# chat_message_id: str | None = None,
# ) -> BasicSearchProcessedStreamResults:
# tool_call_chunk = AIMessageChunk(content="")
if search_results:
answer_handler: AnswerResponseHandler = CitationResponseHandler(
context_docs=search_results,
doc_id_to_rank_map=map_document_id_order(search_results),
)
else:
answer_handler = PassThroughAnswerResponseHandler()
# if search_results:
# answer_handler: AnswerResponseHandler = CitationResponseHandler(
# context_docs=search_results,
# doc_id_to_rank_map=map_document_id_order(search_results),
# )
# else:
# answer_handler = PassThroughAnswerResponseHandler()
full_answer = ""
start_final_answer_streaming_set = False
# Accumulate citation infos if handler emits them
collected_citation_infos: list[CitationInfo] = []
# full_answer = ""
# start_final_answer_streaming_set = False
# # Accumulate citation infos if handler emits them
# collected_citation_infos: list[CitationInfo] = []
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information.
for message in messages:
# # This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# # the stream will contain AIMessageChunks with tool call information.
# for message in messages:
answer_piece = message.content
if not isinstance(answer_piece, str):
# this is only used for logging, so fine to
# just add the string representation
answer_piece = str(answer_piece)
full_answer += answer_piece
# answer_piece = message.content
# if not isinstance(answer_piece, str):
# # this is only used for logging, so fine to
# # just add the string representation
# answer_piece = str(answer_piece)
# full_answer += answer_piece
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
):
tool_call_chunk += message # type: ignore
elif should_stream_answer:
for response_part in answer_handler.handle_response_part(message):
# if isinstance(message, AIMessageChunk) and (
# message.tool_call_chunks or message.tool_calls
# ):
# tool_call_chunk += message # type: ignore
# elif should_stream_answer:
# for response_part in answer_handler.handle_response_part(message):
# only stream out answer parts
if (
isinstance(response_part, (OnyxAnswerPiece, AgentAnswerPiece))
and generate_final_answer
and response_part.answer_piece
):
if chat_message_id is None:
raise ValueError(
"chat_message_id is required when generating final answer"
)
# # only stream out answer parts
# if (
# isinstance(response_part, (OnyxAnswerPiece, AgentAnswerPiece))
# and generate_final_answer
# and response_part.answer_piece
# ):
# if chat_message_id is None:
# raise ValueError(
# "chat_message_id is required when generating final answer"
# )
if not start_final_answer_streaming_set:
# Convert LlmDocs to SavedSearchDocs
saved_search_docs = saved_search_docs_from_llm_docs(
search_results
)
write_custom_event(
ind,
MessageStart(content="", final_documents=saved_search_docs),
writer,
)
start_final_answer_streaming_set = True
# if not start_final_answer_streaming_set:
# # Convert LlmDocs to SavedSearchDocs
# saved_search_docs = saved_search_docs_from_llm_docs(
# search_results
# )
# write_custom_event(
# ind,
# MessageStart(content="", final_documents=saved_search_docs),
# writer,
# )
# start_final_answer_streaming_set = True
write_custom_event(
ind,
MessageDelta(content=response_part.answer_piece),
writer,
)
# collect citation info objects
elif isinstance(response_part, CitationInfo):
collected_citation_infos.append(response_part)
# write_custom_event(
# ind,
# MessageDelta(content=response_part.answer_piece),
# writer,
# )
# # collect citation info objects
# elif isinstance(response_part, CitationInfo):
# collected_citation_infos.append(response_part)
if generate_final_answer and start_final_answer_streaming_set:
# start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
write_custom_event(
ind,
SectionEnd(),
writer,
)
# if generate_final_answer and start_final_answer_streaming_set:
# # start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
# write_custom_event(
# ind,
# SectionEnd(),
# writer,
# )
# Emit citations section if any were collected
if collected_citation_infos:
write_custom_event(ind, CitationStart(), writer)
write_custom_event(
ind, CitationDelta(citations=collected_citation_infos), writer
)
write_custom_event(ind, SectionEnd(), writer)
# # Emit citations section if any were collected
# if collected_citation_infos:
# write_custom_event(ind, CitationStart(), writer)
# write_custom_event(
# ind, CitationDelta(citations=collected_citation_infos), writer
# )
# write_custom_event(ind, SectionEnd(), writer)
logger.debug(f"Full answer: {full_answer}")
return BasicSearchProcessedStreamResults(
ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
)
# logger.debug(f"Full answer: {full_answer}")
# return BasicSearchProcessedStreamResults(
# ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
# )

View File

@@ -1,82 +1,82 @@
from operator import add
from typing import Annotated
from typing import Any
from typing import TypedDict
# from operator import add
# from typing import Annotated
# from typing import Any
# from typing import TypedDict
from pydantic import BaseModel
# from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
from onyx.agents.agent_search.dr.models import OrchestrationPlan
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.context.search.models import InferenceSection
from onyx.db.connector import DocumentSource
# from onyx.agents.agent_search.core_state import CoreState
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.models import IterationInstructions
# from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
# from onyx.agents.agent_search.dr.models import OrchestrationPlan
# from onyx.agents.agent_search.dr.models import OrchestratorTool
# from onyx.context.search.models import InferenceSection
# from onyx.db.connector import DocumentSource
### States ###
# ### States ###
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
# class LoggerUpdate(BaseModel):
# log_messages: Annotated[list[str], add] = []
class OrchestrationUpdate(LoggerUpdate):
tools_used: Annotated[list[str], add] = []
query_list: list[str] = []
iteration_nr: int = 0
current_step_nr: int = 1
plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
remaining_time_budget: float = 2.0 # set by default to about 2 searches
num_closer_suggestions: int = 0 # how many times the closer was suggested
gaps: list[str] = (
[]
) # gaps that may be identified by the closer before being able to answer the question.
iteration_instructions: Annotated[list[IterationInstructions], add] = []
# class OrchestrationUpdate(LoggerUpdate):
# tools_used: Annotated[list[str], add] = []
# query_list: list[str] = []
# iteration_nr: int = 0
# current_step_nr: int = 1
# plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
# remaining_time_budget: float = 2.0 # set by default to about 2 searches
# num_closer_suggestions: int = 0 # how many times the closer was suggested
# gaps: list[str] = (
# []
# ) # gaps that may be identified by the closer before being able to answer the question.
# iteration_instructions: Annotated[list[IterationInstructions], add] = []
class OrchestrationSetup(OrchestrationUpdate):
original_question: str | None = None
chat_history_string: str | None = None
clarification: OrchestrationClarificationInfo | None = None
available_tools: dict[str, OrchestratorTool] | None = None
num_closer_suggestions: int = 0 # how many times the closer was suggested
# class OrchestrationSetup(OrchestrationUpdate):
# original_question: str | None = None
# chat_history_string: str | None = None
# clarification: OrchestrationClarificationInfo | None = None
# available_tools: dict[str, OrchestratorTool] | None = None
# num_closer_suggestions: int = 0 # how many times the closer was suggested
active_source_types: list[DocumentSource] | None = None
active_source_types_descriptions: str | None = None
assistant_system_prompt: str | None = None
assistant_task_prompt: str | None = None
uploaded_test_context: str | None = None
uploaded_image_context: list[dict[str, Any]] | None = None
# active_source_types: list[DocumentSource] | None = None
# active_source_types_descriptions: str | None = None
# assistant_system_prompt: str | None = None
# assistant_task_prompt: str | None = None
# uploaded_test_context: str | None = None
# uploaded_image_context: list[dict[str, Any]] | None = None
class AnswerUpdate(LoggerUpdate):
iteration_responses: Annotated[list[IterationAnswer], add] = []
# class AnswerUpdate(LoggerUpdate):
# iteration_responses: Annotated[list[IterationAnswer], add] = []
class FinalUpdate(LoggerUpdate):
final_answer: str | None = None
all_cited_documents: list[InferenceSection] = []
# class FinalUpdate(LoggerUpdate):
# final_answer: str | None = None
# all_cited_documents: list[InferenceSection] = []
## Graph Input State
class MainInput(CoreState):
pass
# ## Graph Input State
# class MainInput(CoreState):
# pass
## Graph State
class MainState(
# This includes the core state
MainInput,
OrchestrationSetup,
AnswerUpdate,
FinalUpdate,
):
pass
# ## Graph State
# class MainState(
# # This includes the core state
# MainInput,
# OrchestrationSetup,
# AnswerUpdate,
# FinalUpdate,
# ):
# pass
## Graph Output State
class MainOutput(TypedDict):
log_messages: list[str]
final_answer: str | None
all_cited_documents: list[InferenceSection]
# ## Graph Output State
# class MainOutput(TypedDict):
# log_messages: list[str]
# final_answer: str | None
# all_cited_documents: list[InferenceSection]

View File

@@ -1,47 +1,47 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import SearchToolStart
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def basic_search_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# def basic_search_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
current_step_nr = state.current_step_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# current_step_nr = state.current_step_nr
logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
# logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
write_custom_event(
current_step_nr,
SearchToolStart(
is_internet_search=False,
),
writer,
)
# write_custom_event(
# current_step_nr,
# SearchToolStart(
# is_internet_search=False,
# ),
# writer,
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="basic_search",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,286 +1,261 @@
import re
from datetime import datetime
from typing import cast
from uuid import UUID
# import re
# from datetime import datetime
# from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import BaseSearchProcessingResponse
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import SearchAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.dr.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.chat.models import LlmDoc
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.context.search.models import InferenceSection
from onyx.db.connector import DocumentSource
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.models import BaseSearchProcessingResponse
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.models import SearchAnswer
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.dr.utils import extract_document_citations
# from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.agents.agent_search.utils import create_question_prompt
# from onyx.chat.models import LlmDoc
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
# from onyx.context.search.models import InferenceSection
# from onyx.db.connector import DocumentSource
# from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
# from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
# from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
# from onyx.server.query_and_chat.streaming_models import SearchToolDelta
# from onyx.tools.models import SearchToolOverrideKwargs
# from onyx.tools.tool_implementations.search.search_tool import SearchTool
# from onyx.tools.tool_implementations.search_like_tool_utils import (
# SEARCH_INFERENCE_SECTIONS_ID,
# )
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def basic_search(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# def basic_search(
# state: BranchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
current_step_nr = state.current_step_nr
assistant_system_prompt = state.assistant_system_prompt
assistant_task_prompt = state.assistant_task_prompt
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# parallelization_nr = state.parallelization_nr
# current_step_nr = state.current_step_nr
# assistant_system_prompt = state.assistant_system_prompt
# assistant_task_prompt = state.assistant_task_prompt
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
# branch_query = state.branch_question
# if not branch_query:
# raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
research_type = graph_config.behavior.research_type
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = graph_config.inputs.prompt_builder.raw_user_query
# use_agentic_search = graph_config.behavior.use_agentic_search
if not state.available_tools:
raise ValueError("available_tools is not set")
# if not state.available_tools:
# raise ValueError("available_tools is not set")
elif len(state.tools_used) == 0:
raise ValueError("tools_used is empty")
# elif len(state.tools_used) == 0:
# raise ValueError("tools_used is empty")
search_tool_info = state.available_tools[state.tools_used[-1]]
search_tool = cast(SearchTool, search_tool_info.tool_object)
force_use_tool = graph_config.tooling.force_use_tool
# search_tool_info = state.available_tools[state.tools_used[-1]]
# search_tool = cast(SearchTool, search_tool_info.tool_object)
# graph_config.tooling.force_use_tool
# sanity check
if search_tool != graph_config.tooling.search_tool:
raise ValueError("search_tool does not match the configured search tool")
# # sanity check
# if search_tool != graph_config.tooling.search_tool:
# raise ValueError("search_tool does not match the configured search tool")
# rewrite query and identify source types
active_source_types_str = ", ".join(
[source.value for source in state.active_source_types or []]
)
# # rewrite query and identify source types
# active_source_types_str = ", ".join(
# [source.value for source in state.active_source_types or []]
# )
base_search_processing_prompt = BASE_SEARCH_PROCESSING_PROMPT.build(
active_source_types_str=active_source_types_str,
branch_query=branch_query,
current_time=datetime.now().strftime("%Y-%m-%d %H:%M"),
)
# base_search_processing_prompt = BASE_SEARCH_PROCESSING_PROMPT.build(
# active_source_types_str=active_source_types_str,
# branch_query=branch_query,
# current_time=datetime.now().strftime("%Y-%m-%d %H:%M"),
# )
try:
search_processing = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, base_search_processing_prompt
),
schema=BaseSearchProcessingResponse,
timeout_override=TF_DR_TIMEOUT_SHORT,
# max_tokens=100,
)
except Exception as e:
logger.error(f"Could not process query: {e}")
raise e
# try:
# search_processing = invoke_llm_json(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt, base_search_processing_prompt
# ),
# schema=BaseSearchProcessingResponse,
# timeout_override=TF_DR_TIMEOUT_SHORT,
# # max_tokens=100,
# )
# except Exception as e:
# logger.error(f"Could not process query: {e}")
# raise e
rewritten_query = search_processing.rewritten_query
# rewritten_query = search_processing.rewritten_query
# give back the query so we can render it in the UI
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=[rewritten_query],
documents=[],
),
writer,
)
# # give back the query so we can render it in the UI
# write_custom_event(
# current_step_nr,
# SearchToolDelta(
# queries=[rewritten_query],
# documents=[],
# ),
# writer,
# )
implied_start_date = search_processing.time_filter
# implied_start_date = search_processing.time_filter
# Validate time_filter format if it exists
implied_time_filter = None
if implied_start_date:
# # Validate time_filter format if it exists
# implied_time_filter = None
# if implied_start_date:
# Check if time_filter is in YYYY-MM-DD format
date_pattern = r"^\d{4}-\d{2}-\d{2}$"
if re.match(date_pattern, implied_start_date):
implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
# # Check if time_filter is in YYYY-MM-DD format
# date_pattern = r"^\d{4}-\d{2}-\d{2}$"
# if re.match(date_pattern, implied_start_date):
# implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
specified_source_types: list[DocumentSource] | None = (
strings_to_document_sources(search_processing.specified_source_types)
if search_processing.specified_source_types
else None
)
# specified_source_types: list[DocumentSource] | None = (
# strings_to_document_sources(search_processing.specified_source_types)
# if search_processing.specified_source_types
# else None
# )
if specified_source_types is not None and len(specified_source_types) == 0:
specified_source_types = None
# if specified_source_types is not None and len(specified_source_types) == 0:
# specified_source_types = None
logger.debug(
f"Search start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# logger.debug(
# f"Search start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
retrieved_docs: list[InferenceSection] = []
callback_container: list[list[InferenceSection]] = []
# retrieved_docs: list[InferenceSection] = []
user_file_ids: list[UUID] | None = None
project_id: int | None = None
if force_use_tool.override_kwargs and isinstance(
force_use_tool.override_kwargs, SearchToolOverrideKwargs
):
override_kwargs = force_use_tool.override_kwargs
user_file_ids = override_kwargs.user_file_ids
project_id = override_kwargs.project_id
# for tool_response in search_tool.run(
# query=rewritten_query,
# document_sources=specified_source_types,
# time_filter=implied_time_filter,
# override_kwargs=SearchToolOverrideKwargs(original_query=rewritten_query),
# ):
# # get retrieved docs to send to the rest of the graph
# if tool_response.id == SEARCH_INFERENCE_SECTIONS_ID:
# retrieved_docs = cast(list[InferenceSection], tool_response.response)
# new db session to avoid concurrency issues
with get_session_with_current_tenant() as search_db_session:
for tool_response in search_tool.run(
query=rewritten_query,
document_sources=specified_source_types,
time_filter=implied_time_filter,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=search_db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
original_query=rewritten_query,
user_file_ids=user_file_ids,
project_id=project_id,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
# break
break
# # render the retrieved docs in the UI
# write_custom_event(
# current_step_nr,
# SearchToolDelta(
# queries=[],
# documents=convert_inference_sections_to_search_docs(
# retrieved_docs, is_internet=False
# ),
# ),
# writer,
# )
# render the retrieved docs in the UI
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=[],
documents=convert_inference_sections_to_search_docs(
retrieved_docs, is_internet=False
),
),
writer,
)
# document_texts_list = []
document_texts_list = []
# for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
# if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
# raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
# chunk_text = build_document_context(retrieved_doc, doc_num + 1)
# document_texts_list.append(chunk_text)
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
# document_texts = "\n\n".join(document_texts_list)
document_texts = "\n\n".join(document_texts_list)
# logger.debug(
# f"Search end/LLM start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
logger.debug(
f"Search end/LLM start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# # Built prompt
# Built prompt
# if use_agentic_search:
# search_prompt = INTERNAL_SEARCH_PROMPTS[use_agentic_search].build(
# search_query=branch_query,
# base_question=base_question,
# document_text=document_texts,
# )
if research_type == ResearchType.DEEP:
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
search_query=branch_query,
base_question=base_question,
document_text=document_texts,
)
# # Run LLM
# Run LLM
# # search_answer_json = None
# search_answer_json = invoke_llm_json(
# llm=graph_config.tooling.primary_llm,
# prompt=create_question_prompt(
# assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
# ),
# schema=SearchAnswer,
# timeout_override=TF_DR_TIMEOUT_LONG,
# # max_tokens=1500,
# )
# search_answer_json = None
search_answer_json = invoke_llm_json(
llm=graph_config.tooling.primary_llm,
prompt=create_question_prompt(
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
),
schema=SearchAnswer,
timeout_override=TF_DR_TIMEOUT_LONG,
# max_tokens=1500,
)
# logger.debug(
# f"LLM/all done for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
logger.debug(
f"LLM/all done for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# # get cited documents
# answer_string = search_answer_json.answer
# claims = search_answer_json.claims or []
# reasoning = search_answer_json.reasoning
# # answer_string = ""
# # claims = []
# get cited documents
answer_string = search_answer_json.answer
claims = search_answer_json.claims or []
reasoning = search_answer_json.reasoning
# answer_string = ""
# claims = []
# (
# citation_numbers,
# answer_string,
# claims,
# ) = extract_document_citations(answer_string, claims)
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
# if citation_numbers and (
# (max(citation_numbers) > len(retrieved_docs)) or min(citation_numbers) < 1
# ):
# raise ValueError("Citation numbers are out of range for retrieved docs.")
if citation_numbers and (
(max(citation_numbers) > len(retrieved_docs)) or min(citation_numbers) < 1
):
raise ValueError("Citation numbers are out of range for retrieved docs.")
# cited_documents = {
# citation_number: retrieved_docs[citation_number - 1]
# for citation_number in citation_numbers
# }
cited_documents = {
citation_number: retrieved_docs[citation_number - 1]
for citation_number in citation_numbers
}
# else:
# answer_string = ""
# claims = []
# cited_documents = {
# doc_num + 1: retrieved_doc
# for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
# }
# reasoning = ""
else:
answer_string = ""
claims = []
cited_documents = {
doc_num + 1: retrieved_doc
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
}
reasoning = ""
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=search_tool_info.llm_path,
tool_id=search_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=reasoning,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=search_tool_info.llm_path,
# tool_id=search_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=branch_query,
# answer=answer_string,
# claims=claims,
# cited_documents=cited_documents,
# reasoning=reasoning,
# additional_data=None,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="basic_search",
# node_name="searching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,77 +1,77 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.context.search.models import SavedSearchDoc
# from onyx.context.search.models import SearchDoc
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def is_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# def is_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
# current_step_nr = state.current_step_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
[update.question for update in new_updates]
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
# [update.question for update in new_updates]
# doc_lists = [list(update.cited_documents.values()) for update in new_updates]
doc_list = []
# doc_list = []
for xs in doc_lists:
for x in xs:
doc_list.append(x)
# for xs in doc_lists:
# for x in xs:
# doc_list.append(x)
# Convert InferenceSections to SavedSearchDocs
search_docs = SearchDoc.from_chunks_or_sections(doc_list)
retrieved_saved_search_docs = [
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
for search_doc in search_docs
]
# # Convert InferenceSections to SavedSearchDocs
# search_docs = SearchDoc.from_chunks_or_sections(doc_list)
# retrieved_saved_search_docs = [
# SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
# for search_doc in search_docs
# ]
for retrieved_saved_search_doc in retrieved_saved_search_docs:
retrieved_saved_search_doc.is_internet = False
# for retrieved_saved_search_doc in retrieved_saved_search_docs:
# retrieved_saved_search_doc.is_internet = False
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
current_step_nr += 1
# current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="basic_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)
# return SubAgentUpdate(
# iteration_responses=new_updates,
# current_step_nr=current_step_nr,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="basic_search",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,50 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_1_branch import (
basic_search_branch,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import (
basic_search,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_3_reduce import (
is_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_image_generation_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_1_branch import (
# basic_search_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import (
# basic_search,
# )
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_3_reduce import (
# is_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_image_generation_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def dr_basic_search_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Web Search Sub-Agent
"""
# def dr_basic_search_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Web Search Sub-Agent
# """
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
# ### Add nodes ###
graph.add_node("branch", basic_search_branch)
# graph.add_node("branch", basic_search_branch)
graph.add_node("act", basic_search)
# graph.add_node("act", basic_search)
graph.add_node("reducer", is_reducer)
# graph.add_node("reducer", is_reducer)
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
# graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
# graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
# graph.add_edge(start_key="reducer", end_key=END)
return graph
# return graph

View File

@@ -1,30 +1,30 @@
from collections.abc import Hashable
# from collections.abc import Hashable
from langgraph.types import Send
# from langgraph.types import Send
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
current_step_nr=state.current_step_nr,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:MAX_DR_PARALLEL_SEARCH]
)
]
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# current_step_nr=state.current_step_nr,
# context="",
# active_source_types=state.active_source_types,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:MAX_DR_PARALLEL_SEARCH]
# )
# ]

View File

@@ -1,36 +1,36 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def custom_tool_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# def custom_tool_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
# logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="branching",
node_start_time=node_start_time,
)
],
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,169 +1,164 @@
import json
from datetime import datetime
from typing import cast
# import json
# from datetime import datetime
# from typing import cast
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.messages import AIMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
from onyx.tools.tool_implementations.mcp.mcp_tool import MCP_TOOL_RESPONSE_ID
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
# from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
# from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
# from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
# from onyx.tools.tool_implementations.mcp.mcp_tool import MCP_TOOL_RESPONSE_ID
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def custom_tool_act(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# def custom_tool_act(
# state: BranchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# parallelization_nr = state.parallelization_nr
if not state.available_tools:
raise ValueError("available_tools is not set")
# if not state.available_tools:
# raise ValueError("available_tools is not set")
custom_tool_info = state.available_tools[state.tools_used[-1]]
custom_tool_name = custom_tool_info.name
custom_tool = cast(CustomTool, custom_tool_info.tool_object)
# custom_tool_info = state.available_tools[state.tools_used[-1]]
# custom_tool_name = custom_tool_info.name
# custom_tool = cast(CustomTool, custom_tool_info.tool_object)
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
# branch_query = state.branch_question
# if not branch_query:
# raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = graph_config.inputs.prompt_builder.raw_user_query
logger.debug(
f"Tool call start for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# logger.debug(
# f"Tool call start for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
# get tool call args
tool_args: dict | None = None
if graph_config.tooling.using_tool_calling_llm:
# get tool call args from tool-calling LLM
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
query=branch_query,
base_question=base_question,
tool_description=custom_tool_info.description,
)
tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
tool_use_prompt,
tools=[custom_tool.tool_definition()],
tool_choice="required",
timeout_override=TF_DR_TIMEOUT_LONG,
)
# # get tool call args
# tool_args: dict | None = None
# if graph_config.tooling.using_tool_calling_llm:
# # get tool call args from tool-calling LLM
# tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
# query=branch_query,
# base_question=base_question,
# tool_description=custom_tool_info.description,
# )
# tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
# tool_use_prompt,
# tools=[custom_tool.tool_definition()],
# tool_choice="required",
# timeout_override=TF_DR_TIMEOUT_LONG,
# )
# make sure we got a tool call
if (
isinstance(tool_calling_msg, AIMessage)
and len(tool_calling_msg.tool_calls) == 1
):
tool_args = tool_calling_msg.tool_calls[0]["args"]
else:
logger.warning("Tool-calling LLM did not emit a tool call")
# # make sure we got a tool call
# if (
# isinstance(tool_calling_msg, AIMessage)
# and len(tool_calling_msg.tool_calls) == 1
# ):
# tool_args = tool_calling_msg.tool_calls[0]["args"]
# else:
# logger.warning("Tool-calling LLM did not emit a tool call")
if tool_args is None:
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
tool_args = custom_tool.get_args_for_non_tool_calling_llm(
query=branch_query,
history=[],
llm=graph_config.tooling.primary_llm,
force_run=True,
)
# if tool_args is None:
# raise ValueError(
# "Failed to obtain tool arguments from LLM - tool calling is required"
# )
if tool_args is None:
raise ValueError("Failed to obtain tool arguments from LLM")
# # run the tool
# response_summary: CustomToolCallSummary | None = None
# for tool_response in custom_tool.run(**tool_args):
# if tool_response.id in {CUSTOM_TOOL_RESPONSE_ID, MCP_TOOL_RESPONSE_ID}:
# response_summary = cast(CustomToolCallSummary, tool_response.response)
# break
# run the tool
response_summary: CustomToolCallSummary | None = None
for tool_response in custom_tool.run(**tool_args):
if tool_response.id in {CUSTOM_TOOL_RESPONSE_ID, MCP_TOOL_RESPONSE_ID}:
response_summary = cast(CustomToolCallSummary, tool_response.response)
break
# if not response_summary:
# raise ValueError("Custom tool did not return a valid response summary")
if not response_summary:
raise ValueError("Custom tool did not return a valid response summary")
# # summarise tool result
# if not response_summary.response_type:
# raise ValueError("Response type is not returned.")
# summarise tool result
if not response_summary.response_type:
raise ValueError("Response type is not returned.")
# if response_summary.response_type == "json":
# tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
# elif response_summary.response_type in {"image", "csv"}:
# tool_result_str = f"{response_summary.response_type} files: {response_summary.tool_result.file_ids}"
# else:
# tool_result_str = str(response_summary.tool_result)
if response_summary.response_type == "json":
tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
elif response_summary.response_type in {"image", "csv"}:
tool_result_str = f"{response_summary.response_type} files: {response_summary.tool_result.file_ids}"
else:
tool_result_str = str(response_summary.tool_result)
# tool_str = (
# f"Tool used: {custom_tool_name}\n"
# f"Description: {custom_tool_info.description}\n"
# f"Result: {tool_result_str}"
# )
tool_str = (
f"Tool used: {custom_tool_name}\n"
f"Description: {custom_tool_info.description}\n"
f"Result: {tool_result_str}"
)
# tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
# query=branch_query, base_question=base_question, tool_response=tool_str
# )
# answer_string = str(
# graph_config.tooling.primary_llm.invoke(
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
# ).content
# ).strip()
tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
query=branch_query, base_question=base_question, tool_response=tool_str
)
answer_string = str(
graph_config.tooling.primary_llm.invoke_langchain(
tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
).content
).strip()
# tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
# query=branch_query, base_question=base_question, tool_response=tool_str
# )
# answer_string = str(
# graph_config.tooling.primary_llm.invoke_langchain(
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
# ).content
# ).strip()
# get file_ids:
file_ids = None
if response_summary.response_type in {"image", "csv"} and hasattr(
response_summary.tool_result, "file_ids"
):
file_ids = response_summary.tool_result.file_ids
# logger.debug(
# f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
logger.debug(
f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=custom_tool_name,
tool_id=custom_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning="",
additional_data=None,
response_type=response_summary.response_type,
data=response_summary.tool_result,
file_ids=file_ids,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="tool_calling",
node_start_time=node_start_time,
)
],
)
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=custom_tool_name,
# tool_id=custom_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=branch_query,
# answer=answer_string,
# claims=[],
# cited_documents={},
# reasoning="",
# additional_data=None,
# response_type=response_summary.response_type,
# data=response_summary.tool_result,
# file_ids=file_ids,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="tool_calling",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,82 +1,82 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import CustomToolDelta
# from onyx.server.query_and_chat.streaming_models import CustomToolStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def custom_tool_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# def custom_tool_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
current_step_nr = state.current_step_nr
# current_step_nr = state.current_step_nr
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
for new_update in new_updates:
# for new_update in new_updates:
if not new_update.response_type:
raise ValueError("Response type is not returned.")
# if not new_update.response_type:
# raise ValueError("Response type is not returned.")
write_custom_event(
current_step_nr,
CustomToolStart(
tool_name=new_update.tool,
),
writer,
)
# write_custom_event(
# current_step_nr,
# CustomToolStart(
# tool_name=new_update.tool,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
CustomToolDelta(
tool_name=new_update.tool,
response_type=new_update.response_type,
data=new_update.data,
file_ids=new_update.file_ids,
),
writer,
)
# write_custom_event(
# current_step_nr,
# CustomToolDelta(
# tool_name=new_update.tool,
# response_type=new_update.response_type,
# data=new_update.data,
# file_ids=new_update.file_ids,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
current_step_nr += 1
# current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)
# return SubAgentUpdate(
# iteration_responses=new_updates,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,28 +1,28 @@
from collections.abc import Hashable
# from collections.abc import Hashable
from langgraph.types import Send
# from langgraph.types import Send
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import (
SubAgentInput,
)
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import (
# SubAgentInput,
# )
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel call for now
)
]
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# context="",
# active_source_types=state.active_source_types,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:1] # no parallel call for now
# )
# ]

View File

@@ -1,50 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_1_branch import (
custom_tool_branch,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_2_act import (
custom_tool_act,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_3_reduce import (
custom_tool_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_1_branch import (
# custom_tool_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_2_act import (
# custom_tool_act,
# )
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_3_reduce import (
# custom_tool_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def dr_custom_tool_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Generic Tool Sub-Agent
"""
# def dr_custom_tool_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Generic Tool Sub-Agent
# """
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
# ### Add nodes ###
graph.add_node("branch", custom_tool_branch)
# graph.add_node("branch", custom_tool_branch)
graph.add_node("act", custom_tool_act)
# graph.add_node("act", custom_tool_act)
graph.add_node("reducer", custom_tool_reducer)
# graph.add_node("reducer", custom_tool_reducer)
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
# graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
# graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
# graph.add_edge(start_key="reducer", end_key=END)
return graph
# return graph

View File

@@ -1,36 +1,36 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def generic_internal_tool_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# def generic_internal_tool_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
# logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="generic_internal_tool",
node_name="branching",
node_start_time=node_start_time,
)
],
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="generic_internal_tool",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,149 +1,147 @@
import json
from datetime import datetime
from typing import cast
# import json
# from datetime import datetime
# from typing import cast
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.messages import AIMessage
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
from onyx.prompts.dr_prompts import OKTA_TOOL_USE_SPECIAL_PROMPT
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
# from onyx.prompts.dr_prompts import OKTA_TOOL_USE_SPECIAL_PROMPT
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def generic_internal_tool_act(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# def generic_internal_tool_act(
# state: BranchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# parallelization_nr = state.parallelization_nr
if not state.available_tools:
raise ValueError("available_tools is not set")
# if not state.available_tools:
# raise ValueError("available_tools is not set")
generic_internal_tool_info = state.available_tools[state.tools_used[-1]]
generic_internal_tool_name = generic_internal_tool_info.llm_path
generic_internal_tool = generic_internal_tool_info.tool_object
# generic_internal_tool_info = state.available_tools[state.tools_used[-1]]
# generic_internal_tool_name = generic_internal_tool_info.llm_path
# generic_internal_tool = generic_internal_tool_info.tool_object
if generic_internal_tool is None:
raise ValueError("generic_internal_tool is not set")
# if generic_internal_tool is None:
# raise ValueError("generic_internal_tool is not set")
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
# branch_query = state.branch_question
# if not branch_query:
# raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
base_question = graph_config.inputs.prompt_builder.raw_user_query
# graph_config = cast(GraphConfig, config["metadata"]["config"])
# base_question = graph_config.inputs.prompt_builder.raw_user_query
logger.debug(
f"Tool call start for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# logger.debug(
# f"Tool call start for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
# get tool call args
tool_args: dict | None = None
if graph_config.tooling.using_tool_calling_llm:
# get tool call args from tool-calling LLM
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
query=branch_query,
base_question=base_question,
tool_description=generic_internal_tool_info.description,
)
tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
tool_use_prompt,
tools=[generic_internal_tool.tool_definition()],
tool_choice="required",
timeout_override=TF_DR_TIMEOUT_SHORT,
)
# # get tool call args
# tool_args: dict | None = None
# if graph_config.tooling.using_tool_calling_llm:
# # get tool call args from tool-calling LLM
# tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
# query=branch_query,
# base_question=base_question,
# tool_description=generic_internal_tool_info.description,
# )
# tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
# tool_use_prompt,
# tools=[generic_internal_tool.tool_definition()],
# tool_choice="required",
# timeout_override=TF_DR_TIMEOUT_SHORT,
# )
# make sure we got a tool call
if (
isinstance(tool_calling_msg, AIMessage)
and len(tool_calling_msg.tool_calls) == 1
):
tool_args = tool_calling_msg.tool_calls[0]["args"]
else:
logger.warning("Tool-calling LLM did not emit a tool call")
# # make sure we got a tool call
# if (
# isinstance(tool_calling_msg, AIMessage)
# and len(tool_calling_msg.tool_calls) == 1
# ):
# tool_args = tool_calling_msg.tool_calls[0]["args"]
# else:
# logger.warning("Tool-calling LLM did not emit a tool call")
if tool_args is None:
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
tool_args = generic_internal_tool.get_args_for_non_tool_calling_llm(
query=branch_query,
history=[],
llm=graph_config.tooling.primary_llm,
force_run=True,
)
# if tool_args is None:
# raise ValueError(
# "Failed to obtain tool arguments from LLM - tool calling is required"
# )
if tool_args is None:
raise ValueError("Failed to obtain tool arguments from LLM")
# # run the tool
# tool_responses = list(generic_internal_tool.run(**tool_args))
# final_data = generic_internal_tool.final_result(*tool_responses)
# tool_result_str = json.dumps(final_data, ensure_ascii=False)
# run the tool
tool_responses = list(generic_internal_tool.run(**tool_args))
final_data = generic_internal_tool.final_result(*tool_responses)
tool_result_str = json.dumps(final_data, ensure_ascii=False)
# tool_str = (
# f"Tool used: {generic_internal_tool.display_name}\n"
# f"Description: {generic_internal_tool_info.description}\n"
# f"Result: {tool_result_str}"
# )
tool_str = (
f"Tool used: {generic_internal_tool.display_name}\n"
f"Description: {generic_internal_tool_info.description}\n"
f"Result: {tool_result_str}"
)
# if generic_internal_tool.display_name == "Okta Profile":
# tool_prompt = OKTA_TOOL_USE_SPECIAL_PROMPT
# else:
# tool_prompt = CUSTOM_TOOL_USE_PROMPT
if generic_internal_tool.display_name == "Okta Profile":
tool_prompt = OKTA_TOOL_USE_SPECIAL_PROMPT
else:
tool_prompt = CUSTOM_TOOL_USE_PROMPT
# tool_summary_prompt = tool_prompt.build(
# query=branch_query, base_question=base_question, tool_response=tool_str
# )
# answer_string = str(
# graph_config.tooling.primary_llm.invoke(
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
# ).content
# ).strip()
tool_summary_prompt = tool_prompt.build(
query=branch_query, base_question=base_question, tool_response=tool_str
)
answer_string = str(
graph_config.tooling.primary_llm.invoke_langchain(
tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
).content
).strip()
# tool_summary_prompt = tool_prompt.build(
# query=branch_query, base_question=base_question, tool_response=tool_str
# )
# answer_string = str(
# graph_config.tooling.primary_llm.invoke_langchain(
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
# ).content
# ).strip()
logger.debug(
f"Tool call end for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=generic_internal_tool.llm_name,
tool_id=generic_internal_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning="",
additional_data=None,
response_type="text", # TODO: convert all response types to enums
data=answer_string,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="tool_calling",
node_start_time=node_start_time,
)
],
)
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=generic_internal_tool.name,
# tool_id=generic_internal_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=branch_query,
# answer=answer_string,
# claims=[],
# cited_documents={},
# reasoning="",
# additional_data=None,
# response_type="text", # TODO: convert all response types to enums
# data=answer_string,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="tool_calling",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,82 +1,82 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import CustomToolDelta
# from onyx.server.query_and_chat.streaming_models import CustomToolStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def generic_internal_tool_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a generic tool call as part of the DR process.
"""
# def generic_internal_tool_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a generic tool call as part of the DR process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
current_step_nr = state.current_step_nr
# current_step_nr = state.current_step_nr
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
for new_update in new_updates:
# for new_update in new_updates:
if not new_update.response_type:
raise ValueError("Response type is not returned.")
# if not new_update.response_type:
# raise ValueError("Response type is not returned.")
write_custom_event(
current_step_nr,
CustomToolStart(
tool_name=new_update.tool,
),
writer,
)
# write_custom_event(
# current_step_nr,
# CustomToolStart(
# tool_name=new_update.tool,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
CustomToolDelta(
tool_name=new_update.tool,
response_type=new_update.response_type,
data=new_update.data,
file_ids=[],
),
writer,
)
# write_custom_event(
# current_step_nr,
# CustomToolDelta(
# tool_name=new_update.tool,
# response_type=new_update.response_type,
# data=new_update.data,
# file_ids=[],
# ),
# writer,
# )
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
current_step_nr += 1
# current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
log_messages=[
get_langgraph_node_log_string(
graph_component="custom_tool",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)
# return SubAgentUpdate(
# iteration_responses=new_updates,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="custom_tool",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,28 +1,28 @@
from collections.abc import Hashable
# from collections.abc import Hashable
from langgraph.types import Send
# from langgraph.types import Send
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import (
SubAgentInput,
)
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import (
# SubAgentInput,
# )
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel call for now
)
]
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# context="",
# active_source_types=state.active_source_types,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:1] # no parallel call for now
# )
# ]

View File

@@ -1,50 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_1_branch import (
generic_internal_tool_branch,
)
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_2_act import (
generic_internal_tool_act,
)
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_3_reduce import (
generic_internal_tool_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_1_branch import (
# generic_internal_tool_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_2_act import (
# generic_internal_tool_act,
# )
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_3_reduce import (
# generic_internal_tool_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def dr_generic_internal_tool_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Generic Tool Sub-Agent
"""
# def dr_generic_internal_tool_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Generic Tool Sub-Agent
# """
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
# ### Add nodes ###
graph.add_node("branch", generic_internal_tool_branch)
# graph.add_node("branch", generic_internal_tool_branch)
graph.add_node("act", generic_internal_tool_act)
# graph.add_node("act", generic_internal_tool_act)
graph.add_node("reducer", generic_internal_tool_reducer)
# graph.add_node("reducer", generic_internal_tool_reducer)
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
# graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
# graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
# graph.add_edge(start_key="reducer", end_key=END)
return graph
# return graph

View File

@@ -1,45 +1,45 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def image_generation_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a image generation as part of the DR process.
"""
# def image_generation_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a image generation as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
logger.debug(f"Image generation start {iteration_nr} at {datetime.now()}")
# logger.debug(f"Image generation start {iteration_nr} at {datetime.now()}")
# tell frontend that we are starting the image generation tool
write_custom_event(
state.current_step_nr,
ImageGenerationToolStart(),
writer,
)
# # tell frontend that we are starting the image generation tool
# write_custom_event(
# state.current_step_nr,
# ImageGenerationToolStart(),
# writer,
# )
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="image_generation",
node_name="branching",
node_start_time=node_start_time,
)
],
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="image_generation",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,189 +1,187 @@
import json
from datetime import datetime
from typing import cast
# import json
# from datetime import datetime
# from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.models import GeneratedImage
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.file_store.utils import build_frontend_file_url
from onyx.file_store.utils import save_files
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
from onyx.tools.tool_implementations.images.image_generation_tool import (
IMAGE_GENERATION_HEARTBEAT_ID,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
IMAGE_GENERATION_RESPONSE_ID,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.models import GeneratedImage
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.models import GraphConfig
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.file_store.utils import build_frontend_file_url
# from onyx.file_store.utils import save_files
# from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
# from onyx.tools.tool_implementations.images.image_generation_tool import (
# IMAGE_GENERATION_HEARTBEAT_ID,
# )
# from onyx.tools.tool_implementations.images.image_generation_tool import (
# IMAGE_GENERATION_RESPONSE_ID,
# )
# from onyx.tools.tool_implementations.images.image_generation_tool import (
# ImageGenerationResponse,
# )
# from onyx.tools.tool_implementations.images.image_generation_tool import (
# ImageGenerationTool,
# )
# from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def image_generation(
state: BranchInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> BranchUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# def image_generation(
# state: BranchInput,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> BranchUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
parallelization_nr = state.parallelization_nr
state.assistant_system_prompt
state.assistant_task_prompt
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# parallelization_nr = state.parallelization_nr
# state.assistant_system_prompt
# state.assistant_task_prompt
branch_query = state.branch_question
if not branch_query:
raise ValueError("branch_query is not set")
# branch_query = state.branch_question
# if not branch_query:
# raise ValueError("branch_query is not set")
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.prompt_builder.raw_user_query
graph_config.behavior.research_type
# graph_config = cast(GraphConfig, config["metadata"]["config"])
if not state.available_tools:
raise ValueError("available_tools is not set")
# if not state.available_tools:
# raise ValueError("available_tools is not set")
image_tool_info = state.available_tools[state.tools_used[-1]]
image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
# image_tool_info = state.available_tools[state.tools_used[-1]]
# image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
image_prompt = branch_query
requested_shape: ImageShape | None = None
# image_prompt = branch_query
# requested_shape: ImageShape | None = None
try:
parsed_query = json.loads(branch_query)
except json.JSONDecodeError:
parsed_query = None
# try:
# parsed_query = json.loads(branch_query)
# except json.JSONDecodeError:
# parsed_query = None
if isinstance(parsed_query, dict):
prompt_from_llm = parsed_query.get("prompt")
if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
image_prompt = prompt_from_llm.strip()
# if isinstance(parsed_query, dict):
# prompt_from_llm = parsed_query.get("prompt")
# if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
# image_prompt = prompt_from_llm.strip()
raw_shape = parsed_query.get("shape")
if isinstance(raw_shape, str):
try:
requested_shape = ImageShape(raw_shape)
except ValueError:
logger.warning(
"Received unsupported image shape '%s' from LLM. Falling back to square.",
raw_shape,
)
# raw_shape = parsed_query.get("shape")
# if isinstance(raw_shape, str):
# try:
# requested_shape = ImageShape(raw_shape)
# except ValueError:
# logger.warning(
# "Received unsupported image shape '%s' from LLM. Falling back to square.",
# raw_shape,
# )
logger.debug(
f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# logger.debug(
# f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
# Generate images using the image generation tool
image_generation_responses: list[ImageGenerationResponse] = []
# # Generate images using the image generation tool
# image_generation_responses: list[ImageGenerationResponse] = []
if requested_shape is not None:
tool_iterator = image_tool.run(
prompt=image_prompt,
shape=requested_shape.value,
)
else:
tool_iterator = image_tool.run(prompt=image_prompt)
# if requested_shape is not None:
# tool_iterator = image_tool.run(
# prompt=image_prompt,
# shape=requested_shape.value,
# )
# else:
# tool_iterator = image_tool.run(prompt=image_prompt)
for tool_response in tool_iterator:
if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
# Stream heartbeat to frontend
write_custom_event(
state.current_step_nr,
ImageGenerationToolHeartbeat(),
writer,
)
elif tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
response = cast(list[ImageGenerationResponse], tool_response.response)
image_generation_responses = response
break
# for tool_response in tool_iterator:
# if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
# # Stream heartbeat to frontend
# write_custom_event(
# state.current_step_nr,
# ImageGenerationToolHeartbeat(),
# writer,
# )
# elif tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
# response = cast(list[ImageGenerationResponse], tool_response.response)
# image_generation_responses = response
# break
# save images to file store
file_ids = save_files(
urls=[],
base64_files=[img.image_data for img in image_generation_responses],
)
# # save images to file store
# file_ids = save_files(
# urls=[],
# base64_files=[img.image_data for img in image_generation_responses],
# )
final_generated_images = [
GeneratedImage(
file_id=file_id,
url=build_frontend_file_url(file_id),
revised_prompt=img.revised_prompt,
shape=(requested_shape or ImageShape.SQUARE).value,
)
for file_id, img in zip(file_ids, image_generation_responses)
]
# final_generated_images = [
# GeneratedImage(
# file_id=file_id,
# url=build_frontend_file_url(file_id),
# revised_prompt=img.revised_prompt,
# shape=(requested_shape or ImageShape.SQUARE).value,
# )
# for file_id, img in zip(file_ids, image_generation_responses)
# ]
logger.debug(
f"Image generation complete for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# logger.debug(
# f"Image generation complete for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
# Create answer string describing the generated images
if final_generated_images:
image_descriptions = []
for i, img in enumerate(final_generated_images, 1):
if img.shape and img.shape != ImageShape.SQUARE.value:
image_descriptions.append(
f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
)
else:
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
# # Create answer string describing the generated images
# if final_generated_images:
# image_descriptions = []
# for i, img in enumerate(final_generated_images, 1):
# if img.shape and img.shape != ImageShape.SQUARE.value:
# image_descriptions.append(
# f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
# )
# else:
# image_descriptions.append(f"Image {i}: {img.revised_prompt}")
answer_string = (
f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
+ "\n".join(image_descriptions)
)
if requested_shape:
reasoning = (
"Used image generation tool to create "
f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
)
else:
reasoning = (
"Used image generation tool to create "
f"{len(final_generated_images)} image(s) based on the user's request."
)
else:
answer_string = f"Failed to generate images for request: {image_prompt}"
reasoning = "Image generation tool did not return any results."
# answer_string = (
# f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
# + "\n".join(image_descriptions)
# )
# if requested_shape:
# reasoning = (
# "Used image generation tool to create "
# f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
# )
# else:
# reasoning = (
# "Used image generation tool to create "
# f"{len(final_generated_images)} image(s) based on the user's request."
# )
# else:
# answer_string = f"Failed to generate images for request: {image_prompt}"
# reasoning = "Image generation tool did not return any results."
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=image_tool_info.llm_path,
tool_id=image_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=branch_query,
answer=answer_string,
claims=[],
cited_documents={},
reasoning=reasoning,
generated_images=final_generated_images,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="image_generation",
node_name="generating",
node_start_time=node_start_time,
)
],
)
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=image_tool_info.llm_path,
# tool_id=image_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=branch_query,
# answer=answer_string,
# claims=[],
# cited_documents={},
# reasoning=reasoning,
# generated_images=final_generated_images,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="image_generation",
# node_name="generating",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,71 +1,71 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.models import GeneratedImage
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.models import GeneratedImage
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def is_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a standard search as part of the DR process.
"""
# def is_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a standard search as part of the DR process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
# current_step_nr = state.current_step_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
generated_images: list[GeneratedImage] = []
for update in new_updates:
if update.generated_images:
generated_images.extend(update.generated_images)
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
# generated_images: list[GeneratedImage] = []
# for update in new_updates:
# if update.generated_images:
# generated_images.extend(update.generated_images)
# Write the results to the stream
write_custom_event(
current_step_nr,
ImageGenerationToolDelta(
images=generated_images,
),
writer,
)
# # Write the results to the stream
# write_custom_event(
# current_step_nr,
# ImageGenerationFinal(
# images=generated_images,
# ),
# writer,
# )
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
current_step_nr += 1
# current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="image_generation",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)
# return SubAgentUpdate(
# iteration_responses=new_updates,
# current_step_nr=current_step_nr,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="image_generation",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,29 +1,29 @@
from collections.abc import Hashable
# from collections.abc import Hashable
from langgraph.types import Send
# from langgraph.types import Send
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
active_source_types=state.active_source_types,
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:MAX_DR_PARALLEL_SEARCH]
)
]
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# context="",
# active_source_types=state.active_source_types,
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:MAX_DR_PARALLEL_SEARCH]
# )
# ]

View File

@@ -1,50 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_1_branch import (
image_generation_branch,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_2_act import (
image_generation,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_3_reduce import (
is_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_1_branch import (
# image_generation_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_2_act import (
# image_generation,
# )
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_3_reduce import (
# is_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def dr_image_generation_graph_builder() -> StateGraph:
"""
LangGraph graph builder for Image Generation Sub-Agent
"""
# def dr_image_generation_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for Image Generation Sub-Agent
# """
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
# ### Add nodes ###
graph.add_node("branch", image_generation_branch)
# graph.add_node("branch", image_generation_branch)
graph.add_node("act", image_generation)
# graph.add_node("act", image_generation)
graph.add_node("reducer", is_reducer)
# graph.add_node("reducer", is_reducer)
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
# graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
# graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
# graph.add_edge(start_key="reducer", end_key=END)
return graph
# return graph

View File

@@ -1,13 +1,13 @@
from pydantic import BaseModel
# from pydantic import BaseModel
class GeneratedImage(BaseModel):
file_id: str
url: str
revised_prompt: str
shape: str | None = None
# class GeneratedImage(BaseModel):
# file_id: str
# url: str
# revised_prompt: str
# shape: str | None = None
# Needed for PydanticType
class GeneratedImageFullResult(BaseModel):
images: list[GeneratedImage]
# # Needed for PydanticType
# class GeneratedImageFullResult(BaseModel):
# images: list[GeneratedImage]

View File

@@ -1,36 +1,36 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def kg_search_branch(
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> LoggerUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
# def kg_search_branch(
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> LoggerUpdate:
# """
# LangGraph node to perform a KG search as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
logger.debug(f"Search start for KG Search {iteration_nr} at {datetime.now()}")
# logger.debug(f"Search start for KG Search {iteration_nr} at {datetime.now()}")
return LoggerUpdate(
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="branching",
node_start_time=node_start_time,
)
],
)
# return LoggerUpdate(
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="kg_search",
# node_name="branching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,97 +1,97 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
from onyx.agents.agent_search.dr.utils import extract_document_citations
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
from onyx.agents.agent_search.kb_search.states import MainInput as KbMainInput
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
# from onyx.agents.agent_search.dr.utils import extract_document_citations
# from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
# from onyx.agents.agent_search.kb_search.states import MainInput as KbMainInput
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.context.search.models import InferenceSection
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def kg_search(
state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BranchUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
# def kg_search(
# state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
# ) -> BranchUpdate:
# """
# LangGraph node to perform a KG search as part of the DR process.
# """
node_start_time = datetime.now()
iteration_nr = state.iteration_nr
state.current_step_nr
parallelization_nr = state.parallelization_nr
# node_start_time = datetime.now()
# iteration_nr = state.iteration_nr
# state.current_step_nr
# parallelization_nr = state.parallelization_nr
search_query = state.branch_question
if not search_query:
raise ValueError("search_query is not set")
# search_query = state.branch_question
# if not search_query:
# raise ValueError("search_query is not set")
logger.debug(
f"Search start for KG Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
# logger.debug(
# f"Search start for KG Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
# )
if not state.available_tools:
raise ValueError("available_tools is not set")
# if not state.available_tools:
# raise ValueError("available_tools is not set")
kg_tool_info = state.available_tools[state.tools_used[-1]]
# kg_tool_info = state.available_tools[state.tools_used[-1]]
kb_graph = kb_graph_builder().compile()
# kb_graph = kb_graph_builder().compile()
kb_results = kb_graph.invoke(
input=KbMainInput(question=search_query, individual_flow=False),
config=config,
)
# kb_results = kb_graph.invoke(
# input=KbMainInput(question=search_query, individual_flow=False),
# config=config,
# )
# get cited documents
answer_string = kb_results.get("final_answer") or "No answer provided"
claims: list[str] = []
retrieved_docs: list[InferenceSection] = kb_results.get("retrieved_documents", [])
# # get cited documents
# answer_string = kb_results.get("final_answer") or "No answer provided"
# claims: list[str] = []
# retrieved_docs: list[InferenceSection] = kb_results.get("retrieved_documents", [])
(
citation_numbers,
answer_string,
claims,
) = extract_document_citations(answer_string, claims)
# (
# citation_numbers,
# answer_string,
# claims,
# ) = extract_document_citations(answer_string, claims)
# if citation is empty, the answer must have come from the KG rather than a doc
# in that case, simply cite the docs returned by the KG
if not citation_numbers:
citation_numbers = [i + 1 for i in range(len(retrieved_docs))]
# # if citation is empty, the answer must have come from the KG rather than a doc
# # in that case, simply cite the docs returned by the KG
# if not citation_numbers:
# citation_numbers = [i + 1 for i in range(len(retrieved_docs))]
cited_documents = {
citation_number: retrieved_docs[citation_number - 1]
for citation_number in citation_numbers
if citation_number <= len(retrieved_docs)
}
# cited_documents = {
# citation_number: retrieved_docs[citation_number - 1]
# for citation_number in citation_numbers
# if citation_number <= len(retrieved_docs)
# }
return BranchUpdate(
branch_iteration_responses=[
IterationAnswer(
tool=kg_tool_info.llm_path,
tool_id=kg_tool_info.tool_id,
iteration_nr=iteration_nr,
parallelization_nr=parallelization_nr,
question=search_query,
answer=answer_string,
claims=claims,
cited_documents=cited_documents,
reasoning=None,
additional_data=None,
)
],
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="searching",
node_start_time=node_start_time,
)
],
)
# return BranchUpdate(
# branch_iteration_responses=[
# IterationAnswer(
# tool=kg_tool_info.llm_path,
# tool_id=kg_tool_info.tool_id,
# iteration_nr=iteration_nr,
# parallelization_nr=parallelization_nr,
# question=search_query,
# answer=answer_string,
# claims=claims,
# cited_documents=cited_documents,
# reasoning=None,
# additional_data=None,
# )
# ],
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="kg_search",
# node_name="searching",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,121 +1,121 @@
from datetime import datetime
# from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
# from langchain_core.runnables import RunnableConfig
# from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
# from onyx.agents.agent_search.shared_graph_utils.utils import (
# get_langgraph_node_log_string,
# )
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
# from onyx.server.query_and_chat.streaming_models import ReasoningDelta
# from onyx.server.query_and_chat.streaming_models import ReasoningStart
# from onyx.server.query_and_chat.streaming_models import SearchToolDelta
# from onyx.server.query_and_chat.streaming_models import SearchToolStart
# from onyx.server.query_and_chat.streaming_models import SectionEnd
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
_MAX_KG_STEAMED_ANSWER_LENGTH = 1000 # num characters
# _MAX_KG_STEAMED_ANSWER_LENGTH = 1000 # num characters
def kg_search_reducer(
state: SubAgentMainState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubAgentUpdate:
"""
LangGraph node to perform a KG search as part of the DR process.
"""
# def kg_search_reducer(
# state: SubAgentMainState,
# config: RunnableConfig,
# writer: StreamWriter = lambda _: None,
# ) -> SubAgentUpdate:
# """
# LangGraph node to perform a KG search as part of the DR process.
# """
node_start_time = datetime.now()
# node_start_time = datetime.now()
branch_updates = state.branch_iteration_responses
current_iteration = state.iteration_nr
current_step_nr = state.current_step_nr
# branch_updates = state.branch_iteration_responses
# current_iteration = state.iteration_nr
# current_step_nr = state.current_step_nr
new_updates = [
update for update in branch_updates if update.iteration_nr == current_iteration
]
# new_updates = [
# update for update in branch_updates if update.iteration_nr == current_iteration
# ]
queries = [update.question for update in new_updates]
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
# queries = [update.question for update in new_updates]
# doc_lists = [list(update.cited_documents.values()) for update in new_updates]
doc_list = []
# doc_list = []
for xs in doc_lists:
for x in xs:
doc_list.append(x)
# for xs in doc_lists:
# for x in xs:
# doc_list.append(x)
retrieved_search_docs = convert_inference_sections_to_search_docs(doc_list)
kg_answer = (
"The Knowledge Graph Answer:\n\n" + new_updates[0].answer
if len(queries) == 1
else None
)
# retrieved_search_docs = convert_inference_sections_to_search_docs(doc_list)
# kg_answer = (
# "The Knowledge Graph Answer:\n\n" + new_updates[0].answer
# if len(queries) == 1
# else None
# )
if len(retrieved_search_docs) > 0:
write_custom_event(
current_step_nr,
SearchToolStart(
is_internet_search=False,
),
writer,
)
write_custom_event(
current_step_nr,
SearchToolDelta(
queries=queries,
documents=retrieved_search_docs,
),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# if len(retrieved_search_docs) > 0:
# write_custom_event(
# current_step_nr,
# SearchToolStart(
# is_internet_search=False,
# ),
# writer,
# )
# write_custom_event(
# current_step_nr,
# SearchToolDelta(
# queries=queries,
# documents=retrieved_search_docs,
# ),
# writer,
# )
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
current_step_nr += 1
# current_step_nr += 1
if kg_answer is not None:
# if kg_answer is not None:
kg_display_answer = (
f"{kg_answer[:_MAX_KG_STEAMED_ANSWER_LENGTH]}..."
if len(kg_answer) > _MAX_KG_STEAMED_ANSWER_LENGTH
else kg_answer
)
# kg_display_answer = (
# f"{kg_answer[:_MAX_KG_STEAMED_ANSWER_LENGTH]}..."
# if len(kg_answer) > _MAX_KG_STEAMED_ANSWER_LENGTH
# else kg_answer
# )
write_custom_event(
current_step_nr,
ReasoningStart(),
writer,
)
write_custom_event(
current_step_nr,
ReasoningDelta(reasoning=kg_display_answer),
writer,
)
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
# write_custom_event(
# current_step_nr,
# ReasoningStart(),
# writer,
# )
# write_custom_event(
# current_step_nr,
# ReasoningDelta(reasoning=kg_display_answer),
# writer,
# )
# write_custom_event(
# current_step_nr,
# SectionEnd(),
# writer,
# )
current_step_nr += 1
# current_step_nr += 1
return SubAgentUpdate(
iteration_responses=new_updates,
current_step_nr=current_step_nr,
log_messages=[
get_langgraph_node_log_string(
graph_component="kg_search",
node_name="consolidation",
node_start_time=node_start_time,
)
],
)
# return SubAgentUpdate(
# iteration_responses=new_updates,
# current_step_nr=current_step_nr,
# log_messages=[
# get_langgraph_node_log_string(
# graph_component="kg_search",
# node_name="consolidation",
# node_start_time=node_start_time,
# )
# ],
# )

View File

@@ -1,27 +1,27 @@
from collections.abc import Hashable
# from collections.abc import Hashable
from langgraph.types import Send
# from langgraph.types import Send
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
return [
Send(
"act",
BranchInput(
iteration_nr=state.iteration_nr,
parallelization_nr=parallelization_nr,
branch_question=query,
context="",
tools_used=state.tools_used,
available_tools=state.available_tools,
assistant_system_prompt=state.assistant_system_prompt,
assistant_task_prompt=state.assistant_task_prompt,
),
)
for parallelization_nr, query in enumerate(
state.query_list[:1] # no parallel search for now
)
]
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
# return [
# Send(
# "act",
# BranchInput(
# iteration_nr=state.iteration_nr,
# parallelization_nr=parallelization_nr,
# branch_question=query,
# context="",
# tools_used=state.tools_used,
# available_tools=state.available_tools,
# assistant_system_prompt=state.assistant_system_prompt,
# assistant_task_prompt=state.assistant_task_prompt,
# ),
# )
# for parallelization_nr, query in enumerate(
# state.query_list[:1] # no parallel search for now
# )
# ]

View File

@@ -1,50 +1,50 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
# from langgraph.graph import END
# from langgraph.graph import START
# from langgraph.graph import StateGraph
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_1_branch import (
kg_search_branch,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_2_act import (
kg_search,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_3_reduce import (
kg_search_reducer,
)
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_conditional_edges import (
branching_router,
)
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.utils.logger import setup_logger
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_1_branch import (
# kg_search_branch,
# )
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_2_act import (
# kg_search,
# )
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_3_reduce import (
# kg_search_reducer,
# )
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_conditional_edges import (
# branching_router,
# )
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
# from onyx.utils.logger import setup_logger
logger = setup_logger()
# logger = setup_logger()
def dr_kg_search_graph_builder() -> StateGraph:
"""
LangGraph graph builder for KG Search Sub-Agent
"""
# def dr_kg_search_graph_builder() -> StateGraph:
# """
# LangGraph graph builder for KG Search Sub-Agent
# """
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
### Add nodes ###
# ### Add nodes ###
graph.add_node("branch", kg_search_branch)
# graph.add_node("branch", kg_search_branch)
graph.add_node("act", kg_search)
# graph.add_node("act", kg_search)
graph.add_node("reducer", kg_search_reducer)
# graph.add_node("reducer", kg_search_reducer)
### Add edges ###
# ### Add edges ###
graph.add_edge(start_key=START, end_key="branch")
# graph.add_edge(start_key=START, end_key="branch")
graph.add_conditional_edges("branch", branching_router)
# graph.add_conditional_edges("branch", branching_router)
graph.add_edge(start_key="act", end_key="reducer")
# graph.add_edge(start_key="act", end_key="reducer")
graph.add_edge(start_key="reducer", end_key=END)
# graph.add_edge(start_key="reducer", end_key=END)
return graph
# return graph

View File

@@ -1,46 +1,46 @@
from operator import add
from typing import Annotated
# from operator import add
# from typing import Annotated
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.db.connector import DocumentSource
# from onyx.agents.agent_search.dr.models import IterationAnswer
# from onyx.agents.agent_search.dr.models import OrchestratorTool
# from onyx.agents.agent_search.dr.states import LoggerUpdate
# from onyx.db.connector import DocumentSource
class SubAgentUpdate(LoggerUpdate):
iteration_responses: Annotated[list[IterationAnswer], add] = []
current_step_nr: int = 1
# class SubAgentUpdate(LoggerUpdate):
# iteration_responses: Annotated[list[IterationAnswer], add] = []
# current_step_nr: int = 1
class BranchUpdate(LoggerUpdate):
branch_iteration_responses: Annotated[list[IterationAnswer], add] = []
# class BranchUpdate(LoggerUpdate):
# branch_iteration_responses: Annotated[list[IterationAnswer], add] = []
class SubAgentInput(LoggerUpdate):
iteration_nr: int = 0
current_step_nr: int = 1
query_list: list[str] = []
context: str | None = None
active_source_types: list[DocumentSource] | None = None
tools_used: Annotated[list[str], add] = []
available_tools: dict[str, OrchestratorTool] | None = None
assistant_system_prompt: str | None = None
assistant_task_prompt: str | None = None
# class SubAgentInput(LoggerUpdate):
# iteration_nr: int = 0
# current_step_nr: int = 1
# query_list: list[str] = []
# context: str | None = None
# active_source_types: list[DocumentSource] | None = None
# tools_used: Annotated[list[str], add] = []
# available_tools: dict[str, OrchestratorTool] | None = None
# assistant_system_prompt: str | None = None
# assistant_task_prompt: str | None = None
class SubAgentMainState(
# This includes the core state
SubAgentInput,
SubAgentUpdate,
BranchUpdate,
):
pass
# class SubAgentMainState(
# # This includes the core state
# SubAgentInput,
# SubAgentUpdate,
# BranchUpdate,
# ):
# pass
class BranchInput(SubAgentInput):
parallelization_nr: int = 0
branch_question: str
# class BranchInput(SubAgentInput):
# parallelization_nr: int = 0
# branch_question: str
class CustomToolBranchInput(LoggerUpdate):
tool_info: OrchestratorTool
# class CustomToolBranchInput(LoggerUpdate):
# tool_info: OrchestratorTool

View File

@@ -1,71 +1,71 @@
from collections.abc import Sequence
# from collections.abc import Sequence
from exa_py import Exa
from exa_py.api import HighlightsContentsOptions
# from exa_py import Exa
# from exa_py.api import HighlightsContentsOptions
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContent,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.configs.chat_configs import EXA_API_KEY
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.utils.retry_wrapper import retry_builder
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebContent,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchProvider,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchResult,
# )
# from onyx.configs.chat_configs import EXA_API_KEY
# from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
# from onyx.utils.retry_wrapper import retry_builder
class ExaClient(WebSearchProvider):
def __init__(self, api_key: str | None = EXA_API_KEY) -> None:
self.exa = Exa(api_key=api_key)
# class ExaClient(WebSearchProvider):
# def __init__(self, api_key: str | None = EXA_API_KEY) -> None:
# self.exa = Exa(api_key=api_key)
@retry_builder(tries=3, delay=1, backoff=2)
def search(self, query: str) -> list[WebSearchResult]:
response = self.exa.search_and_contents(
query,
type="auto",
highlights=HighlightsContentsOptions(
num_sentences=2,
highlights_per_url=1,
),
num_results=10,
)
# @retry_builder(tries=3, delay=1, backoff=2)
# def search(self, query: str) -> list[WebSearchResult]:
# response = self.exa.search_and_contents(
# query,
# type="auto",
# highlights=HighlightsContentsOptions(
# num_sentences=2,
# highlights_per_url=1,
# ),
# num_results=10,
# )
return [
WebSearchResult(
title=result.title or "",
link=result.url,
snippet=result.highlights[0] if result.highlights else "",
author=result.author,
published_date=(
time_str_to_utc(result.published_date)
if result.published_date
else None
),
)
for result in response.results
]
# return [
# WebSearchResult(
# title=result.title or "",
# link=result.url,
# snippet=result.highlights[0] if result.highlights else "",
# author=result.author,
# published_date=(
# time_str_to_utc(result.published_date)
# if result.published_date
# else None
# ),
# )
# for result in response.results
# ]
@retry_builder(tries=3, delay=1, backoff=2)
def contents(self, urls: Sequence[str]) -> list[WebContent]:
response = self.exa.get_contents(
urls=list(urls),
text=True,
livecrawl="preferred",
)
# @retry_builder(tries=3, delay=1, backoff=2)
# def contents(self, urls: Sequence[str]) -> list[WebContent]:
# response = self.exa.get_contents(
# urls=list(urls),
# text=True,
# livecrawl="preferred",
# )
return [
WebContent(
title=result.title or "",
link=result.url,
full_content=result.text or "",
published_date=(
time_str_to_utc(result.published_date)
if result.published_date
else None
),
)
for result in response.results
]
# return [
# WebContent(
# title=result.title or "",
# link=result.url,
# full_content=result.text or "",
# published_date=(
# time_str_to_utc(result.published_date)
# if result.published_date
# else None
# ),
# )
# for result in response.results
# ]

View File

@@ -0,0 +1,163 @@
# from __future__ import annotations
# from collections.abc import Sequence
# from concurrent.futures import ThreadPoolExecutor
# from dataclasses import dataclass
# from datetime import datetime
# from typing import Any
# import requests
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebContent
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebContentProvider,
# )
# from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
# from onyx.utils.logger import setup_logger
# from onyx.utils.retry_wrapper import retry_builder
# logger = setup_logger()
# FIRECRAWL_SCRAPE_URL = "https://api.firecrawl.dev/v1/scrape"
# _DEFAULT_MAX_WORKERS = 4
# @dataclass
# class ExtractedContentFields:
# text: str
# title: str
# published_date: datetime | None
# class FirecrawlClient(WebContentProvider):
# def __init__(
# self,
# api_key: str,
# *,
# base_url: str = FIRECRAWL_SCRAPE_URL,
# timeout_seconds: int = 30,
# ) -> None:
# self._headers = {
# "Authorization": f"Bearer {api_key}",
# "Content-Type": "application/json",
# }
# self._base_url = base_url
# self._timeout_seconds = timeout_seconds
# self._last_error: str | None = None
# @property
# def last_error(self) -> str | None:
# return self._last_error
# def contents(self, urls: Sequence[str]) -> list[WebContent]:
# if not urls:
# return []
# max_workers = min(_DEFAULT_MAX_WORKERS, len(urls))
# with ThreadPoolExecutor(max_workers=max_workers) as executor:
# return list(executor.map(self._get_webpage_content_safe, urls))
# def _get_webpage_content_safe(self, url: str) -> WebContent:
# try:
# return self._get_webpage_content(url)
# except Exception as exc:
# self._last_error = str(exc)
# return WebContent(
# title="",
# link=url,
# full_content="",
# published_date=None,
# scrape_successful=False,
# )
# @retry_builder(tries=3, delay=1, backoff=2)
# def _get_webpage_content(self, url: str) -> WebContent:
# payload = {
# "url": url,
# "formats": ["markdown"],
# }
# response = requests.post(
# self._base_url,
# headers=self._headers,
# json=payload,
# timeout=self._timeout_seconds,
# )
# if response.status_code != 200:
# try:
# error_payload = response.json()
# except Exception:
# error_payload = response.text
# self._last_error = (
# error_payload if isinstance(error_payload, str) else str(error_payload)
# )
# if 400 <= response.status_code < 500:
# return WebContent(
# title="",
# link=url,
# full_content="",
# published_date=None,
# scrape_successful=False,
# )
# raise ValueError(
# f"Firecrawl fetch failed with status {response.status_code}."
# )
# else:
# self._last_error = None
# response_json = response.json()
# extracted = self._extract_content_fields(response_json, url)
# return WebContent(
# title=extracted.title,
# link=url,
# full_content=extracted.text,
# published_date=extracted.published_date,
# scrape_successful=bool(extracted.text),
# )
# @staticmethod
# def _extract_content_fields(
# response_json: dict[str, Any], url: str
# ) -> ExtractedContentFields:
# data_section = response_json.get("data") or {}
# metadata = data_section.get("metadata") or response_json.get("metadata") or {}
# text_candidates = [
# data_section.get("markdown"),
# data_section.get("content"),
# data_section.get("text"),
# response_json.get("markdown"),
# response_json.get("content"),
# response_json.get("text"),
# ]
# text = next((candidate for candidate in text_candidates if candidate), "")
# title = metadata.get("title") or response_json.get("title") or ""
# published_date = None
# published_date_str = (
# metadata.get("publishedTime")
# or metadata.get("date")
# or response_json.get("publishedTime")
# or response_json.get("date")
# )
# if published_date_str:
# try:
# published_date = time_str_to_utc(published_date_str)
# except Exception:
# published_date = None
# if not text:
# logger.warning(f"Firecrawl returned empty content for url={url}")
# return ExtractedContentFields(
# text=text or "",
# title=title or "",
# published_date=published_date,
# )

View File

@@ -0,0 +1,138 @@
# from __future__ import annotations
# from collections.abc import Sequence
# from datetime import datetime
# from typing import Any
# import requests
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebContent
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebSearchProvider,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebSearchResult
# from onyx.utils.logger import setup_logger
# from onyx.utils.retry_wrapper import retry_builder
# logger = setup_logger()
# GOOGLE_CUSTOM_SEARCH_URL = "https://customsearch.googleapis.com/customsearch/v1"
# class GooglePSEClient(WebSearchProvider):
# def __init__(
# self,
# api_key: str,
# search_engine_id: str,
# *,
# num_results: int = 10,
# timeout_seconds: int = 10,
# ) -> None:
# self._api_key = api_key
# self._search_engine_id = search_engine_id
# self._num_results = num_results
# self._timeout_seconds = timeout_seconds
# @retry_builder(tries=3, delay=1, backoff=2)
# def search(self, query: str) -> list[WebSearchResult]:
# params: dict[str, str] = {
# "key": self._api_key,
# "cx": self._search_engine_id,
# "q": query,
# "num": str(self._num_results),
# }
# response = requests.get(
# GOOGLE_CUSTOM_SEARCH_URL, params=params, timeout=self._timeout_seconds
# )
# # Check for HTTP errors first
# try:
# response.raise_for_status()
# except requests.HTTPError as exc:
# status = response.status_code
# error_detail = "Unknown error"
# try:
# error_data = response.json()
# if "error" in error_data:
# error_info = error_data["error"]
# error_detail = error_info.get("message", str(error_info))
# except Exception:
# error_detail = (
# response.text[:200] if response.text else "No error details"
# )
# raise ValueError(
# f"Google PSE search failed (status {status}): {error_detail}"
# ) from exc
# data = response.json()
# # Google Custom Search API can return errors in the response body even with 200 status
# if "error" in data:
# error_info = data["error"]
# error_message = error_info.get("message", "Unknown error")
# error_code = error_info.get("code", "Unknown")
# raise ValueError(f"Google PSE API error ({error_code}): {error_message}")
# items: list[dict[str, Any]] = data.get("items", [])
# results: list[WebSearchResult] = []
# for item in items:
# link = item.get("link")
# if not link:
# continue
# snippet = item.get("snippet") or ""
# # Attempt to extract metadata if available
# pagemap = item.get("pagemap") or {}
# metatags = pagemap.get("metatags", [])
# published_date: datetime | None = None
# author: str | None = None
# if metatags:
# meta = metatags[0]
# author = meta.get("og:site_name") or meta.get("author")
# published_str = (
# meta.get("article:published_time")
# or meta.get("og:updated_time")
# or meta.get("date")
# )
# if published_str:
# try:
# published_date = datetime.fromisoformat(
# published_str.replace("Z", "+00:00")
# )
# except ValueError:
# logger.debug(
# f"Failed to parse published_date '{published_str}' for link {link}"
# )
# published_date = None
# results.append(
# WebSearchResult(
# title=item.get("title") or "",
# link=link,
# snippet=snippet,
# author=author,
# published_date=published_date,
# )
# )
# return results
# def contents(self, urls: Sequence[str]) -> list[WebContent]:
# logger.warning(
# "Google PSE does not support content fetching; returning empty results."
# )
# return [
# WebContent(
# title="",
# link=url,
# full_content="",
# published_date=None,
# scrape_successful=False,
# )
# for url in urls
# ]

View File

@@ -0,0 +1,94 @@
# from __future__ import annotations
# from collections.abc import Sequence
# import requests
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebContent,
# )
# from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
# WebContentProvider,
# )
# from onyx.file_processing.html_utils import ParsedHTML
# from onyx.file_processing.html_utils import web_html_cleanup
# from onyx.utils.logger import setup_logger
# logger = setup_logger()
# DEFAULT_TIMEOUT_SECONDS = 15
# DEFAULT_USER_AGENT = "OnyxWebCrawler/1.0 (+https://www.onyx.app)"
# class OnyxWebCrawler(WebContentProvider):
# """
# Lightweight built-in crawler that fetches HTML directly and extracts readable text.
# Acts as the default content provider when no external crawler (e.g. Firecrawl) is
# configured.
# """
# def __init__(
# self,
# *,
# timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
# user_agent: str = DEFAULT_USER_AGENT,
# ) -> None:
# self._timeout_seconds = timeout_seconds
# self._headers = {
# "User-Agent": user_agent,
# "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
# }
# def contents(self, urls: Sequence[str]) -> list[WebContent]:
# results: list[WebContent] = []
# for url in urls:
# results.append(self._fetch_url(url))
# return results
# def _fetch_url(self, url: str) -> WebContent:
# try:
# response = requests.get(
# url, headers=self._headers, timeout=self._timeout_seconds
# )
# except Exception as exc: # pragma: no cover - network failures vary
# logger.warning(
# "Onyx crawler failed to fetch %s (%s)",
# url,
# exc.__class__.__name__,
# )
# return WebContent(
# title="",
# link=url,
# full_content="",
# published_date=None,
# scrape_successful=False,
# )
# if response.status_code >= 400:
# logger.warning("Onyx crawler received %s for %s", response.status_code, url)
# return WebContent(
# title="",
# link=url,
# full_content="",
# published_date=None,
# scrape_successful=False,
# )
# try:
# parsed: ParsedHTML = web_html_cleanup(response.text)
# text_content = parsed.cleaned_text or ""
# title = parsed.title or ""
# except Exception as exc:
# logger.warning(
# "Onyx crawler failed to parse %s (%s)", url, exc.__class__.__name__
# )
# text_content = ""
# title = ""
# return WebContent(
# title=title,
# link=url,
# full_content=text_content,
# published_date=None,
# scrape_successful=bool(text_content.strip()),
# )

Some files were not shown because too many files have changed in this diff Show More