mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-16 23:35:46 +00:00
Compare commits
107 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8a38fdf8a5 | ||
|
|
9155d4aa21 | ||
|
|
b20591611a | ||
|
|
83e756bf05 | ||
|
|
19b485cffd | ||
|
|
f5a99053ac | ||
|
|
91f0377dd5 | ||
|
|
25522dfbb8 | ||
|
|
b0e124ec89 | ||
|
|
b699a65384 | ||
|
|
cc82d6e506 | ||
|
|
8a6db7474d | ||
|
|
fd9aea212b | ||
|
|
4aed383e49 | ||
|
|
d0ce313b1a | ||
|
|
4d32c9f5e0 | ||
|
|
158fe31b71 | ||
|
|
97cddc1dd4 | ||
|
|
c520a4ec17 | ||
|
|
9c1f8cc98c | ||
|
|
58ba8cc68a | ||
|
|
a307b0d366 | ||
|
|
e34f58e994 | ||
|
|
7f6dd2dc93 | ||
|
|
ef3daa58b3 | ||
|
|
972c33046e | ||
|
|
802248c4e4 | ||
|
|
f359c44183 | ||
|
|
bab2220091 | ||
|
|
bc35354ced | ||
|
|
742dd23fdd | ||
|
|
ea5690db81 | ||
|
|
853ca635d2 | ||
|
|
c4d2fc9492 | ||
|
|
7aa12c0a36 | ||
|
|
e74cf14401 | ||
|
|
75c42ffa9d | ||
|
|
d6fbb7affd | ||
|
|
75cee70bbb | ||
|
|
1c8b819aa2 | ||
|
|
b7cf33a4cc | ||
|
|
b06459f674 | ||
|
|
920db6b3c2 | ||
|
|
b7e4b65a74 | ||
|
|
e648e0f725 | ||
|
|
c8a3368fce | ||
|
|
f74b02ad9e | ||
|
|
65b59c4a73 | ||
|
|
b74bcd0efc | ||
|
|
8c133b3853 | ||
|
|
67554cef96 | ||
|
|
07e03f3677 | ||
|
|
33fee46d71 | ||
|
|
72f5e3d38f | ||
|
|
f89380ad87 | ||
|
|
e6f00098f2 | ||
|
|
9100afa594 | ||
|
|
93d2febf2a | ||
|
|
693286411a | ||
|
|
01a3064ca3 | ||
|
|
09a80265ee | ||
|
|
2a77481c1e | ||
|
|
6838487689 | ||
|
|
1713c24080 | ||
|
|
73b3a2525a | ||
|
|
59738d9243 | ||
|
|
c0ff9c623b | ||
|
|
c03979209a | ||
|
|
a0b7639693 | ||
|
|
e3ede3c186 | ||
|
|
092dbebdf2 | ||
|
|
838e2fe924 | ||
|
|
48e2bfa3eb | ||
|
|
2a004ad257 | ||
|
|
416c7fd75e | ||
|
|
a4372b461f | ||
|
|
7eb13db6d9 | ||
|
|
c0075d5f59 | ||
|
|
475a3afe56 | ||
|
|
bf5b8e7bae | ||
|
|
4ff28c897b | ||
|
|
ec9e9be42e | ||
|
|
af5fa8fe54 | ||
|
|
03a9e9e068 | ||
|
|
ad81c3f9eb | ||
|
|
62129f4ab9 | ||
|
|
b30d38c747 | ||
|
|
0596b57501 | ||
|
|
482b2c4204 | ||
|
|
df155835b1 | ||
|
|
fd0762a1ee | ||
|
|
bd41618dd9 | ||
|
|
5a7c6312af | ||
|
|
a477508bd7 | ||
|
|
8ac34a8433 | ||
|
|
2c51466bc3 | ||
|
|
62966bd172 | ||
|
|
a8d4482b59 | ||
|
|
dd42a45008 | ||
|
|
a368556282 | ||
|
|
679d1a5ef6 | ||
|
|
12e49cd661 | ||
|
|
1859a0ad79 | ||
|
|
9199d146be | ||
|
|
9c1208ffd6 | ||
|
|
c3387e33eb | ||
|
|
c37f633a37 |
1
.github/actionlint.yml
vendored
1
.github/actionlint.yml
vendored
@@ -17,6 +17,7 @@ self-hosted-runner:
|
||||
- runner=16cpu-linux-x64
|
||||
- ubuntu-slim # Currently in public preview
|
||||
- volume=40gb
|
||||
- volume=50gb
|
||||
|
||||
# Configuration variables in array of strings defined in your repository or
|
||||
# organization. `null` means disabling configuration variables check.
|
||||
|
||||
135
.github/actions/custom-build-and-push/action.yml
vendored
135
.github/actions/custom-build-and-push/action.yml
vendored
@@ -1,135 +0,0 @@
|
||||
name: 'Build and Push Docker Image with Retry'
|
||||
description: 'Attempts to build and push a Docker image, with a retry on failure'
|
||||
inputs:
|
||||
context:
|
||||
description: 'Build context'
|
||||
required: true
|
||||
file:
|
||||
description: 'Dockerfile location'
|
||||
required: true
|
||||
platforms:
|
||||
description: 'Target platforms'
|
||||
required: true
|
||||
pull:
|
||||
description: 'Always attempt to pull a newer version of the image'
|
||||
required: false
|
||||
default: 'true'
|
||||
push:
|
||||
description: 'Push the image to registry'
|
||||
required: false
|
||||
default: 'true'
|
||||
load:
|
||||
description: 'Load the image into Docker daemon'
|
||||
required: false
|
||||
default: 'true'
|
||||
tags:
|
||||
description: 'Image tags'
|
||||
required: true
|
||||
no-cache:
|
||||
description: 'Read from cache'
|
||||
required: false
|
||||
default: 'false'
|
||||
cache-from:
|
||||
description: 'Cache sources'
|
||||
required: false
|
||||
cache-to:
|
||||
description: 'Cache destinations'
|
||||
required: false
|
||||
outputs:
|
||||
description: 'Output destinations'
|
||||
required: false
|
||||
provenance:
|
||||
description: 'Generate provenance attestation'
|
||||
required: false
|
||||
default: 'false'
|
||||
build-args:
|
||||
description: 'Build arguments'
|
||||
required: false
|
||||
retry-wait-time:
|
||||
description: 'Time to wait before attempt 2 in seconds'
|
||||
required: false
|
||||
default: '60'
|
||||
retry-wait-time-2:
|
||||
description: 'Time to wait before attempt 3 in seconds'
|
||||
required: false
|
||||
default: '120'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Build and push Docker image (Attempt 1 of 3)
|
||||
id: buildx1
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
continue-on-error: true
|
||||
with:
|
||||
context: ${{ inputs.context }}
|
||||
file: ${{ inputs.file }}
|
||||
platforms: ${{ inputs.platforms }}
|
||||
pull: ${{ inputs.pull }}
|
||||
push: ${{ inputs.push }}
|
||||
load: ${{ inputs.load }}
|
||||
tags: ${{ inputs.tags }}
|
||||
no-cache: ${{ inputs.no-cache }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
outputs: ${{ inputs.outputs }}
|
||||
provenance: ${{ inputs.provenance }}
|
||||
build-args: ${{ inputs.build-args }}
|
||||
|
||||
- name: Wait before attempt 2
|
||||
if: steps.buildx1.outcome != 'success'
|
||||
run: |
|
||||
echo "First attempt failed. Waiting ${{ inputs.retry-wait-time }} seconds before retry..."
|
||||
sleep ${{ inputs.retry-wait-time }}
|
||||
shell: bash
|
||||
|
||||
- name: Build and push Docker image (Attempt 2 of 3)
|
||||
id: buildx2
|
||||
if: steps.buildx1.outcome != 'success'
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ${{ inputs.context }}
|
||||
file: ${{ inputs.file }}
|
||||
platforms: ${{ inputs.platforms }}
|
||||
pull: ${{ inputs.pull }}
|
||||
push: ${{ inputs.push }}
|
||||
load: ${{ inputs.load }}
|
||||
tags: ${{ inputs.tags }}
|
||||
no-cache: ${{ inputs.no-cache }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
outputs: ${{ inputs.outputs }}
|
||||
provenance: ${{ inputs.provenance }}
|
||||
build-args: ${{ inputs.build-args }}
|
||||
|
||||
- name: Wait before attempt 3
|
||||
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success'
|
||||
run: |
|
||||
echo "Second attempt failed. Waiting ${{ inputs.retry-wait-time-2 }} seconds before retry..."
|
||||
sleep ${{ inputs.retry-wait-time-2 }}
|
||||
shell: bash
|
||||
|
||||
- name: Build and push Docker image (Attempt 3 of 3)
|
||||
id: buildx3
|
||||
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success'
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ${{ inputs.context }}
|
||||
file: ${{ inputs.file }}
|
||||
platforms: ${{ inputs.platforms }}
|
||||
pull: ${{ inputs.pull }}
|
||||
push: ${{ inputs.push }}
|
||||
load: ${{ inputs.load }}
|
||||
tags: ${{ inputs.tags }}
|
||||
no-cache: ${{ inputs.no-cache }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
outputs: ${{ inputs.outputs }}
|
||||
provenance: ${{ inputs.provenance }}
|
||||
build-args: ${{ inputs.build-args }}
|
||||
|
||||
- name: Report failure
|
||||
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success' && steps.buildx3.outcome != 'success'
|
||||
run: |
|
||||
echo "All attempts failed. Possible transient infrastucture issues? Try again later or inspect logs for details."
|
||||
shell: bash
|
||||
42
.github/actions/prepare-build/action.yml
vendored
42
.github/actions/prepare-build/action.yml
vendored
@@ -1,42 +0,0 @@
|
||||
name: "Prepare Build (OpenAPI generation)"
|
||||
description: "Sets up Python with uv, installs deps, generates OpenAPI schema and Python client, uploads artifact"
|
||||
inputs:
|
||||
docker-username:
|
||||
required: true
|
||||
docker-password:
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Setup Python and Install Dependencies
|
||||
uses: ./.github/actions/setup-python-and-install-dependencies
|
||||
|
||||
- name: Generate OpenAPI schema
|
||||
shell: bash
|
||||
working-directory: backend
|
||||
env:
|
||||
PYTHONPATH: "."
|
||||
run: |
|
||||
python scripts/onyx_openapi_schema.py --filename generated/openapi.json
|
||||
|
||||
# needed for pulling openapitools/openapi-generator-cli
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs['docker-username'] }}
|
||||
password: ${{ inputs['docker-password'] }}
|
||||
|
||||
- name: Generate OpenAPI Python client
|
||||
shell: bash
|
||||
run: |
|
||||
docker run --rm \
|
||||
-v "${{ github.workspace }}/backend/generated:/local" \
|
||||
openapitools/openapi-generator-cli generate \
|
||||
-i /local/openapi.json \
|
||||
-g python \
|
||||
-o /local/onyx_openapi_client \
|
||||
--package-name onyx_openapi_client \
|
||||
--skip-validate-spec \
|
||||
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
|
||||
4
.github/actions/setup-playwright/action.yml
vendored
4
.github/actions/setup-playwright/action.yml
vendored
@@ -7,9 +7,9 @@ runs:
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
key: ${{ runner.os }}-playwright-${{ hashFiles('backend/requirements/default.txt') }}
|
||||
key: ${{ runner.os }}-${{ runner.arch }}-playwright-${{ hashFiles('backend/requirements/default.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-playwright-
|
||||
${{ runner.os }}-${{ runner.arch }}-playwright-
|
||||
|
||||
- name: Install playwright
|
||||
shell: bash
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
name: "Setup Python and Install Dependencies"
|
||||
description: "Sets up Python with uv and installs deps"
|
||||
inputs:
|
||||
requirements:
|
||||
description: "Newline-separated list of requirement files to install (relative to repo root)"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
@@ -9,11 +13,26 @@ runs:
|
||||
# with:
|
||||
# enable-cache: true
|
||||
|
||||
- name: Compute requirements hash
|
||||
id: req-hash
|
||||
shell: bash
|
||||
env:
|
||||
REQUIREMENTS: ${{ inputs.requirements }}
|
||||
run: |
|
||||
# Hash the contents of the specified requirement files
|
||||
hash=""
|
||||
while IFS= read -r req; do
|
||||
if [ -n "$req" ] && [ -f "$req" ]; then
|
||||
hash="$hash$(sha256sum "$req")"
|
||||
fi
|
||||
done <<< "$REQUIREMENTS"
|
||||
echo "hash=$(echo "$hash" | sha256sum | cut -d' ' -f1)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Cache uv cache directory
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: ~/.cache/uv
|
||||
key: ${{ runner.os }}-uv-${{ hashFiles('backend/requirements/*.txt', 'backend/pyproject.toml') }}
|
||||
key: ${{ runner.os }}-uv-${{ steps.req-hash.outputs.hash }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-uv-
|
||||
|
||||
@@ -24,15 +43,30 @@ runs:
|
||||
|
||||
- name: Create virtual environment
|
||||
shell: bash
|
||||
run: |
|
||||
uv venv ${{ runner.temp }}/venv
|
||||
echo "VENV_PATH=${{ runner.temp }}/venv" >> $GITHUB_ENV
|
||||
echo "${{ runner.temp }}/venv/bin" >> $GITHUB_PATH
|
||||
env:
|
||||
VENV_DIR: ${{ runner.temp }}/venv
|
||||
run: | # zizmor: ignore[github-env]
|
||||
uv venv "$VENV_DIR"
|
||||
# Validate path before adding to GITHUB_PATH to prevent code injection
|
||||
if [ -d "$VENV_DIR/bin" ]; then
|
||||
realpath "$VENV_DIR/bin" >> "$GITHUB_PATH"
|
||||
else
|
||||
echo "Error: $VENV_DIR/bin does not exist"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies with uv
|
||||
shell: bash
|
||||
env:
|
||||
REQUIREMENTS: ${{ inputs.requirements }}
|
||||
run: |
|
||||
uv pip install \
|
||||
-r backend/requirements/default.txt \
|
||||
-r backend/requirements/dev.txt \
|
||||
-r backend/requirements/model_server.txt
|
||||
# Build the uv pip install command with each requirement file as array elements
|
||||
cmd=("uv" "pip" "install")
|
||||
while IFS= read -r req; do
|
||||
# Skip empty lines
|
||||
if [ -n "$req" ]; then
|
||||
cmd+=("-r" "$req")
|
||||
fi
|
||||
done <<< "$REQUIREMENTS"
|
||||
echo "Running: ${cmd[*]}"
|
||||
"${cmd[@]}"
|
||||
|
||||
21
.github/actions/slack-notify/action.yml
vendored
21
.github/actions/slack-notify/action.yml
vendored
@@ -21,26 +21,27 @@ runs:
|
||||
shell: bash
|
||||
env:
|
||||
SLACK_WEBHOOK_URL: ${{ inputs.webhook-url }}
|
||||
FAILED_JOBS: ${{ inputs.failed-jobs }}
|
||||
TITLE: ${{ inputs.title }}
|
||||
REF_NAME: ${{ inputs.ref-name }}
|
||||
REPO: ${{ github.repository }}
|
||||
WORKFLOW: ${{ github.workflow }}
|
||||
RUN_NUMBER: ${{ github.run_number }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
SERVER_URL: ${{ github.server_url }}
|
||||
GITHUB_REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -z "$SLACK_WEBHOOK_URL" ]; then
|
||||
echo "webhook-url input or SLACK_WEBHOOK_URL env var is not set, skipping notification"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Get inputs with defaults
|
||||
FAILED_JOBS="${{ inputs.failed-jobs }}"
|
||||
TITLE="${{ inputs.title }}"
|
||||
REF_NAME="${{ inputs.ref-name }}"
|
||||
REPO="${{ github.repository }}"
|
||||
WORKFLOW="${{ github.workflow }}"
|
||||
RUN_NUMBER="${{ github.run_number }}"
|
||||
RUN_ID="${{ github.run_id }}"
|
||||
SERVER_URL="${{ github.server_url }}"
|
||||
# Build workflow URL
|
||||
WORKFLOW_URL="${SERVER_URL}/${REPO}/actions/runs/${RUN_ID}"
|
||||
|
||||
# Use ref_name from input or fall back to github.ref_name
|
||||
if [ -z "$REF_NAME" ]; then
|
||||
REF_NAME="${{ github.ref_name }}"
|
||||
REF_NAME="$GITHUB_REF_NAME"
|
||||
fi
|
||||
|
||||
# Escape JSON special characters
|
||||
|
||||
4
.github/dependabot.yml
vendored
4
.github/dependabot.yml
vendored
@@ -4,6 +4,8 @@ updates:
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
cooldown:
|
||||
default-days: 4
|
||||
open-pull-requests-limit: 3
|
||||
assignees:
|
||||
- "jmelahman"
|
||||
@@ -13,6 +15,8 @@ updates:
|
||||
directory: "/backend"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
cooldown:
|
||||
default-days: 4
|
||||
open-pull-requests-limit: 3
|
||||
assignees:
|
||||
- "jmelahman"
|
||||
|
||||
8
.github/workflows/check-lazy-imports.yml
vendored
8
.github/workflows/check-lazy-imports.yml
vendored
@@ -10,13 +10,19 @@ on:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
check-lazy-imports:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # ratchet:actions/setup-python@v6
|
||||
|
||||
796
.github/workflows/deployment.yml
vendored
796
.github/workflows/deployment.yml
vendored
@@ -6,6 +6,9 @@ on:
|
||||
- "*"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
IS_DRY_RUN: ${{ github.event_name == 'workflow_dispatch' }}
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
@@ -15,6 +18,7 @@ jobs:
|
||||
determine-builds:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 90
|
||||
outputs:
|
||||
build-web: ${{ steps.check.outputs.build-web }}
|
||||
build-web-cloud: ${{ steps.check.outputs.build-web-cloud }}
|
||||
@@ -30,7 +34,7 @@ jobs:
|
||||
- name: Check which components to build and version info
|
||||
id: check
|
||||
run: |
|
||||
TAG="${{ github.ref_name }}"
|
||||
TAG="${GITHUB_REF_NAME}"
|
||||
# Sanitize tag name by replacing slashes with hyphens (for Docker tag compatibility)
|
||||
SANITIZED_TAG=$(echo "$TAG" | tr '/' '-')
|
||||
IS_CLOUD=false
|
||||
@@ -79,22 +83,146 @@ jobs:
|
||||
echo "sanitized-tag=$SANITIZED_TAG"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
build-web:
|
||||
build-web-amd64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-web == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-web-build
|
||||
- run-id=${{ github.run_id }}-web-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
DEPLOYMENT: standalone
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/amd64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-web-arm64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-web == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-web-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/arm64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
merge-web:
|
||||
needs:
|
||||
- determine-builds
|
||||
- build-web-amd64
|
||||
- build-web-arm64
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-merge-web
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -109,50 +237,38 @@ jobs:
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-beta == 'true' && 'beta' || '' }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-web-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-web-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
run: |
|
||||
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
|
||||
docker buildx imagetools create \
|
||||
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
|
||||
$IMAGES
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-${{ env.DEPLOYMENT }}-cache
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-${{ env.DEPLOYMENT }}-cache,mode=max
|
||||
|
||||
build-web-cloud:
|
||||
build-web-cloud-amd64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-web-cloud == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-web-cloud-build
|
||||
- run-id=${{ github.run_id }}-web-cloud-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
DEPLOYMENT: cloud
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -161,8 +277,6 @@ jobs:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.event_name == 'workflow_dispatch' && format('web-cloud-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
@@ -173,14 +287,13 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
platforms: linux/amd64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
@@ -195,27 +308,264 @@ jobs:
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-${{ env.DEPLOYMENT }}-cache
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-${{ env.DEPLOYMENT }}-cache,mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-backend:
|
||||
build-web-cloud-arm64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-web-cloud == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-web-cloud-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/arm64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NEXT_PUBLIC_CLOUD_ENABLED=true
|
||||
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
merge-web-cloud:
|
||||
needs:
|
||||
- determine-builds
|
||||
- build-web-cloud-amd64
|
||||
- build-web-cloud-arm64
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-merge-web-cloud
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.event_name == 'workflow_dispatch' && format('web-cloud-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-web-cloud-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-web-cloud-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
run: |
|
||||
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
|
||||
docker buildx imagetools create \
|
||||
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
|
||||
$IMAGES
|
||||
|
||||
build-backend-amd64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-backend == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-backend-build
|
||||
- run-id=${{ github.run_id }}-backend-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-backend-arm64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-backend == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-backend-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/arm64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
merge-backend:
|
||||
needs:
|
||||
- determine-builds
|
||||
- build-backend-amd64
|
||||
- build-backend-arm64
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-merge-backend
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -230,6 +580,162 @@ jobs:
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-backend-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-backend-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
run: |
|
||||
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
|
||||
docker buildx imagetools create \
|
||||
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
|
||||
$IMAGES
|
||||
|
||||
build-model-server-amd64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-model-server == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-model-server-amd64
|
||||
- volume=40gb
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
env:
|
||||
DEBUG: ${{ vars.DOCKER_DEBUG == 'true' && 1 || 0 }}
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ env.EDGE_TAG != 'true' && vars.MODEL_SERVER_NO_CACHE == 'true' }}
|
||||
provenance: false
|
||||
sbom: false
|
||||
|
||||
build-model-server-arm64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-model-server == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-model-server-arm64
|
||||
- volume=40gb
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
env:
|
||||
DEBUG: ${{ vars.DOCKER_DEBUG == 'true' && 1 || 0 }}
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/arm64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ env.EDGE_TAG != 'true' && vars.MODEL_SERVER_NO_CACHE == 'true' }}
|
||||
provenance: false
|
||||
sbom: false
|
||||
|
||||
merge-model-server:
|
||||
needs:
|
||||
- determine-builds
|
||||
- build-model-server-amd64
|
||||
- build-model-server-arm64
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-merge-model-server
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
@@ -239,44 +745,6 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-${{ env.DEPLOYMENT }}-cache
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-${{ env.DEPLOYMENT }}-cache,mode=max
|
||||
|
||||
build-model-server:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-model-server == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-model-server-build
|
||||
- volume=40gb
|
||||
- extras=ecr-cache
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
DOCKER_BUILDKIT: 1
|
||||
BUILDKIT_PROGRESS: plain
|
||||
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
@@ -290,45 +758,29 @@ jobs:
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver-opts: |
|
||||
image=moby/buildkit:latest
|
||||
network=host
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-${{ env.DEPLOYMENT }}-cache
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-${{ env.DEPLOYMENT }}-cache,mode=max
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-model-server-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-model-server-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
run: |
|
||||
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
|
||||
docker buildx imagetools create \
|
||||
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
|
||||
$IMAGES
|
||||
|
||||
trivy-scan-web:
|
||||
needs: [determine-builds, build-web]
|
||||
if: needs.build-web.result == 'success'
|
||||
needs:
|
||||
- determine-builds
|
||||
- merge-web
|
||||
if: needs.merge-web.result == 'success'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-trivy-scan-web
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
steps:
|
||||
@@ -359,13 +811,16 @@ jobs:
|
||||
${SCAN_IMAGE}
|
||||
|
||||
trivy-scan-web-cloud:
|
||||
needs: [determine-builds, build-web-cloud]
|
||||
if: needs.build-web-cloud.result == 'success'
|
||||
needs:
|
||||
- determine-builds
|
||||
- merge-web-cloud
|
||||
if: needs.merge-web-cloud.result == 'success'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-trivy-scan-web-cloud
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
steps:
|
||||
@@ -396,20 +851,25 @@ jobs:
|
||||
${SCAN_IMAGE}
|
||||
|
||||
trivy-scan-backend:
|
||||
needs: [determine-builds, build-backend]
|
||||
if: needs.build-backend.result == 'success'
|
||||
needs:
|
||||
- determine-builds
|
||||
- merge-backend
|
||||
if: needs.merge-backend.result == 'success'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-trivy-scan-backend
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
@@ -438,13 +898,16 @@ jobs:
|
||||
${SCAN_IMAGE}
|
||||
|
||||
trivy-scan-model-server:
|
||||
needs: [determine-builds, build-model-server]
|
||||
if: needs.build-model-server.result == 'success'
|
||||
needs:
|
||||
- determine-builds
|
||||
- merge-model-server
|
||||
if: needs.merge-model-server.result == 'success'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-trivy-scan-model-server
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
@@ -475,33 +938,86 @@ jobs:
|
||||
${SCAN_IMAGE}
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: [build-web, build-web-cloud, build-backend, build-model-server]
|
||||
if: always() && (needs.build-web.result == 'failure' || needs.build-web-cloud.result == 'failure' || needs.build-backend.result == 'failure' || needs.build-model-server.result == 'failure') && github.event_name != 'workflow_dispatch'
|
||||
needs:
|
||||
- build-web-amd64
|
||||
- build-web-arm64
|
||||
- merge-web
|
||||
- build-web-cloud-amd64
|
||||
- build-web-cloud-arm64
|
||||
- merge-web-cloud
|
||||
- build-backend-amd64
|
||||
- build-backend-arm64
|
||||
- merge-backend
|
||||
- build-model-server-amd64
|
||||
- build-model-server-arm64
|
||||
- merge-model-server
|
||||
if: always() && (needs.build-web-amd64.result == 'failure' || needs.build-web-arm64.result == 'failure' || needs.merge-web.result == 'failure' || needs.build-web-cloud-amd64.result == 'failure' || needs.build-web-cloud-arm64.result == 'failure' || needs.merge-web-cloud.result == 'failure' || needs.build-backend-amd64.result == 'failure' || needs.build-backend-arm64.result == 'failure' || needs.merge-backend.result == 'failure' || needs.build-model-server-amd64.result == 'failure' || needs.build-model-server-arm64.result == 'failure' || needs.merge-model-server.result == 'failure') && github.event_name != 'workflow_dispatch'
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Determine failed jobs
|
||||
id: failed-jobs
|
||||
shell: bash
|
||||
run: |
|
||||
FAILED_JOBS=""
|
||||
if [ "${{ needs.build-web.result }}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-web\\n"
|
||||
if [ "${NEEDS_BUILD_WEB_AMD64_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-web-amd64\\n"
|
||||
fi
|
||||
if [ "${{ needs.build-web-cloud.result }}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-web-cloud\\n"
|
||||
if [ "${NEEDS_BUILD_WEB_ARM64_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-web-arm64\\n"
|
||||
fi
|
||||
if [ "${{ needs.build-backend.result }}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-backend\\n"
|
||||
if [ "${NEEDS_MERGE_WEB_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• merge-web\\n"
|
||||
fi
|
||||
if [ "${{ needs.build-model-server.result }}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-model-server\\n"
|
||||
if [ "${NEEDS_BUILD_WEB_CLOUD_AMD64_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-web-cloud-amd64\\n"
|
||||
fi
|
||||
if [ "${NEEDS_BUILD_WEB_CLOUD_ARM64_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-web-cloud-arm64\\n"
|
||||
fi
|
||||
if [ "${NEEDS_MERGE_WEB_CLOUD_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• merge-web-cloud\\n"
|
||||
fi
|
||||
if [ "${NEEDS_BUILD_BACKEND_AMD64_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-backend-amd64\\n"
|
||||
fi
|
||||
if [ "${NEEDS_BUILD_BACKEND_ARM64_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-backend-arm64\\n"
|
||||
fi
|
||||
if [ "${NEEDS_MERGE_BACKEND_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• merge-backend\\n"
|
||||
fi
|
||||
if [ "${NEEDS_BUILD_MODEL_SERVER_AMD64_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-model-server-amd64\\n"
|
||||
fi
|
||||
if [ "${NEEDS_BUILD_MODEL_SERVER_ARM64_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-model-server-arm64\\n"
|
||||
fi
|
||||
if [ "${NEEDS_MERGE_MODEL_SERVER_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• merge-model-server\\n"
|
||||
fi
|
||||
# Remove trailing \n and set output
|
||||
FAILED_JOBS=$(printf '%s' "$FAILED_JOBS" | sed 's/\\n$//')
|
||||
echo "jobs=$FAILED_JOBS" >> "$GITHUB_OUTPUT"
|
||||
env:
|
||||
NEEDS_BUILD_WEB_AMD64_RESULT: ${{ needs.build-web-amd64.result }}
|
||||
NEEDS_BUILD_WEB_ARM64_RESULT: ${{ needs.build-web-arm64.result }}
|
||||
NEEDS_MERGE_WEB_RESULT: ${{ needs.merge-web.result }}
|
||||
NEEDS_BUILD_WEB_CLOUD_AMD64_RESULT: ${{ needs.build-web-cloud-amd64.result }}
|
||||
NEEDS_BUILD_WEB_CLOUD_ARM64_RESULT: ${{ needs.build-web-cloud-arm64.result }}
|
||||
NEEDS_MERGE_WEB_CLOUD_RESULT: ${{ needs.merge-web-cloud.result }}
|
||||
NEEDS_BUILD_BACKEND_AMD64_RESULT: ${{ needs.build-backend-amd64.result }}
|
||||
NEEDS_BUILD_BACKEND_ARM64_RESULT: ${{ needs.build-backend-arm64.result }}
|
||||
NEEDS_MERGE_BACKEND_RESULT: ${{ needs.merge-backend.result }}
|
||||
NEEDS_BUILD_MODEL_SERVER_AMD64_RESULT: ${{ needs.build-model-server-amd64.result }}
|
||||
NEEDS_BUILD_MODEL_SERVER_ARM64_RESULT: ${{ needs.build-model-server-arm64.result }}
|
||||
NEEDS_MERGE_MODEL_SERVER_RESULT: ${{ needs.merge-model-server.result }}
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
|
||||
16
.github/workflows/docker-tag-beta.yml
vendored
16
.github/workflows/docker-tag-beta.yml
vendored
@@ -10,11 +10,15 @@ on:
|
||||
description: "The version (ie v1.0.0-beta.0) to tag as beta"
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
tag:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# use a lower powered instance since this just does i/o to docker hub
|
||||
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-tag"]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
@@ -29,13 +33,19 @@ jobs:
|
||||
run: echo "DOCKER_CLI_EXPERIMENTAL=enabled" >> $GITHUB_ENV
|
||||
|
||||
- name: Pull, Tag and Push Web Server Image
|
||||
env:
|
||||
VERSION: ${{ github.event.inputs.version }}
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-web-server:beta onyxdotapp/onyx-web-server:${{ github.event.inputs.version }}
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-web-server:beta onyxdotapp/onyx-web-server:${VERSION}
|
||||
|
||||
- name: Pull, Tag and Push API Server Image
|
||||
env:
|
||||
VERSION: ${{ github.event.inputs.version }}
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-backend:beta onyxdotapp/onyx-backend:${{ github.event.inputs.version }}
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-backend:beta onyxdotapp/onyx-backend:${VERSION}
|
||||
|
||||
- name: Pull, Tag and Push Model Server Image
|
||||
env:
|
||||
VERSION: ${{ github.event.inputs.version }}
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-model-server:beta onyxdotapp/onyx-model-server:${{ github.event.inputs.version }}
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-model-server:beta onyxdotapp/onyx-model-server:${VERSION}
|
||||
|
||||
16
.github/workflows/docker-tag-latest.yml
vendored
16
.github/workflows/docker-tag-latest.yml
vendored
@@ -10,11 +10,15 @@ on:
|
||||
description: "The version (ie v0.0.1) to tag as latest"
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
tag:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# use a lower powered instance since this just does i/o to docker hub
|
||||
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-tag"]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
@@ -29,13 +33,19 @@ jobs:
|
||||
run: echo "DOCKER_CLI_EXPERIMENTAL=enabled" >> $GITHUB_ENV
|
||||
|
||||
- name: Pull, Tag and Push Web Server Image
|
||||
env:
|
||||
VERSION: ${{ github.event.inputs.version }}
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:${{ github.event.inputs.version }}
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:${VERSION}
|
||||
|
||||
- name: Pull, Tag and Push API Server Image
|
||||
env:
|
||||
VERSION: ${{ github.event.inputs.version }}
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-backend:latest onyxdotapp/onyx-backend:${{ github.event.inputs.version }}
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-backend:latest onyxdotapp/onyx-backend:${VERSION}
|
||||
|
||||
- name: Pull, Tag and Push Model Server Image
|
||||
env:
|
||||
VERSION: ${{ github.event.inputs.version }}
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:${{ github.event.inputs.version }}
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:${VERSION}
|
||||
|
||||
4
.github/workflows/helm-chart-releases.yml
vendored
4
.github/workflows/helm-chart-releases.yml
vendored
@@ -12,11 +12,13 @@ jobs:
|
||||
permissions:
|
||||
contents: write
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Helm CLI
|
||||
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4
|
||||
|
||||
@@ -11,8 +11,9 @@ permissions:
|
||||
jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: actions/stale@5bef64f19d7facfb25b37b414482c7164d639639 # ratchet:actions/stale@v9
|
||||
- uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # ratchet:actions/stale@v10
|
||||
with:
|
||||
stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
|
||||
15
.github/workflows/nightly-scan-licenses.yml
vendored
15
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -15,16 +15,22 @@ on:
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
jobs:
|
||||
scan-licenses:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-licenses"]
|
||||
timeout-minutes: 45
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # ratchet:actions/setup-python@v6
|
||||
@@ -54,7 +60,9 @@ jobs:
|
||||
|
||||
- name: Print report
|
||||
if: always()
|
||||
run: echo "${{ steps.license_check_report.outputs.report }}"
|
||||
env:
|
||||
REPORT: ${{ steps.license_check_report.outputs.report }}
|
||||
run: echo "$REPORT"
|
||||
|
||||
- name: Install npm dependencies
|
||||
working-directory: ./web
|
||||
@@ -82,6 +90,7 @@ jobs:
|
||||
scan-trivy:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-trivy"]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
|
||||
@@ -8,13 +8,18 @@ on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
# AWS
|
||||
S3_AWS_ACCESS_KEY_ID: ${{ secrets.S3_AWS_ACCESS_KEY_ID }}
|
||||
S3_AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_AWS_SECRET_ACCESS_KEY }}
|
||||
# AWS credentials for S3-specific test
|
||||
S3_AWS_ACCESS_KEY_ID_FOR_TEST: ${{ secrets.S3_AWS_ACCESS_KEY_ID }}
|
||||
S3_AWS_SECRET_ACCESS_KEY_FOR_TEST: ${{ secrets.S3_AWS_SECRET_ACCESS_KEY }}
|
||||
|
||||
# MinIO
|
||||
S3_ENDPOINT_URL: "http://localhost:9004"
|
||||
S3_AWS_ACCESS_KEY_ID: "minioadmin"
|
||||
S3_AWS_SECRET_ACCESS_KEY: "minioadmin"
|
||||
|
||||
# Confluence
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
|
||||
@@ -28,15 +33,22 @@ env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
|
||||
# Code Interpreter
|
||||
# TODO: debug why this is failing and enable
|
||||
CODE_INTERPRETER_BASE_URL: http://localhost:8000
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
outputs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Discover test directories
|
||||
id: set-matrix
|
||||
@@ -53,6 +65,7 @@ jobs:
|
||||
- runner=2cpu-linux-arm64
|
||||
- ${{ format('run-id={0}-external-dependency-unit-tests-job-{1}', github.run_id, strategy['job-index']) }}
|
||||
- extras=s3-cache
|
||||
timeout-minutes: 45
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -66,10 +79,17 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Python and Install Dependencies
|
||||
uses: ./.github/actions/setup-python-and-install-dependencies
|
||||
with:
|
||||
requirements: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/ee.txt
|
||||
|
||||
- name: Setup Playwright
|
||||
uses: ./.github/actions/setup-playwright
|
||||
@@ -83,10 +103,24 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create .env file for Docker Compose
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
CODE_INTERPRETER_BETA_ENABLED=true
|
||||
EOF
|
||||
|
||||
- name: Set up Standard Dependencies
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d minio relational_db cache index
|
||||
docker compose \
|
||||
-f docker-compose.yml \
|
||||
-f docker-compose.dev.yml \
|
||||
up -d \
|
||||
minio \
|
||||
relational_db \
|
||||
cache \
|
||||
index \
|
||||
code-interpreter
|
||||
|
||||
- name: Run migrations
|
||||
run: |
|
||||
@@ -97,10 +131,39 @@ jobs:
|
||||
|
||||
- name: Run Tests for ${{ matrix.test-dir }}
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
env:
|
||||
TEST_DIR: ${{ matrix.test-dir }}
|
||||
run: |
|
||||
py.test \
|
||||
--durations=8 \
|
||||
-o junit_family=xunit2 \
|
||||
-xv \
|
||||
--ff \
|
||||
backend/tests/external_dependency_unit/${{ matrix.test-dir }}
|
||||
backend/tests/external_dependency_unit/${TEST_DIR}
|
||||
|
||||
- name: Collect Docker logs on failure
|
||||
if: failure()
|
||||
run: |
|
||||
mkdir -p docker-logs
|
||||
cd deployment/docker_compose
|
||||
|
||||
# Get list of running containers
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml ps -q)
|
||||
|
||||
# Collect logs from each container
|
||||
for container in $containers; do
|
||||
container_name=$(docker inspect --format='{{.Name}}' $container | sed 's/^\///')
|
||||
echo "Collecting logs from $container_name..."
|
||||
docker logs $container > ../../docker-logs/${container_name}.log 2>&1
|
||||
done
|
||||
|
||||
cd ../..
|
||||
echo "Docker logs collected in docker-logs directory"
|
||||
|
||||
- name: Upload Docker logs
|
||||
if: failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-logs-${{ matrix.test-dir }}
|
||||
path: docker-logs/
|
||||
retention-days: 7
|
||||
|
||||
13
.github/workflows/pr-helm-chart-testing.yml
vendored
13
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -9,17 +9,22 @@ on:
|
||||
branches: [ main ]
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
helm-chart-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}-helm-chart-check"]
|
||||
timeout-minutes: 45
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4.3.1
|
||||
@@ -32,9 +37,11 @@ jobs:
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
env:
|
||||
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
|
||||
run: |
|
||||
echo "default_branch: ${{ github.event.repository.default_branch }}"
|
||||
changed=$(ct list-changed --remote origin --target-branch ${{ github.event.repository.default_branch }} --chart-dirs deployment/helm/charts)
|
||||
echo "default_branch: ${DEFAULT_BRANCH}"
|
||||
changed=$(ct list-changed --remote origin --target-branch ${DEFAULT_BRANCH} --chart-dirs deployment/helm/charts)
|
||||
echo "list-changed output: $changed"
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
|
||||
160
.github/workflows/pr-integration-tests.yml
vendored
160
.github/workflows/pr-integration-tests.yml
vendored
@@ -10,6 +10,9 @@ on:
|
||||
- main
|
||||
- "release/**"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
@@ -32,11 +35,14 @@ jobs:
|
||||
discover-test-dirs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
outputs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Discover test directories
|
||||
id: set-matrix
|
||||
@@ -61,10 +67,13 @@ jobs:
|
||||
|
||||
build-backend-image:
|
||||
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-backend-image", "extras=ecr-cache"]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
@@ -85,17 +94,23 @@ jobs:
|
||||
file: ./backend/Dockerfile
|
||||
push: true
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}
|
||||
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-cache
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-cache,mode=max
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
|
||||
type=registry,ref=onyxdotapp/onyx-backend:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-model-server-image", "extras=ecr-cache"]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
@@ -116,16 +131,21 @@ jobs:
|
||||
file: ./backend/Dockerfile.model_server
|
||||
push: true
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}
|
||||
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-cache
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-cache,mode=max
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
|
||||
type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
|
||||
|
||||
build-integration-image:
|
||||
runs-on: [runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-build-integration-image", "extras=ecr-cache"]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
@@ -141,9 +161,16 @@ jobs:
|
||||
|
||||
- name: Build and push integration test image with Docker Bake
|
||||
env:
|
||||
REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
INTEGRATION_REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
TAG: integration-test-${{ github.run_id }}
|
||||
run: cd backend && docker buildx bake --push integration
|
||||
run: |
|
||||
cd backend && docker buildx bake --push \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
|
||||
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
|
||||
integration
|
||||
|
||||
integration-tests:
|
||||
needs:
|
||||
@@ -158,6 +185,7 @@ jobs:
|
||||
- runner=4cpu-linux-arm64
|
||||
- ${{ format('run-id={0}-integration-tests-job-{1}', github.run_id, strategy['job-index']) }}
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -167,7 +195,9 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -181,6 +211,9 @@ jobs:
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
# NOTE: don't need web server for integration tests
|
||||
- name: Start Docker containers
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
@@ -189,10 +222,11 @@ jobs:
|
||||
POSTGRES_USE_NULL_POOL=true \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
ONYX_BACKEND_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }} \
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 \
|
||||
MCP_SERVER_ENABLED=true \
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
@@ -201,43 +235,56 @@ jobs:
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
mcp_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
docker logs -f onyx-api_server-1 &
|
||||
wait_for_service() {
|
||||
local url=$1
|
||||
local label=$2
|
||||
local timeout=${3:-300} # default 5 minutes
|
||||
local start_time
|
||||
start_time=$(date +%s)
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
while true; do
|
||||
local current_time
|
||||
current_time=$(date +%s)
|
||||
local elapsed_time=$((current_time - start_time))
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. ${label} did not become ready in $timeout seconds."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
local response
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" "$url" || echo "curl_error")
|
||||
|
||||
# Use curl with error handling to ignore specific exit code 56
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "${label} is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error while checking ${label}. Retrying in 5 seconds..."
|
||||
else
|
||||
echo "${label} not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
}
|
||||
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
wait_for_service "http://localhost:8080/health" "API server"
|
||||
test_dir="${{ matrix.test-dir.path }}"
|
||||
if [ "$test_dir" = "tests/mcp" ]; then
|
||||
wait_for_service "http://localhost:8090/health" "MCP server"
|
||||
else
|
||||
echo "Skipping MCP server wait for non-MCP suite: $test_dir"
|
||||
fi
|
||||
echo "Finished waiting for services."
|
||||
|
||||
- name: Start Mock Services
|
||||
run: |
|
||||
@@ -266,7 +313,10 @@ jobs:
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e MCP_SERVER_HOST=mcp_server \
|
||||
-e MCP_SERVER_PORT=8090 \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
@@ -317,11 +367,14 @@ jobs:
|
||||
build-integration-image,
|
||||
]
|
||||
runs-on: [runs-on, runner=8cpu-linux-arm64, "run-id=${{ github.run_id }}-multitenant-tests", "extras=ecr-cache"]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
@@ -330,6 +383,9 @@ jobs:
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Start Docker containers for multi-tenant tests
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
@@ -337,9 +393,10 @@ jobs:
|
||||
AUTH_TYPE=cloud \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
ONYX_BACKEND_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }} \
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
DEV_MODE=true \
|
||||
MCP_SERVER_ENABLED=true \
|
||||
docker compose -f docker-compose.multitenant-dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
@@ -348,6 +405,7 @@ jobs:
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
mcp_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker_multi_tenant
|
||||
@@ -379,6 +437,9 @@ jobs:
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run Multi-Tenant Integration Tests
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
echo "Running multi-tenant integration tests..."
|
||||
docker run --rm --network onyx_default \
|
||||
@@ -393,7 +454,10 @@ jobs:
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e MCP_SERVER_HOST=mcp_server \
|
||||
-e MCP_SERVER_PORT=8090 \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AUTH_TYPE=cloud \
|
||||
@@ -402,7 +466,7 @@ jobs:
|
||||
-e REQUIRE_EMAIL_VERIFICATION=false \
|
||||
-e DISABLE_TELEMETRY=true \
|
||||
-e DEV_MODE=true \
|
||||
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
|
||||
${ECR_CACHE}:integration-test-${RUN_ID} \
|
||||
/app/tests/integration/multitenant_tests
|
||||
|
||||
- name: Dump API server logs (multi-tenant)
|
||||
@@ -433,16 +497,10 @@ jobs:
|
||||
required:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
needs: [integration-tests, multitenant-tests]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # ratchet:actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
const needs = ${{ toJSON(needs) }};
|
||||
const failed = Object.values(needs).some(n => n.result !== 'success');
|
||||
if (failed) {
|
||||
core.setFailed('One or more upstream jobs failed or were cancelled.');
|
||||
} else {
|
||||
core.notice('All required jobs succeeded.');
|
||||
}
|
||||
- name: Check job status
|
||||
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
|
||||
run: exit 1
|
||||
|
||||
8
.github/workflows/pr-jest-tests.yml
vendored
8
.github/workflows/pr-jest-tests.yml
vendored
@@ -5,13 +5,19 @@ concurrency:
|
||||
|
||||
on: push
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
jest-tests:
|
||||
name: Jest Tests
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4
|
||||
|
||||
4
.github/workflows/pr-labeler.yml
vendored
4
.github/workflows/pr-labeler.yml
vendored
@@ -1,7 +1,7 @@
|
||||
name: PR Labeler
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
types:
|
||||
@@ -12,11 +12,11 @@ on:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
validate_pr_title:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Check PR title for Conventional Commits
|
||||
env:
|
||||
|
||||
4
.github/workflows/pr-linear-check.yml
vendored
4
.github/workflows/pr-linear-check.yml
vendored
@@ -7,9 +7,13 @@ on:
|
||||
pull_request:
|
||||
types: [opened, edited, reopened, synchronize]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
linear-check:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Check PR body for Linear link or override
|
||||
env:
|
||||
|
||||
138
.github/workflows/pr-mit-integration-tests.yml
vendored
138
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -7,10 +7,14 @@ on:
|
||||
merge_group:
|
||||
types: [checks_requested]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
@@ -28,11 +32,14 @@ jobs:
|
||||
discover-test-dirs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
outputs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Discover test directories
|
||||
id: set-matrix
|
||||
@@ -56,10 +63,13 @@ jobs:
|
||||
|
||||
build-backend-image:
|
||||
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-backend-image", "extras=ecr-cache"]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
@@ -80,16 +90,21 @@ jobs:
|
||||
file: ./backend/Dockerfile
|
||||
push: true
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}
|
||||
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-cache
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-cache,mode=max
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
|
||||
type=registry,ref=onyxdotapp/onyx-backend:latest
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-model-server-image", "extras=ecr-cache"]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
@@ -110,15 +125,20 @@ jobs:
|
||||
file: ./backend/Dockerfile.model_server
|
||||
push: true
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}
|
||||
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-cache
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-cache,mode=max
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
|
||||
type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
|
||||
build-integration-image:
|
||||
runs-on: [runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-build-integration-image", "extras=ecr-cache"]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
@@ -134,9 +154,16 @@ jobs:
|
||||
|
||||
- name: Build and push integration test image with Docker Bake
|
||||
env:
|
||||
REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
INTEGRATION_REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
TAG: integration-test-${{ github.run_id }}
|
||||
run: cd backend && docker buildx bake --push integration
|
||||
run: |
|
||||
cd backend && docker buildx bake --push \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
|
||||
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
|
||||
integration
|
||||
|
||||
integration-tests-mit:
|
||||
needs:
|
||||
@@ -151,6 +178,7 @@ jobs:
|
||||
- runner=4cpu-linux-arm64
|
||||
- ${{ format('run-id={0}-integration-tests-mit-job-{1}', github.run_id, strategy['job-index']) }}
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -160,7 +188,9 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -174,6 +204,9 @@ jobs:
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
# NOTE: don't need web server for integration tests
|
||||
- name: Start Docker containers
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
AUTH_TYPE=basic \
|
||||
@@ -181,9 +214,10 @@ jobs:
|
||||
POSTGRES_USE_NULL_POOL=true \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
ONYX_BACKEND_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }} \
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
MCP_SERVER_ENABLED=true \
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
@@ -192,43 +226,56 @@ jobs:
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
mcp_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
docker logs -f onyx-api_server-1 &
|
||||
wait_for_service() {
|
||||
local url=$1
|
||||
local label=$2
|
||||
local timeout=${3:-300} # default 5 minutes
|
||||
local start_time
|
||||
start_time=$(date +%s)
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
while true; do
|
||||
local current_time
|
||||
current_time=$(date +%s)
|
||||
local elapsed_time=$((current_time - start_time))
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. ${label} did not become ready in $timeout seconds."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
local response
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" "$url" || echo "curl_error")
|
||||
|
||||
# Use curl with error handling to ignore specific exit code 56
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "${label} is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error while checking ${label}. Retrying in 5 seconds..."
|
||||
else
|
||||
echo "${label} not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
}
|
||||
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
wait_for_service "http://localhost:8080/health" "API server"
|
||||
test_dir="${{ matrix.test-dir.path }}"
|
||||
if [ "$test_dir" = "tests/mcp" ]; then
|
||||
wait_for_service "http://localhost:8090/health" "MCP server"
|
||||
else
|
||||
echo "Skipping MCP server wait for non-MCP suite: $test_dir"
|
||||
fi
|
||||
echo "Finished waiting for services."
|
||||
|
||||
- name: Start Mock Services
|
||||
run: |
|
||||
@@ -258,7 +305,10 @@ jobs:
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e MCP_SERVER_HOST=mcp_server \
|
||||
-e MCP_SERVER_PORT=8090 \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
@@ -304,16 +354,10 @@ jobs:
|
||||
required:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
needs: [integration-tests-mit]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # ratchet:actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
const needs = ${{ toJSON(needs) }};
|
||||
const failed = Object.values(needs).some(n => n.result !== 'success');
|
||||
if (failed) {
|
||||
core.setFailed('One or more upstream jobs failed or were cancelled.');
|
||||
} else {
|
||||
core.notice('All required jobs succeeded.');
|
||||
}
|
||||
- name: Check job status
|
||||
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
|
||||
run: exit 1
|
||||
|
||||
147
.github/workflows/pr-playwright-tests.yml
vendored
147
.github/workflows/pr-playwright-tests.yml
vendored
@@ -5,6 +5,9 @@ concurrency:
|
||||
|
||||
on: push
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
@@ -24,6 +27,13 @@ env:
|
||||
MCP_OAUTH_USERNAME: ${{ vars.MCP_OAUTH_USERNAME }}
|
||||
MCP_OAUTH_PASSWORD: ${{ secrets.MCP_OAUTH_PASSWORD }}
|
||||
|
||||
# for MCP API Key tests
|
||||
MCP_API_KEY: test-api-key-12345
|
||||
MCP_API_KEY_TEST_PORT: 8005
|
||||
MCP_API_KEY_TEST_URL: http://host.docker.internal:8005/mcp
|
||||
MCP_API_KEY_SERVER_HOST: 0.0.0.0
|
||||
MCP_API_KEY_SERVER_PUBLIC_HOST: host.docker.internal
|
||||
|
||||
MOCK_LLM_RESPONSE: true
|
||||
MCP_TEST_SERVER_PORT: 8004
|
||||
MCP_TEST_SERVER_URL: http://host.docker.internal:8004/mcp
|
||||
@@ -37,11 +47,14 @@ env:
|
||||
jobs:
|
||||
build-web-image:
|
||||
runs-on: [runs-on, runner=4cpu-linux-arm64, "run-id=${{ github.run_id }}-build-web-image", "extras=ecr-cache"]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
@@ -62,17 +75,22 @@ jobs:
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-${{ github.run_id }}
|
||||
push: true
|
||||
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-cache
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-cache,mode=max
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache
|
||||
type=registry,ref=onyxdotapp/onyx-web-server:latest
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache,mode=max
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-backend-image:
|
||||
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-backend-image", "extras=ecr-cache"]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
@@ -93,17 +111,23 @@ jobs:
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-${{ github.run_id }}
|
||||
push: true
|
||||
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-cache
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-cache,mode=max
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
|
||||
type=registry,ref=onyxdotapp/onyx-backend:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-model-server-image", "extras=ecr-cache"]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
@@ -124,14 +148,22 @@ jobs:
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-${{ github.run_id }}
|
||||
push: true
|
||||
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-cache
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-cache,mode=max
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
|
||||
type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
playwright-tests:
|
||||
needs: [build-web-image, build-backend-image, build-model-server-image]
|
||||
name: Playwright Tests (${{ matrix.project }})
|
||||
runs-on: [runs-on, runner=8cpu-linux-arm64, "run-id=${{ github.run_id }}-playwright-tests-${{ matrix.project }}", "extras=ecr-cache"]
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=8cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-playwright-tests-${{ matrix.project }}"
|
||||
- "extras=ecr-cache"
|
||||
- volume=50gb
|
||||
timeout-minutes: 45
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -140,9 +172,10 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4
|
||||
@@ -168,18 +201,26 @@ jobs:
|
||||
run: npx playwright install --with-deps
|
||||
|
||||
- name: Create .env file for Docker Compose
|
||||
env:
|
||||
OPENAI_API_KEY_VALUE: ${{ env.OPENAI_API_KEY }}
|
||||
EXA_API_KEY_VALUE: ${{ env.EXA_API_KEY }}
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
AUTH_TYPE=basic
|
||||
GEN_AI_API_KEY=${{ env.OPENAI_API_KEY }}
|
||||
EXA_API_KEY=${{ env.EXA_API_KEY }}
|
||||
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
|
||||
EXA_API_KEY=${EXA_API_KEY_VALUE}
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
ONYX_BACKEND_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-${{ github.run_id }}
|
||||
ONYX_MODEL_SERVER_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-${{ github.run_id }}
|
||||
ONYX_WEB_SERVER_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-${{ github.run_id }}
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:playwright-test-backend-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:playwright-test-model-server-${RUN_ID}
|
||||
ONYX_WEB_SERVER_IMAGE=${ECR_CACHE}:playwright-test-web-${RUN_ID}
|
||||
EOF
|
||||
if [ "${{ matrix.project }}" = "no-auth" ]; then
|
||||
echo "PLAYWRIGHT_FORCE_EMPTY_LLM_PROVIDERS=true" >> deployment/docker_compose/.env
|
||||
fi
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -193,7 +234,7 @@ jobs:
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.mcp-oauth-test.yml up -d
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.mcp-oauth-test.yml -f docker-compose.mcp-api-key-test.yml up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
@@ -253,12 +294,65 @@ jobs:
|
||||
sleep 3
|
||||
done
|
||||
|
||||
- name: Wait for MCP API Key mock server
|
||||
run: |
|
||||
echo "Waiting for MCP API Key mock server on port ${MCP_API_KEY_TEST_PORT:-8005}..."
|
||||
start_time=$(date +%s)
|
||||
timeout=120
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. MCP API Key mock server did not become ready in ${timeout}s."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if curl -sf "http://localhost:${MCP_API_KEY_TEST_PORT:-8005}/healthz" > /dev/null; then
|
||||
echo "MCP API Key mock server is ready!"
|
||||
break
|
||||
fi
|
||||
|
||||
sleep 3
|
||||
done
|
||||
|
||||
- name: Wait for web server to be ready
|
||||
run: |
|
||||
echo "Waiting for web server on port 3000..."
|
||||
start_time=$(date +%s)
|
||||
timeout=120
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Web server did not become ready in ${timeout}s."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if curl -sf "http://localhost:3000/api/health" > /dev/null 2>&1 || \
|
||||
curl -sf "http://localhost:3000/" > /dev/null 2>&1; then
|
||||
echo "Web server is ready!"
|
||||
break
|
||||
fi
|
||||
|
||||
echo "Web server not ready yet. Retrying in 3 seconds..."
|
||||
sleep 3
|
||||
done
|
||||
|
||||
- name: Run Playwright tests
|
||||
working-directory: ./web
|
||||
env:
|
||||
PROJECT: ${{ matrix.project }}
|
||||
run: |
|
||||
# Create test-results directory to ensure it exists for artifact upload
|
||||
mkdir -p test-results
|
||||
npx playwright test --project ${{ matrix.project }}
|
||||
if [ "${PROJECT}" = "no-auth" ]; then
|
||||
export PLAYWRIGHT_FORCE_EMPTY_LLM_PROVIDERS=true
|
||||
fi
|
||||
npx playwright test --project ${PROJECT}
|
||||
|
||||
- uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
if: always()
|
||||
@@ -271,10 +365,12 @@ jobs:
|
||||
# save before stopping the containers so the logs can be captured
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
env:
|
||||
WORKSPACE: ${{ github.workspace }}
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs > docker-compose.log
|
||||
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
|
||||
mv docker-compose.log ${WORKSPACE}/docker-compose.log
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
@@ -283,6 +379,17 @@ jobs:
|
||||
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
playwright-required:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
needs: [playwright-tests]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Check job status
|
||||
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
|
||||
run: exit 1
|
||||
|
||||
|
||||
# NOTE: Chromatic UI diff testing is currently disabled.
|
||||
# We are using Playwright for local and CI testing without visual regression checks.
|
||||
@@ -301,7 +408,7 @@ jobs:
|
||||
# ]
|
||||
# steps:
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
# uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
|
||||
|
||||
60
.github/workflows/pr-python-checks.yml
vendored
60
.github/workflows/pr-python-checks.yml
vendored
@@ -10,17 +10,58 @@ on:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
validate-requirements:
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@caf0cab7a618c569241d31dcd442f54681755d39 # ratchet:astral-sh/setup-uv@v3
|
||||
# TODO: Enable caching once there is a uv.lock file checked in.
|
||||
# with:
|
||||
# enable-cache: true
|
||||
|
||||
- name: Validate requirements lock files
|
||||
run: ./backend/scripts/compile_requirements.py --check
|
||||
|
||||
mypy-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# Note: Mypy seems quite optimized for x64 compared to arm64.
|
||||
# Similarly, mypy is single-threaded and incremental, so 2cpu is sufficient.
|
||||
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-mypy-check", "extras=s3-cache"]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Python and Install Dependencies
|
||||
uses: ./.github/actions/setup-python-and-install-dependencies
|
||||
with:
|
||||
requirements: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
backend/requirements/ee.txt
|
||||
|
||||
- name: Generate OpenAPI schema
|
||||
shell: bash
|
||||
working-directory: backend
|
||||
env:
|
||||
PYTHONPATH: "."
|
||||
run: |
|
||||
python scripts/onyx_openapi_schema.py --filename generated/openapi.json
|
||||
|
||||
# needed for pulling openapitools/openapi-generator-cli
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -31,11 +72,18 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Prepare build
|
||||
uses: ./.github/actions/prepare-build
|
||||
with:
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Generate OpenAPI Python client
|
||||
shell: bash
|
||||
run: |
|
||||
docker run --rm \
|
||||
-v "${{ github.workspace }}/backend/generated:/local" \
|
||||
openapitools/openapi-generator-cli generate \
|
||||
-i /local/openapi.json \
|
||||
-g python \
|
||||
-o /local/onyx_openapi_client \
|
||||
--package-name onyx_openapi_client \
|
||||
--skip-validate-spec \
|
||||
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
|
||||
|
||||
- name: Cache mypy cache
|
||||
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
|
||||
|
||||
16
.github/workflows/pr-python-connector-tests.yml
vendored
16
.github/workflows/pr-python-connector-tests.yml
vendored
@@ -11,6 +11,9 @@ on:
|
||||
# This cron expression runs the job daily at 16:00 UTC (9am PT)
|
||||
- cron: "0 16 * * *"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
# AWS
|
||||
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
|
||||
@@ -123,6 +126,7 @@ jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-connectors-check", "extras=s3-cache"]
|
||||
timeout-minutes: 45
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
@@ -131,10 +135,16 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Python and Install Dependencies
|
||||
uses: ./.github/actions/setup-python-and-install-dependencies
|
||||
with:
|
||||
requirements: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Setup Playwright
|
||||
uses: ./.github/actions/setup-playwright
|
||||
@@ -214,8 +224,10 @@ jobs:
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
env:
|
||||
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
|
||||
REPO: ${{ github.repository }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
curl -X POST \
|
||||
-H 'Content-type: application/json' \
|
||||
--data '{"text":"Scheduled Connector Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
|
||||
--data "{\"text\":\"Scheduled Connector Tests failed! Check the run at: https://github.com/${REPO}/actions/runs/${RUN_ID}\"}" \
|
||||
$SLACK_WEBHOOK
|
||||
|
||||
12
.github/workflows/pr-python-model-tests.yml
vendored
12
.github/workflows/pr-python-model-tests.yml
vendored
@@ -11,6 +11,9 @@ on:
|
||||
required: false
|
||||
default: 'main'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
# Bedrock
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
@@ -29,13 +32,16 @@ jobs:
|
||||
model-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}-model-check"]
|
||||
timeout-minutes: 45
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
@@ -122,10 +128,12 @@ jobs:
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
env:
|
||||
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
|
||||
REPO: ${{ github.repository }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
curl -X POST \
|
||||
-H 'Content-type: application/json' \
|
||||
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
|
||||
--data "{\"text\":\"Scheduled Model Tests failed! Check the run at: https://github.com/${REPO}/actions/runs/${RUN_ID}\"}" \
|
||||
$SLACK_WEBHOOK
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
|
||||
14
.github/workflows/pr-python-tests.yml
vendored
14
.github/workflows/pr-python-tests.yml
vendored
@@ -10,10 +10,14 @@ on:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
backend-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-backend-check"]
|
||||
timeout-minutes: 45
|
||||
|
||||
|
||||
env:
|
||||
@@ -27,10 +31,18 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Python and Install Dependencies
|
||||
uses: ./.github/actions/setup-python-and-install-dependencies
|
||||
with:
|
||||
requirements: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
backend/requirements/ee.txt
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
|
||||
15
.github/workflows/pr-quality-checks.yml
vendored
15
.github/workflows/pr-quality-checks.yml
vendored
@@ -7,20 +7,33 @@ on:
|
||||
merge_group:
|
||||
pull_request: null
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
quality-checks:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-quality-checks"]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # ratchet:actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Setup Terraform
|
||||
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # ratchet:hashicorp/setup-terraform@v3
|
||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet:pre-commit/action@v3.0.1
|
||||
env:
|
||||
# uv-run is mypy's id and mypy is covered by the Python Checks which caches dependencies better.
|
||||
SKIP: uv-run
|
||||
with:
|
||||
extra_args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || '' }}
|
||||
- name: Check Actions
|
||||
uses: giner/check-actions@28d366c7cbbe235f9624a88aa31a628167eee28c # ratchet:giner/check-actions@v1.0.1
|
||||
with:
|
||||
check_permissions: false
|
||||
check_versions: false
|
||||
|
||||
4
.github/workflows/sync_foss.yml
vendored
4
.github/workflows/sync_foss.yml
vendored
@@ -9,13 +9,15 @@ on:
|
||||
jobs:
|
||||
sync-foss:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout main Onyx repo
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install git-filter-repo
|
||||
run: |
|
||||
|
||||
25
.github/workflows/tag-nightly.yml
vendored
25
.github/workflows/tag-nightly.yml
vendored
@@ -3,30 +3,30 @@ name: Nightly Tag Push
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 10 * * *" # Runs every day at 2 AM PST / 3 AM PDT / 10 AM UTC
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: write # Allows pushing tags to the repository
|
||||
|
||||
jobs:
|
||||
create-and-push-tag:
|
||||
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-create-and-push-tag"]
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
# actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes
|
||||
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
|
||||
# implement here which needs an actual user's deploy key
|
||||
|
||||
# Additional NOTE: even though this is named "rkuo", the actual key is tied to the onyx repo
|
||||
# and not rkuo's personal account. It is fine to leave this key as is!
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
|
||||
ssh-key: "${{ secrets.DEPLOY_KEY }}"
|
||||
persist-credentials: true
|
||||
|
||||
- name: Set up Git user
|
||||
run: |
|
||||
git config user.name "Richard Kuo [bot]"
|
||||
git config user.email "rkuo[bot]@onyx.app"
|
||||
git config user.name "Onyx Bot [bot]"
|
||||
git config user.email "onyx-bot[bot]@onyx.app"
|
||||
|
||||
- name: Check for existing nightly tag
|
||||
id: check_tag
|
||||
@@ -54,3 +54,12 @@ jobs:
|
||||
run: |
|
||||
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
|
||||
git push origin $TAG_NAME
|
||||
|
||||
- name: Send Slack notification
|
||||
if: failure()
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
title: "🚨 Nightly Tag Push Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
failed-jobs: "create-and-push-tag"
|
||||
|
||||
36
.github/workflows/zizmor.yml
vendored
Normal file
36
.github/workflows/zizmor.yml
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
name: Run Zizmor
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["**"]
|
||||
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
zizmor:
|
||||
name: zizmor
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
permissions:
|
||||
security-events: write # needed for SARIF uploads
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@5a7eac68fb9809dea845d802897dc5c723910fa3 # ratchet:astral-sh/setup-uv@v7.1.3
|
||||
|
||||
- name: Run zizmor
|
||||
run: uvx zizmor==1.16.3 --format=sarif . > results.sarif
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Upload SARIF file
|
||||
uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab # ratchet:github/codeql-action/upload-sarif@codeql-bundle-v2.23.5
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
category: zizmor
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -1,6 +1,7 @@
|
||||
# editors
|
||||
.vscode
|
||||
.zed
|
||||
.cursor
|
||||
|
||||
# macos
|
||||
.DS_store
|
||||
@@ -28,6 +29,8 @@ settings.json
|
||||
|
||||
# others
|
||||
/deployment/data/nginx/app.conf
|
||||
/deployment/data/nginx/mcp.conf.inc
|
||||
/deployment/data/nginx/mcp_upstream.conf.inc
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
*.egg-info
|
||||
@@ -46,5 +49,7 @@ CLAUDE.md
|
||||
# Local .terraform.lock.hcl file
|
||||
.terraform.lock.hcl
|
||||
|
||||
node_modules
|
||||
|
||||
# MCP configs
|
||||
.playwright-mcp
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"onyx-mcp": {
|
||||
"type": "http",
|
||||
"url": "http://localhost:8000/mcp"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,20 @@
|
||||
default_install_hook_types:
|
||||
- pre-commit
|
||||
- post-checkout
|
||||
- post-merge
|
||||
- post-rewrite
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||
# This revision is from https://github.com/astral-sh/uv-pre-commit/pull/53
|
||||
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
hooks:
|
||||
- id: uv-sync
|
||||
- id: uv-run
|
||||
name: mypy
|
||||
args: ["mypy"]
|
||||
pass_filenames: true
|
||||
files: ^backend/.*\.py$
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.6.0
|
||||
hooks:
|
||||
@@ -71,31 +87,3 @@ repos:
|
||||
entry: python3 backend/scripts/check_lazy_imports.py
|
||||
language: system
|
||||
files: ^backend/(?!\.venv/).*\.py$
|
||||
|
||||
# We would like to have a mypy pre-commit hook, but due to the fact that
|
||||
# pre-commit runs in it's own isolated environment, we would need to install
|
||||
# and keep in sync all dependencies so mypy has access to the appropriate type
|
||||
# stubs. This does not seem worth it at the moment, so for now we will stick to
|
||||
# having mypy run via Github Actions / manually by contributors
|
||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||
# rev: v1.1.1
|
||||
# hooks:
|
||||
# - id: mypy
|
||||
# exclude: ^tests/
|
||||
# # below are needed for type stubs since pre-commit runs in it's own
|
||||
# # isolated environment. Unfortunately, this needs to be kept in sync
|
||||
# # with requirements/dev.txt + requirements/default.txt
|
||||
# additional_dependencies: [
|
||||
# alembic==1.10.4,
|
||||
# types-beautifulsoup4==4.12.0.3,
|
||||
# types-html5lib==1.1.11.13,
|
||||
# types-oauthlib==3.2.0.9,
|
||||
# types-psycopg2==2.9.21.10,
|
||||
# types-python-dateutil==2.8.19.13,
|
||||
# types-regex==2023.3.23.1,
|
||||
# types-requests==2.28.11.17,
|
||||
# types-retry==0.9.9.3,
|
||||
# types-urllib3==1.26.25.11
|
||||
# ]
|
||||
# # TODO: add back once errors are addressed
|
||||
# # args: [--strict]
|
||||
|
||||
29
.vscode/launch.template.jsonc
vendored
29
.vscode/launch.template.jsonc
vendored
@@ -20,6 +20,7 @@
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
"MCP Server",
|
||||
"Slack Bot",
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
@@ -152,6 +153,34 @@
|
||||
},
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "MCP Server",
|
||||
"consoleName": "MCP Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"MCP_SERVER_ENABLED": "true",
|
||||
"MCP_SERVER_PORT": "8090",
|
||||
"MCP_SERVER_CORS_ORIGINS": "http://localhost:*",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
"args": [
|
||||
"onyx.mcp_server.api:mcp_app",
|
||||
"--reload",
|
||||
"--port",
|
||||
"8090",
|
||||
"--timeout-graceful-shutdown",
|
||||
"0"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "MCP Server Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery primary",
|
||||
"type": "debugpy",
|
||||
|
||||
@@ -12,6 +12,13 @@ ENV DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
DO_NOT_TRACK="true" \
|
||||
PLAYWRIGHT_BROWSERS_PATH="/app/.cache/ms-playwright"
|
||||
|
||||
# Create non-root user for security best practices
|
||||
RUN groupadd -g 1001 onyx && \
|
||||
useradd -u 1001 -g onyx -m -s /bin/bash onyx && \
|
||||
mkdir -p /var/log/onyx && \
|
||||
chmod 755 /var/log/onyx && \
|
||||
chown onyx:onyx /var/log/onyx
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.9.9 /uv /uvx /bin/
|
||||
|
||||
# Install system dependencies
|
||||
@@ -51,6 +58,7 @@ RUN uv pip install --system --no-cache-dir --upgrade \
|
||||
pip uninstall -y py && \
|
||||
playwright install chromium && \
|
||||
playwright install-deps chromium && \
|
||||
chown -R onyx:onyx /app && \
|
||||
ln -s /usr/local/bin/supervisord /usr/bin/supervisord && \
|
||||
# Cleanup for CVEs and size reduction
|
||||
# https://github.com/tornadoweb/tornado/issues/3107
|
||||
@@ -94,13 +102,6 @@ tiktoken.get_encoding('cl100k_base')"
|
||||
# Set up application files
|
||||
WORKDIR /app
|
||||
|
||||
# Create non-root user for security best practices
|
||||
RUN groupadd -g 1001 onyx && \
|
||||
useradd -u 1001 -g onyx -m -s /bin/bash onyx && \
|
||||
mkdir -p /var/log/onyx && \
|
||||
chmod 755 /var/log/onyx && \
|
||||
chown onyx:onyx /var/log/onyx
|
||||
|
||||
# Enterprise Version Files
|
||||
COPY --chown=onyx:onyx ./ee /app/ee
|
||||
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
|
||||
|
||||
@@ -1,4 +1,42 @@
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
# Base stage with dependencies
|
||||
FROM python:3.11.7-slim-bookworm AS base
|
||||
|
||||
ENV DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
HF_HOME=/app/.cache/huggingface
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.9.9 /uv /uvx /bin/
|
||||
|
||||
RUN mkdir -p /app/.cache/huggingface
|
||||
|
||||
COPY ./requirements/model_server.txt /tmp/requirements.txt
|
||||
RUN uv pip install --system --no-cache-dir --upgrade \
|
||||
-r /tmp/requirements.txt && \
|
||||
rm -rf ~/.cache/uv /tmp/*.txt
|
||||
|
||||
# Stage for downloading tokenizers
|
||||
FROM base AS tokenizers
|
||||
RUN python -c "from transformers import AutoTokenizer; \
|
||||
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
|
||||
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1');"
|
||||
|
||||
# Stage for downloading Onyx models
|
||||
FROM base AS onyx-models
|
||||
RUN python -c "from huggingface_hub import snapshot_download; \
|
||||
snapshot_download(repo_id='onyx-dot-app/hybrid-intent-token-classifier'); \
|
||||
snapshot_download(repo_id='onyx-dot-app/information-content-model');"
|
||||
|
||||
# Stage for downloading embedding and reranking models
|
||||
FROM base AS embedding-models
|
||||
RUN python -c "from huggingface_hub import snapshot_download; \
|
||||
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1');"
|
||||
|
||||
# Initialize SentenceTransformer to cache the custom architecture
|
||||
RUN python -c "from sentence_transformers import SentenceTransformer; \
|
||||
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);"
|
||||
|
||||
# Final stage - combine all downloads
|
||||
FROM base AS final
|
||||
|
||||
LABEL com.danswer.maintainer="founders@onyx.app"
|
||||
LABEL com.danswer.description="This image is for the Onyx model server which runs all of the \
|
||||
@@ -6,44 +44,19 @@ AI models for Onyx. This container and all the code is MIT Licensed and free for
|
||||
You can find it at https://hub.docker.com/r/onyx/onyx-model-server. For more details, \
|
||||
visit https://github.com/onyx-dot-app/onyx."
|
||||
|
||||
ENV DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
HF_HOME=/app/.cache/huggingface
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.9.9 /uv /uvx /bin/
|
||||
|
||||
# Create non-root user for security best practices
|
||||
RUN mkdir -p /app && \
|
||||
groupadd -g 1001 onyx && \
|
||||
useradd -u 1001 -g onyx -m -s /bin/bash onyx && \
|
||||
chown -R onyx:onyx /app && \
|
||||
RUN groupadd -g 1001 onyx && \
|
||||
useradd -u 1001 -g onyx -m -s /bin/bash onyx && \
|
||||
mkdir -p /var/log/onyx && \
|
||||
chmod 755 /var/log/onyx && \
|
||||
chown onyx:onyx /var/log/onyx
|
||||
|
||||
COPY ./requirements/model_server.txt /tmp/requirements.txt
|
||||
RUN uv pip install --system --no-cache-dir --upgrade \
|
||||
-r /tmp/requirements.txt && \
|
||||
rm -rf ~/.cache/uv /tmp/*.txt
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
# Download tokenizers, distilbert for the Onyx model
|
||||
# Download model weights
|
||||
# Run Nomic to pull in the custom architecture and have it cached locally
|
||||
RUN python -c "from transformers import AutoTokenizer; \
|
||||
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
|
||||
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
from huggingface_hub import snapshot_download; \
|
||||
snapshot_download(repo_id='onyx-dot-app/hybrid-intent-token-classifier'); \
|
||||
snapshot_download(repo_id='onyx-dot-app/information-content-model'); \
|
||||
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
from sentence_transformers import SentenceTransformer; \
|
||||
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);" && \
|
||||
# In case the user has volumes mounted to /app/.cache/huggingface that they've downloaded while
|
||||
# running Onyx, move the current contents of the cache folder to a temporary location to ensure
|
||||
# it's preserved in order to combine with the user's cache contents
|
||||
mv /app/.cache/huggingface /app/.cache/temp_huggingface && \
|
||||
chown -R onyx:onyx /app
|
||||
# In case the user has volumes mounted to /app/.cache/huggingface that they've downloaded while
|
||||
# running Onyx, move the current contents of the cache folder to a temporary location to ensure
|
||||
# it's preserved in order to combine with the user's cache contents
|
||||
COPY --chown=onyx:onyx --from=tokenizers /app/.cache/huggingface /app/.cache/temp_huggingface
|
||||
COPY --chown=onyx:onyx --from=onyx-models /app/.cache/huggingface /app/.cache/temp_huggingface
|
||||
COPY --chown=onyx:onyx --from=embedding-models /app/.cache/huggingface /app/.cache/temp_huggingface
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
"""add internet search and content provider tables
|
||||
|
||||
Revision ID: 1f2a3b4c5d6e
|
||||
Revises: 9drpiiw74ljy
|
||||
Create Date: 2025-11-10 19:45:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1f2a3b4c5d6e"
|
||||
down_revision = "9drpiiw74ljy"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"internet_search_provider",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("name", sa.String(), nullable=False, unique=True),
|
||||
sa.Column("provider_type", sa.String(), nullable=False),
|
||||
sa.Column("api_key", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("config", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column(
|
||||
"is_active", sa.Boolean(), nullable=False, server_default=sa.text("false")
|
||||
),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_internet_search_provider_is_active",
|
||||
"internet_search_provider",
|
||||
["is_active"],
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"internet_content_provider",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("name", sa.String(), nullable=False, unique=True),
|
||||
sa.Column("provider_type", sa.String(), nullable=False),
|
||||
sa.Column("api_key", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("config", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column(
|
||||
"is_active", sa.Boolean(), nullable=False, server_default=sa.text("false")
|
||||
),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_internet_content_provider_is_active",
|
||||
"internet_content_provider",
|
||||
["is_active"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
"ix_internet_content_provider_is_active", table_name="internet_content_provider"
|
||||
)
|
||||
op.drop_table("internet_content_provider")
|
||||
op.drop_index(
|
||||
"ix_internet_search_provider_is_active", table_name="internet_search_provider"
|
||||
)
|
||||
op.drop_table("internet_search_provider")
|
||||
@@ -0,0 +1,89 @@
|
||||
"""seed_exa_provider_from_env
|
||||
|
||||
Revision ID: 3c9a65f1207f
|
||||
Revises: 1f2a3b4c5d6e
|
||||
Create Date: 2025-11-20 19:18:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
|
||||
from onyx.utils.encryption import encrypt_string_to_bytes
|
||||
|
||||
revision = "3c9a65f1207f"
|
||||
down_revision = "1f2a3b4c5d6e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
EXA_PROVIDER_NAME = "Exa"
|
||||
|
||||
|
||||
def _get_internet_search_table(metadata: sa.MetaData) -> sa.Table:
|
||||
return sa.Table(
|
||||
"internet_search_provider",
|
||||
metadata,
|
||||
sa.Column("id", sa.Integer, primary_key=True),
|
||||
sa.Column("name", sa.String),
|
||||
sa.Column("provider_type", sa.String),
|
||||
sa.Column("api_key", sa.LargeBinary),
|
||||
sa.Column("config", postgresql.JSONB),
|
||||
sa.Column("is_active", sa.Boolean),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
load_dotenv(find_dotenv())
|
||||
|
||||
exa_api_key = os.environ.get("EXA_API_KEY")
|
||||
if not exa_api_key:
|
||||
return
|
||||
|
||||
bind = op.get_bind()
|
||||
metadata = sa.MetaData()
|
||||
table = _get_internet_search_table(metadata)
|
||||
|
||||
existing = bind.execute(
|
||||
sa.select(table.c.id).where(table.c.name == EXA_PROVIDER_NAME)
|
||||
).first()
|
||||
if existing:
|
||||
return
|
||||
|
||||
encrypted_key = encrypt_string_to_bytes(exa_api_key)
|
||||
|
||||
has_active_provider = bind.execute(
|
||||
sa.select(table.c.id).where(table.c.is_active.is_(True))
|
||||
).first()
|
||||
|
||||
bind.execute(
|
||||
table.insert().values(
|
||||
name=EXA_PROVIDER_NAME,
|
||||
provider_type="exa",
|
||||
api_key=encrypted_key,
|
||||
config=None,
|
||||
is_active=not bool(has_active_provider),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
return
|
||||
104
backend/alembic/versions/4f8a2b3c1d9e_add_open_url_tool.py
Normal file
104
backend/alembic/versions/4f8a2b3c1d9e_add_open_url_tool.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""add_open_url_tool
|
||||
|
||||
Revision ID: 4f8a2b3c1d9e
|
||||
Revises: a852cbe15577
|
||||
Create Date: 2025-11-24 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4f8a2b3c1d9e"
|
||||
down_revision = "a852cbe15577"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
OPEN_URL_TOOL = {
|
||||
"name": "OpenURLTool",
|
||||
"display_name": "Open URL",
|
||||
"description": (
|
||||
"The Open URL Action allows the agent to fetch and read contents of web pages."
|
||||
),
|
||||
"in_code_tool_id": "OpenURLTool",
|
||||
"enabled": True,
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Check if tool already exists
|
||||
existing = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
|
||||
{"in_code_tool_id": OPEN_URL_TOOL["in_code_tool_id"]},
|
||||
).fetchone()
|
||||
|
||||
if existing:
|
||||
tool_id = existing[0]
|
||||
# Update existing tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
OPEN_URL_TOOL,
|
||||
)
|
||||
else:
|
||||
# Insert new tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled)
|
||||
"""
|
||||
),
|
||||
OPEN_URL_TOOL,
|
||||
)
|
||||
# Get the newly inserted tool's id
|
||||
result = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
|
||||
{"in_code_tool_id": OPEN_URL_TOOL["in_code_tool_id"]},
|
||||
).fetchone()
|
||||
tool_id = result[0] # type: ignore
|
||||
|
||||
# Associate the tool with all existing personas
|
||||
# Get all persona IDs
|
||||
persona_ids = conn.execute(sa.text("SELECT id FROM persona")).fetchall()
|
||||
|
||||
for (persona_id,) in persona_ids:
|
||||
# Check if association already exists
|
||||
exists = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT 1 FROM persona__tool
|
||||
WHERE persona_id = :persona_id AND tool_id = :tool_id
|
||||
"""
|
||||
),
|
||||
{"persona_id": persona_id, "tool_id": tool_id},
|
||||
).fetchone()
|
||||
|
||||
if not exists:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (:persona_id, :tool_id)
|
||||
"""
|
||||
),
|
||||
{"persona_id": persona_id, "tool_id": tool_id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We don't remove the tool on downgrade since it's fine to have it around.
|
||||
# If we upgrade again, it will be a no-op.
|
||||
pass
|
||||
@@ -0,0 +1,44 @@
|
||||
"""add_created_at_in_project_userfile
|
||||
|
||||
Revision ID: 6436661d5b65
|
||||
Revises: c7e9f4a3b2d1
|
||||
Create Date: 2025-11-24 11:50:24.536052
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6436661d5b65"
|
||||
down_revision = "c7e9f4a3b2d1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add created_at column to project__user_file table
|
||||
op.add_column(
|
||||
"project__user_file",
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
# Add composite index on (project_id, created_at DESC)
|
||||
op.create_index(
|
||||
"ix_project__user_file_project_id_created_at",
|
||||
"project__user_file",
|
||||
["project_id", sa.text("created_at DESC")],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove composite index on (project_id, created_at)
|
||||
op.drop_index(
|
||||
"ix_project__user_file_project_id_created_at", table_name="project__user_file"
|
||||
)
|
||||
# Remove created_at column from project__user_file table
|
||||
op.drop_column("project__user_file", "created_at")
|
||||
572
backend/alembic/versions/a852cbe15577_new_chat_history.py
Normal file
572
backend/alembic/versions/a852cbe15577_new_chat_history.py
Normal file
@@ -0,0 +1,572 @@
|
||||
"""New Chat History
|
||||
|
||||
Revision ID: a852cbe15577
|
||||
Revises: 6436661d5b65
|
||||
Create Date: 2025-11-08 15:16:37.781308
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a852cbe15577"
|
||||
down_revision = "6436661d5b65"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop research agent tables (if they exist)
|
||||
op.execute("DROP TABLE IF EXISTS research_agent_iteration_sub_step CASCADE")
|
||||
op.execute("DROP TABLE IF EXISTS research_agent_iteration CASCADE")
|
||||
|
||||
# Drop agent sub query and sub question tables (if they exist)
|
||||
op.execute("DROP TABLE IF EXISTS agent__sub_query__search_doc CASCADE")
|
||||
op.execute("DROP TABLE IF EXISTS agent__sub_query CASCADE")
|
||||
op.execute("DROP TABLE IF EXISTS agent__sub_question CASCADE")
|
||||
|
||||
# Update ChatMessage table
|
||||
# Rename parent_message to parent_message_id and make it a foreign key (if not already done)
|
||||
conn = op.get_bind()
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'chat_message' AND column_name = 'parent_message'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if result.fetchone():
|
||||
op.alter_column(
|
||||
"chat_message", "parent_message", new_column_name="parent_message_id"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_chat_message_parent_message_id",
|
||||
"chat_message",
|
||||
"chat_message",
|
||||
["parent_message_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Rename latest_child_message to latest_child_message_id and make it a foreign key (if not already done)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'chat_message' AND column_name = 'latest_child_message'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if result.fetchone():
|
||||
op.alter_column(
|
||||
"chat_message",
|
||||
"latest_child_message",
|
||||
new_column_name="latest_child_message_id",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_chat_message_latest_child_message_id",
|
||||
"chat_message",
|
||||
"chat_message",
|
||||
["latest_child_message_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Add reasoning_tokens column (if not exists)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'chat_message' AND column_name = 'reasoning_tokens'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("reasoning_tokens", sa.Text(), nullable=True)
|
||||
)
|
||||
|
||||
# Drop columns no longer needed (if they exist)
|
||||
for col in [
|
||||
"rephrased_query",
|
||||
"alternate_assistant_id",
|
||||
"overridden_model",
|
||||
"is_agentic",
|
||||
"refined_answer_improvement",
|
||||
"research_type",
|
||||
"research_plan",
|
||||
"research_answer_purpose",
|
||||
]:
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
f"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'chat_message' AND column_name = '{col}'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if result.fetchone():
|
||||
op.drop_column("chat_message", col)
|
||||
|
||||
# Update ToolCall table
|
||||
# Add chat_session_id column (if not exists)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = 'chat_session_id'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_tool_call_chat_session_id",
|
||||
"tool_call",
|
||||
"chat_session",
|
||||
["chat_session_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Rename message_id to parent_chat_message_id and make nullable (if not already done)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = 'message_id'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if result.fetchone():
|
||||
op.alter_column(
|
||||
"tool_call",
|
||||
"message_id",
|
||||
new_column_name="parent_chat_message_id",
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Add parent_tool_call_id (if not exists)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = 'parent_tool_call_id'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.add_column(
|
||||
"tool_call", sa.Column("parent_tool_call_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_tool_call_parent_tool_call_id",
|
||||
"tool_call",
|
||||
"tool_call",
|
||||
["parent_tool_call_id"],
|
||||
["id"],
|
||||
)
|
||||
op.drop_constraint("uq_tool_call_message_id", "tool_call", type_="unique")
|
||||
|
||||
# Add turn_number, tool_id (if not exists)
|
||||
for col_name in ["turn_number", "tool_id"]:
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
f"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = '{col_name}'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column(col_name, sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
# Add tool_call_id as String (if not exists)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = 'tool_call_id'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("tool_call_id", sa.String(), nullable=False, server_default=""),
|
||||
)
|
||||
|
||||
# Add reasoning_tokens (if not exists)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = 'reasoning_tokens'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.add_column(
|
||||
"tool_call", sa.Column("reasoning_tokens", sa.Text(), nullable=True)
|
||||
)
|
||||
|
||||
# Rename tool_arguments to tool_call_arguments (if not already done)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = 'tool_arguments'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if result.fetchone():
|
||||
op.alter_column(
|
||||
"tool_call", "tool_arguments", new_column_name="tool_call_arguments"
|
||||
)
|
||||
|
||||
# Rename tool_result to tool_call_response and change type from JSONB to Text (if not already done)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name, data_type FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = 'tool_result'
|
||||
"""
|
||||
)
|
||||
)
|
||||
tool_result_row = result.fetchone()
|
||||
if tool_result_row:
|
||||
op.alter_column(
|
||||
"tool_call", "tool_result", new_column_name="tool_call_response"
|
||||
)
|
||||
# Change type from JSONB to Text
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
ALTER TABLE tool_call
|
||||
ALTER COLUMN tool_call_response TYPE TEXT
|
||||
USING tool_call_response::text
|
||||
"""
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Check if tool_call_response already exists and is JSONB, then convert to Text
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT data_type FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = 'tool_call_response'
|
||||
"""
|
||||
)
|
||||
)
|
||||
tool_call_response_row = result.fetchone()
|
||||
if tool_call_response_row and tool_call_response_row[0] == "jsonb":
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
ALTER TABLE tool_call
|
||||
ALTER COLUMN tool_call_response TYPE TEXT
|
||||
USING tool_call_response::text
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Add tool_call_tokens (if not exists)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = 'tool_call_tokens'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column(
|
||||
"tool_call_tokens", sa.Integer(), nullable=False, server_default="0"
|
||||
),
|
||||
)
|
||||
|
||||
# Add generated_images column for image generation tool replay (if not exists)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = 'generated_images'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("generated_images", postgresql.JSONB(), nullable=True),
|
||||
)
|
||||
|
||||
# Drop tool_name column (if exists)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'tool_call' AND column_name = 'tool_name'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if result.fetchone():
|
||||
op.drop_column("tool_call", "tool_name")
|
||||
|
||||
# Create tool_call__search_doc association table (if not exists)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT table_name FROM information_schema.tables
|
||||
WHERE table_name = 'tool_call__search_doc'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.create_table(
|
||||
"tool_call__search_doc",
|
||||
sa.Column("tool_call_id", sa.Integer(), nullable=False),
|
||||
sa.Column("search_doc_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["tool_call_id"], ["tool_call.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_doc_id"], ["search_doc.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("tool_call_id", "search_doc_id"),
|
||||
)
|
||||
|
||||
# Add replace_base_system_prompt to persona table (if not exists)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'persona' AND column_name = 'replace_base_system_prompt'
|
||||
"""
|
||||
)
|
||||
)
|
||||
if not result.fetchone():
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"replace_base_system_prompt",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Reverse persona changes
|
||||
op.drop_column("persona", "replace_base_system_prompt")
|
||||
|
||||
# Drop tool_call__search_doc association table
|
||||
op.execute("DROP TABLE IF EXISTS tool_call__search_doc CASCADE")
|
||||
|
||||
# Reverse ToolCall changes
|
||||
op.add_column("tool_call", sa.Column("tool_name", sa.String(), nullable=False))
|
||||
op.drop_column("tool_call", "tool_id")
|
||||
op.drop_column("tool_call", "tool_call_tokens")
|
||||
op.drop_column("tool_call", "generated_images")
|
||||
# Change tool_call_response back to JSONB before renaming
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
ALTER TABLE tool_call
|
||||
ALTER COLUMN tool_call_response TYPE JSONB
|
||||
USING tool_call_response::jsonb
|
||||
"""
|
||||
)
|
||||
)
|
||||
op.alter_column("tool_call", "tool_call_response", new_column_name="tool_result")
|
||||
op.alter_column(
|
||||
"tool_call", "tool_call_arguments", new_column_name="tool_arguments"
|
||||
)
|
||||
op.drop_column("tool_call", "reasoning_tokens")
|
||||
op.drop_column("tool_call", "tool_call_id")
|
||||
op.drop_column("tool_call", "turn_number")
|
||||
op.drop_constraint(
|
||||
"fk_tool_call_parent_tool_call_id", "tool_call", type_="foreignkey"
|
||||
)
|
||||
op.drop_column("tool_call", "parent_tool_call_id")
|
||||
op.alter_column(
|
||||
"tool_call",
|
||||
"parent_chat_message_id",
|
||||
new_column_name="message_id",
|
||||
nullable=False,
|
||||
)
|
||||
op.drop_constraint("fk_tool_call_chat_session_id", "tool_call", type_="foreignkey")
|
||||
op.drop_column("tool_call", "chat_session_id")
|
||||
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"research_answer_purpose",
|
||||
sa.Enum("INTRO", "DEEP_DIVE", name="researchanswerpurpose"),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("research_plan", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"research_type",
|
||||
sa.Enum("SIMPLE", "DEEP", name="researchtype"),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("refined_answer_improvement", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("is_agentic", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("overridden_model", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("rephrased_query", sa.Text(), nullable=True)
|
||||
)
|
||||
op.drop_column("chat_message", "reasoning_tokens")
|
||||
op.drop_constraint(
|
||||
"fk_chat_message_latest_child_message_id", "chat_message", type_="foreignkey"
|
||||
)
|
||||
op.alter_column(
|
||||
"chat_message",
|
||||
"latest_child_message_id",
|
||||
new_column_name="latest_child_message",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"fk_chat_message_parent_message_id", "chat_message", type_="foreignkey"
|
||||
)
|
||||
op.alter_column(
|
||||
"chat_message", "parent_message_id", new_column_name="parent_message"
|
||||
)
|
||||
|
||||
# Recreate agent sub question and sub query tables
|
||||
op.create_table(
|
||||
"agent__sub_question",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("primary_question_id", sa.Integer(), nullable=False),
|
||||
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("sub_question", sa.Text(), nullable=False),
|
||||
sa.Column("level", sa.Integer(), nullable=False),
|
||||
sa.Column("level_question_num", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("sub_answer", sa.Text(), nullable=False),
|
||||
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["primary_question_id"], ["chat_message.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["chat_session_id"], ["chat_session.id"]),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"agent__sub_query",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("parent_question_id", sa.Integer(), nullable=False),
|
||||
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("sub_query", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["parent_question_id"], ["agent__sub_question.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["chat_session_id"], ["chat_session.id"]),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"agent__sub_query__search_doc",
|
||||
sa.Column("sub_query_id", sa.Integer(), nullable=False),
|
||||
sa.Column("search_doc_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["sub_query_id"], ["agent__sub_query.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["search_doc_id"], ["search_doc.id"]),
|
||||
sa.PrimaryKeyConstraint("sub_query_id", "search_doc_id"),
|
||||
)
|
||||
|
||||
# Recreate research agent tables
|
||||
op.create_table(
|
||||
"research_agent_iteration",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("primary_question_id", sa.Integer(), nullable=False),
|
||||
sa.Column("iteration_nr", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("purpose", sa.String(), nullable=True),
|
||||
sa.Column("reasoning", sa.String(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["primary_question_id"], ["chat_message.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"primary_question_id",
|
||||
"iteration_nr",
|
||||
name="_research_agent_iteration_unique_constraint",
|
||||
),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"research_agent_iteration_sub_step",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("primary_question_id", sa.Integer(), nullable=False),
|
||||
sa.Column("iteration_nr", sa.Integer(), nullable=False),
|
||||
sa.Column("iteration_sub_step_nr", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("sub_step_instructions", sa.String(), nullable=True),
|
||||
sa.Column("sub_step_tool_id", sa.Integer(), nullable=True),
|
||||
sa.Column("reasoning", sa.String(), nullable=True),
|
||||
sa.Column("sub_answer", sa.String(), nullable=True),
|
||||
sa.Column("cited_doc_results", postgresql.JSONB(), nullable=False),
|
||||
sa.Column("claims", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("is_web_fetch", sa.Boolean(), nullable=True),
|
||||
sa.Column("queries", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("generated_images", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("additional_data", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["primary_question_id", "iteration_nr"],
|
||||
[
|
||||
"research_agent_iteration.primary_question_id",
|
||||
"research_agent_iteration.iteration_nr",
|
||||
],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(["sub_step_tool_id"], ["tool.id"], ondelete="SET NULL"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
73
backend/alembic/versions/c7e9f4a3b2d1_add_python_tool.py
Normal file
73
backend/alembic/versions/c7e9f4a3b2d1_add_python_tool.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""add_python_tool
|
||||
|
||||
Revision ID: c7e9f4a3b2d1
|
||||
Revises: 3c9a65f1207f
|
||||
Create Date: 2025-11-08 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c7e9f4a3b2d1"
|
||||
down_revision = "3c9a65f1207f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add PythonTool to built-in tools"""
|
||||
conn = op.get_bind()
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"name": "PythonTool",
|
||||
# in the UI, call it `Code Interpreter` since this is a well known term for this tool
|
||||
"display_name": "Code Interpreter",
|
||||
"description": (
|
||||
"The Code Interpreter Action allows the assistant to execute "
|
||||
"Python code in a secure, isolated environment for data analysis, "
|
||||
"computation, visualization, and file processing."
|
||||
),
|
||||
"in_code_tool_id": "PythonTool",
|
||||
"enabled": True,
|
||||
},
|
||||
)
|
||||
|
||||
# needed to store files generated by the python tool
|
||||
op.add_column(
|
||||
"research_agent_iteration_sub_step",
|
||||
sa.Column(
|
||||
"file_ids",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove PythonTool from built-in tools"""
|
||||
conn = op.get_bind()
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM tool
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"in_code_tool_id": "PythonTool",
|
||||
},
|
||||
)
|
||||
|
||||
op.drop_column("research_agent_iteration_sub_step", "file_ids")
|
||||
@@ -1,4 +1,16 @@
|
||||
variable "REPOSITORY" {
|
||||
group "default" {
|
||||
targets = ["backend", "model-server"]
|
||||
}
|
||||
|
||||
variable "BACKEND_REPOSITORY" {
|
||||
default = "onyxdotapp/onyx-backend"
|
||||
}
|
||||
|
||||
variable "MODEL_SERVER_REPOSITORY" {
|
||||
default = "onyxdotapp/onyx-model-server"
|
||||
}
|
||||
|
||||
variable "INTEGRATION_REPOSITORY" {
|
||||
default = "onyxdotapp/onyx-integration"
|
||||
}
|
||||
|
||||
@@ -9,6 +21,22 @@ variable "TAG" {
|
||||
target "backend" {
|
||||
context = "."
|
||||
dockerfile = "Dockerfile"
|
||||
|
||||
cache-from = ["type=registry,ref=${BACKEND_REPOSITORY}:latest"]
|
||||
cache-to = ["type=inline"]
|
||||
|
||||
tags = ["${BACKEND_REPOSITORY}:${TAG}"]
|
||||
}
|
||||
|
||||
target "model-server" {
|
||||
context = "."
|
||||
|
||||
dockerfile = "Dockerfile.model_server"
|
||||
|
||||
cache-from = ["type=registry,ref=${MODEL_SERVER_REPOSITORY}:latest"]
|
||||
cache-to = ["type=inline"]
|
||||
|
||||
tags = ["${MODEL_SERVER_REPOSITORY}:${TAG}"]
|
||||
}
|
||||
|
||||
target "integration" {
|
||||
@@ -20,8 +48,5 @@ target "integration" {
|
||||
base = "target:backend"
|
||||
}
|
||||
|
||||
cache-from = ["type=registry,ref=${REPOSITORY}:integration-test-backend-cache"]
|
||||
cache-to = ["type=registry,ref=${REPOSITORY}:integration-test-backend-cache,mode=max"]
|
||||
|
||||
tags = ["${REPOSITORY}:${TAG}"]
|
||||
tags = ["${INTEGRATION_REPOSITORY}:${TAG}"]
|
||||
}
|
||||
|
||||
@@ -124,6 +124,8 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
|
||||
POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
|
||||
|
||||
MARKETING_POSTHOG_API_KEY = os.environ.get("MARKETING_POSTHOG_API_KEY")
|
||||
|
||||
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
|
||||
|
||||
GATED_TENANTS_KEY = "gated_tenants"
|
||||
|
||||
@@ -199,10 +199,7 @@ def fetch_persona_message_analytics(
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == persona_id,
|
||||
ChatSession.persona_id == persona_id,
|
||||
),
|
||||
ChatSession.persona_id == persona_id,
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
@@ -231,10 +228,7 @@ def fetch_persona_unique_users(
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == persona_id,
|
||||
ChatSession.persona_id == persona_id,
|
||||
),
|
||||
ChatSession.persona_id == persona_id,
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
@@ -265,10 +259,7 @@ def fetch_assistant_message_analytics(
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == assistant_id,
|
||||
ChatSession.persona_id == assistant_id,
|
||||
),
|
||||
ChatSession.persona_id == assistant_id,
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
@@ -299,10 +290,7 @@ def fetch_assistant_unique_users(
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == assistant_id,
|
||||
ChatSession.persona_id == assistant_id,
|
||||
),
|
||||
ChatSession.persona_id == assistant_id,
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
@@ -332,10 +320,7 @@ def fetch_assistant_unique_users_total(
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == assistant_id,
|
||||
ChatSession.persona_id == assistant_id,
|
||||
),
|
||||
ChatSession.persona_id == assistant_id,
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
|
||||
@@ -55,18 +55,7 @@ def get_empty_chat_messages_entries__paginated(
|
||||
|
||||
# Get assistant name (from session persona, or alternate if specified)
|
||||
assistant_name = None
|
||||
if message.alternate_assistant_id:
|
||||
# If there's an alternate assistant, we need to fetch it
|
||||
from onyx.db.models import Persona
|
||||
|
||||
alternate_persona = (
|
||||
db_session.query(Persona)
|
||||
.filter(Persona.id == message.alternate_assistant_id)
|
||||
.first()
|
||||
)
|
||||
if alternate_persona:
|
||||
assistant_name = alternate_persona.name
|
||||
elif chat_session.persona:
|
||||
if chat_session.persona:
|
||||
assistant_name = chat_session.persona.name
|
||||
|
||||
message_skeletons.append(
|
||||
|
||||
@@ -581,6 +581,48 @@ def update_user_curator_relationship(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def add_users_to_user_group(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
user_group_id: int,
|
||||
user_ids: list[UUID],
|
||||
) -> UserGroup:
|
||||
db_user_group = fetch_user_group(db_session=db_session, user_group_id=user_group_id)
|
||||
if db_user_group is None:
|
||||
raise ValueError(f"UserGroup with id '{user_group_id}' not found")
|
||||
|
||||
missing_users = [
|
||||
user_id for user_id in user_ids if fetch_user_by_id(db_session, user_id) is None
|
||||
]
|
||||
if missing_users:
|
||||
raise ValueError(
|
||||
f"User(s) not found: {', '.join(str(user_id) for user_id in missing_users)}"
|
||||
)
|
||||
|
||||
_check_user_group_is_modifiable(db_user_group)
|
||||
|
||||
current_user_ids = [user.id for user in db_user_group.users]
|
||||
current_user_ids_set = set(current_user_ids)
|
||||
new_user_ids = [
|
||||
user_id for user_id in user_ids if user_id not in current_user_ids_set
|
||||
]
|
||||
|
||||
if not new_user_ids:
|
||||
return db_user_group
|
||||
|
||||
user_group_update = UserGroupUpdate(
|
||||
user_ids=current_user_ids + new_user_ids,
|
||||
cc_pair_ids=[cc_pair.id for cc_pair in db_user_group.cc_pairs],
|
||||
)
|
||||
|
||||
return update_user_group(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
user_group_id=user_group_id,
|
||||
user_group_update=user_group_update,
|
||||
)
|
||||
|
||||
|
||||
def update_user_group(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
@@ -603,6 +645,17 @@ def update_user_group(
|
||||
added_user_ids = list(updated_user_ids - current_user_ids)
|
||||
removed_user_ids = list(current_user_ids - updated_user_ids)
|
||||
|
||||
if added_user_ids:
|
||||
missing_users = [
|
||||
user_id
|
||||
for user_id in added_user_ids
|
||||
if fetch_user_by_id(db_session, user_id) is None
|
||||
]
|
||||
if missing_users:
|
||||
raise ValueError(
|
||||
f"User(s) not found: {', '.join(str(user_id) for user_id in missing_users)}"
|
||||
)
|
||||
|
||||
# LEAVING THIS HERE FOR NOW FOR GIVING DIFFERENT ROLES
|
||||
# ACCESS TO DIFFERENT PERMISSIONS
|
||||
# if (removed_user_ids or added_user_ids) and (
|
||||
|
||||
@@ -23,7 +23,7 @@ from ee.onyx.server.query_and_chat.chat_backend import (
|
||||
router as chat_router,
|
||||
)
|
||||
from ee.onyx.server.query_and_chat.query_backend import (
|
||||
basic_router as query_router,
|
||||
basic_router as ee_query_router,
|
||||
)
|
||||
from ee.onyx.server.query_history.api import router as query_history_router
|
||||
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
|
||||
@@ -48,6 +48,9 @@ from onyx.main import include_auth_router_with_prefix
|
||||
from onyx.main import include_router_with_global_prefix_prepended
|
||||
from onyx.main import lifespan as lifespan_base
|
||||
from onyx.main import use_route_function_names_as_operation_ids
|
||||
from onyx.server.query_and_chat.query_backend import (
|
||||
basic_router as query_router,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -119,6 +122,7 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, query_history_router)
|
||||
# EE only backend APIs
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_query_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, standard_answer_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_oauth_router)
|
||||
|
||||
@@ -9,7 +9,7 @@ from ee.onyx.server.query_and_chat.models import (
|
||||
)
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import combine_message_thread
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_chat_history_chain
|
||||
from onyx.chat.models import ChatBasicResponse
|
||||
from onyx.chat.process_message import gather_stream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
@@ -69,7 +69,7 @@ def handle_simplified_chat_message(
|
||||
chat_session_id = chat_message_req.chat_session_id
|
||||
|
||||
try:
|
||||
parent_message, _ = create_chat_chain(
|
||||
parent_message, _ = create_chat_history_chain(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
except Exception:
|
||||
|
||||
@@ -6,18 +6,14 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import BasicChunkRequest
|
||||
from onyx.context.search.models import ChunkContext
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.server.manage.models import StandardAnswer
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import SubQuestionIdentifier
|
||||
|
||||
|
||||
class StandardAnswerRequest(BaseModel):
|
||||
@@ -29,14 +25,12 @@ class StandardAnswerResponse(BaseModel):
|
||||
standard_answers: list[StandardAnswer] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DocumentSearchRequest(ChunkContext):
|
||||
message: str
|
||||
search_type: SearchType
|
||||
retrieval_options: RetrievalDetails
|
||||
recency_bias_multiplier: float = 1.0
|
||||
evaluation_type: LLMEvaluationType
|
||||
# None to use system defaults for reranking
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
class DocumentSearchRequest(BasicChunkRequest):
|
||||
user_selected_filters: BaseFilters | None = None
|
||||
|
||||
|
||||
class DocumentSearchResponse(BaseModel):
|
||||
top_documents: list[InferenceChunk]
|
||||
|
||||
|
||||
class BasicCreateChatMessageRequest(ChunkContext):
|
||||
@@ -96,17 +90,17 @@ class SimpleDoc(BaseModel):
|
||||
metadata: dict | None
|
||||
|
||||
|
||||
class AgentSubQuestion(SubQuestionIdentifier):
|
||||
class AgentSubQuestion(BaseModel):
|
||||
sub_question: str
|
||||
document_ids: list[str]
|
||||
|
||||
|
||||
class AgentAnswer(SubQuestionIdentifier):
|
||||
class AgentAnswer(BaseModel):
|
||||
answer: str
|
||||
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
|
||||
|
||||
|
||||
class AgentSubQuery(SubQuestionIdentifier):
|
||||
class AgentSubQuery(BaseModel):
|
||||
sub_query: str
|
||||
query_id: int
|
||||
|
||||
@@ -152,45 +146,3 @@ class AgentSubQuery(SubQuestionIdentifier):
|
||||
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
|
||||
)
|
||||
return sorted_dict
|
||||
|
||||
|
||||
class OneShotQARequest(ChunkContext):
|
||||
# Supports simplier APIs that don't deal with chat histories or message edits
|
||||
# Easier APIs to work with for developers
|
||||
persona_override_config: PersonaOverrideConfig | None = None
|
||||
persona_id: int | None = None
|
||||
|
||||
messages: list[ThreadMessage]
|
||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
|
||||
# allows the caller to specify the exact search query they want to use
|
||||
# can be used if the message sent to the LLM / query should not be the same
|
||||
# will also disable Thread-based Rewording if specified
|
||||
query_override: str | None = None
|
||||
|
||||
# If True, skips generating an AI response to the search query
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_persona_fields(self) -> "OneShotQARequest":
|
||||
if self.persona_override_config is None and self.persona_id is None:
|
||||
raise ValueError("Exactly one of persona_config or persona_id must be set")
|
||||
elif self.persona_override_config is not None and (self.persona_id is not None):
|
||||
raise ValueError(
|
||||
"If persona_override_config is set, persona_id cannot be set"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class OneShotQAResponse(BaseModel):
|
||||
# This is built piece by piece, any of these can be None as the flow could break
|
||||
answer: str | None = None
|
||||
rephrase: str | None = None
|
||||
citations: list[CitationInfo] | None = None
|
||||
docs: QADocsResponse | None = None
|
||||
error_msg: str | None = None
|
||||
chat_message_id: int | None = None
|
||||
|
||||
@@ -1,316 +1,23 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.onyxbot.slack.handlers.handle_standard_answers import (
|
||||
oneoff_standard_answers,
|
||||
)
|
||||
from ee.onyx.server.query_and_chat.models import DocumentSearchRequest
|
||||
from ee.onyx.server.query_and_chat.models import OneShotQARequest
|
||||
from ee.onyx.server.query_and_chat.models import OneShotQAResponse
|
||||
from ee.onyx.server.query_and_chat.models import StandardAnswerRequest
|
||||
from ee.onyx.server.query_and_chat.models import StandardAnswerResponse
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import combine_message_thread
|
||||
from onyx.chat.chat_utils import prepare_chat_message_request
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.process_message import gather_stream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.onyxbot_configs import MAX_THREAD_CONTEXT_PERCENTAGE
|
||||
from onyx.context.search.models import SavedSearchDocWithContent
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.context.search.pipeline import SearchPipeline
|
||||
from onyx.context.search.utils import dedupe_documents
|
||||
from onyx.context.search.utils import drop_llm_indices
|
||||
from onyx.context.search.utils import relevant_sections_to_indices
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.factory import get_main_llm_from_tuple
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
basic_router = APIRouter(prefix="/query")
|
||||
|
||||
|
||||
class DocumentSearchPagination(BaseModel):
|
||||
offset: int
|
||||
limit: int
|
||||
returned_count: int
|
||||
has_more: bool
|
||||
next_offset: int | None = None
|
||||
|
||||
|
||||
class DocumentSearchResponse(BaseModel):
|
||||
top_documents: list[SavedSearchDocWithContent]
|
||||
llm_indices: list[int]
|
||||
pagination: DocumentSearchPagination
|
||||
|
||||
|
||||
def _normalize_pagination(limit: int | None, offset: int | None) -> tuple[int, int]:
|
||||
if limit is None:
|
||||
resolved_limit = NUM_RETURNED_HITS
|
||||
else:
|
||||
resolved_limit = limit
|
||||
|
||||
if resolved_limit <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="retrieval_options.limit must be positive"
|
||||
)
|
||||
|
||||
if offset is None:
|
||||
resolved_offset = 0
|
||||
else:
|
||||
resolved_offset = offset
|
||||
|
||||
if resolved_offset < 0:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="retrieval_options.offset cannot be negative"
|
||||
)
|
||||
|
||||
return resolved_limit, resolved_offset
|
||||
|
||||
|
||||
@basic_router.post("/document-search")
|
||||
def handle_search_request(
|
||||
search_request: DocumentSearchRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> DocumentSearchResponse:
|
||||
"""Simple search endpoint, does not create a new message or records in the DB"""
|
||||
query = search_request.message
|
||||
logger.notice(f"Received document search query: {query}")
|
||||
|
||||
llm, fast_llm = get_default_llms()
|
||||
pagination_limit, pagination_offset = _normalize_pagination(
|
||||
limit=search_request.retrieval_options.limit,
|
||||
offset=search_request.retrieval_options.offset,
|
||||
)
|
||||
|
||||
search_pipeline = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=query,
|
||||
search_type=search_request.search_type,
|
||||
human_selected_filters=search_request.retrieval_options.filters,
|
||||
enable_auto_detect_filters=search_request.retrieval_options.enable_auto_detect_filters,
|
||||
persona=None, # For simplicity, default settings should be good for this search
|
||||
offset=pagination_offset,
|
||||
limit=pagination_limit + 1,
|
||||
rerank_settings=search_request.rerank_settings,
|
||||
evaluation_type=search_request.evaluation_type,
|
||||
chunks_above=search_request.chunks_above,
|
||||
chunks_below=search_request.chunks_below,
|
||||
full_doc=search_request.full_doc,
|
||||
),
|
||||
user=user,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
skip_query_analysis=False,
|
||||
db_session=db_session,
|
||||
bypass_acl=False,
|
||||
)
|
||||
top_sections = search_pipeline.reranked_sections
|
||||
relevance_sections = search_pipeline.section_relevance
|
||||
top_docs = [
|
||||
SavedSearchDocWithContent(
|
||||
document_id=section.center_chunk.document_id,
|
||||
chunk_ind=section.center_chunk.chunk_id,
|
||||
content=section.center_chunk.content,
|
||||
semantic_identifier=section.center_chunk.semantic_identifier or "Unknown",
|
||||
link=(
|
||||
section.center_chunk.source_links.get(0)
|
||||
if section.center_chunk.source_links
|
||||
else None
|
||||
),
|
||||
blurb=section.center_chunk.blurb,
|
||||
source_type=section.center_chunk.source_type,
|
||||
boost=section.center_chunk.boost,
|
||||
hidden=section.center_chunk.hidden,
|
||||
metadata=section.center_chunk.metadata,
|
||||
score=section.center_chunk.score or 0.0,
|
||||
match_highlights=section.center_chunk.match_highlights,
|
||||
updated_at=section.center_chunk.updated_at,
|
||||
primary_owners=section.center_chunk.primary_owners,
|
||||
secondary_owners=section.center_chunk.secondary_owners,
|
||||
is_internet=False,
|
||||
db_doc_id=0,
|
||||
)
|
||||
for section in top_sections
|
||||
]
|
||||
|
||||
# Track whether the underlying retrieval produced more items than requested
|
||||
has_more_results = len(top_docs) > pagination_limit
|
||||
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
deduped_docs = top_docs
|
||||
dropped_inds = None
|
||||
|
||||
if search_request.retrieval_options.dedupe_docs:
|
||||
deduped_docs, dropped_inds = dedupe_documents(top_docs)
|
||||
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections, items=deduped_docs
|
||||
)
|
||||
|
||||
if dropped_inds:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=deduped_docs,
|
||||
dropped_indices=dropped_inds,
|
||||
)
|
||||
|
||||
paginated_docs = deduped_docs[:pagination_limit]
|
||||
llm_indices = [index for index in llm_indices if index < len(paginated_docs)]
|
||||
has_more = has_more_results
|
||||
pagination = DocumentSearchPagination(
|
||||
offset=pagination_offset,
|
||||
limit=pagination_limit,
|
||||
returned_count=len(paginated_docs),
|
||||
has_more=has_more,
|
||||
next_offset=(pagination_offset + pagination_limit) if has_more else None,
|
||||
)
|
||||
|
||||
return DocumentSearchResponse(
|
||||
top_documents=paginated_docs,
|
||||
llm_indices=llm_indices,
|
||||
pagination=pagination,
|
||||
)
|
||||
|
||||
|
||||
def get_answer_stream(
|
||||
query_request: OneShotQARequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AnswerStream:
|
||||
query = query_request.messages[0].message
|
||||
logger.notice(f"Received query for Answer API: {query}")
|
||||
|
||||
if (
|
||||
query_request.persona_override_config is None
|
||||
and query_request.persona_id is None
|
||||
):
|
||||
raise KeyError("Must provide persona ID or Persona Config")
|
||||
|
||||
persona_info: Persona | PersonaOverrideConfig | None = None
|
||||
if query_request.persona_override_config is not None:
|
||||
persona_info = query_request.persona_override_config
|
||||
elif query_request.persona_id is not None:
|
||||
persona_info = get_persona_by_id(
|
||||
persona_id=query_request.persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
|
||||
llm = get_main_llm_from_tuple(get_llms_for_persona(persona=persona_info, user=user))
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
)
|
||||
|
||||
max_history_tokens = int(
|
||||
llm.config.max_input_tokens * MAX_THREAD_CONTEXT_PERCENTAGE
|
||||
)
|
||||
|
||||
combined_message = combine_message_thread(
|
||||
messages=query_request.messages,
|
||||
max_tokens=max_history_tokens,
|
||||
llm_tokenizer=llm_tokenizer,
|
||||
)
|
||||
|
||||
# Also creates a new chat session
|
||||
request = prepare_chat_message_request(
|
||||
message_text=combined_message,
|
||||
user=user,
|
||||
persona_id=query_request.persona_id,
|
||||
persona_override_config=query_request.persona_override_config,
|
||||
message_ts_to_respond_to=None,
|
||||
retrieval_details=query_request.retrieval_options,
|
||||
rerank_settings=query_request.rerank_settings,
|
||||
db_session=db_session,
|
||||
use_agentic_search=query_request.use_agentic_search,
|
||||
skip_gen_ai_answer_generation=query_request.skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=request,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return packets
|
||||
|
||||
|
||||
@basic_router.post("/answer-with-citation")
|
||||
def get_answer_with_citation(
|
||||
request: OneShotQARequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
user: User | None = Depends(current_user),
|
||||
) -> OneShotQAResponse:
|
||||
try:
|
||||
packets = get_answer_stream(request, user, db_session)
|
||||
answer = gather_stream(packets)
|
||||
|
||||
if answer.error_msg:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
|
||||
return OneShotQAResponse(
|
||||
answer=answer.answer,
|
||||
chat_message_id=answer.message_id,
|
||||
error_msg=answer.error_msg,
|
||||
citations=[
|
||||
CitationInfo(citation_num=i, document_id=doc_id)
|
||||
for i, doc_id in answer.cited_documents.items()
|
||||
],
|
||||
docs=QADocsResponse(
|
||||
top_documents=answer.top_documents,
|
||||
predicted_flow=None,
|
||||
predicted_search=None,
|
||||
applied_source_filters=None,
|
||||
applied_time_cutoff=None,
|
||||
recency_bias_multiplier=0.0,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_answer_with_citation: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="An internal server error occurred")
|
||||
|
||||
|
||||
@basic_router.post("/stream-answer-with-citation")
|
||||
def stream_answer_with_citation(
|
||||
request: OneShotQARequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
user: User | None = Depends(current_user),
|
||||
) -> StreamingResponse:
|
||||
def stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
for packet in get_answer_stream(request, user, db_session):
|
||||
serialized = get_json_line(packet.model_dump())
|
||||
yield serialized
|
||||
except Exception as e:
|
||||
logger.exception("Error in answer streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="application/json")
|
||||
|
||||
|
||||
@basic_router.get("/standard-answer")
|
||||
def get_standard_answer(
|
||||
request: StandardAnswerRequest,
|
||||
|
||||
@@ -24,7 +24,7 @@ from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_display_email
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.background.task_utils import construct_query_history_report_name
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_chat_history_chain
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import FileType
|
||||
@@ -123,10 +123,9 @@ def snapshot_from_chat_session(
|
||||
) -> ChatSessionSnapshot | None:
|
||||
try:
|
||||
# Older chats may not have the right structure
|
||||
last_message, messages = create_chat_chain(
|
||||
messages = create_chat_history_chain(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
messages.append(last_message)
|
||||
except RuntimeError:
|
||||
return None
|
||||
|
||||
|
||||
@@ -4,12 +4,14 @@ from fastapi import HTTPException
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.user_group import add_users_to_user_group
|
||||
from ee.onyx.db.user_group import fetch_user_groups
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_user
|
||||
from ee.onyx.db.user_group import insert_user_group
|
||||
from ee.onyx.db.user_group import prepare_user_group_for_deletion
|
||||
from ee.onyx.db.user_group import update_user_curator_relationship
|
||||
from ee.onyx.db.user_group import update_user_group
|
||||
from ee.onyx.server.user_group.models import AddUsersToUserGroupRequest
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
from ee.onyx.server.user_group.models import UserGroup
|
||||
from ee.onyx.server.user_group.models import UserGroupCreate
|
||||
@@ -79,6 +81,26 @@ def patch_user_group(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/admin/user-group/{user_group_id}/add-users")
|
||||
def add_users(
|
||||
user_group_id: int,
|
||||
add_users_request: AddUsersToUserGroupRequest,
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
try:
|
||||
return UserGroup.from_model(
|
||||
add_users_to_user_group(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
user_group_id=user_group_id,
|
||||
user_ids=add_users_request.user_ids,
|
||||
)
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/admin/user-group/{user_group_id}/set-curator")
|
||||
def set_user_curator(
|
||||
user_group_id: int,
|
||||
|
||||
@@ -87,6 +87,10 @@ class UserGroupUpdate(BaseModel):
|
||||
cc_pair_ids: list[int]
|
||||
|
||||
|
||||
class AddUsersToUserGroupRequest(BaseModel):
|
||||
user_ids: list[UUID]
|
||||
|
||||
|
||||
class SetCuratorRequest(BaseModel):
|
||||
user_id: UUID
|
||||
is_curator: bool
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
|
||||
from posthog import Posthog
|
||||
|
||||
from ee.onyx.configs.app_configs import MARKETING_POSTHOG_API_KEY
|
||||
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
|
||||
from ee.onyx.configs.app_configs import POSTHOG_HOST
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -20,3 +23,80 @@ posthog = Posthog(
|
||||
debug=True,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
|
||||
# For cross referencing between cloud and www Onyx sites
|
||||
# NOTE: These clients are separate because they are separate posthog projects.
|
||||
# We should eventually unify them into a single posthog project,
|
||||
# which would no longer require this workaround
|
||||
marketing_posthog = None
|
||||
if MARKETING_POSTHOG_API_KEY:
|
||||
marketing_posthog = Posthog(
|
||||
project_api_key=MARKETING_POSTHOG_API_KEY,
|
||||
host=POSTHOG_HOST,
|
||||
debug=True,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
|
||||
|
||||
def capture_and_sync_with_alternate_posthog(
|
||||
alternate_distinct_id: str, event: str, properties: dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Identify in both PostHog projects and capture the event in marketing.
|
||||
- Marketing keeps the marketing distinct_id (for feature flags).
|
||||
- Cloud identify uses the cloud distinct_id
|
||||
"""
|
||||
if not marketing_posthog:
|
||||
return
|
||||
|
||||
props = properties.copy()
|
||||
|
||||
try:
|
||||
marketing_posthog.identify(distinct_id=alternate_distinct_id, properties=props)
|
||||
marketing_posthog.capture(alternate_distinct_id, event, props)
|
||||
marketing_posthog.flush()
|
||||
except Exception as e:
|
||||
logger.error(f"Error capturing marketing posthog event: {e}")
|
||||
|
||||
try:
|
||||
if cloud_user_id := props.get("onyx_cloud_user_id"):
|
||||
cloud_props = props.copy()
|
||||
cloud_props.pop("onyx_cloud_user_id", None)
|
||||
|
||||
posthog.identify(
|
||||
distinct_id=cloud_user_id,
|
||||
properties=cloud_props,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error identifying cloud posthog user: {e}")
|
||||
|
||||
|
||||
def get_marketing_posthog_cookie_name() -> str | None:
|
||||
if not MARKETING_POSTHOG_API_KEY:
|
||||
return None
|
||||
return f"onyx_custom_ph_{MARKETING_POSTHOG_API_KEY}_posthog"
|
||||
|
||||
|
||||
def parse_marketing_cookie(cookie_value: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Parse the URL-encoded JSON marketing cookie.
|
||||
|
||||
Expected format (URL-encoded):
|
||||
{"distinct_id":"...", "featureFlags":{"landing_page_variant":"..."}, ...}
|
||||
|
||||
Returns:
|
||||
Dict with 'distinct_id' explicitly required and all other cookie values
|
||||
passed through as-is, or None if parsing fails or distinct_id is missing.
|
||||
"""
|
||||
try:
|
||||
decoded_cookie = unquote(cookie_value)
|
||||
cookie_data = json.loads(decoded_cookie)
|
||||
|
||||
distinct_id = cookie_data.get("distinct_id")
|
||||
if not distinct_id:
|
||||
return None
|
||||
|
||||
return cookie_data
|
||||
except (json.JSONDecodeError, KeyError, TypeError, AttributeError) as e:
|
||||
logger.warning(f"Failed to parse cookie: {e}")
|
||||
return None
|
||||
|
||||
@@ -116,7 +116,7 @@ def _concurrent_embedding(
|
||||
# the model to fail to encode texts. It's pretty rare and we want to allow
|
||||
# concurrent embedding, hence we retry (the specific error is
|
||||
# "RuntimeError: Already borrowed" and occurs in the transformers library)
|
||||
logger.error(f"Error encoding texts, retrying: {e}")
|
||||
logger.warning(f"Error encoding texts, retrying: {e}")
|
||||
time.sleep(ENCODING_RETRY_DELAY)
|
||||
return model.encode(texts, normalize_embeddings=normalize_embeddings)
|
||||
|
||||
|
||||
73
backend/onyx/agents/agent_framework/message_format.py
Normal file
73
backend/onyx/agents/agent_framework/message_format.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import FunctionMessage
|
||||
|
||||
from onyx.llm.message_types import AssistantMessage
|
||||
from onyx.llm.message_types import ChatCompletionMessage
|
||||
from onyx.llm.message_types import FunctionCall
|
||||
from onyx.llm.message_types import SystemMessage
|
||||
from onyx.llm.message_types import ToolCall
|
||||
from onyx.llm.message_types import ToolMessage
|
||||
from onyx.llm.message_types import UserMessageWithText
|
||||
|
||||
|
||||
HUMAN = "human"
|
||||
SYSTEM = "system"
|
||||
AI = "ai"
|
||||
FUNCTION = "function"
|
||||
|
||||
|
||||
def base_messages_to_chat_completion_msgs(
|
||||
msgs: Sequence[BaseMessage],
|
||||
) -> list[ChatCompletionMessage]:
|
||||
return [_base_message_to_chat_completion_msg(msg) for msg in msgs]
|
||||
|
||||
|
||||
def _base_message_to_chat_completion_msg(
|
||||
msg: BaseMessage,
|
||||
) -> ChatCompletionMessage:
|
||||
if msg.type == HUMAN:
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
user_msg: UserMessageWithText = {"role": "user", "content": content}
|
||||
return user_msg
|
||||
if msg.type == SYSTEM:
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
system_msg: SystemMessage = {"role": "system", "content": content}
|
||||
return system_msg
|
||||
if msg.type == AI:
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
assistant_msg: AssistantMessage = {
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
}
|
||||
if isinstance(msg, AIMessage) and msg.tool_calls:
|
||||
assistant_msg["tool_calls"] = [
|
||||
ToolCall(
|
||||
id=tool_call.get("id") or "",
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=tool_call["name"],
|
||||
arguments=json.dumps(tool_call["args"]),
|
||||
),
|
||||
)
|
||||
for tool_call in msg.tool_calls
|
||||
]
|
||||
return assistant_msg
|
||||
if msg.type == FUNCTION:
|
||||
function_message = cast(FunctionMessage, msg)
|
||||
content = (
|
||||
function_message.content
|
||||
if isinstance(function_message.content, str)
|
||||
else str(function_message.content)
|
||||
)
|
||||
tool_msg: ToolMessage = {
|
||||
"role": "tool",
|
||||
"content": content,
|
||||
"tool_call_id": function_message.name or "",
|
||||
}
|
||||
return tool_msg
|
||||
raise ValueError(f"Unexpected message type: {msg.type}")
|
||||
@@ -1,215 +1,309 @@
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
# import json
|
||||
# from collections.abc import Callable
|
||||
# from collections.abc import Iterator
|
||||
# from collections.abc import Sequence
|
||||
# from dataclasses import dataclass
|
||||
# from typing import Any
|
||||
|
||||
from onyx.agents.agent_framework.models import RunItemStreamEvent
|
||||
from onyx.agents.agent_framework.models import StreamEvent
|
||||
from onyx.agents.agent_framework.models import ToolCallOutputStreamItem
|
||||
from onyx.agents.agent_framework.models import ToolCallStreamItem
|
||||
from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.message_types import ChatCompletionMessage
|
||||
from onyx.llm.message_types import ToolCall
|
||||
from onyx.llm.model_response import ModelResponseStream
|
||||
from onyx.tools.tool import RunContextWrapper
|
||||
from onyx.tools.tool import Tool
|
||||
# from onyx.agents.agent_framework.models import RunItemStreamEvent
|
||||
# from onyx.agents.agent_framework.models import StreamEvent
|
||||
# from onyx.agents.agent_framework.models import ToolCallStreamItem
|
||||
# from onyx.llm.interfaces import LanguageModelInput
|
||||
# from onyx.llm.interfaces import LLM
|
||||
# from onyx.llm.interfaces import ToolChoiceOptions
|
||||
# from onyx.llm.message_types import ChatCompletionMessage
|
||||
# from onyx.llm.message_types import ToolCall
|
||||
# from onyx.llm.model_response import ModelResponseStream
|
||||
# from onyx.tools.tool import Tool
|
||||
# from onyx.tracing.framework.create import agent_span
|
||||
# from onyx.tracing.framework.create import generation_span
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryResult:
|
||||
stream: Iterator[StreamEvent]
|
||||
new_messages_stateful: list[ChatCompletionMessage]
|
||||
# @dataclass
|
||||
# class QueryResult:
|
||||
# stream: Iterator[StreamEvent]
|
||||
# new_messages_stateful: list[ChatCompletionMessage]
|
||||
|
||||
|
||||
def _serialize_tool_output(output: Any) -> str:
|
||||
if isinstance(output, str):
|
||||
return output
|
||||
try:
|
||||
return json.dumps(output)
|
||||
except TypeError:
|
||||
return str(output)
|
||||
# def _serialize_tool_output(output: Any) -> str:
|
||||
# if isinstance(output, str):
|
||||
# return output
|
||||
# try:
|
||||
# return json.dumps(output)
|
||||
# except TypeError:
|
||||
# return str(output)
|
||||
|
||||
|
||||
def _update_tool_call_with_delta(
|
||||
tool_calls_in_progress: dict[int, dict[str, Any]],
|
||||
tool_call_delta: Any,
|
||||
) -> None:
|
||||
index = tool_call_delta.index
|
||||
# def _parse_tool_calls_from_message_content(
|
||||
# content: str,
|
||||
# ) -> list[dict[str, Any]]:
|
||||
# """Parse JSON content that represents tool call instructions."""
|
||||
# try:
|
||||
# parsed_content = json.loads(content)
|
||||
# except json.JSONDecodeError:
|
||||
# return []
|
||||
|
||||
if index not in tool_calls_in_progress:
|
||||
tool_calls_in_progress[index] = {
|
||||
"id": None,
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
}
|
||||
# if isinstance(parsed_content, dict):
|
||||
# candidates = [parsed_content]
|
||||
# elif isinstance(parsed_content, list):
|
||||
# candidates = [item for item in parsed_content if isinstance(item, dict)]
|
||||
# else:
|
||||
# return []
|
||||
|
||||
if tool_call_delta.id:
|
||||
tool_calls_in_progress[index]["id"] = tool_call_delta.id
|
||||
# tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
if tool_call_delta.function:
|
||||
if tool_call_delta.function.name:
|
||||
tool_calls_in_progress[index]["name"] = tool_call_delta.function.name
|
||||
# for candidate in candidates:
|
||||
# name = candidate.get("name")
|
||||
# arguments = candidate.get("arguments")
|
||||
|
||||
if tool_call_delta.function.arguments:
|
||||
tool_calls_in_progress[index][
|
||||
"arguments"
|
||||
] += tool_call_delta.function.arguments
|
||||
# if not isinstance(name, str) or arguments is None:
|
||||
# continue
|
||||
|
||||
# if not isinstance(arguments, dict):
|
||||
# continue
|
||||
|
||||
# call_id = candidate.get("id")
|
||||
# arguments_str = json.dumps(arguments)
|
||||
# tool_calls.append(
|
||||
# {
|
||||
# "id": call_id,
|
||||
# "name": name,
|
||||
# "arguments": arguments_str,
|
||||
# }
|
||||
# )
|
||||
|
||||
# return tool_calls
|
||||
|
||||
|
||||
def query(
|
||||
llm_with_default_settings: LLM,
|
||||
messages: LanguageModelInput,
|
||||
tools: Sequence[Tool],
|
||||
context: Any,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
) -> QueryResult:
|
||||
tool_definitions = [tool.tool_definition() for tool in tools]
|
||||
tools_by_name = {tool.name: tool for tool in tools}
|
||||
# def _try_convert_content_to_tool_calls_for_non_tool_calling_llms(
|
||||
# tool_calls_in_progress: dict[int, dict[str, Any]],
|
||||
# content_parts: list[str],
|
||||
# structured_response_format: dict | None,
|
||||
# next_synthetic_tool_call_id: Callable[[], str],
|
||||
# ) -> None:
|
||||
# """Populate tool_calls_in_progress when a non-tool-calling LLM returns JSON content describing tool calls."""
|
||||
# if tool_calls_in_progress or not content_parts or structured_response_format:
|
||||
# return
|
||||
|
||||
new_messages_stateful: list[ChatCompletionMessage] = []
|
||||
# tool_calls_from_content = _parse_tool_calls_from_message_content(
|
||||
# "".join(content_parts)
|
||||
# )
|
||||
|
||||
def stream_generator() -> Iterator[StreamEvent]:
|
||||
reasoning_started = False
|
||||
message_started = False
|
||||
# if not tool_calls_from_content:
|
||||
# return
|
||||
|
||||
tool_calls_in_progress: dict[int, dict[str, Any]] = {}
|
||||
# content_parts.clear()
|
||||
|
||||
content_parts: list[str] = []
|
||||
reasoning_parts: list[str] = []
|
||||
# for index, tool_call_data in enumerate(tool_calls_from_content):
|
||||
# call_id = tool_call_data["id"] or next_synthetic_tool_call_id()
|
||||
# tool_calls_in_progress[index] = {
|
||||
# "id": call_id,
|
||||
# "name": tool_call_data["name"],
|
||||
# "arguments": tool_call_data["arguments"],
|
||||
# }
|
||||
|
||||
for chunk in llm_with_default_settings.stream(
|
||||
prompt=messages,
|
||||
tools=tool_definitions,
|
||||
tool_choice=tool_choice,
|
||||
):
|
||||
assert isinstance(chunk, ModelResponseStream)
|
||||
|
||||
delta = chunk.choice.delta
|
||||
finish_reason = chunk.choice.finish_reason
|
||||
# def _update_tool_call_with_delta(
|
||||
# tool_calls_in_progress: dict[int, dict[str, Any]],
|
||||
# tool_call_delta: Any,
|
||||
# ) -> None:
|
||||
# index = tool_call_delta.index
|
||||
|
||||
if delta.reasoning_content:
|
||||
reasoning_parts.append(delta.reasoning_content)
|
||||
if not reasoning_started:
|
||||
yield RunItemStreamEvent(type="reasoning_start")
|
||||
reasoning_started = True
|
||||
# if index not in tool_calls_in_progress:
|
||||
# tool_calls_in_progress[index] = {
|
||||
# "id": None,
|
||||
# "name": None,
|
||||
# "arguments": "",
|
||||
# }
|
||||
|
||||
if delta.content:
|
||||
content_parts.append(delta.content)
|
||||
if reasoning_started:
|
||||
yield RunItemStreamEvent(type="reasoning_done")
|
||||
reasoning_started = False
|
||||
if not message_started:
|
||||
yield RunItemStreamEvent(type="message_start")
|
||||
message_started = True
|
||||
# if tool_call_delta.id:
|
||||
# tool_calls_in_progress[index]["id"] = tool_call_delta.id
|
||||
|
||||
if delta.tool_calls:
|
||||
if reasoning_started and not message_started:
|
||||
yield RunItemStreamEvent(type="reasoning_done")
|
||||
reasoning_started = False
|
||||
if message_started:
|
||||
yield RunItemStreamEvent(type="message_done")
|
||||
message_started = False
|
||||
# if tool_call_delta.function:
|
||||
# if tool_call_delta.function.name:
|
||||
# tool_calls_in_progress[index]["name"] = tool_call_delta.function.name
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
_update_tool_call_with_delta(
|
||||
tool_calls_in_progress, tool_call_delta
|
||||
)
|
||||
# if tool_call_delta.function.arguments:
|
||||
# tool_calls_in_progress[index][
|
||||
# "arguments"
|
||||
# ] += tool_call_delta.function.arguments
|
||||
|
||||
yield chunk
|
||||
|
||||
if not finish_reason:
|
||||
continue
|
||||
if message_started:
|
||||
yield RunItemStreamEvent(type="message_done")
|
||||
message_started = False
|
||||
# def query(
|
||||
# llm_with_default_settings: LLM,
|
||||
# messages: LanguageModelInput,
|
||||
# tools: Sequence[Tool],
|
||||
# context: Any,
|
||||
# tool_choice: ToolChoiceOptions | None = None,
|
||||
# structured_response_format: dict | None = None,
|
||||
# ) -> QueryResult:
|
||||
# tool_definitions = [tool.tool_definition() for tool in tools]
|
||||
# tools_by_name = {tool.name: tool for tool in tools}
|
||||
|
||||
if finish_reason == "tool_calls" and tool_calls_in_progress:
|
||||
sorted_tool_calls = sorted(tool_calls_in_progress.items())
|
||||
# new_messages_stateful: list[ChatCompletionMessage] = []
|
||||
|
||||
# Build tool calls for the message and execute tools
|
||||
assistant_tool_calls: list[ToolCall] = []
|
||||
tool_outputs: dict[str, str] = {}
|
||||
# current_span = agent_span(
|
||||
# name="agent_framework_query",
|
||||
# output_type="dict" if structured_response_format else "str",
|
||||
# )
|
||||
# current_span.start(mark_as_current=True)
|
||||
# current_span.span_data.tools = [t.name for t in tools]
|
||||
|
||||
for _, tool_call_data in sorted_tool_calls:
|
||||
call_id = tool_call_data["id"]
|
||||
name = tool_call_data["name"]
|
||||
arguments_str = tool_call_data["arguments"]
|
||||
# def stream_generator() -> Iterator[StreamEvent]:
|
||||
# message_started = False
|
||||
# reasoning_started = False
|
||||
|
||||
if call_id is None or name is None:
|
||||
continue
|
||||
# tool_calls_in_progress: dict[int, dict[str, Any]] = {}
|
||||
|
||||
assistant_tool_calls.append(
|
||||
{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": arguments_str,
|
||||
},
|
||||
}
|
||||
)
|
||||
# content_parts: list[str] = []
|
||||
|
||||
yield RunItemStreamEvent(
|
||||
type="tool_call",
|
||||
details=ToolCallStreamItem(
|
||||
call_id=call_id,
|
||||
name=name,
|
||||
arguments=arguments_str,
|
||||
),
|
||||
)
|
||||
# synthetic_tool_call_counter = 0
|
||||
|
||||
if name in tools_by_name:
|
||||
tool = tools_by_name[name]
|
||||
arguments = json.loads(arguments_str)
|
||||
# def _next_synthetic_tool_call_id() -> str:
|
||||
# nonlocal synthetic_tool_call_counter
|
||||
# call_id = f"synthetic_tool_call_{synthetic_tool_call_counter}"
|
||||
# synthetic_tool_call_counter += 1
|
||||
# return call_id
|
||||
|
||||
run_context = RunContextWrapper(context=context)
|
||||
# with generation_span( # type: ignore[misc]
|
||||
# model=llm_with_default_settings.config.model_name,
|
||||
# model_config={
|
||||
# "base_url": str(llm_with_default_settings.config.api_base or ""),
|
||||
# "model_impl": "litellm",
|
||||
# },
|
||||
# ) as span_generation:
|
||||
# # Only set input if messages is a sequence (not a string)
|
||||
# # ChatCompletionMessage TypedDicts are compatible with Mapping[str, Any] at runtime
|
||||
# if isinstance(messages, Sequence) and not isinstance(messages, str):
|
||||
# # Convert ChatCompletionMessage sequence to Sequence[Mapping[str, Any]]
|
||||
# span_generation.span_data.input = [dict(msg) for msg in messages] # type: ignore[assignment]
|
||||
# for chunk in llm_with_default_settings.stream(
|
||||
# prompt=messages,
|
||||
# tools=tool_definitions,
|
||||
# tool_choice=tool_choice,
|
||||
# structured_response_format=structured_response_format,
|
||||
# ):
|
||||
# assert isinstance(chunk, ModelResponseStream)
|
||||
# usage = getattr(chunk, "usage", None)
|
||||
# if usage:
|
||||
# span_generation.span_data.usage = {
|
||||
# "input_tokens": usage.prompt_tokens,
|
||||
# "output_tokens": usage.completion_tokens,
|
||||
# "cache_read_input_tokens": usage.cache_read_input_tokens,
|
||||
# "cache_creation_input_tokens": usage.cache_creation_input_tokens,
|
||||
# }
|
||||
|
||||
# TODO: Instead of executing sequentially, execute in parallel
|
||||
# In practice, it's not a must right now since we don't use parallel
|
||||
# tool calls, so kicking the can down the road for now.
|
||||
output = tool.run_v2(run_context, **arguments)
|
||||
tool_outputs[call_id] = _serialize_tool_output(output)
|
||||
# delta = chunk.choice.delta
|
||||
# finish_reason = chunk.choice.finish_reason
|
||||
|
||||
yield RunItemStreamEvent(
|
||||
type="tool_call_output",
|
||||
details=ToolCallOutputStreamItem(
|
||||
call_id=call_id,
|
||||
output=output,
|
||||
),
|
||||
)
|
||||
# if delta.reasoning_content:
|
||||
# if not reasoning_started:
|
||||
# yield RunItemStreamEvent(type="reasoning_start")
|
||||
# reasoning_started = True
|
||||
|
||||
new_messages_stateful.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": assistant_tool_calls,
|
||||
}
|
||||
)
|
||||
# if delta.content:
|
||||
# if reasoning_started:
|
||||
# yield RunItemStreamEvent(type="reasoning_done")
|
||||
# reasoning_started = False
|
||||
# content_parts.append(delta.content)
|
||||
# if not message_started:
|
||||
# yield RunItemStreamEvent(type="message_start")
|
||||
# message_started = True
|
||||
|
||||
for _, tool_call_data in sorted_tool_calls:
|
||||
call_id = tool_call_data["id"]
|
||||
# if delta.tool_calls:
|
||||
# if reasoning_started:
|
||||
# yield RunItemStreamEvent(type="reasoning_done")
|
||||
# reasoning_started = False
|
||||
# if message_started:
|
||||
# yield RunItemStreamEvent(type="message_done")
|
||||
# message_started = False
|
||||
|
||||
if call_id in tool_outputs:
|
||||
new_messages_stateful.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": tool_outputs[call_id],
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
)
|
||||
# for tool_call_delta in delta.tool_calls:
|
||||
# _update_tool_call_with_delta(
|
||||
# tool_calls_in_progress, tool_call_delta
|
||||
# )
|
||||
|
||||
elif finish_reason == "stop" and content_parts:
|
||||
new_messages_stateful.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "".join(content_parts),
|
||||
}
|
||||
)
|
||||
# yield chunk
|
||||
|
||||
return QueryResult(
|
||||
stream=stream_generator(),
|
||||
new_messages_stateful=new_messages_stateful,
|
||||
)
|
||||
# if not finish_reason:
|
||||
# continue
|
||||
|
||||
# if reasoning_started:
|
||||
# yield RunItemStreamEvent(type="reasoning_done")
|
||||
# reasoning_started = False
|
||||
# if message_started:
|
||||
# yield RunItemStreamEvent(type="message_done")
|
||||
# message_started = False
|
||||
|
||||
# if tool_choice != "none":
|
||||
# _try_convert_content_to_tool_calls_for_non_tool_calling_llms(
|
||||
# tool_calls_in_progress,
|
||||
# content_parts,
|
||||
# structured_response_format,
|
||||
# _next_synthetic_tool_call_id,
|
||||
# )
|
||||
|
||||
# if content_parts:
|
||||
# new_messages_stateful.append(
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": "".join(content_parts),
|
||||
# }
|
||||
# )
|
||||
# span_generation.span_data.output = new_messages_stateful
|
||||
|
||||
# # Execute tool calls outside of the stream loop and generation_span
|
||||
# if tool_calls_in_progress:
|
||||
# sorted_tool_calls = sorted(tool_calls_in_progress.items())
|
||||
|
||||
# # Build tool calls for the message and execute tools
|
||||
# assistant_tool_calls: list[ToolCall] = []
|
||||
|
||||
# for _, tool_call_data in sorted_tool_calls:
|
||||
# call_id = tool_call_data["id"]
|
||||
# name = tool_call_data["name"]
|
||||
# arguments_str = tool_call_data["arguments"]
|
||||
|
||||
# if call_id is None or name is None:
|
||||
# continue
|
||||
|
||||
# assistant_tool_calls.append(
|
||||
# {
|
||||
# "id": call_id,
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "name": name,
|
||||
# "arguments": arguments_str,
|
||||
# },
|
||||
# }
|
||||
# )
|
||||
|
||||
# yield RunItemStreamEvent(
|
||||
# type="tool_call",
|
||||
# details=ToolCallStreamItem(
|
||||
# call_id=call_id,
|
||||
# name=name,
|
||||
# arguments=arguments_str,
|
||||
# ),
|
||||
# )
|
||||
|
||||
# if name in tools_by_name:
|
||||
# tools_by_name[name]
|
||||
# json.loads(arguments_str)
|
||||
|
||||
# run_context = RunContextWrapper(context=context)
|
||||
|
||||
# TODO: Instead of executing sequentially, execute in parallel
|
||||
# In practice, it's not a must right now since we don't use parallel
|
||||
# tool calls, so kicking the can down the road for now.
|
||||
|
||||
# TODO broken for now, no need for a run_v2
|
||||
# output = tool.run_v2(run_context, **arguments)
|
||||
|
||||
# yield RunItemStreamEvent(
|
||||
# type="tool_call_output",
|
||||
# details=ToolCallOutputStreamItem(
|
||||
# call_id=call_id,
|
||||
# output=output,
|
||||
# ),
|
||||
# )
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
# from operator import add
|
||||
# from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel
|
||||
# from pydantic import BaseModel
|
||||
|
||||
|
||||
class CoreState(BaseModel):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
# class CoreState(BaseModel):
|
||||
# """
|
||||
# This is the core state that is shared across all subgraphs.
|
||||
# """
|
||||
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
current_step_nr: int = 1
|
||||
# log_messages: Annotated[list[str], add] = []
|
||||
# current_step_nr: int = 1
|
||||
|
||||
|
||||
class SubgraphCoreState(BaseModel):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
# class SubgraphCoreState(BaseModel):
|
||||
# """
|
||||
# This is the core state that is shared across all subgraphs.
|
||||
# """
|
||||
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
# log_messages: Annotated[list[str], add] = []
|
||||
|
||||
@@ -1,62 +1,62 @@
|
||||
from collections.abc import Hashable
|
||||
from typing import cast
|
||||
# from collections.abc import Hashable
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import Send
|
||||
# from langchain_core.runnables.config import RunnableConfig
|
||||
# from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
ObjectResearchInformationUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
SearchSourcesObjectsUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
# ObjectResearchInformationUpdate,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
# SearchSourcesObjectsUpdate,
|
||||
# )
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
|
||||
|
||||
def parallel_object_source_research_edge(
|
||||
state: SearchSourcesObjectsUpdate, config: RunnableConfig
|
||||
) -> list[Send | Hashable]:
|
||||
"""
|
||||
LangGraph edge to parallelize the research for an individual object and source
|
||||
"""
|
||||
# def parallel_object_source_research_edge(
|
||||
# state: SearchSourcesObjectsUpdate, config: RunnableConfig
|
||||
# ) -> list[Send | Hashable]:
|
||||
# """
|
||||
# LangGraph edge to parallelize the research for an individual object and source
|
||||
# """
|
||||
|
||||
search_objects = state.analysis_objects
|
||||
search_sources = state.analysis_sources
|
||||
# search_objects = state.analysis_objects
|
||||
# search_sources = state.analysis_sources
|
||||
|
||||
object_source_combinations = [
|
||||
(object, source) for object in search_objects for source in search_sources
|
||||
]
|
||||
# object_source_combinations = [
|
||||
# (object, source) for object in search_objects for source in search_sources
|
||||
# ]
|
||||
|
||||
return [
|
||||
Send(
|
||||
"research_object_source",
|
||||
ObjectSourceInput(
|
||||
object_source_combination=object_source_combination,
|
||||
log_messages=[],
|
||||
),
|
||||
)
|
||||
for object_source_combination in object_source_combinations
|
||||
]
|
||||
# return [
|
||||
# Send(
|
||||
# "research_object_source",
|
||||
# ObjectSourceInput(
|
||||
# object_source_combination=object_source_combination,
|
||||
# log_messages=[],
|
||||
# ),
|
||||
# )
|
||||
# for object_source_combination in object_source_combinations
|
||||
# ]
|
||||
|
||||
|
||||
def parallel_object_research_consolidation_edge(
|
||||
state: ObjectResearchInformationUpdate, config: RunnableConfig
|
||||
) -> list[Send | Hashable]:
|
||||
"""
|
||||
LangGraph edge to parallelize the research for an individual object and source
|
||||
"""
|
||||
cast(GraphConfig, config["metadata"]["config"])
|
||||
object_research_information_results = state.object_research_information_results
|
||||
# def parallel_object_research_consolidation_edge(
|
||||
# state: ObjectResearchInformationUpdate, config: RunnableConfig
|
||||
# ) -> list[Send | Hashable]:
|
||||
# """
|
||||
# LangGraph edge to parallelize the research for an individual object and source
|
||||
# """
|
||||
# cast(GraphConfig, config["metadata"]["config"])
|
||||
# object_research_information_results = state.object_research_information_results
|
||||
|
||||
return [
|
||||
Send(
|
||||
"consolidate_object_research",
|
||||
ObjectInformationInput(
|
||||
object_information=object_information,
|
||||
log_messages=[],
|
||||
),
|
||||
)
|
||||
for object_information in object_research_information_results
|
||||
]
|
||||
# return [
|
||||
# Send(
|
||||
# "consolidate_object_research",
|
||||
# ObjectInformationInput(
|
||||
# object_information=object_information,
|
||||
# log_messages=[],
|
||||
# ),
|
||||
# )
|
||||
# for object_information in object_research_information_results
|
||||
# ]
|
||||
|
||||
@@ -1,103 +1,103 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
# from langgraph.graph import END
|
||||
# from langgraph.graph import START
|
||||
# from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.edges import (
|
||||
parallel_object_research_consolidation_edge,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.edges import (
|
||||
parallel_object_source_research_edge,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.nodes.a1_search_objects import (
|
||||
search_objects,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.nodes.a2_research_object_source import (
|
||||
research_object_source,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.nodes.a3_structure_research_by_object import (
|
||||
structure_research_by_object,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.nodes.a4_consolidate_object_research import (
|
||||
consolidate_object_research,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.nodes.a5_consolidate_research import (
|
||||
consolidate_research,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import MainInput
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import MainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dc_search_analysis.edges import (
|
||||
# parallel_object_research_consolidation_edge,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dc_search_analysis.edges import (
|
||||
# parallel_object_source_research_edge,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dc_search_analysis.nodes.a1_search_objects import (
|
||||
# search_objects,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dc_search_analysis.nodes.a2_research_object_source import (
|
||||
# research_object_source,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dc_search_analysis.nodes.a3_structure_research_by_object import (
|
||||
# structure_research_by_object,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dc_search_analysis.nodes.a4_consolidate_object_research import (
|
||||
# consolidate_object_research,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dc_search_analysis.nodes.a5_consolidate_research import (
|
||||
# consolidate_research,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import MainInput
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
test_mode = False
|
||||
# test_mode = False
|
||||
|
||||
|
||||
def divide_and_conquer_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the knowledge graph search process.
|
||||
"""
|
||||
# def divide_and_conquer_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
# """
|
||||
# LangGraph graph builder for the knowledge graph search process.
|
||||
# """
|
||||
|
||||
graph = StateGraph(
|
||||
state_schema=MainState,
|
||||
input=MainInput,
|
||||
)
|
||||
# graph = StateGraph(
|
||||
# state_schema=MainState,
|
||||
# input=MainInput,
|
||||
# )
|
||||
|
||||
### Add nodes ###
|
||||
# ### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
"search_objects",
|
||||
search_objects,
|
||||
)
|
||||
# graph.add_node(
|
||||
# "search_objects",
|
||||
# search_objects,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
"structure_research_by_source",
|
||||
structure_research_by_object,
|
||||
)
|
||||
# graph.add_node(
|
||||
# "structure_research_by_source",
|
||||
# structure_research_by_object,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
"research_object_source",
|
||||
research_object_source,
|
||||
)
|
||||
# graph.add_node(
|
||||
# "research_object_source",
|
||||
# research_object_source,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
"consolidate_object_research",
|
||||
consolidate_object_research,
|
||||
)
|
||||
# graph.add_node(
|
||||
# "consolidate_object_research",
|
||||
# consolidate_object_research,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
"consolidate_research",
|
||||
consolidate_research,
|
||||
)
|
||||
# graph.add_node(
|
||||
# "consolidate_research",
|
||||
# consolidate_research,
|
||||
# )
|
||||
|
||||
### Add edges ###
|
||||
# ### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="search_objects")
|
||||
# graph.add_edge(start_key=START, end_key="search_objects")
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="search_objects",
|
||||
path=parallel_object_source_research_edge,
|
||||
path_map=["research_object_source"],
|
||||
)
|
||||
# graph.add_conditional_edges(
|
||||
# source="search_objects",
|
||||
# path=parallel_object_source_research_edge,
|
||||
# path_map=["research_object_source"],
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="research_object_source",
|
||||
end_key="structure_research_by_source",
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="research_object_source",
|
||||
# end_key="structure_research_by_source",
|
||||
# )
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="structure_research_by_source",
|
||||
path=parallel_object_research_consolidation_edge,
|
||||
path_map=["consolidate_object_research"],
|
||||
)
|
||||
# graph.add_conditional_edges(
|
||||
# source="structure_research_by_source",
|
||||
# path=parallel_object_research_consolidation_edge,
|
||||
# path_map=["consolidate_object_research"],
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="consolidate_object_research",
|
||||
end_key="consolidate_research",
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="consolidate_object_research",
|
||||
# end_key="consolidate_research",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="consolidate_research",
|
||||
end_key=END,
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="consolidate_research",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
return graph
|
||||
# return graph
|
||||
|
||||
@@ -1,146 +1,146 @@
|
||||
from typing import cast
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.messages import HumanMessage
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
|
||||
from onyx.agents.agent_search.dc_search_analysis.ops import research
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import MainState
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
SearchSourcesObjectsUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.prompts.agents.dc_prompts import DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT
|
||||
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SEPARATOR
|
||||
from onyx.prompts.agents.dc_prompts import DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT
|
||||
from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
|
||||
# from onyx.agents.agent_search.dc_search_analysis.ops import research
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
# SearchSourcesObjectsUpdate,
|
||||
# )
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
# trim_prompt_piece,
|
||||
# )
|
||||
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT
|
||||
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_SEPARATOR
|
||||
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT
|
||||
# from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
|
||||
# from onyx.utils.logger import setup_logger
|
||||
# from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def search_objects(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> SearchSourcesObjectsUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
# def search_objects(
|
||||
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> SearchSourcesObjectsUpdate:
|
||||
# """
|
||||
# LangGraph node to start the agentic search process.
|
||||
# """
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
# search_tool = graph_config.tooling.search_tool
|
||||
|
||||
if search_tool is None or graph_config.inputs.persona is None:
|
||||
raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
# if search_tool is None or graph_config.inputs.persona is None:
|
||||
# raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
|
||||
try:
|
||||
instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
# try:
|
||||
# instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
|
||||
agent_1_instructions = extract_section(
|
||||
instructions, "Agent Step 1:", "Agent Step 2:"
|
||||
)
|
||||
if agent_1_instructions is None:
|
||||
raise ValueError("Agent 1 instructions not found")
|
||||
# agent_1_instructions = extract_section(
|
||||
# instructions, "Agent Step 1:", "Agent Step 2:"
|
||||
# )
|
||||
# if agent_1_instructions is None:
|
||||
# raise ValueError("Agent 1 instructions not found")
|
||||
|
||||
agent_1_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
|
||||
# agent_1_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
|
||||
|
||||
agent_1_task = extract_section(
|
||||
agent_1_instructions, "Task:", "Independent Research Sources:"
|
||||
)
|
||||
if agent_1_task is None:
|
||||
raise ValueError("Agent 1 task not found")
|
||||
# agent_1_task = extract_section(
|
||||
# agent_1_instructions, "Task:", "Independent Research Sources:"
|
||||
# )
|
||||
# if agent_1_task is None:
|
||||
# raise ValueError("Agent 1 task not found")
|
||||
|
||||
agent_1_independent_sources_str = extract_section(
|
||||
agent_1_instructions, "Independent Research Sources:", "Output Objective:"
|
||||
)
|
||||
if agent_1_independent_sources_str is None:
|
||||
raise ValueError("Agent 1 Independent Research Sources not found")
|
||||
# agent_1_independent_sources_str = extract_section(
|
||||
# agent_1_instructions, "Independent Research Sources:", "Output Objective:"
|
||||
# )
|
||||
# if agent_1_independent_sources_str is None:
|
||||
# raise ValueError("Agent 1 Independent Research Sources not found")
|
||||
|
||||
document_sources = strings_to_document_sources(
|
||||
[
|
||||
x.strip().lower()
|
||||
for x in agent_1_independent_sources_str.split(DC_OBJECT_SEPARATOR)
|
||||
]
|
||||
)
|
||||
# document_sources = strings_to_document_sources(
|
||||
# [
|
||||
# x.strip().lower()
|
||||
# for x in agent_1_independent_sources_str.split(DC_OBJECT_SEPARATOR)
|
||||
# ]
|
||||
# )
|
||||
|
||||
agent_1_output_objective = extract_section(
|
||||
agent_1_instructions, "Output Objective:"
|
||||
)
|
||||
if agent_1_output_objective is None:
|
||||
raise ValueError("Agent 1 output objective not found")
|
||||
# agent_1_output_objective = extract_section(
|
||||
# agent_1_instructions, "Output Objective:"
|
||||
# )
|
||||
# if agent_1_output_objective is None:
|
||||
# raise ValueError("Agent 1 output objective not found")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Agent 1 instructions not found or not formatted correctly: {e}"
|
||||
)
|
||||
# except Exception as e:
|
||||
# raise ValueError(
|
||||
# f"Agent 1 instructions not found or not formatted correctly: {e}"
|
||||
# )
|
||||
|
||||
# Extract objects
|
||||
# # Extract objects
|
||||
|
||||
if agent_1_base_data is None:
|
||||
# Retrieve chunks for objects
|
||||
# if agent_1_base_data is None:
|
||||
# # Retrieve chunks for objects
|
||||
|
||||
retrieved_docs = research(question, search_tool)[:10]
|
||||
# retrieved_docs = research(question, search_tool)[:10]
|
||||
|
||||
document_texts_list = []
|
||||
for doc_num, doc in enumerate(retrieved_docs):
|
||||
chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
|
||||
document_texts_list.append(chunk_text)
|
||||
# document_texts_list = []
|
||||
# for doc_num, doc in enumerate(retrieved_docs):
|
||||
# chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
|
||||
# document_texts_list.append(chunk_text)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
# document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
dc_object_extraction_prompt = DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT.format(
|
||||
question=question,
|
||||
task=agent_1_task,
|
||||
document_text=document_texts,
|
||||
objects_of_interest=agent_1_output_objective,
|
||||
)
|
||||
else:
|
||||
dc_object_extraction_prompt = DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT.format(
|
||||
question=question,
|
||||
task=agent_1_task,
|
||||
base_data=agent_1_base_data,
|
||||
objects_of_interest=agent_1_output_objective,
|
||||
)
|
||||
# dc_object_extraction_prompt = DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT.format(
|
||||
# question=question,
|
||||
# task=agent_1_task,
|
||||
# document_text=document_texts,
|
||||
# objects_of_interest=agent_1_output_objective,
|
||||
# )
|
||||
# else:
|
||||
# dc_object_extraction_prompt = DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT.format(
|
||||
# question=question,
|
||||
# task=agent_1_task,
|
||||
# base_data=agent_1_base_data,
|
||||
# objects_of_interest=agent_1_output_objective,
|
||||
# )
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=trim_prompt_piece(
|
||||
config=graph_config.tooling.primary_llm.config,
|
||||
prompt_piece=dc_object_extraction_prompt,
|
||||
reserved_str="",
|
||||
),
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
30,
|
||||
primary_llm.invoke_langchain,
|
||||
prompt=msg,
|
||||
timeout_override=30,
|
||||
max_tokens=300,
|
||||
)
|
||||
# msg = [
|
||||
# HumanMessage(
|
||||
# content=trim_prompt_piece(
|
||||
# config=graph_config.tooling.primary_llm.config,
|
||||
# prompt_piece=dc_object_extraction_prompt,
|
||||
# reserved_str="",
|
||||
# ),
|
||||
# )
|
||||
# ]
|
||||
# primary_llm = graph_config.tooling.primary_llm
|
||||
# # Grader
|
||||
# try:
|
||||
# llm_response = run_with_timeout(
|
||||
# 30,
|
||||
# primary_llm.invoke_langchain,
|
||||
# prompt=msg,
|
||||
# timeout_override=30,
|
||||
# max_tokens=300,
|
||||
# )
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
cleaned_response = cleaned_response.split("OBJECTS:")[1]
|
||||
object_list = [x.strip() for x in cleaned_response.split(";")]
|
||||
# cleaned_response = (
|
||||
# str(llm_response.content)
|
||||
# .replace("```json\n", "")
|
||||
# .replace("\n```", "")
|
||||
# .replace("\n", "")
|
||||
# )
|
||||
# cleaned_response = cleaned_response.split("OBJECTS:")[1]
|
||||
# object_list = [x.strip() for x in cleaned_response.split(";")]
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in search_objects: {e}")
|
||||
# except Exception as e:
|
||||
# raise ValueError(f"Error in search_objects: {e}")
|
||||
|
||||
return SearchSourcesObjectsUpdate(
|
||||
analysis_objects=object_list,
|
||||
analysis_sources=document_sources,
|
||||
log_messages=["Agent 1 Task done"],
|
||||
)
|
||||
# return SearchSourcesObjectsUpdate(
|
||||
# analysis_objects=object_list,
|
||||
# analysis_sources=document_sources,
|
||||
# log_messages=["Agent 1 Task done"],
|
||||
# )
|
||||
|
||||
@@ -1,180 +1,180 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import cast
|
||||
# from datetime import datetime
|
||||
# from datetime import timedelta
|
||||
# from datetime import timezone
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.messages import HumanMessage
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
|
||||
from onyx.agents.agent_search.dc_search_analysis.ops import research
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
ObjectSourceResearchUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SOURCE_RESEARCH_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
|
||||
# from onyx.agents.agent_search.dc_search_analysis.ops import research
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
# ObjectSourceResearchUpdate,
|
||||
# )
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
# trim_prompt_piece,
|
||||
# )
|
||||
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_SOURCE_RESEARCH_PROMPT
|
||||
# from onyx.utils.logger import setup_logger
|
||||
# from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def research_object_source(
|
||||
state: ObjectSourceInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ObjectSourceResearchUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
datetime.now()
|
||||
# def research_object_source(
|
||||
# state: ObjectSourceInput,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> ObjectSourceResearchUpdate:
|
||||
# """
|
||||
# LangGraph node to start the agentic search process.
|
||||
# """
|
||||
# datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
object, document_source = state.object_source_combination
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# search_tool = graph_config.tooling.search_tool
|
||||
# question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
# object, document_source = state.object_source_combination
|
||||
|
||||
if search_tool is None or graph_config.inputs.persona is None:
|
||||
raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
# if search_tool is None or graph_config.inputs.persona is None:
|
||||
# raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
|
||||
try:
|
||||
instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
# try:
|
||||
# instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
|
||||
agent_2_instructions = extract_section(
|
||||
instructions, "Agent Step 2:", "Agent Step 3:"
|
||||
)
|
||||
if agent_2_instructions is None:
|
||||
raise ValueError("Agent 2 instructions not found")
|
||||
# agent_2_instructions = extract_section(
|
||||
# instructions, "Agent Step 2:", "Agent Step 3:"
|
||||
# )
|
||||
# if agent_2_instructions is None:
|
||||
# raise ValueError("Agent 2 instructions not found")
|
||||
|
||||
agent_2_task = extract_section(
|
||||
agent_2_instructions, "Task:", "Independent Research Sources:"
|
||||
)
|
||||
if agent_2_task is None:
|
||||
raise ValueError("Agent 2 task not found")
|
||||
# agent_2_task = extract_section(
|
||||
# agent_2_instructions, "Task:", "Independent Research Sources:"
|
||||
# )
|
||||
# if agent_2_task is None:
|
||||
# raise ValueError("Agent 2 task not found")
|
||||
|
||||
agent_2_time_cutoff = extract_section(
|
||||
agent_2_instructions, "Time Cutoff:", "Research Topics:"
|
||||
)
|
||||
# agent_2_time_cutoff = extract_section(
|
||||
# agent_2_instructions, "Time Cutoff:", "Research Topics:"
|
||||
# )
|
||||
|
||||
agent_2_research_topics = extract_section(
|
||||
agent_2_instructions, "Research Topics:", "Output Objective"
|
||||
)
|
||||
# agent_2_research_topics = extract_section(
|
||||
# agent_2_instructions, "Research Topics:", "Output Objective"
|
||||
# )
|
||||
|
||||
agent_2_output_objective = extract_section(
|
||||
agent_2_instructions, "Output Objective:"
|
||||
)
|
||||
if agent_2_output_objective is None:
|
||||
raise ValueError("Agent 2 output objective not found")
|
||||
# agent_2_output_objective = extract_section(
|
||||
# agent_2_instructions, "Output Objective:"
|
||||
# )
|
||||
# if agent_2_output_objective is None:
|
||||
# raise ValueError("Agent 2 output objective not found")
|
||||
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"Agent 1 instructions not found or not formatted correctly: {e}"
|
||||
)
|
||||
# except Exception:
|
||||
# raise ValueError(
|
||||
# "Agent 1 instructions not found or not formatted correctly: {e}"
|
||||
# )
|
||||
|
||||
# Populate prompt
|
||||
# # Populate prompt
|
||||
|
||||
# Retrieve chunks for objects
|
||||
# # Retrieve chunks for objects
|
||||
|
||||
if agent_2_time_cutoff is not None and agent_2_time_cutoff.strip() != "":
|
||||
if agent_2_time_cutoff.strip().endswith("d"):
|
||||
try:
|
||||
days = int(agent_2_time_cutoff.strip()[:-1])
|
||||
agent_2_source_start_time = datetime.now(timezone.utc) - timedelta(
|
||||
days=days
|
||||
)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
|
||||
)
|
||||
else:
|
||||
agent_2_source_start_time = None
|
||||
# if agent_2_time_cutoff is not None and agent_2_time_cutoff.strip() != "":
|
||||
# if agent_2_time_cutoff.strip().endswith("d"):
|
||||
# try:
|
||||
# days = int(agent_2_time_cutoff.strip()[:-1])
|
||||
# agent_2_source_start_time = datetime.now(timezone.utc) - timedelta(
|
||||
# days=days
|
||||
# )
|
||||
# except ValueError:
|
||||
# raise ValueError(
|
||||
# f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
|
||||
# )
|
||||
# else:
|
||||
# raise ValueError(
|
||||
# f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
|
||||
# )
|
||||
# else:
|
||||
# agent_2_source_start_time = None
|
||||
|
||||
document_sources = [document_source] if document_source else None
|
||||
# document_sources = [document_source] if document_source else None
|
||||
|
||||
if len(question.strip()) > 0:
|
||||
research_area = f"{question} for {object}"
|
||||
elif agent_2_research_topics and len(agent_2_research_topics.strip()) > 0:
|
||||
research_area = f"{agent_2_research_topics} for {object}"
|
||||
else:
|
||||
research_area = object
|
||||
# if len(question.strip()) > 0:
|
||||
# research_area = f"{question} for {object}"
|
||||
# elif agent_2_research_topics and len(agent_2_research_topics.strip()) > 0:
|
||||
# research_area = f"{agent_2_research_topics} for {object}"
|
||||
# else:
|
||||
# research_area = object
|
||||
|
||||
retrieved_docs = research(
|
||||
question=research_area,
|
||||
search_tool=search_tool,
|
||||
document_sources=document_sources,
|
||||
time_cutoff=agent_2_source_start_time,
|
||||
)
|
||||
# retrieved_docs = research(
|
||||
# question=research_area,
|
||||
# search_tool=search_tool,
|
||||
# document_sources=document_sources,
|
||||
# time_cutoff=agent_2_source_start_time,
|
||||
# )
|
||||
|
||||
# Generate document text
|
||||
# # Generate document text
|
||||
|
||||
document_texts_list = []
|
||||
for doc_num, doc in enumerate(retrieved_docs):
|
||||
chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
|
||||
document_texts_list.append(chunk_text)
|
||||
# document_texts_list = []
|
||||
# for doc_num, doc in enumerate(retrieved_docs):
|
||||
# chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
|
||||
# document_texts_list.append(chunk_text)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
# document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
# Built prompt
|
||||
# # Built prompt
|
||||
|
||||
today = datetime.now().strftime("%A, %Y-%m-%d")
|
||||
# today = datetime.now().strftime("%A, %Y-%m-%d")
|
||||
|
||||
dc_object_source_research_prompt = (
|
||||
DC_OBJECT_SOURCE_RESEARCH_PROMPT.format(
|
||||
today=today,
|
||||
question=question,
|
||||
task=agent_2_task,
|
||||
document_text=document_texts,
|
||||
format=agent_2_output_objective,
|
||||
)
|
||||
.replace("---object---", object)
|
||||
.replace("---source---", document_source.value)
|
||||
)
|
||||
# dc_object_source_research_prompt = (
|
||||
# DC_OBJECT_SOURCE_RESEARCH_PROMPT.format(
|
||||
# today=today,
|
||||
# question=question,
|
||||
# task=agent_2_task,
|
||||
# document_text=document_texts,
|
||||
# format=agent_2_output_objective,
|
||||
# )
|
||||
# .replace("---object---", object)
|
||||
# .replace("---source---", document_source.value)
|
||||
# )
|
||||
|
||||
# Run LLM
|
||||
# # Run LLM
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=trim_prompt_piece(
|
||||
config=graph_config.tooling.primary_llm.config,
|
||||
prompt_piece=dc_object_source_research_prompt,
|
||||
reserved_str="",
|
||||
),
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
30,
|
||||
primary_llm.invoke_langchain,
|
||||
prompt=msg,
|
||||
timeout_override=30,
|
||||
max_tokens=300,
|
||||
)
|
||||
# msg = [
|
||||
# HumanMessage(
|
||||
# content=trim_prompt_piece(
|
||||
# config=graph_config.tooling.primary_llm.config,
|
||||
# prompt_piece=dc_object_source_research_prompt,
|
||||
# reserved_str="",
|
||||
# ),
|
||||
# )
|
||||
# ]
|
||||
# primary_llm = graph_config.tooling.primary_llm
|
||||
# # Grader
|
||||
# try:
|
||||
# llm_response = run_with_timeout(
|
||||
# 30,
|
||||
# primary_llm.invoke_langchain,
|
||||
# prompt=msg,
|
||||
# timeout_override=30,
|
||||
# max_tokens=300,
|
||||
# )
|
||||
|
||||
cleaned_response = str(llm_response.content).replace("```json\n", "")
|
||||
cleaned_response = cleaned_response.split("RESEARCH RESULTS:")[1]
|
||||
object_research_results = {
|
||||
"object": object,
|
||||
"source": document_source.value,
|
||||
"research_result": cleaned_response,
|
||||
}
|
||||
# cleaned_response = str(llm_response.content).replace("```json\n", "")
|
||||
# cleaned_response = cleaned_response.split("RESEARCH RESULTS:")[1]
|
||||
# object_research_results = {
|
||||
# "object": object,
|
||||
# "source": document_source.value,
|
||||
# "research_result": cleaned_response,
|
||||
# }
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in research_object_source: {e}")
|
||||
# except Exception as e:
|
||||
# raise ValueError(f"Error in research_object_source: {e}")
|
||||
|
||||
logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
|
||||
# logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
|
||||
|
||||
return ObjectSourceResearchUpdate(
|
||||
object_source_research_results=[object_research_results],
|
||||
log_messages=["Agent Step 2 done for one object"],
|
||||
)
|
||||
# return ObjectSourceResearchUpdate(
|
||||
# object_source_research_results=[object_research_results],
|
||||
# log_messages=["Agent Step 2 done for one object"],
|
||||
# )
|
||||
|
||||
@@ -1,48 +1,48 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
# from collections import defaultdict
|
||||
# from typing import Dict
|
||||
# from typing import List
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import MainState
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
ObjectResearchInformationUpdate,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
# ObjectResearchInformationUpdate,
|
||||
# )
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def structure_research_by_object(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ObjectResearchInformationUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
# def structure_research_by_object(
|
||||
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> ObjectResearchInformationUpdate:
|
||||
# """
|
||||
# LangGraph node to start the agentic search process.
|
||||
# """
|
||||
|
||||
object_source_research_results = state.object_source_research_results
|
||||
# object_source_research_results = state.object_source_research_results
|
||||
|
||||
object_research_information_results: List[Dict[str, str]] = []
|
||||
object_research_information_results_list: Dict[str, List[str]] = defaultdict(list)
|
||||
# object_research_information_results: List[Dict[str, str]] = []
|
||||
# object_research_information_results_list: Dict[str, List[str]] = defaultdict(list)
|
||||
|
||||
for object_source_research in object_source_research_results:
|
||||
object = object_source_research["object"]
|
||||
source = object_source_research["source"]
|
||||
research_result = object_source_research["research_result"]
|
||||
# for object_source_research in object_source_research_results:
|
||||
# object = object_source_research["object"]
|
||||
# source = object_source_research["source"]
|
||||
# research_result = object_source_research["research_result"]
|
||||
|
||||
object_research_information_results_list[object].append(
|
||||
f"Source: {source}\n{research_result}"
|
||||
)
|
||||
# object_research_information_results_list[object].append(
|
||||
# f"Source: {source}\n{research_result}"
|
||||
# )
|
||||
|
||||
for object, information in object_research_information_results_list.items():
|
||||
object_research_information_results.append(
|
||||
{"object": object, "information": "\n".join(information)}
|
||||
)
|
||||
# for object, information in object_research_information_results_list.items():
|
||||
# object_research_information_results.append(
|
||||
# {"object": object, "information": "\n".join(information)}
|
||||
# )
|
||||
|
||||
logger.debug("DivCon Step A3 - Object Research Information Structuring - completed")
|
||||
# logger.debug("DivCon Step A3 - Object Research Information Structuring - completed")
|
||||
|
||||
return ObjectResearchInformationUpdate(
|
||||
object_research_information_results=object_research_information_results,
|
||||
log_messages=["A3 - Object Research Information structured"],
|
||||
)
|
||||
# return ObjectResearchInformationUpdate(
|
||||
# object_research_information_results=object_research_information_results,
|
||||
# log_messages=["A3 - Object Research Information structured"],
|
||||
# )
|
||||
|
||||
@@ -1,103 +1,103 @@
|
||||
from typing import cast
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.messages import HumanMessage
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import ObjectResearchUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.prompts.agents.dc_prompts import DC_OBJECT_CONSOLIDATION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import ObjectResearchUpdate
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
# trim_prompt_piece,
|
||||
# )
|
||||
# from onyx.prompts.agents.dc_prompts import DC_OBJECT_CONSOLIDATION_PROMPT
|
||||
# from onyx.utils.logger import setup_logger
|
||||
# from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def consolidate_object_research(
|
||||
state: ObjectInformationInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ObjectResearchUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
# def consolidate_object_research(
|
||||
# state: ObjectInformationInput,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> ObjectResearchUpdate:
|
||||
# """
|
||||
# LangGraph node to start the agentic search process.
|
||||
# """
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# search_tool = graph_config.tooling.search_tool
|
||||
# question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
if search_tool is None or graph_config.inputs.persona is None:
|
||||
raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
# if search_tool is None or graph_config.inputs.persona is None:
|
||||
# raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
|
||||
instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
# instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
|
||||
agent_4_instructions = extract_section(
|
||||
instructions, "Agent Step 4:", "Agent Step 5:"
|
||||
)
|
||||
if agent_4_instructions is None:
|
||||
raise ValueError("Agent 4 instructions not found")
|
||||
agent_4_output_objective = extract_section(
|
||||
agent_4_instructions, "Output Objective:"
|
||||
)
|
||||
if agent_4_output_objective is None:
|
||||
raise ValueError("Agent 4 output objective not found")
|
||||
# agent_4_instructions = extract_section(
|
||||
# instructions, "Agent Step 4:", "Agent Step 5:"
|
||||
# )
|
||||
# if agent_4_instructions is None:
|
||||
# raise ValueError("Agent 4 instructions not found")
|
||||
# agent_4_output_objective = extract_section(
|
||||
# agent_4_instructions, "Output Objective:"
|
||||
# )
|
||||
# if agent_4_output_objective is None:
|
||||
# raise ValueError("Agent 4 output objective not found")
|
||||
|
||||
object_information = state.object_information
|
||||
# object_information = state.object_information
|
||||
|
||||
object = object_information["object"]
|
||||
information = object_information["information"]
|
||||
# object = object_information["object"]
|
||||
# information = object_information["information"]
|
||||
|
||||
# Create a prompt for the object consolidation
|
||||
# # Create a prompt for the object consolidation
|
||||
|
||||
dc_object_consolidation_prompt = DC_OBJECT_CONSOLIDATION_PROMPT.format(
|
||||
question=question,
|
||||
object=object,
|
||||
information=information,
|
||||
format=agent_4_output_objective,
|
||||
)
|
||||
# dc_object_consolidation_prompt = DC_OBJECT_CONSOLIDATION_PROMPT.format(
|
||||
# question=question,
|
||||
# object=object,
|
||||
# information=information,
|
||||
# format=agent_4_output_objective,
|
||||
# )
|
||||
|
||||
# Run LLM
|
||||
# # Run LLM
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=trim_prompt_piece(
|
||||
config=graph_config.tooling.primary_llm.config,
|
||||
prompt_piece=dc_object_consolidation_prompt,
|
||||
reserved_str="",
|
||||
),
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
30,
|
||||
primary_llm.invoke_langchain,
|
||||
prompt=msg,
|
||||
timeout_override=30,
|
||||
max_tokens=300,
|
||||
)
|
||||
# msg = [
|
||||
# HumanMessage(
|
||||
# content=trim_prompt_piece(
|
||||
# config=graph_config.tooling.primary_llm.config,
|
||||
# prompt_piece=dc_object_consolidation_prompt,
|
||||
# reserved_str="",
|
||||
# ),
|
||||
# )
|
||||
# ]
|
||||
# primary_llm = graph_config.tooling.primary_llm
|
||||
# # Grader
|
||||
# try:
|
||||
# llm_response = run_with_timeout(
|
||||
# 30,
|
||||
# primary_llm.invoke_langchain,
|
||||
# prompt=msg,
|
||||
# timeout_override=30,
|
||||
# max_tokens=300,
|
||||
# )
|
||||
|
||||
cleaned_response = str(llm_response.content).replace("```json\n", "")
|
||||
consolidated_information = cleaned_response.split("INFORMATION:")[1]
|
||||
# cleaned_response = str(llm_response.content).replace("```json\n", "")
|
||||
# consolidated_information = cleaned_response.split("INFORMATION:")[1]
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in consolidate_object_research: {e}")
|
||||
# except Exception as e:
|
||||
# raise ValueError(f"Error in consolidate_object_research: {e}")
|
||||
|
||||
object_research_results = {
|
||||
"object": object,
|
||||
"research_result": consolidated_information,
|
||||
}
|
||||
# object_research_results = {
|
||||
# "object": object,
|
||||
# "research_result": consolidated_information,
|
||||
# }
|
||||
|
||||
logger.debug(
|
||||
"DivCon Step A4 - Object Research Consolidation - completed for an object"
|
||||
)
|
||||
# logger.debug(
|
||||
# "DivCon Step A4 - Object Research Consolidation - completed for an object"
|
||||
# )
|
||||
|
||||
return ObjectResearchUpdate(
|
||||
object_research_results=[object_research_results],
|
||||
log_messages=["Agent Source Consilidation done"],
|
||||
)
|
||||
# return ObjectResearchUpdate(
|
||||
# object_research_results=[object_research_results],
|
||||
# log_messages=["Agent Source Consilidation done"],
|
||||
# )
|
||||
|
||||
@@ -1,127 +1,127 @@
|
||||
from typing import cast
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.messages import HumanMessage
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import MainState
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import ResearchUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
|
||||
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_WITH_BASE_DATA_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
# from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import MainState
|
||||
# from onyx.agents.agent_search.dc_search_analysis.states import ResearchUpdate
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
# trim_prompt_piece,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
# from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
|
||||
# from onyx.prompts.agents.dc_prompts import DC_FORMATTING_WITH_BASE_DATA_PROMPT
|
||||
# from onyx.utils.logger import setup_logger
|
||||
# from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def consolidate_research(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ResearchUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
# def consolidate_research(
|
||||
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> ResearchUpdate:
|
||||
# """
|
||||
# LangGraph node to start the agentic search process.
|
||||
# """
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
# search_tool = graph_config.tooling.search_tool
|
||||
|
||||
if search_tool is None or graph_config.inputs.persona is None:
|
||||
raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
# if search_tool is None or graph_config.inputs.persona is None:
|
||||
# raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
|
||||
# Populate prompt
|
||||
instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
# # Populate prompt
|
||||
# instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
|
||||
try:
|
||||
agent_5_instructions = extract_section(
|
||||
instructions, "Agent Step 5:", "Agent End"
|
||||
)
|
||||
if agent_5_instructions is None:
|
||||
raise ValueError("Agent 5 instructions not found")
|
||||
agent_5_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
|
||||
agent_5_task = extract_section(
|
||||
agent_5_instructions, "Task:", "Independent Research Sources:"
|
||||
)
|
||||
if agent_5_task is None:
|
||||
raise ValueError("Agent 5 task not found")
|
||||
agent_5_output_objective = extract_section(
|
||||
agent_5_instructions, "Output Objective:"
|
||||
)
|
||||
if agent_5_output_objective is None:
|
||||
raise ValueError("Agent 5 output objective not found")
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Instructions for Agent Step 5 were not properly formatted: {e}"
|
||||
)
|
||||
# try:
|
||||
# agent_5_instructions = extract_section(
|
||||
# instructions, "Agent Step 5:", "Agent End"
|
||||
# )
|
||||
# if agent_5_instructions is None:
|
||||
# raise ValueError("Agent 5 instructions not found")
|
||||
# agent_5_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
|
||||
# agent_5_task = extract_section(
|
||||
# agent_5_instructions, "Task:", "Independent Research Sources:"
|
||||
# )
|
||||
# if agent_5_task is None:
|
||||
# raise ValueError("Agent 5 task not found")
|
||||
# agent_5_output_objective = extract_section(
|
||||
# agent_5_instructions, "Output Objective:"
|
||||
# )
|
||||
# if agent_5_output_objective is None:
|
||||
# raise ValueError("Agent 5 output objective not found")
|
||||
# except ValueError as e:
|
||||
# raise ValueError(
|
||||
# f"Instructions for Agent Step 5 were not properly formatted: {e}"
|
||||
# )
|
||||
|
||||
research_result_list = []
|
||||
# research_result_list = []
|
||||
|
||||
if agent_5_task.strip() == "*concatenate*":
|
||||
object_research_results = state.object_research_results
|
||||
# if agent_5_task.strip() == "*concatenate*":
|
||||
# object_research_results = state.object_research_results
|
||||
|
||||
for object_research_result in object_research_results:
|
||||
object = object_research_result["object"]
|
||||
research_result = object_research_result["research_result"]
|
||||
research_result_list.append(f"Object: {object}\n\n{research_result}")
|
||||
# for object_research_result in object_research_results:
|
||||
# object = object_research_result["object"]
|
||||
# research_result = object_research_result["research_result"]
|
||||
# research_result_list.append(f"Object: {object}\n\n{research_result}")
|
||||
|
||||
research_results = "\n\n".join(research_result_list)
|
||||
# research_results = "\n\n".join(research_result_list)
|
||||
|
||||
else:
|
||||
raise NotImplementedError("Only '*concatenate*' is currently supported")
|
||||
# else:
|
||||
# raise NotImplementedError("Only '*concatenate*' is currently supported")
|
||||
|
||||
# Create a prompt for the object consolidation
|
||||
# # Create a prompt for the object consolidation
|
||||
|
||||
if agent_5_base_data is None:
|
||||
dc_formatting_prompt = DC_FORMATTING_NO_BASE_DATA_PROMPT.format(
|
||||
text=research_results,
|
||||
format=agent_5_output_objective,
|
||||
)
|
||||
else:
|
||||
dc_formatting_prompt = DC_FORMATTING_WITH_BASE_DATA_PROMPT.format(
|
||||
base_data=agent_5_base_data,
|
||||
text=research_results,
|
||||
format=agent_5_output_objective,
|
||||
)
|
||||
# if agent_5_base_data is None:
|
||||
# dc_formatting_prompt = DC_FORMATTING_NO_BASE_DATA_PROMPT.format(
|
||||
# text=research_results,
|
||||
# format=agent_5_output_objective,
|
||||
# )
|
||||
# else:
|
||||
# dc_formatting_prompt = DC_FORMATTING_WITH_BASE_DATA_PROMPT.format(
|
||||
# base_data=agent_5_base_data,
|
||||
# text=research_results,
|
||||
# format=agent_5_output_objective,
|
||||
# )
|
||||
|
||||
# Run LLM
|
||||
# # Run LLM
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=trim_prompt_piece(
|
||||
config=graph_config.tooling.primary_llm.config,
|
||||
prompt_piece=dc_formatting_prompt,
|
||||
reserved_str="",
|
||||
),
|
||||
)
|
||||
]
|
||||
# msg = [
|
||||
# HumanMessage(
|
||||
# content=trim_prompt_piece(
|
||||
# config=graph_config.tooling.primary_llm.config,
|
||||
# prompt_piece=dc_formatting_prompt,
|
||||
# reserved_str="",
|
||||
# ),
|
||||
# )
|
||||
# ]
|
||||
|
||||
try:
|
||||
_ = run_with_timeout(
|
||||
60,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=msg,
|
||||
event_name="initial_agent_answer",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=30,
|
||||
max_tokens=None,
|
||||
),
|
||||
)
|
||||
# try:
|
||||
# _ = run_with_timeout(
|
||||
# 60,
|
||||
# lambda: stream_llm_answer(
|
||||
# llm=graph_config.tooling.primary_llm,
|
||||
# prompt=msg,
|
||||
# event_name="initial_agent_answer",
|
||||
# writer=writer,
|
||||
# agent_answer_level=0,
|
||||
# agent_answer_question_num=0,
|
||||
# agent_answer_type="agent_level_answer",
|
||||
# timeout_override=30,
|
||||
# max_tokens=None,
|
||||
# ),
|
||||
# )
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in consolidate_research: {e}")
|
||||
# except Exception as e:
|
||||
# raise ValueError(f"Error in consolidate_research: {e}")
|
||||
|
||||
logger.debug("DivCon Step A5 - Final Generation - completed")
|
||||
# logger.debug("DivCon Step A5 - Final Generation - completed")
|
||||
|
||||
return ResearchUpdate(
|
||||
research_results=research_results,
|
||||
log_messages=["Agent Source Consilidation done"],
|
||||
)
|
||||
# return ResearchUpdate(
|
||||
# research_results=research_results,
|
||||
# log_messages=["Agent Source Consilidation done"],
|
||||
# )
|
||||
|
||||
@@ -1,61 +1,50 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
# from onyx.chat.models import LlmDoc
|
||||
# from onyx.configs.constants import DocumentSource
|
||||
# from onyx.tools.models import SearchToolOverrideKwargs
|
||||
# from onyx.tools.tool_implementations.search.search_tool import (
|
||||
# FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
# )
|
||||
# from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
def research(
|
||||
question: str,
|
||||
search_tool: SearchTool,
|
||||
document_sources: list[DocumentSource] | None = None,
|
||||
time_cutoff: datetime | None = None,
|
||||
) -> list[LlmDoc]:
|
||||
# new db session to avoid concurrency issues
|
||||
# def research(
|
||||
# question: str,
|
||||
# search_tool: SearchTool,
|
||||
# document_sources: list[DocumentSource] | None = None,
|
||||
# time_cutoff: datetime | None = None,
|
||||
# ) -> list[LlmDoc]:
|
||||
# # new db session to avoid concurrency issues
|
||||
|
||||
callback_container: list[list[InferenceSection]] = []
|
||||
retrieved_docs: list[LlmDoc] = []
|
||||
# retrieved_docs: list[LlmDoc] = []
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=question,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=False,
|
||||
alternate_db_session=db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
skip_query_analysis=True,
|
||||
document_sources=document_sources,
|
||||
time_cutoff=time_cutoff,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
retrieved_docs = cast(list[LlmDoc], tool_response.response)[:10]
|
||||
break
|
||||
return retrieved_docs
|
||||
# for tool_response in search_tool.run(
|
||||
# query=question,
|
||||
# override_kwargs=SearchToolOverrideKwargs(original_query=question),
|
||||
# ):
|
||||
# # get retrieved docs to send to the rest of the graph
|
||||
# if tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
# retrieved_docs = cast(list[LlmDoc], tool_response.response)[:10]
|
||||
# break
|
||||
# return retrieved_docs
|
||||
|
||||
|
||||
def extract_section(
|
||||
text: str, start_marker: str, end_marker: str | None = None
|
||||
) -> str | None:
|
||||
"""Extract text between markers, returning None if markers not found"""
|
||||
parts = text.split(start_marker)
|
||||
# def extract_section(
|
||||
# text: str, start_marker: str, end_marker: str | None = None
|
||||
# ) -> str | None:
|
||||
# """Extract text between markers, returning None if markers not found"""
|
||||
# parts = text.split(start_marker)
|
||||
|
||||
if len(parts) == 1:
|
||||
return None
|
||||
# if len(parts) == 1:
|
||||
# return None
|
||||
|
||||
after_start = parts[1].strip()
|
||||
# after_start = parts[1].strip()
|
||||
|
||||
if not end_marker:
|
||||
return after_start
|
||||
# if not end_marker:
|
||||
# return after_start
|
||||
|
||||
extract = after_start.split(end_marker)[0]
|
||||
# extract = after_start.split(end_marker)[0]
|
||||
|
||||
return extract.strip()
|
||||
# return extract.strip()
|
||||
|
||||
@@ -1,72 +1,72 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import Dict
|
||||
from typing import TypedDict
|
||||
# from operator import add
|
||||
# from typing import Annotated
|
||||
# from typing import Dict
|
||||
# from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
# from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.configs.constants import DocumentSource
|
||||
# from onyx.agents.agent_search.core_state import CoreState
|
||||
# from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
# from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
# from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
# from onyx.configs.constants import DocumentSource
|
||||
|
||||
|
||||
### States ###
|
||||
class LoggerUpdate(BaseModel):
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
# ### States ###
|
||||
# class LoggerUpdate(BaseModel):
|
||||
# log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class SearchSourcesObjectsUpdate(LoggerUpdate):
|
||||
analysis_objects: list[str] = []
|
||||
analysis_sources: list[DocumentSource] = []
|
||||
# class SearchSourcesObjectsUpdate(LoggerUpdate):
|
||||
# analysis_objects: list[str] = []
|
||||
# analysis_sources: list[DocumentSource] = []
|
||||
|
||||
|
||||
class ObjectSourceInput(LoggerUpdate):
|
||||
object_source_combination: tuple[str, DocumentSource]
|
||||
# class ObjectSourceInput(LoggerUpdate):
|
||||
# object_source_combination: tuple[str, DocumentSource]
|
||||
|
||||
|
||||
class ObjectSourceResearchUpdate(LoggerUpdate):
|
||||
object_source_research_results: Annotated[list[Dict[str, str]], add] = []
|
||||
# class ObjectSourceResearchUpdate(LoggerUpdate):
|
||||
# object_source_research_results: Annotated[list[Dict[str, str]], add] = []
|
||||
|
||||
|
||||
class ObjectInformationInput(LoggerUpdate):
|
||||
object_information: Dict[str, str]
|
||||
# class ObjectInformationInput(LoggerUpdate):
|
||||
# object_information: Dict[str, str]
|
||||
|
||||
|
||||
class ObjectResearchInformationUpdate(LoggerUpdate):
|
||||
object_research_information_results: Annotated[list[Dict[str, str]], add] = []
|
||||
# class ObjectResearchInformationUpdate(LoggerUpdate):
|
||||
# object_research_information_results: Annotated[list[Dict[str, str]], add] = []
|
||||
|
||||
|
||||
class ObjectResearchUpdate(LoggerUpdate):
|
||||
object_research_results: Annotated[list[Dict[str, str]], add] = []
|
||||
# class ObjectResearchUpdate(LoggerUpdate):
|
||||
# object_research_results: Annotated[list[Dict[str, str]], add] = []
|
||||
|
||||
|
||||
class ResearchUpdate(LoggerUpdate):
|
||||
research_results: str | None = None
|
||||
# class ResearchUpdate(LoggerUpdate):
|
||||
# research_results: str | None = None
|
||||
|
||||
|
||||
## Graph Input State
|
||||
class MainInput(CoreState):
|
||||
pass
|
||||
# ## Graph Input State
|
||||
# class MainInput(CoreState):
|
||||
# pass
|
||||
|
||||
|
||||
## Graph State
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
ToolChoiceInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
SearchSourcesObjectsUpdate,
|
||||
ObjectSourceResearchUpdate,
|
||||
ObjectResearchInformationUpdate,
|
||||
ObjectResearchUpdate,
|
||||
ResearchUpdate,
|
||||
):
|
||||
pass
|
||||
# ## Graph State
|
||||
# class MainState(
|
||||
# # This includes the core state
|
||||
# MainInput,
|
||||
# ToolChoiceInput,
|
||||
# ToolCallUpdate,
|
||||
# ToolChoiceUpdate,
|
||||
# SearchSourcesObjectsUpdate,
|
||||
# ObjectSourceResearchUpdate,
|
||||
# ObjectResearchInformationUpdate,
|
||||
# ObjectResearchUpdate,
|
||||
# ResearchUpdate,
|
||||
# ):
|
||||
# pass
|
||||
|
||||
|
||||
## Graph Output State - presently not used
|
||||
class MainOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
# ## Graph Output State - presently not used
|
||||
# class MainOutput(TypedDict):
|
||||
# log_messages: list[str]
|
||||
|
||||
@@ -1,36 +1,36 @@
|
||||
from pydantic import BaseModel
|
||||
# from pydantic import BaseModel
|
||||
|
||||
|
||||
class RefinementSubQuestion(BaseModel):
|
||||
sub_question: str
|
||||
sub_question_id: str
|
||||
verified: bool
|
||||
answered: bool
|
||||
answer: str
|
||||
# class RefinementSubQuestion(BaseModel):
|
||||
# sub_question: str
|
||||
# sub_question_id: str
|
||||
# verified: bool
|
||||
# answered: bool
|
||||
# answer: str
|
||||
|
||||
|
||||
class AgentTimings(BaseModel):
|
||||
base_duration_s: float | None
|
||||
refined_duration_s: float | None
|
||||
full_duration_s: float | None
|
||||
# class AgentTimings(BaseModel):
|
||||
# base_duration_s: float | None
|
||||
# refined_duration_s: float | None
|
||||
# full_duration_s: float | None
|
||||
|
||||
|
||||
class AgentBaseMetrics(BaseModel):
|
||||
num_verified_documents_total: int | None
|
||||
num_verified_documents_core: int | None
|
||||
verified_avg_score_core: float | None
|
||||
num_verified_documents_base: int | float | None
|
||||
verified_avg_score_base: float | None = None
|
||||
base_doc_boost_factor: float | None = None
|
||||
support_boost_factor: float | None = None
|
||||
duration_s: float | None = None
|
||||
# class AgentBaseMetrics(BaseModel):
|
||||
# num_verified_documents_total: int | None
|
||||
# num_verified_documents_core: int | None
|
||||
# verified_avg_score_core: float | None
|
||||
# num_verified_documents_base: int | float | None
|
||||
# verified_avg_score_base: float | None = None
|
||||
# base_doc_boost_factor: float | None = None
|
||||
# support_boost_factor: float | None = None
|
||||
# duration_s: float | None = None
|
||||
|
||||
|
||||
class AgentRefinedMetrics(BaseModel):
|
||||
refined_doc_boost_factor: float | None = None
|
||||
refined_question_boost_factor: float | None = None
|
||||
duration_s: float | None = None
|
||||
# class AgentRefinedMetrics(BaseModel):
|
||||
# refined_doc_boost_factor: float | None = None
|
||||
# refined_question_boost_factor: float | None = None
|
||||
# duration_s: float | None = None
|
||||
|
||||
|
||||
class AgentAdditionalMetrics(BaseModel):
|
||||
pass
|
||||
# class AgentAdditionalMetrics(BaseModel):
|
||||
# pass
|
||||
|
||||
@@ -1,61 +1,61 @@
|
||||
from collections.abc import Hashable
|
||||
# from collections.abc import Hashable
|
||||
|
||||
from langgraph.graph import END
|
||||
from langgraph.types import Send
|
||||
# from langgraph.graph import END
|
||||
# from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
# from onyx.agents.agent_search.dr.enums import DRPath
|
||||
# from onyx.agents.agent_search.dr.states import MainState
|
||||
|
||||
|
||||
def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
|
||||
if not state.tools_used:
|
||||
raise IndexError("state.tools_used cannot be empty")
|
||||
# def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
|
||||
# if not state.tools_used:
|
||||
# raise IndexError("state.tools_used cannot be empty")
|
||||
|
||||
# next_tool is either a generic tool name or a DRPath string
|
||||
next_tool_name = state.tools_used[-1]
|
||||
# # next_tool is either a generic tool name or a DRPath string
|
||||
# next_tool_name = state.tools_used[-1]
|
||||
|
||||
available_tools = state.available_tools
|
||||
if not available_tools:
|
||||
raise ValueError("No tool is available. This should not happen.")
|
||||
# available_tools = state.available_tools
|
||||
# if not available_tools:
|
||||
# raise ValueError("No tool is available. This should not happen.")
|
||||
|
||||
if next_tool_name in available_tools:
|
||||
next_tool_path = available_tools[next_tool_name].path
|
||||
elif next_tool_name == DRPath.END.value:
|
||||
return END
|
||||
elif next_tool_name == DRPath.LOGGER.value:
|
||||
return DRPath.LOGGER
|
||||
elif next_tool_name == DRPath.CLOSER.value:
|
||||
return DRPath.CLOSER
|
||||
else:
|
||||
return DRPath.ORCHESTRATOR
|
||||
# if next_tool_name in available_tools:
|
||||
# next_tool_path = available_tools[next_tool_name].path
|
||||
# elif next_tool_name == DRPath.END.value:
|
||||
# return END
|
||||
# elif next_tool_name == DRPath.LOGGER.value:
|
||||
# return DRPath.LOGGER
|
||||
# elif next_tool_name == DRPath.CLOSER.value:
|
||||
# return DRPath.CLOSER
|
||||
# else:
|
||||
# return DRPath.ORCHESTRATOR
|
||||
|
||||
# handle invalid paths
|
||||
if next_tool_path == DRPath.CLARIFIER:
|
||||
raise ValueError("CLARIFIER is not a valid path during iteration")
|
||||
# # handle invalid paths
|
||||
# if next_tool_path == DRPath.CLARIFIER:
|
||||
# raise ValueError("CLARIFIER is not a valid path during iteration")
|
||||
|
||||
# handle tool calls without a query
|
||||
if (
|
||||
next_tool_path
|
||||
in (
|
||||
DRPath.INTERNAL_SEARCH,
|
||||
DRPath.WEB_SEARCH,
|
||||
DRPath.KNOWLEDGE_GRAPH,
|
||||
DRPath.IMAGE_GENERATION,
|
||||
)
|
||||
and len(state.query_list) == 0
|
||||
):
|
||||
return DRPath.CLOSER
|
||||
# # handle tool calls without a query
|
||||
# if (
|
||||
# next_tool_path
|
||||
# in (
|
||||
# DRPath.INTERNAL_SEARCH,
|
||||
# DRPath.WEB_SEARCH,
|
||||
# DRPath.KNOWLEDGE_GRAPH,
|
||||
# DRPath.IMAGE_GENERATION,
|
||||
# )
|
||||
# and len(state.query_list) == 0
|
||||
# ):
|
||||
# return DRPath.CLOSER
|
||||
|
||||
return next_tool_path
|
||||
# return next_tool_path
|
||||
|
||||
|
||||
def completeness_router(state: MainState) -> DRPath | str:
|
||||
if not state.tools_used:
|
||||
raise IndexError("tools_used cannot be empty")
|
||||
# def completeness_router(state: MainState) -> DRPath | str:
|
||||
# if not state.tools_used:
|
||||
# raise IndexError("tools_used cannot be empty")
|
||||
|
||||
# go to closer if path is CLOSER or no queries
|
||||
next_path = state.tools_used[-1]
|
||||
# # go to closer if path is CLOSER or no queries
|
||||
# next_path = state.tools_used[-1]
|
||||
|
||||
if next_path == DRPath.ORCHESTRATOR.value:
|
||||
return DRPath.ORCHESTRATOR
|
||||
return DRPath.LOGGER
|
||||
# if next_path == DRPath.ORCHESTRATOR.value:
|
||||
# return DRPath.ORCHESTRATOR
|
||||
# return DRPath.LOGGER
|
||||
|
||||
@@ -1,31 +1,27 @@
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
# from onyx.agents.agent_search.dr.enums import DRPath
|
||||
|
||||
MAX_CHAT_HISTORY_MESSAGES = (
|
||||
3 # note: actual count is x2 to account for user and assistant messages
|
||||
)
|
||||
# MAX_CHAT_HISTORY_MESSAGES = (
|
||||
# 3 # note: actual count is x2 to account for user and assistant messages
|
||||
# )
|
||||
|
||||
MAX_DR_PARALLEL_SEARCH = 4
|
||||
# MAX_DR_PARALLEL_SEARCH = 4
|
||||
|
||||
# TODO: test more, generally not needed/adds unnecessary iterations
|
||||
MAX_NUM_CLOSER_SUGGESTIONS = (
|
||||
0 # how many times the closer can send back to the orchestrator
|
||||
)
|
||||
# # TODO: test more, generally not needed/adds unnecessary iterations
|
||||
# MAX_NUM_CLOSER_SUGGESTIONS = (
|
||||
# 0 # how many times the closer can send back to the orchestrator
|
||||
# )
|
||||
|
||||
CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
|
||||
HIGH_LEVEL_PLAN_PREFIX = "The Plan:"
|
||||
# CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
|
||||
# HIGH_LEVEL_PLAN_PREFIX = "The Plan:"
|
||||
|
||||
AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
|
||||
DRPath.INTERNAL_SEARCH: 1.0,
|
||||
DRPath.KNOWLEDGE_GRAPH: 2.0,
|
||||
DRPath.WEB_SEARCH: 1.5,
|
||||
DRPath.IMAGE_GENERATION: 3.0,
|
||||
DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
|
||||
DRPath.CLOSER: 0.0,
|
||||
}
|
||||
# AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
|
||||
# DRPath.INTERNAL_SEARCH: 1.0,
|
||||
# DRPath.KNOWLEDGE_GRAPH: 2.0,
|
||||
# DRPath.WEB_SEARCH: 1.5,
|
||||
# DRPath.IMAGE_GENERATION: 3.0,
|
||||
# DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
|
||||
# DRPath.CLOSER: 0.0,
|
||||
# }
|
||||
|
||||
DR_TIME_BUDGET_BY_TYPE = {
|
||||
ResearchType.THOUGHTFUL: 3.0,
|
||||
ResearchType.DEEP: 12.0,
|
||||
ResearchType.FAST: 0.5,
|
||||
}
|
||||
# # Default time budget for agentic search (when use_agentic_search is True)
|
||||
# DR_TIME_BUDGET_DEFAULT = 12.0
|
||||
|
||||
@@ -1,112 +1,111 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import DRPromptPurpose
|
||||
from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
from onyx.prompts.dr_prompts import GET_CLARIFICATION_PROMPT
|
||||
from onyx.prompts.dr_prompts import KG_TYPES_DESCRIPTIONS
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
|
||||
from onyx.prompts.dr_prompts import TOOL_DIFFERENTIATION_HINTS
|
||||
from onyx.prompts.dr_prompts import TOOL_QUESTION_HINTS
|
||||
from onyx.prompts.prompt_template import PromptTemplate
|
||||
# from onyx.agents.agent_search.dr.enums import DRPath
|
||||
# from onyx.agents.agent_search.dr.models import DRPromptPurpose
|
||||
# from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
# from onyx.prompts.dr_prompts import GET_CLARIFICATION_PROMPT
|
||||
# from onyx.prompts.dr_prompts import KG_TYPES_DESCRIPTIONS
|
||||
# from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
|
||||
# from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
|
||||
# from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
|
||||
# from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
|
||||
# from onyx.prompts.dr_prompts import ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
|
||||
# from onyx.prompts.dr_prompts import TOOL_DIFFERENTIATION_HINTS
|
||||
# from onyx.prompts.dr_prompts import TOOL_QUESTION_HINTS
|
||||
# from onyx.prompts.prompt_template import PromptTemplate
|
||||
|
||||
|
||||
def get_dr_prompt_orchestration_templates(
|
||||
purpose: DRPromptPurpose,
|
||||
research_type: ResearchType,
|
||||
available_tools: dict[str, OrchestratorTool],
|
||||
entity_types_string: str | None = None,
|
||||
relationship_types_string: str | None = None,
|
||||
reasoning_result: str | None = None,
|
||||
tool_calls_string: str | None = None,
|
||||
) -> PromptTemplate:
|
||||
available_tools = available_tools or {}
|
||||
tool_names = list(available_tools.keys())
|
||||
tool_description_str = "\n\n".join(
|
||||
f"- {tool_name}: {tool.description}"
|
||||
for tool_name, tool in available_tools.items()
|
||||
)
|
||||
tool_cost_str = "\n".join(
|
||||
f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
|
||||
)
|
||||
# def get_dr_prompt_orchestration_templates(
|
||||
# purpose: DRPromptPurpose,
|
||||
# use_agentic_search: bool,
|
||||
# available_tools: dict[str, OrchestratorTool],
|
||||
# entity_types_string: str | None = None,
|
||||
# relationship_types_string: str | None = None,
|
||||
# reasoning_result: str | None = None,
|
||||
# tool_calls_string: str | None = None,
|
||||
# ) -> PromptTemplate:
|
||||
# available_tools = available_tools or {}
|
||||
# tool_names = list(available_tools.keys())
|
||||
# tool_description_str = "\n\n".join(
|
||||
# f"- {tool_name}: {tool.description}"
|
||||
# for tool_name, tool in available_tools.items()
|
||||
# )
|
||||
# tool_cost_str = "\n".join(
|
||||
# f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
|
||||
# )
|
||||
|
||||
tool_differentiations: list[str] = [
|
||||
TOOL_DIFFERENTIATION_HINTS[(tool_1, tool_2)]
|
||||
for tool_1 in available_tools
|
||||
for tool_2 in available_tools
|
||||
if (tool_1, tool_2) in TOOL_DIFFERENTIATION_HINTS
|
||||
]
|
||||
tool_differentiation_hint_string = (
|
||||
"\n".join(tool_differentiations) or "(No differentiating hints available)"
|
||||
)
|
||||
# TODO: add tool deliniation pairs for custom tools as well
|
||||
# tool_differentiations: list[str] = [
|
||||
# TOOL_DIFFERENTIATION_HINTS[(tool_1, tool_2)]
|
||||
# for tool_1 in available_tools
|
||||
# for tool_2 in available_tools
|
||||
# if (tool_1, tool_2) in TOOL_DIFFERENTIATION_HINTS
|
||||
# ]
|
||||
# tool_differentiation_hint_string = (
|
||||
# "\n".join(tool_differentiations) or "(No differentiating hints available)"
|
||||
# )
|
||||
# # TODO: add tool deliniation pairs for custom tools as well
|
||||
|
||||
tool_question_hint_string = (
|
||||
"\n".join(
|
||||
"- " + TOOL_QUESTION_HINTS[tool]
|
||||
for tool in available_tools
|
||||
if tool in TOOL_QUESTION_HINTS
|
||||
)
|
||||
or "(No examples available)"
|
||||
)
|
||||
# tool_question_hint_string = (
|
||||
# "\n".join(
|
||||
# "- " + TOOL_QUESTION_HINTS[tool]
|
||||
# for tool in available_tools
|
||||
# if tool in TOOL_QUESTION_HINTS
|
||||
# )
|
||||
# or "(No examples available)"
|
||||
# )
|
||||
|
||||
if DRPath.KNOWLEDGE_GRAPH.value in available_tools and (
|
||||
entity_types_string or relationship_types_string
|
||||
):
|
||||
# if DRPath.KNOWLEDGE_GRAPH.value in available_tools and (
|
||||
# entity_types_string or relationship_types_string
|
||||
# ):
|
||||
|
||||
kg_types_descriptions = KG_TYPES_DESCRIPTIONS.build(
|
||||
possible_entities=entity_types_string or "",
|
||||
possible_relationships=relationship_types_string or "",
|
||||
)
|
||||
else:
|
||||
kg_types_descriptions = "(The Knowledge Graph is not used.)"
|
||||
# kg_types_descriptions = KG_TYPES_DESCRIPTIONS.build(
|
||||
# possible_entities=entity_types_string or "",
|
||||
# possible_relationships=relationship_types_string or "",
|
||||
# )
|
||||
# else:
|
||||
# kg_types_descriptions = "(The Knowledge Graph is not used.)"
|
||||
|
||||
if purpose == DRPromptPurpose.PLAN:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
raise ValueError("plan generation is not supported for FAST time budget")
|
||||
base_template = ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
|
||||
# if purpose == DRPromptPurpose.PLAN:
|
||||
# if not use_agentic_search:
|
||||
# raise ValueError("plan generation is only supported for agentic search")
|
||||
# base_template = ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
|
||||
|
||||
elif purpose == DRPromptPurpose.NEXT_STEP_REASONING:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
base_template = ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
|
||||
else:
|
||||
raise ValueError(
|
||||
"reasoning is not separately required for DEEP time budget"
|
||||
)
|
||||
# elif purpose == DRPromptPurpose.NEXT_STEP_REASONING:
|
||||
# if not use_agentic_search:
|
||||
# base_template = ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
|
||||
# else:
|
||||
# raise ValueError(
|
||||
# "reasoning is not separately required for agentic search"
|
||||
# )
|
||||
|
||||
elif purpose == DRPromptPurpose.NEXT_STEP_PURPOSE:
|
||||
base_template = ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
|
||||
# elif purpose == DRPromptPurpose.NEXT_STEP_PURPOSE:
|
||||
# base_template = ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
|
||||
|
||||
elif purpose == DRPromptPurpose.NEXT_STEP:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
base_template = ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
|
||||
else:
|
||||
base_template = ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
|
||||
# elif purpose == DRPromptPurpose.NEXT_STEP:
|
||||
# if not use_agentic_search:
|
||||
# base_template = ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
|
||||
# else:
|
||||
# base_template = ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
|
||||
|
||||
elif purpose == DRPromptPurpose.CLARIFICATION:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
raise ValueError("clarification is not supported for FAST time budget")
|
||||
base_template = GET_CLARIFICATION_PROMPT
|
||||
# elif purpose == DRPromptPurpose.CLARIFICATION:
|
||||
# if not use_agentic_search:
|
||||
# raise ValueError("clarification is only supported for agentic search")
|
||||
# base_template = GET_CLARIFICATION_PROMPT
|
||||
|
||||
else:
|
||||
# for mypy, clearly a mypy bug
|
||||
raise ValueError(f"Invalid purpose: {purpose}")
|
||||
# else:
|
||||
# # for mypy, clearly a mypy bug
|
||||
# raise ValueError(f"Invalid purpose: {purpose}")
|
||||
|
||||
return base_template.partial_build(
|
||||
num_available_tools=str(len(tool_names)),
|
||||
available_tools=", ".join(tool_names),
|
||||
tool_choice_options=" or ".join(tool_names),
|
||||
current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
kg_types_descriptions=kg_types_descriptions,
|
||||
tool_descriptions=tool_description_str,
|
||||
tool_differentiation_hints=tool_differentiation_hint_string,
|
||||
tool_question_hints=tool_question_hint_string,
|
||||
average_tool_costs=tool_cost_str,
|
||||
reasoning_result=reasoning_result or "(No reasoning result provided.)",
|
||||
tool_calls_string=tool_calls_string or "(No tool calls provided.)",
|
||||
)
|
||||
# return base_template.partial_build(
|
||||
# num_available_tools=str(len(tool_names)),
|
||||
# available_tools=", ".join(tool_names),
|
||||
# tool_choice_options=" or ".join(tool_names),
|
||||
# current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
# kg_types_descriptions=kg_types_descriptions,
|
||||
# tool_descriptions=tool_description_str,
|
||||
# tool_differentiation_hints=tool_differentiation_hint_string,
|
||||
# tool_question_hints=tool_question_hint_string,
|
||||
# average_tool_costs=tool_cost_str,
|
||||
# reasoning_result=reasoning_result or "(No reasoning result provided.)",
|
||||
# tool_calls_string=tool_calls_string or "(No tool calls provided.)",
|
||||
# )
|
||||
|
||||
@@ -1,32 +1,22 @@
|
||||
from enum import Enum
|
||||
# from enum import Enum
|
||||
|
||||
|
||||
class ResearchType(str, Enum):
|
||||
"""Research type options for agent search operations"""
|
||||
# class ResearchAnswerPurpose(str, Enum):
|
||||
# """Research answer purpose options for agent search operations"""
|
||||
|
||||
# BASIC = "BASIC"
|
||||
LEGACY_AGENTIC = "LEGACY_AGENTIC" # only used for legacy agentic search migrations
|
||||
THOUGHTFUL = "THOUGHTFUL"
|
||||
DEEP = "DEEP"
|
||||
FAST = "FAST"
|
||||
# ANSWER = "ANSWER"
|
||||
# CLARIFICATION_REQUEST = "CLARIFICATION_REQUEST"
|
||||
|
||||
|
||||
class ResearchAnswerPurpose(str, Enum):
|
||||
"""Research answer purpose options for agent search operations"""
|
||||
|
||||
ANSWER = "ANSWER"
|
||||
CLARIFICATION_REQUEST = "CLARIFICATION_REQUEST"
|
||||
|
||||
|
||||
class DRPath(str, Enum):
|
||||
CLARIFIER = "Clarifier"
|
||||
ORCHESTRATOR = "Orchestrator"
|
||||
INTERNAL_SEARCH = "Internal Search"
|
||||
GENERIC_TOOL = "Generic Tool"
|
||||
KNOWLEDGE_GRAPH = "Knowledge Graph Search"
|
||||
WEB_SEARCH = "Web Search"
|
||||
IMAGE_GENERATION = "Image Generation"
|
||||
GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
|
||||
CLOSER = "Closer"
|
||||
LOGGER = "Logger"
|
||||
END = "End"
|
||||
# class DRPath(str, Enum):
|
||||
# CLARIFIER = "Clarifier"
|
||||
# ORCHESTRATOR = "Orchestrator"
|
||||
# INTERNAL_SEARCH = "Internal Search"
|
||||
# GENERIC_TOOL = "Generic Tool"
|
||||
# KNOWLEDGE_GRAPH = "Knowledge Graph Search"
|
||||
# WEB_SEARCH = "Web Search"
|
||||
# IMAGE_GENERATION = "Image Generation"
|
||||
# GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
|
||||
# CLOSER = "Closer"
|
||||
# LOGGER = "Logger"
|
||||
# END = "End"
|
||||
|
||||
@@ -1,88 +1,88 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
# from langgraph.graph import END
|
||||
# from langgraph.graph import START
|
||||
# from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.conditional_edges import completeness_router
|
||||
from onyx.agents.agent_search.dr.conditional_edges import decision_router
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.nodes.dr_a0_clarification import clarifier
|
||||
from onyx.agents.agent_search.dr.nodes.dr_a1_orchestrator import orchestrator
|
||||
from onyx.agents.agent_search.dr.nodes.dr_a2_closer import closer
|
||||
from onyx.agents.agent_search.dr.nodes.dr_a3_logger import logging
|
||||
from onyx.agents.agent_search.dr.states import MainInput
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_graph_builder import (
|
||||
dr_basic_search_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
|
||||
dr_custom_tool_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_graph_builder import (
|
||||
dr_generic_internal_tool_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_graph_builder import (
|
||||
dr_image_generation_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
|
||||
dr_kg_search_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_graph_builder import (
|
||||
dr_ws_graph_builder,
|
||||
)
|
||||
# from onyx.agents.agent_search.dr.conditional_edges import completeness_router
|
||||
# from onyx.agents.agent_search.dr.conditional_edges import decision_router
|
||||
# from onyx.agents.agent_search.dr.enums import DRPath
|
||||
# from onyx.agents.agent_search.dr.nodes.dr_a0_clarification import clarifier
|
||||
# from onyx.agents.agent_search.dr.nodes.dr_a1_orchestrator import orchestrator
|
||||
# from onyx.agents.agent_search.dr.nodes.dr_a2_closer import closer
|
||||
# from onyx.agents.agent_search.dr.nodes.dr_a3_logger import logging
|
||||
# from onyx.agents.agent_search.dr.states import MainInput
|
||||
# from onyx.agents.agent_search.dr.states import MainState
|
||||
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_graph_builder import (
|
||||
# dr_basic_search_graph_builder,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
|
||||
# dr_custom_tool_graph_builder,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_graph_builder import (
|
||||
# dr_generic_internal_tool_graph_builder,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_graph_builder import (
|
||||
# dr_image_generation_graph_builder,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
|
||||
# dr_kg_search_graph_builder,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_graph_builder import (
|
||||
# dr_ws_graph_builder,
|
||||
# )
|
||||
|
||||
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import search
|
||||
# # from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import search
|
||||
|
||||
|
||||
def dr_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the deep research agent.
|
||||
"""
|
||||
# def dr_graph_builder() -> StateGraph:
|
||||
# """
|
||||
# LangGraph graph builder for the deep research agent.
|
||||
# """
|
||||
|
||||
graph = StateGraph(state_schema=MainState, input=MainInput)
|
||||
# graph = StateGraph(state_schema=MainState, input=MainInput)
|
||||
|
||||
### Add nodes ###
|
||||
# ### Add nodes ###
|
||||
|
||||
graph.add_node(DRPath.CLARIFIER, clarifier)
|
||||
# graph.add_node(DRPath.CLARIFIER, clarifier)
|
||||
|
||||
graph.add_node(DRPath.ORCHESTRATOR, orchestrator)
|
||||
# graph.add_node(DRPath.ORCHESTRATOR, orchestrator)
|
||||
|
||||
basic_search_graph = dr_basic_search_graph_builder().compile()
|
||||
graph.add_node(DRPath.INTERNAL_SEARCH, basic_search_graph)
|
||||
# basic_search_graph = dr_basic_search_graph_builder().compile()
|
||||
# graph.add_node(DRPath.INTERNAL_SEARCH, basic_search_graph)
|
||||
|
||||
kg_search_graph = dr_kg_search_graph_builder().compile()
|
||||
graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
|
||||
# kg_search_graph = dr_kg_search_graph_builder().compile()
|
||||
# graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
|
||||
|
||||
internet_search_graph = dr_ws_graph_builder().compile()
|
||||
graph.add_node(DRPath.WEB_SEARCH, internet_search_graph)
|
||||
# internet_search_graph = dr_ws_graph_builder().compile()
|
||||
# graph.add_node(DRPath.WEB_SEARCH, internet_search_graph)
|
||||
|
||||
image_generation_graph = dr_image_generation_graph_builder().compile()
|
||||
graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
|
||||
# image_generation_graph = dr_image_generation_graph_builder().compile()
|
||||
# graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
|
||||
|
||||
custom_tool_graph = dr_custom_tool_graph_builder().compile()
|
||||
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
|
||||
# custom_tool_graph = dr_custom_tool_graph_builder().compile()
|
||||
# graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
|
||||
|
||||
generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
|
||||
graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)
|
||||
# generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
|
||||
# graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)
|
||||
|
||||
graph.add_node(DRPath.CLOSER, closer)
|
||||
graph.add_node(DRPath.LOGGER, logging)
|
||||
# graph.add_node(DRPath.CLOSER, closer)
|
||||
# graph.add_node(DRPath.LOGGER, logging)
|
||||
|
||||
### Add edges ###
|
||||
# ### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key=DRPath.CLARIFIER)
|
||||
# graph.add_edge(start_key=START, end_key=DRPath.CLARIFIER)
|
||||
|
||||
graph.add_conditional_edges(DRPath.CLARIFIER, decision_router)
|
||||
# graph.add_conditional_edges(DRPath.CLARIFIER, decision_router)
|
||||
|
||||
graph.add_conditional_edges(DRPath.ORCHESTRATOR, decision_router)
|
||||
# graph.add_conditional_edges(DRPath.ORCHESTRATOR, decision_router)
|
||||
|
||||
graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.WEB_SEARCH, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
# graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
|
||||
# graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
|
||||
# graph.add_edge(start_key=DRPath.WEB_SEARCH, end_key=DRPath.ORCHESTRATOR)
|
||||
# graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
|
||||
# graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
# graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
|
||||
graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
|
||||
graph.add_edge(start_key=DRPath.LOGGER, end_key=END)
|
||||
# graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
|
||||
# graph.add_edge(start_key=DRPath.LOGGER, end_key=END)
|
||||
|
||||
return graph
|
||||
# return graph
|
||||
|
||||
@@ -1,131 +1,131 @@
|
||||
from enum import Enum
|
||||
# from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
# from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
|
||||
GeneratedImage,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.tools.tool import Tool
|
||||
# from onyx.agents.agent_search.dr.enums import DRPath
|
||||
# from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
|
||||
# GeneratedImage,
|
||||
# )
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.tools.tool import Tool
|
||||
|
||||
|
||||
class OrchestratorStep(BaseModel):
|
||||
tool: str
|
||||
questions: list[str]
|
||||
# class OrchestratorStep(BaseModel):
|
||||
# tool: str
|
||||
# questions: list[str]
|
||||
|
||||
|
||||
class OrchestratorDecisonsNoPlan(BaseModel):
|
||||
reasoning: str
|
||||
next_step: OrchestratorStep
|
||||
# class OrchestratorDecisonsNoPlan(BaseModel):
|
||||
# reasoning: str
|
||||
# next_step: OrchestratorStep
|
||||
|
||||
|
||||
class OrchestrationPlan(BaseModel):
|
||||
reasoning: str
|
||||
plan: str
|
||||
# class OrchestrationPlan(BaseModel):
|
||||
# reasoning: str
|
||||
# plan: str
|
||||
|
||||
|
||||
class ClarificationGenerationResponse(BaseModel):
|
||||
clarification_needed: bool
|
||||
clarification_question: str
|
||||
# class ClarificationGenerationResponse(BaseModel):
|
||||
# clarification_needed: bool
|
||||
# clarification_question: str
|
||||
|
||||
|
||||
class DecisionResponse(BaseModel):
|
||||
reasoning: str
|
||||
decision: str
|
||||
# class DecisionResponse(BaseModel):
|
||||
# reasoning: str
|
||||
# decision: str
|
||||
|
||||
|
||||
class QueryEvaluationResponse(BaseModel):
|
||||
reasoning: str
|
||||
query_permitted: bool
|
||||
# class QueryEvaluationResponse(BaseModel):
|
||||
# reasoning: str
|
||||
# query_permitted: bool
|
||||
|
||||
|
||||
class OrchestrationClarificationInfo(BaseModel):
|
||||
clarification_question: str
|
||||
clarification_response: str | None = None
|
||||
# class OrchestrationClarificationInfo(BaseModel):
|
||||
# clarification_question: str
|
||||
# clarification_response: str | None = None
|
||||
|
||||
|
||||
class WebSearchAnswer(BaseModel):
|
||||
urls_to_open_indices: list[int]
|
||||
# class WebSearchAnswer(BaseModel):
|
||||
# urls_to_open_indices: list[int]
|
||||
|
||||
|
||||
class SearchAnswer(BaseModel):
|
||||
reasoning: str
|
||||
answer: str
|
||||
claims: list[str] | None = None
|
||||
# class SearchAnswer(BaseModel):
|
||||
# reasoning: str
|
||||
# answer: str
|
||||
# claims: list[str] | None = None
|
||||
|
||||
|
||||
class TestInfoCompleteResponse(BaseModel):
|
||||
reasoning: str
|
||||
complete: bool
|
||||
gaps: list[str]
|
||||
# class TestInfoCompleteResponse(BaseModel):
|
||||
# reasoning: str
|
||||
# complete: bool
|
||||
# gaps: list[str]
|
||||
|
||||
|
||||
# TODO: revisit with custom tools implementation in v2
|
||||
# each tool should be a class with the attributes below, plus the actual tool implementation
|
||||
# this will also allow custom tools to have their own cost
|
||||
class OrchestratorTool(BaseModel):
|
||||
tool_id: int
|
||||
name: str
|
||||
llm_path: str # the path for the LLM to refer by
|
||||
path: DRPath # the actual path in the graph
|
||||
description: str
|
||||
metadata: dict[str, str]
|
||||
cost: float
|
||||
tool_object: Tool | None = None # None for CLOSER
|
||||
# # TODO: revisit with custom tools implementation in v2
|
||||
# # each tool should be a class with the attributes below, plus the actual tool implementation
|
||||
# # this will also allow custom tools to have their own cost
|
||||
# class OrchestratorTool(BaseModel):
|
||||
# tool_id: int
|
||||
# name: str
|
||||
# llm_path: str # the path for the LLM to refer by
|
||||
# path: DRPath # the actual path in the graph
|
||||
# description: str
|
||||
# metadata: dict[str, str]
|
||||
# cost: float
|
||||
# tool_object: Tool | None = None # None for CLOSER
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
# class Config:
|
||||
# arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class IterationInstructions(BaseModel):
|
||||
iteration_nr: int
|
||||
plan: str | None
|
||||
reasoning: str
|
||||
purpose: str
|
||||
# class IterationInstructions(BaseModel):
|
||||
# iteration_nr: int
|
||||
# plan: str | None
|
||||
# reasoning: str
|
||||
# purpose: str
|
||||
|
||||
|
||||
class IterationAnswer(BaseModel):
|
||||
tool: str
|
||||
tool_id: int
|
||||
iteration_nr: int
|
||||
parallelization_nr: int
|
||||
question: str
|
||||
reasoning: str | None
|
||||
answer: str
|
||||
cited_documents: dict[int, InferenceSection]
|
||||
background_info: str | None = None
|
||||
claims: list[str] | None = None
|
||||
additional_data: dict[str, str] | None = None
|
||||
response_type: str | None = None
|
||||
data: dict | list | str | int | float | bool | None = None
|
||||
file_ids: list[str] | None = None
|
||||
# TODO: This is not ideal, but we'll can rework the schema
|
||||
# for deep research later
|
||||
is_web_fetch: bool = False
|
||||
# for image generation step-types
|
||||
generated_images: list[GeneratedImage] | None = None
|
||||
# for multi-query search tools (v2 web search and internal search)
|
||||
# TODO: Clean this up to be more flexible to tools
|
||||
queries: list[str] | None = None
|
||||
# class IterationAnswer(BaseModel):
|
||||
# tool: str
|
||||
# tool_id: int
|
||||
# iteration_nr: int
|
||||
# parallelization_nr: int
|
||||
# question: str
|
||||
# reasoning: str | None
|
||||
# answer: str
|
||||
# cited_documents: dict[int, InferenceSection]
|
||||
# background_info: str | None = None
|
||||
# claims: list[str] | None = None
|
||||
# additional_data: dict[str, str] | None = None
|
||||
# response_type: str | None = None
|
||||
# data: dict | list | str | int | float | bool | None = None
|
||||
# file_ids: list[str] | None = None
|
||||
# # TODO: This is not ideal, but we'll can rework the schema
|
||||
# # for deep research later
|
||||
# is_web_fetch: bool = False
|
||||
# # for image generation step-types
|
||||
# generated_images: list[GeneratedImage] | None = None
|
||||
# # for multi-query search tools (v2 web search and internal search)
|
||||
# # TODO: Clean this up to be more flexible to tools
|
||||
# queries: list[str] | None = None
|
||||
|
||||
|
||||
class AggregatedDRContext(BaseModel):
|
||||
context: str
|
||||
cited_documents: list[InferenceSection]
|
||||
is_internet_marker_dict: dict[str, bool]
|
||||
global_iteration_responses: list[IterationAnswer]
|
||||
# class AggregatedDRContext(BaseModel):
|
||||
# context: str
|
||||
# cited_documents: list[InferenceSection]
|
||||
# is_internet_marker_dict: dict[str, bool]
|
||||
# global_iteration_responses: list[IterationAnswer]
|
||||
|
||||
|
||||
class DRPromptPurpose(str, Enum):
|
||||
PLAN = "PLAN"
|
||||
NEXT_STEP = "NEXT_STEP"
|
||||
NEXT_STEP_REASONING = "NEXT_STEP_REASONING"
|
||||
NEXT_STEP_PURPOSE = "NEXT_STEP_PURPOSE"
|
||||
CLARIFICATION = "CLARIFICATION"
|
||||
# class DRPromptPurpose(str, Enum):
|
||||
# PLAN = "PLAN"
|
||||
# NEXT_STEP = "NEXT_STEP"
|
||||
# NEXT_STEP_REASONING = "NEXT_STEP_REASONING"
|
||||
# NEXT_STEP_PURPOSE = "NEXT_STEP_PURPOSE"
|
||||
# CLARIFICATION = "CLARIFICATION"
|
||||
|
||||
|
||||
class BaseSearchProcessingResponse(BaseModel):
|
||||
specified_source_types: list[str]
|
||||
rewritten_query: str
|
||||
time_filter: str
|
||||
# class BaseSearchProcessingResponse(BaseModel):
|
||||
# specified_source_types: list[str]
|
||||
# rewritten_query: str
|
||||
# time_filter: str
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,423 +1,418 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# import re
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
from sqlalchemy.orm import Session
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
# from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
|
||||
from onyx.agents.agent_search.dr.constants import MAX_NUM_CLOSER_SUGGESTIONS
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
from onyx.agents.agent_search.dr.models import TestInfoCompleteResponse
|
||||
from onyx.agents.agent_search.dr.states import FinalUpdate
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.states import OrchestrationUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
|
||||
GeneratedImageFullResult,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.utils import aggregate_context
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.agents.agent_search.dr.utils import get_chat_history_string
|
||||
from onyx.agents.agent_search.dr.utils import get_prompt_question
|
||||
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.chat.chat_utils import llm_doc_from_inference_section
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.chat import create_search_doc_from_inference_section
|
||||
from onyx.db.chat import update_db_session_with_messages
|
||||
from onyx.db.models import ChatMessage__SearchDoc
|
||||
from onyx.db.models import ResearchAgentIteration
|
||||
from onyx.db.models import ResearchAgentIterationSubStep
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.llm.utils import check_number_of_tokens
|
||||
from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
|
||||
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
|
||||
from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.server.query_and_chat.streaming_models import StreamingType
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
# from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
|
||||
# from onyx.agents.agent_search.dr.constants import MAX_NUM_CLOSER_SUGGESTIONS
|
||||
# from onyx.agents.agent_search.dr.enums import DRPath
|
||||
# from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
|
||||
# from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
# from onyx.agents.agent_search.dr.models import TestInfoCompleteResponse
|
||||
# from onyx.agents.agent_search.dr.states import FinalUpdate
|
||||
# from onyx.agents.agent_search.dr.states import MainState
|
||||
# from onyx.agents.agent_search.dr.states import OrchestrationUpdate
|
||||
# from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
|
||||
# GeneratedImageFullResult,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.utils import aggregate_context
|
||||
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
# from onyx.agents.agent_search.dr.utils import get_chat_history_string
|
||||
# from onyx.agents.agent_search.dr.utils import get_prompt_question
|
||||
# from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
# from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.agents.agent_search.utils import create_question_prompt
|
||||
# from onyx.chat.chat_utils import llm_doc_from_inference_section
|
||||
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.db.chat import create_search_doc_from_inference_section
|
||||
# from onyx.db.chat import update_db_session_with_messages
|
||||
# from onyx.db.models import ChatMessage__SearchDoc
|
||||
# from onyx.db.models import ResearchAgentIteration
|
||||
# from onyx.db.models import ResearchAgentIterationSubStep
|
||||
# from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
# from onyx.llm.utils import check_number_of_tokens
|
||||
# from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
# from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
|
||||
# from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
|
||||
# from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
|
||||
# from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
# from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
# from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
# from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
# from onyx.server.query_and_chat.streaming_models import StreamingType
|
||||
# from onyx.utils.logger import setup_logger
|
||||
# from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def extract_citation_numbers(text: str) -> list[int]:
|
||||
"""
|
||||
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
|
||||
Returns a list of all unique citation numbers found.
|
||||
"""
|
||||
# Pattern to match [[number]] or [[number1, number2, ...]]
|
||||
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
|
||||
matches = re.findall(pattern, text)
|
||||
# def extract_citation_numbers(text: str) -> list[int]:
|
||||
# """
|
||||
# Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
|
||||
# Returns a list of all unique citation numbers found.
|
||||
# """
|
||||
# # Pattern to match [[number]] or [[number1, number2, ...]]
|
||||
# pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
|
||||
# matches = re.findall(pattern, text)
|
||||
|
||||
cited_numbers = []
|
||||
for match in matches:
|
||||
# Split by comma and extract all numbers
|
||||
numbers = [int(num.strip()) for num in match.split(",")]
|
||||
cited_numbers.extend(numbers)
|
||||
# cited_numbers = []
|
||||
# for match in matches:
|
||||
# # Split by comma and extract all numbers
|
||||
# numbers = [int(num.strip()) for num in match.split(",")]
|
||||
# cited_numbers.extend(numbers)
|
||||
|
||||
return list(set(cited_numbers)) # Return unique numbers
|
||||
# return list(set(cited_numbers)) # Return unique numbers
|
||||
|
||||
|
||||
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
|
||||
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
|
||||
numbers = [int(num.strip()) for num in citation_content.split(",")]
|
||||
# def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
|
||||
# citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
|
||||
# numbers = [int(num.strip()) for num in citation_content.split(",")]
|
||||
|
||||
# For multiple citations like [[3, 5, 7]], create separate linked citations
|
||||
linked_citations = []
|
||||
for num in numbers:
|
||||
if num - 1 < len(docs): # Check bounds
|
||||
link = docs[num - 1].link or ""
|
||||
linked_citations.append(f"[[{num}]]({link})")
|
||||
else:
|
||||
linked_citations.append(f"[[{num}]]") # No link if out of bounds
|
||||
# # For multiple citations like [[3, 5, 7]], create separate linked citations
|
||||
# linked_citations = []
|
||||
# for num in numbers:
|
||||
# if num - 1 < len(docs): # Check bounds
|
||||
# link = docs[num - 1].link or ""
|
||||
# linked_citations.append(f"[[{num}]]({link})")
|
||||
# else:
|
||||
# linked_citations.append(f"[[{num}]]") # No link if out of bounds
|
||||
|
||||
return "".join(linked_citations)
|
||||
# return "".join(linked_citations)
|
||||
|
||||
|
||||
def insert_chat_message_search_doc_pair(
|
||||
message_id: int, search_doc_ids: list[int], db_session: Session
|
||||
) -> None:
|
||||
"""
|
||||
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
|
||||
# def insert_chat_message_search_doc_pair(
|
||||
# message_id: int, search_doc_ids: list[int], db_session: Session
|
||||
# ) -> None:
|
||||
# """
|
||||
# Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
|
||||
|
||||
Args:
|
||||
message_id: The ID of the chat message
|
||||
search_doc_id: The ID of the search document
|
||||
db_session: The database session
|
||||
"""
|
||||
for search_doc_id in search_doc_ids:
|
||||
chat_message_search_doc = ChatMessage__SearchDoc(
|
||||
chat_message_id=message_id, search_doc_id=search_doc_id
|
||||
)
|
||||
db_session.add(chat_message_search_doc)
|
||||
# Args:
|
||||
# message_id: The ID of the chat message
|
||||
# search_doc_id: The ID of the search document
|
||||
# db_session: The database session
|
||||
# """
|
||||
# for search_doc_id in search_doc_ids:
|
||||
# chat_message_search_doc = ChatMessage__SearchDoc(
|
||||
# chat_message_id=message_id, search_doc_id=search_doc_id
|
||||
# )
|
||||
# db_session.add(chat_message_search_doc)
|
||||
|
||||
|
||||
def save_iteration(
|
||||
state: MainState,
|
||||
graph_config: GraphConfig,
|
||||
aggregated_context: AggregatedDRContext,
|
||||
final_answer: str,
|
||||
all_cited_documents: list[InferenceSection],
|
||||
is_internet_marker_dict: dict[str, bool],
|
||||
) -> None:
|
||||
db_session = graph_config.persistence.db_session
|
||||
message_id = graph_config.persistence.message_id
|
||||
research_type = graph_config.behavior.research_type
|
||||
db_session = graph_config.persistence.db_session
|
||||
# def save_iteration(
|
||||
# state: MainState,
|
||||
# graph_config: GraphConfig,
|
||||
# aggregated_context: AggregatedDRContext,
|
||||
# final_answer: str,
|
||||
# all_cited_documents: list[InferenceSection],
|
||||
# is_internet_marker_dict: dict[str, bool],
|
||||
# ) -> None:
|
||||
# db_session = graph_config.persistence.db_session
|
||||
# message_id = graph_config.persistence.message_id
|
||||
|
||||
# first, insert the search_docs
|
||||
search_docs = [
|
||||
create_search_doc_from_inference_section(
|
||||
inference_section=inference_section,
|
||||
is_internet=is_internet_marker_dict.get(
|
||||
inference_section.center_chunk.document_id, False
|
||||
), # TODO: revisit
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
for inference_section in all_cited_documents
|
||||
]
|
||||
# # first, insert the search_docs
|
||||
# search_docs = [
|
||||
# create_search_doc_from_inference_section(
|
||||
# inference_section=inference_section,
|
||||
# is_internet=is_internet_marker_dict.get(
|
||||
# inference_section.center_chunk.document_id, False
|
||||
# ), # TODO: revisit
|
||||
# db_session=db_session,
|
||||
# commit=False,
|
||||
# )
|
||||
# for inference_section in all_cited_documents
|
||||
# ]
|
||||
|
||||
# then, map_search_docs to message
|
||||
insert_chat_message_search_doc_pair(
|
||||
message_id, [search_doc.id for search_doc in search_docs], db_session
|
||||
)
|
||||
# # then, map_search_docs to message
|
||||
# insert_chat_message_search_doc_pair(
|
||||
# message_id, [search_doc.id for search_doc in search_docs], db_session
|
||||
# )
|
||||
|
||||
# lastly, insert the citations
|
||||
citation_dict: dict[int, int] = {}
|
||||
cited_doc_nrs = extract_citation_numbers(final_answer)
|
||||
for cited_doc_nr in cited_doc_nrs:
|
||||
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
|
||||
# # lastly, insert the citations
|
||||
# citation_dict: dict[int, int] = {}
|
||||
# cited_doc_nrs = extract_citation_numbers(final_answer)
|
||||
# for cited_doc_nr in cited_doc_nrs:
|
||||
# citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
|
||||
|
||||
# TODO: generate plan as dict in the first place
|
||||
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
|
||||
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
|
||||
# # TODO: generate plan as dict in the first place
|
||||
# plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
|
||||
# plan_of_record_dict = parse_plan_to_dict(plan_of_record)
|
||||
|
||||
# Update the chat message and its parent message in database
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=graph_config.persistence.chat_session_id,
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=final_answer,
|
||||
citations=citation_dict,
|
||||
research_type=research_type,
|
||||
research_plan=plan_of_record_dict,
|
||||
final_documents=search_docs,
|
||||
update_parent_message=True,
|
||||
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
)
|
||||
# # Update the chat message and its parent message in database
|
||||
# update_db_session_with_messages(
|
||||
# db_session=db_session,
|
||||
# chat_message_id=message_id,
|
||||
# chat_session_id=graph_config.persistence.chat_session_id,
|
||||
# is_agentic=graph_config.behavior.use_agentic_search,
|
||||
# message=final_answer,
|
||||
# citations=citation_dict,
|
||||
# research_type=None, # research_type is deprecated
|
||||
# research_plan=plan_of_record_dict,
|
||||
# final_documents=search_docs,
|
||||
# update_parent_message=True,
|
||||
# research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
# )
|
||||
|
||||
for iteration_preparation in state.iteration_instructions:
|
||||
research_agent_iteration_step = ResearchAgentIteration(
|
||||
primary_question_id=message_id,
|
||||
reasoning=iteration_preparation.reasoning,
|
||||
purpose=iteration_preparation.purpose,
|
||||
iteration_nr=iteration_preparation.iteration_nr,
|
||||
)
|
||||
db_session.add(research_agent_iteration_step)
|
||||
# for iteration_preparation in state.iteration_instructions:
|
||||
# research_agent_iteration_step = ResearchAgentIteration(
|
||||
# primary_question_id=message_id,
|
||||
# reasoning=iteration_preparation.reasoning,
|
||||
# purpose=iteration_preparation.purpose,
|
||||
# iteration_nr=iteration_preparation.iteration_nr,
|
||||
# )
|
||||
# db_session.add(research_agent_iteration_step)
|
||||
|
||||
for iteration_answer in aggregated_context.global_iteration_responses:
|
||||
# for iteration_answer in aggregated_context.global_iteration_responses:
|
||||
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
list(iteration_answer.cited_documents.values())
|
||||
)
|
||||
# retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
# list(iteration_answer.cited_documents.values())
|
||||
# )
|
||||
|
||||
# Convert SavedSearchDoc objects to JSON-serializable format
|
||||
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
|
||||
# # Convert SavedSearchDoc objects to JSON-serializable format
|
||||
# serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
|
||||
|
||||
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
|
||||
primary_question_id=message_id,
|
||||
iteration_nr=iteration_answer.iteration_nr,
|
||||
iteration_sub_step_nr=iteration_answer.parallelization_nr,
|
||||
sub_step_instructions=iteration_answer.question,
|
||||
sub_step_tool_id=iteration_answer.tool_id,
|
||||
sub_answer=iteration_answer.answer,
|
||||
reasoning=iteration_answer.reasoning,
|
||||
claims=iteration_answer.claims,
|
||||
cited_doc_results=serialized_search_docs,
|
||||
generated_images=(
|
||||
GeneratedImageFullResult(images=iteration_answer.generated_images)
|
||||
if iteration_answer.generated_images
|
||||
else None
|
||||
),
|
||||
additional_data=iteration_answer.additional_data,
|
||||
queries=iteration_answer.queries,
|
||||
)
|
||||
db_session.add(research_agent_iteration_sub_step)
|
||||
# research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
|
||||
# primary_question_id=message_id,
|
||||
# iteration_nr=iteration_answer.iteration_nr,
|
||||
# iteration_sub_step_nr=iteration_answer.parallelization_nr,
|
||||
# sub_step_instructions=iteration_answer.question,
|
||||
# sub_step_tool_id=iteration_answer.tool_id,
|
||||
# sub_answer=iteration_answer.answer,
|
||||
# reasoning=iteration_answer.reasoning,
|
||||
# claims=iteration_answer.claims,
|
||||
# cited_doc_results=serialized_search_docs,
|
||||
# generated_images=(
|
||||
# GeneratedImageFullResult(images=iteration_answer.generated_images)
|
||||
# if iteration_answer.generated_images
|
||||
# else None
|
||||
# ),
|
||||
# additional_data=iteration_answer.additional_data,
|
||||
# queries=iteration_answer.queries,
|
||||
# )
|
||||
# db_session.add(research_agent_iteration_sub_step)
|
||||
|
||||
db_session.commit()
|
||||
# db_session.commit()
|
||||
|
||||
|
||||
def closer(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> FinalUpdate | OrchestrationUpdate:
|
||||
"""
|
||||
LangGraph node to close the DR process and finalize the answer.
|
||||
"""
|
||||
# def closer(
|
||||
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> FinalUpdate | OrchestrationUpdate:
|
||||
# """
|
||||
# LangGraph node to close the DR process and finalize the answer.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# TODO: generate final answer using all the previous steps
|
||||
# (right now, answers from each step are concatenated onto each other)
|
||||
# Also, add missing fields once usage in UI is clear.
|
||||
# node_start_time = datetime.now()
|
||||
# # TODO: generate final answer using all the previous steps
|
||||
# # (right now, answers from each step are concatenated onto each other)
|
||||
# # Also, add missing fields once usage in UI is clear.
|
||||
|
||||
current_step_nr = state.current_step_nr
|
||||
# current_step_nr = state.current_step_nr
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = state.original_question
|
||||
if not base_question:
|
||||
raise ValueError("Question is required for closer")
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# base_question = state.original_question
|
||||
# if not base_question:
|
||||
# raise ValueError("Question is required for closer")
|
||||
|
||||
research_type = graph_config.behavior.research_type
|
||||
# use_agentic_search = graph_config.behavior.use_agentic_search
|
||||
|
||||
assistant_system_prompt: str = state.assistant_system_prompt or ""
|
||||
assistant_task_prompt = state.assistant_task_prompt
|
||||
# assistant_system_prompt: str = state.assistant_system_prompt or ""
|
||||
# assistant_task_prompt = state.assistant_task_prompt
|
||||
|
||||
uploaded_context = state.uploaded_test_context or ""
|
||||
# uploaded_context = state.uploaded_test_context or ""
|
||||
|
||||
clarification = state.clarification
|
||||
prompt_question = get_prompt_question(base_question, clarification)
|
||||
# clarification = state.clarification
|
||||
# prompt_question = get_prompt_question(base_question, clarification)
|
||||
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
graph_config.inputs.prompt_builder.message_history,
|
||||
MAX_CHAT_HISTORY_MESSAGES,
|
||||
)
|
||||
or "(No chat history yet available)"
|
||||
)
|
||||
# chat_history_string = (
|
||||
# get_chat_history_string(
|
||||
# graph_config.inputs.prompt_builder.message_history,
|
||||
# MAX_CHAT_HISTORY_MESSAGES,
|
||||
# )
|
||||
# or "(No chat history yet available)"
|
||||
# )
|
||||
|
||||
aggregated_context_w_docs = aggregate_context(
|
||||
state.iteration_responses, include_documents=True
|
||||
)
|
||||
# aggregated_context_w_docs = aggregate_context(
|
||||
# state.iteration_responses, include_documents=True
|
||||
# )
|
||||
|
||||
aggregated_context_wo_docs = aggregate_context(
|
||||
state.iteration_responses, include_documents=False
|
||||
)
|
||||
# aggregated_context_wo_docs = aggregate_context(
|
||||
# state.iteration_responses, include_documents=False
|
||||
# )
|
||||
|
||||
iteration_responses_w_docs_string = aggregated_context_w_docs.context
|
||||
iteration_responses_wo_docs_string = aggregated_context_wo_docs.context
|
||||
all_cited_documents = aggregated_context_w_docs.cited_documents
|
||||
# iteration_responses_w_docs_string = aggregated_context_w_docs.context
|
||||
# iteration_responses_wo_docs_string = aggregated_context_wo_docs.context
|
||||
# all_cited_documents = aggregated_context_w_docs.cited_documents
|
||||
|
||||
num_closer_suggestions = state.num_closer_suggestions
|
||||
# num_closer_suggestions = state.num_closer_suggestions
|
||||
|
||||
if (
|
||||
num_closer_suggestions < MAX_NUM_CLOSER_SUGGESTIONS
|
||||
and research_type == ResearchType.DEEP
|
||||
):
|
||||
test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
|
||||
base_question=prompt_question,
|
||||
questions_answers_claims=iteration_responses_wo_docs_string,
|
||||
chat_history_string=chat_history_string,
|
||||
high_level_plan=(
|
||||
state.plan_of_record.plan
|
||||
if state.plan_of_record
|
||||
else "No plan available"
|
||||
),
|
||||
)
|
||||
# if (
|
||||
# num_closer_suggestions < MAX_NUM_CLOSER_SUGGESTIONS
|
||||
# and use_agentic_search
|
||||
# ):
|
||||
# test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
|
||||
# base_question=prompt_question,
|
||||
# questions_answers_claims=iteration_responses_wo_docs_string,
|
||||
# chat_history_string=chat_history_string,
|
||||
# high_level_plan=(
|
||||
# state.plan_of_record.plan
|
||||
# if state.plan_of_record
|
||||
# else "No plan available"
|
||||
# ),
|
||||
# )
|
||||
|
||||
test_info_complete_json = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt,
|
||||
test_info_complete_prompt + (assistant_task_prompt or ""),
|
||||
),
|
||||
schema=TestInfoCompleteResponse,
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
# max_tokens=1000,
|
||||
)
|
||||
# test_info_complete_json = invoke_llm_json(
|
||||
# llm=graph_config.tooling.primary_llm,
|
||||
# prompt=create_question_prompt(
|
||||
# assistant_system_prompt,
|
||||
# test_info_complete_prompt + (assistant_task_prompt or ""),
|
||||
# ),
|
||||
# schema=TestInfoCompleteResponse,
|
||||
# timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
# # max_tokens=1000,
|
||||
# )
|
||||
|
||||
if test_info_complete_json.complete:
|
||||
pass
|
||||
# if test_info_complete_json.complete:
|
||||
# pass
|
||||
|
||||
else:
|
||||
return OrchestrationUpdate(
|
||||
tools_used=[DRPath.ORCHESTRATOR.value],
|
||||
query_list=[],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="closer",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
gaps=test_info_complete_json.gaps,
|
||||
num_closer_suggestions=num_closer_suggestions + 1,
|
||||
)
|
||||
# else:
|
||||
# return OrchestrationUpdate(
|
||||
# tools_used=[DRPath.ORCHESTRATOR.value],
|
||||
# query_list=[],
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="main",
|
||||
# node_name="closer",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# gaps=test_info_complete_json.gaps,
|
||||
# num_closer_suggestions=num_closer_suggestions + 1,
|
||||
# )
|
||||
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
all_cited_documents
|
||||
)
|
||||
# retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
# all_cited_documents
|
||||
# )
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
MessageStart(
|
||||
content="",
|
||||
final_documents=retrieved_search_docs,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# MessageStart(
|
||||
# content="",
|
||||
# final_documents=retrieved_search_docs,
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
if research_type in [ResearchType.THOUGHTFUL, ResearchType.FAST]:
|
||||
final_answer_base_prompt = FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
|
||||
elif research_type == ResearchType.DEEP:
|
||||
final_answer_base_prompt = FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
|
||||
else:
|
||||
raise ValueError(f"Invalid research type: {research_type}")
|
||||
# if not use_agentic_search:
|
||||
# final_answer_base_prompt = FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
|
||||
# else:
|
||||
# final_answer_base_prompt = FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
|
||||
|
||||
estimated_final_answer_prompt_tokens = check_number_of_tokens(
|
||||
final_answer_base_prompt.build(
|
||||
base_question=prompt_question,
|
||||
iteration_responses_string=iteration_responses_w_docs_string,
|
||||
chat_history_string=chat_history_string,
|
||||
uploaded_context=uploaded_context,
|
||||
)
|
||||
)
|
||||
# estimated_final_answer_prompt_tokens = check_number_of_tokens(
|
||||
# final_answer_base_prompt.build(
|
||||
# base_question=prompt_question,
|
||||
# iteration_responses_string=iteration_responses_w_docs_string,
|
||||
# chat_history_string=chat_history_string,
|
||||
# uploaded_context=uploaded_context,
|
||||
# )
|
||||
# )
|
||||
|
||||
# for DR, rely only on sub-answers and claims to save tokens if context is too long
|
||||
# TODO: consider compression step for Thoughtful mode if context is too long.
|
||||
# Should generally not be the case though.
|
||||
# # for DR, rely only on sub-answers and claims to save tokens if context is too long
|
||||
# # TODO: consider compression step for Thoughtful mode if context is too long.
|
||||
# # Should generally not be the case though.
|
||||
|
||||
max_allowed_input_tokens = graph_config.tooling.primary_llm.config.max_input_tokens
|
||||
# max_allowed_input_tokens = graph_config.tooling.primary_llm.config.max_input_tokens
|
||||
|
||||
if (
|
||||
estimated_final_answer_prompt_tokens > 0.8 * max_allowed_input_tokens
|
||||
and research_type == ResearchType.DEEP
|
||||
):
|
||||
iteration_responses_string = iteration_responses_wo_docs_string
|
||||
else:
|
||||
iteration_responses_string = iteration_responses_w_docs_string
|
||||
# if (
|
||||
# estimated_final_answer_prompt_tokens > 0.8 * max_allowed_input_tokens
|
||||
# and use_agentic_search
|
||||
# ):
|
||||
# iteration_responses_string = iteration_responses_wo_docs_string
|
||||
# else:
|
||||
# iteration_responses_string = iteration_responses_w_docs_string
|
||||
|
||||
final_answer_prompt = final_answer_base_prompt.build(
|
||||
base_question=prompt_question,
|
||||
iteration_responses_string=iteration_responses_string,
|
||||
chat_history_string=chat_history_string,
|
||||
uploaded_context=uploaded_context,
|
||||
)
|
||||
# final_answer_prompt = final_answer_base_prompt.build(
|
||||
# base_question=prompt_question,
|
||||
# iteration_responses_string=iteration_responses_string,
|
||||
# chat_history_string=chat_history_string,
|
||||
# uploaded_context=uploaded_context,
|
||||
# )
|
||||
|
||||
if graph_config.inputs.project_instructions:
|
||||
assistant_system_prompt = (
|
||||
assistant_system_prompt
|
||||
+ PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
+ (graph_config.inputs.project_instructions or "")
|
||||
)
|
||||
# if graph_config.inputs.project_instructions:
|
||||
# assistant_system_prompt = (
|
||||
# assistant_system_prompt
|
||||
# + PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
# + (graph_config.inputs.project_instructions or "")
|
||||
# )
|
||||
|
||||
all_context_llmdocs = [
|
||||
llm_doc_from_inference_section(inference_section)
|
||||
for inference_section in all_cited_documents
|
||||
]
|
||||
# all_context_llmdocs = [
|
||||
# llm_doc_from_inference_section(inference_section)
|
||||
# for inference_section in all_cited_documents
|
||||
# ]
|
||||
|
||||
try:
|
||||
streamed_output, _, citation_infos = run_with_timeout(
|
||||
int(3 * TF_DR_TIMEOUT_LONG),
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt,
|
||||
final_answer_prompt + (assistant_task_prompt or ""),
|
||||
),
|
||||
event_name="basic_response",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=int(2 * TF_DR_TIMEOUT_LONG),
|
||||
answer_piece=StreamingType.MESSAGE_DELTA.value,
|
||||
ind=current_step_nr,
|
||||
context_docs=all_context_llmdocs,
|
||||
replace_citations=True,
|
||||
# max_tokens=None,
|
||||
),
|
||||
)
|
||||
# try:
|
||||
# streamed_output, _, citation_infos = run_with_timeout(
|
||||
# int(3 * TF_DR_TIMEOUT_LONG),
|
||||
# lambda: stream_llm_answer(
|
||||
# llm=graph_config.tooling.primary_llm,
|
||||
# prompt=create_question_prompt(
|
||||
# assistant_system_prompt,
|
||||
# final_answer_prompt + (assistant_task_prompt or ""),
|
||||
# ),
|
||||
# event_name="basic_response",
|
||||
# writer=writer,
|
||||
# agent_answer_level=0,
|
||||
# agent_answer_question_num=0,
|
||||
# agent_answer_type="agent_level_answer",
|
||||
# timeout_override=int(2 * TF_DR_TIMEOUT_LONG),
|
||||
# answer_piece=StreamingType.MESSAGE_DELTA.value,
|
||||
# ind=current_step_nr,
|
||||
# context_docs=all_context_llmdocs,
|
||||
# replace_citations=True,
|
||||
# # max_tokens=None,
|
||||
# ),
|
||||
# )
|
||||
|
||||
final_answer = "".join(streamed_output)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in consolidate_research: {e}")
|
||||
# final_answer = "".join(streamed_output)
|
||||
# except Exception as e:
|
||||
# raise ValueError(f"Error in consolidate_research: {e}")
|
||||
|
||||
write_custom_event(current_step_nr, SectionEnd(), writer)
|
||||
# write_custom_event(current_step_nr, SectionEnd(), writer)
|
||||
|
||||
current_step_nr += 1
|
||||
# current_step_nr += 1
|
||||
|
||||
write_custom_event(current_step_nr, CitationStart(), writer)
|
||||
write_custom_event(current_step_nr, CitationDelta(citations=citation_infos), writer)
|
||||
write_custom_event(current_step_nr, SectionEnd(), writer)
|
||||
# write_custom_event(current_step_nr, CitationStart(), writer)
|
||||
# write_custom_event(current_step_nr, CitationDelta(citations=citation_infos), writer)
|
||||
# write_custom_event(current_step_nr, SectionEnd(), writer)
|
||||
|
||||
current_step_nr += 1
|
||||
# current_step_nr += 1
|
||||
|
||||
# Log the research agent steps
|
||||
# save_iteration(
|
||||
# state,
|
||||
# graph_config,
|
||||
# aggregated_context,
|
||||
# final_answer,
|
||||
# all_cited_documents,
|
||||
# is_internet_marker_dict,
|
||||
# )
|
||||
# # Log the research agent steps
|
||||
# # save_iteration(
|
||||
# # state,
|
||||
# # graph_config,
|
||||
# # aggregated_context,
|
||||
# # final_answer,
|
||||
# # all_cited_documents,
|
||||
# # is_internet_marker_dict,
|
||||
# # )
|
||||
|
||||
return FinalUpdate(
|
||||
final_answer=final_answer,
|
||||
all_cited_documents=all_cited_documents,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="closer",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return FinalUpdate(
|
||||
# final_answer=final_answer,
|
||||
# all_cited_documents=all_cited_documents,
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="main",
|
||||
# node_name="closer",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,248 +1,246 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# import re
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
from sqlalchemy.orm import Session
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
# from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
|
||||
from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
|
||||
GeneratedImageFullResult,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.utils import aggregate_context
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.chat import create_search_doc_from_inference_section
|
||||
from onyx.db.chat import update_db_session_with_messages
|
||||
from onyx.db.models import ChatMessage__SearchDoc
|
||||
from onyx.db.models import ResearchAgentIteration
|
||||
from onyx.db.models import ResearchAgentIterationSubStep
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
|
||||
# from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
# from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
# from onyx.agents.agent_search.dr.states import MainState
|
||||
# from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
|
||||
# GeneratedImageFullResult,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.utils import aggregate_context
|
||||
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
# from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.db.chat import create_search_doc_from_inference_section
|
||||
# from onyx.db.chat import update_db_session_with_messages
|
||||
# from onyx.db.models import ChatMessage__SearchDoc
|
||||
# from onyx.db.models import ResearchAgentIteration
|
||||
# from onyx.db.models import ResearchAgentIterationSubStep
|
||||
# from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
# from onyx.natural_language_processing.utils import get_tokenizer
|
||||
# from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_citation_numbers(text: str) -> list[int]:
|
||||
"""
|
||||
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
|
||||
Returns a list of all unique citation numbers found.
|
||||
"""
|
||||
# Pattern to match [[number]] or [[number1, number2, ...]]
|
||||
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
|
||||
matches = re.findall(pattern, text)
|
||||
# def _extract_citation_numbers(text: str) -> list[int]:
|
||||
# """
|
||||
# Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
|
||||
# Returns a list of all unique citation numbers found.
|
||||
# """
|
||||
# # Pattern to match [[number]] or [[number1, number2, ...]]
|
||||
# pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
|
||||
# matches = re.findall(pattern, text)
|
||||
|
||||
cited_numbers = []
|
||||
for match in matches:
|
||||
# Split by comma and extract all numbers
|
||||
numbers = [int(num.strip()) for num in match.split(",")]
|
||||
cited_numbers.extend(numbers)
|
||||
# cited_numbers = []
|
||||
# for match in matches:
|
||||
# # Split by comma and extract all numbers
|
||||
# numbers = [int(num.strip()) for num in match.split(",")]
|
||||
# cited_numbers.extend(numbers)
|
||||
|
||||
return list(set(cited_numbers)) # Return unique numbers
|
||||
# return list(set(cited_numbers)) # Return unique numbers
|
||||
|
||||
|
||||
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
|
||||
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
|
||||
numbers = [int(num.strip()) for num in citation_content.split(",")]
|
||||
# def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
|
||||
# citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
|
||||
# numbers = [int(num.strip()) for num in citation_content.split(",")]
|
||||
|
||||
# For multiple citations like [[3, 5, 7]], create separate linked citations
|
||||
linked_citations = []
|
||||
for num in numbers:
|
||||
if num - 1 < len(docs): # Check bounds
|
||||
link = docs[num - 1].link or ""
|
||||
linked_citations.append(f"[[{num}]]({link})")
|
||||
else:
|
||||
linked_citations.append(f"[[{num}]]") # No link if out of bounds
|
||||
# # For multiple citations like [[3, 5, 7]], create separate linked citations
|
||||
# linked_citations = []
|
||||
# for num in numbers:
|
||||
# if num - 1 < len(docs): # Check bounds
|
||||
# link = docs[num - 1].link or ""
|
||||
# linked_citations.append(f"[[{num}]]({link})")
|
||||
# else:
|
||||
# linked_citations.append(f"[[{num}]]") # No link if out of bounds
|
||||
|
||||
return "".join(linked_citations)
|
||||
# return "".join(linked_citations)
|
||||
|
||||
|
||||
def _insert_chat_message_search_doc_pair(
|
||||
message_id: int, search_doc_ids: list[int], db_session: Session
|
||||
) -> None:
|
||||
"""
|
||||
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
|
||||
# def _insert_chat_message_search_doc_pair(
|
||||
# message_id: int, search_doc_ids: list[int], db_session: Session
|
||||
# ) -> None:
|
||||
# """
|
||||
# Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
|
||||
|
||||
Args:
|
||||
message_id: The ID of the chat message
|
||||
search_doc_id: The ID of the search document
|
||||
db_session: The database session
|
||||
"""
|
||||
for search_doc_id in search_doc_ids:
|
||||
chat_message_search_doc = ChatMessage__SearchDoc(
|
||||
chat_message_id=message_id, search_doc_id=search_doc_id
|
||||
)
|
||||
db_session.add(chat_message_search_doc)
|
||||
# Args:
|
||||
# message_id: The ID of the chat message
|
||||
# search_doc_id: The ID of the search document
|
||||
# db_session: The database session
|
||||
# """
|
||||
# for search_doc_id in search_doc_ids:
|
||||
# chat_message_search_doc = ChatMessage__SearchDoc(
|
||||
# chat_message_id=message_id, search_doc_id=search_doc_id
|
||||
# )
|
||||
# db_session.add(chat_message_search_doc)
|
||||
|
||||
|
||||
def save_iteration(
|
||||
state: MainState,
|
||||
graph_config: GraphConfig,
|
||||
aggregated_context: AggregatedDRContext,
|
||||
final_answer: str,
|
||||
all_cited_documents: list[InferenceSection],
|
||||
is_internet_marker_dict: dict[str, bool],
|
||||
num_tokens: int,
|
||||
) -> None:
|
||||
db_session = graph_config.persistence.db_session
|
||||
message_id = graph_config.persistence.message_id
|
||||
research_type = graph_config.behavior.research_type
|
||||
db_session = graph_config.persistence.db_session
|
||||
# def save_iteration(
|
||||
# state: MainState,
|
||||
# graph_config: GraphConfig,
|
||||
# aggregated_context: AggregatedDRContext,
|
||||
# final_answer: str,
|
||||
# all_cited_documents: list[InferenceSection],
|
||||
# is_internet_marker_dict: dict[str, bool],
|
||||
# num_tokens: int,
|
||||
# ) -> None:
|
||||
# db_session = graph_config.persistence.db_session
|
||||
# message_id = graph_config.persistence.message_id
|
||||
|
||||
# first, insert the search_docs
|
||||
search_docs = [
|
||||
create_search_doc_from_inference_section(
|
||||
inference_section=inference_section,
|
||||
is_internet=is_internet_marker_dict.get(
|
||||
inference_section.center_chunk.document_id, False
|
||||
), # TODO: revisit
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
for inference_section in all_cited_documents
|
||||
]
|
||||
# # first, insert the search_docs
|
||||
# search_docs = [
|
||||
# create_search_doc_from_inference_section(
|
||||
# inference_section=inference_section,
|
||||
# is_internet=is_internet_marker_dict.get(
|
||||
# inference_section.center_chunk.document_id, False
|
||||
# ), # TODO: revisit
|
||||
# db_session=db_session,
|
||||
# commit=False,
|
||||
# )
|
||||
# for inference_section in all_cited_documents
|
||||
# ]
|
||||
|
||||
# then, map_search_docs to message
|
||||
_insert_chat_message_search_doc_pair(
|
||||
message_id, [search_doc.id for search_doc in search_docs], db_session
|
||||
)
|
||||
# # then, map_search_docs to message
|
||||
# _insert_chat_message_search_doc_pair(
|
||||
# message_id, [search_doc.id for search_doc in search_docs], db_session
|
||||
# )
|
||||
|
||||
# lastly, insert the citations
|
||||
citation_dict: dict[int, int] = {}
|
||||
cited_doc_nrs = _extract_citation_numbers(final_answer)
|
||||
if search_docs:
|
||||
for cited_doc_nr in cited_doc_nrs:
|
||||
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
|
||||
# # lastly, insert the citations
|
||||
# citation_dict: dict[int, int] = {}
|
||||
# cited_doc_nrs = _extract_citation_numbers(final_answer)
|
||||
# if search_docs:
|
||||
# for cited_doc_nr in cited_doc_nrs:
|
||||
# citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
|
||||
|
||||
# TODO: generate plan as dict in the first place
|
||||
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
|
||||
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
|
||||
# # TODO: generate plan as dict in the first place
|
||||
# plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
|
||||
# plan_of_record_dict = parse_plan_to_dict(plan_of_record)
|
||||
|
||||
# Update the chat message and its parent message in database
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=graph_config.persistence.chat_session_id,
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=final_answer,
|
||||
citations=citation_dict,
|
||||
research_type=research_type,
|
||||
research_plan=plan_of_record_dict,
|
||||
final_documents=search_docs,
|
||||
update_parent_message=True,
|
||||
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
token_count=num_tokens,
|
||||
)
|
||||
# # Update the chat message and its parent message in database
|
||||
# update_db_session_with_messages(
|
||||
# db_session=db_session,
|
||||
# chat_message_id=message_id,
|
||||
# chat_session_id=graph_config.persistence.chat_session_id,
|
||||
# is_agentic=graph_config.behavior.use_agentic_search,
|
||||
# message=final_answer,
|
||||
# citations=citation_dict,
|
||||
# research_type=None, # research_type is deprecated
|
||||
# research_plan=plan_of_record_dict,
|
||||
# final_documents=search_docs,
|
||||
# update_parent_message=True,
|
||||
# research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
# token_count=num_tokens,
|
||||
# )
|
||||
|
||||
for iteration_preparation in state.iteration_instructions:
|
||||
research_agent_iteration_step = ResearchAgentIteration(
|
||||
primary_question_id=message_id,
|
||||
reasoning=iteration_preparation.reasoning,
|
||||
purpose=iteration_preparation.purpose,
|
||||
iteration_nr=iteration_preparation.iteration_nr,
|
||||
)
|
||||
db_session.add(research_agent_iteration_step)
|
||||
# for iteration_preparation in state.iteration_instructions:
|
||||
# research_agent_iteration_step = ResearchAgentIteration(
|
||||
# primary_question_id=message_id,
|
||||
# reasoning=iteration_preparation.reasoning,
|
||||
# purpose=iteration_preparation.purpose,
|
||||
# iteration_nr=iteration_preparation.iteration_nr,
|
||||
# )
|
||||
# db_session.add(research_agent_iteration_step)
|
||||
|
||||
for iteration_answer in aggregated_context.global_iteration_responses:
|
||||
# for iteration_answer in aggregated_context.global_iteration_responses:
|
||||
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
list(iteration_answer.cited_documents.values())
|
||||
)
|
||||
# retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
# list(iteration_answer.cited_documents.values())
|
||||
# )
|
||||
|
||||
# Convert SavedSearchDoc objects to JSON-serializable format
|
||||
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
|
||||
# # Convert SavedSearchDoc objects to JSON-serializable format
|
||||
# serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
|
||||
|
||||
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
|
||||
primary_question_id=message_id,
|
||||
iteration_nr=iteration_answer.iteration_nr,
|
||||
iteration_sub_step_nr=iteration_answer.parallelization_nr,
|
||||
sub_step_instructions=iteration_answer.question,
|
||||
sub_step_tool_id=iteration_answer.tool_id,
|
||||
sub_answer=iteration_answer.answer,
|
||||
reasoning=iteration_answer.reasoning,
|
||||
claims=iteration_answer.claims,
|
||||
cited_doc_results=serialized_search_docs,
|
||||
generated_images=(
|
||||
GeneratedImageFullResult(images=iteration_answer.generated_images)
|
||||
if iteration_answer.generated_images
|
||||
else None
|
||||
),
|
||||
additional_data=iteration_answer.additional_data,
|
||||
queries=iteration_answer.queries,
|
||||
)
|
||||
db_session.add(research_agent_iteration_sub_step)
|
||||
# research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
|
||||
# primary_question_id=message_id,
|
||||
# iteration_nr=iteration_answer.iteration_nr,
|
||||
# iteration_sub_step_nr=iteration_answer.parallelization_nr,
|
||||
# sub_step_instructions=iteration_answer.question,
|
||||
# sub_step_tool_id=iteration_answer.tool_id,
|
||||
# sub_answer=iteration_answer.answer,
|
||||
# reasoning=iteration_answer.reasoning,
|
||||
# claims=iteration_answer.claims,
|
||||
# cited_doc_results=serialized_search_docs,
|
||||
# generated_images=(
|
||||
# GeneratedImageFullResult(images=iteration_answer.generated_images)
|
||||
# if iteration_answer.generated_images
|
||||
# else None
|
||||
# ),
|
||||
# additional_data=iteration_answer.additional_data,
|
||||
# queries=iteration_answer.queries,
|
||||
# )
|
||||
# db_session.add(research_agent_iteration_sub_step)
|
||||
|
||||
db_session.commit()
|
||||
# db_session.commit()
|
||||
|
||||
|
||||
def logging(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to close the DR process and finalize the answer.
|
||||
"""
|
||||
# def logging(
|
||||
# state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> LoggerUpdate:
|
||||
# """
|
||||
# LangGraph node to close the DR process and finalize the answer.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# TODO: generate final answer using all the previous steps
|
||||
# (right now, answers from each step are concatenated onto each other)
|
||||
# Also, add missing fields once usage in UI is clear.
|
||||
# node_start_time = datetime.now()
|
||||
# # TODO: generate final answer using all the previous steps
|
||||
# # (right now, answers from each step are concatenated onto each other)
|
||||
# # Also, add missing fields once usage in UI is clear.
|
||||
|
||||
current_step_nr = state.current_step_nr
|
||||
# current_step_nr = state.current_step_nr
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = state.original_question
|
||||
if not base_question:
|
||||
raise ValueError("Question is required for closer")
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# base_question = state.original_question
|
||||
# if not base_question:
|
||||
# raise ValueError("Question is required for closer")
|
||||
|
||||
aggregated_context = aggregate_context(
|
||||
state.iteration_responses, include_documents=True
|
||||
)
|
||||
# aggregated_context = aggregate_context(
|
||||
# state.iteration_responses, include_documents=True
|
||||
# )
|
||||
|
||||
all_cited_documents = aggregated_context.cited_documents
|
||||
# all_cited_documents = aggregated_context.cited_documents
|
||||
|
||||
is_internet_marker_dict = aggregated_context.is_internet_marker_dict
|
||||
# is_internet_marker_dict = aggregated_context.is_internet_marker_dict
|
||||
|
||||
final_answer = state.final_answer or ""
|
||||
llm_provider = graph_config.tooling.primary_llm.config.model_provider
|
||||
llm_model_name = graph_config.tooling.primary_llm.config.model_name
|
||||
# final_answer = state.final_answer or ""
|
||||
# llm_provider = graph_config.tooling.primary_llm.config.model_provider
|
||||
# llm_model_name = graph_config.tooling.primary_llm.config.model_name
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm_model_name,
|
||||
provider_type=llm_provider,
|
||||
)
|
||||
num_tokens = len(llm_tokenizer.encode(final_answer or ""))
|
||||
# llm_tokenizer = get_tokenizer(
|
||||
# model_name=llm_model_name,
|
||||
# provider_type=llm_provider,
|
||||
# )
|
||||
# num_tokens = len(llm_tokenizer.encode(final_answer or ""))
|
||||
|
||||
write_custom_event(current_step_nr, OverallStop(), writer)
|
||||
# write_custom_event(current_step_nr, OverallStop(), writer)
|
||||
|
||||
# Log the research agent steps
|
||||
save_iteration(
|
||||
state,
|
||||
graph_config,
|
||||
aggregated_context,
|
||||
final_answer,
|
||||
all_cited_documents,
|
||||
is_internet_marker_dict,
|
||||
num_tokens,
|
||||
)
|
||||
# # Log the research agent steps
|
||||
# save_iteration(
|
||||
# state,
|
||||
# graph_config,
|
||||
# aggregated_context,
|
||||
# final_answer,
|
||||
# all_cited_documents,
|
||||
# is_internet_marker_dict,
|
||||
# num_tokens,
|
||||
# )
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="logger",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return LoggerUpdate(
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="main",
|
||||
# node_name="logger",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,132 +1,131 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
# from collections.abc import Iterator
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langgraph.types import StreamWriter
|
||||
from pydantic import BaseModel
|
||||
# from langchain_core.messages import AIMessageChunk
|
||||
# from langchain_core.messages import BaseMessage
|
||||
# from langgraph.types import StreamWriter
|
||||
# from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
PassThroughAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import map_document_id_order
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
|
||||
# from onyx.chat.models import AgentAnswerPiece
|
||||
# from onyx.chat.models import CitationInfo
|
||||
# from onyx.chat.models import LlmDoc
|
||||
# from onyx.chat.models import OnyxAnswerPiece
|
||||
# from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
# from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
|
||||
# from onyx.chat.stream_processing.answer_response_handler import (
|
||||
# PassThroughAnswerResponseHandler,
|
||||
# )
|
||||
# from onyx.chat.stream_processing.utils import map_document_id_order
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
# from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
# from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
# from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
# from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
class BasicSearchProcessedStreamResults(BaseModel):
|
||||
ai_message_chunk: AIMessageChunk = AIMessageChunk(content="")
|
||||
full_answer: str | None = None
|
||||
cited_references: list[InferenceSection] = []
|
||||
retrieved_documents: list[LlmDoc] = []
|
||||
# class BasicSearchProcessedStreamResults(BaseModel):
|
||||
# ai_message_chunk: AIMessageChunk = AIMessageChunk(content="")
|
||||
# full_answer: str | None = None
|
||||
# cited_references: list[InferenceSection] = []
|
||||
# retrieved_documents: list[LlmDoc] = []
|
||||
|
||||
|
||||
def process_llm_stream(
|
||||
messages: Iterator[BaseMessage],
|
||||
should_stream_answer: bool,
|
||||
writer: StreamWriter,
|
||||
ind: int,
|
||||
search_results: list[LlmDoc] | None = None,
|
||||
generate_final_answer: bool = False,
|
||||
chat_message_id: str | None = None,
|
||||
) -> BasicSearchProcessedStreamResults:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
# def process_llm_stream(
|
||||
# messages: Iterator[BaseMessage],
|
||||
# should_stream_answer: bool,
|
||||
# writer: StreamWriter,
|
||||
# ind: int,
|
||||
# search_results: list[LlmDoc] | None = None,
|
||||
# generate_final_answer: bool = False,
|
||||
# chat_message_id: str | None = None,
|
||||
# ) -> BasicSearchProcessedStreamResults:
|
||||
# tool_call_chunk = AIMessageChunk(content="")
|
||||
|
||||
if search_results:
|
||||
answer_handler: AnswerResponseHandler = CitationResponseHandler(
|
||||
context_docs=search_results,
|
||||
doc_id_to_rank_map=map_document_id_order(search_results),
|
||||
)
|
||||
else:
|
||||
answer_handler = PassThroughAnswerResponseHandler()
|
||||
# if search_results:
|
||||
# answer_handler: AnswerResponseHandler = CitationResponseHandler(
|
||||
# context_docs=search_results,
|
||||
# doc_id_to_rank_map=map_document_id_order(search_results),
|
||||
# )
|
||||
# else:
|
||||
# answer_handler = PassThroughAnswerResponseHandler()
|
||||
|
||||
full_answer = ""
|
||||
start_final_answer_streaming_set = False
|
||||
# Accumulate citation infos if handler emits them
|
||||
collected_citation_infos: list[CitationInfo] = []
|
||||
# full_answer = ""
|
||||
# start_final_answer_streaming_set = False
|
||||
# # Accumulate citation infos if handler emits them
|
||||
# collected_citation_infos: list[CitationInfo] = []
|
||||
|
||||
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
||||
# the stream will contain AIMessageChunks with tool call information.
|
||||
for message in messages:
|
||||
# # This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
||||
# # the stream will contain AIMessageChunks with tool call information.
|
||||
# for message in messages:
|
||||
|
||||
answer_piece = message.content
|
||||
if not isinstance(answer_piece, str):
|
||||
# this is only used for logging, so fine to
|
||||
# just add the string representation
|
||||
answer_piece = str(answer_piece)
|
||||
full_answer += answer_piece
|
||||
# answer_piece = message.content
|
||||
# if not isinstance(answer_piece, str):
|
||||
# # this is only used for logging, so fine to
|
||||
# # just add the string representation
|
||||
# answer_piece = str(answer_piece)
|
||||
# full_answer += answer_piece
|
||||
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
):
|
||||
tool_call_chunk += message # type: ignore
|
||||
elif should_stream_answer:
|
||||
for response_part in answer_handler.handle_response_part(message):
|
||||
# if isinstance(message, AIMessageChunk) and (
|
||||
# message.tool_call_chunks or message.tool_calls
|
||||
# ):
|
||||
# tool_call_chunk += message # type: ignore
|
||||
# elif should_stream_answer:
|
||||
# for response_part in answer_handler.handle_response_part(message):
|
||||
|
||||
# only stream out answer parts
|
||||
if (
|
||||
isinstance(response_part, (OnyxAnswerPiece, AgentAnswerPiece))
|
||||
and generate_final_answer
|
||||
and response_part.answer_piece
|
||||
):
|
||||
if chat_message_id is None:
|
||||
raise ValueError(
|
||||
"chat_message_id is required when generating final answer"
|
||||
)
|
||||
# # only stream out answer parts
|
||||
# if (
|
||||
# isinstance(response_part, (OnyxAnswerPiece, AgentAnswerPiece))
|
||||
# and generate_final_answer
|
||||
# and response_part.answer_piece
|
||||
# ):
|
||||
# if chat_message_id is None:
|
||||
# raise ValueError(
|
||||
# "chat_message_id is required when generating final answer"
|
||||
# )
|
||||
|
||||
if not start_final_answer_streaming_set:
|
||||
# Convert LlmDocs to SavedSearchDocs
|
||||
saved_search_docs = saved_search_docs_from_llm_docs(
|
||||
search_results
|
||||
)
|
||||
write_custom_event(
|
||||
ind,
|
||||
MessageStart(content="", final_documents=saved_search_docs),
|
||||
writer,
|
||||
)
|
||||
start_final_answer_streaming_set = True
|
||||
# if not start_final_answer_streaming_set:
|
||||
# # Convert LlmDocs to SavedSearchDocs
|
||||
# saved_search_docs = saved_search_docs_from_llm_docs(
|
||||
# search_results
|
||||
# )
|
||||
# write_custom_event(
|
||||
# ind,
|
||||
# MessageStart(content="", final_documents=saved_search_docs),
|
||||
# writer,
|
||||
# )
|
||||
# start_final_answer_streaming_set = True
|
||||
|
||||
write_custom_event(
|
||||
ind,
|
||||
MessageDelta(content=response_part.answer_piece),
|
||||
writer,
|
||||
)
|
||||
# collect citation info objects
|
||||
elif isinstance(response_part, CitationInfo):
|
||||
collected_citation_infos.append(response_part)
|
||||
# write_custom_event(
|
||||
# ind,
|
||||
# MessageDelta(content=response_part.answer_piece),
|
||||
# writer,
|
||||
# )
|
||||
# # collect citation info objects
|
||||
# elif isinstance(response_part, CitationInfo):
|
||||
# collected_citation_infos.append(response_part)
|
||||
|
||||
if generate_final_answer and start_final_answer_streaming_set:
|
||||
# start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
|
||||
write_custom_event(
|
||||
ind,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
# if generate_final_answer and start_final_answer_streaming_set:
|
||||
# # start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
|
||||
# write_custom_event(
|
||||
# ind,
|
||||
# SectionEnd(),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
# Emit citations section if any were collected
|
||||
if collected_citation_infos:
|
||||
write_custom_event(ind, CitationStart(), writer)
|
||||
write_custom_event(
|
||||
ind, CitationDelta(citations=collected_citation_infos), writer
|
||||
)
|
||||
write_custom_event(ind, SectionEnd(), writer)
|
||||
# # Emit citations section if any were collected
|
||||
# if collected_citation_infos:
|
||||
# write_custom_event(ind, CitationStart(), writer)
|
||||
# write_custom_event(
|
||||
# ind, CitationDelta(citations=collected_citation_infos), writer
|
||||
# )
|
||||
# write_custom_event(ind, SectionEnd(), writer)
|
||||
|
||||
logger.debug(f"Full answer: {full_answer}")
|
||||
return BasicSearchProcessedStreamResults(
|
||||
ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
|
||||
)
|
||||
# logger.debug(f"Full answer: {full_answer}")
|
||||
# return BasicSearchProcessedStreamResults(
|
||||
# ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
|
||||
# )
|
||||
|
||||
@@ -1,82 +1,82 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import Any
|
||||
from typing import TypedDict
|
||||
# from operator import add
|
||||
# from typing import Annotated
|
||||
# from typing import Any
|
||||
# from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
# from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationPlan
|
||||
from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.connector import DocumentSource
|
||||
# from onyx.agents.agent_search.core_state import CoreState
|
||||
# from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
# from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
# from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
|
||||
# from onyx.agents.agent_search.dr.models import OrchestrationPlan
|
||||
# from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.db.connector import DocumentSource
|
||||
|
||||
### States ###
|
||||
# ### States ###
|
||||
|
||||
|
||||
class LoggerUpdate(BaseModel):
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
# class LoggerUpdate(BaseModel):
|
||||
# log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class OrchestrationUpdate(LoggerUpdate):
|
||||
tools_used: Annotated[list[str], add] = []
|
||||
query_list: list[str] = []
|
||||
iteration_nr: int = 0
|
||||
current_step_nr: int = 1
|
||||
plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
|
||||
remaining_time_budget: float = 2.0 # set by default to about 2 searches
|
||||
num_closer_suggestions: int = 0 # how many times the closer was suggested
|
||||
gaps: list[str] = (
|
||||
[]
|
||||
) # gaps that may be identified by the closer before being able to answer the question.
|
||||
iteration_instructions: Annotated[list[IterationInstructions], add] = []
|
||||
# class OrchestrationUpdate(LoggerUpdate):
|
||||
# tools_used: Annotated[list[str], add] = []
|
||||
# query_list: list[str] = []
|
||||
# iteration_nr: int = 0
|
||||
# current_step_nr: int = 1
|
||||
# plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
|
||||
# remaining_time_budget: float = 2.0 # set by default to about 2 searches
|
||||
# num_closer_suggestions: int = 0 # how many times the closer was suggested
|
||||
# gaps: list[str] = (
|
||||
# []
|
||||
# ) # gaps that may be identified by the closer before being able to answer the question.
|
||||
# iteration_instructions: Annotated[list[IterationInstructions], add] = []
|
||||
|
||||
|
||||
class OrchestrationSetup(OrchestrationUpdate):
|
||||
original_question: str | None = None
|
||||
chat_history_string: str | None = None
|
||||
clarification: OrchestrationClarificationInfo | None = None
|
||||
available_tools: dict[str, OrchestratorTool] | None = None
|
||||
num_closer_suggestions: int = 0 # how many times the closer was suggested
|
||||
# class OrchestrationSetup(OrchestrationUpdate):
|
||||
# original_question: str | None = None
|
||||
# chat_history_string: str | None = None
|
||||
# clarification: OrchestrationClarificationInfo | None = None
|
||||
# available_tools: dict[str, OrchestratorTool] | None = None
|
||||
# num_closer_suggestions: int = 0 # how many times the closer was suggested
|
||||
|
||||
active_source_types: list[DocumentSource] | None = None
|
||||
active_source_types_descriptions: str | None = None
|
||||
assistant_system_prompt: str | None = None
|
||||
assistant_task_prompt: str | None = None
|
||||
uploaded_test_context: str | None = None
|
||||
uploaded_image_context: list[dict[str, Any]] | None = None
|
||||
# active_source_types: list[DocumentSource] | None = None
|
||||
# active_source_types_descriptions: str | None = None
|
||||
# assistant_system_prompt: str | None = None
|
||||
# assistant_task_prompt: str | None = None
|
||||
# uploaded_test_context: str | None = None
|
||||
# uploaded_image_context: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class AnswerUpdate(LoggerUpdate):
|
||||
iteration_responses: Annotated[list[IterationAnswer], add] = []
|
||||
# class AnswerUpdate(LoggerUpdate):
|
||||
# iteration_responses: Annotated[list[IterationAnswer], add] = []
|
||||
|
||||
|
||||
class FinalUpdate(LoggerUpdate):
|
||||
final_answer: str | None = None
|
||||
all_cited_documents: list[InferenceSection] = []
|
||||
# class FinalUpdate(LoggerUpdate):
|
||||
# final_answer: str | None = None
|
||||
# all_cited_documents: list[InferenceSection] = []
|
||||
|
||||
|
||||
## Graph Input State
|
||||
class MainInput(CoreState):
|
||||
pass
|
||||
# ## Graph Input State
|
||||
# class MainInput(CoreState):
|
||||
# pass
|
||||
|
||||
|
||||
## Graph State
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
OrchestrationSetup,
|
||||
AnswerUpdate,
|
||||
FinalUpdate,
|
||||
):
|
||||
pass
|
||||
# ## Graph State
|
||||
# class MainState(
|
||||
# # This includes the core state
|
||||
# MainInput,
|
||||
# OrchestrationSetup,
|
||||
# AnswerUpdate,
|
||||
# FinalUpdate,
|
||||
# ):
|
||||
# pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
class MainOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
final_answer: str | None
|
||||
all_cited_documents: list[InferenceSection]
|
||||
# ## Graph Output State
|
||||
# class MainOutput(TypedDict):
|
||||
# log_messages: list[str]
|
||||
# final_answer: str | None
|
||||
# all_cited_documents: list[InferenceSection]
|
||||
|
||||
@@ -1,47 +1,47 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def basic_search_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
# def basic_search_branch(
|
||||
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> LoggerUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a standard search as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
# node_start_time = datetime.now()
|
||||
# iteration_nr = state.iteration_nr
|
||||
# current_step_nr = state.current_step_nr
|
||||
|
||||
logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
|
||||
# logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolStart(
|
||||
is_internet_search=False,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# SearchToolStart(
|
||||
# is_internet_search=False,
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="basic_search",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return LoggerUpdate(
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="basic_search",
|
||||
# node_name="branching",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,286 +1,261 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
# import re
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import BaseSearchProcessingResponse
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import SearchAnswer
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.agents.agent_search.dr.utils import extract_document_citations
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.connector import DocumentSource
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
|
||||
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
|
||||
from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.models import BaseSearchProcessingResponse
|
||||
# from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
# from onyx.agents.agent_search.dr.models import SearchAnswer
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
# from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
# from onyx.agents.agent_search.dr.utils import extract_document_citations
|
||||
# from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.agents.agent_search.utils import create_question_prompt
|
||||
# from onyx.chat.models import LlmDoc
|
||||
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
# from onyx.context.search.models import InferenceSection
|
||||
# from onyx.db.connector import DocumentSource
|
||||
# from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
|
||||
# from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
|
||||
# from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
|
||||
# from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
# from onyx.tools.models import SearchToolOverrideKwargs
|
||||
# from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
# from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
# SEARCH_INFERENCE_SECTIONS_ID,
|
||||
# )
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def basic_search(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
# def basic_search(
|
||||
# state: BranchInput,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> BranchUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a standard search as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
assistant_system_prompt = state.assistant_system_prompt
|
||||
assistant_task_prompt = state.assistant_task_prompt
|
||||
# node_start_time = datetime.now()
|
||||
# iteration_nr = state.iteration_nr
|
||||
# parallelization_nr = state.parallelization_nr
|
||||
# current_step_nr = state.current_step_nr
|
||||
# assistant_system_prompt = state.assistant_system_prompt
|
||||
# assistant_task_prompt = state.assistant_task_prompt
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
# branch_query = state.branch_question
|
||||
# if not branch_query:
|
||||
# raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
research_type = graph_config.behavior.research_type
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
# use_agentic_search = graph_config.behavior.use_agentic_search
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
# if not state.available_tools:
|
||||
# raise ValueError("available_tools is not set")
|
||||
|
||||
elif len(state.tools_used) == 0:
|
||||
raise ValueError("tools_used is empty")
|
||||
# elif len(state.tools_used) == 0:
|
||||
# raise ValueError("tools_used is empty")
|
||||
|
||||
search_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
search_tool = cast(SearchTool, search_tool_info.tool_object)
|
||||
force_use_tool = graph_config.tooling.force_use_tool
|
||||
# search_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
# search_tool = cast(SearchTool, search_tool_info.tool_object)
|
||||
# graph_config.tooling.force_use_tool
|
||||
|
||||
# sanity check
|
||||
if search_tool != graph_config.tooling.search_tool:
|
||||
raise ValueError("search_tool does not match the configured search tool")
|
||||
# # sanity check
|
||||
# if search_tool != graph_config.tooling.search_tool:
|
||||
# raise ValueError("search_tool does not match the configured search tool")
|
||||
|
||||
# rewrite query and identify source types
|
||||
active_source_types_str = ", ".join(
|
||||
[source.value for source in state.active_source_types or []]
|
||||
)
|
||||
# # rewrite query and identify source types
|
||||
# active_source_types_str = ", ".join(
|
||||
# [source.value for source in state.active_source_types or []]
|
||||
# )
|
||||
|
||||
base_search_processing_prompt = BASE_SEARCH_PROCESSING_PROMPT.build(
|
||||
active_source_types_str=active_source_types_str,
|
||||
branch_query=branch_query,
|
||||
current_time=datetime.now().strftime("%Y-%m-%d %H:%M"),
|
||||
)
|
||||
# base_search_processing_prompt = BASE_SEARCH_PROCESSING_PROMPT.build(
|
||||
# active_source_types_str=active_source_types_str,
|
||||
# branch_query=branch_query,
|
||||
# current_time=datetime.now().strftime("%Y-%m-%d %H:%M"),
|
||||
# )
|
||||
|
||||
try:
|
||||
search_processing = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt, base_search_processing_prompt
|
||||
),
|
||||
schema=BaseSearchProcessingResponse,
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
# max_tokens=100,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not process query: {e}")
|
||||
raise e
|
||||
# try:
|
||||
# search_processing = invoke_llm_json(
|
||||
# llm=graph_config.tooling.primary_llm,
|
||||
# prompt=create_question_prompt(
|
||||
# assistant_system_prompt, base_search_processing_prompt
|
||||
# ),
|
||||
# schema=BaseSearchProcessingResponse,
|
||||
# timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
# # max_tokens=100,
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.error(f"Could not process query: {e}")
|
||||
# raise e
|
||||
|
||||
rewritten_query = search_processing.rewritten_query
|
||||
# rewritten_query = search_processing.rewritten_query
|
||||
|
||||
# give back the query so we can render it in the UI
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolDelta(
|
||||
queries=[rewritten_query],
|
||||
documents=[],
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# # give back the query so we can render it in the UI
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# SearchToolDelta(
|
||||
# queries=[rewritten_query],
|
||||
# documents=[],
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
implied_start_date = search_processing.time_filter
|
||||
# implied_start_date = search_processing.time_filter
|
||||
|
||||
# Validate time_filter format if it exists
|
||||
implied_time_filter = None
|
||||
if implied_start_date:
|
||||
# # Validate time_filter format if it exists
|
||||
# implied_time_filter = None
|
||||
# if implied_start_date:
|
||||
|
||||
# Check if time_filter is in YYYY-MM-DD format
|
||||
date_pattern = r"^\d{4}-\d{2}-\d{2}$"
|
||||
if re.match(date_pattern, implied_start_date):
|
||||
implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
|
||||
# # Check if time_filter is in YYYY-MM-DD format
|
||||
# date_pattern = r"^\d{4}-\d{2}-\d{2}$"
|
||||
# if re.match(date_pattern, implied_start_date):
|
||||
# implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
|
||||
|
||||
specified_source_types: list[DocumentSource] | None = (
|
||||
strings_to_document_sources(search_processing.specified_source_types)
|
||||
if search_processing.specified_source_types
|
||||
else None
|
||||
)
|
||||
# specified_source_types: list[DocumentSource] | None = (
|
||||
# strings_to_document_sources(search_processing.specified_source_types)
|
||||
# if search_processing.specified_source_types
|
||||
# else None
|
||||
# )
|
||||
|
||||
if specified_source_types is not None and len(specified_source_types) == 0:
|
||||
specified_source_types = None
|
||||
# if specified_source_types is not None and len(specified_source_types) == 0:
|
||||
# specified_source_types = None
|
||||
|
||||
logger.debug(
|
||||
f"Search start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
# logger.debug(
|
||||
# f"Search start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
# )
|
||||
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
callback_container: list[list[InferenceSection]] = []
|
||||
# retrieved_docs: list[InferenceSection] = []
|
||||
|
||||
user_file_ids: list[UUID] | None = None
|
||||
project_id: int | None = None
|
||||
if force_use_tool.override_kwargs and isinstance(
|
||||
force_use_tool.override_kwargs, SearchToolOverrideKwargs
|
||||
):
|
||||
override_kwargs = force_use_tool.override_kwargs
|
||||
user_file_ids = override_kwargs.user_file_ids
|
||||
project_id = override_kwargs.project_id
|
||||
# for tool_response in search_tool.run(
|
||||
# query=rewritten_query,
|
||||
# document_sources=specified_source_types,
|
||||
# time_filter=implied_time_filter,
|
||||
# override_kwargs=SearchToolOverrideKwargs(original_query=rewritten_query),
|
||||
# ):
|
||||
# # get retrieved docs to send to the rest of the graph
|
||||
# if tool_response.id == SEARCH_INFERENCE_SECTIONS_ID:
|
||||
# retrieved_docs = cast(list[InferenceSection], tool_response.response)
|
||||
|
||||
# new db session to avoid concurrency issues
|
||||
with get_session_with_current_tenant() as search_db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=rewritten_query,
|
||||
document_sources=specified_source_types,
|
||||
time_filter=implied_time_filter,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=search_db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
skip_query_analysis=True,
|
||||
original_query=rewritten_query,
|
||||
user_file_ids=user_file_ids,
|
||||
project_id=project_id,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
response = cast(SearchResponseSummary, tool_response.response)
|
||||
retrieved_docs = response.top_sections
|
||||
# break
|
||||
|
||||
break
|
||||
# # render the retrieved docs in the UI
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# SearchToolDelta(
|
||||
# queries=[],
|
||||
# documents=convert_inference_sections_to_search_docs(
|
||||
# retrieved_docs, is_internet=False
|
||||
# ),
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
# render the retrieved docs in the UI
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolDelta(
|
||||
queries=[],
|
||||
documents=convert_inference_sections_to_search_docs(
|
||||
retrieved_docs, is_internet=False
|
||||
),
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# document_texts_list = []
|
||||
|
||||
document_texts_list = []
|
||||
# for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
|
||||
# if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
|
||||
# raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
|
||||
# chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
# document_texts_list.append(chunk_text)
|
||||
|
||||
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
|
||||
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
|
||||
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
|
||||
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
document_texts_list.append(chunk_text)
|
||||
# document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
# logger.debug(
|
||||
# f"Search end/LLM start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
# )
|
||||
|
||||
logger.debug(
|
||||
f"Search end/LLM start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
# # Built prompt
|
||||
|
||||
# Built prompt
|
||||
# if use_agentic_search:
|
||||
# search_prompt = INTERNAL_SEARCH_PROMPTS[use_agentic_search].build(
|
||||
# search_query=branch_query,
|
||||
# base_question=base_question,
|
||||
# document_text=document_texts,
|
||||
# )
|
||||
|
||||
if research_type == ResearchType.DEEP:
|
||||
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
|
||||
search_query=branch_query,
|
||||
base_question=base_question,
|
||||
document_text=document_texts,
|
||||
)
|
||||
# # Run LLM
|
||||
|
||||
# Run LLM
|
||||
# # search_answer_json = None
|
||||
# search_answer_json = invoke_llm_json(
|
||||
# llm=graph_config.tooling.primary_llm,
|
||||
# prompt=create_question_prompt(
|
||||
# assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
|
||||
# ),
|
||||
# schema=SearchAnswer,
|
||||
# timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
# # max_tokens=1500,
|
||||
# )
|
||||
|
||||
# search_answer_json = None
|
||||
search_answer_json = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
|
||||
),
|
||||
schema=SearchAnswer,
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
# max_tokens=1500,
|
||||
)
|
||||
# logger.debug(
|
||||
# f"LLM/all done for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
# )
|
||||
|
||||
logger.debug(
|
||||
f"LLM/all done for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
# # get cited documents
|
||||
# answer_string = search_answer_json.answer
|
||||
# claims = search_answer_json.claims or []
|
||||
# reasoning = search_answer_json.reasoning
|
||||
# # answer_string = ""
|
||||
# # claims = []
|
||||
|
||||
# get cited documents
|
||||
answer_string = search_answer_json.answer
|
||||
claims = search_answer_json.claims or []
|
||||
reasoning = search_answer_json.reasoning
|
||||
# answer_string = ""
|
||||
# claims = []
|
||||
# (
|
||||
# citation_numbers,
|
||||
# answer_string,
|
||||
# claims,
|
||||
# ) = extract_document_citations(answer_string, claims)
|
||||
|
||||
(
|
||||
citation_numbers,
|
||||
answer_string,
|
||||
claims,
|
||||
) = extract_document_citations(answer_string, claims)
|
||||
# if citation_numbers and (
|
||||
# (max(citation_numbers) > len(retrieved_docs)) or min(citation_numbers) < 1
|
||||
# ):
|
||||
# raise ValueError("Citation numbers are out of range for retrieved docs.")
|
||||
|
||||
if citation_numbers and (
|
||||
(max(citation_numbers) > len(retrieved_docs)) or min(citation_numbers) < 1
|
||||
):
|
||||
raise ValueError("Citation numbers are out of range for retrieved docs.")
|
||||
# cited_documents = {
|
||||
# citation_number: retrieved_docs[citation_number - 1]
|
||||
# for citation_number in citation_numbers
|
||||
# }
|
||||
|
||||
cited_documents = {
|
||||
citation_number: retrieved_docs[citation_number - 1]
|
||||
for citation_number in citation_numbers
|
||||
}
|
||||
# else:
|
||||
# answer_string = ""
|
||||
# claims = []
|
||||
# cited_documents = {
|
||||
# doc_num + 1: retrieved_doc
|
||||
# for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
|
||||
# }
|
||||
# reasoning = ""
|
||||
|
||||
else:
|
||||
answer_string = ""
|
||||
claims = []
|
||||
cited_documents = {
|
||||
doc_num + 1: retrieved_doc
|
||||
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
|
||||
}
|
||||
reasoning = ""
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=search_tool_info.llm_path,
|
||||
tool_id=search_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=claims,
|
||||
cited_documents=cited_documents,
|
||||
reasoning=reasoning,
|
||||
additional_data=None,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="basic_search",
|
||||
node_name="searching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return BranchUpdate(
|
||||
# branch_iteration_responses=[
|
||||
# IterationAnswer(
|
||||
# tool=search_tool_info.llm_path,
|
||||
# tool_id=search_tool_info.tool_id,
|
||||
# iteration_nr=iteration_nr,
|
||||
# parallelization_nr=parallelization_nr,
|
||||
# question=branch_query,
|
||||
# answer=answer_string,
|
||||
# claims=claims,
|
||||
# cited_documents=cited_documents,
|
||||
# reasoning=reasoning,
|
||||
# additional_data=None,
|
||||
# )
|
||||
# ],
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="basic_search",
|
||||
# node_name="searching",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,77 +1,77 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.context.search.models import SavedSearchDoc
|
||||
# from onyx.context.search.models import SearchDoc
|
||||
# from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def is_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
# def is_reducer(
|
||||
# state: SubAgentMainState,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> SubAgentUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a standard search as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# node_start_time = datetime.now()
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
# branch_updates = state.branch_iteration_responses
|
||||
# current_iteration = state.iteration_nr
|
||||
# current_step_nr = state.current_step_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
# new_updates = [
|
||||
# update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
# ]
|
||||
|
||||
[update.question for update in new_updates]
|
||||
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
|
||||
# [update.question for update in new_updates]
|
||||
# doc_lists = [list(update.cited_documents.values()) for update in new_updates]
|
||||
|
||||
doc_list = []
|
||||
# doc_list = []
|
||||
|
||||
for xs in doc_lists:
|
||||
for x in xs:
|
||||
doc_list.append(x)
|
||||
# for xs in doc_lists:
|
||||
# for x in xs:
|
||||
# doc_list.append(x)
|
||||
|
||||
# Convert InferenceSections to SavedSearchDocs
|
||||
search_docs = SearchDoc.from_chunks_or_sections(doc_list)
|
||||
retrieved_saved_search_docs = [
|
||||
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
|
||||
for search_doc in search_docs
|
||||
]
|
||||
# # Convert InferenceSections to SavedSearchDocs
|
||||
# search_docs = SearchDoc.from_chunks_or_sections(doc_list)
|
||||
# retrieved_saved_search_docs = [
|
||||
# SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
|
||||
# for search_doc in search_docs
|
||||
# ]
|
||||
|
||||
for retrieved_saved_search_doc in retrieved_saved_search_docs:
|
||||
retrieved_saved_search_doc.is_internet = False
|
||||
# for retrieved_saved_search_doc in retrieved_saved_search_docs:
|
||||
# retrieved_saved_search_doc.is_internet = False
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# SectionEnd(),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
current_step_nr += 1
|
||||
# current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="basic_search",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return SubAgentUpdate(
|
||||
# iteration_responses=new_updates,
|
||||
# current_step_nr=current_step_nr,
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="basic_search",
|
||||
# node_name="consolidation",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,50 +1,50 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
# from langgraph.graph import END
|
||||
# from langgraph.graph import START
|
||||
# from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_1_branch import (
|
||||
basic_search_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import (
|
||||
basic_search,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_3_reduce import (
|
||||
is_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_image_generation_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_1_branch import (
|
||||
# basic_search_branch,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import (
|
||||
# basic_search,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_3_reduce import (
|
||||
# is_reducer,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_image_generation_conditional_edges import (
|
||||
# branching_router,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def dr_basic_search_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Web Search Sub-Agent
|
||||
"""
|
||||
# def dr_basic_search_graph_builder() -> StateGraph:
|
||||
# """
|
||||
# LangGraph graph builder for Web Search Sub-Agent
|
||||
# """
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
# ### Add nodes ###
|
||||
|
||||
graph.add_node("branch", basic_search_branch)
|
||||
# graph.add_node("branch", basic_search_branch)
|
||||
|
||||
graph.add_node("act", basic_search)
|
||||
# graph.add_node("act", basic_search)
|
||||
|
||||
graph.add_node("reducer", is_reducer)
|
||||
# graph.add_node("reducer", is_reducer)
|
||||
|
||||
### Add edges ###
|
||||
# ### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
# graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
# graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
# graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
# graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
# return graph
|
||||
|
||||
@@ -1,30 +1,30 @@
|
||||
from collections.abc import Hashable
|
||||
# from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
# from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
current_step_nr=state.current_step_nr,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:MAX_DR_PARALLEL_SEARCH]
|
||||
)
|
||||
]
|
||||
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
# return [
|
||||
# Send(
|
||||
# "act",
|
||||
# BranchInput(
|
||||
# iteration_nr=state.iteration_nr,
|
||||
# parallelization_nr=parallelization_nr,
|
||||
# branch_question=query,
|
||||
# current_step_nr=state.current_step_nr,
|
||||
# context="",
|
||||
# active_source_types=state.active_source_types,
|
||||
# tools_used=state.tools_used,
|
||||
# available_tools=state.available_tools,
|
||||
# assistant_system_prompt=state.assistant_system_prompt,
|
||||
# assistant_task_prompt=state.assistant_task_prompt,
|
||||
# ),
|
||||
# )
|
||||
# for parallelization_nr, query in enumerate(
|
||||
# state.query_list[:MAX_DR_PARALLEL_SEARCH]
|
||||
# )
|
||||
# ]
|
||||
|
||||
@@ -1,36 +1,36 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def custom_tool_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
# def custom_tool_branch(
|
||||
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> LoggerUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a generic tool call as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
# node_start_time = datetime.now()
|
||||
# iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
|
||||
# logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return LoggerUpdate(
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="custom_tool",
|
||||
# node_name="branching",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,169 +1,164 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# import json
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.messages import AIMessage
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
from onyx.tools.tool_implementations.mcp.mcp_tool import MCP_TOOL_RESPONSE_ID
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
|
||||
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
|
||||
# from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
# from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
|
||||
# from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
# from onyx.tools.tool_implementations.mcp.mcp_tool import MCP_TOOL_RESPONSE_ID
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def custom_tool_act(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
# def custom_tool_act(
|
||||
# state: BranchInput,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> BranchUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a generic tool call as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
# node_start_time = datetime.now()
|
||||
# iteration_nr = state.iteration_nr
|
||||
# parallelization_nr = state.parallelization_nr
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
# if not state.available_tools:
|
||||
# raise ValueError("available_tools is not set")
|
||||
|
||||
custom_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
custom_tool_name = custom_tool_info.name
|
||||
custom_tool = cast(CustomTool, custom_tool_info.tool_object)
|
||||
# custom_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
# custom_tool_name = custom_tool_info.name
|
||||
# custom_tool = cast(CustomTool, custom_tool_info.tool_object)
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
# branch_query = state.branch_question
|
||||
# if not branch_query:
|
||||
# raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
logger.debug(
|
||||
f"Tool call start for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
# logger.debug(
|
||||
# f"Tool call start for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
# )
|
||||
|
||||
# get tool call args
|
||||
tool_args: dict | None = None
|
||||
if graph_config.tooling.using_tool_calling_llm:
|
||||
# get tool call args from tool-calling LLM
|
||||
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
|
||||
query=branch_query,
|
||||
base_question=base_question,
|
||||
tool_description=custom_tool_info.description,
|
||||
)
|
||||
tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
|
||||
tool_use_prompt,
|
||||
tools=[custom_tool.tool_definition()],
|
||||
tool_choice="required",
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
)
|
||||
# # get tool call args
|
||||
# tool_args: dict | None = None
|
||||
# if graph_config.tooling.using_tool_calling_llm:
|
||||
# # get tool call args from tool-calling LLM
|
||||
# tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
|
||||
# query=branch_query,
|
||||
# base_question=base_question,
|
||||
# tool_description=custom_tool_info.description,
|
||||
# )
|
||||
# tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
|
||||
# tool_use_prompt,
|
||||
# tools=[custom_tool.tool_definition()],
|
||||
# tool_choice="required",
|
||||
# timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
# )
|
||||
|
||||
# make sure we got a tool call
|
||||
if (
|
||||
isinstance(tool_calling_msg, AIMessage)
|
||||
and len(tool_calling_msg.tool_calls) == 1
|
||||
):
|
||||
tool_args = tool_calling_msg.tool_calls[0]["args"]
|
||||
else:
|
||||
logger.warning("Tool-calling LLM did not emit a tool call")
|
||||
# # make sure we got a tool call
|
||||
# if (
|
||||
# isinstance(tool_calling_msg, AIMessage)
|
||||
# and len(tool_calling_msg.tool_calls) == 1
|
||||
# ):
|
||||
# tool_args = tool_calling_msg.tool_calls[0]["args"]
|
||||
# else:
|
||||
# logger.warning("Tool-calling LLM did not emit a tool call")
|
||||
|
||||
if tool_args is None:
|
||||
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
|
||||
tool_args = custom_tool.get_args_for_non_tool_calling_llm(
|
||||
query=branch_query,
|
||||
history=[],
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
force_run=True,
|
||||
)
|
||||
# if tool_args is None:
|
||||
# raise ValueError(
|
||||
# "Failed to obtain tool arguments from LLM - tool calling is required"
|
||||
# )
|
||||
|
||||
if tool_args is None:
|
||||
raise ValueError("Failed to obtain tool arguments from LLM")
|
||||
# # run the tool
|
||||
# response_summary: CustomToolCallSummary | None = None
|
||||
# for tool_response in custom_tool.run(**tool_args):
|
||||
# if tool_response.id in {CUSTOM_TOOL_RESPONSE_ID, MCP_TOOL_RESPONSE_ID}:
|
||||
# response_summary = cast(CustomToolCallSummary, tool_response.response)
|
||||
# break
|
||||
|
||||
# run the tool
|
||||
response_summary: CustomToolCallSummary | None = None
|
||||
for tool_response in custom_tool.run(**tool_args):
|
||||
if tool_response.id in {CUSTOM_TOOL_RESPONSE_ID, MCP_TOOL_RESPONSE_ID}:
|
||||
response_summary = cast(CustomToolCallSummary, tool_response.response)
|
||||
break
|
||||
# if not response_summary:
|
||||
# raise ValueError("Custom tool did not return a valid response summary")
|
||||
|
||||
if not response_summary:
|
||||
raise ValueError("Custom tool did not return a valid response summary")
|
||||
# # summarise tool result
|
||||
# if not response_summary.response_type:
|
||||
# raise ValueError("Response type is not returned.")
|
||||
|
||||
# summarise tool result
|
||||
if not response_summary.response_type:
|
||||
raise ValueError("Response type is not returned.")
|
||||
# if response_summary.response_type == "json":
|
||||
# tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
|
||||
# elif response_summary.response_type in {"image", "csv"}:
|
||||
# tool_result_str = f"{response_summary.response_type} files: {response_summary.tool_result.file_ids}"
|
||||
# else:
|
||||
# tool_result_str = str(response_summary.tool_result)
|
||||
|
||||
if response_summary.response_type == "json":
|
||||
tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
|
||||
elif response_summary.response_type in {"image", "csv"}:
|
||||
tool_result_str = f"{response_summary.response_type} files: {response_summary.tool_result.file_ids}"
|
||||
else:
|
||||
tool_result_str = str(response_summary.tool_result)
|
||||
# tool_str = (
|
||||
# f"Tool used: {custom_tool_name}\n"
|
||||
# f"Description: {custom_tool_info.description}\n"
|
||||
# f"Result: {tool_result_str}"
|
||||
# )
|
||||
|
||||
tool_str = (
|
||||
f"Tool used: {custom_tool_name}\n"
|
||||
f"Description: {custom_tool_info.description}\n"
|
||||
f"Result: {tool_result_str}"
|
||||
)
|
||||
# tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
|
||||
# query=branch_query, base_question=base_question, tool_response=tool_str
|
||||
# )
|
||||
# answer_string = str(
|
||||
# graph_config.tooling.primary_llm.invoke(
|
||||
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
|
||||
# ).content
|
||||
# ).strip()
|
||||
|
||||
tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
|
||||
query=branch_query, base_question=base_question, tool_response=tool_str
|
||||
)
|
||||
answer_string = str(
|
||||
graph_config.tooling.primary_llm.invoke_langchain(
|
||||
tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
|
||||
).content
|
||||
).strip()
|
||||
# tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
|
||||
# query=branch_query, base_question=base_question, tool_response=tool_str
|
||||
# )
|
||||
# answer_string = str(
|
||||
# graph_config.tooling.primary_llm.invoke_langchain(
|
||||
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
|
||||
# ).content
|
||||
# ).strip()
|
||||
|
||||
# get file_ids:
|
||||
file_ids = None
|
||||
if response_summary.response_type in {"image", "csv"} and hasattr(
|
||||
response_summary.tool_result, "file_ids"
|
||||
):
|
||||
file_ids = response_summary.tool_result.file_ids
|
||||
# logger.debug(
|
||||
# f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
# )
|
||||
|
||||
logger.debug(
|
||||
f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=custom_tool_name,
|
||||
tool_id=custom_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=[],
|
||||
cited_documents={},
|
||||
reasoning="",
|
||||
additional_data=None,
|
||||
response_type=response_summary.response_type,
|
||||
data=response_summary.tool_result,
|
||||
file_ids=file_ids,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="tool_calling",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return BranchUpdate(
|
||||
# branch_iteration_responses=[
|
||||
# IterationAnswer(
|
||||
# tool=custom_tool_name,
|
||||
# tool_id=custom_tool_info.tool_id,
|
||||
# iteration_nr=iteration_nr,
|
||||
# parallelization_nr=parallelization_nr,
|
||||
# question=branch_query,
|
||||
# answer=answer_string,
|
||||
# claims=[],
|
||||
# cited_documents={},
|
||||
# reasoning="",
|
||||
# additional_data=None,
|
||||
# response_type=response_summary.response_type,
|
||||
# data=response_summary.tool_result,
|
||||
# file_ids=file_ids,
|
||||
# )
|
||||
# ],
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="custom_tool",
|
||||
# node_name="tool_calling",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,82 +1,82 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
# from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
# from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def custom_tool_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
# def custom_tool_reducer(
|
||||
# state: SubAgentMainState,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> SubAgentUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a generic tool call as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# node_start_time = datetime.now()
|
||||
|
||||
current_step_nr = state.current_step_nr
|
||||
# current_step_nr = state.current_step_nr
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
# branch_updates = state.branch_iteration_responses
|
||||
# current_iteration = state.iteration_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
# new_updates = [
|
||||
# update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
# ]
|
||||
|
||||
for new_update in new_updates:
|
||||
# for new_update in new_updates:
|
||||
|
||||
if not new_update.response_type:
|
||||
raise ValueError("Response type is not returned.")
|
||||
# if not new_update.response_type:
|
||||
# raise ValueError("Response type is not returned.")
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
CustomToolStart(
|
||||
tool_name=new_update.tool,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# CustomToolStart(
|
||||
# tool_name=new_update.tool,
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
CustomToolDelta(
|
||||
tool_name=new_update.tool,
|
||||
response_type=new_update.response_type,
|
||||
data=new_update.data,
|
||||
file_ids=new_update.file_ids,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# CustomToolDelta(
|
||||
# tool_name=new_update.tool,
|
||||
# response_type=new_update.response_type,
|
||||
# data=new_update.data,
|
||||
# file_ids=new_update.file_ids,
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# SectionEnd(),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
current_step_nr += 1
|
||||
# current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return SubAgentUpdate(
|
||||
# iteration_responses=new_updates,
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="custom_tool",
|
||||
# node_name="consolidation",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,28 +1,28 @@
|
||||
from collections.abc import Hashable
|
||||
# from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
# from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import (
|
||||
SubAgentInput,
|
||||
)
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import (
|
||||
# SubAgentInput,
|
||||
# )
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:1] # no parallel call for now
|
||||
)
|
||||
]
|
||||
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
# return [
|
||||
# Send(
|
||||
# "act",
|
||||
# BranchInput(
|
||||
# iteration_nr=state.iteration_nr,
|
||||
# parallelization_nr=parallelization_nr,
|
||||
# branch_question=query,
|
||||
# context="",
|
||||
# active_source_types=state.active_source_types,
|
||||
# tools_used=state.tools_used,
|
||||
# available_tools=state.available_tools,
|
||||
# ),
|
||||
# )
|
||||
# for parallelization_nr, query in enumerate(
|
||||
# state.query_list[:1] # no parallel call for now
|
||||
# )
|
||||
# ]
|
||||
|
||||
@@ -1,50 +1,50 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
# from langgraph.graph import END
|
||||
# from langgraph.graph import START
|
||||
# from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_1_branch import (
|
||||
custom_tool_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_2_act import (
|
||||
custom_tool_act,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_3_reduce import (
|
||||
custom_tool_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_1_branch import (
|
||||
# custom_tool_branch,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_2_act import (
|
||||
# custom_tool_act,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_3_reduce import (
|
||||
# custom_tool_reducer,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_conditional_edges import (
|
||||
# branching_router,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def dr_custom_tool_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Generic Tool Sub-Agent
|
||||
"""
|
||||
# def dr_custom_tool_graph_builder() -> StateGraph:
|
||||
# """
|
||||
# LangGraph graph builder for Generic Tool Sub-Agent
|
||||
# """
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
# ### Add nodes ###
|
||||
|
||||
graph.add_node("branch", custom_tool_branch)
|
||||
# graph.add_node("branch", custom_tool_branch)
|
||||
|
||||
graph.add_node("act", custom_tool_act)
|
||||
# graph.add_node("act", custom_tool_act)
|
||||
|
||||
graph.add_node("reducer", custom_tool_reducer)
|
||||
# graph.add_node("reducer", custom_tool_reducer)
|
||||
|
||||
### Add edges ###
|
||||
# ### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
# graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
# graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
# graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
# graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
# return graph
|
||||
|
||||
@@ -1,36 +1,36 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def generic_internal_tool_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
# def generic_internal_tool_branch(
|
||||
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> LoggerUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a generic tool call as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
# node_start_time = datetime.now()
|
||||
# iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
|
||||
# logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="generic_internal_tool",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return LoggerUpdate(
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="generic_internal_tool",
|
||||
# node_name="branching",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,149 +1,147 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# import json
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.messages import AIMessage
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
|
||||
from onyx.prompts.dr_prompts import OKTA_TOOL_USE_SPECIAL_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
|
||||
# from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
|
||||
# from onyx.prompts.dr_prompts import OKTA_TOOL_USE_SPECIAL_PROMPT
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def generic_internal_tool_act(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
# def generic_internal_tool_act(
|
||||
# state: BranchInput,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> BranchUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a generic tool call as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
# node_start_time = datetime.now()
|
||||
# iteration_nr = state.iteration_nr
|
||||
# parallelization_nr = state.parallelization_nr
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
# if not state.available_tools:
|
||||
# raise ValueError("available_tools is not set")
|
||||
|
||||
generic_internal_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
generic_internal_tool_name = generic_internal_tool_info.llm_path
|
||||
generic_internal_tool = generic_internal_tool_info.tool_object
|
||||
# generic_internal_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
# generic_internal_tool_name = generic_internal_tool_info.llm_path
|
||||
# generic_internal_tool = generic_internal_tool_info.tool_object
|
||||
|
||||
if generic_internal_tool is None:
|
||||
raise ValueError("generic_internal_tool is not set")
|
||||
# if generic_internal_tool is None:
|
||||
# raise ValueError("generic_internal_tool is not set")
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
# branch_query = state.branch_question
|
||||
# if not branch_query:
|
||||
# raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
logger.debug(
|
||||
f"Tool call start for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
# logger.debug(
|
||||
# f"Tool call start for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
# )
|
||||
|
||||
# get tool call args
|
||||
tool_args: dict | None = None
|
||||
if graph_config.tooling.using_tool_calling_llm:
|
||||
# get tool call args from tool-calling LLM
|
||||
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
|
||||
query=branch_query,
|
||||
base_question=base_question,
|
||||
tool_description=generic_internal_tool_info.description,
|
||||
)
|
||||
tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
|
||||
tool_use_prompt,
|
||||
tools=[generic_internal_tool.tool_definition()],
|
||||
tool_choice="required",
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
)
|
||||
# # get tool call args
|
||||
# tool_args: dict | None = None
|
||||
# if graph_config.tooling.using_tool_calling_llm:
|
||||
# # get tool call args from tool-calling LLM
|
||||
# tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
|
||||
# query=branch_query,
|
||||
# base_question=base_question,
|
||||
# tool_description=generic_internal_tool_info.description,
|
||||
# )
|
||||
# tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
|
||||
# tool_use_prompt,
|
||||
# tools=[generic_internal_tool.tool_definition()],
|
||||
# tool_choice="required",
|
||||
# timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
# )
|
||||
|
||||
# make sure we got a tool call
|
||||
if (
|
||||
isinstance(tool_calling_msg, AIMessage)
|
||||
and len(tool_calling_msg.tool_calls) == 1
|
||||
):
|
||||
tool_args = tool_calling_msg.tool_calls[0]["args"]
|
||||
else:
|
||||
logger.warning("Tool-calling LLM did not emit a tool call")
|
||||
# # make sure we got a tool call
|
||||
# if (
|
||||
# isinstance(tool_calling_msg, AIMessage)
|
||||
# and len(tool_calling_msg.tool_calls) == 1
|
||||
# ):
|
||||
# tool_args = tool_calling_msg.tool_calls[0]["args"]
|
||||
# else:
|
||||
# logger.warning("Tool-calling LLM did not emit a tool call")
|
||||
|
||||
if tool_args is None:
|
||||
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
|
||||
tool_args = generic_internal_tool.get_args_for_non_tool_calling_llm(
|
||||
query=branch_query,
|
||||
history=[],
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
force_run=True,
|
||||
)
|
||||
# if tool_args is None:
|
||||
# raise ValueError(
|
||||
# "Failed to obtain tool arguments from LLM - tool calling is required"
|
||||
# )
|
||||
|
||||
if tool_args is None:
|
||||
raise ValueError("Failed to obtain tool arguments from LLM")
|
||||
# # run the tool
|
||||
# tool_responses = list(generic_internal_tool.run(**tool_args))
|
||||
# final_data = generic_internal_tool.final_result(*tool_responses)
|
||||
# tool_result_str = json.dumps(final_data, ensure_ascii=False)
|
||||
|
||||
# run the tool
|
||||
tool_responses = list(generic_internal_tool.run(**tool_args))
|
||||
final_data = generic_internal_tool.final_result(*tool_responses)
|
||||
tool_result_str = json.dumps(final_data, ensure_ascii=False)
|
||||
# tool_str = (
|
||||
# f"Tool used: {generic_internal_tool.display_name}\n"
|
||||
# f"Description: {generic_internal_tool_info.description}\n"
|
||||
# f"Result: {tool_result_str}"
|
||||
# )
|
||||
|
||||
tool_str = (
|
||||
f"Tool used: {generic_internal_tool.display_name}\n"
|
||||
f"Description: {generic_internal_tool_info.description}\n"
|
||||
f"Result: {tool_result_str}"
|
||||
)
|
||||
# if generic_internal_tool.display_name == "Okta Profile":
|
||||
# tool_prompt = OKTA_TOOL_USE_SPECIAL_PROMPT
|
||||
# else:
|
||||
# tool_prompt = CUSTOM_TOOL_USE_PROMPT
|
||||
|
||||
if generic_internal_tool.display_name == "Okta Profile":
|
||||
tool_prompt = OKTA_TOOL_USE_SPECIAL_PROMPT
|
||||
else:
|
||||
tool_prompt = CUSTOM_TOOL_USE_PROMPT
|
||||
# tool_summary_prompt = tool_prompt.build(
|
||||
# query=branch_query, base_question=base_question, tool_response=tool_str
|
||||
# )
|
||||
# answer_string = str(
|
||||
# graph_config.tooling.primary_llm.invoke(
|
||||
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
|
||||
# ).content
|
||||
# ).strip()
|
||||
|
||||
tool_summary_prompt = tool_prompt.build(
|
||||
query=branch_query, base_question=base_question, tool_response=tool_str
|
||||
)
|
||||
answer_string = str(
|
||||
graph_config.tooling.primary_llm.invoke_langchain(
|
||||
tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
|
||||
).content
|
||||
).strip()
|
||||
# tool_summary_prompt = tool_prompt.build(
|
||||
# query=branch_query, base_question=base_question, tool_response=tool_str
|
||||
# )
|
||||
# answer_string = str(
|
||||
# graph_config.tooling.primary_llm.invoke_langchain(
|
||||
# tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
|
||||
# ).content
|
||||
# ).strip()
|
||||
|
||||
logger.debug(
|
||||
f"Tool call end for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=generic_internal_tool.llm_name,
|
||||
tool_id=generic_internal_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=[],
|
||||
cited_documents={},
|
||||
reasoning="",
|
||||
additional_data=None,
|
||||
response_type="text", # TODO: convert all response types to enums
|
||||
data=answer_string,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="tool_calling",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return BranchUpdate(
|
||||
# branch_iteration_responses=[
|
||||
# IterationAnswer(
|
||||
# tool=generic_internal_tool.name,
|
||||
# tool_id=generic_internal_tool_info.tool_id,
|
||||
# iteration_nr=iteration_nr,
|
||||
# parallelization_nr=parallelization_nr,
|
||||
# question=branch_query,
|
||||
# answer=answer_string,
|
||||
# claims=[],
|
||||
# cited_documents={},
|
||||
# reasoning="",
|
||||
# additional_data=None,
|
||||
# response_type="text", # TODO: convert all response types to enums
|
||||
# data=answer_string,
|
||||
# )
|
||||
# ],
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="custom_tool",
|
||||
# node_name="tool_calling",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,82 +1,82 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
# from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
# from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def generic_internal_tool_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
# def generic_internal_tool_reducer(
|
||||
# state: SubAgentMainState,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> SubAgentUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a generic tool call as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# node_start_time = datetime.now()
|
||||
|
||||
current_step_nr = state.current_step_nr
|
||||
# current_step_nr = state.current_step_nr
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
# branch_updates = state.branch_iteration_responses
|
||||
# current_iteration = state.iteration_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
# new_updates = [
|
||||
# update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
# ]
|
||||
|
||||
for new_update in new_updates:
|
||||
# for new_update in new_updates:
|
||||
|
||||
if not new_update.response_type:
|
||||
raise ValueError("Response type is not returned.")
|
||||
# if not new_update.response_type:
|
||||
# raise ValueError("Response type is not returned.")
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
CustomToolStart(
|
||||
tool_name=new_update.tool,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# CustomToolStart(
|
||||
# tool_name=new_update.tool,
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
CustomToolDelta(
|
||||
tool_name=new_update.tool,
|
||||
response_type=new_update.response_type,
|
||||
data=new_update.data,
|
||||
file_ids=[],
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# CustomToolDelta(
|
||||
# tool_name=new_update.tool,
|
||||
# response_type=new_update.response_type,
|
||||
# data=new_update.data,
|
||||
# file_ids=[],
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# SectionEnd(),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
current_step_nr += 1
|
||||
# current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return SubAgentUpdate(
|
||||
# iteration_responses=new_updates,
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="custom_tool",
|
||||
# node_name="consolidation",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,28 +1,28 @@
|
||||
from collections.abc import Hashable
|
||||
# from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
# from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import (
|
||||
SubAgentInput,
|
||||
)
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import (
|
||||
# SubAgentInput,
|
||||
# )
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:1] # no parallel call for now
|
||||
)
|
||||
]
|
||||
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
# return [
|
||||
# Send(
|
||||
# "act",
|
||||
# BranchInput(
|
||||
# iteration_nr=state.iteration_nr,
|
||||
# parallelization_nr=parallelization_nr,
|
||||
# branch_question=query,
|
||||
# context="",
|
||||
# active_source_types=state.active_source_types,
|
||||
# tools_used=state.tools_used,
|
||||
# available_tools=state.available_tools,
|
||||
# ),
|
||||
# )
|
||||
# for parallelization_nr, query in enumerate(
|
||||
# state.query_list[:1] # no parallel call for now
|
||||
# )
|
||||
# ]
|
||||
|
||||
@@ -1,50 +1,50 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
# from langgraph.graph import END
|
||||
# from langgraph.graph import START
|
||||
# from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_1_branch import (
|
||||
generic_internal_tool_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_2_act import (
|
||||
generic_internal_tool_act,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_3_reduce import (
|
||||
generic_internal_tool_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_1_branch import (
|
||||
# generic_internal_tool_branch,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_2_act import (
|
||||
# generic_internal_tool_act,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_3_reduce import (
|
||||
# generic_internal_tool_reducer,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_conditional_edges import (
|
||||
# branching_router,
|
||||
# )
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def dr_generic_internal_tool_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Generic Tool Sub-Agent
|
||||
"""
|
||||
# def dr_generic_internal_tool_graph_builder() -> StateGraph:
|
||||
# """
|
||||
# LangGraph graph builder for Generic Tool Sub-Agent
|
||||
# """
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
# graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
# ### Add nodes ###
|
||||
|
||||
graph.add_node("branch", generic_internal_tool_branch)
|
||||
# graph.add_node("branch", generic_internal_tool_branch)
|
||||
|
||||
graph.add_node("act", generic_internal_tool_act)
|
||||
# graph.add_node("act", generic_internal_tool_act)
|
||||
|
||||
graph.add_node("reducer", generic_internal_tool_reducer)
|
||||
# graph.add_node("reducer", generic_internal_tool_reducer)
|
||||
|
||||
### Add edges ###
|
||||
# ### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
# graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
# graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
# graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
# graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
# return graph
|
||||
|
||||
@@ -1,45 +1,45 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def image_generation_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a image generation as part of the DR process.
|
||||
"""
|
||||
# def image_generation_branch(
|
||||
# state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
# ) -> LoggerUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a image generation as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
# node_start_time = datetime.now()
|
||||
# iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Image generation start {iteration_nr} at {datetime.now()}")
|
||||
# logger.debug(f"Image generation start {iteration_nr} at {datetime.now()}")
|
||||
|
||||
# tell frontend that we are starting the image generation tool
|
||||
write_custom_event(
|
||||
state.current_step_nr,
|
||||
ImageGenerationToolStart(),
|
||||
writer,
|
||||
)
|
||||
# # tell frontend that we are starting the image generation tool
|
||||
# write_custom_event(
|
||||
# state.current_step_nr,
|
||||
# ImageGenerationToolStart(),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="image_generation",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return LoggerUpdate(
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="image_generation",
|
||||
# node_name="branching",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,189 +1,187 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
# import json
|
||||
# from datetime import datetime
|
||||
# from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.models import GeneratedImage
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.file_store.utils import build_frontend_file_url
|
||||
from onyx.file_store.utils import save_files
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
IMAGE_GENERATION_HEARTBEAT_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
IMAGE_GENERATION_RESPONSE_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.models import GeneratedImage
|
||||
# from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
# from onyx.agents.agent_search.models import GraphConfig
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.file_store.utils import build_frontend_file_url
|
||||
# from onyx.file_store.utils import save_files
|
||||
# from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
|
||||
# from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
# IMAGE_GENERATION_HEARTBEAT_ID,
|
||||
# )
|
||||
# from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
# IMAGE_GENERATION_RESPONSE_ID,
|
||||
# )
|
||||
# from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
# ImageGenerationResponse,
|
||||
# )
|
||||
# from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
# ImageGenerationTool,
|
||||
# )
|
||||
# from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def image_generation(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
# def image_generation(
|
||||
# state: BranchInput,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> BranchUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a standard search as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
state.assistant_system_prompt
|
||||
state.assistant_task_prompt
|
||||
# node_start_time = datetime.now()
|
||||
# iteration_nr = state.iteration_nr
|
||||
# parallelization_nr = state.parallelization_nr
|
||||
# state.assistant_system_prompt
|
||||
# state.assistant_task_prompt
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
# branch_query = state.branch_question
|
||||
# if not branch_query:
|
||||
# raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
graph_config.inputs.prompt_builder.raw_user_query
|
||||
graph_config.behavior.research_type
|
||||
# graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
# if not state.available_tools:
|
||||
# raise ValueError("available_tools is not set")
|
||||
|
||||
image_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
|
||||
# image_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
# image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
|
||||
|
||||
image_prompt = branch_query
|
||||
requested_shape: ImageShape | None = None
|
||||
# image_prompt = branch_query
|
||||
# requested_shape: ImageShape | None = None
|
||||
|
||||
try:
|
||||
parsed_query = json.loads(branch_query)
|
||||
except json.JSONDecodeError:
|
||||
parsed_query = None
|
||||
# try:
|
||||
# parsed_query = json.loads(branch_query)
|
||||
# except json.JSONDecodeError:
|
||||
# parsed_query = None
|
||||
|
||||
if isinstance(parsed_query, dict):
|
||||
prompt_from_llm = parsed_query.get("prompt")
|
||||
if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
|
||||
image_prompt = prompt_from_llm.strip()
|
||||
# if isinstance(parsed_query, dict):
|
||||
# prompt_from_llm = parsed_query.get("prompt")
|
||||
# if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
|
||||
# image_prompt = prompt_from_llm.strip()
|
||||
|
||||
raw_shape = parsed_query.get("shape")
|
||||
if isinstance(raw_shape, str):
|
||||
try:
|
||||
requested_shape = ImageShape(raw_shape)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"Received unsupported image shape '%s' from LLM. Falling back to square.",
|
||||
raw_shape,
|
||||
)
|
||||
# raw_shape = parsed_query.get("shape")
|
||||
# if isinstance(raw_shape, str):
|
||||
# try:
|
||||
# requested_shape = ImageShape(raw_shape)
|
||||
# except ValueError:
|
||||
# logger.warning(
|
||||
# "Received unsupported image shape '%s' from LLM. Falling back to square.",
|
||||
# raw_shape,
|
||||
# )
|
||||
|
||||
logger.debug(
|
||||
f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
# logger.debug(
|
||||
# f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
# )
|
||||
|
||||
# Generate images using the image generation tool
|
||||
image_generation_responses: list[ImageGenerationResponse] = []
|
||||
# # Generate images using the image generation tool
|
||||
# image_generation_responses: list[ImageGenerationResponse] = []
|
||||
|
||||
if requested_shape is not None:
|
||||
tool_iterator = image_tool.run(
|
||||
prompt=image_prompt,
|
||||
shape=requested_shape.value,
|
||||
)
|
||||
else:
|
||||
tool_iterator = image_tool.run(prompt=image_prompt)
|
||||
# if requested_shape is not None:
|
||||
# tool_iterator = image_tool.run(
|
||||
# prompt=image_prompt,
|
||||
# shape=requested_shape.value,
|
||||
# )
|
||||
# else:
|
||||
# tool_iterator = image_tool.run(prompt=image_prompt)
|
||||
|
||||
for tool_response in tool_iterator:
|
||||
if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
|
||||
# Stream heartbeat to frontend
|
||||
write_custom_event(
|
||||
state.current_step_nr,
|
||||
ImageGenerationToolHeartbeat(),
|
||||
writer,
|
||||
)
|
||||
elif tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
response = cast(list[ImageGenerationResponse], tool_response.response)
|
||||
image_generation_responses = response
|
||||
break
|
||||
# for tool_response in tool_iterator:
|
||||
# if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
|
||||
# # Stream heartbeat to frontend
|
||||
# write_custom_event(
|
||||
# state.current_step_nr,
|
||||
# ImageGenerationToolHeartbeat(),
|
||||
# writer,
|
||||
# )
|
||||
# elif tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
# response = cast(list[ImageGenerationResponse], tool_response.response)
|
||||
# image_generation_responses = response
|
||||
# break
|
||||
|
||||
# save images to file store
|
||||
file_ids = save_files(
|
||||
urls=[],
|
||||
base64_files=[img.image_data for img in image_generation_responses],
|
||||
)
|
||||
# # save images to file store
|
||||
# file_ids = save_files(
|
||||
# urls=[],
|
||||
# base64_files=[img.image_data for img in image_generation_responses],
|
||||
# )
|
||||
|
||||
final_generated_images = [
|
||||
GeneratedImage(
|
||||
file_id=file_id,
|
||||
url=build_frontend_file_url(file_id),
|
||||
revised_prompt=img.revised_prompt,
|
||||
shape=(requested_shape or ImageShape.SQUARE).value,
|
||||
)
|
||||
for file_id, img in zip(file_ids, image_generation_responses)
|
||||
]
|
||||
# final_generated_images = [
|
||||
# GeneratedImage(
|
||||
# file_id=file_id,
|
||||
# url=build_frontend_file_url(file_id),
|
||||
# revised_prompt=img.revised_prompt,
|
||||
# shape=(requested_shape or ImageShape.SQUARE).value,
|
||||
# )
|
||||
# for file_id, img in zip(file_ids, image_generation_responses)
|
||||
# ]
|
||||
|
||||
logger.debug(
|
||||
f"Image generation complete for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
# logger.debug(
|
||||
# f"Image generation complete for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
# )
|
||||
|
||||
# Create answer string describing the generated images
|
||||
if final_generated_images:
|
||||
image_descriptions = []
|
||||
for i, img in enumerate(final_generated_images, 1):
|
||||
if img.shape and img.shape != ImageShape.SQUARE.value:
|
||||
image_descriptions.append(
|
||||
f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
|
||||
)
|
||||
else:
|
||||
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
|
||||
# # Create answer string describing the generated images
|
||||
# if final_generated_images:
|
||||
# image_descriptions = []
|
||||
# for i, img in enumerate(final_generated_images, 1):
|
||||
# if img.shape and img.shape != ImageShape.SQUARE.value:
|
||||
# image_descriptions.append(
|
||||
# f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
|
||||
# )
|
||||
# else:
|
||||
# image_descriptions.append(f"Image {i}: {img.revised_prompt}")
|
||||
|
||||
answer_string = (
|
||||
f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
|
||||
+ "\n".join(image_descriptions)
|
||||
)
|
||||
if requested_shape:
|
||||
reasoning = (
|
||||
"Used image generation tool to create "
|
||||
f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
|
||||
)
|
||||
else:
|
||||
reasoning = (
|
||||
"Used image generation tool to create "
|
||||
f"{len(final_generated_images)} image(s) based on the user's request."
|
||||
)
|
||||
else:
|
||||
answer_string = f"Failed to generate images for request: {image_prompt}"
|
||||
reasoning = "Image generation tool did not return any results."
|
||||
# answer_string = (
|
||||
# f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
|
||||
# + "\n".join(image_descriptions)
|
||||
# )
|
||||
# if requested_shape:
|
||||
# reasoning = (
|
||||
# "Used image generation tool to create "
|
||||
# f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
|
||||
# )
|
||||
# else:
|
||||
# reasoning = (
|
||||
# "Used image generation tool to create "
|
||||
# f"{len(final_generated_images)} image(s) based on the user's request."
|
||||
# )
|
||||
# else:
|
||||
# answer_string = f"Failed to generate images for request: {image_prompt}"
|
||||
# reasoning = "Image generation tool did not return any results."
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=image_tool_info.llm_path,
|
||||
tool_id=image_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=[],
|
||||
cited_documents={},
|
||||
reasoning=reasoning,
|
||||
generated_images=final_generated_images,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="image_generation",
|
||||
node_name="generating",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return BranchUpdate(
|
||||
# branch_iteration_responses=[
|
||||
# IterationAnswer(
|
||||
# tool=image_tool_info.llm_path,
|
||||
# tool_id=image_tool_info.tool_id,
|
||||
# iteration_nr=iteration_nr,
|
||||
# parallelization_nr=parallelization_nr,
|
||||
# question=branch_query,
|
||||
# answer=answer_string,
|
||||
# claims=[],
|
||||
# cited_documents={},
|
||||
# reasoning=reasoning,
|
||||
# generated_images=final_generated_images,
|
||||
# )
|
||||
# ],
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="image_generation",
|
||||
# node_name="generating",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,71 +1,71 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
# from langchain_core.runnables import RunnableConfig
|
||||
# from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.models import GeneratedImage
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
# from onyx.agents.agent_search.dr.models import GeneratedImage
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
# get_langgraph_node_log_string,
|
||||
# )
|
||||
# from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
# from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
|
||||
# from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
# from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
# logger = setup_logger()
|
||||
|
||||
|
||||
def is_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
# def is_reducer(
|
||||
# state: SubAgentMainState,
|
||||
# config: RunnableConfig,
|
||||
# writer: StreamWriter = lambda _: None,
|
||||
# ) -> SubAgentUpdate:
|
||||
# """
|
||||
# LangGraph node to perform a standard search as part of the DR process.
|
||||
# """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# node_start_time = datetime.now()
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
# branch_updates = state.branch_iteration_responses
|
||||
# current_iteration = state.iteration_nr
|
||||
# current_step_nr = state.current_step_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
generated_images: list[GeneratedImage] = []
|
||||
for update in new_updates:
|
||||
if update.generated_images:
|
||||
generated_images.extend(update.generated_images)
|
||||
# new_updates = [
|
||||
# update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
# ]
|
||||
# generated_images: list[GeneratedImage] = []
|
||||
# for update in new_updates:
|
||||
# if update.generated_images:
|
||||
# generated_images.extend(update.generated_images)
|
||||
|
||||
# Write the results to the stream
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
ImageGenerationToolDelta(
|
||||
images=generated_images,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# # Write the results to the stream
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# ImageGenerationFinal(
|
||||
# images=generated_images,
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
# write_custom_event(
|
||||
# current_step_nr,
|
||||
# SectionEnd(),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
current_step_nr += 1
|
||||
# current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="image_generation",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
# return SubAgentUpdate(
|
||||
# iteration_responses=new_updates,
|
||||
# current_step_nr=current_step_nr,
|
||||
# log_messages=[
|
||||
# get_langgraph_node_log_string(
|
||||
# graph_component="image_generation",
|
||||
# node_name="consolidation",
|
||||
# node_start_time=node_start_time,
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
|
||||
@@ -1,29 +1,29 @@
|
||||
from collections.abc import Hashable
|
||||
# from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
# from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
# from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
# from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:MAX_DR_PARALLEL_SEARCH]
|
||||
)
|
||||
]
|
||||
# def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
# return [
|
||||
# Send(
|
||||
# "act",
|
||||
# BranchInput(
|
||||
# iteration_nr=state.iteration_nr,
|
||||
# parallelization_nr=parallelization_nr,
|
||||
# branch_question=query,
|
||||
# context="",
|
||||
# active_source_types=state.active_source_types,
|
||||
# tools_used=state.tools_used,
|
||||
# available_tools=state.available_tools,
|
||||
# assistant_system_prompt=state.assistant_system_prompt,
|
||||
# assistant_task_prompt=state.assistant_task_prompt,
|
||||
# ),
|
||||
# )
|
||||
# for parallelization_nr, query in enumerate(
|
||||
# state.query_list[:MAX_DR_PARALLEL_SEARCH]
|
||||
# )
|
||||
# ]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user