mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-16 23:35:46 +00:00
Compare commits
129 Commits
v2.0.2-clo
...
nightly-la
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bffca81477 | ||
|
|
561b487102 | ||
|
|
cc9b14c99b | ||
|
|
de674a19e0 | ||
|
|
79114bf92c | ||
|
|
b5dccd96b3 | ||
|
|
a55cc5a537 | ||
|
|
cdf3cc444b | ||
|
|
cd3941f4b7 | ||
|
|
0182743619 | ||
|
|
0e2f596aa2 | ||
|
|
0be45676b7 | ||
|
|
30a3470001 | ||
|
|
d52fa83afa | ||
|
|
9eb5643cc3 | ||
|
|
afe34218b8 | ||
|
|
4776947dfa | ||
|
|
c4bc25f540 | ||
|
|
b77078b339 | ||
|
|
88b28a303b | ||
|
|
59d7d3905a | ||
|
|
a48fe7550a | ||
|
|
c25a99955c | ||
|
|
ac509f865a | ||
|
|
5819389ae8 | ||
|
|
eae5774cdc | ||
|
|
8fed0a8138 | ||
|
|
c04196941d | ||
|
|
19461955ed | ||
|
|
cb3152ff5c | ||
|
|
cf187e8f58 | ||
|
|
deaa3df42f | ||
|
|
d6e98bfbc8 | ||
|
|
ff58ad0b87 | ||
|
|
eb7cb02cc0 | ||
|
|
7876d8da1b | ||
|
|
8a6f83115e | ||
|
|
b7f81aed10 | ||
|
|
a415a997cf | ||
|
|
7781afd74e | ||
|
|
d0a4f4ce66 | ||
|
|
ba00de8904 | ||
|
|
91f21bb22b | ||
|
|
491f3127c5 | ||
|
|
0987fb852b | ||
|
|
5f68141335 | ||
|
|
b5793ee522 | ||
|
|
238c244fec | ||
|
|
c103a878b7 | ||
|
|
03deb064cc | ||
|
|
09062195b4 | ||
|
|
dc57a5451c | ||
|
|
781f60a5ab | ||
|
|
423961fefb | ||
|
|
324b6ceeef | ||
|
|
d9e14bf5da | ||
|
|
eb2cb1bb25 | ||
|
|
0de9f47694 | ||
|
|
2757f3936c | ||
|
|
8ba61e9123 | ||
|
|
c10d7fbc32 | ||
|
|
b6ed217781 | ||
|
|
7d20f73f71 | ||
|
|
2b306255f9 | ||
|
|
e149d08d47 | ||
|
|
e98ddb9fe6 | ||
|
|
b9a5297694 | ||
|
|
4666312df2 | ||
|
|
d4e524cd83 | ||
|
|
a719228034 | ||
|
|
2fe8b5e33a | ||
|
|
af243b0ef5 | ||
|
|
c96ac04619 | ||
|
|
e2f2950fee | ||
|
|
8b84c59d29 | ||
|
|
b718a276cf | ||
|
|
700511720f | ||
|
|
6bd1719156 | ||
|
|
c8bfe9e0a1 | ||
|
|
037bc04740 | ||
|
|
c3704d47df | ||
|
|
397a153ff6 | ||
|
|
870c432ccf | ||
|
|
c4a81a590f | ||
|
|
017c095eed | ||
|
|
ee37d21aa4 | ||
|
|
e492d88b2d | ||
|
|
3512fdcd9d | ||
|
|
3550795cab | ||
|
|
b26306d678 | ||
|
|
85140b4ba6 | ||
|
|
c241f79f97 | ||
|
|
9808dec6b7 | ||
|
|
632c74af6d | ||
|
|
79073d878c | ||
|
|
620df88c51 | ||
|
|
717f05975d | ||
|
|
d2176342c1 | ||
|
|
bb198b05e1 | ||
|
|
085013d8c3 | ||
|
|
e46f632570 | ||
|
|
bbb4b9eda3 | ||
|
|
12b7c7d4dd | ||
|
|
464967340b | ||
|
|
a2308c2f45 | ||
|
|
2ee9f79f71 | ||
|
|
c3904b7c96 | ||
|
|
5009dcf911 | ||
|
|
c7b4a0fad9 | ||
|
|
60a402fcab | ||
|
|
c9bb078a37 | ||
|
|
c36c2a6c8d | ||
|
|
f9e2f9cbb4 | ||
|
|
0b7c808480 | ||
|
|
0a6ff30ee4 | ||
|
|
dc036eb452 | ||
|
|
ee950b9cbd | ||
|
|
dd71765849 | ||
|
|
dc6b97f1b1 | ||
|
|
d960c23b6a | ||
|
|
d9c753ba92 | ||
|
|
60234dd6da | ||
|
|
f88ef2e9ff | ||
|
|
6b479a01ea | ||
|
|
248fe416e1 | ||
|
|
cbea4bb75c | ||
|
|
4a147a48dc | ||
|
|
a77025cd46 | ||
|
|
d10914ccc6 |
50
.github/actions/prepare-build/action.yml
vendored
Normal file
50
.github/actions/prepare-build/action.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
name: "Prepare Build (OpenAPI generation)"
|
||||
description: "Sets up Python with uv, installs deps, generates OpenAPI schema and Python client, uploads artifact"
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@v3
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Install Python dependencies with uv
|
||||
shell: bash
|
||||
run: |
|
||||
uv pip install --system \
|
||||
-r backend/requirements/default.txt \
|
||||
-r backend/requirements/dev.txt
|
||||
|
||||
- name: Generate OpenAPI schema
|
||||
shell: bash
|
||||
working-directory: backend
|
||||
env:
|
||||
PYTHONPATH: "."
|
||||
run: |
|
||||
python scripts/onyx_openapi_schema.py --filename generated/openapi.json
|
||||
|
||||
- name: Generate OpenAPI Python client
|
||||
shell: bash
|
||||
run: |
|
||||
docker run --rm \
|
||||
-v "${{ github.workspace }}/backend/generated:/local" \
|
||||
openapitools/openapi-generator-cli generate \
|
||||
-i /local/openapi.json \
|
||||
-g python \
|
||||
-o /local/onyx_openapi_client \
|
||||
--package-name onyx_openapi_client \
|
||||
--skip-validate-spec \
|
||||
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
|
||||
|
||||
- name: Upload OpenAPI artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: openapi-artifacts
|
||||
path: backend/generated/
|
||||
|
||||
@@ -42,6 +42,11 @@ jobs:
|
||||
else
|
||||
echo "is_stable=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+-beta\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
|
||||
echo "is_beta=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "is_beta=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
@@ -57,6 +62,7 @@ jobs:
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_beta == 'true' && 'beta' || '' }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -118,6 +124,11 @@ jobs:
|
||||
else
|
||||
echo "is_stable=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+-beta\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
|
||||
echo "is_beta=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "is_beta=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v4
|
||||
@@ -140,6 +151,7 @@ jobs:
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_beta == 'true' && 'beta' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
|
||||
@@ -88,7 +88,7 @@ jobs:
|
||||
push: true
|
||||
tags: ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
@@ -134,7 +134,7 @@ jobs:
|
||||
push: true
|
||||
tags: ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
@@ -153,6 +153,11 @@ jobs:
|
||||
else
|
||||
echo "is_stable=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+-beta\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
|
||||
echo "is_beta=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "is_beta=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
@@ -176,6 +181,11 @@ jobs:
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
fi
|
||||
if [[ "${{ steps.check_version.outputs.is_beta }}" == "true" ]]; then
|
||||
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:beta \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
fi
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@v3
|
||||
|
||||
@@ -56,6 +56,11 @@ jobs:
|
||||
else
|
||||
echo "is_stable=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+-beta\.[0-9]+$ ]]; then
|
||||
echo "is_beta=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "is_beta=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -71,6 +76,7 @@ jobs:
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_beta == 'true' && 'beta' || '' }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -128,6 +134,11 @@ jobs:
|
||||
else
|
||||
echo "is_stable=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+-beta\.[0-9]+$ ]]; then
|
||||
echo "is_beta=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "is_beta=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v4
|
||||
@@ -150,6 +161,7 @@ jobs:
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_beta == 'true' && 'beta' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
|
||||
41
.github/workflows/docker-tag-beta.yml
vendored
Normal file
41
.github/workflows/docker-tag-beta.yml
vendored
Normal file
@@ -0,0 +1,41 @@
|
||||
# This workflow is set up to be manually triggered via the GitHub Action tab.
|
||||
# Given a version, it will tag those backend and webserver images as "beta".
|
||||
|
||||
name: Tag Beta Version
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: "The version (ie v1.0.0-beta.0) to tag as beta"
|
||||
required: true
|
||||
|
||||
jobs:
|
||||
tag:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# use a lower powered instance since this just does i/o to docker hub
|
||||
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Enable Docker CLI experimental features
|
||||
run: echo "DOCKER_CLI_EXPERIMENTAL=enabled" >> $GITHUB_ENV
|
||||
|
||||
- name: Pull, Tag and Push Web Server Image
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-web-server:beta onyxdotapp/onyx-web-server:${{ github.event.inputs.version }}
|
||||
|
||||
- name: Pull, Tag and Push API Server Image
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-backend:beta onyxdotapp/onyx-backend:${{ github.event.inputs.version }}
|
||||
|
||||
- name: Pull, Tag and Push Model Server Image
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-model-server:beta onyxdotapp/onyx-model-server:${{ github.event.inputs.version }}
|
||||
4
.github/workflows/pr-helm-chart-testing.yml
vendored
4
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -19,9 +19,9 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4.2.0
|
||||
uses: azure/setup-helm@v4.3.1
|
||||
with:
|
||||
version: v3.17.0
|
||||
version: v3.19.0
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.7.0
|
||||
|
||||
62
.github/workflows/pr-integration-tests.yml
vendored
62
.github/workflows/pr-integration-tests.yml
vendored
@@ -31,6 +31,7 @@ env:
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
@@ -67,46 +68,8 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
- name: Generate OpenAPI schema
|
||||
working-directory: ./backend
|
||||
env:
|
||||
PYTHONPATH: "."
|
||||
run: |
|
||||
python scripts/onyx_openapi_schema.py --filename generated/openapi.json
|
||||
|
||||
- name: Generate OpenAPI Python client
|
||||
working-directory: ./backend
|
||||
run: |
|
||||
docker run --rm \
|
||||
-v "${{ github.workspace }}/backend/generated:/local" \
|
||||
openapitools/openapi-generator-cli generate \
|
||||
-i /local/openapi.json \
|
||||
-g python \
|
||||
-o /local/onyx_openapi_client \
|
||||
--package-name onyx_openapi_client \
|
||||
--skip-validate-spec \
|
||||
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
|
||||
|
||||
- name: Upload OpenAPI artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: openapi-artifacts
|
||||
path: backend/generated/
|
||||
- name: Prepare build
|
||||
uses: ./.github/actions/prepare-build
|
||||
|
||||
build-backend-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
@@ -135,6 +98,7 @@ jobs:
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
@@ -161,7 +125,8 @@ jobs:
|
||||
push: true
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
no-cache: true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
|
||||
build-integration-image:
|
||||
needs: prepare-build
|
||||
@@ -186,16 +151,11 @@ jobs:
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push integration test Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
- name: Build and push integration test image with Docker Bake
|
||||
env:
|
||||
REGISTRY: ${{ env.PRIVATE_REGISTRY }}
|
||||
TAG: test-${{ github.run_id }}
|
||||
run: cd backend && docker buildx bake --no-cache --push integration
|
||||
|
||||
integration-tests:
|
||||
needs:
|
||||
|
||||
61
.github/workflows/pr-mit-integration-tests.yml
vendored
61
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -64,46 +64,8 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
- name: Generate OpenAPI schema
|
||||
working-directory: ./backend
|
||||
env:
|
||||
PYTHONPATH: "."
|
||||
run: |
|
||||
python scripts/onyx_openapi_schema.py --filename generated/openapi.json
|
||||
|
||||
- name: Generate OpenAPI Python client
|
||||
working-directory: ./backend
|
||||
run: |
|
||||
docker run --rm \
|
||||
-v "${{ github.workspace }}/backend/generated:/local" \
|
||||
openapitools/openapi-generator-cli generate \
|
||||
-i /local/openapi.json \
|
||||
-g python \
|
||||
-o /local/onyx_openapi_client \
|
||||
--package-name onyx_openapi_client \
|
||||
--skip-validate-spec \
|
||||
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
|
||||
|
||||
- name: Upload OpenAPI artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: openapi-artifacts
|
||||
path: backend/generated/
|
||||
- name: Prepare build
|
||||
uses: ./.github/actions/prepare-build
|
||||
|
||||
build-backend-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
@@ -132,6 +94,7 @@ jobs:
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
@@ -158,7 +121,8 @@ jobs:
|
||||
push: true
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
no-cache: true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
|
||||
build-integration-image:
|
||||
needs: prepare-build
|
||||
@@ -183,16 +147,11 @@ jobs:
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push integration test Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
- name: Build and push integration test image with Docker Bake
|
||||
env:
|
||||
REGISTRY: ${{ env.PRIVATE_REGISTRY }}
|
||||
TAG: test-${{ github.run_id }}
|
||||
run: cd backend && docker buildx bake --no-cache --push integration
|
||||
|
||||
integration-tests-mit:
|
||||
needs:
|
||||
|
||||
20
.github/workflows/pr-playwright-tests.yml
vendored
20
.github/workflows/pr-playwright-tests.yml
vendored
@@ -12,7 +12,7 @@ env:
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID_ECR }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_ECR }}
|
||||
BUILDX_NO_DEFAULT_ATTESTATIONS: 1
|
||||
|
||||
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
@@ -57,7 +57,7 @@ jobs:
|
||||
sbom: false
|
||||
push: true
|
||||
outputs: type=registry
|
||||
# no-cache: true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-backend-image:
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
@@ -90,7 +90,7 @@ jobs:
|
||||
sbom: false
|
||||
push: true
|
||||
outputs: type=registry
|
||||
# no-cache: true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
@@ -123,7 +123,7 @@ jobs:
|
||||
sbom: false
|
||||
push: true
|
||||
outputs: type=registry
|
||||
# no-cache: true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
playwright-tests:
|
||||
needs: [build-web-image, build-backend-image, build-model-server-image]
|
||||
@@ -215,15 +215,15 @@ jobs:
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
# Use curl with error handling to ignore specific exit code 56
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
|
||||
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
@@ -232,7 +232,7 @@ jobs:
|
||||
else
|
||||
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
@@ -247,9 +247,9 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
# Includes test results and debug screenshots
|
||||
# Includes test results and trace.zip files
|
||||
name: playwright-test-results-${{ github.run_id }}
|
||||
path: ./web/test-results
|
||||
path: ./web/test-results/
|
||||
retention-days: 30
|
||||
|
||||
# save before stopping the containers so the logs can be captured
|
||||
|
||||
47
.github/workflows/sync_foss.yml
vendored
Normal file
47
.github/workflows/sync_foss.yml
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
name: Sync FOSS Repo
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# Run daily at 3am PT (11am UTC during PST)
|
||||
- cron: '0 11 * * *'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
sync-foss:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout main Onyx repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install git-filter-repo
|
||||
run: |
|
||||
sudo apt-get update && sudo apt-get install -y git-filter-repo
|
||||
|
||||
- name: Configure SSH for deploy key
|
||||
env:
|
||||
FOSS_REPO_DEPLOY_KEY: ${{ secrets.FOSS_REPO_DEPLOY_KEY }}
|
||||
run: |
|
||||
mkdir -p ~/.ssh
|
||||
echo "$FOSS_REPO_DEPLOY_KEY" > ~/.ssh/id_ed25519
|
||||
chmod 600 ~/.ssh/id_ed25519
|
||||
ssh-keyscan github.com >> ~/.ssh/known_hosts
|
||||
|
||||
- name: Set Git config
|
||||
run: |
|
||||
git config --global user.name "onyx-bot"
|
||||
git config --global user.email "bot@onyx.app"
|
||||
|
||||
- name: Build FOSS version
|
||||
run: bash backend/scripts/make_foss_repo.sh
|
||||
|
||||
- name: Push to FOSS repo
|
||||
env:
|
||||
FOSS_REPO_URL: git@github.com:onyx-dot-app/onyx-foss.git
|
||||
run: |
|
||||
cd /tmp/foss_repo
|
||||
git remote add public "$FOSS_REPO_URL"
|
||||
git push --force public main
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -31,6 +31,11 @@ settings.json
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
*.egg-info
|
||||
|
||||
# Claude
|
||||
AGENTS.md
|
||||
CLAUDE.md
|
||||
|
||||
|
||||
# Local .terraform directories
|
||||
**/.terraform/*
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ repos:
|
||||
rev: v0.11.4
|
||||
hooks:
|
||||
- id: ruff
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||
rev: v3.1.0
|
||||
hooks:
|
||||
@@ -36,14 +37,32 @@ repos:
|
||||
types_or: [html, css, javascript, ts, tsx]
|
||||
language_version: system
|
||||
|
||||
- repo: https://github.com/sirwart/ripsecrets
|
||||
rev: v0.1.11
|
||||
hooks:
|
||||
- id: ripsecrets
|
||||
args:
|
||||
- --additional-pattern
|
||||
- ^sk-[A-Za-z0-9_\-]{20,}$
|
||||
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: terraform-fmt
|
||||
name: terraform fmt
|
||||
entry: terraform fmt -recursive
|
||||
language: system
|
||||
pass_filenames: false
|
||||
files: \.tf$
|
||||
- id: check-lazy-imports
|
||||
name: Check lazy imports are not directly imported
|
||||
entry: python3 backend/scripts/check_lazy_imports.py
|
||||
language: system
|
||||
files: ^backend/.*\.py$
|
||||
files: ^backend/(?!\.venv/).*\.py$
|
||||
pass_filenames: false
|
||||
# Note: pass_filenames is false because tsc must check the entire
|
||||
# project, but the files filter ensures this only runs when relevant
|
||||
# files change. Using --incremental for faster subsequent checks.
|
||||
|
||||
# We would like to have a mypy pre-commit hook, but due to the fact that
|
||||
# pre-commit runs in it's own isolated environment, we would need to install
|
||||
|
||||
13
.vscode/env_template.txt
vendored
13
.vscode/env_template.txt
vendored
@@ -1,6 +1,6 @@
|
||||
# Copy this file to .env in the .vscode folder
|
||||
# Fill in the <REPLACE THIS> values as needed, it is recommended to set the GEN_AI_API_KEY value to avoid having to set up an LLM in the UI
|
||||
# Also check out danswer/backend/scripts/restart_containers.sh for a script to restart the containers which Danswer relies on outside of VSCode/Cursor processes
|
||||
# Also check out onyx/backend/scripts/restart_containers.sh for a script to restart the containers which Onyx relies on outside of VSCode/Cursor processes
|
||||
|
||||
# For local dev, often user Authentication is not needed
|
||||
AUTH_TYPE=disabled
|
||||
@@ -37,8 +37,8 @@ OPENAI_API_KEY=<REPLACE THIS>
|
||||
GEN_AI_MODEL_VERSION=gpt-4o
|
||||
FAST_GEN_AI_MODEL_VERSION=gpt-4o
|
||||
|
||||
# For Danswer Slack Bot, overrides the UI values so no need to set this up via UI every time
|
||||
# Only needed if using DanswerBot
|
||||
# For Onyx Slack Bot, overrides the UI values so no need to set this up via UI every time
|
||||
# Only needed if using OnyxBot
|
||||
#ONYX_BOT_SLACK_APP_TOKEN=<REPLACE THIS>
|
||||
#ONYX_BOT_SLACK_BOT_TOKEN=<REPLACE THIS>
|
||||
|
||||
@@ -75,4 +75,9 @@ SHOW_EXTRA_CONNECTORS=True
|
||||
LANGSMITH_TRACING="true"
|
||||
LANGSMITH_ENDPOINT="https://api.smith.langchain.com"
|
||||
LANGSMITH_API_KEY=<REPLACE_THIS>
|
||||
LANGSMITH_PROJECT=<REPLACE_THIS>
|
||||
LANGSMITH_PROJECT=<REPLACE_THIS>
|
||||
|
||||
# Local Confluence OAuth testing
|
||||
# OAUTH_CONFLUENCE_CLOUD_CLIENT_ID=<REPLACE_THIS>
|
||||
# OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET=<REPLACE_THIS>
|
||||
# NEXT_PUBLIC_TEST_ENV=True
|
||||
@@ -194,13 +194,15 @@ alembic -n schema_private upgrade head
|
||||
|
||||
### Creating Migrations
|
||||
```bash
|
||||
# Auto-generate migration
|
||||
alembic revision --autogenerate -m "description"
|
||||
# Create migration
|
||||
alembic revision -m "description"
|
||||
|
||||
# Multi-tenant migration
|
||||
alembic -n schema_private revision --autogenerate -m "description"
|
||||
alembic -n schema_private revision -m "description"
|
||||
```
|
||||
|
||||
Write the migration manually and place it in the file that alembic creates when running the above command.
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
There are 4 main types of tests within Onyx:
|
||||
@@ -197,15 +197,19 @@ alembic -n schema_private upgrade head
|
||||
|
||||
### Creating Migrations
|
||||
```bash
|
||||
# Auto-generate migration
|
||||
alembic revision --autogenerate -m "description"
|
||||
# Create migration
|
||||
alembic revision -m "description"
|
||||
|
||||
# Multi-tenant migration
|
||||
alembic -n schema_private revision --autogenerate -m "description"
|
||||
alembic -n schema_private revision -m "description"
|
||||
```
|
||||
|
||||
Write the migration manually and place it in the file that alembic creates when running the above command.
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
First, you must activate the virtual environment with `source .venv/bin/activate`.
|
||||
|
||||
There are 4 main types of tests within Onyx:
|
||||
|
||||
### Unit Tests
|
||||
@@ -216,7 +220,7 @@ write these for complex, isolated modules e.g. `citation_processing.py`.
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest -xv backend/tests/unit
|
||||
pytest -xv backend/tests/unit
|
||||
```
|
||||
|
||||
### External Dependency Unit Tests
|
||||
@@ -94,6 +94,12 @@ If using PowerShell, the command slightly differs:
|
||||
|
||||
Install the required python dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r backend/requirements/combined.txt
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```bash
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
@@ -122,7 +128,7 @@ Onyx uses Node v22.20.0. We highly recommend you use [Node Version Manager (nvm)
|
||||
to manage your Node installations. Once installed, you can run
|
||||
|
||||
```bash
|
||||
nvm install 22 && nvm use 22`
|
||||
nvm install 22 && nvm use 22
|
||||
node -v # verify your active version
|
||||
```
|
||||
|
||||
|
||||
@@ -15,8 +15,8 @@ ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DO_NOT_TRACK="true" \
|
||||
PLAYWRIGHT_BROWSERS_PATH="/app/.cache/ms-playwright"
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||
|
||||
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
|
||||
# Install system dependencies
|
||||
# cmake needed for psycopg (postgres)
|
||||
# libpq-dev needed for psycopg (postgres)
|
||||
@@ -48,22 +48,19 @@ RUN apt-get update && \
|
||||
# Remove py which is pulled in by retry, py is not needed and is a CVE
|
||||
COPY ./requirements/default.txt /tmp/requirements.txt
|
||||
COPY ./requirements/ee.txt /tmp/ee-requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
--retries 5 \
|
||||
--timeout 30 \
|
||||
RUN uv pip install --system --no-cache-dir --upgrade \
|
||||
-r /tmp/requirements.txt \
|
||||
-r /tmp/ee-requirements.txt && \
|
||||
pip uninstall -y py && \
|
||||
playwright install chromium && \
|
||||
playwright install-deps chromium && \
|
||||
ln -s /usr/local/bin/supervisord /usr/bin/supervisord
|
||||
|
||||
# Cleanup for CVEs and size reduction
|
||||
# https://github.com/tornadoweb/tornado/issues/3107
|
||||
# xserver-common and xvfb included by playwright installation but not needed after
|
||||
# perl-base is part of the base Python Debian image but not needed for Onyx functionality
|
||||
# perl-base could only be removed with --allow-remove-essential
|
||||
RUN apt-get update && \
|
||||
ln -s /usr/local/bin/supervisord /usr/bin/supervisord && \
|
||||
# Cleanup for CVEs and size reduction
|
||||
# https://github.com/tornadoweb/tornado/issues/3107
|
||||
# xserver-common and xvfb included by playwright installation but not needed after
|
||||
# perl-base is part of the base Python Debian image but not needed for Onyx functionality
|
||||
# perl-base could only be removed with --allow-remove-essential
|
||||
apt-get update && \
|
||||
apt-get remove -y --allow-remove-essential \
|
||||
perl-base \
|
||||
xserver-common \
|
||||
@@ -73,15 +70,16 @@ RUN apt-get update && \
|
||||
libxmlsec1-dev \
|
||||
pkg-config \
|
||||
gcc && \
|
||||
apt-get install -y libxmlsec1-openssl && \
|
||||
# Install here to avoid some packages being cleaned up above
|
||||
apt-get install -y \
|
||||
libxmlsec1-openssl \
|
||||
# Install postgresql-client for easy manual tests
|
||||
postgresql-client && \
|
||||
apt-get autoremove -y && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm -rf ~/.cache/uv /tmp/*.txt && \
|
||||
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
|
||||
# Install postgresql-client for easy manual tests
|
||||
# Install it here to avoid it being cleaned up above
|
||||
RUN apt-get update && apt-get install -y postgresql-client
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
@@ -95,38 +93,37 @@ nltk.download('punkt_tab', quiet=True);"
|
||||
# Set up application files
|
||||
WORKDIR /app
|
||||
|
||||
# Enterprise Version Files
|
||||
COPY ./ee /app/ee
|
||||
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
|
||||
|
||||
# Set up application files
|
||||
COPY ./onyx /app/onyx
|
||||
COPY ./shared_configs /app/shared_configs
|
||||
COPY ./alembic /app/alembic
|
||||
COPY ./alembic_tenants /app/alembic_tenants
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
COPY ./static /app/static
|
||||
|
||||
# Escape hatch scripts
|
||||
COPY ./scripts/debugging /app/scripts/debugging
|
||||
COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
COPY ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh
|
||||
|
||||
# Put logo in assets
|
||||
COPY ./assets /app/assets
|
||||
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
# Create non-root user for security best practices
|
||||
RUN groupadd -g 1001 onyx && \
|
||||
useradd -u 1001 -g onyx -m -s /bin/bash onyx && \
|
||||
chown -R onyx:onyx /app && \
|
||||
mkdir -p /var/log/onyx && \
|
||||
chmod 755 /var/log/onyx && \
|
||||
chown onyx:onyx /var/log/onyx
|
||||
|
||||
# Enterprise Version Files
|
||||
COPY --chown=onyx:onyx ./ee /app/ee
|
||||
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
|
||||
|
||||
# Set up application files
|
||||
COPY --chown=onyx:onyx ./onyx /app/onyx
|
||||
COPY --chown=onyx:onyx ./shared_configs /app/shared_configs
|
||||
COPY --chown=onyx:onyx ./alembic /app/alembic
|
||||
COPY --chown=onyx:onyx ./alembic_tenants /app/alembic_tenants
|
||||
COPY --chown=onyx:onyx ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
COPY --chown=onyx:onyx ./static /app/static
|
||||
|
||||
# Escape hatch scripts
|
||||
COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging
|
||||
COPY --chown=onyx:onyx ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
COPY --chown=onyx:onyx ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh
|
||||
|
||||
# Put logo in assets
|
||||
COPY --chown=onyx:onyx ./assets /app/assets
|
||||
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
# Default command which does nothing
|
||||
# This container is used by api server and background which specify their own CMD
|
||||
CMD ["tail", "-f", "/dev/null"]
|
||||
|
||||
@@ -12,7 +12,7 @@ ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
HF_HOME=/app/.cache/huggingface
|
||||
|
||||
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||
|
||||
# Create non-root user for security best practices
|
||||
RUN mkdir -p /app && \
|
||||
@@ -34,19 +34,17 @@ RUN set -eux; \
|
||||
pkg-config \
|
||||
curl \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
# Install latest stable Rust (supports Cargo.lock v4)
|
||||
&& curl -sSf https://sh.rustup.rs | sh -s -- -y --profile minimal --default-toolchain stable \
|
||||
&& rustc --version && cargo --version
|
||||
&& rustc --version && cargo --version \
|
||||
&& apt-get remove -y --allow-remove-essential perl-base \
|
||||
&& apt-get autoremove -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY ./requirements/model_server.txt /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
--retries 5 \
|
||||
--timeout 30 \
|
||||
-r /tmp/requirements.txt
|
||||
|
||||
RUN apt-get remove -y --allow-remove-essential perl-base && \
|
||||
apt-get autoremove -y
|
||||
RUN uv pip install --system --no-cache-dir --upgrade \
|
||||
-r /tmp/requirements.txt && \
|
||||
rm -rf ~/.cache/uv /tmp/*.txt
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
# Download tokenizers, distilbert for the Onyx model
|
||||
@@ -61,12 +59,11 @@ snapshot_download(repo_id='onyx-dot-app/information-content-model'); \
|
||||
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
from sentence_transformers import SentenceTransformer; \
|
||||
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);"
|
||||
|
||||
# In case the user has volumes mounted to /app/.cache/huggingface that they've downloaded while
|
||||
# running Onyx, move the current contents of the cache folder to a temporary location to ensure
|
||||
# it's preserved in order to combine with the user's cache contents
|
||||
RUN mv /app/.cache/huggingface /app/.cache/temp_huggingface && \
|
||||
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);" && \
|
||||
# In case the user has volumes mounted to /app/.cache/huggingface that they've downloaded while
|
||||
# running Onyx, move the current contents of the cache folder to a temporary location to ensure
|
||||
# it's preserved in order to combine with the user's cache contents
|
||||
mv /app/.cache/huggingface /app/.cache/temp_huggingface && \
|
||||
chown -R onyx:onyx /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
"""add theme_preference to user
|
||||
|
||||
Revision ID: 09995b8811eb
|
||||
Revises: 3d1cca026fe8
|
||||
Create Date: 2025-10-24 08:58:50.246949
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from onyx.db.enums import ThemePreference
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "09995b8811eb"
|
||||
down_revision = "3d1cca026fe8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"theme_preference",
|
||||
sa.Enum(ThemePreference, native_enum=False),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "theme_preference")
|
||||
@@ -12,6 +12,7 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import text
|
||||
import logging
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
@@ -58,6 +59,9 @@ def upgrade() -> None:
|
||||
logger.info("Dropping chat_session.folder_id...")
|
||||
|
||||
# Drop foreign key constraint first
|
||||
op.execute(
|
||||
"ALTER TABLE chat_session DROP CONSTRAINT IF EXISTS chat_session_chat_folder_fk"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE chat_session DROP CONSTRAINT IF EXISTS chat_session_folder_fk"
|
||||
)
|
||||
@@ -172,20 +176,6 @@ def downgrade() -> None:
|
||||
"user_file", sa.Column("folder_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Recreate chat_folder table
|
||||
if "chat_folder" not in inspector.get_table_names():
|
||||
op.create_table(
|
||||
"chat_folder",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"], ["user.id"], name="chat_folder_user_fk"
|
||||
),
|
||||
)
|
||||
|
||||
# Recreate persona__user_folder table
|
||||
if "persona__user_folder" not in inspector.get_table_names():
|
||||
op.create_table(
|
||||
@@ -197,6 +187,26 @@ def downgrade() -> None:
|
||||
sa.ForeignKeyConstraint(["user_folder_id"], ["user_project.id"]),
|
||||
)
|
||||
|
||||
# Recreate chat_folder table and related structures
|
||||
if "chat_folder" not in inspector.get_table_names():
|
||||
op.create_table(
|
||||
"chat_folder",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("name", sa.String(), nullable=True),
|
||||
sa.Column("display_priority", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
name="chat_folder_user_id_fkey",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Add folder_id back to chat_session
|
||||
if "chat_session" in inspector.get_table_names():
|
||||
columns = [col["name"] for col in inspector.get_columns("chat_session")]
|
||||
@@ -208,7 +218,7 @@ def downgrade() -> None:
|
||||
# Add foreign key if chat_folder exists
|
||||
if "chat_folder" in inspector.get_table_names():
|
||||
op.create_foreign_key(
|
||||
"chat_session_folder_fk",
|
||||
"chat_session_chat_folder_fk",
|
||||
"chat_session",
|
||||
"chat_folder",
|
||||
["folder_id"],
|
||||
|
||||
@@ -292,7 +292,7 @@ def downgrade() -> None:
|
||||
logger.error("CRITICAL: Downgrading data cleanup cannot restore deleted data!")
|
||||
logger.error("Data restoration requires backup files or database backup.")
|
||||
|
||||
raise NotImplementedError(
|
||||
"Downgrade of legacy data cleanup is not supported. "
|
||||
"Deleted data must be restored from backups."
|
||||
)
|
||||
# raise NotImplementedError(
|
||||
# "Downgrade of legacy data cleanup is not supported. "
|
||||
# "Deleted data must be restored from backups."
|
||||
# )
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
"""add_oauth_config_and_user_tokens
|
||||
|
||||
Revision ID: 3d1cca026fe8
|
||||
Revises: c8a93a2af083
|
||||
Create Date: 2025-10-21 13:27:34.274721
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import fastapi_users_db_sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3d1cca026fe8"
|
||||
down_revision = "c8a93a2af083"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create oauth_config table
|
||||
op.create_table(
|
||||
"oauth_config",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("authorization_url", sa.Text(), nullable=False),
|
||||
sa.Column("token_url", sa.Text(), nullable=False),
|
||||
sa.Column("client_id", sa.LargeBinary(), nullable=False),
|
||||
sa.Column("client_secret", sa.LargeBinary(), nullable=False),
|
||||
sa.Column("scopes", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column(
|
||||
"additional_params",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
|
||||
# Create oauth_user_token table
|
||||
op.create_table(
|
||||
"oauth_user_token",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("oauth_config_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("token_data", sa.LargeBinary(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["oauth_config_id"], ["oauth_config.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("oauth_config_id", "user_id", name="uq_oauth_user_token"),
|
||||
)
|
||||
|
||||
# Create index on user_id for efficient user-based token lookups
|
||||
# Note: unique constraint on (oauth_config_id, user_id) already creates
|
||||
# an index for config-based lookups
|
||||
op.create_index(
|
||||
"ix_oauth_user_token_user_id",
|
||||
"oauth_user_token",
|
||||
["user_id"],
|
||||
)
|
||||
|
||||
# Add oauth_config_id column to tool table
|
||||
op.add_column("tool", sa.Column("oauth_config_id", sa.Integer(), nullable=True))
|
||||
|
||||
# Create foreign key from tool to oauth_config
|
||||
op.create_foreign_key(
|
||||
"tool_oauth_config_fk",
|
||||
"tool",
|
||||
"oauth_config",
|
||||
["oauth_config_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop foreign key from tool to oauth_config
|
||||
op.drop_constraint("tool_oauth_config_fk", "tool", type_="foreignkey")
|
||||
|
||||
# Drop oauth_config_id column from tool table
|
||||
op.drop_column("tool", "oauth_config_id")
|
||||
|
||||
# Drop index on user_id
|
||||
op.drop_index("ix_oauth_user_token_user_id", table_name="oauth_user_token")
|
||||
|
||||
# Drop oauth_user_token table (will cascade delete tokens)
|
||||
op.drop_table("oauth_user_token")
|
||||
|
||||
# Drop oauth_config table
|
||||
op.drop_table("oauth_config")
|
||||
@@ -45,8 +45,23 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"chat_session_chat_folder_fk", "chat_session", type_="foreignkey"
|
||||
)
|
||||
op.drop_column("chat_session", "folder_id")
|
||||
op.drop_table("chat_folder")
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
if "chat_session" in inspector.get_table_names():
|
||||
chat_session_fks = {
|
||||
fk.get("name") for fk in inspector.get_foreign_keys("chat_session")
|
||||
}
|
||||
if "chat_session_chat_folder_fk" in chat_session_fks:
|
||||
op.drop_constraint(
|
||||
"chat_session_chat_folder_fk", "chat_session", type_="foreignkey"
|
||||
)
|
||||
|
||||
chat_session_columns = {
|
||||
col["name"] for col in inspector.get_columns("chat_session")
|
||||
}
|
||||
if "folder_id" in chat_session_columns:
|
||||
op.drop_column("chat_session", "folder_id")
|
||||
|
||||
if "chat_folder" in inspector.get_table_names():
|
||||
op.drop_table("chat_folder")
|
||||
|
||||
@@ -180,14 +180,162 @@ def downgrade() -> None:
|
||||
)
|
||||
logger.error("Only proceed if absolutely necessary and have backups.")
|
||||
|
||||
# The downgrade would need to:
|
||||
# 1. Add back integer columns
|
||||
# 2. Generate new sequential IDs
|
||||
# 3. Update all foreign key references
|
||||
# 4. Swap primary keys back
|
||||
# This is complex and risky, so we raise an error instead
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
raise NotImplementedError(
|
||||
"Downgrade of UUID primary key swap is not supported due to data loss risk. "
|
||||
"Manual intervention with data backup/restore is required."
|
||||
# Capture existing primary key definitions so we can restore them after swaps
|
||||
persona_pk = inspector.get_pk_constraint("persona__user_file") or {}
|
||||
persona_pk_name = persona_pk.get("name")
|
||||
persona_pk_cols = persona_pk.get("constrained_columns") or []
|
||||
|
||||
project_pk = inspector.get_pk_constraint("project__user_file") or {}
|
||||
project_pk_name = project_pk.get("name")
|
||||
project_pk_cols = project_pk.get("constrained_columns") or []
|
||||
|
||||
# Drop foreign keys that reference the UUID primary key
|
||||
op.drop_constraint(
|
||||
"persona__user_file_user_file_id_fkey",
|
||||
"persona__user_file",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"fk_project__user_file_user_file_id",
|
||||
"project__user_file",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Drop primary keys that rely on the UUID column so we can replace it
|
||||
if persona_pk_name:
|
||||
op.drop_constraint(persona_pk_name, "persona__user_file", type_="primary")
|
||||
if project_pk_name:
|
||||
op.drop_constraint(project_pk_name, "project__user_file", type_="primary")
|
||||
|
||||
# Rebuild integer IDs on user_file using a sequence-backed column
|
||||
op.execute("CREATE SEQUENCE IF NOT EXISTS user_file_id_seq")
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"id_int",
|
||||
sa.Integer(),
|
||||
server_default=sa.text("nextval('user_file_id_seq')"),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
op.execute("ALTER SEQUENCE user_file_id_seq OWNED BY user_file.id_int")
|
||||
|
||||
# Prepare integer foreign key columns on referencing tables
|
||||
op.add_column(
|
||||
"persona__user_file",
|
||||
sa.Column("user_file_id_int", sa.Integer(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"project__user_file",
|
||||
sa.Column("user_file_id_int", sa.Integer(), nullable=True),
|
||||
)
|
||||
|
||||
# Populate the new integer foreign key columns by mapping from the UUID IDs
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE persona__user_file AS p
|
||||
SET user_file_id_int = uf.id_int
|
||||
FROM user_file AS uf
|
||||
WHERE p.user_file_id = uf.id
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE project__user_file AS p
|
||||
SET user_file_id_int = uf.id_int
|
||||
FROM user_file AS uf
|
||||
WHERE p.user_file_id = uf.id
|
||||
"""
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"persona__user_file",
|
||||
"user_file_id_int",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"project__user_file",
|
||||
"user_file_id_int",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Remove the UUID foreign key columns and rename the integer replacements
|
||||
op.drop_column("persona__user_file", "user_file_id")
|
||||
op.alter_column(
|
||||
"persona__user_file",
|
||||
"user_file_id_int",
|
||||
new_column_name="user_file_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
op.drop_column("project__user_file", "user_file_id")
|
||||
op.alter_column(
|
||||
"project__user_file",
|
||||
"user_file_id_int",
|
||||
new_column_name="user_file_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Swap the user_file primary key back to the integer column
|
||||
op.drop_constraint("user_file_pkey", "user_file", type_="primary")
|
||||
op.drop_column("user_file", "id")
|
||||
op.alter_column(
|
||||
"user_file",
|
||||
"id_int",
|
||||
new_column_name="id",
|
||||
existing_type=sa.Integer(),
|
||||
)
|
||||
op.alter_column(
|
||||
"user_file",
|
||||
"id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False,
|
||||
server_default=sa.text("nextval('user_file_id_seq')"),
|
||||
)
|
||||
op.execute("ALTER SEQUENCE user_file_id_seq OWNED BY user_file.id")
|
||||
op.execute(
|
||||
"""
|
||||
SELECT setval(
|
||||
'user_file_id_seq',
|
||||
GREATEST(COALESCE(MAX(id), 1), 1),
|
||||
MAX(id) IS NOT NULL
|
||||
)
|
||||
FROM user_file
|
||||
"""
|
||||
)
|
||||
op.create_primary_key("user_file_pkey", "user_file", ["id"])
|
||||
|
||||
# Restore primary keys on referencing tables
|
||||
if persona_pk_cols:
|
||||
op.create_primary_key(
|
||||
"persona__user_file_pkey", "persona__user_file", persona_pk_cols
|
||||
)
|
||||
if project_pk_cols:
|
||||
op.create_primary_key(
|
||||
"project__user_file_pkey",
|
||||
"project__user_file",
|
||||
project_pk_cols,
|
||||
)
|
||||
|
||||
# Recreate foreign keys pointing at the integer primary key
|
||||
op.create_foreign_key(
|
||||
"persona__user_file_user_file_id_fkey",
|
||||
"persona__user_file",
|
||||
"user_file",
|
||||
["user_file_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_project__user_file_user_file_id",
|
||||
"project__user_file",
|
||||
"user_file",
|
||||
["user_file_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
@@ -181,12 +181,21 @@ def upgrade() -> None:
|
||||
sa.Column("user_file_id", psql.UUID(as_uuid=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("project_id", "user_file_id"),
|
||||
)
|
||||
logger.info("Created project__user_file table")
|
||||
|
||||
# Only create the index if it doesn't exist
|
||||
existing_indexes = [
|
||||
ix["name"] for ix in inspector.get_indexes("project__user_file")
|
||||
]
|
||||
if "idx_project__user_file_user_file_id" not in existing_indexes:
|
||||
op.create_index(
|
||||
"idx_project__user_file_user_file_id",
|
||||
"project__user_file",
|
||||
["user_file_id"],
|
||||
)
|
||||
logger.info("Created project__user_file table")
|
||||
logger.info(
|
||||
"Created index idx_project__user_file_user_file_id on project__user_file"
|
||||
)
|
||||
|
||||
logger.info("Migration 1 (schema additions) completed successfully")
|
||||
|
||||
@@ -201,7 +210,7 @@ def downgrade() -> None:
|
||||
|
||||
# Drop project__user_file table
|
||||
if "project__user_file" in inspector.get_table_names():
|
||||
op.drop_index("idx_project__user_file_user_file_id", "project__user_file")
|
||||
# op.drop_index("idx_project__user_file_user_file_id", "project__user_file")
|
||||
op.drop_table("project__user_file")
|
||||
logger.info("Dropped project__user_file table")
|
||||
|
||||
|
||||
24
backend/docker-bake.hcl
Normal file
24
backend/docker-bake.hcl
Normal file
@@ -0,0 +1,24 @@
|
||||
variable "REGISTRY" {
|
||||
default = "onyxdotapp"
|
||||
}
|
||||
|
||||
variable "TAG" {
|
||||
default = "latest"
|
||||
}
|
||||
|
||||
target "backend" {
|
||||
context = "."
|
||||
dockerfile = "Dockerfile"
|
||||
}
|
||||
|
||||
target "integration" {
|
||||
context = "."
|
||||
dockerfile = "tests/integration/Dockerfile"
|
||||
|
||||
// Provide the base image via build context from the backend target
|
||||
contexts = {
|
||||
base = "target:backend"
|
||||
}
|
||||
|
||||
tags = ["${REGISTRY}/integration-test-onyx-integration:${TAG}"]
|
||||
}
|
||||
@@ -5,7 +5,6 @@ from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from ee.onyx.server.tenants.product_gating import get_gated_tenants
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
@@ -52,10 +51,18 @@ def cloud_beat_task_generator(
|
||||
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
gated_tenants = get_gated_tenants()
|
||||
|
||||
# NOTE: for now, we are running tasks for gated tenants, since we want to allow
|
||||
# connector deletion to run successfully. The new plan is to continously prune
|
||||
# the gated tenants set, so we won't have a build up of old, unused gated tenants.
|
||||
# Keeping this around in case we want to revert to the previous behavior.
|
||||
# gated_tenants = get_gated_tenants()
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id in gated_tenants:
|
||||
continue
|
||||
|
||||
# Same comment here as the above NOTE
|
||||
# if tenant_id in gated_tenants:
|
||||
# continue
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
<!-- <document type="danswer_chunk" mode="index" /> -->
|
||||
{{ document_elements }}
|
||||
</documents>
|
||||
<nodes count="60">
|
||||
<nodes count="50">
|
||||
<resources vcpu="8.0" memory="128.0Gb" architecture="arm64" storage-type="local"
|
||||
disk="475.0Gb" />
|
||||
</nodes>
|
||||
|
||||
@@ -139,19 +139,13 @@ def get_all_space_permissions(
|
||||
) -> dict[str, ExternalAccess]:
|
||||
logger.debug("Getting space permissions")
|
||||
# Gets all the spaces in the Confluence instance
|
||||
all_space_keys = []
|
||||
start = 0
|
||||
while True:
|
||||
spaces_batch = confluence_client.get_all_spaces(
|
||||
start=start, limit=REQUEST_PAGINATION_LIMIT
|
||||
all_space_keys = [
|
||||
key
|
||||
for space in confluence_client.retrieve_confluence_spaces(
|
||||
limit=REQUEST_PAGINATION_LIMIT,
|
||||
)
|
||||
for space in spaces_batch.get("results", []):
|
||||
all_space_keys.append(space.get("key"))
|
||||
|
||||
if len(spaces_batch.get("results", [])) < REQUEST_PAGINATION_LIMIT:
|
||||
break
|
||||
|
||||
start += len(spaces_batch.get("results", []))
|
||||
if (key := space.get("key"))
|
||||
]
|
||||
|
||||
# Gets the permissions for each space
|
||||
logger.debug(f"Got {len(all_space_keys)} spaces from confluence")
|
||||
|
||||
@@ -76,6 +76,7 @@ class ConfluenceCloudOAuth:
|
||||
"read:confluence-content.permission%20"
|
||||
"read:confluence-user%20"
|
||||
"read:confluence-groups%20"
|
||||
"read:space:confluence%20"
|
||||
"readonly:content.attachment:confluence%20"
|
||||
"search:confluence%20"
|
||||
# granular scope
|
||||
|
||||
134
backend/onyx/agents/agent_sdk/message_format.py
Normal file
134
backend/onyx/agents/agent_sdk/message_format.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
|
||||
from onyx.agents.agent_sdk.message_types import AgentSDKMessage
|
||||
from onyx.agents.agent_sdk.message_types import AssistantMessageWithContent
|
||||
from onyx.agents.agent_sdk.message_types import ImageContent
|
||||
from onyx.agents.agent_sdk.message_types import InputTextContent
|
||||
from onyx.agents.agent_sdk.message_types import SystemMessage
|
||||
from onyx.agents.agent_sdk.message_types import UserMessage
|
||||
|
||||
|
||||
# TODO: Currently, we only support native API input for images. For other
|
||||
# files, we process the content and share it as text in the message. In
|
||||
# the future, we might support native file uploads for other types of files.
|
||||
def base_messages_to_agent_sdk_msgs(
|
||||
msgs: Sequence[BaseMessage],
|
||||
) -> list[AgentSDKMessage]:
|
||||
return [_base_message_to_agent_sdk_msg(msg) for msg in msgs]
|
||||
|
||||
|
||||
def _base_message_to_agent_sdk_msg(msg: BaseMessage) -> AgentSDKMessage:
|
||||
message_type_to_agent_sdk_role = {
|
||||
"human": "user",
|
||||
"system": "system",
|
||||
"ai": "assistant",
|
||||
}
|
||||
role = message_type_to_agent_sdk_role[msg.type]
|
||||
|
||||
# Convert content to Agent SDK format
|
||||
content = msg.content
|
||||
|
||||
if isinstance(content, str):
|
||||
# For system/user/assistant messages, use InputTextContent
|
||||
if role in ("system", "user"):
|
||||
input_text_content: list[InputTextContent | ImageContent] = [
|
||||
InputTextContent(type="input_text", text=content)
|
||||
]
|
||||
if role == "system":
|
||||
# SystemMessage only accepts InputTextContent
|
||||
system_msg: SystemMessage = {
|
||||
"role": "system",
|
||||
"content": [InputTextContent(type="input_text", text=content)],
|
||||
}
|
||||
return system_msg
|
||||
else: # user
|
||||
user_msg: UserMessage = {
|
||||
"role": "user",
|
||||
"content": input_text_content,
|
||||
}
|
||||
return user_msg
|
||||
else: # assistant
|
||||
assistant_msg: AssistantMessageWithContent = {
|
||||
"role": "assistant",
|
||||
"content": [InputTextContent(type="input_text", text=content)],
|
||||
}
|
||||
return assistant_msg
|
||||
elif isinstance(content, list):
|
||||
# For lists, we need to process based on the role
|
||||
if role == "assistant":
|
||||
# Assistant messages use InputTextContent | OutputTextContent
|
||||
from onyx.agents.agent_sdk.message_types import OutputTextContent
|
||||
|
||||
assistant_content: list[InputTextContent | OutputTextContent] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
assistant_content.append(
|
||||
InputTextContent(type="input_text", text=item)
|
||||
)
|
||||
elif isinstance(item, dict) and item.get("type") == "text":
|
||||
assistant_content.append(
|
||||
InputTextContent(type="input_text", text=item.get("text", ""))
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected item type for assistant message: {type(item)}. Item: {item}"
|
||||
)
|
||||
assistant_msg_list: AssistantMessageWithContent = {
|
||||
"role": "assistant",
|
||||
"content": assistant_content,
|
||||
}
|
||||
return assistant_msg_list
|
||||
else: # system or user - use InputTextContent
|
||||
input_content: list[InputTextContent | ImageContent] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
input_content.append(InputTextContent(type="input_text", text=item))
|
||||
elif isinstance(item, dict):
|
||||
item_type = item.get("type")
|
||||
if item_type == "text":
|
||||
input_content.append(
|
||||
InputTextContent(
|
||||
type="input_text", text=item.get("text", "")
|
||||
)
|
||||
)
|
||||
elif item_type == "image_url":
|
||||
# Convert image_url to input_image format
|
||||
image_url = item.get("image_url", {})
|
||||
if isinstance(image_url, dict):
|
||||
url = image_url.get("url", "")
|
||||
else:
|
||||
url = image_url
|
||||
input_content.append(
|
||||
ImageContent(
|
||||
type="input_image", image_url=url, detail="auto"
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected item type: {item_type}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected item type: {type(item)}. Item: {item}"
|
||||
)
|
||||
|
||||
if role == "system":
|
||||
# SystemMessage only accepts InputTextContent (no images)
|
||||
text_only_content = [
|
||||
c for c in input_content if c["type"] == "input_text"
|
||||
]
|
||||
system_msg_list: SystemMessage = {
|
||||
"role": "system",
|
||||
"content": text_only_content, # type: ignore[typeddict-item]
|
||||
}
|
||||
return system_msg_list
|
||||
else: # user
|
||||
user_msg_list: UserMessage = {
|
||||
"role": "user",
|
||||
"content": input_content,
|
||||
}
|
||||
return user_msg_list
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected content type: {type(content)}. Content: {content}"
|
||||
)
|
||||
124
backend/onyx/agents/agent_sdk/message_types.py
Normal file
124
backend/onyx/agents/agent_sdk/message_types.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Strongly typed message structures for Agent SDK messages."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class InputTextContent(TypedDict):
|
||||
type: Literal["input_text"]
|
||||
text: str
|
||||
|
||||
|
||||
class OutputTextContent(TypedDict):
|
||||
type: Literal["output_text"]
|
||||
text: str
|
||||
|
||||
|
||||
TextContent = InputTextContent | OutputTextContent
|
||||
|
||||
|
||||
class ImageContent(TypedDict):
|
||||
type: Literal["input_image"]
|
||||
image_url: str
|
||||
detail: str
|
||||
|
||||
|
||||
# Tool call structures
|
||||
class ToolCallFunction(TypedDict):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(TypedDict):
|
||||
id: str
|
||||
type: Literal["function"]
|
||||
function: ToolCallFunction
|
||||
|
||||
|
||||
# Message types
|
||||
class SystemMessage(TypedDict):
|
||||
role: Literal["system"]
|
||||
content: list[InputTextContent] # System messages use input text
|
||||
|
||||
|
||||
class UserMessage(TypedDict):
|
||||
role: Literal["user"]
|
||||
content: list[
|
||||
InputTextContent | ImageContent
|
||||
] # User messages use input text or images
|
||||
|
||||
|
||||
class AssistantMessageWithContent(TypedDict):
|
||||
role: Literal["assistant"]
|
||||
content: list[
|
||||
InputTextContent | OutputTextContent
|
||||
] # Assistant messages can receive output_text from agents SDK, but we convert to input_text
|
||||
|
||||
|
||||
class AssistantMessageWithToolCalls(TypedDict):
|
||||
role: Literal["assistant"]
|
||||
tool_calls: list[ToolCall]
|
||||
|
||||
|
||||
class AssistantMessageDuringAgentRun(TypedDict):
|
||||
role: Literal["assistant"]
|
||||
id: str
|
||||
content: (
|
||||
list[InputTextContent | OutputTextContent] | list[ToolCall]
|
||||
) # Assistant runtime messages can receive output_text from agents SDK, but we convert to input_text
|
||||
status: Literal["completed", "failed", "in_progress"]
|
||||
type: Literal["message"]
|
||||
|
||||
|
||||
class ToolMessage(TypedDict):
|
||||
role: Literal["tool"]
|
||||
content: str
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class FunctionCallMessage(TypedDict):
|
||||
"""Agent SDK function call message format."""
|
||||
|
||||
type: Literal["function_call"]
|
||||
id: str
|
||||
call_id: str
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class FunctionCallOutputMessage(TypedDict):
|
||||
"""Agent SDK function call output message format."""
|
||||
|
||||
type: Literal["function_call_output"]
|
||||
call_id: str
|
||||
output: str
|
||||
|
||||
|
||||
class SummaryText(TypedDict):
|
||||
"""Summary text item in reasoning messages."""
|
||||
|
||||
text: str
|
||||
type: Literal["summary_text"]
|
||||
|
||||
|
||||
class ReasoningMessage(TypedDict):
|
||||
"""Agent SDK reasoning message format."""
|
||||
|
||||
id: str
|
||||
type: Literal["reasoning"]
|
||||
summary: list[SummaryText]
|
||||
|
||||
|
||||
# Union type for all Agent SDK messages
|
||||
AgentSDKMessage = (
|
||||
SystemMessage
|
||||
| UserMessage
|
||||
| AssistantMessageWithContent
|
||||
| AssistantMessageWithToolCalls
|
||||
| AssistantMessageDuringAgentRun
|
||||
| ToolMessage
|
||||
| FunctionCallMessage
|
||||
| FunctionCallOutputMessage
|
||||
| ReasoningMessage
|
||||
)
|
||||
30
backend/onyx/agents/agent_sdk/monkey_patches.py
Normal file
30
backend/onyx/agents/agent_sdk/monkey_patches.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from typing import Any
|
||||
|
||||
from agents.models.openai_responses import Converter as OpenAIResponsesConverter
|
||||
|
||||
|
||||
# TODO: I am very sad that I have to monkey patch this :(
|
||||
# Basically, OpenAI agents sdk doesn't convert the tool choice correctly
|
||||
# when they have a built-in tool in their framework, like they do for web_search.
|
||||
# Going to open up a thread with OpenAI agents team to see what they recommend
|
||||
# or what we can fix.
|
||||
# A discussion is warranted, but we likely want to just write our own LitellmModel for
|
||||
# the OpenAI agents SDK since they probably don't really care about Litellm and will
|
||||
# prioritize functionality for their own models.
|
||||
def monkey_patch_convert_tool_choice_to_ignore_openai_hosted_web_search() -> None:
|
||||
if (
|
||||
getattr(OpenAIResponsesConverter.convert_tool_choice, "__name__", "")
|
||||
== "_patched_convert_tool_choice"
|
||||
):
|
||||
return
|
||||
|
||||
orig_func = OpenAIResponsesConverter.convert_tool_choice.__func__ # type: ignore[attr-defined]
|
||||
|
||||
def _patched_convert_tool_choice(cls: type, tool_choice: Any) -> Any:
|
||||
if tool_choice == "web_search":
|
||||
return {"type": "function", "name": "web_search"}
|
||||
return orig_func(cls, tool_choice)
|
||||
|
||||
OpenAIResponsesConverter.convert_tool_choice = classmethod( # type: ignore[method-assign, assignment]
|
||||
_patched_convert_tool_choice
|
||||
)
|
||||
@@ -2,15 +2,17 @@ import asyncio
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from typing import Generic
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
from agents import Agent
|
||||
from agents import RunResultStreaming
|
||||
from agents import TContext
|
||||
from agents.run import Runner
|
||||
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.agents.agent_sdk.message_types import AgentSDKMessage
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -41,8 +43,8 @@ class SyncAgentStream(Generic[T]):
|
||||
self,
|
||||
*,
|
||||
agent: Agent,
|
||||
input: list[dict],
|
||||
context: ChatTurnContext,
|
||||
input: Sequence[AgentSDKMessage],
|
||||
context: TContext | None = None,
|
||||
max_turns: int = 100,
|
||||
queue_maxsize: int = 0,
|
||||
) -> None:
|
||||
@@ -54,7 +56,7 @@ class SyncAgentStream(Generic[T]):
|
||||
self._q: "queue.Queue[object]" = queue.Queue(maxsize=queue_maxsize)
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._streamed: RunResultStreaming | None = None
|
||||
self.streamed: RunResultStreaming | None = None
|
||||
self._exc: Optional[BaseException] = None
|
||||
self._cancel_requested = threading.Event()
|
||||
self._started = threading.Event()
|
||||
@@ -87,7 +89,7 @@ class SyncAgentStream(Generic[T]):
|
||||
"""
|
||||
self._cancel_requested.set()
|
||||
loop = self._loop
|
||||
streamed = self._streamed
|
||||
streamed = self.streamed
|
||||
if loop is not None and streamed is not None and not self._done.is_set():
|
||||
loop.call_soon_threadsafe(streamed.cancel)
|
||||
return True
|
||||
@@ -123,7 +125,7 @@ class SyncAgentStream(Generic[T]):
|
||||
async def worker() -> None:
|
||||
try:
|
||||
# Start the streamed run inside the loop thread
|
||||
self._streamed = Runner.run_streamed(
|
||||
self.streamed = Runner.run_streamed(
|
||||
self._agent,
|
||||
self._input, # type: ignore[arg-type]
|
||||
context=self._context,
|
||||
@@ -132,15 +134,15 @@ class SyncAgentStream(Generic[T]):
|
||||
|
||||
# If cancel was requested before we created _streamed, honor it now
|
||||
if self._cancel_requested.is_set():
|
||||
await self._streamed.cancel() # type: ignore[func-returns-value]
|
||||
await self.streamed.cancel() # type: ignore[func-returns-value]
|
||||
|
||||
# Consume async events and forward into the thread-safe queue
|
||||
async for ev in self._streamed.stream_events():
|
||||
async for ev in self.streamed.stream_events():
|
||||
# Early exit if a late cancel arrives
|
||||
if self._cancel_requested.is_set():
|
||||
# Try to cancel gracefully; don't break until cancel takes effect
|
||||
try:
|
||||
await self._streamed.cancel() # type: ignore[func-returns-value]
|
||||
await self.streamed.cancel() # type: ignore[func-returns-value]
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
@@ -174,4 +176,3 @@ class SyncAgentStream(Generic[T]):
|
||||
finally:
|
||||
loop.close()
|
||||
self._loop = None
|
||||
self._streamed = None
|
||||
@@ -14,10 +14,10 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.prompts.agents.dc_prompts import DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT
|
||||
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SEPARATOR
|
||||
from onyx.prompts.agents.dc_prompts import DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT
|
||||
from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
@@ -61,10 +61,12 @@ def search_objects(
|
||||
if agent_1_independent_sources_str is None:
|
||||
raise ValueError("Agent 1 Independent Research Sources not found")
|
||||
|
||||
document_sources = [
|
||||
DocumentSource(x.strip().lower())
|
||||
for x in agent_1_independent_sources_str.split(DC_OBJECT_SEPARATOR)
|
||||
]
|
||||
document_sources = strings_to_document_sources(
|
||||
[
|
||||
x.strip().lower()
|
||||
for x in agent_1_independent_sources_str.split(DC_OBJECT_SEPARATOR)
|
||||
]
|
||||
)
|
||||
|
||||
agent_1_output_objective = extract_section(
|
||||
agent_1_instructions, "Output Objective:"
|
||||
|
||||
@@ -98,14 +98,6 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _format_tool_name(tool_name: str) -> str:
|
||||
"""Convert tool name to LLM-friendly format."""
|
||||
name = tool_name.replace(" ", "_")
|
||||
# take care of camel case like GetAPIKey -> GET_API_KEY for LLM readability
|
||||
name = re.sub(r"(?<=[a-z0-9])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])", "_", name)
|
||||
return name.upper()
|
||||
|
||||
|
||||
def _get_available_tools(
|
||||
db_session: Session,
|
||||
graph_config: GraphConfig,
|
||||
@@ -562,7 +554,7 @@ def clarifier(
|
||||
# if there is only one tool (Closer), we don't need to decide. It's an LLM answer
|
||||
llm_decision = DecisionResponse(decision="LLM", reasoning="")
|
||||
|
||||
if llm_decision.decision == "LLM":
|
||||
if llm_decision.decision == "LLM" and research_type != ResearchType.DEEP:
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
@@ -702,55 +694,58 @@ def clarifier(
|
||||
should_stream_answer=True,
|
||||
writer=writer,
|
||||
ind=0,
|
||||
final_search_results=context_llm_docs,
|
||||
displayed_search_results=context_llm_docs,
|
||||
search_results=context_llm_docs,
|
||||
generate_final_answer=True,
|
||||
chat_message_id=str(graph_config.persistence.chat_session_id),
|
||||
)
|
||||
|
||||
full_response = stream_and_process()
|
||||
if len(full_response.ai_message_chunk.tool_calls) == 0:
|
||||
# Deep research always continues to clarification or search
|
||||
if research_type != ResearchType.DEEP:
|
||||
full_response = stream_and_process()
|
||||
if len(full_response.ai_message_chunk.tool_calls) == 0:
|
||||
|
||||
if isinstance(full_response.full_answer, str):
|
||||
full_answer = (
|
||||
normalize_square_bracket_citations_to_double_with_links(
|
||||
full_response.full_answer
|
||||
if isinstance(full_response.full_answer, str):
|
||||
full_answer = (
|
||||
normalize_square_bracket_citations_to_double_with_links(
|
||||
full_response.full_answer
|
||||
)
|
||||
)
|
||||
else:
|
||||
full_answer = None
|
||||
|
||||
# Persist final documents and derive citations when using in-context docs
|
||||
final_documents_db, citations_map = (
|
||||
_persist_final_docs_and_citations(
|
||||
db_session=db_session,
|
||||
context_llm_docs=context_llm_docs,
|
||||
full_answer=full_answer,
|
||||
)
|
||||
)
|
||||
else:
|
||||
full_answer = None
|
||||
|
||||
# Persist final documents and derive citations when using in-context docs
|
||||
final_documents_db, citations_map = _persist_final_docs_and_citations(
|
||||
db_session=db_session,
|
||||
context_llm_docs=context_llm_docs,
|
||||
full_answer=full_answer,
|
||||
)
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=graph_config.persistence.chat_session_id,
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=full_answer,
|
||||
token_count=len(llm_tokenizer.encode(full_answer or "")),
|
||||
citations=citations_map,
|
||||
final_documents=final_documents_db or None,
|
||||
update_parent_message=True,
|
||||
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
)
|
||||
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=graph_config.persistence.chat_session_id,
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=full_answer,
|
||||
token_count=len(llm_tokenizer.encode(full_answer or "")),
|
||||
citations=citations_map,
|
||||
final_documents=final_documents_db or None,
|
||||
update_parent_message=True,
|
||||
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return OrchestrationSetup(
|
||||
original_question=original_question,
|
||||
chat_history_string="",
|
||||
tools_used=[DRPath.END.value],
|
||||
query_list=[],
|
||||
available_tools=available_tools,
|
||||
assistant_system_prompt=assistant_system_prompt,
|
||||
assistant_task_prompt=assistant_task_prompt,
|
||||
)
|
||||
return OrchestrationSetup(
|
||||
original_question=original_question,
|
||||
chat_history_string="",
|
||||
tools_used=[DRPath.END.value],
|
||||
query_list=[],
|
||||
available_tools=available_tools,
|
||||
assistant_system_prompt=assistant_system_prompt,
|
||||
assistant_task_prompt=assistant_task_prompt,
|
||||
)
|
||||
|
||||
# Continue, as external knowledge is required.
|
||||
|
||||
|
||||
@@ -41,18 +41,16 @@ def process_llm_stream(
|
||||
should_stream_answer: bool,
|
||||
writer: StreamWriter,
|
||||
ind: int,
|
||||
final_search_results: list[LlmDoc] | None = None,
|
||||
displayed_search_results: list[LlmDoc] | None = None,
|
||||
search_results: list[LlmDoc] | None = None,
|
||||
generate_final_answer: bool = False,
|
||||
chat_message_id: str | None = None,
|
||||
) -> BasicSearchProcessedStreamResults:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
|
||||
if final_search_results and displayed_search_results:
|
||||
if search_results:
|
||||
answer_handler: AnswerResponseHandler = CitationResponseHandler(
|
||||
context_docs=final_search_results,
|
||||
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
context_docs=search_results,
|
||||
doc_id_to_rank_map=map_document_id_order(search_results),
|
||||
)
|
||||
else:
|
||||
answer_handler = PassThroughAnswerResponseHandler()
|
||||
@@ -78,7 +76,7 @@ def process_llm_stream(
|
||||
):
|
||||
tool_call_chunk += message # type: ignore
|
||||
elif should_stream_answer:
|
||||
for response_part in answer_handler.handle_response_part(message, []):
|
||||
for response_part in answer_handler.handle_response_part(message):
|
||||
|
||||
# only stream out answer parts
|
||||
if (
|
||||
@@ -94,7 +92,7 @@ def process_llm_stream(
|
||||
if not start_final_answer_streaming_set:
|
||||
# Convert LlmDocs to SavedSearchDocs
|
||||
saved_search_docs = saved_search_docs_from_llm_docs(
|
||||
final_search_results
|
||||
search_results
|
||||
)
|
||||
write_custom_event(
|
||||
ind,
|
||||
|
||||
@@ -30,6 +30,7 @@ from onyx.db.connector import DocumentSource
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
|
||||
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
|
||||
from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
@@ -128,10 +129,11 @@ def basic_search(
|
||||
if re.match(date_pattern, implied_start_date):
|
||||
implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
|
||||
|
||||
specified_source_types: list[DocumentSource] | None = [
|
||||
DocumentSource(source_type)
|
||||
for source_type in search_processing.specified_source_types
|
||||
]
|
||||
specified_source_types: list[DocumentSource] | None = (
|
||||
strings_to_document_sources(search_processing.specified_source_types)
|
||||
if search_processing.specified_source_types
|
||||
else None
|
||||
)
|
||||
|
||||
if specified_source_types is not None and len(specified_source_types) == 0:
|
||||
specified_source_types = None
|
||||
|
||||
@@ -117,10 +117,8 @@ def image_generation(
|
||||
|
||||
# save images to file store
|
||||
file_ids = save_files(
|
||||
urls=[img.url for img in image_generation_responses if img.url],
|
||||
base64_files=[
|
||||
img.image_data for img in image_generation_responses if img.image_data
|
||||
],
|
||||
urls=[],
|
||||
base64_files=[img.image_data for img in image_generation_responses],
|
||||
)
|
||||
|
||||
final_generated_images = [
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from exa_py import Exa
|
||||
from exa_py.api import HighlightsContentsOptions
|
||||
|
||||
@@ -47,9 +49,9 @@ class ExaClient(WebSearchProvider):
|
||||
]
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def contents(self, urls: list[str]) -> list[WebContent]:
|
||||
def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
response = self.exa.get_contents(
|
||||
urls=urls,
|
||||
urls=list(urls),
|
||||
text=True,
|
||||
livecrawl="preferred",
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import requests
|
||||
@@ -55,7 +56,7 @@ class SerperClient(WebSearchProvider):
|
||||
for result in organic_results
|
||||
]
|
||||
|
||||
def contents(self, urls: list[str]) -> list[WebContent]:
|
||||
def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
if not urls:
|
||||
return []
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ def web_search(
|
||||
def _search(search_query: str) -> list[WebSearchResult]:
|
||||
search_results: list[WebSearchResult] = []
|
||||
try:
|
||||
search_results = provider.search(search_query)
|
||||
search_results = list(provider.search(search_query))
|
||||
except Exception as e:
|
||||
logger.error(f"Error performing search: {e}")
|
||||
return search_results
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
@@ -31,9 +32,9 @@ class WebContent(BaseModel):
|
||||
|
||||
class WebSearchProvider(ABC):
|
||||
@abstractmethod
|
||||
def search(self, query: str) -> list[WebSearchResult]:
|
||||
def search(self, query: str) -> Sequence[WebSearchResult]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def contents(self, urls: list[str]) -> list[WebContent]:
|
||||
def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
pass
|
||||
|
||||
@@ -4,6 +4,8 @@ from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.chat.models import DOCUMENT_CITATION_NUMBER_EMPTY_VALUE
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
@@ -75,3 +77,23 @@ def dummy_inference_section_from_internet_search_result(
|
||||
chunks=[],
|
||||
combined_content="",
|
||||
)
|
||||
|
||||
|
||||
def llm_doc_from_web_content(web_content: WebContent) -> LlmDoc:
|
||||
"""Create an LlmDoc from WebContent with the INTERNET_SEARCH_DOC_ prefix"""
|
||||
return LlmDoc(
|
||||
# TODO: Is this what we want to do for document_id? We're kind of overloading it since it
|
||||
# should ideally correspond to a document in the database. But I guess if you're calling this
|
||||
# function you know it won't be in the database.
|
||||
document_id="INTERNET_SEARCH_DOC_" + web_content.link,
|
||||
content=truncate_search_result_content(web_content.full_content),
|
||||
blurb=web_content.link,
|
||||
semantic_identifier=web_content.link,
|
||||
source_type=DocumentSource.WEB,
|
||||
metadata={},
|
||||
link=web_content.link,
|
||||
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
|
||||
updated_at=web_content.published_date,
|
||||
source_links={},
|
||||
match_highlights=[],
|
||||
)
|
||||
|
||||
155
backend/onyx/auth/oauth_token_manager.py
Normal file
155
backend/onyx/auth/oauth_token_manager.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import time
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.models import OAuthUserToken
|
||||
from onyx.db.oauth_config import get_user_oauth_token
|
||||
from onyx.db.oauth_config import upsert_user_oauth_token
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class OAuthTokenManager:
|
||||
"""Manages OAuth token retrieval, refresh, and validation"""
|
||||
|
||||
def __init__(self, oauth_config: OAuthConfig, user_id: UUID, db_session: Session):
|
||||
self.oauth_config = oauth_config
|
||||
self.user_id = user_id
|
||||
self.db_session = db_session
|
||||
|
||||
def get_valid_access_token(self) -> str | None:
|
||||
"""Get valid access token, refreshing if necessary"""
|
||||
user_token = get_user_oauth_token(
|
||||
self.oauth_config.id, self.user_id, self.db_session
|
||||
)
|
||||
|
||||
if not user_token:
|
||||
return None
|
||||
|
||||
token_data = user_token.token_data
|
||||
|
||||
# Check if token is expired
|
||||
if OAuthTokenManager.is_token_expired(token_data):
|
||||
# Try to refresh if we have a refresh token
|
||||
if "refresh_token" in token_data:
|
||||
try:
|
||||
return self.refresh_token(user_token)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to refresh token: {e}")
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
return token_data.get("access_token")
|
||||
|
||||
def refresh_token(self, user_token: OAuthUserToken) -> str:
|
||||
"""Refresh access token using refresh token"""
|
||||
token_data = user_token.token_data
|
||||
|
||||
response = requests.post(
|
||||
self.oauth_config.token_url,
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token_data["refresh_token"],
|
||||
"client_id": self.oauth_config.client_id,
|
||||
"client_secret": self.oauth_config.client_secret,
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
new_token_data = response.json()
|
||||
|
||||
# Calculate expires_at if expires_in is present
|
||||
if "expires_in" in new_token_data:
|
||||
new_token_data["expires_at"] = (
|
||||
int(time.time()) + new_token_data["expires_in"]
|
||||
)
|
||||
|
||||
# Preserve refresh_token if not returned (some providers don't return it)
|
||||
if "refresh_token" not in new_token_data and "refresh_token" in token_data:
|
||||
new_token_data["refresh_token"] = token_data["refresh_token"]
|
||||
|
||||
# Update token in DB
|
||||
upsert_user_oauth_token(
|
||||
self.oauth_config.id,
|
||||
self.user_id,
|
||||
new_token_data,
|
||||
self.db_session,
|
||||
)
|
||||
|
||||
return new_token_data["access_token"]
|
||||
|
||||
@classmethod
|
||||
def token_expiration_time(cls, token_data: dict[str, Any]) -> int | None:
|
||||
"""Get the token expiration time"""
|
||||
expires_at = token_data.get("expires_at")
|
||||
if not expires_at:
|
||||
return None
|
||||
|
||||
return expires_at
|
||||
|
||||
@classmethod
|
||||
def is_token_expired(cls, token_data: dict[str, Any]) -> bool:
|
||||
"""Check if token is expired (with 60 second buffer)"""
|
||||
expires_at = cls.token_expiration_time(token_data)
|
||||
if not expires_at:
|
||||
return False # No expiration data, assume valid
|
||||
|
||||
# Add 60 second buffer to avoid race conditions
|
||||
return int(time.time()) + 60 >= expires_at
|
||||
|
||||
def exchange_code_for_token(self, code: str, redirect_uri: str) -> dict[str, Any]:
|
||||
"""Exchange authorization code for access token"""
|
||||
response = requests.post(
|
||||
self.oauth_config.token_url,
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"client_id": self.oauth_config.client_id,
|
||||
"client_secret": self.oauth_config.client_secret,
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
token_data = response.json()
|
||||
|
||||
# Calculate expires_at if expires_in is present
|
||||
if "expires_in" in token_data:
|
||||
token_data["expires_at"] = int(time.time()) + token_data["expires_in"]
|
||||
|
||||
return token_data
|
||||
|
||||
@staticmethod
|
||||
def build_authorization_url(
|
||||
oauth_config: OAuthConfig, redirect_uri: str, state: str
|
||||
) -> str:
|
||||
"""Build OAuth authorization URL"""
|
||||
params: dict[str, Any] = {
|
||||
"client_id": oauth_config.client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"state": state,
|
||||
}
|
||||
|
||||
# Add scopes if configured
|
||||
if oauth_config.scopes:
|
||||
params["scope"] = " ".join(oauth_config.scopes)
|
||||
|
||||
# Add any additional provider-specific parameters
|
||||
if oauth_config.additional_params:
|
||||
params.update(oauth_config.additional_params)
|
||||
|
||||
# Check if URL already has query parameters
|
||||
separator = "&" if "?" in oauth_config.authorization_url else "?"
|
||||
|
||||
return f"{oauth_config.authorization_url}{separator}{urlencode(params)}"
|
||||
@@ -109,13 +109,11 @@ from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.saml import get_saml_account
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.secrets import extract_hashed_cookie
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
@@ -1064,17 +1062,7 @@ async def _check_for_saml_and_jwt(
|
||||
user: User | None,
|
||||
async_db_session: AsyncSession,
|
||||
) -> User | None:
|
||||
# Check if the user has a session cookie from SAML
|
||||
if AUTH_TYPE == AuthType.SAML:
|
||||
saved_cookie = extract_hashed_cookie(request)
|
||||
|
||||
if saved_cookie:
|
||||
saml_account = await get_saml_account(
|
||||
cookie=saved_cookie, async_db_session=async_db_session
|
||||
)
|
||||
user = saml_account.user if saml_account else None
|
||||
|
||||
# If user is still None, check for JWT in Authorization header
|
||||
# If user is None, check for JWT in Authorization header
|
||||
if user is None and JWT_PUBLIC_KEY_URL is not None:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from onyx.chat.models import ResponsePart
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
|
||||
|
||||
# This is Legacy code that is not used anymore.
|
||||
# It is kept here for reference.
|
||||
class LLMResponseHandlerManager:
|
||||
"""
|
||||
This class is responsible for postprocessing the LLM response stream.
|
||||
In particular, we:
|
||||
1. handle the tool call requests
|
||||
2. handle citations
|
||||
3. pass through answers generated by the LLM
|
||||
4. Stop yielding if the client disconnects
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_handler: ToolResponseHandler | None,
|
||||
answer_handler: AnswerResponseHandler | None,
|
||||
is_cancelled: Callable[[], bool],
|
||||
):
|
||||
self.tool_handler = tool_handler or ToolResponseHandler([])
|
||||
self.answer_handler = answer_handler or DummyAnswerResponseHandler()
|
||||
self.is_cancelled = is_cancelled
|
||||
|
||||
def handle_llm_response(
|
||||
self,
|
||||
stream: Iterator[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
all_messages: list[BaseMessage | str] = []
|
||||
for message in stream:
|
||||
if self.is_cancelled():
|
||||
yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
return
|
||||
# tool handler doesn't do anything until the full message is received
|
||||
# NOTE: still need to run list() to get this to run
|
||||
list(self.tool_handler.handle_response_part(message, all_messages))
|
||||
yield from self.answer_handler.handle_response_part(message, all_messages)
|
||||
all_messages.append(message)
|
||||
|
||||
# potentially give back all info on the selected tool call + its result
|
||||
yield from self.tool_handler.handle_response_part(None, all_messages)
|
||||
yield from self.answer_handler.handle_response_part(None, all_messages)
|
||||
|
||||
def next_llm_call(self, llm_call: LLMCall) -> LLMCall | None:
|
||||
return self.tool_handler.next_llm_call(llm_call)
|
||||
@@ -33,9 +33,16 @@ if TYPE_CHECKING:
|
||||
from onyx.db.models import Persona
|
||||
|
||||
|
||||
# We need this value to be a constant instead of None to avoid JSON serialization issues
|
||||
DOCUMENT_CITATION_NUMBER_EMPTY_VALUE = -1
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
"""This contains the minimal set information for the LLM portion including citations"""
|
||||
|
||||
# This is kind of cooked. We're overloading this field as both a "catch all" for identifying
|
||||
# an LLM doc as well as a way to connect the LlmDoc to the DB. For internal search, it will
|
||||
# be an id for the db but not for web search.
|
||||
document_id: str
|
||||
content: str
|
||||
blurb: str
|
||||
@@ -46,6 +53,7 @@ class LlmDoc(BaseModel):
|
||||
link: str | None
|
||||
source_links: dict[int, str] | None
|
||||
match_highlights: list[str] | None
|
||||
document_citation_number: int | None = DOCUMENT_CITATION_NUMBER_EMPTY_VALUE
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
|
||||
@@ -8,9 +8,12 @@ from typing import cast
|
||||
from typing import Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from agents import Model
|
||||
from agents import ModelSettings
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_sdk.message_format import base_messages_to_agent_sdk_msgs
|
||||
from onyx.chat.answer import Answer
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_temporary_persona
|
||||
@@ -31,10 +34,9 @@ from onyx.chat.models import UserKnowledgeFilePacket
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import (
|
||||
default_build_system_message_v2,
|
||||
default_build_system_message_for_default_assistant_v2,
|
||||
)
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message_v2
|
||||
from onyx.chat.turn import fast_chat_turn
|
||||
from onyx.chat.turn.infra.emitter import get_default_emitter
|
||||
from onyx.chat.turn.models import ChatTurnDependencies
|
||||
@@ -77,13 +79,14 @@ from onyx.db.projects import get_user_files_from_project
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.feature_flags.factory import get_default_feature_flag_provider
|
||||
from onyx.feature_flags.feature_flags_keys import SIMPLE_AGENT_FRAMEWORK
|
||||
from onyx.feature_flags.feature_flags_keys import DISABLE_SIMPLE_AGENT_FRAMEWORK
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import build_frontend_file_url
|
||||
from onyx.file_store.utils import load_all_chat_files
|
||||
from onyx.kg.models import KGException
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.factory import get_llm_model_and_settings_for_persona
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.factory import get_main_llm_from_tuple
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -98,7 +101,6 @@ from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.tools.adapter_v1_to_v2 import tools_to_function_tools
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool import Tool
|
||||
@@ -524,9 +526,15 @@ def stream_chat_message_objects(
|
||||
if new_msg_req.current_message_files:
|
||||
for fd in new_msg_req.current_message_files:
|
||||
uid = fd.get("user_file_id")
|
||||
if uid is not None:
|
||||
user_file_id = UUID(uid)
|
||||
user_file_ids.append(user_file_id)
|
||||
if not uid:
|
||||
continue
|
||||
try:
|
||||
user_file_ids.append(UUID(uid))
|
||||
except (TypeError, ValueError, AttributeError):
|
||||
logger.warning(
|
||||
"Skipping invalid user_file_id from current_message_files: %s",
|
||||
uid,
|
||||
)
|
||||
|
||||
# Load in user files into memory and create search tool override kwargs if needed
|
||||
# if we have enough tokens, we don't need to use search
|
||||
@@ -715,9 +723,7 @@ def stream_chat_message_objects(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
),
|
||||
image_generation_tool_config=ImageGenerationToolConfig(
|
||||
additional_headers=litellm_additional_headers,
|
||||
),
|
||||
image_generation_tool_config=ImageGenerationToolConfig(),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
@@ -752,31 +758,25 @@ def stream_chat_message_objects(
|
||||
]
|
||||
)
|
||||
feature_flag_provider = get_default_feature_flag_provider()
|
||||
simple_agent_framework_enabled = (
|
||||
simple_agent_framework_disabled = (
|
||||
feature_flag_provider.feature_enabled_for_user_tenant(
|
||||
flag_key=SIMPLE_AGENT_FRAMEWORK,
|
||||
flag_key=DISABLE_SIMPLE_AGENT_FRAMEWORK,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
and not new_msg_req.use_agentic_search
|
||||
or new_msg_req.use_agentic_search
|
||||
)
|
||||
prompt_user_message = (
|
||||
default_build_user_message_v2(
|
||||
user_query=final_msg.message,
|
||||
prompt_config=prompt_config,
|
||||
files=latest_query_files,
|
||||
)
|
||||
if simple_agent_framework_enabled
|
||||
else default_build_user_message(
|
||||
user_query=final_msg.message,
|
||||
prompt_config=prompt_config,
|
||||
files=latest_query_files,
|
||||
)
|
||||
prompt_user_message = default_build_user_message(
|
||||
user_query=final_msg.message,
|
||||
prompt_config=prompt_config,
|
||||
files=latest_query_files,
|
||||
)
|
||||
mem_callback = make_memories_callback(user, db_session)
|
||||
system_message = (
|
||||
default_build_system_message_v2(prompt_config, llm.config, mem_callback)
|
||||
if simple_agent_framework_enabled
|
||||
default_build_system_message_for_default_assistant_v2(
|
||||
prompt_config, llm.config, mem_callback, tools
|
||||
)
|
||||
if not simple_agent_framework_disabled and persona.is_default_persona
|
||||
else default_build_system_message(prompt_config, llm.config, mem_callback)
|
||||
)
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
@@ -823,7 +823,13 @@ def stream_chat_message_objects(
|
||||
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
|
||||
project_instructions=project_instructions,
|
||||
)
|
||||
if simple_agent_framework_enabled:
|
||||
if not simple_agent_framework_disabled:
|
||||
llm_model, model_settings = get_llm_model_and_settings_for_persona(
|
||||
persona=persona,
|
||||
llm_override=(new_msg_req.llm_override or chat_session.llm_override),
|
||||
additional_headers=litellm_additional_headers,
|
||||
timeout=None, # Will use default timeout logic
|
||||
)
|
||||
yield from _fast_message_stream(
|
||||
answer,
|
||||
tools,
|
||||
@@ -831,6 +837,9 @@ def stream_chat_message_objects(
|
||||
get_redis_client(),
|
||||
chat_session_id,
|
||||
reserved_message_id,
|
||||
prompt_config,
|
||||
llm_model,
|
||||
model_settings,
|
||||
)
|
||||
else:
|
||||
from onyx.chat.packet_proccessing import process_streamed_packets
|
||||
@@ -882,41 +891,22 @@ def _fast_message_stream(
|
||||
redis_client: Redis,
|
||||
chat_session_id: UUID,
|
||||
reserved_message_id: int,
|
||||
prompt_config: PromptConfig,
|
||||
llm_model: Model,
|
||||
model_settings: ModelSettings,
|
||||
) -> Generator[Packet, None, None]:
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
messages = base_messages_to_agent_sdk_msgs(
|
||||
answer.graph_inputs.prompt_builder.build()
|
||||
)
|
||||
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
|
||||
OktaProfileTool,
|
||||
)
|
||||
from onyx.llm.litellm_singleton import LitellmModel
|
||||
|
||||
image_generation_tool_instance = None
|
||||
okta_profile_tool_instance = None
|
||||
for tool in tools:
|
||||
if isinstance(tool, ImageGenerationTool):
|
||||
image_generation_tool_instance = tool
|
||||
elif isinstance(tool, OktaProfileTool):
|
||||
okta_profile_tool_instance = tool
|
||||
converted_message_history = [
|
||||
PreviousMessage.from_langchain_msg(message, 0).to_agent_sdk_msg()
|
||||
for message in answer.graph_inputs.prompt_builder.build()
|
||||
]
|
||||
emitter = get_default_emitter()
|
||||
return fast_chat_turn.fast_chat_turn(
|
||||
messages=converted_message_history,
|
||||
messages=messages,
|
||||
# TODO: Maybe we can use some DI framework here?
|
||||
dependencies=ChatTurnDependencies(
|
||||
llm_model=LitellmModel(
|
||||
model=answer.graph_tooling.primary_llm.config.model_name,
|
||||
base_url=answer.graph_tooling.primary_llm.config.api_base,
|
||||
api_key=answer.graph_tooling.primary_llm.config.api_key,
|
||||
),
|
||||
llm_model=llm_model,
|
||||
model_settings=model_settings,
|
||||
llm=answer.graph_tooling.primary_llm,
|
||||
tools=tools_to_function_tools(tools),
|
||||
search_pipeline=answer.graph_tooling.search_tool,
|
||||
image_generation_tool=image_generation_tool_instance,
|
||||
okta_profile_tool=okta_profile_tool_instance,
|
||||
tools=tools,
|
||||
db_session=db_session,
|
||||
redis_client=redis_client,
|
||||
emitter=emitter,
|
||||
@@ -924,6 +914,8 @@ def _fast_message_stream(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=reserved_message_id,
|
||||
research_type=answer.graph_config.behavior.research_type,
|
||||
prompt_config=prompt_config,
|
||||
force_use_tool=answer.graph_tooling.force_use_tool,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,10 @@ from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
|
||||
from onyx.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT_V2
|
||||
from onyx.prompts.chat_prompts import CUSTOM_INSTRUCTIONS_PROMPT
|
||||
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
|
||||
from onyx.prompts.chat_prompts import LONG_CONVERSATION_REMINDER_PROMPT
|
||||
from onyx.prompts.chat_prompts import TOOL_PERSISTENCE_PROMPT
|
||||
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from onyx.prompts.prompt_utils import drop_messages_history_overflow
|
||||
from onyx.prompts.prompt_utils import handle_company_awareness
|
||||
@@ -34,13 +37,26 @@ from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
|
||||
def default_build_system_message_v2(
|
||||
# TODO: We can provide do smoother templating than all these sequential
|
||||
# function calls
|
||||
def default_build_system_message_for_default_assistant_v2(
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
memories_callback: Callable[[], list[str]] | None = None,
|
||||
) -> SystemMessage | None:
|
||||
system_prompt = prompt_config.system_prompt.strip()
|
||||
system_prompt += REQUIRE_CITATION_STATEMENT_V2
|
||||
tools: list[Tool] | None = None,
|
||||
) -> SystemMessage:
|
||||
# Check if we should include custom instructions (before date processing)
|
||||
custom_instructions = prompt_config.system_prompt.strip()
|
||||
clean_custom_instructions = "".join(custom_instructions.split())
|
||||
clean_default_system_prompt = "".join(DEFAULT_SYSTEM_PROMPT.split())
|
||||
should_include_custom_instructions = (
|
||||
clean_custom_instructions
|
||||
and clean_custom_instructions != clean_default_system_prompt
|
||||
)
|
||||
|
||||
# Start with base prompt
|
||||
system_prompt = DEFAULT_SYSTEM_PROMPT + "\n" + LONG_CONVERSATION_REMINDER_PROMPT
|
||||
|
||||
# See https://simonwillison.net/tags/markdown/ for context on this temporary fix
|
||||
# for o-series markdown generation
|
||||
if (
|
||||
@@ -48,20 +64,51 @@ def default_build_system_message_v2(
|
||||
and llm_config.model_name.startswith("o")
|
||||
):
|
||||
system_prompt = CODE_BLOCK_MARKDOWN + system_prompt
|
||||
|
||||
if should_include_custom_instructions:
|
||||
system_prompt += "\n\n## Custom Instructions\n"
|
||||
system_prompt += CUSTOM_INSTRUCTIONS_PROMPT
|
||||
system_prompt += custom_instructions
|
||||
|
||||
tag_handled_prompt = handle_onyx_date_awareness(
|
||||
system_prompt,
|
||||
prompt_config,
|
||||
add_additional_info_if_no_tag=prompt_config.datetime_aware,
|
||||
)
|
||||
|
||||
if not tag_handled_prompt:
|
||||
return None
|
||||
|
||||
tag_handled_prompt = handle_company_awareness(tag_handled_prompt)
|
||||
|
||||
if memories_callback:
|
||||
tag_handled_prompt = handle_memories(tag_handled_prompt, memories_callback)
|
||||
|
||||
# Add Tools section if tools are provided
|
||||
if tools:
|
||||
tag_handled_prompt += "\n\n# Tools\n"
|
||||
tag_handled_prompt += TOOL_PERSISTENCE_PROMPT
|
||||
|
||||
for tool in tools:
|
||||
if type(tool).__name__ == "WebSearchTool":
|
||||
# Import at runtime to avoid circular dependency
|
||||
from onyx.tools.tool_implementations_v2.web import (
|
||||
WEB_SEARCH_LONG_DESCRIPTION,
|
||||
OPEN_URL_LONG_DESCRIPTION,
|
||||
)
|
||||
|
||||
# Special handling for WebSearchTool - expand to web_search and open_url
|
||||
tag_handled_prompt += "\n## web_search\n"
|
||||
tag_handled_prompt += WEB_SEARCH_LONG_DESCRIPTION
|
||||
tag_handled_prompt += "\n\n## open_url\n"
|
||||
tag_handled_prompt += OPEN_URL_LONG_DESCRIPTION
|
||||
else:
|
||||
# TODO: ToolV2 should make this much cleaner
|
||||
from onyx.tools.adapter_v1_to_v2 import tools_to_function_tools
|
||||
|
||||
if tools_to_function_tools([tool]):
|
||||
tag_handled_prompt += (
|
||||
f"\n## {tools_to_function_tools([tool])[0].name}\n"
|
||||
)
|
||||
tag_handled_prompt += tool.description
|
||||
|
||||
return SystemMessage(content=tag_handled_prompt)
|
||||
|
||||
|
||||
@@ -95,24 +142,6 @@ def default_build_system_message(
|
||||
return SystemMessage(content=tag_handled_prompt)
|
||||
|
||||
|
||||
def default_build_user_message_v2(
|
||||
user_query: str,
|
||||
prompt_config: PromptConfig,
|
||||
files: list[InMemoryChatFile] = [],
|
||||
) -> HumanMessage:
|
||||
user_prompt = user_query
|
||||
user_prompt = user_prompt.strip()
|
||||
tag_handled_prompt = handle_onyx_date_awareness(user_prompt, prompt_config)
|
||||
user_msg = HumanMessage(
|
||||
content=(
|
||||
build_content_with_imgs(tag_handled_prompt, files)
|
||||
if files
|
||||
else tag_handled_prompt
|
||||
)
|
||||
)
|
||||
return user_msg
|
||||
|
||||
|
||||
def default_build_user_message(
|
||||
user_query: str,
|
||||
prompt_config: PromptConfig,
|
||||
|
||||
@@ -20,7 +20,6 @@ class AnswerResponseHandler(abc.ABC):
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | str | None,
|
||||
previous_response_items: list[BaseMessage | str],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -29,7 +28,6 @@ class PassThroughAnswerResponseHandler(AnswerResponseHandler):
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | str | None,
|
||||
previous_response_items: list[BaseMessage | str],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
content = _message_to_str(response_item)
|
||||
yield OnyxAnswerPiece(answer_piece=content)
|
||||
@@ -39,7 +37,6 @@ class DummyAnswerResponseHandler(AnswerResponseHandler):
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | str | None,
|
||||
previous_response_items: list[BaseMessage | str],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
# This is a dummy handler that returns nothing
|
||||
yield from []
|
||||
@@ -49,27 +46,19 @@ class CitationResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
final_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.final_doc_id_to_rank_map = final_doc_id_to_rank_map
|
||||
self.display_doc_id_to_rank_map = display_doc_id_to_rank_map
|
||||
self.citation_processor = CitationProcessor(
|
||||
context_docs=self.context_docs,
|
||||
final_doc_id_to_rank_map=self.final_doc_id_to_rank_map,
|
||||
display_doc_id_to_rank_map=self.display_doc_id_to_rank_map,
|
||||
doc_id_to_rank_map=doc_id_to_rank_map,
|
||||
)
|
||||
self.processed_text = ""
|
||||
self.citations: list[CitationInfo] = []
|
||||
|
||||
# TODO remove this after citation issue is resolved
|
||||
logger.debug(f"Document to ranking map {self.final_doc_id_to_rank_map}")
|
||||
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | str | None,
|
||||
previous_response_items: list[BaseMessage | str],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
if response_item is None:
|
||||
return
|
||||
|
||||
@@ -50,13 +50,11 @@ class CitationProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
final_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
):
|
||||
self.context_docs = context_docs # list of docs in the order the LLM sees
|
||||
self.final_order_mapping = final_doc_id_to_rank_map.order_mapping
|
||||
self.display_order_mapping = display_doc_id_to_rank_map.order_mapping
|
||||
self.order_mapping = doc_id_to_rank_map.order_mapping
|
||||
self.max_citation_num = len(context_docs)
|
||||
self.stop_stream = stop_stream
|
||||
|
||||
@@ -69,11 +67,14 @@ class CitationProcessor:
|
||||
self.non_citation_count = 0
|
||||
|
||||
# '[', '[[', '[1', '[[1', '[1,', '[1, ', '[1,2', '[1, 2,', etc.
|
||||
self.possible_citation_pattern = re.compile(r"(\[+(?:\d+,? ?)*$)")
|
||||
# Also matches unicode bracket variants: 【, [
|
||||
self.possible_citation_pattern = re.compile(r"([\[【[]+(?:\d+,? ?)*$)")
|
||||
|
||||
# group 1: '[[1]]', [[2]], etc.
|
||||
# group 2: '[1]', '[1, 2]', '[1,2,16]', etc.
|
||||
self.citation_pattern = re.compile(r"(\[\[\d+\]\])|(\[\d+(?:, ?\d+)*\])")
|
||||
# group 1: '[[1]]', [[2]], etc. (also matches 【【1】】, [[1]], 【1】, [1])
|
||||
# group 2: '[1]', '[1, 2]', '[1,2,16]', etc. (also matches unicode variants)
|
||||
self.citation_pattern = re.compile(
|
||||
r"([\[【[]{2}\d+[\]】]]{2})|([\[【[]\d+(?:, ?\d+)*[\]】]])"
|
||||
)
|
||||
|
||||
def process_token(
|
||||
self, token: str | None
|
||||
@@ -149,15 +150,20 @@ class CitationProcessor:
|
||||
def process_citation(self, match: re.Match) -> tuple[str, list[CitationInfo]]:
|
||||
"""
|
||||
Process a single citation match and return the citation string and the
|
||||
citation info. The match string can look like '[1]', '[1, 13, 6], '[[4]]', etc.
|
||||
citation info. The match string can look like '[1]', '[1, 13, 6], '[[4]]',
|
||||
'【1】', '【【4】】', '[1]', etc.
|
||||
"""
|
||||
citation_str: str = match.group() # e.g., '[1]', '[1, 2, 3]', '[[1]]', etc.
|
||||
formatted = match.lastindex == 1 # True means already in the form '[[1]]'
|
||||
citation_str: str = (
|
||||
match.group()
|
||||
) # e.g., '[1]', '[1, 2, 3]', '[[1]]', '【1】', etc.
|
||||
formatted = (
|
||||
match.lastindex == 1
|
||||
) # True means already in the form '[[1]]' or '【【1】】'
|
||||
|
||||
final_processed_str = ""
|
||||
final_citation_info: list[CitationInfo] = []
|
||||
|
||||
# process the citation_str
|
||||
# process the citation_str - regex ensures matched brackets, so we can simply slice
|
||||
citation_content = citation_str[2:-2] if formatted else citation_str[1:-1]
|
||||
for num in (int(num) for num in citation_content.split(",")):
|
||||
# keep invalid citations as is
|
||||
@@ -169,13 +175,13 @@ class CitationProcessor:
|
||||
# should always be in the display_doc_order_dict. But check anyways
|
||||
context_llm_doc = self.context_docs[num - 1]
|
||||
llm_docid = context_llm_doc.document_id
|
||||
if llm_docid not in self.display_order_mapping:
|
||||
if llm_docid not in self.order_mapping:
|
||||
logger.warning(
|
||||
f"Doc {llm_docid} not in display_doc_order_dict. "
|
||||
f"Doc {llm_docid} not in doc_order_dict. "
|
||||
"Used LLM citation number instead."
|
||||
)
|
||||
displayed_citation_num = self.display_order_mapping.get(
|
||||
llm_docid, self.final_order_mapping[llm_docid]
|
||||
displayed_citation_num = self.order_mapping.get(
|
||||
llm_docid, self.order_mapping[llm_docid]
|
||||
)
|
||||
|
||||
# skip citations of the same work if cited recently
|
||||
@@ -223,13 +229,17 @@ class CitationProcessorGraph:
|
||||
|
||||
# '[', '[[', '[1', '[[1', '[1,', '[1, ', '[1,2', '[1, 2,', etc.
|
||||
# Also supports '[D1', '[D1, D3' type patterns
|
||||
self.possible_citation_pattern = re.compile(r"(\[+(?:(?:\d+|D\d+),? ?)*$)")
|
||||
# Also supports unicode bracket variants: 【, [
|
||||
self.possible_citation_pattern = re.compile(
|
||||
r"([\[【[]+(?:(?:\d+|D\d+),? ?)*$)"
|
||||
)
|
||||
|
||||
# group 1: '[[1]]', [[2]], etc.
|
||||
# group 2: '[1]', '[1, 2]', '[1,2,16]', etc.
|
||||
# Also supports '[D1]', '[D1, D3]', '[[D1]]' type patterns
|
||||
# Also supports unicode bracket variants
|
||||
self.citation_pattern = re.compile(
|
||||
r"(\[\[(?:\d+|D\d+)\]\])|(\[(?:\d+|D\d+)(?:, ?(?:\d+|D\d+))*\])"
|
||||
r"([\[【[]{2}(?:\d+|D\d+)[\]】]]{2})|([\[【[](?:\d+|D\d+)(?:, ?(?:\d+|D\d+))*[\]】]])"
|
||||
)
|
||||
|
||||
def process_token(
|
||||
@@ -309,15 +319,20 @@ class CitationProcessorGraph:
|
||||
def process_citation(self, match: re.Match) -> tuple[str, list[CitationInfo]]:
|
||||
"""
|
||||
Process a single citation match and return the citation string and the
|
||||
citation info. The match string can look like '[1]', '[1, 13, 6], '[[4]]', etc.
|
||||
citation info. The match string can look like '[1]', '[1, 13, 6], '[[4]]',
|
||||
'【1】', '【【4】】', '[1]', '[D1]', etc.
|
||||
"""
|
||||
citation_str: str = match.group() # e.g., '[1]', '[1, 2, 3]', '[[1]]', etc.
|
||||
formatted = match.lastindex == 1 # True means already in the form '[[1]]'
|
||||
citation_str: str = (
|
||||
match.group()
|
||||
) # e.g., '[1]', '[1, 2, 3]', '[[1]]', '【1】', etc.
|
||||
formatted = (
|
||||
match.lastindex == 1
|
||||
) # True means already in the form '[[1]]' or '【【1】】'
|
||||
|
||||
final_processed_str = ""
|
||||
final_citation_info: list[CitationInfo] = []
|
||||
|
||||
# process the citation_str
|
||||
# process the citation_str - regex ensures matched brackets, so we can simply slice
|
||||
citation_content = citation_str[2:-2] if formatted else citation_str[1:-1]
|
||||
for num in (int(num) for num in citation_content.split(",")):
|
||||
# keep invalid citations as is
|
||||
|
||||
@@ -21,3 +21,11 @@ def map_document_id_order(
|
||||
current += 1
|
||||
|
||||
return DocumentIdOrderMapping(order_mapping=order_mapping)
|
||||
|
||||
|
||||
def map_document_id_order_v2(fetched_docs: list[LlmDoc]) -> DocumentIdOrderMapping:
|
||||
order_mapping = {}
|
||||
for doc in fetched_docs:
|
||||
if doc.document_id not in order_mapping and doc.document_citation_number:
|
||||
order_mapping[doc.document_id] = doc.document_citation_number
|
||||
return DocumentIdOrderMapping(order_mapping=order_mapping)
|
||||
|
||||
@@ -109,7 +109,6 @@ class ToolResponseHandler:
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | str | None,
|
||||
previous_response_items: list[BaseMessage | str],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
if response_item is None:
|
||||
yield from self._handle_tool_call()
|
||||
|
||||
0
backend/onyx/chat/turn/context_handler/__init__.py
Normal file
0
backend/onyx/chat/turn/context_handler/__init__.py
Normal file
82
backend/onyx/chat/turn/context_handler/citation.py
Normal file
82
backend/onyx/chat/turn/context_handler/citation.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Citation context handler for assigning sequential citation numbers to documents."""
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
|
||||
from onyx.agents.agent_sdk.message_types import AgentSDKMessage
|
||||
from onyx.agents.agent_sdk.message_types import FunctionCallOutputMessage
|
||||
from onyx.chat.models import DOCUMENT_CITATION_NUMBER_EMPTY_VALUE
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
|
||||
|
||||
class CitationAssignmentResult(BaseModel):
|
||||
updated_messages: list[AgentSDKMessage]
|
||||
num_docs_cited: int
|
||||
num_tool_calls_cited: int
|
||||
new_llm_docs: list[LlmDoc]
|
||||
|
||||
|
||||
def assign_citation_numbers_recent_tool_calls(
|
||||
agent_turn_messages: Sequence[AgentSDKMessage],
|
||||
ctx: ChatTurnContext,
|
||||
) -> CitationAssignmentResult:
|
||||
updated_messages: list[AgentSDKMessage] = []
|
||||
docs_fetched_so_far = ctx.documents_processed_by_citation_context_handler
|
||||
tool_calls_cited_so_far = ctx.tool_calls_processed_by_citation_context_handler
|
||||
num_tool_calls_cited = 0
|
||||
num_docs_cited = 0
|
||||
curr_tool_call_idx = 0
|
||||
new_llm_docs: list[LlmDoc] = []
|
||||
|
||||
for message in agent_turn_messages:
|
||||
new_message: AgentSDKMessage | None = None
|
||||
if message.get("type") == "function_call_output":
|
||||
if curr_tool_call_idx >= tool_calls_cited_so_far:
|
||||
# Type narrow to FunctionCallOutputMessage after checking the 'type' field
|
||||
func_call_output_msg: FunctionCallOutputMessage = message # type: ignore[assignment]
|
||||
content = func_call_output_msg["output"]
|
||||
try:
|
||||
raw_list = json.loads(content)
|
||||
llm_docs = [LlmDoc(**doc) for doc in raw_list]
|
||||
except (json.JSONDecodeError, TypeError, ValidationError):
|
||||
llm_docs = []
|
||||
|
||||
if llm_docs:
|
||||
updated_citation_number = False
|
||||
for doc in llm_docs:
|
||||
if (
|
||||
doc.document_citation_number
|
||||
== DOCUMENT_CITATION_NUMBER_EMPTY_VALUE
|
||||
):
|
||||
num_docs_cited += 1 # add 1 first so it's 1-indexed
|
||||
updated_citation_number = True
|
||||
doc.document_citation_number = (
|
||||
docs_fetched_so_far + num_docs_cited
|
||||
)
|
||||
if updated_citation_number:
|
||||
# Create updated function call output message
|
||||
updated_output_message: FunctionCallOutputMessage = {
|
||||
"type": "function_call_output",
|
||||
"call_id": func_call_output_msg["call_id"],
|
||||
"output": json.dumps(
|
||||
[doc.model_dump(mode="json") for doc in llm_docs]
|
||||
),
|
||||
}
|
||||
new_message = updated_output_message
|
||||
num_tool_calls_cited += 1
|
||||
new_llm_docs.extend(llm_docs)
|
||||
# Increment counter for ALL function_call_output messages, not just processed ones
|
||||
curr_tool_call_idx += 1
|
||||
|
||||
updated_messages.append(new_message or message)
|
||||
|
||||
return CitationAssignmentResult(
|
||||
updated_messages=updated_messages,
|
||||
num_docs_cited=num_docs_cited,
|
||||
num_tool_calls_cited=num_tool_calls_cited,
|
||||
new_llm_docs=new_llm_docs,
|
||||
)
|
||||
55
backend/onyx/chat/turn/context_handler/task_prompt.py
Normal file
55
backend/onyx/chat/turn/context_handler/task_prompt.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Task prompt context handler for updating task prompts in agent messages."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from onyx.agents.agent_sdk.message_types import AgentSDKMessage
|
||||
from onyx.agents.agent_sdk.message_types import InputTextContent
|
||||
from onyx.agents.agent_sdk.message_types import UserMessage
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.prompts.prompt_utils import build_task_prompt_reminders_v2
|
||||
|
||||
|
||||
def update_task_prompt(
|
||||
current_user_message: UserMessage,
|
||||
agent_turn_messages: Sequence[AgentSDKMessage],
|
||||
prompt_config: PromptConfig,
|
||||
should_cite_documents: bool,
|
||||
) -> list[AgentSDKMessage]:
|
||||
user_query = _extract_user_query(current_user_message)
|
||||
new_task_prompt_text = build_task_prompt_reminders_v2(
|
||||
user_query,
|
||||
prompt_config,
|
||||
use_language_hint=False,
|
||||
should_cite=should_cite_documents,
|
||||
)
|
||||
last_user_idx = max(
|
||||
(i for i, m in enumerate(agent_turn_messages) if m.get("role") == "user"),
|
||||
default=-1,
|
||||
)
|
||||
|
||||
# Filter out last user message and add new task prompt as user message
|
||||
filtered_messages: list[AgentSDKMessage] = [
|
||||
m for i, m in enumerate(agent_turn_messages) if i != last_user_idx
|
||||
]
|
||||
|
||||
text_content: InputTextContent = {
|
||||
"type": "input_text",
|
||||
"text": new_task_prompt_text,
|
||||
}
|
||||
new_user_message: UserMessage = {"role": "user", "content": [text_content]}
|
||||
|
||||
return filtered_messages + [new_user_message]
|
||||
|
||||
|
||||
def _extract_user_query(current_user_message: UserMessage) -> str:
|
||||
pass
|
||||
|
||||
first_content = current_user_message["content"][0]
|
||||
# User messages contain InputTextContent or ImageContent
|
||||
# Only InputTextContent has "text" field, ImageContent has "image_url"
|
||||
if first_content["type"] == "input_text":
|
||||
# Type narrow - we know it's InputTextContent based on the type check
|
||||
text_content: InputTextContent = first_content # type: ignore[assignment]
|
||||
return text_content["text"]
|
||||
# If it's an image content, return empty string or handle appropriately
|
||||
return ""
|
||||
@@ -1,28 +1,39 @@
|
||||
from dataclasses import replace
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from agents import Agent
|
||||
from agents import ModelSettings
|
||||
from agents import RawResponsesStreamEvent
|
||||
from agents import StopAtTools
|
||||
from agents import RunResultStreaming
|
||||
from agents import ToolCallItem
|
||||
from agents.tracing import trace
|
||||
|
||||
from onyx.agents.agent_sdk.message_types import AgentSDKMessage
|
||||
from onyx.agents.agent_sdk.message_types import UserMessage
|
||||
from onyx.agents.agent_sdk.monkey_patches import (
|
||||
monkey_patch_convert_tool_choice_to_ignore_openai_hosted_web_search,
|
||||
)
|
||||
from onyx.agents.agent_sdk.sync_agent_stream_adapter import SyncAgentStream
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.chat.chat_utils import llm_doc_from_inference_section
|
||||
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.stop_signal_checker import is_connected
|
||||
from onyx.chat.stop_signal_checker import reset_cancel_status
|
||||
from onyx.chat.stream_processing.citation_processing import CitationProcessor
|
||||
from onyx.chat.stream_processing.utils import map_document_id_order_v2
|
||||
from onyx.chat.turn.context_handler.citation import (
|
||||
assign_citation_numbers_recent_tool_calls,
|
||||
)
|
||||
from onyx.chat.turn.context_handler.task_prompt import update_task_prompt
|
||||
from onyx.chat.turn.infra.chat_turn_event_stream import unified_event_stream
|
||||
from onyx.chat.turn.infra.session_sink import extract_final_answer_from_packets
|
||||
from onyx.chat.turn.infra.session_sink import save_iteration
|
||||
from onyx.chat.turn.infra.sync_agent_stream_adapter import SyncAgentStream
|
||||
from onyx.chat.turn.models import AgentToolType
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.chat.turn.models import ChatTurnDependencies
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.chat.turn.save_turn import extract_final_answer_from_packets
|
||||
from onyx.chat.turn.save_turn import save_turn
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
@@ -30,18 +41,117 @@ from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketObj
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.tool_implementations_v2.image_generation import image_generation_tool
|
||||
from onyx.tools.adapter_v1_to_v2 import force_use_tool_to_function_tool_names
|
||||
from onyx.tools.adapter_v1_to_v2 import tools_to_function_tools
|
||||
from onyx.tools.force import ForceUseTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import ResponseFunctionToolCall
|
||||
|
||||
|
||||
# TODO -- this can be refactored out and played with in evals + normal demo
|
||||
def _run_agent_loop(
|
||||
messages: list[AgentSDKMessage],
|
||||
dependencies: ChatTurnDependencies,
|
||||
chat_session_id: UUID,
|
||||
ctx: ChatTurnContext,
|
||||
prompt_config: PromptConfig,
|
||||
force_use_tool: ForceUseTool | None = None,
|
||||
) -> None:
|
||||
monkey_patch_convert_tool_choice_to_ignore_openai_hosted_web_search()
|
||||
# Split messages into three parts for clear tracking
|
||||
# TODO: Think about terminal tool calls like image gen
|
||||
# in multi turn conversations
|
||||
chat_history = messages[:-1]
|
||||
current_user_message = messages[-1]
|
||||
if (
|
||||
not isinstance(current_user_message, dict)
|
||||
or current_user_message.get("role") != "user"
|
||||
):
|
||||
raise ValueError("Last message must be a user message")
|
||||
current_user_message_typed: UserMessage = current_user_message # type: ignore
|
||||
agent_turn_messages: list[AgentSDKMessage] = []
|
||||
last_call_is_final = False
|
||||
first_iteration = True
|
||||
|
||||
while not last_call_is_final:
|
||||
current_messages = chat_history + [current_user_message] + agent_turn_messages
|
||||
if not dependencies.tools:
|
||||
tool_choice = None
|
||||
else:
|
||||
tool_choice = (
|
||||
force_use_tool_to_function_tool_names(
|
||||
force_use_tool, dependencies.tools
|
||||
)
|
||||
if first_iteration and force_use_tool
|
||||
else None
|
||||
) or "auto"
|
||||
model_settings = replace(dependencies.model_settings, tool_choice=tool_choice)
|
||||
agent = Agent(
|
||||
name="Assistant",
|
||||
model=dependencies.llm_model,
|
||||
tools=cast(
|
||||
list[AgentToolType], tools_to_function_tools(dependencies.tools)
|
||||
),
|
||||
model_settings=model_settings,
|
||||
tool_use_behavior="stop_on_first_tool",
|
||||
)
|
||||
agent_stream: SyncAgentStream = SyncAgentStream(
|
||||
agent=agent,
|
||||
input=current_messages,
|
||||
context=ctx,
|
||||
)
|
||||
streamed, tool_call_events = _process_stream(
|
||||
agent_stream, chat_session_id, dependencies, ctx
|
||||
)
|
||||
|
||||
all_messages_after_stream = streamed.to_input_list()
|
||||
# The new messages are everything after chat_history + current_user_message
|
||||
previous_message_count = len(chat_history) + 1
|
||||
agent_turn_messages = [
|
||||
cast(AgentSDKMessage, msg)
|
||||
for msg in all_messages_after_stream[previous_message_count:]
|
||||
]
|
||||
|
||||
agent_turn_messages = list(
|
||||
update_task_prompt(
|
||||
current_user_message_typed,
|
||||
agent_turn_messages,
|
||||
prompt_config,
|
||||
ctx.should_cite_documents,
|
||||
)
|
||||
)
|
||||
citation_result = assign_citation_numbers_recent_tool_calls(
|
||||
agent_turn_messages, ctx
|
||||
)
|
||||
agent_turn_messages = list(citation_result.updated_messages)
|
||||
ctx.ordered_fetched_documents.extend(citation_result.new_llm_docs)
|
||||
ctx.documents_processed_by_citation_context_handler += (
|
||||
citation_result.num_docs_cited
|
||||
)
|
||||
ctx.tool_calls_processed_by_citation_context_handler += (
|
||||
citation_result.num_tool_calls_cited
|
||||
)
|
||||
|
||||
# TODO: Make this configurable on OnyxAgent level
|
||||
stopping_tools = ["image_generation"]
|
||||
if len(tool_call_events) == 0 or any(
|
||||
tool.name in stopping_tools for tool in tool_call_events
|
||||
):
|
||||
last_call_is_final = True
|
||||
first_iteration = False
|
||||
|
||||
|
||||
def _fast_chat_turn_core(
|
||||
messages: list[dict],
|
||||
messages: list[AgentSDKMessage],
|
||||
dependencies: ChatTurnDependencies,
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
research_type: ResearchType,
|
||||
# Dependency injectable arguments for testing
|
||||
starter_global_iteration_responses: list[IterationAnswer] | None = None,
|
||||
starter_cited_documents: list[InferenceSection] | None = None,
|
||||
prompt_config: PromptConfig,
|
||||
force_use_tool: ForceUseTool | None = None,
|
||||
# Dependency injectable argument for testing
|
||||
starter_context: ChatTurnContext | None = None,
|
||||
) -> None:
|
||||
"""Core fast chat turn logic that allows overriding global_iteration_responses for testing.
|
||||
|
||||
@@ -58,37 +168,86 @@ def _fast_chat_turn_core(
|
||||
chat_session_id,
|
||||
dependencies.redis_client,
|
||||
)
|
||||
ctx = ChatTurnContext(
|
||||
ctx = starter_context or ChatTurnContext(
|
||||
run_dependencies=dependencies,
|
||||
aggregated_context=AggregatedDRContext(
|
||||
context="context",
|
||||
cited_documents=starter_cited_documents or [],
|
||||
is_internet_marker_dict={},
|
||||
global_iteration_responses=starter_global_iteration_responses or [],
|
||||
),
|
||||
iteration_instructions=[],
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=message_id,
|
||||
research_type=research_type,
|
||||
)
|
||||
agent = Agent(
|
||||
name="Assistant",
|
||||
model=dependencies.llm_model,
|
||||
tools=cast(list[AgentToolType], dependencies.tools),
|
||||
model_settings=ModelSettings(
|
||||
temperature=dependencies.llm.config.temperature,
|
||||
include_usage=True,
|
||||
),
|
||||
tool_use_behavior=StopAtTools(stop_at_tool_names=[image_generation_tool.name]),
|
||||
with trace("fast_chat_turn"):
|
||||
_run_agent_loop(
|
||||
messages=messages,
|
||||
dependencies=dependencies,
|
||||
chat_session_id=chat_session_id,
|
||||
ctx=ctx,
|
||||
prompt_config=prompt_config,
|
||||
force_use_tool=force_use_tool,
|
||||
)
|
||||
_emit_citations_for_final_answer(
|
||||
dependencies=dependencies,
|
||||
ctx=ctx,
|
||||
)
|
||||
# By default, the agent can only take 10 turns. For our use case, it should be higher.
|
||||
max_turns = 25
|
||||
agent_stream: SyncAgentStream = SyncAgentStream(
|
||||
agent=agent,
|
||||
input=messages,
|
||||
context=ctx,
|
||||
max_turns=max_turns,
|
||||
final_answer = extract_final_answer_from_packets(
|
||||
dependencies.emitter.packet_history
|
||||
)
|
||||
save_turn(
|
||||
db_session=dependencies.db_session,
|
||||
message_id=message_id,
|
||||
chat_session_id=chat_session_id,
|
||||
research_type=research_type,
|
||||
model_name=dependencies.llm.config.model_name,
|
||||
model_provider=dependencies.llm.config.model_provider,
|
||||
iteration_instructions=ctx.iteration_instructions,
|
||||
global_iteration_responses=ctx.global_iteration_responses,
|
||||
final_answer=final_answer,
|
||||
unordered_fetched_inference_sections=ctx.unordered_fetched_inference_sections,
|
||||
ordered_fetched_documents=ctx.ordered_fetched_documents,
|
||||
)
|
||||
dependencies.emitter.emit(
|
||||
Packet(ind=ctx.current_run_step, obj=OverallStop(type="stop"))
|
||||
)
|
||||
|
||||
|
||||
@unified_event_stream
|
||||
def fast_chat_turn(
|
||||
messages: list[AgentSDKMessage],
|
||||
dependencies: ChatTurnDependencies,
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
research_type: ResearchType,
|
||||
prompt_config: PromptConfig,
|
||||
force_use_tool: ForceUseTool | None = None,
|
||||
) -> None:
|
||||
"""Main fast chat turn function that calls the core logic with default parameters."""
|
||||
_fast_chat_turn_core(
|
||||
messages,
|
||||
dependencies,
|
||||
chat_session_id,
|
||||
message_id,
|
||||
research_type,
|
||||
prompt_config,
|
||||
force_use_tool=force_use_tool,
|
||||
)
|
||||
|
||||
|
||||
def _process_stream(
|
||||
agent_stream: SyncAgentStream,
|
||||
chat_session_id: UUID,
|
||||
dependencies: ChatTurnDependencies,
|
||||
ctx: ChatTurnContext,
|
||||
) -> tuple[RunResultStreaming, list["ResponseFunctionToolCall"]]:
|
||||
from litellm import ResponseFunctionToolCall
|
||||
|
||||
mapping = map_document_id_order_v2(ctx.ordered_fetched_documents)
|
||||
if ctx.ordered_fetched_documents:
|
||||
processor = CitationProcessor(
|
||||
context_docs=ctx.ordered_fetched_documents,
|
||||
doc_id_to_rank_map=mapping,
|
||||
stop_stream=None,
|
||||
)
|
||||
else:
|
||||
processor = None
|
||||
tool_call_events: list[ResponseFunctionToolCall] = []
|
||||
for ev in agent_stream:
|
||||
connected = is_connected(
|
||||
chat_session_id,
|
||||
@@ -98,58 +257,14 @@ def _fast_chat_turn_core(
|
||||
_emit_clean_up_packets(dependencies, ctx)
|
||||
agent_stream.cancel()
|
||||
break
|
||||
obj = _default_packet_translation(ev, ctx)
|
||||
obj = _default_packet_translation(ev, ctx, processor)
|
||||
if obj:
|
||||
dependencies.emitter.emit(Packet(ind=ctx.current_run_step, obj=obj))
|
||||
final_answer = extract_final_answer_from_packets(
|
||||
dependencies.emitter.packet_history
|
||||
)
|
||||
|
||||
all_cited_documents = []
|
||||
if ctx.aggregated_context.global_iteration_responses:
|
||||
context_docs = _gather_context_docs_from_iteration_answers(
|
||||
ctx.aggregated_context.global_iteration_responses
|
||||
)
|
||||
all_cited_documents = context_docs
|
||||
if context_docs and final_answer:
|
||||
_process_citations_for_final_answer(
|
||||
final_answer=final_answer,
|
||||
context_docs=context_docs,
|
||||
dependencies=dependencies,
|
||||
ctx=ctx,
|
||||
)
|
||||
|
||||
save_iteration(
|
||||
db_session=dependencies.db_session,
|
||||
message_id=message_id,
|
||||
chat_session_id=chat_session_id,
|
||||
research_type=research_type,
|
||||
ctx=ctx,
|
||||
final_answer=final_answer,
|
||||
all_cited_documents=all_cited_documents,
|
||||
)
|
||||
dependencies.emitter.emit(
|
||||
Packet(ind=ctx.current_run_step, obj=OverallStop(type="stop"))
|
||||
)
|
||||
|
||||
|
||||
@unified_event_stream
|
||||
def fast_chat_turn(
|
||||
messages: list[dict],
|
||||
dependencies: ChatTurnDependencies,
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
research_type: ResearchType,
|
||||
) -> None:
|
||||
"""Main fast chat turn function that calls the core logic with default parameters."""
|
||||
_fast_chat_turn_core(
|
||||
messages,
|
||||
dependencies,
|
||||
chat_session_id,
|
||||
message_id,
|
||||
research_type,
|
||||
starter_global_iteration_responses=None,
|
||||
)
|
||||
if isinstance(getattr(ev, "item", None), ToolCallItem):
|
||||
tool_call_events.append(cast(ResponseFunctionToolCall, ev.item.raw_item))
|
||||
if agent_stream.streamed is None:
|
||||
raise ValueError("agent_stream.streamed is None")
|
||||
return agent_stream.streamed, tool_call_events
|
||||
|
||||
|
||||
# TODO: Maybe in general there's a cleaner way to handle cancellation in the middle of a tool call?
|
||||
@@ -173,85 +288,46 @@ def _emit_clean_up_packets(
|
||||
)
|
||||
|
||||
|
||||
def _gather_context_docs_from_iteration_answers(
|
||||
iteration_answers: list[IterationAnswer],
|
||||
) -> list[InferenceSection]:
|
||||
"""Gather cited documents from iteration answers for citation processing."""
|
||||
context_docs: list[InferenceSection] = []
|
||||
|
||||
for iteration_answer in iteration_answers:
|
||||
# Extract cited documents from this iteration
|
||||
for inference_section in iteration_answer.cited_documents.values():
|
||||
# Avoid duplicates by checking document_id
|
||||
if not any(
|
||||
doc.center_chunk.document_id
|
||||
== inference_section.center_chunk.document_id
|
||||
for doc in context_docs
|
||||
):
|
||||
context_docs.append(inference_section)
|
||||
|
||||
return context_docs
|
||||
|
||||
|
||||
def _process_citations_for_final_answer(
|
||||
final_answer: str,
|
||||
context_docs: list[InferenceSection],
|
||||
def _emit_citations_for_final_answer(
|
||||
dependencies: ChatTurnDependencies,
|
||||
ctx: ChatTurnContext,
|
||||
) -> None:
|
||||
index = ctx.current_run_step + 1
|
||||
"""Process citations in the final answer and emit citation events."""
|
||||
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
|
||||
# Convert InferenceSection objects to LlmDoc objects for citation processing
|
||||
llm_docs = [llm_doc_from_inference_section(section) for section in context_docs]
|
||||
|
||||
# Create document ID to rank mappings (simple 1-based indexing)
|
||||
final_doc_id_to_rank_map = DocumentIdOrderMapping(
|
||||
order_mapping={doc.document_id: i + 1 for i, doc in enumerate(llm_docs)}
|
||||
)
|
||||
display_doc_id_to_rank_map = final_doc_id_to_rank_map # Same mapping for display
|
||||
|
||||
# Initialize citation processor
|
||||
citation_processor = CitationProcessor(
|
||||
context_docs=llm_docs,
|
||||
final_doc_id_to_rank_map=final_doc_id_to_rank_map,
|
||||
display_doc_id_to_rank_map=display_doc_id_to_rank_map,
|
||||
)
|
||||
|
||||
# Process the final answer through citation processor
|
||||
collected_citations: list = []
|
||||
for response_part in citation_processor.process_token(final_answer):
|
||||
if hasattr(response_part, "citation_num"): # It's a CitationInfo
|
||||
collected_citations.append(response_part)
|
||||
|
||||
# Emit citation events if we found any citations
|
||||
if collected_citations:
|
||||
if ctx.citations:
|
||||
dependencies.emitter.emit(Packet(ind=index, obj=CitationStart()))
|
||||
dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=CitationDelta(citations=collected_citations), # type: ignore[arg-type]
|
||||
obj=CitationDelta(citations=ctx.citations),
|
||||
)
|
||||
)
|
||||
dependencies.emitter.emit(Packet(ind=index, obj=SectionEnd(type="section_end")))
|
||||
ctx.current_run_step = index
|
||||
|
||||
|
||||
def _default_packet_translation(ev: object, ctx: ChatTurnContext) -> PacketObj | None:
|
||||
def _default_packet_translation(
|
||||
ev: object, ctx: ChatTurnContext, processor: CitationProcessor | None
|
||||
) -> PacketObj | None:
|
||||
if isinstance(ev, RawResponsesStreamEvent):
|
||||
# TODO: might need some variation here for different types of models
|
||||
# OpenAI packet translator
|
||||
obj: PacketObj | None = None
|
||||
if ev.data.type == "response.content_part.added":
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
ctx.aggregated_context.cited_documents
|
||||
retrieved_search_docs = saved_search_docs_from_llm_docs(
|
||||
ctx.ordered_fetched_documents
|
||||
)
|
||||
obj = MessageStart(
|
||||
type="message_start", content="", final_documents=retrieved_search_docs
|
||||
)
|
||||
elif ev.data.type == "response.output_text.delta":
|
||||
obj = MessageDelta(type="message_delta", content=ev.data.delta)
|
||||
elif ev.data.type == "response.output_text.delta" and len(ev.data.delta) > 0:
|
||||
if processor:
|
||||
final_answer_piece = ""
|
||||
for response_part in processor.process_token(ev.data.delta):
|
||||
if isinstance(response_part, CitationInfo):
|
||||
ctx.citations.append(response_part)
|
||||
else:
|
||||
final_answer_piece += response_part.answer_piece or ""
|
||||
obj = MessageDelta(type="message_delta", content=final_answer_piece)
|
||||
else:
|
||||
obj = MessageDelta(type="message_delta", content=ev.data.delta)
|
||||
elif ev.data.type == "response.content_part.done":
|
||||
obj = SectionEnd(type="section_end")
|
||||
return obj
|
||||
|
||||
@@ -11,22 +11,20 @@ from agents import HostedMCPTool
|
||||
from agents import ImageGenerationTool as AgentsImageGenerationTool
|
||||
from agents import LocalShellTool
|
||||
from agents import Model
|
||||
from agents import ModelSettings
|
||||
from agents import WebSearchTool
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.turn.infra.emitter import Emitter
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
|
||||
OktaProfileTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
# Type alias for all tool types accepted by the Agent
|
||||
AgentToolType = (
|
||||
@@ -44,14 +42,13 @@ AgentToolType = (
|
||||
@dataclass
|
||||
class ChatTurnDependencies:
|
||||
llm_model: Model
|
||||
model_settings: ModelSettings
|
||||
# TODO we can delete this field (combine them)
|
||||
llm: LLM
|
||||
db_session: Session
|
||||
tools: Sequence[FunctionTool]
|
||||
tools: Sequence[Tool]
|
||||
redis_client: Redis
|
||||
emitter: Emitter
|
||||
search_pipeline: SearchTool | None = None
|
||||
image_generation_tool: ImageGenerationTool | None = None
|
||||
okta_profile_tool: OktaProfileTool | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -62,9 +59,18 @@ class ChatTurnContext:
|
||||
message_id: int
|
||||
research_type: ResearchType
|
||||
run_dependencies: ChatTurnDependencies
|
||||
aggregated_context: AggregatedDRContext
|
||||
current_run_step: int = 0
|
||||
iteration_instructions: list[IterationInstructions] = dataclasses.field(
|
||||
default_factory=list
|
||||
)
|
||||
web_fetch_results: list[dict] = dataclasses.field(default_factory=list)
|
||||
global_iteration_responses: list[IterationAnswer] = dataclasses.field(
|
||||
default_factory=list
|
||||
)
|
||||
should_cite_documents: bool = False
|
||||
documents_processed_by_citation_context_handler: int = 0
|
||||
tool_calls_processed_by_citation_context_handler: int = 0
|
||||
unordered_fetched_inference_sections: list[InferenceSection] = dataclasses.field(
|
||||
default_factory=list
|
||||
)
|
||||
ordered_fetched_documents: list[LlmDoc] = dataclasses.field(default_factory=list)
|
||||
citations: list[CitationInfo] = dataclasses.field(default_factory=list)
|
||||
|
||||
@@ -9,11 +9,14 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
|
||||
GeneratedImageFullResult,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.chat import create_search_doc_from_inference_section
|
||||
from onyx.db.chat import update_db_session_with_messages
|
||||
@@ -26,43 +29,59 @@ from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
|
||||
|
||||
def save_iteration(
|
||||
def save_turn(
|
||||
db_session: Session,
|
||||
message_id: int,
|
||||
chat_session_id: UUID,
|
||||
research_type: ResearchType,
|
||||
ctx: ChatTurnContext,
|
||||
final_answer: str,
|
||||
all_cited_documents: list[InferenceSection],
|
||||
unordered_fetched_inference_sections: list[InferenceSection],
|
||||
ordered_fetched_documents: list[LlmDoc],
|
||||
iteration_instructions: list[IterationInstructions],
|
||||
global_iteration_responses: list[IterationAnswer],
|
||||
# TODO: figure out better way to pass these dependencies
|
||||
model_name: str,
|
||||
model_provider: str,
|
||||
) -> None:
|
||||
# first, insert the search_docs
|
||||
is_internet_marker_dict: dict[str, bool] = {}
|
||||
search_docs = [
|
||||
create_search_doc_from_inference_section(
|
||||
inference_section=inference_section,
|
||||
is_internet=is_internet_marker_dict.get(
|
||||
inference_section.center_chunk.document_id, False
|
||||
), # TODO: revisit
|
||||
inference_section=doc,
|
||||
is_internet=doc.center_chunk.source_type == DocumentSource.WEB,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
for inference_section in all_cited_documents
|
||||
for doc in unordered_fetched_inference_sections
|
||||
]
|
||||
|
||||
# then, map_search_docs to message
|
||||
_insert_chat_message_search_doc_pair(
|
||||
message_id, [search_doc.id for search_doc in search_docs], db_session
|
||||
)
|
||||
|
||||
# lastly, insert the citations
|
||||
citation_dict: dict[int, int] = {}
|
||||
cited_doc_nrs = _extract_citation_numbers(final_answer)
|
||||
if search_docs:
|
||||
for cited_doc_nr in cited_doc_nrs:
|
||||
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
|
||||
# Create mapping: citation_number -> document_id
|
||||
citation_to_doc_id = {
|
||||
doc.document_citation_number: doc.document_id
|
||||
for doc in ordered_fetched_documents
|
||||
if doc.document_citation_number is not None
|
||||
}
|
||||
|
||||
# Create mapping: document_id -> search_doc.id
|
||||
doc_id_to_search_doc_id = {doc.document_id: doc.id for doc in search_docs}
|
||||
|
||||
# Chain the lookups: cited_doc_nr -> document_id -> search_doc.id
|
||||
citation_dict = {
|
||||
cited_doc_nr: doc_id_to_search_doc_id[citation_to_doc_id[cited_doc_nr]]
|
||||
for cited_doc_nr in cited_doc_nrs
|
||||
if cited_doc_nr in citation_to_doc_id
|
||||
and citation_to_doc_id[cited_doc_nr] in doc_id_to_search_doc_id
|
||||
}
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=ctx.run_dependencies.llm.config.model_name,
|
||||
provider_type=ctx.run_dependencies.llm.config.model_provider,
|
||||
model_name=model_name,
|
||||
provider_type=model_provider,
|
||||
)
|
||||
num_tokens = len(llm_tokenizer.encode(final_answer or ""))
|
||||
# Update the chat message and its parent message in database
|
||||
@@ -83,7 +102,7 @@ def save_iteration(
|
||||
|
||||
# TODO: I don't think this is the ideal schema for all use cases
|
||||
# find a better schema to store tool and reasoning calls
|
||||
for iteration_preparation in ctx.iteration_instructions:
|
||||
for iteration_preparation in iteration_instructions:
|
||||
research_agent_iteration_step = ResearchAgentIteration(
|
||||
primary_question_id=message_id,
|
||||
reasoning=iteration_preparation.reasoning,
|
||||
@@ -92,7 +111,7 @@ def save_iteration(
|
||||
)
|
||||
db_session.add(research_agent_iteration_step)
|
||||
|
||||
for iteration_answer in ctx.aggregated_context.global_iteration_responses:
|
||||
for iteration_answer in global_iteration_responses:
|
||||
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
list(iteration_answer.cited_documents.values())
|
||||
@@ -126,7 +126,7 @@ OAUTH_CLIENT_SECRET = (
|
||||
os.environ.get("OAUTH_CLIENT_SECRET", os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET"))
|
||||
or ""
|
||||
)
|
||||
# OpenID Connect configuration URL for Okta Profile Tool and other OIDC integrations
|
||||
# OpenID Connect configuration URL for OIDC integrations
|
||||
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL") or ""
|
||||
|
||||
# Applicable for OIDC Auth, allows you to override the scopes that
|
||||
@@ -687,11 +687,6 @@ MAX_TOKENS_FOR_FULL_INCLUSION = 4096
|
||||
#####
|
||||
# Tool Configs
|
||||
#####
|
||||
OKTA_PROFILE_TOOL_ENABLED = (
|
||||
os.environ.get("OKTA_PROFILE_TOOL_ENABLED", "").lower() == "true"
|
||||
)
|
||||
# API token for SSWS auth to Okta Admin API. If set, Users API will be used to enrich profile.
|
||||
OKTA_API_TOKEN = os.environ.get("OKTA_API_TOKEN") or ""
|
||||
|
||||
|
||||
#####
|
||||
|
||||
@@ -21,7 +21,6 @@ GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key"
|
||||
PUBLIC_DOC_PAT = "PUBLIC"
|
||||
ID_SEPARATOR = ":;:"
|
||||
DEFAULT_BOOST = 0
|
||||
SESSION_KEY = "session"
|
||||
|
||||
# Cookies
|
||||
FASTAPI_USERS_AUTH_COOKIE_NAME = (
|
||||
|
||||
@@ -744,7 +744,10 @@ class ConfluenceConnector(
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
try:
|
||||
spaces = self.low_timeout_confluence_client.get_all_spaces(limit=1)
|
||||
spaces_iter = self.low_timeout_confluence_client.retrieve_confluence_spaces(
|
||||
limit=1,
|
||||
)
|
||||
first_space = next(spaces_iter, None)
|
||||
except HTTPError as e:
|
||||
status_code = e.response.status_code if e.response else None
|
||||
if status_code == 401:
|
||||
@@ -763,6 +766,12 @@ class ConfluenceConnector(
|
||||
f"Unexpected error while validating Confluence settings: {e}"
|
||||
)
|
||||
|
||||
if not first_space:
|
||||
raise ConnectorValidationError(
|
||||
"No Confluence spaces found. Either your credentials lack permissions, or "
|
||||
"there truly are no spaces in this Confluence instance."
|
||||
)
|
||||
|
||||
if self.space:
|
||||
try:
|
||||
self.low_timeout_confluence_client.get_space(self.space)
|
||||
@@ -771,12 +780,6 @@ class ConfluenceConnector(
|
||||
"Invalid Confluence space key provided"
|
||||
) from e
|
||||
|
||||
if not spaces or not spaces.get("results"):
|
||||
raise ConnectorValidationError(
|
||||
"No Confluence spaces found. Either your credentials lack permissions, or "
|
||||
"there truly are no spaces in this Confluence instance."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
@@ -46,7 +46,6 @@ from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -63,6 +62,9 @@ _USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
_USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||
|
||||
_CONFLUENCE_SPACES_API_V1 = "rest/api/space"
|
||||
_CONFLUENCE_SPACES_API_V2 = "wiki/api/v2/spaces"
|
||||
|
||||
|
||||
class ConfluenceRateLimitError(Exception):
|
||||
pass
|
||||
@@ -213,6 +215,97 @@ class OnyxConfluence:
|
||||
]
|
||||
return oauth2_dict
|
||||
|
||||
def _build_spaces_url(
|
||||
self,
|
||||
is_v2: bool,
|
||||
base_url: str,
|
||||
limit: int,
|
||||
space_keys: list[str] | None,
|
||||
start: int | None = None,
|
||||
) -> str:
|
||||
"""Build URL for Confluence spaces API with query parameters."""
|
||||
key_param = "keys" if is_v2 else "spaceKey"
|
||||
|
||||
params = [f"limit={limit}"]
|
||||
if space_keys:
|
||||
params.append(f"{key_param}={','.join(space_keys)}")
|
||||
if start is not None and not is_v2:
|
||||
params.append(f"start={start}")
|
||||
|
||||
return f"{base_url}?{'&'.join(params)}"
|
||||
|
||||
def _paginate_spaces_for_endpoint(
|
||||
self,
|
||||
is_v2: bool,
|
||||
base_url: str,
|
||||
limit: int,
|
||||
space_keys: list[str] | None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""Internal helper to paginate through spaces for a specific API endpoint."""
|
||||
start = 0
|
||||
url = self._build_spaces_url(
|
||||
is_v2, base_url, limit, space_keys, start if not is_v2 else None
|
||||
)
|
||||
|
||||
while url:
|
||||
response = self.get(url, advanced_mode=True)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = data.get("results", [])
|
||||
if not results:
|
||||
return
|
||||
|
||||
yield from results
|
||||
|
||||
if is_v2:
|
||||
url = data.get("_links", {}).get("next", "")
|
||||
else:
|
||||
if len(results) < limit:
|
||||
return
|
||||
start += len(results)
|
||||
url = self._build_spaces_url(is_v2, base_url, limit, space_keys, start)
|
||||
|
||||
def retrieve_confluence_spaces(
|
||||
self,
|
||||
space_keys: list[str] | None = None,
|
||||
limit: int = 50,
|
||||
) -> Iterator[dict[str, str]]:
|
||||
"""
|
||||
Retrieve spaces from Confluence using v2 API (Cloud) or v1 API (Server/fallback).
|
||||
|
||||
Args:
|
||||
space_keys: Optional list of space keys to filter by
|
||||
limit: Results per page (default 50)
|
||||
|
||||
Yields:
|
||||
Space dictionaries with keys: id, key, name, type, status, etc.
|
||||
|
||||
Note:
|
||||
For Cloud instances, attempts v2 API first. If v2 returns 404,
|
||||
automatically falls back to v1 API for compatibility with older instances.
|
||||
"""
|
||||
# Determine API version once
|
||||
use_v2 = self._is_cloud and not self.scoped_token
|
||||
base_url = _CONFLUENCE_SPACES_API_V2 if use_v2 else _CONFLUENCE_SPACES_API_V1
|
||||
|
||||
try:
|
||||
yield from self._paginate_spaces_for_endpoint(
|
||||
use_v2, base_url, limit, space_keys
|
||||
)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 404 and use_v2:
|
||||
logger.warning(
|
||||
"v2 spaces API returned 404, falling back to v1 API. "
|
||||
"This may indicate an older Confluence Cloud instance."
|
||||
)
|
||||
# Fallback to v1
|
||||
yield from self._paginate_spaces_for_endpoint(
|
||||
False, _CONFLUENCE_SPACES_API_V1, limit, space_keys
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
def _probe_connection(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
@@ -226,11 +319,9 @@ class OnyxConfluence:
|
||||
if self.scoped_token:
|
||||
# v2 endpoint doesn't always work with scoped tokens, use v1
|
||||
token = credentials["confluence_access_token"]
|
||||
probe_url = f"{self.base_url}/rest/api/space?limit=1"
|
||||
probe_url = f"{self.base_url}/{_CONFLUENCE_SPACES_API_V1}?limit=1"
|
||||
import requests
|
||||
|
||||
logger.info(f"First and Last 5 of token: {token[:5]}...{token[-5:]}")
|
||||
|
||||
try:
|
||||
r = requests.get(
|
||||
probe_url,
|
||||
@@ -252,59 +343,23 @@ class OnyxConfluence:
|
||||
raise e
|
||||
return
|
||||
|
||||
# probe connection with direct client, no retries
|
||||
if "confluence_refresh_token" in credentials:
|
||||
logger.info("Probing Confluence with OAuth Access Token.")
|
||||
# Initialize connection with probe timeout settings
|
||||
self._confluence = self._initialize_connection_helper(
|
||||
credentials, **merged_kwargs
|
||||
)
|
||||
|
||||
oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(
|
||||
credentials
|
||||
)
|
||||
url = (
|
||||
f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
|
||||
)
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
url=url, oauth2=oauth2_dict, **merged_kwargs
|
||||
)
|
||||
else:
|
||||
logger.info("Probing Confluence with Personal Access Token.")
|
||||
url = self._url
|
||||
if self._is_cloud:
|
||||
logger.info("running with cloud client")
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
url=url,
|
||||
username=credentials["confluence_username"],
|
||||
password=credentials["confluence_access_token"],
|
||||
**merged_kwargs,
|
||||
)
|
||||
else:
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
url=url,
|
||||
token=credentials["confluence_access_token"],
|
||||
**merged_kwargs,
|
||||
)
|
||||
# Retrieve first space to validate connection
|
||||
spaces_iter = self.retrieve_confluence_spaces(limit=1)
|
||||
first_space = next(spaces_iter, None)
|
||||
|
||||
# This call sometimes hangs indefinitely, so we run it in a timeout
|
||||
spaces = run_with_timeout(
|
||||
timeout=10,
|
||||
func=confluence_client_with_minimal_retries.get_all_spaces,
|
||||
limit=1,
|
||||
if not first_space:
|
||||
raise RuntimeError(
|
||||
f"No spaces found at {self._url}! "
|
||||
"Check your credentials and wiki_base and make sure "
|
||||
"is_cloud is set correctly."
|
||||
)
|
||||
|
||||
# uncomment the following for testing
|
||||
# the following is an attempt to retrieve the user's timezone
|
||||
# Unfornately, all data is returned in UTC regardless of the user's time zone
|
||||
# even tho CQL parses incoming times based on the user's time zone
|
||||
# space_key = spaces["results"][0]["key"]
|
||||
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
|
||||
|
||||
if not spaces:
|
||||
raise RuntimeError(
|
||||
f"No spaces found at {url}! "
|
||||
"Check your credentials and wiki_base and make sure "
|
||||
"is_cloud is set correctly."
|
||||
)
|
||||
|
||||
logger.info("Confluence probe succeeded.")
|
||||
logger.info("Confluence probe succeeded.")
|
||||
|
||||
def _initialize_connection(
|
||||
self,
|
||||
@@ -786,29 +841,6 @@ class OnyxConfluence:
|
||||
type=user["accountType"],
|
||||
)
|
||||
else:
|
||||
# https://developer.atlassian.com/server/confluence/rest/v900/api-group-user/#api-rest-api-user-list-get
|
||||
# ^ is only available on data center deployments
|
||||
# Example response:
|
||||
# [
|
||||
# {
|
||||
# 'type': 'known',
|
||||
# 'username': 'admin',
|
||||
# 'userKey': '40281082950c5fe901950c61c55d0000',
|
||||
# 'profilePicture': {
|
||||
# 'path': '/images/icons/profilepics/default.svg',
|
||||
# 'width': 48,
|
||||
# 'height': 48,
|
||||
# 'isDefault': True
|
||||
# },
|
||||
# 'displayName': 'Admin Test',
|
||||
# '_links': {
|
||||
# 'self': 'http://localhost:8090/rest/api/user?key=40281082950c5fe901950c61c55d0000'
|
||||
# },
|
||||
# '_expandable': {
|
||||
# 'status': ''
|
||||
# }
|
||||
# }
|
||||
# ]
|
||||
for user in self._paginate_url("rest/api/user/list", limit):
|
||||
yield ConfluenceUser(
|
||||
user_id=user["userKey"],
|
||||
|
||||
@@ -10,6 +10,7 @@ from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from dateutil.parser import parse
|
||||
from dateutil.parser import ParserError
|
||||
|
||||
from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
|
||||
from onyx.configs.constants import IGNORE_FOR_QA
|
||||
@@ -31,22 +32,40 @@ def datetime_to_utc(dt: datetime) -> datetime:
|
||||
|
||||
def time_str_to_utc(datetime_str: str) -> datetime:
|
||||
# Remove all timezone abbreviations in parentheses
|
||||
datetime_str = re.sub(r"\([A-Z]+\)", "", datetime_str).strip()
|
||||
normalized = re.sub(r"\([A-Z]+\)", "", datetime_str).strip()
|
||||
|
||||
# Remove any remaining parentheses and their contents
|
||||
datetime_str = re.sub(r"\(.*?\)", "", datetime_str).strip()
|
||||
normalized = re.sub(r"\(.*?\)", "", normalized).strip()
|
||||
|
||||
try:
|
||||
dt = parse(datetime_str)
|
||||
except ValueError:
|
||||
# Fix common format issues (e.g. "0000" => "+0000")
|
||||
if "0000" in datetime_str:
|
||||
datetime_str = datetime_str.replace(" 0000", " +0000")
|
||||
dt = parse(datetime_str)
|
||||
else:
|
||||
raise
|
||||
candidates: list[str] = [normalized]
|
||||
|
||||
return datetime_to_utc(dt)
|
||||
# Some sources (e.g. Gmail) may prefix the value with labels like "Date:"
|
||||
label_stripped = re.sub(
|
||||
r"^\s*[A-Za-z][A-Za-z\s_-]*:\s*", "", normalized, count=1
|
||||
).strip()
|
||||
if label_stripped and label_stripped != normalized:
|
||||
candidates.append(label_stripped)
|
||||
|
||||
# Fix common format issues (e.g. "0000" => "+0000")
|
||||
for candidate in list(candidates):
|
||||
if " 0000" in candidate:
|
||||
fixed = candidate.replace(" 0000", " +0000")
|
||||
if fixed not in candidates:
|
||||
candidates.append(fixed)
|
||||
|
||||
last_exception: Exception | None = None
|
||||
for candidate in candidates:
|
||||
try:
|
||||
dt = parse(candidate)
|
||||
return datetime_to_utc(dt)
|
||||
except (ValueError, ParserError) as exc:
|
||||
last_exception = exc
|
||||
|
||||
if last_exception is not None:
|
||||
raise last_exception
|
||||
|
||||
# Fallback in case parsing failed without raising (should not happen)
|
||||
raise ValueError(f"Unable to parse datetime string: {datetime_str}")
|
||||
|
||||
|
||||
def basic_expert_info_representation(info: BasicExpertInfo) -> str | None:
|
||||
|
||||
@@ -58,6 +58,8 @@ EMAIL_FIELDS = [
|
||||
"to",
|
||||
]
|
||||
|
||||
MAX_MESSAGE_BODY_BYTES = 10 * 1024 * 1024 # 10MB cap to keep large threads safe
|
||||
|
||||
add_retries = retry_builder(tries=50, max_delay=30)
|
||||
|
||||
|
||||
@@ -120,16 +122,52 @@ def _get_owners_from_emails(emails: dict[str, str | None]) -> list[BasicExpertIn
|
||||
|
||||
|
||||
def _get_message_body(payload: dict[str, Any]) -> str:
|
||||
parts = payload.get("parts", [])
|
||||
message_body = ""
|
||||
for part in parts:
|
||||
"""
|
||||
Gmail threads can contain large inline parts (including attachments
|
||||
transmitted as base64). Only decode text/plain parts and skip anything
|
||||
that breaches the safety threshold to protect against OOMs.
|
||||
"""
|
||||
|
||||
message_body_chunks: list[str] = []
|
||||
stack = [payload]
|
||||
|
||||
while stack:
|
||||
part = stack.pop()
|
||||
if not part:
|
||||
continue
|
||||
|
||||
children = part.get("parts", [])
|
||||
stack.extend(reversed(children))
|
||||
|
||||
mime_type = part.get("mimeType")
|
||||
body = part.get("body")
|
||||
if mime_type == "text/plain" and body:
|
||||
data = body.get("data", "")
|
||||
if mime_type != "text/plain":
|
||||
continue
|
||||
|
||||
body = part.get("body", {})
|
||||
data = body.get("data", "")
|
||||
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# base64 inflates storage by ~4/3; work with decoded size estimate
|
||||
approx_decoded_size = (len(data) * 3) // 4
|
||||
if approx_decoded_size > MAX_MESSAGE_BODY_BYTES:
|
||||
logger.warning(
|
||||
"Skipping oversized Gmail message part (%s bytes > %s limit)",
|
||||
approx_decoded_size,
|
||||
MAX_MESSAGE_BODY_BYTES,
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
text = urlsafe_b64decode(data).decode()
|
||||
message_body += text
|
||||
return message_body
|
||||
except (ValueError, UnicodeDecodeError) as error:
|
||||
logger.warning("Failed to decode Gmail message part: %s", error)
|
||||
continue
|
||||
|
||||
message_body_chunks.append(text)
|
||||
|
||||
return "".join(message_body_chunks)
|
||||
|
||||
|
||||
def message_to_section(message: Dict[str, Any]) -> tuple[TextSection, dict[str, str]]:
|
||||
|
||||
@@ -50,6 +50,12 @@ logger = setup_logger()
|
||||
# represent smart chips (elements like dates and doc links).
|
||||
SMART_CHIP_CHAR = "\ue907"
|
||||
WEB_VIEW_LINK_KEY = "webViewLink"
|
||||
# Fallback templates for generating web links when Drive omits webViewLink.
|
||||
_FALLBACK_WEB_VIEW_LINK_TEMPLATES = {
|
||||
GDriveMimeType.DOC.value: "https://docs.google.com/document/d/{}/view",
|
||||
GDriveMimeType.SPREADSHEET.value: "https://docs.google.com/spreadsheets/d/{}/view",
|
||||
GDriveMimeType.PPT.value: "https://docs.google.com/presentation/d/{}/view",
|
||||
}
|
||||
|
||||
MAX_RETRIEVER_EMAILS = 20
|
||||
CHUNK_SIZE_BUFFER = 64 # extra bytes past the limit to read
|
||||
@@ -79,7 +85,25 @@ class PermissionSyncContext(BaseModel):
|
||||
|
||||
|
||||
def onyx_document_id_from_drive_file(file: GoogleDriveFileType) -> str:
|
||||
link = file[WEB_VIEW_LINK_KEY]
|
||||
link = file.get(WEB_VIEW_LINK_KEY)
|
||||
if not link:
|
||||
file_id = file.get("id")
|
||||
if not file_id:
|
||||
raise KeyError(
|
||||
f"Google Drive file missing both '{WEB_VIEW_LINK_KEY}' and 'id' fields."
|
||||
)
|
||||
mime_type = file.get("mimeType", "")
|
||||
template = _FALLBACK_WEB_VIEW_LINK_TEMPLATES.get(mime_type)
|
||||
if template is None:
|
||||
link = f"https://drive.google.com/file/d/{file_id}/view"
|
||||
else:
|
||||
link = template.format(file_id)
|
||||
logger.debug(
|
||||
"Missing webViewLink for Google Drive file with id %s. "
|
||||
"Falling back to constructed link %s",
|
||||
file_id,
|
||||
link,
|
||||
)
|
||||
parsed_url = urlparse(link)
|
||||
parsed_url = parsed_url._replace(query="") # remove query parameters
|
||||
spl_path = parsed_url.path.split("/")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
import socket
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
@@ -152,6 +153,12 @@ def _execute_single_retrieval(
|
||||
else:
|
||||
logger.exception("Error executing request:")
|
||||
raise e
|
||||
except (TimeoutError, socket.timeout) as error:
|
||||
logger.warning(
|
||||
"Timed out executing Google API request; retrying with backoff. Details: %s",
|
||||
error,
|
||||
)
|
||||
results = add_retries(lambda: retrieval_function(**request_kwargs).execute())()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -191,7 +191,7 @@ class CredentialsProviderInterface(abc.ABC, Generic[T]):
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_dynamic(self) -> bool:
|
||||
"""If dynamic, the credentials may change during usage ... maening the client
|
||||
"""If dynamic, the credentials may change during usage ... meaning the client
|
||||
needs to use the locking features of the credentials provider to operate
|
||||
correctly.
|
||||
|
||||
|
||||
@@ -644,6 +644,7 @@ class JiraConnector(
|
||||
jql=self.jql_query,
|
||||
start=0,
|
||||
max_results=1,
|
||||
all_issue_ids=[],
|
||||
)
|
||||
),
|
||||
None,
|
||||
|
||||
@@ -57,6 +57,8 @@ from onyx.connectors.sharepoint.connector_utils import get_sharepoint_external_a
|
||||
from onyx.file_processing.extract_file_text import ACCEPTED_IMAGE_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
from onyx.file_processing.extract_file_text import OnyxExtensionType
|
||||
from onyx.file_processing.file_validation import EXCLUDED_IMAGE_TYPES
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
@@ -770,7 +772,7 @@ class SharepointConnector(
|
||||
try:
|
||||
site = self.graph_client.sites.get_by_url(site_descriptor.url)
|
||||
drives = site.drives.get().execute_query()
|
||||
logger.debug(f"Found drives: {[drive.name for drive in drives]}")
|
||||
logger.info(f"Found drives: {[drive.name for drive in drives]}")
|
||||
|
||||
drives = [
|
||||
drive
|
||||
@@ -782,19 +784,23 @@ class SharepointConnector(
|
||||
if drive is None:
|
||||
logger.warning(f"Drive '{drive_name}' not found")
|
||||
return []
|
||||
|
||||
logger.info(f"Found drive: {drive.name}")
|
||||
try:
|
||||
root_folder = drive.root
|
||||
if site_descriptor.folder_path:
|
||||
for folder_part in site_descriptor.folder_path.split("/"):
|
||||
root_folder = root_folder.get_by_path(folder_part)
|
||||
|
||||
logger.info(f"Found root folder: {root_folder.name}")
|
||||
|
||||
# TODO: consider ways to avoid materializing the entire list of files in memory
|
||||
query = root_folder.get_files(
|
||||
recursive=True,
|
||||
page_size=1000,
|
||||
)
|
||||
driveitems = query.execute_query()
|
||||
logger.debug(f"Found {len(driveitems)} items in drive '{drive_name}'")
|
||||
logger.info(f"Found {len(driveitems)} items in drive '{drive_name}'")
|
||||
|
||||
# Filter items based on folder path if specified
|
||||
if site_descriptor.folder_path:
|
||||
@@ -833,7 +839,7 @@ class SharepointConnector(
|
||||
<= item.last_modified_datetime.replace(tzinfo=timezone.utc)
|
||||
<= end
|
||||
]
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"Found {len(driveitems)} items within time window in drive '{drive.name}'"
|
||||
)
|
||||
|
||||
@@ -1420,6 +1426,9 @@ class SharepointConnector(
|
||||
return checkpoint
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"Fetching drive items for drive name: {current_drive_name}"
|
||||
)
|
||||
driveitems = self._get_drive_items_for_drive_name(
|
||||
site_descriptor, current_drive_name, start_dt, end_dt
|
||||
)
|
||||
@@ -1453,6 +1462,12 @@ class SharepointConnector(
|
||||
)
|
||||
for driveitem in driveitems:
|
||||
driveitem_extension = get_file_ext(driveitem.name)
|
||||
if not is_accepted_file_ext(driveitem_extension, OnyxExtensionType.All):
|
||||
logger.warning(
|
||||
f"Skipping {driveitem.web_url} as it is not a supported file type"
|
||||
)
|
||||
continue
|
||||
|
||||
# Only yield empty documents if they are PDFs or images
|
||||
should_yield_if_empty = (
|
||||
driveitem_extension in ACCEPTED_IMAGE_FILE_EXTENSIONS
|
||||
@@ -1476,6 +1491,10 @@ class SharepointConnector(
|
||||
TextSection(link=driveitem.web_url, text="")
|
||||
]
|
||||
yield doc
|
||||
else:
|
||||
logger.warning(
|
||||
f"Skipping {driveitem.web_url} as it is empty and not a PDF or image"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to process driveitem {driveitem.web_url}: {e}"
|
||||
|
||||
@@ -378,6 +378,16 @@ def _update_request_url(request: RequestOptions, next_url: str) -> None:
|
||||
request.url = next_url
|
||||
|
||||
|
||||
def _add_prefer_header(request: RequestOptions) -> None:
|
||||
"""Add Prefer header to work around Microsoft Graph API ampersand bug.
|
||||
See: https://developer.microsoft.com/en-us/graph/known-issues/?search=18185
|
||||
"""
|
||||
if not hasattr(request, "headers") or request.headers is None:
|
||||
request.headers = {}
|
||||
# Add header to handle properly encoded ampersands in filters
|
||||
request.headers["Prefer"] = "legacySearch=false"
|
||||
|
||||
|
||||
def _collect_all_teams(
|
||||
graph_client: GraphClient,
|
||||
requested: list[str] | None = None,
|
||||
@@ -385,26 +395,52 @@ def _collect_all_teams(
|
||||
teams: list[Team] = []
|
||||
next_url: str | None = None
|
||||
|
||||
# Build OData filter for requested teams
|
||||
# Only escape single quotes for OData syntax - the library handles URL encoding
|
||||
filter = None
|
||||
if requested:
|
||||
filter = " or ".join(f"displayName eq '{team_name}'" for team_name in requested)
|
||||
use_filter = bool(requested)
|
||||
if use_filter and requested:
|
||||
filter_parts = []
|
||||
for name in requested:
|
||||
# Escape single quotes for OData syntax (replace ' with '')
|
||||
# The office365 library will handle URL encoding of the entire filter
|
||||
escaped_name = name.replace("'", "''")
|
||||
filter_parts.append(f"displayName eq '{escaped_name}'")
|
||||
filter = " or ".join(filter_parts)
|
||||
|
||||
while True:
|
||||
if filter:
|
||||
query = graph_client.teams.get().filter(filter)
|
||||
else:
|
||||
query = graph_client.teams.get_all(
|
||||
# explicitly needed because of incorrect type definitions provided by the `office365` library
|
||||
page_loaded=lambda _: None
|
||||
)
|
||||
try:
|
||||
if filter:
|
||||
query = graph_client.teams.get().filter(filter)
|
||||
# Add header to work around Microsoft Graph API ampersand bug
|
||||
query.before_execute(lambda req: _add_prefer_header(request=req))
|
||||
else:
|
||||
query = graph_client.teams.get_all(
|
||||
# explicitly needed because of incorrect type definitions provided by the `office365` library
|
||||
page_loaded=lambda _: None
|
||||
)
|
||||
|
||||
if next_url:
|
||||
url = next_url
|
||||
query.before_execute(
|
||||
lambda req: _update_request_url(request=req, next_url=url)
|
||||
)
|
||||
if next_url:
|
||||
url = next_url
|
||||
query.before_execute(
|
||||
lambda req: _update_request_url(request=req, next_url=url)
|
||||
)
|
||||
|
||||
team_collection = query.execute_query()
|
||||
except (ClientRequestException, ValueError) as e:
|
||||
# If OData filter fails, fallback to client-side filtering
|
||||
if use_filter:
|
||||
logger.warning(
|
||||
f"OData filter failed with {type(e).__name__}: {e}. "
|
||||
f"Falling back to client-side filtering."
|
||||
)
|
||||
use_filter = False
|
||||
filter = None
|
||||
teams = []
|
||||
next_url = None
|
||||
continue
|
||||
raise
|
||||
|
||||
team_collection = query.execute_query()
|
||||
filtered_teams = (
|
||||
team
|
||||
for team in team_collection
|
||||
|
||||
@@ -535,7 +535,8 @@ class WebConnector(LoadConnector):
|
||||
id=initial_url,
|
||||
sections=[TextSection(link=initial_url, text=page_text)],
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=initial_url.split("/")[-1],
|
||||
semantic_identifier=initial_url.rstrip("/").split("/")[-1]
|
||||
or initial_url,
|
||||
metadata=metadata,
|
||||
doc_updated_at=(
|
||||
_get_datetime_from_last_modified_header(last_modified)
|
||||
|
||||
@@ -959,6 +959,13 @@ def translate_db_message_to_chat_message_detail(
|
||||
chat_message: ChatMessage,
|
||||
remove_doc_content: bool = False,
|
||||
) -> ChatMessageDetail:
|
||||
# Get current feedback if any
|
||||
current_feedback = None
|
||||
if chat_message.chat_message_feedbacks:
|
||||
latest_feedback = chat_message.chat_message_feedbacks[-1]
|
||||
if latest_feedback.is_positive is not None:
|
||||
current_feedback = "like" if latest_feedback.is_positive else "dislike"
|
||||
|
||||
chat_msg_detail = ChatMessageDetail(
|
||||
chat_session_id=chat_message.chat_session_id,
|
||||
message_id=chat_message.id,
|
||||
@@ -986,6 +993,7 @@ def translate_db_message_to_chat_message_detail(
|
||||
alternate_assistant_id=chat_message.alternate_assistant_id,
|
||||
overridden_model=chat_message.overridden_model,
|
||||
error=chat_message.error,
|
||||
current_feedback=current_feedback,
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
@@ -1188,6 +1196,7 @@ def update_db_session_with_messages(
|
||||
final_documents: list[SearchDoc] | None = None,
|
||||
update_parent_message: bool = True,
|
||||
research_answer_purpose: ResearchAnswerPurpose | None = None,
|
||||
files: list[FileDescriptor] | None = None,
|
||||
commit: bool = False,
|
||||
) -> ChatMessage:
|
||||
chat_message = (
|
||||
@@ -1230,6 +1239,9 @@ def update_db_session_with_messages(
|
||||
if research_answer_purpose:
|
||||
chat_message.research_answer_purpose = research_answer_purpose
|
||||
|
||||
if files is not None:
|
||||
chat_message.files = files
|
||||
|
||||
if update_parent_message:
|
||||
parent_chat_message = (
|
||||
db_session.query(ChatMessage)
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.connectors.models import InputType
|
||||
from onyx.db.enums import IndexingMode
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import FederatedConnector
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.kg.models import KGConnectorData
|
||||
from onyx.server.documents.models import ConnectorBase
|
||||
@@ -25,6 +26,12 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def check_federated_connectors_exist(db_session: Session) -> bool:
|
||||
stmt = select(exists(FederatedConnector))
|
||||
result = db_session.execute(stmt)
|
||||
return result.scalar() or False
|
||||
|
||||
|
||||
def check_connectors_exist(db_session: Session) -> bool:
|
||||
# Connector 0 is created on server startup as a default for ingestion
|
||||
# it will always exist and we don't need to count it for this
|
||||
|
||||
@@ -2,3 +2,12 @@ SLACK_BOT_PERSONA_PREFIX = "__slack_bot_persona__"
|
||||
DEFAULT_PERSONA_SLACK_CHANNEL_NAME = "DEFAULT_SLACK_CHANNEL"
|
||||
|
||||
CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX = "ConnectorValidationError:"
|
||||
|
||||
|
||||
# Sentinel value to distinguish between "not provided" and "explicitly set to None"
|
||||
class UnsetType:
|
||||
def __repr__(self) -> str:
|
||||
return "<UNSET>"
|
||||
|
||||
|
||||
UNSET = UnsetType()
|
||||
|
||||
@@ -173,3 +173,9 @@ class UserFileStatus(str, PyEnum):
|
||||
FAILED = "FAILED"
|
||||
CANCELED = "CANCELED"
|
||||
DELETING = "DELETING"
|
||||
|
||||
|
||||
class ThemePreference(str, PyEnum):
|
||||
LIGHT = "light"
|
||||
DARK = "dark"
|
||||
SYSTEM = "system"
|
||||
|
||||
@@ -262,3 +262,24 @@ def create_chat_message_feedback(
|
||||
|
||||
db_session.add(message_feedback)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_chat_message_feedback(
|
||||
chat_message_id: int,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Remove all feedback for a chat message."""
|
||||
chat_message = get_chat_message(
|
||||
chat_message_id=chat_message_id, user_id=user_id, db_session=db_session
|
||||
)
|
||||
|
||||
if chat_message.message_type != MessageType.ASSISTANT:
|
||||
raise ValueError("Can only remove feedback from LLM Outputs")
|
||||
|
||||
# Delete all feedback for this message
|
||||
db_session.query(ChatMessageFeedback).filter(
|
||||
ChatMessageFeedback.chat_message_id == chat_message_id
|
||||
).delete()
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -79,11 +79,21 @@ def upsert_llm_provider(
|
||||
existing_llm_provider = LLMProviderModel(name=llm_provider_upsert_request.name)
|
||||
db_session.add(existing_llm_provider)
|
||||
|
||||
# Filter out empty strings and None values from custom_config to allow
|
||||
# providers like Bedrock to fall back to IAM roles when credentials are not provided
|
||||
custom_config = llm_provider_upsert_request.custom_config
|
||||
if custom_config:
|
||||
custom_config = {
|
||||
k: v for k, v in custom_config.items() if v is not None and v.strip() != ""
|
||||
}
|
||||
# Set to None if the dict is empty after filtering
|
||||
custom_config = custom_config if custom_config else None
|
||||
|
||||
existing_llm_provider.provider = llm_provider_upsert_request.provider
|
||||
existing_llm_provider.api_key = llm_provider_upsert_request.api_key
|
||||
existing_llm_provider.api_base = llm_provider_upsert_request.api_base
|
||||
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
|
||||
existing_llm_provider.custom_config = llm_provider_upsert_request.custom_config
|
||||
existing_llm_provider.custom_config = custom_config
|
||||
existing_llm_provider.default_model_name = (
|
||||
llm_provider_upsert_request.default_model_name
|
||||
)
|
||||
|
||||
@@ -65,6 +65,7 @@ from onyx.db.enums import (
|
||||
UserFileStatus,
|
||||
MCPAuthenticationPerformer,
|
||||
MCPTransport,
|
||||
ThemePreference,
|
||||
)
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import SearchFeedbackType
|
||||
@@ -183,6 +184,11 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
)
|
||||
auto_scroll: Mapped[bool | None] = mapped_column(Boolean, default=None)
|
||||
shortcut_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
theme_preference: Mapped[ThemePreference | None] = mapped_column(
|
||||
Enum(ThemePreference, native_enum=False),
|
||||
nullable=True,
|
||||
default=None,
|
||||
)
|
||||
# personalization fields are exposed via the chat user settings "Personalization" tab
|
||||
personal_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
personal_role: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
@@ -248,6 +254,11 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
cascade="all, delete-orphan",
|
||||
lazy="selectin",
|
||||
)
|
||||
oauth_user_tokens: Mapped[list["OAuthUserToken"]] = relationship(
|
||||
"OAuthUserToken",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
@validates("email")
|
||||
def validate_email(self, key: str, value: str) -> str:
|
||||
@@ -2509,9 +2520,16 @@ class Tool(Base):
|
||||
mcp_server_id: Mapped[int | None] = mapped_column(
|
||||
Integer, ForeignKey("mcp_server.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
# OAuth configuration for this tool (null for tools without OAuth)
|
||||
oauth_config_id: Mapped[int | None] = mapped_column(
|
||||
Integer, ForeignKey("oauth_config.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
user: Mapped[User | None] = relationship("User", back_populates="custom_tools")
|
||||
oauth_config: Mapped["OAuthConfig | None"] = relationship(
|
||||
"OAuthConfig", back_populates="tools"
|
||||
)
|
||||
# Relationship to Persona through the association table
|
||||
personas: Mapped[list["Persona"]] = relationship(
|
||||
"Persona",
|
||||
@@ -2524,6 +2542,92 @@ class Tool(Base):
|
||||
)
|
||||
|
||||
|
||||
class OAuthConfig(Base):
|
||||
"""OAuth provider configuration that can be shared across multiple tools"""
|
||||
|
||||
__tablename__ = "oauth_config"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True, nullable=False)
|
||||
|
||||
# OAuth provider endpoints
|
||||
authorization_url: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
token_url: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
|
||||
# Client credentials (encrypted)
|
||||
client_id: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
|
||||
client_secret: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
|
||||
|
||||
# Optional configurations
|
||||
scopes: Mapped[list[str] | None] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
additional_params: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
|
||||
# Metadata
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
tools: Mapped[list["Tool"]] = relationship("Tool", back_populates="oauth_config")
|
||||
user_tokens: Mapped[list["OAuthUserToken"]] = relationship(
|
||||
"OAuthUserToken", back_populates="oauth_config", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class OAuthUserToken(Base):
|
||||
"""Per-user OAuth tokens for a specific OAuth configuration"""
|
||||
|
||||
__tablename__ = "oauth_user_token"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
oauth_config_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("oauth_config.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# Token data (encrypted)
|
||||
# Structure: {
|
||||
# "access_token": "...",
|
||||
# "refresh_token": "...", # Optional
|
||||
# "token_type": "Bearer",
|
||||
# "expires_at": 1234567890, # Unix timestamp, optional
|
||||
# "scope": "repo user" # Optional
|
||||
# }
|
||||
token_data: Mapped[dict[str, Any]] = mapped_column(EncryptedJson(), nullable=False)
|
||||
|
||||
# Metadata
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
oauth_config: Mapped["OAuthConfig"] = relationship(
|
||||
"OAuthConfig", back_populates="user_tokens"
|
||||
)
|
||||
user: Mapped["User"] = relationship("User")
|
||||
|
||||
# Unique constraint: One token per user per OAuth config
|
||||
__table_args__ = (
|
||||
UniqueConstraint("oauth_config_id", "user_id", name="uq_oauth_user_token"),
|
||||
)
|
||||
|
||||
|
||||
class StarterMessage(BaseModel):
|
||||
"""Starter message for a persona."""
|
||||
|
||||
|
||||
195
backend/onyx/db/oauth_config.py
Normal file
195
backend/onyx/db/oauth_config.py
Normal file
@@ -0,0 +1,195 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.models import OAuthUserToken
|
||||
from onyx.db.models import Tool
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# OAuth Config CRUD operations
|
||||
|
||||
|
||||
def create_oauth_config(
|
||||
name: str,
|
||||
authorization_url: str,
|
||||
token_url: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
scopes: list[str] | None,
|
||||
additional_params: dict[str, str] | None,
|
||||
db_session: Session,
|
||||
) -> OAuthConfig:
|
||||
"""Create a new OAuth configuration"""
|
||||
oauth_config = OAuthConfig(
|
||||
name=name,
|
||||
authorization_url=authorization_url,
|
||||
token_url=token_url,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
scopes=scopes,
|
||||
additional_params=additional_params,
|
||||
)
|
||||
db_session.add(oauth_config)
|
||||
db_session.commit()
|
||||
return oauth_config
|
||||
|
||||
|
||||
def get_oauth_config(oauth_config_id: int, db_session: Session) -> OAuthConfig | None:
|
||||
"""Get OAuth configuration by ID"""
|
||||
return db_session.scalar(
|
||||
select(OAuthConfig).where(OAuthConfig.id == oauth_config_id)
|
||||
)
|
||||
|
||||
|
||||
def get_oauth_configs(db_session: Session) -> list[OAuthConfig]:
|
||||
"""Get all OAuth configurations"""
|
||||
return list(db_session.scalars(select(OAuthConfig)).all())
|
||||
|
||||
|
||||
def update_oauth_config(
|
||||
oauth_config_id: int,
|
||||
db_session: Session,
|
||||
name: str | None = None,
|
||||
authorization_url: str | None = None,
|
||||
token_url: str | None = None,
|
||||
client_id: str | None = None,
|
||||
client_secret: str | None = None,
|
||||
scopes: list[str] | None = None,
|
||||
additional_params: dict[str, Any] | None = None,
|
||||
clear_client_id: bool = False,
|
||||
clear_client_secret: bool = False,
|
||||
) -> OAuthConfig:
|
||||
"""
|
||||
Update OAuth configuration.
|
||||
|
||||
NOTE: If client_id or client_secret are None, existing values are preserved.
|
||||
To clear these values, set clear_client_id or clear_client_secret to True.
|
||||
This allows partial updates without re-entering secrets.
|
||||
"""
|
||||
oauth_config = db_session.scalar(
|
||||
select(OAuthConfig).where(OAuthConfig.id == oauth_config_id)
|
||||
)
|
||||
if oauth_config is None:
|
||||
raise ValueError(f"OAuth config with id {oauth_config_id} does not exist")
|
||||
|
||||
# Update only provided fields
|
||||
if name is not None:
|
||||
oauth_config.name = name
|
||||
if authorization_url is not None:
|
||||
oauth_config.authorization_url = authorization_url
|
||||
if token_url is not None:
|
||||
oauth_config.token_url = token_url
|
||||
if clear_client_id:
|
||||
oauth_config.client_id = ""
|
||||
elif client_id is not None:
|
||||
oauth_config.client_id = client_id
|
||||
if clear_client_secret:
|
||||
oauth_config.client_secret = ""
|
||||
elif client_secret is not None:
|
||||
oauth_config.client_secret = client_secret
|
||||
if scopes is not None:
|
||||
oauth_config.scopes = scopes
|
||||
if additional_params is not None:
|
||||
oauth_config.additional_params = additional_params
|
||||
|
||||
db_session.commit()
|
||||
return oauth_config
|
||||
|
||||
|
||||
def delete_oauth_config(oauth_config_id: int, db_session: Session) -> None:
|
||||
"""
|
||||
Delete OAuth configuration.
|
||||
|
||||
Sets oauth_config_id to NULL for associated tools due to SET NULL foreign key.
|
||||
Cascades delete to user tokens.
|
||||
"""
|
||||
oauth_config = db_session.scalar(
|
||||
select(OAuthConfig).where(OAuthConfig.id == oauth_config_id)
|
||||
)
|
||||
if oauth_config is None:
|
||||
raise ValueError(f"OAuth config with id {oauth_config_id} does not exist")
|
||||
|
||||
db_session.delete(oauth_config)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# User Token operations
|
||||
|
||||
|
||||
def get_user_oauth_token(
|
||||
oauth_config_id: int, user_id: UUID, db_session: Session
|
||||
) -> OAuthUserToken | None:
|
||||
"""Get user's OAuth token for a specific configuration"""
|
||||
return db_session.scalar(
|
||||
select(OAuthUserToken).where(
|
||||
OAuthUserToken.oauth_config_id == oauth_config_id,
|
||||
OAuthUserToken.user_id == user_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_all_user_oauth_tokens(
|
||||
user_id: UUID, db_session: Session
|
||||
) -> list[OAuthUserToken]:
|
||||
"""
|
||||
Get all user OAuth tokens.
|
||||
"""
|
||||
stmt = select(OAuthUserToken).where(OAuthUserToken.user_id == user_id)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def upsert_user_oauth_token(
|
||||
oauth_config_id: int, user_id: UUID, token_data: dict, db_session: Session
|
||||
) -> OAuthUserToken:
|
||||
"""Insert or update user's OAuth token for a specific configuration"""
|
||||
existing_token = get_user_oauth_token(oauth_config_id, user_id, db_session)
|
||||
|
||||
if existing_token:
|
||||
# Update existing token
|
||||
existing_token.token_data = token_data
|
||||
db_session.commit()
|
||||
return existing_token
|
||||
else:
|
||||
# Create new token
|
||||
new_token = OAuthUserToken(
|
||||
oauth_config_id=oauth_config_id,
|
||||
user_id=user_id,
|
||||
token_data=token_data,
|
||||
)
|
||||
db_session.add(new_token)
|
||||
db_session.commit()
|
||||
return new_token
|
||||
|
||||
|
||||
def delete_user_oauth_token(
|
||||
oauth_config_id: int, user_id: UUID, db_session: Session
|
||||
) -> None:
|
||||
"""Delete user's OAuth token for a specific configuration"""
|
||||
user_token = get_user_oauth_token(oauth_config_id, user_id, db_session)
|
||||
if user_token is None:
|
||||
raise ValueError(
|
||||
f"OAuth token for user {user_id} and config {oauth_config_id} does not exist"
|
||||
)
|
||||
|
||||
db_session.delete(user_token)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# Helper operations
|
||||
|
||||
|
||||
def get_tools_by_oauth_config(oauth_config_id: int, db_session: Session) -> list[Tool]:
|
||||
"""Get all tools that use a specific OAuth configuration"""
|
||||
return list(
|
||||
db_session.scalars(
|
||||
select(Tool).where(Tool.oauth_config_id == oauth_config_id)
|
||||
).all()
|
||||
)
|
||||
@@ -7,6 +7,8 @@ from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.constants import UNSET
|
||||
from onyx.db.constants import UnsetType
|
||||
from onyx.db.models import Tool
|
||||
from onyx.server.features.tool.models import Header
|
||||
from onyx.tools.built_in_tools import BUILT_IN_TOOL_TYPES
|
||||
@@ -62,6 +64,7 @@ def create_tool__no_commit(
|
||||
passthrough_auth: bool,
|
||||
*,
|
||||
mcp_server_id: int | None = None,
|
||||
oauth_config_id: int | None = None,
|
||||
enabled: bool = True,
|
||||
) -> Tool:
|
||||
new_tool = Tool(
|
||||
@@ -75,6 +78,7 @@ def create_tool__no_commit(
|
||||
user_id=user_id,
|
||||
passthrough_auth=passthrough_auth,
|
||||
mcp_server_id=mcp_server_id,
|
||||
oauth_config_id=oauth_config_id,
|
||||
enabled=enabled,
|
||||
)
|
||||
db_session.add(new_tool)
|
||||
@@ -91,6 +95,7 @@ def update_tool(
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
passthrough_auth: bool | None,
|
||||
oauth_config_id: int | None | UnsetType = UNSET,
|
||||
) -> Tool:
|
||||
tool = get_tool_by_id(tool_id, db_session)
|
||||
if tool is None:
|
||||
@@ -110,6 +115,8 @@ def update_tool(
|
||||
]
|
||||
if passthrough_auth is not None:
|
||||
tool.passthrough_auth = passthrough_auth
|
||||
if not isinstance(oauth_config_id, UnsetType):
|
||||
tool.oauth_config_id = oauth_config_id
|
||||
db_session.commit()
|
||||
|
||||
return tool
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import ThemePreference
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import Assistant__UserSpecificConfig
|
||||
from onyx.db.models import Memory
|
||||
@@ -124,6 +125,20 @@ def update_user_default_model(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_user_theme_preference(
|
||||
user_id: UUID,
|
||||
theme_preference: ThemePreference,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update user's theme preference setting."""
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user_id) # type: ignore
|
||||
.values(theme_preference=theme_preference)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_user_personalization(
|
||||
user_id: UUID,
|
||||
*,
|
||||
|
||||
@@ -363,4 +363,10 @@ schema {{ schema_name }} {
|
||||
expression: bm25(content) + (5 * bm25(title))
|
||||
}
|
||||
}
|
||||
|
||||
rank-profile random_ inherits default {
|
||||
first-phase {
|
||||
expression: random
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,11 +23,7 @@ class EvalConfiguration(BaseModel):
|
||||
|
||||
|
||||
class EvalConfigurationOptions(BaseModel):
|
||||
builtin_tool_types: list[str] = list(
|
||||
tool_name
|
||||
for tool_name in BUILT_IN_TOOL_MAP.keys()
|
||||
if tool_name != "OktaProfileTool"
|
||||
)
|
||||
builtin_tool_types: list[str] = list(BUILT_IN_TOOL_MAP.keys())
|
||||
persona_override_config: PersonaOverrideConfig | None = None
|
||||
llm: LLMOverride = LLMOverride(
|
||||
model_provider="Default",
|
||||
|
||||
@@ -3,4 +3,4 @@ Feature flag keys used throughout the application.
|
||||
Centralizes feature flag key definitions to avoid magic strings.
|
||||
"""
|
||||
|
||||
SIMPLE_AGENT_FRAMEWORK = "simple-agent-framework"
|
||||
DISABLE_SIMPLE_AGENT_FRAMEWORK = "disable-simple-agent-framework"
|
||||
|
||||
@@ -89,6 +89,7 @@ _MARKITDOWN_CONVERTER: Optional["MarkItDown"] = None
|
||||
KNOWN_OPENPYXL_BUGS = [
|
||||
"Value must be either numerical or a string containing a wildcard",
|
||||
"File contains no valid workbook part",
|
||||
"Unable to read workbook: could not read stylesheet from None",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -430,9 +430,20 @@ class S3BackedFileStore(FileStore):
|
||||
|
||||
# Delete from external storage
|
||||
s3_client = self._get_s3_client()
|
||||
s3_client.delete_object(
|
||||
Bucket=file_record.bucket_name, Key=file_record.object_key
|
||||
)
|
||||
try:
|
||||
s3_client.delete_object(
|
||||
Bucket=file_record.bucket_name, Key=file_record.object_key
|
||||
)
|
||||
except ClientError as e:
|
||||
# If the object doesn't exist in file store, treat it as success
|
||||
# since the end goal (object not existing) is achieved
|
||||
if e.response.get("Error", {}).get("Code") == "NoSuchKey":
|
||||
logger.warning(
|
||||
f"delete_file: File {file_id} not found in file store (key: {file_record.object_key}), "
|
||||
"cleaning up database record."
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
# Delete metadata from database
|
||||
delete_filerecord_by_file_id(file_id=file_id, db_session=db_session)
|
||||
|
||||
@@ -37,6 +37,7 @@ from onyx.configs.model_configs import LITELLM_EXTRA_BODY
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.llm_provider_options import OLLAMA_PROVIDER_NAME
|
||||
from onyx.llm.llm_provider_options import VERTEX_CREDENTIALS_FILE_KWARG
|
||||
from onyx.llm.llm_provider_options import VERTEX_LOCATION_KWARG
|
||||
from onyx.llm.utils import model_is_reasoning_model
|
||||
@@ -306,9 +307,15 @@ class DefaultMultiLLM(LLM):
|
||||
model_kwargs[k] = v
|
||||
continue
|
||||
|
||||
# for all values, set them as env variables
|
||||
os.environ[k] = v
|
||||
|
||||
# If there are any empty or null values,
|
||||
# they MUST NOT be set in the env
|
||||
if v is not None and v.strip():
|
||||
os.environ[k] = v
|
||||
else:
|
||||
os.environ.pop(k, None)
|
||||
# This is needed for Ollama to do proper function calling
|
||||
if model_provider == OLLAMA_PROVIDER_NAME and api_base is not None:
|
||||
os.environ["OLLAMA_API_BASE"] = api_base
|
||||
if extra_headers:
|
||||
model_kwargs.update({"extra_headers": extra_headers})
|
||||
if extra_body:
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from agents import ModelSettings
|
||||
from agents.models.interface import Model
|
||||
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from onyx.configs.chat_configs import QA_TIMEOUT
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.llm import fetch_default_provider
|
||||
@@ -8,6 +15,8 @@ from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_llm_provider_view
|
||||
from onyx.db.models import Persona
|
||||
from onyx.llm.chat_llm import DefaultMultiLLM
|
||||
from onyx.llm.chat_llm import VERTEX_CREDENTIALS_FILE_KWARG
|
||||
from onyx.llm.chat_llm import VERTEX_LOCATION_KWARG
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.llm_provider_options import OLLAMA_API_KEY_CONFIG_KEY
|
||||
@@ -15,6 +24,7 @@ from onyx.llm.llm_provider_options import OLLAMA_PROVIDER_NAME
|
||||
from onyx.llm.llm_provider_options import OPENROUTER_PROVIDER_NAME
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.utils import get_max_input_tokens_from_llm_provider
|
||||
from onyx.llm.utils import model_is_reasoning_model
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.utils.headers import build_llm_extra_headers
|
||||
@@ -27,7 +37,6 @@ logger = setup_logger()
|
||||
def _build_provider_extra_headers(
|
||||
provider: str, custom_config: dict[str, str] | None
|
||||
) -> dict[str, str]:
|
||||
# Ollama Cloud: allow passing Bearer token via custom config for cloud instances
|
||||
if provider == OLLAMA_PROVIDER_NAME and custom_config:
|
||||
raw_api_key = custom_config.get(OLLAMA_API_KEY_CONFIG_KEY)
|
||||
api_key = raw_api_key.strip() if raw_api_key else None
|
||||
@@ -108,6 +117,56 @@ def get_llms_for_persona(
|
||||
return _create_llm(model), _create_llm(fast_model)
|
||||
|
||||
|
||||
def get_llm_model_and_settings_for_persona(
|
||||
persona: Persona,
|
||||
llm_override: LLMOverride | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
timeout: int | None = None,
|
||||
) -> tuple[Model, ModelSettings]:
|
||||
"""Get LitellmModel and settings for a persona.
|
||||
|
||||
Returns a tuple of:
|
||||
- LitellmModel instance
|
||||
- ModelSettings configured with the persona's parameters
|
||||
"""
|
||||
provider_name_override = llm_override.model_provider if llm_override else None
|
||||
model_version_override = llm_override.model_version if llm_override else None
|
||||
temperature_override = llm_override.temperature if llm_override else None
|
||||
|
||||
provider_name = provider_name_override or persona.llm_model_provider_override
|
||||
model_name = None
|
||||
if not provider_name:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
llm_provider = fetch_default_provider(db_session)
|
||||
|
||||
if not llm_provider:
|
||||
raise ValueError("No default LLM provider found")
|
||||
|
||||
model_name = llm_provider.default_model_name
|
||||
else:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
llm_provider = fetch_llm_provider_view(db_session, provider_name)
|
||||
|
||||
model = model_version_override or persona.llm_model_version_override or model_name
|
||||
if not model:
|
||||
raise ValueError("No model name found")
|
||||
if not llm_provider:
|
||||
raise ValueError("No LLM provider found")
|
||||
|
||||
return get_llm_model_and_settings(
|
||||
provider=llm_provider.provider,
|
||||
model=model,
|
||||
deployment_name=llm_provider.deployment_name,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
temperature=temperature_override,
|
||||
timeout=timeout,
|
||||
additional_headers=additional_headers,
|
||||
)
|
||||
|
||||
|
||||
def get_default_llm_with_vision(
|
||||
timeout: int | None = None,
|
||||
temperature: float | None = None,
|
||||
@@ -311,3 +370,84 @@ def get_llm(
|
||||
long_term_logger=long_term_logger,
|
||||
max_input_tokens=max_input_tokens,
|
||||
)
|
||||
|
||||
|
||||
def get_llm_model_and_settings(
|
||||
provider: str,
|
||||
model: str,
|
||||
deployment_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
custom_config: dict[str, str] | None = None,
|
||||
temperature: float | None = None,
|
||||
timeout: int | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
model_kwargs: dict[str, Any] | None = None,
|
||||
) -> tuple[Model, ModelSettings]:
|
||||
from onyx.llm.litellm_singleton import LitellmModel
|
||||
|
||||
if temperature is None:
|
||||
temperature = GEN_AI_TEMPERATURE
|
||||
|
||||
# Configure timeout following the same pattern as DefaultMultiLLM
|
||||
if timeout is None:
|
||||
if model_is_reasoning_model(model, provider):
|
||||
timeout = QA_TIMEOUT * 10 # Reasoning models are slow
|
||||
else:
|
||||
timeout = QA_TIMEOUT
|
||||
|
||||
extra_headers = build_llm_extra_headers(additional_headers)
|
||||
|
||||
# NOTE: this is needed since Ollama API key is optional
|
||||
# User may access Ollama cloud via locally hosted instance (logged in)
|
||||
# or just via the cloud API (not logged in, using API key)
|
||||
provider_extra_headers = _build_provider_extra_headers(provider, custom_config)
|
||||
if provider_extra_headers:
|
||||
extra_headers.update(provider_extra_headers)
|
||||
|
||||
# NOTE: have to set these as environment variables for Litellm since
|
||||
# not all are able to passed in but they always support them set as env
|
||||
# variables. We'll also try passing them in, since litellm just ignores
|
||||
# addtional kwargs (and some kwargs MUST be passed in rather than set as
|
||||
# env variables)
|
||||
model_kwargs = model_kwargs or {}
|
||||
if custom_config:
|
||||
for k, v in custom_config.items():
|
||||
os.environ[k] = v
|
||||
if custom_config and provider == "vertex_ai":
|
||||
for k, v in custom_config.items():
|
||||
if k == VERTEX_CREDENTIALS_FILE_KWARG:
|
||||
model_kwargs[k] = v
|
||||
continue
|
||||
elif k == VERTEX_LOCATION_KWARG:
|
||||
model_kwargs[k] = v
|
||||
continue
|
||||
# This is needed for Ollama to do proper function calling
|
||||
if provider == OLLAMA_PROVIDER_NAME and api_base is not None:
|
||||
os.environ["OLLAMA_API_BASE"] = api_base
|
||||
if api_version:
|
||||
model_kwargs["api_version"] = api_version
|
||||
# Add timeout to model_kwargs so it gets passed to litellm
|
||||
model_kwargs["timeout"] = timeout
|
||||
# Build the full model name in provider/model format
|
||||
model_name = f"{provider}/{deployment_name or model}"
|
||||
|
||||
# Create LitellmModel instance
|
||||
litellm_model = LitellmModel(
|
||||
model=model_name,
|
||||
# NOTE: have to pass in None instead of empty string for these
|
||||
# otherwise litellm can have some issues with bedrock
|
||||
base_url=api_base or None,
|
||||
api_key=api_key or None,
|
||||
)
|
||||
|
||||
# Create ModelSettings with the provided configuration
|
||||
model_settings = ModelSettings(
|
||||
temperature=temperature,
|
||||
include_usage=True,
|
||||
extra_headers=extra_headers if extra_headers else None,
|
||||
extra_args=model_kwargs,
|
||||
)
|
||||
|
||||
return litellm_model, model_settings
|
||||
|
||||
@@ -7,14 +7,13 @@ All other modules should import litellm from here instead of directly.
|
||||
import litellm
|
||||
from agents.extensions.models.litellm_model import LitellmModel
|
||||
|
||||
from .config import initialize_litellm
|
||||
from .monkey_patches import apply_monkey_patches
|
||||
|
||||
# Import litellm
|
||||
# Initialize litellm configuration immediately on import
|
||||
# This ensures the singleton pattern - configuration happens only once
|
||||
initialize_litellm()
|
||||
apply_monkey_patches()
|
||||
|
||||
# Configure litellm settings immediately on import
|
||||
# If a user configures a different model and it doesn't support all the same
|
||||
# parameters like frequency and presence, just ignore them
|
||||
litellm.drop_params = True
|
||||
litellm.telemetry = False
|
||||
|
||||
# Export the configured litellm module
|
||||
# Export the configured litellm module and model
|
||||
__all__ = ["litellm", "LitellmModel"]
|
||||
67
backend/onyx/llm/litellm_singleton/config.py
Normal file
67
backend/onyx/llm/litellm_singleton/config.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import litellm
|
||||
|
||||
|
||||
def configure_litellm_settings() -> None:
|
||||
# If a user configures a different model and it doesn't support all the same
|
||||
# parameters like frequency and presence, just ignore them
|
||||
litellm.drop_params = True
|
||||
litellm.telemetry = False
|
||||
litellm.modify_params = True
|
||||
|
||||
|
||||
def register_ollama_models() -> None:
|
||||
litellm.register_model(
|
||||
model_cost={
|
||||
# GPT-OSS models
|
||||
"ollama_chat/gpt-oss:120b-cloud": {"supports_function_calling": True},
|
||||
"ollama_chat/gpt-oss:120b": {"supports_function_calling": True},
|
||||
"ollama_chat/gpt-oss:20b-cloud": {"supports_function_calling": True},
|
||||
"ollama_chat/gpt-oss:20b": {"supports_function_calling": True},
|
||||
# DeepSeek models
|
||||
"ollama_chat/deepseek-r1:latest": {"supports_function_calling": True},
|
||||
"ollama_chat/deepseek-r1:1.5b": {"supports_function_calling": True},
|
||||
"ollama_chat/deepseek-r1:7b": {"supports_function_calling": True},
|
||||
"ollama_chat/deepseek-r1:8b": {"supports_function_calling": True},
|
||||
"ollama_chat/deepseek-r1:14b": {"supports_function_calling": True},
|
||||
"ollama_chat/deepseek-r1:32b": {"supports_function_calling": True},
|
||||
"ollama_chat/deepseek-r1:70b": {"supports_function_calling": True},
|
||||
"ollama_chat/deepseek-r1:671b": {"supports_function_calling": True},
|
||||
"ollama_chat/deepseek-v3.1:latest": {"supports_function_calling": True},
|
||||
"ollama_chat/deepseek-v3.1:671b": {"supports_function_calling": True},
|
||||
"ollama_chat/deepseek-v3.1:671b-cloud": {"supports_function_calling": True},
|
||||
# Gemma3 models
|
||||
"ollama_chat/gemma3:latest": {"supports_function_calling": True},
|
||||
"ollama_chat/gemma3:270m": {"supports_function_calling": True},
|
||||
"ollama_chat/gemma3:1b": {"supports_function_calling": True},
|
||||
"ollama_chat/gemma3:4b": {"supports_function_calling": True},
|
||||
"ollama_chat/gemma3:12b": {"supports_function_calling": True},
|
||||
"ollama_chat/gemma3:27b": {"supports_function_calling": True},
|
||||
# Qwen models
|
||||
"ollama_chat/qwen3-coder:latest": {"supports_function_calling": True},
|
||||
"ollama_chat/qwen3-coder:30b": {"supports_function_calling": True},
|
||||
"ollama_chat/qwen3-coder:480b": {"supports_function_calling": True},
|
||||
"ollama_chat/qwen3-coder:480b-cloud": {"supports_function_calling": True},
|
||||
"ollama_chat/qwen3-vl:latest": {"supports_function_calling": True},
|
||||
"ollama_chat/qwen3-vl:2b": {"supports_function_calling": True},
|
||||
"ollama_chat/qwen3-vl:4b": {"supports_function_calling": True},
|
||||
"ollama_chat/qwen3-vl:8b": {"supports_function_calling": True},
|
||||
"ollama_chat/qwen3-vl:30b": {"supports_function_calling": True},
|
||||
"ollama_chat/qwen3-vl:32b": {"supports_function_calling": True},
|
||||
"ollama_chat/qwen3-vl:235b": {"supports_function_calling": True},
|
||||
"ollama_chat/qwen3-vl:235b-cloud": {"supports_function_calling": True},
|
||||
"ollama_chat/qwen3-vl:235b-instruct-cloud": {
|
||||
"supports_function_calling": True
|
||||
},
|
||||
# Kimi
|
||||
"ollama_chat/kimi-k2:1t": {"supports_function_calling": True},
|
||||
"ollama_chat/kimi-k2:1t-cloud": {"supports_function_calling": True},
|
||||
# GLM
|
||||
"ollama_chat/glm-4.6:cloud": {"supports_function_calling": True},
|
||||
"ollama_chat/glm-4.6": {"supports_function_calling": True},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def initialize_litellm() -> None:
|
||||
configure_litellm_settings()
|
||||
register_ollama_models()
|
||||
281
backend/onyx/llm/litellm_singleton/monkey_patches.py
Normal file
281
backend/onyx/llm/litellm_singleton/monkey_patches.py
Normal file
@@ -0,0 +1,281 @@
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
from litellm import AllMessageValues
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
extract_images_from_message,
|
||||
)
|
||||
from litellm.llms.ollama.chat.transformation import OllamaChatCompletionResponseIterator
|
||||
from litellm.llms.ollama.chat.transformation import OllamaChatConfig
|
||||
from litellm.llms.ollama.common_utils import OllamaError
|
||||
from litellm.types.llms.ollama import OllamaChatCompletionMessage
|
||||
from litellm.types.llms.ollama import OllamaToolCall
|
||||
from litellm.types.llms.ollama import OllamaToolCallFunction
|
||||
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
|
||||
from litellm.types.utils import ChatCompletionUsageBlock
|
||||
from litellm.types.utils import ModelResponseStream
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def _patch_ollama_transform_request() -> None:
|
||||
"""
|
||||
Patches OllamaChatConfig.transform_request to handle reasoning content
|
||||
and tool calls properly for Ollama chat completions.
|
||||
"""
|
||||
if (
|
||||
getattr(OllamaChatConfig.transform_request, "__name__", "")
|
||||
== "_patched_transform_request"
|
||||
):
|
||||
return
|
||||
|
||||
def _patched_transform_request(
|
||||
self: Any,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
stream = optional_params.pop("stream", False)
|
||||
format = optional_params.pop("format", None)
|
||||
keep_alive = optional_params.pop("keep_alive", None)
|
||||
think = optional_params.pop("think", None)
|
||||
function_name = optional_params.pop("function_name", None)
|
||||
litellm_params["function_name"] = function_name
|
||||
tools = optional_params.pop("tools", None)
|
||||
|
||||
new_messages = []
|
||||
for m in messages:
|
||||
if isinstance(
|
||||
m, BaseModel
|
||||
): # avoid message serialization issues - https://github.com/BerriAI/litellm/issues/5319
|
||||
m = m.model_dump(exclude_none=True)
|
||||
tool_calls = m.get("tool_calls")
|
||||
new_tools: List[OllamaToolCall] = []
|
||||
if tool_calls is not None and isinstance(tool_calls, list):
|
||||
for tool in tool_calls:
|
||||
typed_tool = ChatCompletionAssistantToolCall(**tool) # type: ignore[typeddict-item]
|
||||
if typed_tool["type"] == "function":
|
||||
arguments = {}
|
||||
if "arguments" in typed_tool["function"]:
|
||||
arguments = json.loads(typed_tool["function"]["arguments"])
|
||||
ollama_tool_call = OllamaToolCall(
|
||||
function=OllamaToolCallFunction(
|
||||
name=typed_tool["function"].get("name") or "",
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
new_tools.append(ollama_tool_call)
|
||||
cast(dict, m)["tool_calls"] = new_tools
|
||||
reasoning_content, parsed_content = _extract_reasoning_content(
|
||||
cast(dict, m)
|
||||
)
|
||||
content_str = convert_content_list_to_str(cast(AllMessageValues, m))
|
||||
images = extract_images_from_message(cast(AllMessageValues, m))
|
||||
|
||||
ollama_message = OllamaChatCompletionMessage(
|
||||
role=cast(str, m.get("role")),
|
||||
)
|
||||
if reasoning_content is not None:
|
||||
ollama_message["thinking"] = reasoning_content
|
||||
if content_str is not None:
|
||||
ollama_message["content"] = content_str
|
||||
if images is not None:
|
||||
ollama_message["images"] = images
|
||||
if new_tools:
|
||||
ollama_message["tool_calls"] = new_tools
|
||||
|
||||
new_messages.append(ollama_message)
|
||||
|
||||
# Load Config
|
||||
config = self.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params:
|
||||
optional_params[k] = v
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": new_messages,
|
||||
"options": optional_params,
|
||||
"stream": stream,
|
||||
}
|
||||
if format is not None:
|
||||
data["format"] = format
|
||||
if tools is not None:
|
||||
data["tools"] = tools
|
||||
if keep_alive is not None:
|
||||
data["keep_alive"] = keep_alive
|
||||
if think is not None:
|
||||
data["think"] = think
|
||||
|
||||
return data
|
||||
|
||||
OllamaChatConfig.transform_request = _patched_transform_request # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_ollama_chunk_parser() -> None:
|
||||
"""
|
||||
Patches OllamaChatCompletionResponseIterator.chunk_parser to properly handle
|
||||
reasoning content and content in streaming responses.
|
||||
"""
|
||||
if (
|
||||
getattr(OllamaChatCompletionResponseIterator.chunk_parser, "__name__", "")
|
||||
== "_patched_chunk_parser"
|
||||
):
|
||||
return
|
||||
|
||||
def _patched_chunk_parser(self: Any, chunk: dict) -> ModelResponseStream:
|
||||
try:
|
||||
"""
|
||||
Expected chunk format:
|
||||
{
|
||||
"model": "llama3.1",
|
||||
"created_at": "2025-05-24T02:12:05.859654Z",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [{
|
||||
"function": {
|
||||
"name": "get_latest_album_ratings",
|
||||
"arguments": {
|
||||
"artist_name": "Taylor Swift"
|
||||
}
|
||||
}
|
||||
}]
|
||||
},
|
||||
"done_reason": "stop",
|
||||
"done": true,
|
||||
...
|
||||
}
|
||||
Need to:
|
||||
- convert 'message' to 'delta'
|
||||
- return finish_reason when done is true
|
||||
- return usage when done is true
|
||||
"""
|
||||
from litellm.types.utils import Delta
|
||||
from litellm.types.utils import StreamingChoices
|
||||
|
||||
# process tool calls - if complete function arg - add id to tool call
|
||||
tool_calls = chunk["message"].get("tool_calls")
|
||||
if tool_calls is not None:
|
||||
for tool_call in tool_calls:
|
||||
function_args = tool_call.get("function").get("arguments")
|
||||
if function_args is not None and len(function_args) > 0:
|
||||
is_function_call_complete = self._is_function_call_complete(
|
||||
function_args
|
||||
)
|
||||
if is_function_call_complete:
|
||||
tool_call["id"] = str(uuid.uuid4())
|
||||
|
||||
# PROCESS REASONING CONTENT
|
||||
reasoning_content: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
if chunk["message"].get("thinking") is not None:
|
||||
# Always process thinking content when present
|
||||
reasoning_content = chunk["message"].get("thinking")
|
||||
if self.started_reasoning_content is False:
|
||||
self.started_reasoning_content = True
|
||||
elif chunk["message"].get("content") is not None:
|
||||
# Mark thinking as finished when we start getting regular content
|
||||
if (
|
||||
self.started_reasoning_content
|
||||
and not self.finished_reasoning_content
|
||||
):
|
||||
self.finished_reasoning_content = True
|
||||
|
||||
message_content = chunk["message"].get("content")
|
||||
if "<think>" in message_content:
|
||||
message_content = message_content.replace("<think>", "")
|
||||
self.started_reasoning_content = True
|
||||
if "</think>" in message_content and self.started_reasoning_content:
|
||||
message_content = message_content.replace("</think>", "")
|
||||
self.finished_reasoning_content = True
|
||||
if (
|
||||
self.started_reasoning_content
|
||||
and not self.finished_reasoning_content
|
||||
):
|
||||
reasoning_content = message_content
|
||||
else:
|
||||
content = message_content
|
||||
|
||||
delta = Delta(
|
||||
content=content,
|
||||
reasoning_content=reasoning_content,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
if chunk["done"] is True:
|
||||
finish_reason = chunk.get("done_reason", "stop")
|
||||
choices = [
|
||||
StreamingChoices(
|
||||
delta=delta,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
]
|
||||
else:
|
||||
choices = [
|
||||
StreamingChoices(
|
||||
delta=delta,
|
||||
)
|
||||
]
|
||||
|
||||
usage = ChatCompletionUsageBlock(
|
||||
prompt_tokens=chunk.get("prompt_eval_count", 0),
|
||||
completion_tokens=chunk.get("eval_count", 0),
|
||||
total_tokens=chunk.get("prompt_eval_count", 0)
|
||||
+ chunk.get("eval_count", 0),
|
||||
)
|
||||
|
||||
return ModelResponseStream(
|
||||
id=str(uuid.uuid4()),
|
||||
object="chat.completion.chunk",
|
||||
created=int(time.time()), # ollama created_at is in UTC
|
||||
usage=usage,
|
||||
model=chunk["model"],
|
||||
choices=choices,
|
||||
)
|
||||
except KeyError as e:
|
||||
raise OllamaError(
|
||||
message=f"KeyError: {e}, Got unexpected response from Ollama: {chunk}",
|
||||
status_code=400,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
OllamaChatCompletionResponseIterator.chunk_parser = _patched_chunk_parser # type: ignore[method-assign]
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for Ollama compatibility.
|
||||
|
||||
This includes:
|
||||
- Patching OllamaChatConfig.transform_request for reasoning content support
|
||||
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
|
||||
"""
|
||||
_patch_ollama_transform_request()
|
||||
_patch_ollama_chunk_parser()
|
||||
|
||||
|
||||
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
_parse_content_for_reasoning,
|
||||
)
|
||||
|
||||
message_content = message.get("content")
|
||||
if "reasoning_content" in message:
|
||||
return message["reasoning_content"], message["content"]
|
||||
elif "reasoning" in message:
|
||||
return message["reasoning"], message["content"]
|
||||
elif isinstance(message_content, str):
|
||||
return _parse_content_for_reasoning(message_content)
|
||||
return None, message_content
|
||||
@@ -132,7 +132,7 @@ def _build_bedrock_region_options() -> list[CustomConfigOption]:
|
||||
|
||||
BEDROCK_REGION_OPTIONS = _build_bedrock_region_options()
|
||||
|
||||
OLLAMA_PROVIDER_NAME = "ollama"
|
||||
OLLAMA_PROVIDER_NAME = "ollama_chat"
|
||||
OLLAMA_API_KEY_CONFIG_KEY = "OLLAMA_API_KEY"
|
||||
|
||||
# OpenRouter
|
||||
|
||||
@@ -53,6 +53,30 @@ ONE_MILLION = 1_000_000
|
||||
CHUNKS_PER_DOC_ESTIMATE = 5
|
||||
|
||||
|
||||
def _unwrap_nested_exception(error: Exception) -> Exception:
|
||||
"""
|
||||
Traverse common exception wrappers to surface the underlying LiteLLM error.
|
||||
"""
|
||||
visited: set[int] = set()
|
||||
current = error
|
||||
for _ in range(100):
|
||||
visited.add(id(current))
|
||||
candidate: Exception | None = None
|
||||
cause = getattr(current, "__cause__", None)
|
||||
if isinstance(cause, Exception):
|
||||
candidate = cause
|
||||
elif (
|
||||
hasattr(current, "args")
|
||||
and len(getattr(current, "args")) == 1
|
||||
and isinstance(current.args[0], Exception)
|
||||
):
|
||||
candidate = current.args[0]
|
||||
if candidate is None or id(candidate) in visited:
|
||||
break
|
||||
current = candidate
|
||||
return current
|
||||
|
||||
|
||||
def litellm_exception_to_error_msg(
|
||||
e: Exception,
|
||||
llm: LLM,
|
||||
@@ -74,31 +98,58 @@ def litellm_exception_to_error_msg(
|
||||
from litellm.exceptions import ContentPolicyViolationError
|
||||
from litellm.exceptions import BudgetExceededError
|
||||
|
||||
error_msg = str(e)
|
||||
core_exception = _unwrap_nested_exception(e)
|
||||
error_msg = str(core_exception)
|
||||
|
||||
if custom_error_msg_mappings:
|
||||
for error_msg_pattern, custom_error_msg in custom_error_msg_mappings.items():
|
||||
if error_msg_pattern in error_msg:
|
||||
return custom_error_msg
|
||||
|
||||
if isinstance(e, BadRequestError):
|
||||
if isinstance(core_exception, BadRequestError):
|
||||
error_msg = "Bad request: The server couldn't process your request. Please check your input."
|
||||
elif isinstance(e, AuthenticationError):
|
||||
elif isinstance(core_exception, AuthenticationError):
|
||||
error_msg = "Authentication failed: Please check your API key and credentials."
|
||||
elif isinstance(e, PermissionDeniedError):
|
||||
elif isinstance(core_exception, PermissionDeniedError):
|
||||
error_msg = (
|
||||
"Permission denied: You don't have the necessary permissions for this operation."
|
||||
"Ensure you have access to this model."
|
||||
)
|
||||
elif isinstance(e, NotFoundError):
|
||||
elif isinstance(core_exception, NotFoundError):
|
||||
error_msg = "Resource not found: The requested resource doesn't exist."
|
||||
elif isinstance(e, UnprocessableEntityError):
|
||||
elif isinstance(core_exception, UnprocessableEntityError):
|
||||
error_msg = "Unprocessable entity: The server couldn't process your request due to semantic errors."
|
||||
elif isinstance(e, RateLimitError):
|
||||
error_msg = (
|
||||
"Rate limit exceeded: Please slow down your requests and try again later."
|
||||
elif isinstance(core_exception, RateLimitError):
|
||||
provider_name = (
|
||||
llm.config.model_provider
|
||||
if llm is not None and llm.config.model_provider
|
||||
else "The LLM provider"
|
||||
)
|
||||
elif isinstance(e, ContextWindowExceededError):
|
||||
upstream_detail: str | None = None
|
||||
message_attr = getattr(core_exception, "message", None)
|
||||
if message_attr:
|
||||
upstream_detail = str(message_attr)
|
||||
elif hasattr(core_exception, "api_error"):
|
||||
api_error = core_exception.api_error # type: ignore[attr-defined]
|
||||
if isinstance(api_error, dict):
|
||||
upstream_detail = (
|
||||
api_error.get("message")
|
||||
or api_error.get("detail")
|
||||
or api_error.get("error")
|
||||
)
|
||||
if not upstream_detail:
|
||||
upstream_detail = str(core_exception)
|
||||
upstream_detail = str(upstream_detail).strip()
|
||||
if ":" in upstream_detail and upstream_detail.lower().startswith(
|
||||
"ratelimiterror"
|
||||
):
|
||||
upstream_detail = upstream_detail.split(":", 1)[1].strip()
|
||||
error_msg = (
|
||||
f"{provider_name} rate limit: {upstream_detail}"
|
||||
if upstream_detail
|
||||
else f"{provider_name} rate limit exceeded: Please slow down your requests and try again later."
|
||||
)
|
||||
elif isinstance(core_exception, ContextWindowExceededError):
|
||||
error_msg = (
|
||||
"Context window exceeded: Your input is too long for the model to process."
|
||||
)
|
||||
@@ -113,18 +164,21 @@ def litellm_exception_to_error_msg(
|
||||
logger.warning(
|
||||
"Unable to get maximum input token for LiteLLM excpetion handling"
|
||||
)
|
||||
elif isinstance(e, ContentPolicyViolationError):
|
||||
elif isinstance(core_exception, ContentPolicyViolationError):
|
||||
error_msg = "Content policy violation: Your request violates the content policy. Please revise your input."
|
||||
elif isinstance(e, APIConnectionError):
|
||||
elif isinstance(core_exception, APIConnectionError):
|
||||
error_msg = "API connection error: Failed to connect to the API. Please check your internet connection."
|
||||
elif isinstance(e, BudgetExceededError):
|
||||
elif isinstance(core_exception, BudgetExceededError):
|
||||
error_msg = (
|
||||
"Budget exceeded: You've exceeded your allocated budget for API usage."
|
||||
)
|
||||
elif isinstance(e, Timeout):
|
||||
elif isinstance(core_exception, Timeout):
|
||||
error_msg = "Request timed out: The operation took too long to complete. Please try again."
|
||||
elif isinstance(e, APIError):
|
||||
error_msg = f"API error: An error occurred while communicating with the API. Details: {str(e)}"
|
||||
elif isinstance(core_exception, APIError):
|
||||
error_msg = (
|
||||
"API error: An error occurred while communicating with the API. "
|
||||
f"Details: {str(core_exception)}"
|
||||
)
|
||||
elif not fallback_to_error_msg:
|
||||
error_msg = "An unexpected error occurred while processing your request. Please try again later."
|
||||
return error_msg
|
||||
@@ -385,7 +439,26 @@ def test_llm(llm: LLM) -> str | None:
|
||||
def get_model_map() -> dict:
|
||||
import litellm
|
||||
|
||||
starting_map = copy.deepcopy(cast(dict, litellm.model_cost))
|
||||
DIVIDER = "/"
|
||||
|
||||
original_map = cast(dict[str, dict], litellm.model_cost)
|
||||
starting_map = copy.deepcopy(original_map)
|
||||
for key in original_map:
|
||||
if DIVIDER in key:
|
||||
truncated_key = key.split(DIVIDER)[-1]
|
||||
# make sure not to overwrite an original key
|
||||
if truncated_key in original_map:
|
||||
continue
|
||||
|
||||
# if there are multiple possible matches, choose the most "detailed"
|
||||
# one as a heuristic. "detailed" = the description of the model
|
||||
# has the most filled out fields.
|
||||
existing_truncated_value = starting_map.get(truncated_key)
|
||||
potential_truncated_value = original_map[key]
|
||||
if not existing_truncated_value or len(potential_truncated_value) > len(
|
||||
existing_truncated_value
|
||||
):
|
||||
starting_map[truncated_key] = potential_truncated_value
|
||||
|
||||
# NOTE: we could add additional models here in the future,
|
||||
# but for now there is no point. Ollama allows the user to
|
||||
|
||||
@@ -76,12 +76,17 @@ from onyx.server.features.input_prompt.api import (
|
||||
from onyx.server.features.mcp.api import admin_router as mcp_admin_router
|
||||
from onyx.server.features.mcp.api import router as mcp_router
|
||||
from onyx.server.features.notifications.api import router as notification_router
|
||||
from onyx.server.features.oauth_config.api import (
|
||||
admin_router as admin_oauth_config_router,
|
||||
)
|
||||
from onyx.server.features.oauth_config.api import router as oauth_config_router
|
||||
from onyx.server.features.password.api import router as password_router
|
||||
from onyx.server.features.persona.api import admin_router as admin_persona_router
|
||||
from onyx.server.features.persona.api import basic_router as persona_router
|
||||
from onyx.server.features.projects.api import router as projects_router
|
||||
from onyx.server.features.tool.api import admin_router as admin_tool_router
|
||||
from onyx.server.features.tool.api import router as tool_router
|
||||
from onyx.server.features.user_oauth_token.api import router as user_oauth_token_router
|
||||
from onyx.server.federated.api import router as federated_router
|
||||
from onyx.server.gpts.api import router as gpts_router
|
||||
from onyx.server.kg.api import admin_router as kg_admin_router
|
||||
@@ -369,6 +374,9 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, notification_router)
|
||||
include_router_with_global_prefix_prepended(application, tool_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_tool_router)
|
||||
include_router_with_global_prefix_prepended(application, oauth_config_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_oauth_config_router)
|
||||
include_router_with_global_prefix_prepended(application, user_oauth_token_router)
|
||||
include_router_with_global_prefix_prepended(application, state_router)
|
||||
include_router_with_global_prefix_prepended(application, onyx_api_router)
|
||||
include_router_with_global_prefix_prepended(application, gpts_router)
|
||||
|
||||
@@ -200,7 +200,7 @@ class CloudEmbedding:
|
||||
response = await client.embeddings.create(
|
||||
input=text_batch,
|
||||
model=model,
|
||||
dimensions=reduced_dimension or openai.NOT_GIVEN,
|
||||
dimensions=reduced_dimension or openai.omit,
|
||||
)
|
||||
final_embeddings.extend(
|
||||
[embedding.embedding for embedding in response.data]
|
||||
|
||||
@@ -11,12 +11,16 @@ Try to cite inline as opposed to leaving all citations until the very end of the
|
||||
""".rstrip()
|
||||
|
||||
REQUIRE_CITATION_STATEMENT_V2 = """
|
||||
Cite relevant statements INLINE using the format [[1]](https://example.com) with the document number (an integer) in between
|
||||
the brackets. To cite multiple documents, use [[1]](https://example.com), [[2]](https://example.com) format instead of \
|
||||
[[1, 2]](https://example.com). \
|
||||
Cite relevant statements INLINE using the format [1], [3], etc. to reference the document_citation_number from the tool call response. \
|
||||
DO NOT provide any links following the citations. In other words, avoid using the format [1](https://example.com). \
|
||||
Avoid using double brackets like [[1]]. To cite multiple documents, use [1], [3] format instead of [1, 3]. \
|
||||
Try to cite inline as opposed to leaving all citations until the very end of the response.
|
||||
""".rstrip()
|
||||
|
||||
STRESS_USER_PROMPT_IMPORTANCE = """
|
||||
Here is the user's prompt:
|
||||
"""
|
||||
|
||||
NO_CITATION_STATEMENT = """
|
||||
Do not provide any citations even if there are examples in the chat history.
|
||||
""".rstrip()
|
||||
@@ -29,6 +33,38 @@ PROJECT_INSTRUCTIONS_SEPARATOR = (
|
||||
"but only for style, formatting, and context]]\n"
|
||||
)
|
||||
|
||||
LONG_CONVERSATION_REMINDER_TAG_OPEN = "<long_conversation_reminder>"
|
||||
LONG_CONVERSATION_REMINDER_TAG_CLOSED = "</long_conversation_reminder>"
|
||||
LONG_CONVERSATION_REMINDER_PROMPT = f"""
|
||||
A set of reminders may appear inside {LONG_CONVERSATION_REMINDER_TAG_OPEN} tags.
|
||||
This is added to the end of the person’s message. Behave in accordance with these instructions
|
||||
if they are relevant, and continue normally if they are not.
|
||||
"""
|
||||
|
||||
# ruff: noqa: E501, W605 start
|
||||
DEFAULT_SYSTEM_PROMPT = """
|
||||
You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the user's intent, ask clarifying questions when needed, think step-by-step through complex problems, provide clear and accurate answers, and proactively anticipate helpful follow-up information. Always prioritize being truthful, nuanced, insightful, and efficient.
|
||||
The current date is [[CURRENT_DATETIME]]
|
||||
|
||||
You use different text styles, bolding, emojis (sparingly), block quotes, and other formatting to make your responses more readable and engaging.
|
||||
You use proper Markdown and LaTeX to format your responses for math, scientific, and chemical formulas, symbols, etc.: '$$\\n[expression]\\n$$' for standalone cases and '\( [expression] \)' when inline.
|
||||
For code you prefer to use Markdown and specify the language.
|
||||
You can use Markdown horizontal rules (---) to separate sections of your responses.
|
||||
You can use Markdown tables to format your responses for data, lists, and other structured information.
|
||||
"""
|
||||
# ruff: noqa: E501, W605 end
|
||||
|
||||
TOOL_PERSISTENCE_PROMPT = """
|
||||
You are an agent with the following tools. Please keep going until the user's query is
|
||||
completely resolved, before ending your turn and yielding back to the user.
|
||||
For more complicated queries, try to do more tool calls to obtain information relevant to the user's query.
|
||||
Only terminate your turn when you are sure that the problem is solved.\n"
|
||||
"""
|
||||
|
||||
CUSTOM_INSTRUCTIONS_PROMPT = """
|
||||
The user has provided the following instructions, these are VERY IMPORTANT and must be adhered to at all times:
|
||||
"""
|
||||
|
||||
ADDITIONAL_INFO = "\n\nAdditional Information:\n\t- {datetime_info}."
|
||||
|
||||
CODE_BLOCK_MARKDOWN = "Formatting re-enabled. "
|
||||
|
||||
@@ -13,6 +13,9 @@ from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.models import Persona
|
||||
from onyx.prompts.chat_prompts import ADDITIONAL_INFO
|
||||
from onyx.prompts.chat_prompts import CITATION_REMINDER
|
||||
from onyx.prompts.chat_prompts import LONG_CONVERSATION_REMINDER_TAG_CLOSED
|
||||
from onyx.prompts.chat_prompts import LONG_CONVERSATION_REMINDER_TAG_OPEN
|
||||
from onyx.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT_V2
|
||||
from onyx.prompts.constants import CODE_BLOCK_PAT
|
||||
from onyx.prompts.direct_qa_prompts import COMPANY_DESCRIPTION_BLOCK
|
||||
from onyx.prompts.direct_qa_prompts import COMPANY_NAME_BLOCK
|
||||
@@ -131,6 +134,40 @@ def build_task_prompt_reminders(
|
||||
return base_task + citation_or_nothing + language_hint_or_nothing
|
||||
|
||||
|
||||
def build_task_prompt_reminders_v2(
|
||||
chat_turn_user_message: str,
|
||||
prompt: Persona | PromptConfig,
|
||||
use_language_hint: bool,
|
||||
should_cite: bool,
|
||||
language_hint_str: str = LANGUAGE_HINT,
|
||||
) -> str:
|
||||
"""V2 version that conditionally includes citation requirements.
|
||||
|
||||
Args:
|
||||
prompt: Persona or PromptConfig with task_prompt
|
||||
use_language_hint: Whether to include language hint
|
||||
should_cite: Whether to include citation requirement statement
|
||||
language_hint_str: Language hint string to use
|
||||
|
||||
Returns:
|
||||
Task prompt with optional citation statement and language hint
|
||||
"""
|
||||
base_task = prompt.task_prompt or ""
|
||||
citation_or_nothing = REQUIRE_CITATION_STATEMENT_V2 if should_cite else ""
|
||||
language_hint_or_nothing = language_hint_str.lstrip() if use_language_hint else ""
|
||||
if len(base_task) + len(citation_or_nothing) + len(language_hint_or_nothing) > 0:
|
||||
return f"""
|
||||
{LONG_CONVERSATION_REMINDER_TAG_OPEN}
|
||||
{base_task}
|
||||
{citation_or_nothing}
|
||||
{language_hint_or_nothing}
|
||||
{LONG_CONVERSATION_REMINDER_TAG_CLOSED}
|
||||
{chat_turn_user_message}
|
||||
"""
|
||||
else:
|
||||
return chat_turn_user_message
|
||||
|
||||
|
||||
# Maps connector enum string to a more natural language representation for the LLM
|
||||
# If not on the list, uses the original but slightly cleaned up, see below
|
||||
CONNECTOR_NAME_MAP = {
|
||||
|
||||
6
backend/onyx/server/features/oauth_config/__init__.py
Normal file
6
backend/onyx/server/features/oauth_config/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""OAuth configuration feature module."""
|
||||
|
||||
from onyx.server.features.oauth_config.api import admin_router
|
||||
from onyx.server.features.oauth_config.api import router
|
||||
|
||||
__all__ = ["admin_router", "router"]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user