Compare commits

..

6 Commits

Author SHA1 Message Date
Raunak Bhagat
9e9c3ec0b9 Remove unused imports 2025-11-18 13:51:10 -08:00
Raunak Bhagat
1457ca2a20 Make share button instantaneous 2025-11-18 13:50:37 -08:00
Raunak Bhagat
edc390edc6 Implement AppPage wrapper for all other pages inside of /chat 2025-11-18 13:34:38 -08:00
Raunak Bhagat
022624cb5a Maintain consistent heights 2025-11-18 13:20:09 -08:00
Raunak Bhagat
f301257130 Make chatSession info and settings info be passed in as server-side data 2025-11-18 13:07:52 -08:00
Raunak Bhagat
9eecc71cda Fix flashing 2025-11-18 11:43:49 -08:00
195 changed files with 1829 additions and 10602 deletions

View File

@@ -7,9 +7,9 @@ runs:
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-${{ runner.arch }}-playwright-${{ hashFiles('backend/requirements/default.txt') }}
key: ${{ runner.os }}-playwright-${{ hashFiles('backend/requirements/default.txt') }}
restore-keys: |
${{ runner.os }}-${{ runner.arch }}-playwright-
${{ runner.os }}-playwright-
- name: Install playwright
shell: bash

View File

@@ -10,9 +10,6 @@ on:
- main
- 'release/**'
permissions:
contents: read
jobs:
check-lazy-imports:
runs-on: ubuntu-latest
@@ -20,8 +17,6 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # ratchet:actions/setup-python@v6

View File

@@ -6,9 +6,6 @@ on:
- "*"
workflow_dispatch:
permissions:
contents: read
env:
IS_DRY_RUN: ${{ github.event_name == 'workflow_dispatch' }}
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
@@ -33,7 +30,7 @@ jobs:
- name: Check which components to build and version info
id: check
run: |
TAG="${GITHUB_REF_NAME}"
TAG="${{ github.ref_name }}"
# Sanitize tag name by replacing slashes with hyphens (for Docker tag compatibility)
SANITIZED_TAG=$(echo "$TAG" | tr '/' '-')
IS_CLOUD=false
@@ -82,143 +79,22 @@ jobs:
echo "sanitized-tag=$SANITIZED_TAG"
} >> "$GITHUB_OUTPUT"
build-web-amd64:
build-web:
needs: determine-builds
if: needs.determine-builds.outputs.build-web == 'true'
runs-on:
- runs-on
- runner=4cpu-linux-x64
- run-id=${{ github.run_id }}-web-amd64
- run-id=${{ github.run_id }}-web-build
- extras=ecr-cache
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
DEPLOYMENT: standalone
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/amd64
labels: ${{ steps.meta.outputs.labels }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64,mode=max
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-web-arm64:
needs: determine-builds
if: needs.determine-builds.outputs.build-web == 'true'
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- run-id=${{ github.run_id }}-web-arm64
- extras=ecr-cache
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/arm64
labels: ${{ steps.meta.outputs.labels }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64,mode=max
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
merge-web:
needs:
- determine-builds
- build-web-amd64
- build-web-arm64
runs-on:
- runs-on
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-merge-web
- extras=ecr-cache
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Docker meta
id: meta
@@ -233,171 +109,50 @@ jobs:
type=raw,value=${{ github.event_name != 'workflow_dispatch' && env.EDGE_TAG == 'true' && 'edge' || '' }}
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-beta == 'true' && 'beta' || '' }}
- name: Create and push manifest
env:
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
AMD64_DIGEST: ${{ needs.build-web-amd64.outputs.digest }}
ARM64_DIGEST: ${{ needs.build-web-arm64.outputs.digest }}
META_TAGS: ${{ steps.meta.outputs.tags }}
run: |
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
docker buildx imagetools create \
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
$IMAGES
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
build-web-cloud-amd64:
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-${{ env.DEPLOYMENT }}-cache
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-${{ env.DEPLOYMENT }}-cache,mode=max
build-web-cloud:
needs: determine-builds
if: needs.determine-builds.outputs.build-web-cloud == 'true'
runs-on:
- runs-on
- runner=4cpu-linux-x64
- run-id=${{ github.run_id }}-web-cloud-amd64
- run-id=${{ github.run_id }}-web-cloud-build
- extras=ecr-cache
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
DEPLOYMENT: cloud
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/amd64
labels: ${{ steps.meta.outputs.labels }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
NEXT_PUBLIC_CLOUD_ENABLED=true
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64,mode=max
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-web-cloud-arm64:
needs: determine-builds
if: needs.determine-builds.outputs.build-web-cloud == 'true'
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- run-id=${{ github.run_id }}-web-cloud-arm64
- extras=ecr-cache
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/arm64
labels: ${{ steps.meta.outputs.labels }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
NEXT_PUBLIC_CLOUD_ENABLED=true
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64,mode=max
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
merge-web-cloud:
needs:
- determine-builds
- build-web-cloud-amd64
- build-web-cloud-arm64
runs-on:
- runs-on
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-merge-web-cloud
- extras=ecr-cache
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Docker meta
id: meta
@@ -409,153 +164,58 @@ jobs:
tags: |
type=raw,value=${{ github.event_name == 'workflow_dispatch' && format('web-cloud-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
- name: Create and push manifest
env:
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
AMD64_DIGEST: ${{ needs.build-web-cloud-amd64.outputs.digest }}
ARM64_DIGEST: ${{ needs.build-web-cloud-arm64.outputs.digest }}
META_TAGS: ${{ steps.meta.outputs.tags }}
run: |
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
docker buildx imagetools create \
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
$IMAGES
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
build-backend-amd64:
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
NEXT_PUBLIC_CLOUD_ENABLED=true
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-${{ env.DEPLOYMENT }}-cache
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-${{ env.DEPLOYMENT }}-cache,mode=max
build-backend:
needs: determine-builds
if: needs.determine-builds.outputs.build-backend == 'true'
runs-on:
- runs-on
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-backend-amd64
- run-id=${{ github.run_id }}-backend-build
- extras=ecr-cache
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64
labels: ${{ steps.meta.outputs.labels }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64,mode=max
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-backend-arm64:
needs: determine-builds
if: needs.determine-builds.outputs.build-backend == 'true'
runs-on:
- runs-on
- runner=2cpu-linux-arm64
- run-id=${{ github.run_id }}-backend-arm64
- extras=ecr-cache
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/arm64
labels: ${{ steps.meta.outputs.labels }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64,mode=max
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
merge-backend:
needs:
- determine-builds
- build-backend-amd64
- build-backend-arm64
runs-on:
- runs-on
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-merge-backend
- extras=ecr-cache
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Docker meta
id: meta
@@ -570,51 +230,8 @@ jobs:
type=raw,value=${{ github.event_name != 'workflow_dispatch' && env.EDGE_TAG == 'true' && 'edge' || '' }}
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
- name: Create and push manifest
env:
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
AMD64_DIGEST: ${{ needs.build-backend-amd64.outputs.digest }}
ARM64_DIGEST: ${{ needs.build-backend-arm64.outputs.digest }}
META_TAGS: ${{ steps.meta.outputs.tags }}
run: |
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
docker buildx imagetools create \
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
$IMAGES
build-model-server-amd64:
needs: determine-builds
if: needs.determine-builds.outputs.build-model-server == 'true'
runs-on:
- runs-on
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-model-server-amd64
- volume=40gb
- extras=ecr-cache
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
with:
buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }}
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
@@ -622,115 +239,43 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
- name: Build and push
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
env:
DEBUG: ${{ vars.DOCKER_DEBUG == 'true' && 1 || 0 }}
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64
file: ./backend/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-${{ env.DEPLOYMENT }}-cache
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64,mode=max
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
provenance: false
sbom: false
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-${{ env.DEPLOYMENT }}-cache,mode=max
build-model-server-arm64:
build-model-server:
needs: determine-builds
if: needs.determine-builds.outputs.build-model-server == 'true'
runs-on:
- runs-on
- runner=2cpu-linux-arm64
- run-id=${{ github.run_id }}-model-server-arm64
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-model-server-build
- volume=40gb
- extras=ecr-cache
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
DOCKER_BUILDKIT: 1
BUILDKIT_PROGRESS: plain
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
with:
buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }}
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
env:
DEBUG: ${{ vars.DOCKER_DEBUG == 'true' && 1 || 0 }}
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/arm64
labels: ${{ steps.meta.outputs.labels }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64,mode=max
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
provenance: false
sbom: false
merge-model-server:
needs:
- determine-builds
- build-model-server-amd64
- build-model-server-arm64
runs-on:
- runs-on
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-merge-model-server
- extras=ecr-cache
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Docker meta
id: meta
@@ -745,26 +290,43 @@ jobs:
type=raw,value=${{ github.event_name != 'workflow_dispatch' && env.EDGE_TAG == 'true' && 'edge' || '' }}
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
- name: Create and push manifest
env:
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
AMD64_DIGEST: ${{ needs.build-model-server-amd64.outputs.digest }}
ARM64_DIGEST: ${{ needs.build-model-server-arm64.outputs.digest }}
META_TAGS: ${{ steps.meta.outputs.tags }}
run: |
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
docker buildx imagetools create \
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
$IMAGES
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
with:
driver-opts: |
image=moby/buildkit:latest
network=host
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-${{ env.DEPLOYMENT }}-cache
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-${{ env.DEPLOYMENT }}-cache,mode=max
trivy-scan-web:
needs:
- determine-builds
- merge-web
if: needs.merge-web.result == 'success'
needs: [determine-builds, build-web]
if: needs.build-web.result == 'success'
runs-on:
- runs-on
- runner=2cpu-linux-arm64
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-trivy-scan-web
- extras=ecr-cache
env:
@@ -797,13 +359,11 @@ jobs:
${SCAN_IMAGE}
trivy-scan-web-cloud:
needs:
- determine-builds
- merge-web-cloud
if: needs.merge-web-cloud.result == 'success'
needs: [determine-builds, build-web-cloud]
if: needs.build-web-cloud.result == 'success'
runs-on:
- runs-on
- runner=2cpu-linux-arm64
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-trivy-scan-web-cloud
- extras=ecr-cache
env:
@@ -836,13 +396,11 @@ jobs:
${SCAN_IMAGE}
trivy-scan-backend:
needs:
- determine-builds
- merge-backend
if: needs.merge-backend.result == 'success'
needs: [determine-builds, build-backend]
if: needs.build-backend.result == 'success'
runs-on:
- runs-on
- runner=2cpu-linux-arm64
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-trivy-scan-backend
- extras=ecr-cache
env:
@@ -852,8 +410,6 @@ jobs:
- name: Checkout
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Run Trivy vulnerability scanner
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
@@ -882,13 +438,11 @@ jobs:
${SCAN_IMAGE}
trivy-scan-model-server:
needs:
- determine-builds
- merge-model-server
if: needs.merge-model-server.result == 'success'
needs: [determine-builds, build-model-server]
if: needs.build-model-server.result == 'success'
runs-on:
- runs-on
- runner=2cpu-linux-arm64
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-trivy-scan-model-server
- extras=ecr-cache
env:
@@ -921,85 +475,33 @@ jobs:
${SCAN_IMAGE}
notify-slack-on-failure:
needs:
- build-web-amd64
- build-web-arm64
- merge-web
- build-web-cloud-amd64
- build-web-cloud-arm64
- merge-web-cloud
- build-backend-amd64
- build-backend-arm64
- merge-backend
- build-model-server-amd64
- build-model-server-arm64
- merge-model-server
if: always() && (needs.build-web-amd64.result == 'failure' || needs.build-web-arm64.result == 'failure' || needs.merge-web.result == 'failure' || needs.build-web-cloud-amd64.result == 'failure' || needs.build-web-cloud-arm64.result == 'failure' || needs.merge-web-cloud.result == 'failure' || needs.build-backend-amd64.result == 'failure' || needs.build-backend-arm64.result == 'failure' || needs.merge-backend.result == 'failure' || needs.build-model-server-amd64.result == 'failure' || needs.build-model-server-arm64.result == 'failure' || needs.merge-model-server.result == 'failure') && github.event_name != 'workflow_dispatch'
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
needs: [build-web, build-web-cloud, build-backend, build-model-server]
if: always() && (needs.build-web.result == 'failure' || needs.build-web-cloud.result == 'failure' || needs.build-backend.result == 'failure' || needs.build-model-server.result == 'failure') && github.event_name != 'workflow_dispatch'
runs-on: ubuntu-slim
steps:
- name: Checkout
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Determine failed jobs
id: failed-jobs
shell: bash
run: |
FAILED_JOBS=""
if [ "${NEEDS_BUILD_WEB_AMD64_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-web-amd64\\n"
if [ "${{ needs.build-web.result }}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-web\\n"
fi
if [ "${NEEDS_BUILD_WEB_ARM64_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-web-arm64\\n"
if [ "${{ needs.build-web-cloud.result }}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-web-cloud\\n"
fi
if [ "${NEEDS_MERGE_WEB_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• merge-web\\n"
if [ "${{ needs.build-backend.result }}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-backend\\n"
fi
if [ "${NEEDS_BUILD_WEB_CLOUD_AMD64_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-web-cloud-amd64\\n"
fi
if [ "${NEEDS_BUILD_WEB_CLOUD_ARM64_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-web-cloud-arm64\\n"
fi
if [ "${NEEDS_MERGE_WEB_CLOUD_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• merge-web-cloud\\n"
fi
if [ "${NEEDS_BUILD_BACKEND_AMD64_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-backend-amd64\\n"
fi
if [ "${NEEDS_BUILD_BACKEND_ARM64_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-backend-arm64\\n"
fi
if [ "${NEEDS_MERGE_BACKEND_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• merge-backend\\n"
fi
if [ "${NEEDS_BUILD_MODEL_SERVER_AMD64_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-model-server-amd64\\n"
fi
if [ "${NEEDS_BUILD_MODEL_SERVER_ARM64_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-model-server-arm64\\n"
fi
if [ "${NEEDS_MERGE_MODEL_SERVER_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• merge-model-server\\n"
if [ "${{ needs.build-model-server.result }}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-model-server\\n"
fi
# Remove trailing \n and set output
FAILED_JOBS=$(printf '%s' "$FAILED_JOBS" | sed 's/\\n$//')
echo "jobs=$FAILED_JOBS" >> "$GITHUB_OUTPUT"
env:
NEEDS_BUILD_WEB_AMD64_RESULT: ${{ needs.build-web-amd64.result }}
NEEDS_BUILD_WEB_ARM64_RESULT: ${{ needs.build-web-arm64.result }}
NEEDS_MERGE_WEB_RESULT: ${{ needs.merge-web.result }}
NEEDS_BUILD_WEB_CLOUD_AMD64_RESULT: ${{ needs.build-web-cloud-amd64.result }}
NEEDS_BUILD_WEB_CLOUD_ARM64_RESULT: ${{ needs.build-web-cloud-arm64.result }}
NEEDS_MERGE_WEB_CLOUD_RESULT: ${{ needs.merge-web-cloud.result }}
NEEDS_BUILD_BACKEND_AMD64_RESULT: ${{ needs.build-backend-amd64.result }}
NEEDS_BUILD_BACKEND_ARM64_RESULT: ${{ needs.build-backend-arm64.result }}
NEEDS_MERGE_BACKEND_RESULT: ${{ needs.merge-backend.result }}
NEEDS_BUILD_MODEL_SERVER_AMD64_RESULT: ${{ needs.build-model-server-amd64.result }}
NEEDS_BUILD_MODEL_SERVER_ARM64_RESULT: ${{ needs.build-model-server-arm64.result }}
NEEDS_MERGE_MODEL_SERVER_RESULT: ${{ needs.merge-model-server.result }}
- name: Send Slack notification
uses: ./.github/actions/slack-notify

View File

@@ -10,9 +10,6 @@ on:
description: "The version (ie v1.0.0-beta.0) to tag as beta"
required: true
permissions:
contents: read
jobs:
tag:
# See https://runs-on.com/runners/linux/
@@ -32,19 +29,13 @@ jobs:
run: echo "DOCKER_CLI_EXPERIMENTAL=enabled" >> $GITHUB_ENV
- name: Pull, Tag and Push Web Server Image
env:
VERSION: ${{ github.event.inputs.version }}
run: |
docker buildx imagetools create -t onyxdotapp/onyx-web-server:beta onyxdotapp/onyx-web-server:${VERSION}
docker buildx imagetools create -t onyxdotapp/onyx-web-server:beta onyxdotapp/onyx-web-server:${{ github.event.inputs.version }}
- name: Pull, Tag and Push API Server Image
env:
VERSION: ${{ github.event.inputs.version }}
run: |
docker buildx imagetools create -t onyxdotapp/onyx-backend:beta onyxdotapp/onyx-backend:${VERSION}
docker buildx imagetools create -t onyxdotapp/onyx-backend:beta onyxdotapp/onyx-backend:${{ github.event.inputs.version }}
- name: Pull, Tag and Push Model Server Image
env:
VERSION: ${{ github.event.inputs.version }}
run: |
docker buildx imagetools create -t onyxdotapp/onyx-model-server:beta onyxdotapp/onyx-model-server:${VERSION}
docker buildx imagetools create -t onyxdotapp/onyx-model-server:beta onyxdotapp/onyx-model-server:${{ github.event.inputs.version }}

View File

@@ -10,9 +10,6 @@ on:
description: "The version (ie v0.0.1) to tag as latest"
required: true
permissions:
contents: read
jobs:
tag:
# See https://runs-on.com/runners/linux/
@@ -32,19 +29,13 @@ jobs:
run: echo "DOCKER_CLI_EXPERIMENTAL=enabled" >> $GITHUB_ENV
- name: Pull, Tag and Push Web Server Image
env:
VERSION: ${{ github.event.inputs.version }}
run: |
docker buildx imagetools create -t onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:${VERSION}
docker buildx imagetools create -t onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:${{ github.event.inputs.version }}
- name: Pull, Tag and Push API Server Image
env:
VERSION: ${{ github.event.inputs.version }}
run: |
docker buildx imagetools create -t onyxdotapp/onyx-backend:latest onyxdotapp/onyx-backend:${VERSION}
docker buildx imagetools create -t onyxdotapp/onyx-backend:latest onyxdotapp/onyx-backend:${{ github.event.inputs.version }}
- name: Pull, Tag and Push Model Server Image
env:
VERSION: ${{ github.event.inputs.version }}
run: |
docker buildx imagetools create -t onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:${VERSION}
docker buildx imagetools create -t onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:${{ github.event.inputs.version }}

View File

@@ -17,7 +17,6 @@ jobs:
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Install Helm CLI
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4

View File

@@ -15,21 +15,16 @@ on:
permissions:
actions: read
contents: read
security-events: write
jobs:
scan-licenses:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-licenses"]
permissions:
actions: read
contents: read
security-events: write
steps:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # ratchet:actions/setup-python@v6
@@ -59,9 +54,7 @@ jobs:
- name: Print report
if: always()
env:
REPORT: ${{ steps.license_check_report.outputs.report }}
run: echo "$REPORT"
run: echo "${{ steps.license_check_report.outputs.report }}"
- name: Install npm dependencies
working-directory: ./web

View File

@@ -8,9 +8,6 @@ on:
pull_request:
branches: [main]
permissions:
contents: read
env:
# AWS
S3_AWS_ACCESS_KEY_ID: ${{ secrets.S3_AWS_ACCESS_KEY_ID }}
@@ -40,8 +37,6 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Discover test directories
id: set-matrix
@@ -72,8 +67,6 @@ jobs:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Setup Python and Install Dependencies
uses: ./.github/actions/setup-python-and-install-dependencies
@@ -104,12 +97,10 @@ jobs:
- name: Run Tests for ${{ matrix.test-dir }}
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
env:
TEST_DIR: ${{ matrix.test-dir }}
run: |
py.test \
--durations=8 \
-o junit_family=xunit2 \
-xv \
--ff \
backend/tests/external_dependency_unit/${TEST_DIR}
backend/tests/external_dependency_unit/${{ matrix.test-dir }}

View File

@@ -9,9 +9,6 @@ on:
branches: [ main ]
workflow_dispatch: # Allows manual triggering
permissions:
contents: read
jobs:
helm-chart-check:
# See https://runs-on.com/runners/linux/
@@ -23,7 +20,6 @@ jobs:
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Set up Helm
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4.3.1
@@ -36,11 +32,9 @@ jobs:
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
- name: Run chart-testing (list-changed)
id: list-changed
env:
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
run: |
echo "default_branch: ${DEFAULT_BRANCH}"
changed=$(ct list-changed --remote origin --target-branch ${DEFAULT_BRANCH} --chart-dirs deployment/helm/charts)
echo "default_branch: ${{ github.event.repository.default_branch }}"
changed=$(ct list-changed --remote origin --target-branch ${{ github.event.repository.default_branch }} --chart-dirs deployment/helm/charts)
echo "list-changed output: $changed"
if [[ -n "$changed" ]]; then
echo "changed=true" >> "$GITHUB_OUTPUT"

View File

@@ -10,9 +10,6 @@ on:
- main
- "release/**"
permissions:
contents: read
env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -40,8 +37,6 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Discover test directories
id: set-matrix
@@ -70,8 +65,6 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -92,11 +85,8 @@ jobs:
file: ./backend/Dockerfile
push: true
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
type=registry,ref=onyxdotapp/onyx-backend:latest
cache-to: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-cache
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-cache,mode=max
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
@@ -106,8 +96,6 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -128,10 +116,8 @@ jobs:
file: ./backend/Dockerfile.model_server
push: true
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
type=registry,ref=onyxdotapp/onyx-model-server:latest
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-cache
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-cache,mode=max
build-integration-image:
@@ -140,8 +126,6 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -157,16 +141,9 @@ jobs:
- name: Build and push integration test image with Docker Bake
env:
INTEGRATION_REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
TAG: integration-test-${{ github.run_id }}
run: |
cd backend && docker buildx bake --push \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
integration
run: cd backend && docker buildx bake --push integration
integration-tests:
needs:
@@ -191,8 +168,6 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
@@ -206,9 +181,6 @@ jobs:
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
# NOTE: don't need web server for integration tests
- name: Start Docker containers
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
@@ -217,8 +189,8 @@ jobs:
POSTGRES_USE_NULL_POOL=true \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
ONYX_BACKEND_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }} \
ONYX_MODEL_SERVER_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }} \
INTEGRATION_TESTS_MODE=true \
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 \
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
@@ -350,8 +322,6 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
@@ -360,9 +330,6 @@ jobs:
password: ${{ secrets.DOCKER_TOKEN }}
- name: Start Docker containers for multi-tenant tests
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
@@ -370,8 +337,8 @@ jobs:
AUTH_TYPE=cloud \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
ONYX_BACKEND_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }} \
ONYX_MODEL_SERVER_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }} \
DEV_MODE=true \
docker compose -f docker-compose.multitenant-dev.yml up \
relational_db \
@@ -412,9 +379,6 @@ jobs:
echo "Finished waiting for service."
- name: Run Multi-Tenant Integration Tests
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
echo "Running multi-tenant integration tests..."
docker run --rm --network onyx_default \
@@ -438,7 +402,7 @@ jobs:
-e REQUIRE_EMAIL_VERIFICATION=false \
-e DISABLE_TELEMETRY=true \
-e DEV_MODE=true \
${ECR_CACHE}:integration-test-${RUN_ID} \
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
/app/tests/integration/multitenant_tests
- name: Dump API server logs (multi-tenant)
@@ -472,6 +436,13 @@ jobs:
needs: [integration-tests, multitenant-tests]
if: ${{ always() }}
steps:
- name: Check job status
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
run: exit 1
- uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # ratchet:actions/github-script@v8
with:
script: |
const needs = ${{ toJSON(needs) }};
const failed = Object.values(needs).some(n => n.result !== 'success');
if (failed) {
core.setFailed('One or more upstream jobs failed or were cancelled.');
} else {
core.notice('All required jobs succeeded.');
}

View File

@@ -5,9 +5,6 @@ concurrency:
on: push
permissions:
contents: read
jobs:
jest-tests:
name: Jest Tests
@@ -15,8 +12,6 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Setup node
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4

View File

@@ -1,7 +1,7 @@
name: PR Labeler
on:
pull_request:
pull_request_target:
branches:
- main
types:
@@ -12,6 +12,7 @@ on:
permissions:
contents: read
pull-requests: write
jobs:
validate_pr_title:

View File

@@ -7,9 +7,6 @@ on:
pull_request:
types: [opened, edited, reopened, synchronize]
permissions:
contents: read
jobs:
linear-check:
runs-on: ubuntu-latest

View File

@@ -7,9 +7,6 @@ on:
merge_group:
types: [checks_requested]
permissions:
contents: read
env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -36,8 +33,6 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Discover test directories
id: set-matrix
@@ -65,8 +60,6 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -87,10 +80,8 @@ jobs:
file: ./backend/Dockerfile
push: true
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
type=registry,ref=onyxdotapp/onyx-backend:latest
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-cache
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-cache,mode=max
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-model-server-image:
@@ -99,8 +90,6 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -121,10 +110,8 @@ jobs:
file: ./backend/Dockerfile.model_server
push: true
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
type=registry,ref=onyxdotapp/onyx-model-server:latest
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-cache
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-cache,mode=max
build-integration-image:
runs-on: [runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-build-integration-image", "extras=ecr-cache"]
@@ -132,8 +119,6 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -149,16 +134,9 @@ jobs:
- name: Build and push integration test image with Docker Bake
env:
INTEGRATION_REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
TAG: integration-test-${{ github.run_id }}
run: |
cd backend && docker buildx bake --push \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
integration
run: cd backend && docker buildx bake --push integration
integration-tests-mit:
needs:
@@ -183,8 +161,6 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
@@ -198,9 +174,6 @@ jobs:
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
# NOTE: don't need web server for integration tests
- name: Start Docker containers
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cd deployment/docker_compose
AUTH_TYPE=basic \
@@ -208,8 +181,8 @@ jobs:
POSTGRES_USE_NULL_POOL=true \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
ONYX_BACKEND_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }} \
ONYX_MODEL_SERVER_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }} \
INTEGRATION_TESTS_MODE=true \
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
relational_db \
@@ -334,6 +307,13 @@ jobs:
needs: [integration-tests-mit]
if: ${{ always() }}
steps:
- name: Check job status
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
run: exit 1
- uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # ratchet:actions/github-script@v8
with:
script: |
const needs = ${{ toJSON(needs) }};
const failed = Object.values(needs).some(n => n.result !== 'success');
if (failed) {
core.setFailed('One or more upstream jobs failed or were cancelled.');
} else {
core.notice('All required jobs succeeded.');
}

View File

@@ -5,9 +5,6 @@ concurrency:
on: push
permissions:
contents: read
env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -45,8 +42,6 @@ jobs:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -67,10 +62,8 @@ jobs:
platforms: linux/arm64
tags: ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-${{ github.run_id }}
push: true
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache
type=registry,ref=onyxdotapp/onyx-web-server:latest
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache,mode=max
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-cache
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-cache,mode=max
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-backend-image:
@@ -80,8 +73,6 @@ jobs:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -102,11 +93,8 @@ jobs:
platforms: linux/arm64
tags: ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-${{ github.run_id }}
push: true
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
type=registry,ref=onyxdotapp/onyx-backend:latest
cache-to: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-cache
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-cache,mode=max
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-model-server-image:
@@ -116,8 +104,6 @@ jobs:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -138,10 +124,8 @@ jobs:
platforms: linux/arm64
tags: ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-${{ github.run_id }}
push: true
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
type=registry,ref=onyxdotapp/onyx-model-server:latest
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-cache
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-cache,mode=max
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
playwright-tests:
@@ -159,7 +143,6 @@ jobs:
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Setup node
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4
@@ -185,22 +168,17 @@ jobs:
run: npx playwright install --with-deps
- name: Create .env file for Docker Compose
env:
OPENAI_API_KEY_VALUE: ${{ env.OPENAI_API_KEY }}
EXA_API_KEY_VALUE: ${{ env.EXA_API_KEY }}
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cat <<EOF > deployment/docker_compose/.env
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
AUTH_TYPE=basic
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
EXA_API_KEY=${EXA_API_KEY_VALUE}
GEN_AI_API_KEY=${{ env.OPENAI_API_KEY }}
EXA_API_KEY=${{ env.EXA_API_KEY }}
REQUIRE_EMAIL_VERIFICATION=false
DISABLE_TELEMETRY=true
ONYX_BACKEND_IMAGE=${ECR_CACHE}:playwright-test-backend-${RUN_ID}
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:playwright-test-model-server-${RUN_ID}
ONYX_WEB_SERVER_IMAGE=${ECR_CACHE}:playwright-test-web-${RUN_ID}
ONYX_BACKEND_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-${{ github.run_id }}
ONYX_MODEL_SERVER_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-${{ github.run_id }}
ONYX_WEB_SERVER_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-${{ github.run_id }}
EOF
# needed for pulling Vespa, Redis, Postgres, and Minio images
@@ -277,12 +255,10 @@ jobs:
- name: Run Playwright tests
working-directory: ./web
env:
PROJECT: ${{ matrix.project }}
run: |
# Create test-results directory to ensure it exists for artifact upload
mkdir -p test-results
npx playwright test --project ${PROJECT}
npx playwright test --project ${{ matrix.project }}
- uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
if: always()
@@ -295,12 +271,10 @@ jobs:
# save before stopping the containers so the logs can be captured
- name: Save Docker logs
if: success() || failure()
env:
WORKSPACE: ${{ github.workspace }}
run: |
cd deployment/docker_compose
docker compose logs > docker-compose.log
mv docker-compose.log ${WORKSPACE}/docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
- name: Upload logs
if: success() || failure()
@@ -309,16 +283,6 @@ jobs:
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
path: ${{ github.workspace }}/docker-compose.log
playwright-required:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
needs: [playwright-tests]
if: ${{ always() }}
steps:
- name: Check job status
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
run: exit 1
# NOTE: Chromatic UI diff testing is currently disabled.
# We are using Playwright for local and CI testing without visual regression checks.

View File

@@ -10,9 +10,6 @@ on:
- main
- 'release/**'
permissions:
contents: read
jobs:
mypy-check:
# See https://runs-on.com/runners/linux/
@@ -24,8 +21,6 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
# needed for pulling openapitools/openapi-generator-cli
# otherwise, we hit the "Unauthenticated users" limit

View File

@@ -11,9 +11,6 @@ on:
# This cron expression runs the job daily at 16:00 UTC (9am PT)
- cron: "0 16 * * *"
permissions:
contents: read
env:
# AWS
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
@@ -135,8 +132,6 @@ jobs:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Setup Python and Install Dependencies
uses: ./.github/actions/setup-python-and-install-dependencies
@@ -219,10 +214,8 @@ jobs:
if: failure() && github.event_name == 'schedule'
env:
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
REPO: ${{ github.repository }}
RUN_ID: ${{ github.run_id }}
run: |
curl -X POST \
-H 'Content-type: application/json' \
--data "{\"text\":\"Scheduled Connector Tests failed! Check the run at: https://github.com/${REPO}/actions/runs/${RUN_ID}\"}" \
--data '{"text":"Scheduled Connector Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
$SLACK_WEBHOOK

View File

@@ -11,9 +11,6 @@ on:
required: false
default: 'main'
permissions:
contents: read
env:
# Bedrock
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
@@ -39,8 +36,6 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
@@ -127,12 +122,10 @@ jobs:
if: failure() && github.event_name == 'schedule'
env:
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
REPO: ${{ github.repository }}
RUN_ID: ${{ github.run_id }}
run: |
curl -X POST \
-H 'Content-type: application/json' \
--data "{\"text\":\"Scheduled Model Tests failed! Check the run at: https://github.com/${REPO}/actions/runs/${RUN_ID}\"}" \
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
$SLACK_WEBHOOK
- name: Dump all-container logs (optional)

View File

@@ -10,9 +10,6 @@ on:
- main
- 'release/**'
permissions:
contents: read
jobs:
backend-check:
# See https://runs-on.com/runners/linux/
@@ -31,8 +28,6 @@ jobs:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Setup Python and Install Dependencies
uses: ./.github/actions/setup-python-and-install-dependencies

View File

@@ -7,9 +7,6 @@ on:
merge_group:
pull_request: null
permissions:
contents: read
jobs:
quality-checks:
# See https://runs-on.com/runners/linux/
@@ -19,7 +16,6 @@ jobs:
- uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # ratchet:actions/setup-python@v6
with:
python-version: "3.11"

View File

@@ -16,7 +16,6 @@ jobs:
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Install git-filter-repo
run: |

View File

@@ -3,29 +3,30 @@ name: Nightly Tag Push
on:
schedule:
- cron: "0 10 * * *" # Runs every day at 2 AM PST / 3 AM PDT / 10 AM UTC
workflow_dispatch:
permissions:
contents: write # Allows pushing tags to the repository
jobs:
create-and-push-tag:
runs-on: ubuntu-slim
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-create-and-push-tag"]
steps:
# actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
# implement here which needs an actual user's deploy key
# Additional NOTE: even though this is named "rkuo", the actual key is tied to the onyx repo
# and not rkuo's personal account. It is fine to leave this key as is!
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
ssh-key: "${{ secrets.DEPLOY_KEY }}"
persist-credentials: true
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
- name: Set up Git user
run: |
git config user.name "Onyx Bot [bot]"
git config user.email "onyx-bot[bot]@onyx.app"
git config user.name "Richard Kuo [bot]"
git config user.email "rkuo[bot]@onyx.app"
- name: Check for existing nightly tag
id: check_tag
@@ -53,12 +54,3 @@ jobs:
run: |
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
git push origin $TAG_NAME
- name: Send Slack notification
if: failure()
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
title: "🚨 Nightly Tag Push Failed"
ref-name: ${{ github.ref_name }}
failed-jobs: "create-and-push-tag"

View File

@@ -1,35 +0,0 @@
name: Run Zizmor
on:
push:
branches: ["main"]
pull_request:
branches: ["**"]
permissions: {}
jobs:
zizmor:
name: zizmor
runs-on: ubuntu-slim
permissions:
security-events: write # needed for SARIF uploads
steps:
- name: Checkout repository
uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # ratchet:actions/checkout@v5.0.1
with:
persist-credentials: false
- name: Install the latest version of uv
uses: astral-sh/setup-uv@5a7eac68fb9809dea845d802897dc5c723910fa3 # ratchet:astral-sh/setup-uv@v7.1.3
- name: Run zizmor
run: uvx zizmor==1.16.3 --format=sarif . > results.sarif
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Upload SARIF file
uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab # ratchet:github/codeql-action/upload-sarif@codeql-bundle-v2.23.5
with:
sarif_file: results.sarif
category: zizmor

View File

@@ -1,89 +0,0 @@
"""add internet search and content provider tables
Revision ID: 1f2a3b4c5d6e
Revises: 9drpiiw74ljy
Create Date: 2025-11-10 19:45:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "1f2a3b4c5d6e"
down_revision = "9drpiiw74ljy"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"internet_search_provider",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("name", sa.String(), nullable=False, unique=True),
sa.Column("provider_type", sa.String(), nullable=False),
sa.Column("api_key", sa.LargeBinary(), nullable=True),
sa.Column("config", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column(
"is_active", sa.Boolean(), nullable=False, server_default=sa.text("false")
),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
)
op.create_index(
"ix_internet_search_provider_is_active",
"internet_search_provider",
["is_active"],
)
op.create_table(
"internet_content_provider",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("name", sa.String(), nullable=False, unique=True),
sa.Column("provider_type", sa.String(), nullable=False),
sa.Column("api_key", sa.LargeBinary(), nullable=True),
sa.Column("config", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column(
"is_active", sa.Boolean(), nullable=False, server_default=sa.text("false")
),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
)
op.create_index(
"ix_internet_content_provider_is_active",
"internet_content_provider",
["is_active"],
)
def downgrade() -> None:
op.drop_index(
"ix_internet_content_provider_is_active", table_name="internet_content_provider"
)
op.drop_table("internet_content_provider")
op.drop_index(
"ix_internet_search_provider_is_active", table_name="internet_search_provider"
)
op.drop_table("internet_search_provider")

View File

@@ -1,16 +1,4 @@
group "default" {
targets = ["backend", "model-server"]
}
variable "BACKEND_REPOSITORY" {
default = "onyxdotapp/onyx-backend"
}
variable "MODEL_SERVER_REPOSITORY" {
default = "onyxdotapp/onyx-model-server"
}
variable "INTEGRATION_REPOSITORY" {
variable "REPOSITORY" {
default = "onyxdotapp/onyx-integration"
}
@@ -21,22 +9,6 @@ variable "TAG" {
target "backend" {
context = "."
dockerfile = "Dockerfile"
cache-from = ["type=registry,ref=${BACKEND_REPOSITORY}:latest"]
cache-to = ["type=inline"]
tags = ["${BACKEND_REPOSITORY}:${TAG}"]
}
target "model-server" {
context = "."
dockerfile = "Dockerfile.model_server"
cache-from = ["type=registry,ref=${MODEL_SERVER_REPOSITORY}:latest"]
cache-to = ["type=inline"]
tags = ["${MODEL_SERVER_REPOSITORY}:${TAG}"]
}
target "integration" {
@@ -48,5 +20,8 @@ target "integration" {
base = "target:backend"
}
tags = ["${INTEGRATION_REPOSITORY}:${TAG}"]
cache-from = ["type=registry,ref=${REPOSITORY}:integration-test-backend-cache"]
cache-to = ["type=registry,ref=${REPOSITORY}:integration-test-backend-cache,mode=max"]
tags = ["${REPOSITORY}:${TAG}"]
}

View File

@@ -116,7 +116,7 @@ def _concurrent_embedding(
# the model to fail to encode texts. It's pretty rare and we want to allow
# concurrent embedding, hence we retry (the specific error is
# "RuntimeError: Already borrowed" and occurs in the transformers library)
logger.warning(f"Error encoding texts, retrying: {e}")
logger.error(f"Error encoding texts, retrying: {e}")
time.sleep(ENCODING_RETRY_DELAY)
return model.encode(texts, normalize_embeddings=normalize_embeddings)

View File

@@ -1,73 +0,0 @@
import json
from collections.abc import Sequence
from typing import cast
from langchain_core.messages import AIMessage
from langchain_core.messages import BaseMessage
from langchain_core.messages import FunctionMessage
from onyx.llm.message_types import AssistantMessage
from onyx.llm.message_types import ChatCompletionMessage
from onyx.llm.message_types import FunctionCall
from onyx.llm.message_types import SystemMessage
from onyx.llm.message_types import ToolCall
from onyx.llm.message_types import ToolMessage
from onyx.llm.message_types import UserMessageWithText
HUMAN = "human"
SYSTEM = "system"
AI = "ai"
FUNCTION = "function"
def base_messages_to_chat_completion_msgs(
msgs: Sequence[BaseMessage],
) -> list[ChatCompletionMessage]:
return [_base_message_to_chat_completion_msg(msg) for msg in msgs]
def _base_message_to_chat_completion_msg(
msg: BaseMessage,
) -> ChatCompletionMessage:
if msg.type == HUMAN:
content = msg.content if isinstance(msg.content, str) else str(msg.content)
user_msg: UserMessageWithText = {"role": "user", "content": content}
return user_msg
if msg.type == SYSTEM:
content = msg.content if isinstance(msg.content, str) else str(msg.content)
system_msg: SystemMessage = {"role": "system", "content": content}
return system_msg
if msg.type == AI:
content = msg.content if isinstance(msg.content, str) else str(msg.content)
assistant_msg: AssistantMessage = {
"role": "assistant",
"content": content,
}
if isinstance(msg, AIMessage) and msg.tool_calls:
assistant_msg["tool_calls"] = [
ToolCall(
id=tool_call.get("id") or "",
type="function",
function=FunctionCall(
name=tool_call["name"],
arguments=json.dumps(tool_call["args"]),
),
)
for tool_call in msg.tool_calls
]
return assistant_msg
if msg.type == FUNCTION:
function_message = cast(FunctionMessage, msg)
content = (
function_message.content
if isinstance(function_message.content, str)
else str(function_message.content)
)
tool_msg: ToolMessage = {
"role": "tool",
"content": content,
"tool_call_id": function_message.name or "",
}
return tool_msg
raise ValueError(f"Unexpected message type: {msg.type}")

View File

@@ -1,11 +1,9 @@
import json
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
import onyx.tracing.framework._error_tracing as _error_tracing
from onyx.agents.agent_framework.models import RunItemStreamEvent
from onyx.agents.agent_framework.models import StreamEvent
from onyx.agents.agent_framework.models import ToolCallOutputStreamItem
@@ -18,10 +16,6 @@ from onyx.llm.message_types import ToolCall
from onyx.llm.model_response import ModelResponseStream
from onyx.tools.tool import RunContextWrapper
from onyx.tools.tool import Tool
from onyx.tracing.framework.create import agent_span
from onyx.tracing.framework.create import function_span
from onyx.tracing.framework.create import generation_span
from onyx.tracing.framework.spans import SpanError
@dataclass
@@ -39,75 +33,6 @@ def _serialize_tool_output(output: Any) -> str:
return str(output)
def _parse_tool_calls_from_message_content(
content: str,
) -> list[dict[str, Any]]:
"""Parse JSON content that represents tool call instructions."""
try:
parsed_content = json.loads(content)
except json.JSONDecodeError:
return []
if isinstance(parsed_content, dict):
candidates = [parsed_content]
elif isinstance(parsed_content, list):
candidates = [item for item in parsed_content if isinstance(item, dict)]
else:
return []
tool_calls: list[dict[str, Any]] = []
for candidate in candidates:
name = candidate.get("name")
arguments = candidate.get("arguments")
if not isinstance(name, str) or arguments is None:
continue
if not isinstance(arguments, dict):
continue
call_id = candidate.get("id")
arguments_str = json.dumps(arguments)
tool_calls.append(
{
"id": call_id,
"name": name,
"arguments": arguments_str,
}
)
return tool_calls
def _try_convert_content_to_tool_calls_for_non_tool_calling_llms(
tool_calls_in_progress: dict[int, dict[str, Any]],
content_parts: list[str],
structured_response_format: dict | None,
next_synthetic_tool_call_id: Callable[[], str],
) -> None:
"""Populate tool_calls_in_progress when a non-tool-calling LLM returns JSON content describing tool calls."""
if tool_calls_in_progress or not content_parts or structured_response_format:
return
tool_calls_from_content = _parse_tool_calls_from_message_content(
"".join(content_parts)
)
if not tool_calls_from_content:
return
content_parts.clear()
for index, tool_call_data in enumerate(tool_calls_from_content):
call_id = tool_call_data["id"] or next_synthetic_tool_call_id()
tool_calls_in_progress[index] = {
"id": call_id,
"name": tool_call_data["name"],
"arguments": tool_call_data["arguments"],
}
def _update_tool_call_with_delta(
tool_calls_in_progress: dict[int, dict[str, Any]],
tool_call_delta: Any,
@@ -140,225 +65,150 @@ def query(
tools: Sequence[Tool],
context: Any,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> QueryResult:
tool_definitions = [tool.tool_definition() for tool in tools]
tools_by_name = {tool.name: tool for tool in tools}
new_messages_stateful: list[ChatCompletionMessage] = []
current_span = agent_span(
name="agent_framework_query",
output_type="dict" if structured_response_format else "str",
)
current_span.start(mark_as_current=True)
current_span.span_data.tools = [t.name for t in tools]
def stream_generator() -> Iterator[StreamEvent]:
message_started = False
reasoning_started = False
message_started = False
tool_calls_in_progress: dict[int, dict[str, Any]] = {}
content_parts: list[str] = []
reasoning_parts: list[str] = []
synthetic_tool_call_counter = 0
for chunk in llm_with_default_settings.stream(
prompt=messages,
tools=tool_definitions,
tool_choice=tool_choice,
):
assert isinstance(chunk, ModelResponseStream)
def _next_synthetic_tool_call_id() -> str:
nonlocal synthetic_tool_call_counter
call_id = f"synthetic_tool_call_{synthetic_tool_call_counter}"
synthetic_tool_call_counter += 1
return call_id
delta = chunk.choice.delta
finish_reason = chunk.choice.finish_reason
with generation_span( # type: ignore[misc]
model=llm_with_default_settings.config.model_name,
model_config={
"base_url": str(llm_with_default_settings.config.api_base or ""),
"model_impl": "litellm",
},
) as span_generation:
# Only set input if messages is a sequence (not a string)
# ChatCompletionMessage TypedDicts are compatible with Mapping[str, Any] at runtime
if isinstance(messages, Sequence) and not isinstance(messages, str):
# Convert ChatCompletionMessage sequence to Sequence[Mapping[str, Any]]
span_generation.span_data.input = [dict(msg) for msg in messages] # type: ignore[assignment]
for chunk in llm_with_default_settings.stream(
prompt=messages,
tools=tool_definitions,
tool_choice=tool_choice,
structured_response_format=structured_response_format,
):
assert isinstance(chunk, ModelResponseStream)
usage = getattr(chunk, "usage", None)
if usage:
span_generation.span_data.usage = {
"input_tokens": usage.prompt_tokens,
"output_tokens": usage.completion_tokens,
"cache_read_input_tokens": usage.cache_read_input_tokens,
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
}
delta = chunk.choice.delta
finish_reason = chunk.choice.finish_reason
if delta.reasoning_content:
if not reasoning_started:
yield RunItemStreamEvent(type="reasoning_start")
reasoning_started = True
if delta.content:
if reasoning_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
content_parts.append(delta.content)
if not message_started:
yield RunItemStreamEvent(type="message_start")
message_started = True
if delta.tool_calls:
if reasoning_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
if message_started:
yield RunItemStreamEvent(type="message_done")
message_started = False
for tool_call_delta in delta.tool_calls:
_update_tool_call_with_delta(
tool_calls_in_progress, tool_call_delta
)
yield chunk
if not finish_reason:
continue
if delta.reasoning_content:
reasoning_parts.append(delta.reasoning_content)
if not reasoning_started:
yield RunItemStreamEvent(type="reasoning_start")
reasoning_started = True
if delta.content:
content_parts.append(delta.content)
if reasoning_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
if not message_started:
yield RunItemStreamEvent(type="message_start")
message_started = True
if delta.tool_calls:
if reasoning_started and not message_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
if message_started:
yield RunItemStreamEvent(type="message_done")
message_started = False
if tool_choice != "none":
_try_convert_content_to_tool_calls_for_non_tool_calling_llms(
tool_calls_in_progress,
content_parts,
structured_response_format,
_next_synthetic_tool_call_id,
for tool_call_delta in delta.tool_calls:
_update_tool_call_with_delta(
tool_calls_in_progress, tool_call_delta
)
if content_parts:
new_messages_stateful.append(
yield chunk
if not finish_reason:
continue
if message_started:
yield RunItemStreamEvent(type="message_done")
message_started = False
if finish_reason == "tool_calls" and tool_calls_in_progress:
sorted_tool_calls = sorted(tool_calls_in_progress.items())
# Build tool calls for the message and execute tools
assistant_tool_calls: list[ToolCall] = []
tool_outputs: dict[str, str] = {}
for _, tool_call_data in sorted_tool_calls:
call_id = tool_call_data["id"]
name = tool_call_data["name"]
arguments_str = tool_call_data["arguments"]
if call_id is None or name is None:
continue
assistant_tool_calls.append(
{
"role": "assistant",
"content": "".join(content_parts),
"id": call_id,
"type": "function",
"function": {
"name": name,
"arguments": arguments_str,
},
}
)
span_generation.span_data.output = new_messages_stateful
# Execute tool calls outside of the stream loop and generation_span
if tool_calls_in_progress:
sorted_tool_calls = sorted(tool_calls_in_progress.items())
yield RunItemStreamEvent(
type="tool_call",
details=ToolCallStreamItem(
call_id=call_id,
name=name,
arguments=arguments_str,
),
)
# Build tool calls for the message and execute tools
assistant_tool_calls: list[ToolCall] = []
tool_outputs: dict[str, str] = {}
if name in tools_by_name:
tool = tools_by_name[name]
arguments = json.loads(arguments_str)
for _, tool_call_data in sorted_tool_calls:
call_id = tool_call_data["id"]
name = tool_call_data["name"]
arguments_str = tool_call_data["arguments"]
run_context = RunContextWrapper(context=context)
if call_id is None or name is None:
continue
# TODO: Instead of executing sequentially, execute in parallel
# In practice, it's not a must right now since we don't use parallel
# tool calls, so kicking the can down the road for now.
output = tool.run_v2(run_context, **arguments)
tool_outputs[call_id] = _serialize_tool_output(output)
assistant_tool_calls.append(
yield RunItemStreamEvent(
type="tool_call_output",
details=ToolCallOutputStreamItem(
call_id=call_id,
output=output,
),
)
new_messages_stateful.append(
{
"id": call_id,
"type": "function",
"function": {
"name": name,
"arguments": arguments_str,
},
"role": "assistant",
"content": None,
"tool_calls": assistant_tool_calls,
}
)
yield RunItemStreamEvent(
type="tool_call",
details=ToolCallStreamItem(
call_id=call_id,
name=name,
arguments=arguments_str,
),
for _, tool_call_data in sorted_tool_calls:
call_id = tool_call_data["id"]
if call_id in tool_outputs:
new_messages_stateful.append(
{
"role": "tool",
"content": tool_outputs[call_id],
"tool_call_id": call_id,
}
)
elif finish_reason == "stop" and content_parts:
new_messages_stateful.append(
{
"role": "assistant",
"content": "".join(content_parts),
}
)
if name in tools_by_name:
tool = tools_by_name[name]
arguments = json.loads(arguments_str)
run_context = RunContextWrapper(context=context)
# TODO: Instead of executing sequentially, execute in parallel
# In practice, it's not a must right now since we don't use parallel
# tool calls, so kicking the can down the road for now.
with function_span(tool.name) as span_fn:
span_fn.span_data.input = arguments
try:
output = tool.run_v2(run_context, **arguments)
tool_outputs[call_id] = _serialize_tool_output(output)
span_fn.span_data.output = output
except Exception as e:
_error_tracing.attach_error_to_current_span(
SpanError(
message="Error running tool",
data={"tool_name": tool.name, "error": str(e)},
)
)
# Treat the error as the tool output so the framework can continue
error_output = f"Error: {str(e)}"
tool_outputs[call_id] = error_output
output = error_output
yield RunItemStreamEvent(
type="tool_call_output",
details=ToolCallOutputStreamItem(
call_id=call_id,
output=output,
),
)
else:
not_found_output = f"Tool {name} not found"
tool_outputs[call_id] = _serialize_tool_output(not_found_output)
yield RunItemStreamEvent(
type="tool_call_output",
details=ToolCallOutputStreamItem(
call_id=call_id,
output=not_found_output,
),
)
new_messages_stateful.append(
{
"role": "assistant",
"content": None,
"tool_calls": assistant_tool_calls,
}
)
for _, tool_call_data in sorted_tool_calls:
call_id = tool_call_data["id"]
if call_id in tool_outputs:
new_messages_stateful.append(
{
"role": "tool",
"content": tool_outputs[call_id],
"tool_call_id": call_id,
}
)
current_span.finish(reset_current=True)
return QueryResult(
stream=stream_generator(),
new_messages_stateful=new_messages_stateful,

View File

@@ -26,9 +26,9 @@ def monkey_patch_convert_tool_choice_to_ignore_openai_hosted_web_search() -> Non
# Without this patch, the library uses special formatting that breaks our custom tools
# See: https://platform.openai.com/docs/api-reference/responses/create#responses_create-tool_choice-hosted_tool-type
if tool_choice == "web_search":
return "web_search"
return {"type": "function", "name": "web_search"}
if tool_choice == "image_generation":
return "image_generation"
return {"type": "function", "name": "image_generation"}
return orig_func(cls, tool_choice)
OpenAIResponsesConverter.convert_tool_choice = classmethod( # type: ignore[method-assign, assignment]

View File

@@ -12,12 +12,13 @@ from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.configs.chat_configs import EXA_API_KEY
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.utils.retry_wrapper import retry_builder
class ExaClient(WebSearchProvider):
def __init__(self, api_key: str) -> None:
def __init__(self, api_key: str | None = EXA_API_KEY) -> None:
self.exa = Exa(api_key=api_key)
@retry_builder(tries=3, delay=1, backoff=2)

View File

@@ -1,163 +0,0 @@
from __future__ import annotations
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from datetime import datetime
from typing import Any
import requests
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebContent
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContentProvider,
)
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
FIRECRAWL_SCRAPE_URL = "https://api.firecrawl.dev/v1/scrape"
_DEFAULT_MAX_WORKERS = 4
@dataclass
class ExtractedContentFields:
text: str
title: str
published_date: datetime | None
class FirecrawlClient(WebContentProvider):
def __init__(
self,
api_key: str,
*,
base_url: str = FIRECRAWL_SCRAPE_URL,
timeout_seconds: int = 30,
) -> None:
self._headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
self._base_url = base_url
self._timeout_seconds = timeout_seconds
self._last_error: str | None = None
@property
def last_error(self) -> str | None:
return self._last_error
def contents(self, urls: Sequence[str]) -> list[WebContent]:
if not urls:
return []
max_workers = min(_DEFAULT_MAX_WORKERS, len(urls))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
return list(executor.map(self._get_webpage_content_safe, urls))
def _get_webpage_content_safe(self, url: str) -> WebContent:
try:
return self._get_webpage_content(url)
except Exception as exc:
self._last_error = str(exc)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
@retry_builder(tries=3, delay=1, backoff=2)
def _get_webpage_content(self, url: str) -> WebContent:
payload = {
"url": url,
"formats": ["markdown"],
}
response = requests.post(
self._base_url,
headers=self._headers,
json=payload,
timeout=self._timeout_seconds,
)
if response.status_code != 200:
try:
error_payload = response.json()
except Exception:
error_payload = response.text
self._last_error = (
error_payload if isinstance(error_payload, str) else str(error_payload)
)
if 400 <= response.status_code < 500:
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
raise ValueError(
f"Firecrawl fetch failed with status {response.status_code}."
)
else:
self._last_error = None
response_json = response.json()
extracted = self._extract_content_fields(response_json, url)
return WebContent(
title=extracted.title,
link=url,
full_content=extracted.text,
published_date=extracted.published_date,
scrape_successful=bool(extracted.text),
)
@staticmethod
def _extract_content_fields(
response_json: dict[str, Any], url: str
) -> ExtractedContentFields:
data_section = response_json.get("data") or {}
metadata = data_section.get("metadata") or response_json.get("metadata") or {}
text_candidates = [
data_section.get("markdown"),
data_section.get("content"),
data_section.get("text"),
response_json.get("markdown"),
response_json.get("content"),
response_json.get("text"),
]
text = next((candidate for candidate in text_candidates if candidate), "")
title = metadata.get("title") or response_json.get("title") or ""
published_date = None
published_date_str = (
metadata.get("publishedTime")
or metadata.get("date")
or response_json.get("publishedTime")
or response_json.get("date")
)
if published_date_str:
try:
published_date = time_str_to_utc(published_date_str)
except Exception:
published_date = None
if not text:
logger.warning(f"Firecrawl returned empty content for url={url}")
return ExtractedContentFields(
text=text or "",
title=title or "",
published_date=published_date,
)

View File

@@ -1,138 +0,0 @@
from __future__ import annotations
from collections.abc import Sequence
from datetime import datetime
from typing import Any
import requests
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebContent
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebSearchResult
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
GOOGLE_CUSTOM_SEARCH_URL = "https://customsearch.googleapis.com/customsearch/v1"
class GooglePSEClient(WebSearchProvider):
def __init__(
self,
api_key: str,
search_engine_id: str,
*,
num_results: int = 10,
timeout_seconds: int = 10,
) -> None:
self._api_key = api_key
self._search_engine_id = search_engine_id
self._num_results = num_results
self._timeout_seconds = timeout_seconds
@retry_builder(tries=3, delay=1, backoff=2)
def search(self, query: str) -> list[WebSearchResult]:
params: dict[str, str] = {
"key": self._api_key,
"cx": self._search_engine_id,
"q": query,
"num": str(self._num_results),
}
response = requests.get(
GOOGLE_CUSTOM_SEARCH_URL, params=params, timeout=self._timeout_seconds
)
# Check for HTTP errors first
try:
response.raise_for_status()
except requests.HTTPError as exc:
status = response.status_code
error_detail = "Unknown error"
try:
error_data = response.json()
if "error" in error_data:
error_info = error_data["error"]
error_detail = error_info.get("message", str(error_info))
except Exception:
error_detail = (
response.text[:200] if response.text else "No error details"
)
raise ValueError(
f"Google PSE search failed (status {status}): {error_detail}"
) from exc
data = response.json()
# Google Custom Search API can return errors in the response body even with 200 status
if "error" in data:
error_info = data["error"]
error_message = error_info.get("message", "Unknown error")
error_code = error_info.get("code", "Unknown")
raise ValueError(f"Google PSE API error ({error_code}): {error_message}")
items: list[dict[str, Any]] = data.get("items", [])
results: list[WebSearchResult] = []
for item in items:
link = item.get("link")
if not link:
continue
snippet = item.get("snippet") or ""
# Attempt to extract metadata if available
pagemap = item.get("pagemap") or {}
metatags = pagemap.get("metatags", [])
published_date: datetime | None = None
author: str | None = None
if metatags:
meta = metatags[0]
author = meta.get("og:site_name") or meta.get("author")
published_str = (
meta.get("article:published_time")
or meta.get("og:updated_time")
or meta.get("date")
)
if published_str:
try:
published_date = datetime.fromisoformat(
published_str.replace("Z", "+00:00")
)
except ValueError:
logger.debug(
f"Failed to parse published_date '{published_str}' for link {link}"
)
published_date = None
results.append(
WebSearchResult(
title=item.get("title") or "",
link=link,
snippet=snippet,
author=author,
published_date=published_date,
)
)
return results
def contents(self, urls: Sequence[str]) -> list[WebContent]:
logger.warning(
"Google PSE does not support content fetching; returning empty results."
)
return [
WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
for url in urls
]

View File

@@ -1,94 +0,0 @@
from __future__ import annotations
from collections.abc import Sequence
import requests
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContent,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContentProvider,
)
from onyx.file_processing.html_utils import ParsedHTML
from onyx.file_processing.html_utils import web_html_cleanup
from onyx.utils.logger import setup_logger
logger = setup_logger()
DEFAULT_TIMEOUT_SECONDS = 15
DEFAULT_USER_AGENT = "OnyxWebCrawler/1.0 (+https://www.onyx.app)"
class OnyxWebCrawlerClient(WebContentProvider):
"""
Lightweight built-in crawler that fetches HTML directly and extracts readable text.
Acts as the default content provider when no external crawler (e.g. Firecrawl) is
configured.
"""
def __init__(
self,
*,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
user_agent: str = DEFAULT_USER_AGENT,
) -> None:
self._timeout_seconds = timeout_seconds
self._headers = {
"User-Agent": user_agent,
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
}
def contents(self, urls: Sequence[str]) -> list[WebContent]:
results: list[WebContent] = []
for url in urls:
results.append(self._fetch_url(url))
return results
def _fetch_url(self, url: str) -> WebContent:
try:
response = requests.get(
url, headers=self._headers, timeout=self._timeout_seconds
)
except Exception as exc: # pragma: no cover - network failures vary
logger.warning(
"Onyx crawler failed to fetch %s (%s)",
url,
exc.__class__.__name__,
)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
if response.status_code >= 400:
logger.warning("Onyx crawler received %s for %s", response.status_code, url)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
try:
parsed: ParsedHTML = web_html_cleanup(response.text)
text_content = parsed.cleaned_text or ""
title = parsed.title or ""
except Exception as exc:
logger.warning(
"Onyx crawler failed to parse %s (%s)", url, exc.__class__.__name__
)
text_content = ""
title = ""
return WebContent(
title=title,
link=url,
full_content=text_content,
published_date=None,
scrape_successful=bool(text_content.strip()),
)

View File

@@ -13,6 +13,7 @@ from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.configs.chat_configs import SERPER_API_KEY
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.utils.retry_wrapper import retry_builder
@@ -21,7 +22,7 @@ SERPER_CONTENTS_URL = "https://scrape.serper.dev"
class SerperClient(WebSearchProvider):
def __init__(self, api_key: str) -> None:
def __init__(self, api_key: str | None = SERPER_API_KEY) -> None:
self.headers = {
"X-API-KEY": api_key,
"Content-Type": "application/json",
@@ -39,13 +40,7 @@ class SerperClient(WebSearchProvider):
data=json.dumps(payload),
)
try:
response.raise_for_status()
except Exception:
# Avoid leaking API keys/URLs
raise ValueError(
"Serper search failed. Check credentials or quota."
) from None
response.raise_for_status()
results = response.json()
organic_results = results["organic"]
@@ -104,13 +99,7 @@ class SerperClient(WebSearchProvider):
scrape_successful=False,
)
try:
response.raise_for_status()
except Exception:
# Avoid leaking API keys/URLs
raise ValueError(
"Serper content fetch failed. Check credentials."
) from None
response.raise_for_status()
response_json = response.json()

View File

@@ -74,22 +74,13 @@ def web_search(
if not provider:
raise ValueError("No internet search provider found")
# Log which provider type is being used
provider_type = type(provider).__name__
logger.info(
f"Performing web search with {provider_type} for query: '{search_query}'"
)
@traceable(name="Search Provider API Call")
def _search(search_query: str) -> list[WebSearchResult]:
search_results: list[WebSearchResult] = []
try:
search_results = list(provider.search(search_query))
logger.info(
f"Search returned {len(search_results)} results using {provider_type}"
)
except Exception as e:
logger.error(f"Error performing search with {provider_type}: {e}")
logger.error(f"Error performing search: {e}")
return search_results
search_results: list[WebSearchResult] = _search(search_query)

View File

@@ -5,7 +5,7 @@ from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
get_default_content_provider,
get_default_provider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchUpdate
@@ -41,9 +41,9 @@ def web_fetch(
if graph_config.inputs.persona is None:
raise ValueError("persona is not set")
provider = get_default_content_provider()
provider = get_default_provider()
if provider is None:
raise ValueError("No web content provider found")
raise ValueError("No web search provider found")
retrieved_docs: list[InferenceSection] = []
try:
@@ -52,7 +52,7 @@ def web_fetch(
for result in provider.contents(state.urls_to_open)
]
except Exception as e:
logger.exception(e)
logger.error(f"Error fetching URLs: {e}")
if not retrieved_docs:
logger.warning("No content retrieved from URLs")

View File

@@ -2,6 +2,7 @@ from abc import ABC
from abc import abstractmethod
from collections.abc import Sequence
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from pydantic import field_validator
@@ -9,6 +10,13 @@ from pydantic import field_validator
from onyx.utils.url import normalize_url
class ProviderType(Enum):
"""Enum for internet search provider types"""
GOOGLE = "google"
EXA = "exa"
class WebSearchResult(BaseModel):
title: str
link: str
@@ -35,13 +43,11 @@ class WebContent(BaseModel):
return normalize_url(v)
class WebContentProvider(ABC):
@abstractmethod
def contents(self, urls: Sequence[str]) -> list[WebContent]:
pass
class WebSearchProvider(WebContentProvider):
class WebSearchProvider(ABC):
@abstractmethod
def search(self, query: str) -> Sequence[WebSearchResult]:
pass
@abstractmethod
def contents(self, urls: Sequence[str]) -> list[WebContent]:
pass

View File

@@ -1,199 +1,19 @@
from typing import Any
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.exa_client import (
ExaClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.firecrawl_client import (
FIRECRAWL_SCRAPE_URL,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.firecrawl_client import (
FirecrawlClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.google_pse_client import (
GooglePSEClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.onyx_web_crawler_client import (
OnyxWebCrawlerClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.serper_client import (
SerperClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContentProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.web_search import fetch_active_web_content_provider
from onyx.db.web_search import fetch_active_web_search_provider
from onyx.utils.logger import setup_logger
from shared_configs.enums import WebContentProviderType
from shared_configs.enums import WebSearchProviderType
logger = setup_logger()
def build_search_provider_from_config(
*,
provider_type: WebSearchProviderType,
api_key: str | None,
config: dict[str, str] | None,
provider_name: str = "web_search_provider",
) -> WebSearchProvider | None:
provider_type_value = provider_type.value
try:
provider_type_enum = WebSearchProviderType(provider_type_value)
except ValueError:
logger.error(
f"Unknown web search provider type '{provider_type_value}'. "
"Skipping provider initialization."
)
return None
# All web search providers require an API key
if not api_key:
raise ValueError(
f"Web search provider '{provider_name}' is missing an API key."
)
assert api_key is not None
config = config or {}
if provider_type_enum == WebSearchProviderType.EXA:
return ExaClient(api_key=api_key)
if provider_type_enum == WebSearchProviderType.SERPER:
return SerperClient(api_key=api_key)
if provider_type_enum == WebSearchProviderType.GOOGLE_PSE:
search_engine_id = (
config.get("search_engine_id")
or config.get("cx")
or config.get("search_engine")
)
if not search_engine_id:
raise ValueError(
"Google PSE provider requires a search engine id (cx) in addition to the API key."
)
assert search_engine_id is not None
try:
num_results = int(config.get("num_results", 10))
except (TypeError, ValueError):
raise ValueError(
"Invalid value for Google PSE 'num_results'; expected integer."
)
try:
timeout_seconds = int(config.get("timeout_seconds", 10))
except (TypeError, ValueError):
raise ValueError(
"Invalid value for Google PSE 'timeout_seconds'; expected integer."
)
return GooglePSEClient(
api_key=api_key,
search_engine_id=search_engine_id,
num_results=num_results,
timeout_seconds=timeout_seconds,
)
logger.error(
f"Unhandled web search provider type '{provider_type_value}'. "
"Skipping provider initialization."
)
return None
def _build_search_provider(provider_model: Any) -> WebSearchProvider | None:
return build_search_provider_from_config(
provider_type=WebSearchProviderType(provider_model.provider_type),
api_key=provider_model.api_key,
config=provider_model.config or {},
provider_name=provider_model.name,
)
def build_content_provider_from_config(
*,
provider_type: WebContentProviderType,
api_key: str | None,
config: dict[str, str] | None,
provider_name: str = "web_content_provider",
) -> WebContentProvider | None:
provider_type_value = provider_type.value
try:
provider_type_enum = WebContentProviderType(provider_type_value)
except ValueError:
logger.error(
f"Unknown web content provider type '{provider_type_value}'. "
"Skipping provider initialization."
)
return None
if provider_type_enum == WebContentProviderType.ONYX_WEB_CRAWLER:
config = config or {}
timeout_value = config.get("timeout_seconds", 15)
try:
timeout_seconds = int(timeout_value)
except (TypeError, ValueError):
raise ValueError(
"Invalid value for Onyx Web Crawler 'timeout_seconds'; expected integer."
)
return OnyxWebCrawlerClient(timeout_seconds=timeout_seconds)
if provider_type_enum == WebContentProviderType.FIRECRAWL:
if not api_key:
raise ValueError("Firecrawl content provider requires an API key.")
assert api_key is not None
config = config or {}
timeout_seconds_str = config.get("timeout_seconds")
if timeout_seconds_str is None:
timeout_seconds = 10
else:
try:
timeout_seconds = int(timeout_seconds_str)
except (TypeError, ValueError):
raise ValueError(
"Invalid value for Firecrawl 'timeout_seconds'; expected integer."
)
return FirecrawlClient(
api_key=api_key,
base_url=config.get("base_url") or FIRECRAWL_SCRAPE_URL,
timeout_seconds=timeout_seconds,
)
logger.error(
f"Unhandled web content provider type '{provider_type_value}'. "
"Skipping provider initialization."
)
return None
def _build_content_provider(provider_model: Any) -> WebContentProvider | None:
return build_content_provider_from_config(
provider_type=WebContentProviderType(provider_model.provider_type),
api_key=provider_model.api_key,
config=provider_model.config or {},
provider_name=provider_model.name,
)
from onyx.configs.chat_configs import EXA_API_KEY
from onyx.configs.chat_configs import SERPER_API_KEY
def get_default_provider() -> WebSearchProvider | None:
with get_session_with_current_tenant() as db_session:
provider_model = fetch_active_web_search_provider(db_session)
if provider_model is None:
return None
return _build_search_provider(provider_model)
def get_default_content_provider() -> WebContentProvider | None:
with get_session_with_current_tenant() as db_session:
provider_model = fetch_active_web_content_provider(db_session)
if provider_model:
provider = _build_content_provider(provider_model)
if provider:
return provider
# Fall back to built-in Onyx crawler when nothing is configured.
try:
return OnyxWebCrawlerClient()
except Exception as exc: # pragma: no cover - defensive
logger.error(f"Failed to initialize default Onyx crawler: {exc}")
return None
if EXA_API_KEY:
return ExaClient()
if SERPER_API_KEY:
return SerperClient()
return None

View File

@@ -11,8 +11,11 @@ from jwt import decode as jwt_decode
from jwt import InvalidTokenError
from jwt import PyJWTError
from jwt.algorithms import RSAAlgorithm
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from onyx.configs.app_configs import JWT_PUBLIC_KEY_URL
from onyx.db.models import User
from onyx.utils.logger import setup_logger
@@ -128,7 +131,7 @@ def _resolve_public_key_from_jwks(
return None
async def verify_jwt_token(token: str) -> dict[str, Any] | None:
async def verify_jwt_token(token: str, async_db_session: AsyncSession) -> User | None:
for attempt in range(_PUBLIC_KEY_FETCH_ATTEMPTS):
public_key = get_public_key(token)
if public_key is None:
@@ -139,6 +142,8 @@ async def verify_jwt_token(token: str) -> dict[str, Any] | None:
return None
try:
from sqlalchemy import func
payload = jwt_decode(
token,
public_key,
@@ -158,6 +163,15 @@ async def verify_jwt_token(token: str) -> dict[str, Any] | None:
continue
return None
return payload
email = payload.get("email")
if email:
result = await async_db_session.execute(
select(User).where(func.lower(User.email) == func.lower(email))
)
return result.scalars().first()
logger.warning(
"JWT token decoded successfully but no email claim found; skipping auth"
)
break
return None

View File

@@ -1063,107 +1063,6 @@ fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
optional_fastapi_current_user = fastapi_users.current_user(active=True, optional=True)
_JWT_EMAIL_CLAIM_KEYS = ("email", "preferred_username", "upn")
def _extract_email_from_jwt(payload: dict[str, Any]) -> str | None:
"""Return the best-effort email/username from a decoded JWT payload."""
for key in _JWT_EMAIL_CLAIM_KEYS:
value = payload.get(key)
if isinstance(value, str) and value:
try:
email_info = validate_email(value, check_deliverability=False)
except EmailNotValidError:
continue
normalized_email = email_info.normalized or email_info.email
return normalized_email.lower()
return None
async def _sync_jwt_oidc_expiry(
user_manager: UserManager, user: User, payload: dict[str, Any]
) -> None:
if TRACK_EXTERNAL_IDP_EXPIRY:
expires_at = payload.get("exp")
if expires_at is None:
return
try:
expiry_timestamp = int(expires_at)
except (TypeError, ValueError):
logger.warning("Invalid exp claim on JWT for user %s", user.email)
return
oidc_expiry = datetime.fromtimestamp(expiry_timestamp, tz=timezone.utc)
if user.oidc_expiry == oidc_expiry:
return
await user_manager.user_db.update(user, {"oidc_expiry": oidc_expiry})
user.oidc_expiry = oidc_expiry # type: ignore
return
if user.oidc_expiry is not None:
await user_manager.user_db.update(user, {"oidc_expiry": None})
user.oidc_expiry = None # type: ignore
async def _get_or_create_user_from_jwt(
payload: dict[str, Any],
request: Request,
async_db_session: AsyncSession,
) -> User | None:
email = _extract_email_from_jwt(payload)
if email is None:
logger.warning(
"JWT token decoded successfully but no email claim found; skipping auth"
)
return None
# Enforce the same allowlist/domain policies as other auth flows
verify_email_is_invited(email)
verify_email_domain(email)
user_db: SQLAlchemyUserAdminDB[User, uuid.UUID] = SQLAlchemyUserAdminDB(
async_db_session, User, OAuthAccount
)
user_manager = UserManager(user_db)
try:
user = await user_manager.get_by_email(email)
if not user.is_active:
logger.warning("Inactive user %s attempted JWT login; skipping", email)
return None
if not user.role.is_web_login():
raise exceptions.UserNotExists()
except exceptions.UserNotExists:
logger.info("Provisioning user %s from JWT login", email)
try:
user = await user_manager.create(
UserCreate(
email=email,
password=generate_password(),
is_verified=True,
),
request=request,
)
except exceptions.UserAlreadyExists:
user = await user_manager.get_by_email(email)
if not user.is_active:
logger.warning(
"Inactive user %s attempted JWT login during provisioning race; skipping",
email,
)
return None
if not user.role.is_web_login():
logger.warning(
"Non-web-login user %s attempted JWT login during provisioning race; skipping",
email,
)
return None
await _sync_jwt_oidc_expiry(user_manager, user, payload)
return user
async def _check_for_saml_and_jwt(
request: Request,
user: User | None,
@@ -1174,11 +1073,7 @@ async def _check_for_saml_and_jwt(
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header[len("Bearer ") :].strip()
payload = await verify_jwt_token(token)
if payload is not None:
user = await _get_or_create_user_from_jwt(
payload, request, async_db_session
)
user = await verify_jwt_token(token, async_db_session)
return user

View File

@@ -53,8 +53,8 @@ from onyx.server.query_and_chat.streaming_models import PacketObj
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.tools.adapter_v1_to_v2 import force_use_tool_to_function_tool_names
from onyx.tools.adapter_v1_to_v2 import tools_to_function_tools
from onyx.tools.force import filter_tools_for_force_tool_use
from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool
@@ -109,10 +109,6 @@ def _run_agent_loop(
available_tools: Sequence[Tool] = (
dependencies.tools if iteration_count < MAX_ITERATIONS else []
)
if force_use_tool and force_use_tool.force_use:
available_tools = filter_tools_for_force_tool_use(
list(available_tools), force_use_tool
)
memories = get_memories(dependencies.user_or_none, dependencies.db_session)
# TODO: The system is rather prompt-cache efficient except for rebuilding the system prompt.
# The biggest offender is when we hit max iterations and then all the tool calls cannot
@@ -146,8 +142,10 @@ def _run_agent_loop(
tool_choice = None
else:
tool_choice = (
"required" if force_use_tool and force_use_tool.force_use else "auto"
)
force_use_tool_to_function_tool_names(force_use_tool, available_tools)
if iteration_count == 0 and force_use_tool
else None
) or "auto"
model_settings = replace(dependencies.model_settings, tool_choice=tool_choice)
agent = Agent(

View File

@@ -89,6 +89,9 @@ STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true"
# Internet Search
EXA_API_KEY = os.environ.get("EXA_API_KEY") or None
SERPER_API_KEY = os.environ.get("SERPER_API_KEY") or None
NUM_INTERNET_SEARCH_RESULTS = int(os.environ.get("NUM_INTERNET_SEARCH_RESULTS") or 10)
NUM_INTERNET_SEARCH_CHUNKS = int(os.environ.get("NUM_INTERNET_SEARCH_CHUNKS") or 50)

View File

@@ -426,7 +426,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
title = ticket.properties.get("subject") or f"Ticket {ticket.id}"
link = self._get_object_url("tickets", ticket.id)
content_text = ticket.properties.get("content") or ""
content_text = ticket.properties.get("content", "")
# Main ticket section
sections = [TextSection(link=link, text=content_text)]

View File

@@ -15,14 +15,10 @@ from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import TextSection
from onyx.context.search.federated.models import SlackMessage
from onyx.context.search.federated.slack_search_utils import ALL_CHANNEL_TYPES
from onyx.context.search.federated.slack_search_utils import build_channel_query_filter
from onyx.context.search.federated.slack_search_utils import build_slack_queries
from onyx.context.search.federated.slack_search_utils import ChannelTypeString
from onyx.context.search.federated.slack_search_utils import get_channel_type
from onyx.context.search.federated.slack_search_utils import (
get_channel_type_for_missing_scope,
)
from onyx.context.search.federated.slack_search_utils import is_recency_query
from onyx.context.search.federated.slack_search_utils import should_include_message
from onyx.context.search.models import InferenceChunk
@@ -50,6 +46,7 @@ logger = setup_logger()
HIGHLIGHT_START_CHAR = "\ue000"
HIGHLIGHT_END_CHAR = "\ue001"
CHANNEL_TYPES = ["public_channel", "im", "mpim", "private_channel"]
CHANNEL_METADATA_CACHE_TTL = 60 * 60 * 24 # 24 hours
SLACK_THREAD_CONTEXT_WINDOW = 3 # Number of messages before matched message to include
CHANNEL_METADATA_MAX_RETRIES = 3 # Maximum retry attempts for channel metadata fetching
@@ -101,12 +98,10 @@ def fetch_and_cache_channel_metadata(
# Retry logic with exponential backoff
last_exception = None
available_channel_types = ALL_CHANNEL_TYPES.copy()
for attempt in range(CHANNEL_METADATA_MAX_RETRIES):
try:
# Use available channel types (may be reduced if scopes are missing)
channel_types = ",".join(available_channel_types)
# ALWAYS fetch all channel types including private
channel_types = ",".join(CHANNEL_TYPES)
# Fetch all channels in one call
cursor = None
@@ -162,42 +157,6 @@ def fetch_and_cache_channel_metadata(
except SlackApiError as e:
last_exception = e
# Extract all needed fields from response upfront
if e.response:
error_response = e.response.get("error", "")
needed_scope = e.response.get("needed", "")
else:
error_response = ""
needed_scope = ""
# Check if this is a missing_scope error
if error_response == "missing_scope":
# Get the channel type that requires this scope
missing_channel_type = get_channel_type_for_missing_scope(needed_scope)
if (
missing_channel_type
and missing_channel_type in available_channel_types
):
# Remove the problematic channel type and retry
available_channel_types.remove(missing_channel_type)
logger.warning(
f"Missing scope '{needed_scope}' for channel type '{missing_channel_type}'. "
f"Continuing with reduced channel types: {available_channel_types}"
)
# Don't count this as a retry attempt, just try again with fewer types
if available_channel_types: # Only continue if we have types left
continue
# Otherwise fall through to retry logic
else:
logger.error(
f"Missing scope '{needed_scope}' but could not map to channel type or already removed. "
f"Response: {e.response}"
)
# For other errors, use retry logic
if attempt < CHANNEL_METADATA_MAX_RETRIES - 1:
retry_delay = CHANNEL_METADATA_RETRY_DELAY * (2**attempt)
logger.warning(
@@ -210,15 +169,7 @@ def fetch_and_cache_channel_metadata(
f"Failed to fetch channel metadata after {CHANNEL_METADATA_MAX_RETRIES} attempts: {e}"
)
# If we have some channel metadata despite errors, return it with a warning
if channel_metadata:
logger.warning(
f"Returning partial channel metadata ({len(channel_metadata)} channels) despite errors. "
f"Last error: {last_exception}"
)
return channel_metadata
# If we exhausted all retries and have no data, raise the last exception
# If we exhausted all retries, raise the last exception
if last_exception:
raise SlackApiError(
f"Channel metadata fetching failed after {CHANNEL_METADATA_MAX_RETRIES} attempts",

View File

@@ -29,9 +29,6 @@ DAYS_PER_WEEK = 7
DAYS_PER_MONTH = 30
MAX_CONTENT_WORDS = 3
# Punctuation to strip from words during analysis
WORD_PUNCTUATION = ".,!?;:\"'#"
RECENCY_KEYWORDS = ["recent", "latest", "newest", "last"]
@@ -44,48 +41,6 @@ class ChannelTypeString(str, Enum):
PUBLIC_CHANNEL = "public_channel"
# All Slack channel types for fetching metadata
ALL_CHANNEL_TYPES = [
ChannelTypeString.PUBLIC_CHANNEL.value,
ChannelTypeString.IM.value,
ChannelTypeString.MPIM.value,
ChannelTypeString.PRIVATE_CHANNEL.value,
]
# Map Slack API scopes to their corresponding channel types
# This is used for graceful degradation when scopes are missing
SCOPE_TO_CHANNEL_TYPE_MAP = {
"mpim:read": ChannelTypeString.MPIM.value,
"mpim:history": ChannelTypeString.MPIM.value,
"im:read": ChannelTypeString.IM.value,
"im:history": ChannelTypeString.IM.value,
"groups:read": ChannelTypeString.PRIVATE_CHANNEL.value,
"groups:history": ChannelTypeString.PRIVATE_CHANNEL.value,
"channels:read": ChannelTypeString.PUBLIC_CHANNEL.value,
"channels:history": ChannelTypeString.PUBLIC_CHANNEL.value,
}
def get_channel_type_for_missing_scope(scope: str) -> str | None:
"""Get the channel type that requires a specific Slack scope.
Args:
scope: The Slack API scope (e.g., 'mpim:read', 'im:history')
Returns:
The channel type string if scope is recognized, None otherwise
Examples:
>>> get_channel_type_for_missing_scope('mpim:read')
'mpim'
>>> get_channel_type_for_missing_scope('im:read')
'im'
>>> get_channel_type_for_missing_scope('unknown:scope')
None
"""
return SCOPE_TO_CHANNEL_TYPE_MAP.get(scope)
def _parse_llm_code_block_response(response: str) -> str:
"""Remove code block markers from LLM response if present.
@@ -109,40 +64,11 @@ def _parse_llm_code_block_response(response: str) -> str:
def is_recency_query(query: str) -> bool:
"""Check if a query is primarily about recency (not content + recency).
Returns True only for pure recency queries like "recent messages" or "latest updates",
but False for queries with content + recency like "golf scores last saturday".
"""
# Check if query contains recency keywords
has_recency_keyword = any(
return any(
re.search(rf"\b{re.escape(keyword)}\b", query, flags=re.IGNORECASE)
for keyword in RECENCY_KEYWORDS
)
if not has_recency_keyword:
return False
# Get combined stop words (NLTK + Slack-specific)
all_stop_words = _get_combined_stop_words()
# Extract content words (excluding stop words)
query_lower = query.lower()
words = query_lower.split()
# Count content words (not stop words, length > 2)
content_word_count = 0
for word in words:
clean_word = word.strip(WORD_PUNCTUATION)
if clean_word and len(clean_word) > 2 and clean_word not in all_stop_words:
content_word_count += 1
# If query has significant content words (>= 2), it's not a pure recency query
# Examples:
# - "recent messages" -> content_word_count = 0 -> pure recency
# - "golf scores last saturday" -> content_word_count = 3 (golf, scores, saturday) -> not pure recency
return content_word_count < 2
def extract_date_range_from_query(
query: str,
@@ -157,21 +83,6 @@ def extract_date_range_from_query(
if re.search(r"\byesterday\b", query_lower):
return min(1, default_search_days)
# Handle "last [day of week]" - e.g., "last monday", "last saturday"
days_of_week = [
"monday",
"tuesday",
"wednesday",
"thursday",
"friday",
"saturday",
"sunday",
]
for day in days_of_week:
if re.search(rf"\b(?:last|this)\s+{day}\b", query_lower):
# Assume last occurrence of that day was within the past week
return min(DAYS_PER_WEEK, default_search_days)
match = re.search(r"\b(?:last|past)\s+(\d+)\s+days?\b", query_lower)
if match:
days = int(match.group(1))
@@ -209,40 +120,22 @@ def extract_date_range_from_query(
try:
data = json.loads(response_clean)
if not isinstance(data, dict):
logger.debug(
f"LLM date extraction returned non-dict response for query: "
f"'{query}', using default: {default_search_days} days"
)
return default_search_days
days_back = data.get("days_back")
if days_back is None:
logger.debug(
f"LLM date extraction returned null for query: '{query}', "
f"using default: {default_search_days} days"
f"LLM date extraction returned null for query: '{query}', using default: {default_search_days} days"
)
return default_search_days
if not isinstance(days_back, (int, float)):
logger.debug(
f"LLM date extraction returned non-numeric days_back for "
f"query: '{query}', using default: {default_search_days} days"
)
return default_search_days
except json.JSONDecodeError:
logger.debug(
f"Failed to parse LLM date extraction response for query: '{query}' "
f"(response: '{response_clean}'), "
f"using default: {default_search_days} days"
f"Failed to parse LLM date extraction response for query: '{query}', using default: {default_search_days} days"
)
return default_search_days
return min(int(days_back), default_search_days)
return min(days_back, default_search_days)
except Exception as e:
logger.warning(f"Error extracting date range with LLM for query '{query}': {e}")
logger.warning(f"Error extracting date range with LLM: {e}")
return default_search_days
@@ -520,29 +413,6 @@ SLACK_SPECIFIC_STOP_WORDS = frozenset(
)
def _get_combined_stop_words() -> set[str]:
"""Get combined NLTK + Slack-specific stop words.
Returns a set of stop words for filtering content words.
Falls back to just Slack-specific stop words if NLTK is unavailable.
Note: Currently only supports English stop words. Non-English queries
may have suboptimal content word extraction. Future enhancement could
detect query language and load appropriate stop words.
"""
try:
from nltk.corpus import stopwords # type: ignore
# TODO: Support multiple languages - currently hardcoded to English
# Could detect language or allow configuration
nltk_stop_words = set(stopwords.words("english"))
except Exception:
# Fallback if NLTK not available
nltk_stop_words = set()
return nltk_stop_words | SLACK_SPECIFIC_STOP_WORDS
def extract_content_words_from_recency_query(
query_text: str, channel_references: set[str]
) -> list[str]:
@@ -557,19 +427,28 @@ def extract_content_words_from_recency_query(
Returns:
List of content words (up to MAX_CONTENT_WORDS)
"""
# Get combined stop words (NLTK + Slack-specific)
all_stop_words = _get_combined_stop_words()
# Get standard English stop words from NLTK (lazy import)
try:
from nltk.corpus import stopwords # type: ignore
nltk_stop_words = set(stopwords.words("english"))
except Exception:
# Fallback if NLTK not available
nltk_stop_words = set()
# Combine NLTK stop words with Slack-specific stop words
all_stop_words = nltk_stop_words | SLACK_SPECIFIC_STOP_WORDS
words = query_text.split()
content_words = []
for word in words:
clean_word = word.lower().strip(WORD_PUNCTUATION)
clean_word = word.lower().strip(".,!?;:\"'#")
# Skip if it's a channel reference or a stop word
if clean_word in channel_references:
continue
if clean_word and clean_word not in all_stop_words and len(clean_word) > 2:
clean_word_orig = word.strip(WORD_PUNCTUATION)
clean_word_orig = word.strip(".,!?;:\"'#")
if clean_word_orig.lower() not in all_stop_words:
content_words.append(clean_word_orig)

View File

@@ -442,25 +442,10 @@ def set_cc_pair_repeated_error_state(
cc_pair_id: int,
in_repeated_error_state: bool,
) -> None:
values: dict = {"in_repeated_error_state": in_repeated_error_state}
# When entering repeated error state, also pause the connector
# to prevent continued indexing retry attempts.
# However, don't pause if there's an active manual indexing trigger,
# which indicates the user wants to retry immediately.
if in_repeated_error_state:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
# Only pause if there's no manual indexing trigger active
if cc_pair and cc_pair.indexing_trigger is None:
values["status"] = ConnectorCredentialPairStatus.PAUSED
stmt = (
update(ConnectorCredentialPair)
.where(ConnectorCredentialPair.id == cc_pair_id)
.values(**values)
.values(in_repeated_error_state=in_repeated_error_state)
)
db_session.execute(stmt)
db_session.commit()

View File

@@ -2482,50 +2482,6 @@ class CloudEmbeddingProvider(Base):
return f"<EmbeddingProvider(type='{self.provider_type}')>"
class InternetSearchProvider(Base):
__tablename__ = "internet_search_provider"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True, nullable=False)
provider_type: Mapped[str] = mapped_column(String, nullable=False)
api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
config: Mapped[dict[str, str] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
def __repr__(self) -> str:
return f"<InternetSearchProvider(name='{self.name}', provider_type='{self.provider_type}')>"
class InternetContentProvider(Base):
__tablename__ = "internet_content_provider"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True, nullable=False)
provider_type: Mapped[str] = mapped_column(String, nullable=False)
api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
config: Mapped[dict[str, str] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
def __repr__(self) -> str:
return f"<InternetContentProvider(name='{self.name}', provider_type='{self.provider_type}')>"
class DocumentSet(Base):
__tablename__ = "document_set"

View File

@@ -1,309 +0,0 @@
from __future__ import annotations
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.db.models import InternetContentProvider
from onyx.db.models import InternetSearchProvider
from shared_configs.enums import WebContentProviderType
from shared_configs.enums import WebSearchProviderType
def fetch_web_search_providers(db_session: Session) -> list[InternetSearchProvider]:
stmt = select(InternetSearchProvider).order_by(InternetSearchProvider.id.asc())
return list(db_session.scalars(stmt).all())
def fetch_web_content_providers(db_session: Session) -> list[InternetContentProvider]:
stmt = select(InternetContentProvider).order_by(InternetContentProvider.id.asc())
return list(db_session.scalars(stmt).all())
def fetch_active_web_search_provider(
db_session: Session,
) -> InternetSearchProvider | None:
stmt = select(InternetSearchProvider).where(
InternetSearchProvider.is_active.is_(True)
)
return db_session.scalars(stmt).first()
def fetch_web_search_provider_by_id(
provider_id: int, db_session: Session
) -> InternetSearchProvider | None:
return db_session.get(InternetSearchProvider, provider_id)
def fetch_web_search_provider_by_name(
name: str, db_session: Session
) -> InternetSearchProvider | None:
stmt = select(InternetSearchProvider).where(InternetSearchProvider.name.ilike(name))
return db_session.scalars(stmt).first()
def _ensure_unique_search_name(
name: str, provider_id: int | None, db_session: Session
) -> None:
existing = fetch_web_search_provider_by_name(name=name, db_session=db_session)
if existing and existing.id != provider_id:
raise ValueError(f"A web search provider named '{name}' already exists.")
def _apply_search_provider_updates(
provider: InternetSearchProvider,
*,
name: str,
provider_type: WebSearchProviderType,
api_key: str | None,
api_key_changed: bool,
config: dict[str, str] | None,
) -> None:
provider.name = name
provider.provider_type = provider_type.value
provider.config = config
if api_key_changed or provider.api_key is None:
provider.api_key = api_key
def upsert_web_search_provider(
*,
provider_id: int | None,
name: str,
provider_type: WebSearchProviderType,
api_key: str | None,
api_key_changed: bool,
config: dict[str, str] | None,
activate: bool,
db_session: Session,
) -> InternetSearchProvider:
_ensure_unique_search_name(
name=name, provider_id=provider_id, db_session=db_session
)
provider: InternetSearchProvider | None = None
if provider_id is not None:
provider = fetch_web_search_provider_by_id(provider_id, db_session)
if provider is None:
raise ValueError(f"No web search provider with id {provider_id} exists.")
else:
provider = InternetSearchProvider()
db_session.add(provider)
_apply_search_provider_updates(
provider,
name=name,
provider_type=provider_type,
api_key=api_key,
api_key_changed=api_key_changed,
config=config,
)
db_session.flush()
if activate:
set_active_web_search_provider(provider_id=provider.id, db_session=db_session)
db_session.commit()
db_session.refresh(provider)
return provider
def set_active_web_search_provider(
*, provider_id: int | None, db_session: Session
) -> InternetSearchProvider:
if provider_id is None:
raise ValueError("Cannot activate a provider without an id.")
provider = fetch_web_search_provider_by_id(provider_id, db_session)
if provider is None:
raise ValueError(f"No web search provider with id {provider_id} exists.")
db_session.execute(
update(InternetSearchProvider)
.where(
InternetSearchProvider.is_active.is_(True),
InternetSearchProvider.id != provider_id,
)
.values(is_active=False)
)
provider.is_active = True
db_session.flush()
db_session.refresh(provider)
return provider
def deactivate_web_search_provider(
*, provider_id: int | None, db_session: Session
) -> InternetSearchProvider:
if provider_id is None:
raise ValueError("Cannot deactivate a provider without an id.")
provider = fetch_web_search_provider_by_id(provider_id, db_session)
if provider is None:
raise ValueError(f"No web search provider with id {provider_id} exists.")
provider.is_active = False
db_session.flush()
db_session.refresh(provider)
return provider
def delete_web_search_provider(provider_id: int, db_session: Session) -> None:
provider = fetch_web_search_provider_by_id(provider_id, db_session)
if provider is None:
raise ValueError(f"No web search provider with id {provider_id} exists.")
db_session.delete(provider)
db_session.flush()
db_session.commit()
# Content provider helpers
def fetch_active_web_content_provider(
db_session: Session,
) -> InternetContentProvider | None:
stmt = select(InternetContentProvider).where(
InternetContentProvider.is_active.is_(True)
)
return db_session.scalars(stmt).first()
def fetch_web_content_provider_by_id(
provider_id: int, db_session: Session
) -> InternetContentProvider | None:
return db_session.get(InternetContentProvider, provider_id)
def fetch_web_content_provider_by_name(
name: str, db_session: Session
) -> InternetContentProvider | None:
stmt = select(InternetContentProvider).where(
InternetContentProvider.name.ilike(name)
)
return db_session.scalars(stmt).first()
def _ensure_unique_content_name(
name: str, provider_id: int | None, db_session: Session
) -> None:
existing = fetch_web_content_provider_by_name(name=name, db_session=db_session)
if existing and existing.id != provider_id:
raise ValueError(f"A web content provider named '{name}' already exists.")
def _apply_content_provider_updates(
provider: InternetContentProvider,
*,
name: str,
provider_type: WebContentProviderType,
api_key: str | None,
api_key_changed: bool,
config: dict[str, str] | None,
) -> None:
provider.name = name
provider.provider_type = provider_type.value
provider.config = config
if api_key_changed or provider.api_key is None:
provider.api_key = api_key
def upsert_web_content_provider(
*,
provider_id: int | None,
name: str,
provider_type: WebContentProviderType,
api_key: str | None,
api_key_changed: bool,
config: dict[str, str] | None,
activate: bool,
db_session: Session,
) -> InternetContentProvider:
_ensure_unique_content_name(
name=name, provider_id=provider_id, db_session=db_session
)
provider: InternetContentProvider | None = None
if provider_id is not None:
provider = fetch_web_content_provider_by_id(provider_id, db_session)
if provider is None:
raise ValueError(f"No web content provider with id {provider_id} exists.")
else:
provider = InternetContentProvider()
db_session.add(provider)
_apply_content_provider_updates(
provider,
name=name,
provider_type=provider_type,
api_key=api_key,
api_key_changed=api_key_changed,
config=config,
)
db_session.flush()
if activate:
set_active_web_content_provider(provider_id=provider.id, db_session=db_session)
db_session.commit()
db_session.refresh(provider)
return provider
def set_active_web_content_provider(
*, provider_id: int | None, db_session: Session
) -> InternetContentProvider:
if provider_id is None:
raise ValueError("Cannot activate a provider without an id.")
provider = fetch_web_content_provider_by_id(provider_id, db_session)
if provider is None:
raise ValueError(f"No web content provider with id {provider_id} exists.")
db_session.execute(
update(InternetContentProvider)
.where(
InternetContentProvider.is_active.is_(True),
InternetContentProvider.id != provider_id,
)
.values(is_active=False)
)
provider.is_active = True
db_session.flush()
db_session.refresh(provider)
return provider
def deactivate_web_content_provider(
*, provider_id: int | None, db_session: Session
) -> InternetContentProvider:
if provider_id is None:
raise ValueError("Cannot deactivate a provider without an id.")
provider = fetch_web_content_provider_by_id(provider_id, db_session)
if provider is None:
raise ValueError(f"No web content provider with id {provider_id} exists.")
provider.is_active = False
db_session.flush()
db_session.refresh(provider)
return provider
def delete_web_content_provider(provider_id: int, db_session: Session) -> None:
provider = fetch_web_content_provider_by_id(provider_id, db_session)
if provider is None:
raise ValueError(f"No web content provider with id {provider_id} exists.")
db_session.delete(provider)
db_session.flush()
db_session.commit()

View File

@@ -510,7 +510,6 @@ class LitellmLLM(LLM):
# model params
temperature=(1 if is_reasoning else self._temperature),
timeout=timeout_override or self._timeout,
**({"stream_options": {"include_usage": True}} if stream else {}),
# For now, we don't support parallel tool calls
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error

View File

@@ -3,87 +3,29 @@ import time
import uuid
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import TypedDict
from typing import Union
from litellm import AllMessageValues
from litellm.completion_extras.litellm_responses_transformation.transformation import (
OpenAiResponsesToChatCompletionStreamIterator,
)
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
try:
from litellm.litellm_core_utils.prompt_templates.common_utils import (
extract_images_from_message,
)
except ImportError:
extract_images_from_message = None # type: ignore[assignment]
from litellm.litellm_core_utils.prompt_templates.common_utils import (
extract_images_from_message,
)
from litellm.llms.ollama.chat.transformation import OllamaChatCompletionResponseIterator
from litellm.llms.ollama.chat.transformation import OllamaChatConfig
from litellm.llms.ollama.common_utils import OllamaError
try:
from litellm.types.llms.ollama import OllamaChatCompletionMessage
except ImportError:
class OllamaChatCompletionMessage(TypedDict, total=False): # type: ignore[no-redef]
"""Fallback for LiteLLM versions where this TypedDict was removed."""
role: str
content: Optional[str]
images: Optional[List[Any]]
thinking: Optional[str]
tool_calls: Optional[List["OllamaToolCall"]]
from litellm.types.llms.ollama import OllamaChatCompletionMessage
from litellm.types.llms.ollama import OllamaToolCall
from litellm.types.llms.ollama import OllamaToolCallFunction
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
from litellm.types.utils import ChatCompletionUsageBlock
from litellm.types.utils import GenericStreamingChunk
from litellm.types.utils import ModelResponseStream
from litellm.utils import verbose_logger
from pydantic import BaseModel
if extract_images_from_message is None:
def extract_images_from_message(
message: AllMessageValues,
) -> Optional[List[Any]]:
"""Fallback for LiteLLM versions that dropped extract_images_from_message."""
images: List[Any] = []
content = message.get("content")
if not isinstance(content, list):
return None
for item in content:
if not isinstance(item, Dict):
continue
item_type = item.get("type")
if item_type == "image_url":
image_url = item.get("image_url")
if isinstance(image_url, dict):
if image_url.get("url"):
images.append(image_url)
elif image_url:
images.append(image_url)
elif item_type in {"input_image", "image"}:
image_value = item.get("image")
if image_value:
images.append(image_value)
return images or None
def _patch_ollama_transform_request() -> None:
"""
Patches OllamaChatConfig.transform_request to handle reasoning content
@@ -312,189 +254,16 @@ def _patch_ollama_chunk_parser() -> None:
OllamaChatCompletionResponseIterator.chunk_parser = _patched_chunk_parser # type: ignore[method-assign]
def _patch_openai_responses_chunk_parser() -> None:
"""
Patches OpenAiResponsesToChatCompletionStreamIterator.chunk_parser to properly
handle OpenAI Responses API streaming format and convert it to chat completion format.
"""
if (
getattr(
OpenAiResponsesToChatCompletionStreamIterator.chunk_parser,
"__name__",
"",
)
== "_patched_openai_responses_chunk_parser"
):
return
def _patched_openai_responses_chunk_parser(
self: Any, chunk: dict
) -> Union["GenericStreamingChunk", "ModelResponseStream"]:
# Transform responses API streaming chunk to chat completion format
from litellm.types.llms.openai import ChatCompletionToolCallFunctionChunk
from litellm.types.utils import (
ChatCompletionToolCallChunk,
GenericStreamingChunk,
)
verbose_logger.debug(
f"Chat provider: transform_streaming_response called with chunk: {chunk}"
)
parsed_chunk = chunk
if not parsed_chunk:
raise ValueError("Chat provider: Empty parsed_chunk")
if not isinstance(parsed_chunk, dict):
raise ValueError(f"Chat provider: Invalid chunk type {type(parsed_chunk)}")
# Handle different event types from responses API
event_type = parsed_chunk.get("type")
verbose_logger.debug(f"Chat provider: Processing event type: {event_type}")
if event_type == "response.created":
# Initial response creation event
verbose_logger.debug(f"Chat provider: response.created -> {chunk}")
return GenericStreamingChunk(
text="", tool_use=None, is_finished=False, finish_reason="", usage=None
)
elif event_type == "response.output_item.added":
# New output item added
output_item = parsed_chunk.get("item", {})
if output_item.get("type") == "function_call":
return GenericStreamingChunk(
text="",
tool_use=ChatCompletionToolCallChunk(
id=output_item.get("call_id"),
index=0,
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=output_item.get("name", None),
arguments=parsed_chunk.get("arguments", ""),
),
),
is_finished=False,
finish_reason="",
usage=None,
)
elif output_item.get("type") == "message":
pass
elif output_item.get("type") == "reasoning":
pass
else:
raise ValueError(f"Chat provider: Invalid output_item {output_item}")
elif event_type == "response.function_call_arguments.delta":
content_part: Optional[str] = parsed_chunk.get("delta", None)
if content_part:
return GenericStreamingChunk(
text="",
tool_use=ChatCompletionToolCallChunk(
id=None,
index=0,
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=None, arguments=content_part
),
),
is_finished=False,
finish_reason="",
usage=None,
)
else:
raise ValueError(
f"Chat provider: Invalid function argument delta {parsed_chunk}"
)
elif event_type == "response.output_item.done":
# New output item added
output_item = parsed_chunk.get("item", {})
if output_item.get("type") == "function_call":
return GenericStreamingChunk(
text="",
tool_use=ChatCompletionToolCallChunk(
id=output_item.get("call_id"),
index=0,
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=parsed_chunk.get("name", None),
arguments="", # responses API sends everything again, we don't
),
),
is_finished=True,
finish_reason="tool_calls",
usage=None,
)
elif output_item.get("type") == "message":
return GenericStreamingChunk(
finish_reason="stop", is_finished=True, usage=None, text=""
)
elif output_item.get("type") == "reasoning":
pass
else:
raise ValueError(f"Chat provider: Invalid output_item {output_item}")
elif event_type == "response.output_text.delta":
# Content part added to output
content_part = parsed_chunk.get("delta", None)
if content_part is not None:
return GenericStreamingChunk(
text=content_part,
tool_use=None,
is_finished=False,
finish_reason="",
usage=None,
)
else:
raise ValueError(f"Chat provider: Invalid text delta {parsed_chunk}")
elif event_type == "response.reasoning_summary_text.delta":
content_part = parsed_chunk.get("delta", None)
if content_part:
from litellm.types.utils import (
Delta,
ModelResponseStream,
StreamingChoices,
)
return ModelResponseStream(
choices=[
StreamingChoices(
index=cast(int, parsed_chunk.get("summary_index")),
delta=Delta(reasoning_content=content_part),
)
]
)
else:
pass
# For any unhandled event types, create a minimal valid chunk or skip
verbose_logger.debug(
f"Chat provider: Unhandled event type '{event_type}', creating empty chunk"
)
# Return a minimal valid chunk for unknown events
return GenericStreamingChunk(
text="", tool_use=None, is_finished=False, finish_reason="", usage=None
)
_patched_openai_responses_chunk_parser.__name__ = (
"_patched_openai_responses_chunk_parser"
)
OpenAiResponsesToChatCompletionStreamIterator.chunk_parser = _patched_openai_responses_chunk_parser # type: ignore[method-assign]
def apply_monkey_patches() -> None:
"""
Apply all necessary monkey patches to LiteLLM for compatibility.
Apply all necessary monkey patches to LiteLLM for Ollama compatibility.
This includes:
- Patching OllamaChatConfig.transform_request for reasoning content support
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
"""
_patch_ollama_transform_request()
_patch_ollama_chunk_parser()
_patch_openai_responses_chunk_parser()
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:

View File

@@ -0,0 +1,146 @@
from collections.abc import Sequence
from langchain.schema.messages import BaseMessage
from onyx.llm.message_types import AssistantMessage
from onyx.llm.message_types import ChatCompletionMessage
from onyx.llm.message_types import ContentPart
from onyx.llm.message_types import ImageContentPart
from onyx.llm.message_types import SystemMessage
from onyx.llm.message_types import TextContentPart
from onyx.llm.message_types import UserMessageWithParts
from onyx.llm.message_types import UserMessageWithText
def base_messages_to_chat_completion_msgs(
msgs: Sequence[BaseMessage],
) -> list[ChatCompletionMessage]:
return [_base_message_to_chat_completion_msg(msg) for msg in msgs]
def _base_message_to_chat_completion_msg(msg: BaseMessage) -> ChatCompletionMessage:
message_type_to_role = {
"human": "user",
"system": "system",
"ai": "assistant",
"tool": "tool",
}
role = message_type_to_role[msg.type]
content = msg.content
if isinstance(content, str):
# Simple string content
if role == "system":
system_msg: SystemMessage = {
"role": "system",
"content": content,
}
return system_msg
elif role == "user":
user_msg: UserMessageWithText = {
"role": "user",
"content": content,
}
return user_msg
else: # assistant
assistant_msg: AssistantMessage = {
"role": "assistant",
"content": content,
}
return assistant_msg
elif isinstance(content, list):
# List content - need to convert to OpenAI format
if role == "assistant":
# For assistant, convert list to simple string
# (OpenAI format uses string content, not list)
text_parts = []
for item in content:
if isinstance(item, str):
text_parts.append(item)
elif isinstance(item, dict) and item.get("type") == "text":
text_parts.append(item.get("text", ""))
else:
raise ValueError(
f"Unexpected item type for assistant message: {type(item)}. Item: {item}"
)
assistant_msg_from_list: AssistantMessage = {
"role": "assistant",
"content": " ".join(text_parts) if text_parts else None,
}
return assistant_msg_from_list
else: # system or user
content_parts: list[ContentPart] = []
has_images = False
for item in content:
if isinstance(item, str):
content_parts.append(TextContentPart(type="text", text=item))
elif isinstance(item, dict):
item_type = item.get("type")
if item_type == "text":
content_parts.append(
TextContentPart(type="text", text=item.get("text", ""))
)
elif item_type == "image_url":
has_images = True
# Convert image_url to OpenAI format
image_url = item.get("image_url", {})
if isinstance(image_url, dict):
url = image_url.get("url", "")
detail = image_url.get("detail", "auto")
else:
url = image_url
detail = "auto"
image_part: ImageContentPart = {
"type": "image_url",
"image_url": {"url": url, "detail": detail},
}
content_parts.append(image_part)
else:
raise ValueError(f"Unexpected item type: {item_type}")
else:
raise ValueError(
f"Unexpected item type: {type(item)}. Item: {item}"
)
if role == "system":
# System messages should be text only, concatenate all text parts
text_parts = [
part["text"] for part in content_parts if part["type"] == "text"
]
system_msg_from_list: SystemMessage = {
"role": "system",
"content": " ".join(text_parts),
}
return system_msg_from_list
else: # user
# If there are images, use the parts format; otherwise use simple string
if has_images or len(content_parts) > 1:
user_msg_with_parts: UserMessageWithParts = {
"role": "user",
"content": content_parts,
}
return user_msg_with_parts
elif len(content_parts) == 1 and content_parts[0]["type"] == "text":
# Single text part - use simple string format
user_msg_simple: UserMessageWithText = {
"role": "user",
"content": content_parts[0]["text"],
}
return user_msg_simple
else:
# Empty content
user_msg_empty: UserMessageWithText = {
"role": "user",
"content": "",
}
return user_msg_empty
else:
raise ValueError(
f"Unexpected content type: {type(content)}. Content: {content}"
)

View File

@@ -38,19 +38,10 @@ class StreamingChoice(BaseModel):
delta: Delta = Field(default_factory=Delta)
class Usage(BaseModel):
completion_tokens: int
prompt_tokens: int
total_tokens: int
cache_creation_input_tokens: int
cache_read_input_tokens: int
class ModelResponseStream(BaseModel):
id: str
created: str
choice: StreamingChoice
usage: Usage | None = None
if TYPE_CHECKING:
@@ -183,24 +174,10 @@ def from_litellm_model_response_stream(
delta=parsed_delta,
)
usage_data = response_data.get("usage")
return ModelResponseStream(
id=response_id,
created=created,
choice=streaming_choice,
usage=(
Usage(
completion_tokens=usage_data.get("completion_tokens", 0),
prompt_tokens=usage_data.get("prompt_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
cache_creation_input_tokens=usage_data.get(
"cache_creation_input_tokens", 0
),
cache_read_input_tokens=usage_data.get("cache_read_input_tokens", 0),
)
if usage_data
else None
),
)

View File

@@ -116,7 +116,6 @@ def litellm_exception_to_error_msg(
from litellm.exceptions import Timeout
from litellm.exceptions import ContentPolicyViolationError
from litellm.exceptions import BudgetExceededError
from litellm.exceptions import ServiceUnavailableError
core_exception = _unwrap_nested_exception(e)
error_msg = str(core_exception)
@@ -169,23 +168,6 @@ def litellm_exception_to_error_msg(
if upstream_detail
else f"{provider_name} rate limit exceeded: Please slow down your requests and try again later."
)
elif isinstance(core_exception, ServiceUnavailableError):
provider_name = (
llm.config.model_provider
if llm is not None and llm.config.model_provider
else "The LLM provider"
)
# Check if this is specifically the Bedrock "Too many connections" error
if "Too many connections" in error_msg or "BedrockException" in error_msg:
error_msg = (
f"{provider_name} is experiencing high connection volume and cannot process your request right now. "
"This typically happens when there are too many simultaneous requests to the AI model. "
"Please wait a moment and try again. If this persists, contact your system administrator "
"to review connection limits and retry configurations."
)
else:
# Generic 503 Service Unavailable
error_msg = f"{provider_name} service error: {str(core_exception)}"
elif isinstance(core_exception, ContextWindowExceededError):
error_msg = (
"Context window exceeded: Your input is too long for the model to process."

View File

@@ -1,7 +1,6 @@
import logging
import sys
import traceback
import warnings
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Any
@@ -103,9 +102,6 @@ from onyx.server.manage.llm.api import basic_router as llm_router
from onyx.server.manage.search_settings import router as search_settings_router
from onyx.server.manage.slack_bot import router as slack_bot_management_router
from onyx.server.manage.users import router as user_router
from onyx.server.manage.web_search.api import (
admin_router as web_search_admin_router,
)
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
from onyx.server.middleware.rate_limiting import close_auth_limiter
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
@@ -147,13 +143,6 @@ from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SENTRY_DSN
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
warnings.filterwarnings(
"ignore", category=ResourceWarning, message=r"Unclosed client session"
)
warnings.filterwarnings(
"ignore", category=ResourceWarning, message=r"Unclosed connector"
)
logger = setup_logger()
file_handlers = [
@@ -403,7 +392,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
include_router_with_global_prefix_prepended(application, llm_router)
include_router_with_global_prefix_prepended(application, embedding_admin_router)
include_router_with_global_prefix_prepended(application, embedding_router)
include_router_with_global_prefix_prepended(application, web_search_admin_router)
include_router_with_global_prefix_prepended(
application, token_rate_limit_settings_router
)

View File

@@ -125,18 +125,18 @@ def get_versions() -> AllVersions:
onyx=latest_stable_version,
relational_db="postgres:15.2-alpine",
index="vespaengine/vespa:8.277.17",
nginx="nginx:1.25.5-alpine",
nginx="nginx:1.23.4-alpine",
),
dev=ContainerVersions(
onyx=latest_dev_version,
relational_db="postgres:15.2-alpine",
index="vespaengine/vespa:8.277.17",
nginx="nginx:1.25.5-alpine",
nginx="nginx:1.23.4-alpine",
),
migration=ContainerVersions(
onyx="airgapped-intfloat-nomic-migration",
relational_db="postgres:15.2-alpine",
index="vespaengine/vespa:8.277.17",
nginx="nginx:1.25.5-alpine",
nginx="nginx:1.23.4-alpine",
),
)

View File

@@ -1,364 +0,0 @@
from __future__ import annotations
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
build_content_provider_from_config,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
build_search_provider_from_config,
)
from onyx.auth.users import current_admin_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.db.web_search import deactivate_web_content_provider
from onyx.db.web_search import deactivate_web_search_provider
from onyx.db.web_search import delete_web_content_provider
from onyx.db.web_search import delete_web_search_provider
from onyx.db.web_search import fetch_web_content_provider_by_name
from onyx.db.web_search import fetch_web_content_providers
from onyx.db.web_search import fetch_web_search_provider_by_name
from onyx.db.web_search import fetch_web_search_providers
from onyx.db.web_search import set_active_web_content_provider
from onyx.db.web_search import set_active_web_search_provider
from onyx.db.web_search import upsert_web_content_provider
from onyx.db.web_search import upsert_web_search_provider
from onyx.server.manage.web_search.models import WebContentProviderTestRequest
from onyx.server.manage.web_search.models import WebContentProviderUpsertRequest
from onyx.server.manage.web_search.models import WebContentProviderView
from onyx.server.manage.web_search.models import WebSearchProviderTestRequest
from onyx.server.manage.web_search.models import WebSearchProviderUpsertRequest
from onyx.server.manage.web_search.models import WebSearchProviderView
from onyx.utils.logger import setup_logger
from shared_configs.enums import WebContentProviderType
from shared_configs.enums import WebSearchProviderType
logger = setup_logger()
admin_router = APIRouter(prefix="/admin/web-search")
@admin_router.get("/search-providers", response_model=list[WebSearchProviderView])
def list_search_providers(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[WebSearchProviderView]:
providers = fetch_web_search_providers(db_session)
return [
WebSearchProviderView(
id=provider.id,
name=provider.name,
provider_type=WebSearchProviderType(provider.provider_type),
is_active=provider.is_active,
config=provider.config or {},
has_api_key=bool(provider.api_key),
)
for provider in providers
]
@admin_router.post("/search-providers", response_model=WebSearchProviderView)
def upsert_search_provider_endpoint(
request: WebSearchProviderUpsertRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> WebSearchProviderView:
existing_by_name = fetch_web_search_provider_by_name(request.name, db_session)
if (
existing_by_name
and request.id is not None
and existing_by_name.id != request.id
):
raise HTTPException(
status_code=400,
detail=f"A search provider named '{request.name}' already exists.",
)
provider = upsert_web_search_provider(
provider_id=request.id,
name=request.name,
provider_type=request.provider_type,
api_key=request.api_key,
api_key_changed=request.api_key_changed,
config=request.config,
activate=request.activate,
db_session=db_session,
)
db_session.commit()
return WebSearchProviderView(
id=provider.id,
name=provider.name,
provider_type=WebSearchProviderType(provider.provider_type),
is_active=provider.is_active,
config=provider.config or {},
has_api_key=bool(provider.api_key),
)
@admin_router.delete(
"/search-providers/{provider_id}", status_code=204, response_class=Response
)
def delete_search_provider(
provider_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> Response:
delete_web_search_provider(provider_id, db_session)
return Response(status_code=204)
@admin_router.post("/search-providers/{provider_id}/activate")
def activate_search_provider(
provider_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> WebSearchProviderView:
provider = set_active_web_search_provider(
provider_id=provider_id, db_session=db_session
)
db_session.commit()
return WebSearchProviderView(
id=provider.id,
name=provider.name,
provider_type=WebSearchProviderType(provider.provider_type),
is_active=provider.is_active,
config=provider.config or {},
has_api_key=bool(provider.api_key),
)
@admin_router.post("/search-providers/{provider_id}/deactivate")
def deactivate_search_provider(
provider_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
deactivate_web_search_provider(provider_id=provider_id, db_session=db_session)
db_session.commit()
return {"status": "ok"}
@admin_router.post("/search-providers/test")
def test_search_provider(
request: WebSearchProviderTestRequest,
_: User | None = Depends(current_admin_user),
) -> dict[str, str]:
try:
provider = build_search_provider_from_config(
provider_type=request.provider_type,
api_key=request.api_key,
config=request.config or {},
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
if provider is None:
raise HTTPException(
status_code=400, detail="Unable to build provider configuration."
)
# Actually test the API key by making a real search call
try:
test_results = provider.search("test")
if not test_results or not any(result.link for result in test_results):
raise HTTPException(
status_code=400,
detail="API key validation failed: search returned no results.",
)
except HTTPException:
raise
except Exception as e:
error_msg = str(e)
if (
"api" in error_msg.lower()
or "key" in error_msg.lower()
or "auth" in error_msg.lower()
):
raise HTTPException(
status_code=400,
detail=f"Invalid API key: {error_msg}",
) from e
raise HTTPException(
status_code=400,
detail=f"API key validation failed: {error_msg}",
) from e
logger.info(
f"Web search provider test succeeded for {request.provider_type.value}."
)
return {"status": "ok"}
@admin_router.get("/content-providers", response_model=list[WebContentProviderView])
def list_content_providers(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[WebContentProviderView]:
providers = fetch_web_content_providers(db_session)
return [
WebContentProviderView(
id=provider.id,
name=provider.name,
provider_type=WebContentProviderType(provider.provider_type),
is_active=provider.is_active,
config=provider.config or {},
has_api_key=bool(provider.api_key),
)
for provider in providers
]
@admin_router.post("/content-providers", response_model=WebContentProviderView)
def upsert_content_provider_endpoint(
request: WebContentProviderUpsertRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> WebContentProviderView:
existing_by_name = fetch_web_content_provider_by_name(request.name, db_session)
if (
existing_by_name
and request.id is not None
and existing_by_name.id != request.id
):
raise HTTPException(
status_code=400,
detail=f"A content provider named '{request.name}' already exists.",
)
provider = upsert_web_content_provider(
provider_id=request.id,
name=request.name,
provider_type=request.provider_type,
api_key=request.api_key,
api_key_changed=request.api_key_changed,
config=request.config,
activate=request.activate,
db_session=db_session,
)
db_session.commit()
return WebContentProviderView(
id=provider.id,
name=provider.name,
provider_type=WebContentProviderType(provider.provider_type),
is_active=provider.is_active,
config=provider.config or {},
has_api_key=bool(provider.api_key),
)
@admin_router.delete(
"/content-providers/{provider_id}", status_code=204, response_class=Response
)
def delete_content_provider(
provider_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> Response:
delete_web_content_provider(provider_id, db_session)
return Response(status_code=204)
@admin_router.post("/content-providers/{provider_id}/activate")
def activate_content_provider(
provider_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> WebContentProviderView:
provider = set_active_web_content_provider(
provider_id=provider_id, db_session=db_session
)
db_session.commit()
return WebContentProviderView(
id=provider.id,
name=provider.name,
provider_type=WebContentProviderType(provider.provider_type),
is_active=provider.is_active,
config=provider.config or {},
has_api_key=bool(provider.api_key),
)
@admin_router.post("/content-providers/reset-default")
def reset_content_provider_default(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
providers = fetch_web_content_providers(db_session)
active_ids = [provider.id for provider in providers if provider.is_active]
for provider_id in active_ids:
deactivate_web_content_provider(provider_id=provider_id, db_session=db_session)
db_session.commit()
return {"status": "ok"}
@admin_router.post("/content-providers/{provider_id}/deactivate")
def deactivate_content_provider(
provider_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
deactivate_web_content_provider(provider_id=provider_id, db_session=db_session)
db_session.commit()
return {"status": "ok"}
@admin_router.post("/content-providers/test")
def test_content_provider(
request: WebContentProviderTestRequest,
_: User | None = Depends(current_admin_user),
) -> dict[str, str]:
try:
provider = build_content_provider_from_config(
provider_type=request.provider_type,
api_key=request.api_key,
config=request.config or {},
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
if provider is None:
raise HTTPException(
status_code=400, detail="Unable to build provider configuration."
)
# Actually test the API key by making a real content fetch call
try:
test_url = "https://example.com"
test_results = provider.contents([test_url])
if not test_results or not any(
result.scrape_successful for result in test_results
):
raise HTTPException(
status_code=400,
detail="API key validation failed: content fetch returned no results.",
)
except HTTPException:
raise
except Exception as e:
error_msg = str(e)
if (
"api" in error_msg.lower()
or "key" in error_msg.lower()
or "auth" in error_msg.lower()
):
raise HTTPException(
status_code=400,
detail=f"Invalid API key: {error_msg}",
) from e
raise HTTPException(
status_code=400,
detail=f"API key validation failed: {error_msg}",
) from e
logger.info(
f"Web content provider test succeeded for {request.provider_type.value}."
)
return {"status": "ok"}

View File

@@ -1,69 +0,0 @@
from typing import Any
from pydantic import BaseModel
from pydantic import Field
from shared_configs.enums import WebContentProviderType
from shared_configs.enums import WebSearchProviderType
class WebSearchProviderView(BaseModel):
id: int
name: str
provider_type: WebSearchProviderType
is_active: bool
config: dict[str, str] | None
has_api_key: bool = Field(
default=False,
description="Indicates whether an API key is stored for this provider.",
)
class WebSearchProviderUpsertRequest(BaseModel):
id: int | None = Field(default=None, description="Existing provider ID to update.")
name: str
provider_type: WebSearchProviderType
config: dict[str, str] | None = None
api_key: str | None = Field(
default=None,
description="API key for the provider. Only required when creating or updating credentials.",
)
api_key_changed: bool = Field(
default=False,
description="Set to true when providing a new API key for an existing provider.",
)
activate: bool = Field(
default=False,
description="If true, sets this provider as the active one after upsert.",
)
class WebContentProviderView(BaseModel):
id: int
name: str
provider_type: WebContentProviderType
is_active: bool
config: dict[str, str] | None
has_api_key: bool = Field(default=False)
class WebContentProviderUpsertRequest(BaseModel):
id: int | None = None
name: str
provider_type: WebContentProviderType
config: dict[str, str] | None = None
api_key: str | None = None
api_key_changed: bool = False
activate: bool = False
class WebSearchProviderTestRequest(BaseModel):
provider_type: WebSearchProviderType
api_key: str | None = None
config: dict[str, Any] | None = None
class WebContentProviderTestRequest(BaseModel):
provider_type: WebContentProviderType
api_key: str | None = None
config: dict[str, Any] | None = None

View File

@@ -16,7 +16,7 @@ class ForceUseTool(BaseModel):
def build_openai_tool_choice_dict(self) -> dict[str, Any]:
"""Build dict in the format that OpenAI expects which tells them to use this tool."""
return {"type": "function", "name": self.tool_name}
return {"type": "function", "function": {"name": self.tool_name}}
def filter_tools_for_force_tool_use(

View File

@@ -5,8 +5,8 @@ from sqlalchemy.orm import Session
from typing_extensions import override
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.web_search import fetch_active_web_search_provider
from onyx.configs.chat_configs import EXA_API_KEY
from onyx.configs.chat_configs import SERPER_API_KEY
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.tools.message import ToolCallSummary
@@ -51,10 +51,8 @@ class WebSearchTool(Tool[None]):
@override
@classmethod
def is_available(cls, db_session: Session) -> bool:
"""Available only if an active web search provider is configured in the database."""
with get_session_with_current_tenant() as session:
provider = fetch_active_web_search_provider(session)
return provider is not None
"""Available only if EXA or SERPER API key is configured."""
return bool(EXA_API_KEY) or bool(SERPER_API_KEY)
def tool_definition(self) -> dict:
return {

View File

@@ -6,20 +6,14 @@ from pydantic import TypeAdapter
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContentProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
get_default_content_provider,
get_default_provider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
get_default_provider,
WebSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
dummy_inference_section_from_internet_content,
@@ -196,7 +190,7 @@ changing or evolving.
def _open_url_core(
run_context: RunContextWrapper[ChatTurnContext],
urls: Sequence[str],
content_provider: WebContentProvider,
search_provider: WebSearchProvider,
) -> list[LlmOpenUrlResult]:
# TODO: Find better way to track index that isn't so implicit
# based on number of tool calls
@@ -212,7 +206,7 @@ def _open_url_core(
)
)
docs = content_provider.contents(urls)
docs = search_provider.contents(urls)
results = [
LlmOpenUrlResult(
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
@@ -273,10 +267,10 @@ def open_url(
"""
Tool for fetching and extracting full content from web pages.
"""
content_provider = get_default_content_provider()
if content_provider is None:
raise ValueError("No web content provider found")
retrieved_docs = _open_url_core(run_context, urls, content_provider)
search_provider = get_default_provider()
if search_provider is None:
raise ValueError("No search provider found")
retrieved_docs = _open_url_core(run_context, urls, search_provider)
adapter = TypeAdapter(list[LlmOpenUrlResult])
return adapter.dump_json(retrieved_docs).decode()

View File

@@ -80,11 +80,7 @@ def setup_braintrust_if_creds_available() -> None:
api_key=BRAINTRUST_API_KEY,
)
braintrust.set_masking_function(_mask)
set_trace_processors([BraintrustTracingProcessor(braintrust_logger)])
_setup_legacy_langchain_tracing()
logger.notice("Braintrust tracing initialized")
def _setup_legacy_langchain_tracing() -> None:
handler = BraintrustCallbackHandler()
set_global_handler(handler)
set_trace_processors([BraintrustTracingProcessor(braintrust_logger)])
logger.notice("Braintrust tracing initialized")

View File

@@ -1,216 +0,0 @@
import datetime
from typing import Any
from typing import Dict
from typing import Optional
import braintrust
from braintrust import NOOP_SPAN
from .framework.processor_interface import TracingProcessor
from .framework.span_data import AgentSpanData
from .framework.span_data import FunctionSpanData
from .framework.span_data import GenerationSpanData
from .framework.span_data import SpanData
from .framework.spans import Span
from .framework.traces import Trace
def _span_type(span: Span[Any]) -> braintrust.SpanTypeAttribute:
if span.span_data.type in ["agent"]:
return braintrust.SpanTypeAttribute.TASK
elif span.span_data.type in ["function"]:
return braintrust.SpanTypeAttribute.TOOL
elif span.span_data.type in ["generation"]:
return braintrust.SpanTypeAttribute.LLM
else:
return braintrust.SpanTypeAttribute.TASK
def _span_name(span: Span[Any]) -> str:
if isinstance(span.span_data, AgentSpanData) or isinstance(
span.span_data, FunctionSpanData
):
return span.span_data.name
elif isinstance(span.span_data, GenerationSpanData):
return "Generation"
else:
return "Unknown"
def _timestamp_from_maybe_iso(timestamp: Optional[str]) -> Optional[float]:
if timestamp is None:
return None
return datetime.datetime.fromisoformat(timestamp).timestamp()
def _maybe_timestamp_elapsed(
end: Optional[str], start: Optional[str]
) -> Optional[float]:
if start is None or end is None:
return None
return (
datetime.datetime.fromisoformat(end) - datetime.datetime.fromisoformat(start)
).total_seconds()
class BraintrustTracingProcessor(TracingProcessor):
"""
`BraintrustTracingProcessor` is a `tracing.TracingProcessor` that logs traces to Braintrust.
Args:
logger: A `braintrust.Span` or `braintrust.Experiment` or `braintrust.Logger` to use for logging.
If `None`, the current span, experiment, or logger will be selected exactly as in `braintrust.start_span`.
"""
def __init__(self, logger: Optional[braintrust.Logger] = None):
self._logger = logger
self._spans: Dict[str, Any] = {}
self._first_input: Dict[str, Any] = {}
self._last_output: Dict[str, Any] = {}
def on_trace_start(self, trace: Trace) -> None:
trace_meta = trace.export() or {}
metadata = {
"group_id": trace_meta.get("group_id"),
**(trace_meta.get("metadata") or {}),
}
current_context = braintrust.current_span()
if current_context != NOOP_SPAN:
self._spans[trace.trace_id] = current_context.start_span( # type: ignore[assignment]
name=trace.name,
span_attributes={"type": "task", "name": trace.name},
metadata=metadata,
)
elif self._logger is not None:
self._spans[trace.trace_id] = self._logger.start_span( # type: ignore[assignment]
span_attributes={"type": "task", "name": trace.name},
span_id=trace.trace_id,
root_span_id=trace.trace_id,
metadata=metadata,
)
else:
self._spans[trace.trace_id] = braintrust.start_span( # type: ignore[assignment]
id=trace.trace_id,
span_attributes={"type": "task", "name": trace.name},
metadata=metadata,
)
def on_trace_end(self, trace: Trace) -> None:
span: Any = self._spans.pop(trace.trace_id)
# Get the first input and last output for this specific trace
trace_first_input = self._first_input.pop(trace.trace_id, None)
trace_last_output = self._last_output.pop(trace.trace_id, None)
span.log(input=trace_first_input, output=trace_last_output)
span.end()
def _agent_log_data(self, span: Span[AgentSpanData]) -> Dict[str, Any]:
return {
"metadata": {
"tools": span.span_data.tools,
"handoffs": span.span_data.handoffs,
"output_type": span.span_data.output_type,
}
}
def _function_log_data(self, span: Span[FunctionSpanData]) -> Dict[str, Any]:
return {
"input": span.span_data.input,
"output": span.span_data.output,
}
def _generation_log_data(self, span: Span[GenerationSpanData]) -> Dict[str, Any]:
metrics = {}
ttft = _maybe_timestamp_elapsed(span.ended_at, span.started_at)
if ttft is not None:
metrics["time_to_first_token"] = ttft
usage = span.span_data.usage or {}
if "prompt_tokens" in usage:
metrics["prompt_tokens"] = usage["prompt_tokens"]
elif "input_tokens" in usage:
metrics["prompt_tokens"] = usage["input_tokens"]
if "completion_tokens" in usage:
metrics["completion_tokens"] = usage["completion_tokens"]
elif "output_tokens" in usage:
metrics["completion_tokens"] = usage["output_tokens"]
if "total_tokens" in usage:
metrics["tokens"] = usage["total_tokens"]
elif "input_tokens" in usage and "output_tokens" in usage:
metrics["tokens"] = usage["input_tokens"] + usage["output_tokens"]
if "cache_read_input_tokens" in usage:
metrics["prompt_cached_tokens"] = usage["cache_read_input_tokens"]
if "cache_creation_input_tokens" in usage:
metrics["prompt_cache_creation_tokens"] = usage[
"cache_creation_input_tokens"
]
return {
"input": span.span_data.input,
"output": span.span_data.output,
"metadata": {
"model": span.span_data.model,
"model_config": span.span_data.model_config,
},
"metrics": metrics,
}
def _log_data(self, span: Span[Any]) -> Dict[str, Any]:
if isinstance(span.span_data, AgentSpanData):
return self._agent_log_data(span)
elif isinstance(span.span_data, FunctionSpanData):
return self._function_log_data(span)
elif isinstance(span.span_data, GenerationSpanData):
return self._generation_log_data(span)
else:
return {}
def on_span_start(self, span: Span[SpanData]) -> None:
parent: Any = (
self._spans[span.parent_id]
if span.parent_id is not None
else self._spans[span.trace_id]
)
created_span: Any = parent.start_span(
id=span.span_id,
name=_span_name(span),
type=_span_type(span),
start_time=_timestamp_from_maybe_iso(span.started_at),
)
self._spans[span.span_id] = created_span
# Set the span as current so current_span() calls will return it
created_span.set_current()
def on_span_end(self, span: Span[SpanData]) -> None:
s: Any = self._spans.pop(span.span_id)
event = dict(error=span.error, **self._log_data(span))
s.log(**event)
s.unset_current()
s.end(_timestamp_from_maybe_iso(span.ended_at))
input_ = event.get("input")
output = event.get("output")
# Store first input and last output per trace_id
trace_id = span.trace_id
if trace_id not in self._first_input and input_ is not None:
self._first_input[trace_id] = input_
if output is not None:
self._last_output[trace_id] = output
def shutdown(self) -> None:
if self._logger is not None:
self._logger.flush()
else:
braintrust.flush()
def force_flush(self) -> None:
if self._logger is not None:
self._logger.flush()
else:
braintrust.flush()

View File

@@ -1,21 +0,0 @@
from .processor_interface import TracingProcessor
from .provider import DefaultTraceProvider
from .setup import get_trace_provider
from .setup import set_trace_provider
def add_trace_processor(span_processor: TracingProcessor) -> None:
"""
Adds a new trace processor. This processor will receive all traces/spans.
"""
get_trace_provider().register_processor(span_processor)
def set_trace_processors(processors: list[TracingProcessor]) -> None:
"""
Set the list of trace processors. This will replace the current list of processors.
"""
get_trace_provider().set_processors(processors)
set_trace_provider(DefaultTraceProvider())

View File

@@ -1,21 +0,0 @@
from typing import Any
from .create import get_current_span
from .spans import Span
from .spans import SpanError
from onyx.utils.logger import setup_logger
logger = setup_logger(__name__)
def attach_error_to_span(span: Span[Any], error: SpanError) -> None:
span.set_error(error)
def attach_error_to_current_span(error: SpanError) -> None:
span = get_current_span()
if span:
attach_error_to_span(span, error)
else:
logger.warning(f"No span to add error {error} to")

View File

@@ -1,192 +0,0 @@
from __future__ import annotations
from collections.abc import Mapping
from collections.abc import Sequence
from typing import Any
from typing import TYPE_CHECKING
from .setup import get_trace_provider
from .span_data import AgentSpanData
from .span_data import FunctionSpanData
from .span_data import GenerationSpanData
from .spans import Span
from .traces import Trace
from onyx.utils.logger import setup_logger
if TYPE_CHECKING:
pass
logger = setup_logger(__name__)
def trace(
workflow_name: str,
trace_id: str | None = None,
group_id: str | None = None,
metadata: dict[str, Any] | None = None,
disabled: bool = False,
) -> Trace:
"""
Create a new trace. The trace will not be started automatically; you should either use
it as a context manager (`with trace(...):`) or call `trace.start()` + `trace.finish()`
manually.
In addition to the workflow name and optional grouping identifier, you can provide
an arbitrary metadata dictionary to attach additional user-defined information to
the trace.
Args:
workflow_name: The name of the logical app or workflow. For example, you might provide
"code_bot" for a coding agent, or "customer_support_agent" for a customer support agent.
trace_id: The ID of the trace. Optional. If not provided, we will generate an ID. We
recommend using `util.gen_trace_id()` to generate a trace ID, to guarantee that IDs are
correctly formatted.
group_id: Optional grouping identifier to link multiple traces from the same conversation
or process. For instance, you might use a chat thread ID.
metadata: Optional dictionary of additional metadata to attach to the trace.
disabled: If True, we will return a Trace but the Trace will not be recorded.
Returns:
The newly created trace object.
"""
current_trace = get_trace_provider().get_current_trace()
if current_trace:
logger.warning(
"Trace already exists. Creating a new trace, but this is probably a mistake."
)
return get_trace_provider().create_trace(
name=workflow_name,
trace_id=trace_id,
group_id=group_id,
metadata=metadata,
disabled=disabled,
)
def get_current_trace() -> Trace | None:
"""Returns the currently active trace, if present."""
return get_trace_provider().get_current_trace()
def get_current_span() -> Span[Any] | None:
"""Returns the currently active span, if present."""
return get_trace_provider().get_current_span()
def agent_span(
name: str,
handoffs: list[str] | None = None,
tools: list[str] | None = None,
output_type: str | None = None,
span_id: str | None = None,
parent: Trace | Span[Any] | None = None,
disabled: bool = False,
) -> Span[AgentSpanData]:
"""Create a new agent span. The span will not be started automatically, you should either do
`with agent_span() ...` or call `span.start()` + `span.finish()` manually.
Args:
name: The name of the agent.
handoffs: Optional list of agent names to which this agent could hand off control.
tools: Optional list of tool names available to this agent.
output_type: Optional name of the output type produced by the agent.
span_id: The ID of the span. Optional. If not provided, we will generate an ID. We
recommend using `util.gen_span_id()` to generate a span ID, to guarantee that IDs are
correctly formatted.
parent: The parent span or trace. If not provided, we will automatically use the current
trace/span as the parent.
disabled: If True, we will return a Span but the Span will not be recorded.
Returns:
The newly created agent span.
"""
return get_trace_provider().create_span(
span_data=AgentSpanData(
name=name, handoffs=handoffs, tools=tools, output_type=output_type
),
span_id=span_id,
parent=parent,
disabled=disabled,
)
def function_span(
name: str,
input: str | None = None,
output: str | None = None,
span_id: str | None = None,
parent: Trace | Span[Any] | None = None,
disabled: bool = False,
) -> Span[FunctionSpanData]:
"""Create a new function span. The span will not be started automatically, you should either do
`with function_span() ...` or call `span.start()` + `span.finish()` manually.
Args:
name: The name of the function.
input: The input to the function.
output: The output of the function.
span_id: The ID of the span. Optional. If not provided, we will generate an ID. We
recommend using `util.gen_span_id()` to generate a span ID, to guarantee that IDs are
correctly formatted.
parent: The parent span or trace. If not provided, we will automatically use the current
trace/span as the parent.
disabled: If True, we will return a Span but the Span will not be recorded.
Returns:
The newly created function span.
"""
return get_trace_provider().create_span(
span_data=FunctionSpanData(name=name, input=input, output=output),
span_id=span_id,
parent=parent,
disabled=disabled,
)
def generation_span(
input: Sequence[Mapping[str, Any]] | None = None,
output: Sequence[Mapping[str, Any]] | None = None,
model: str | None = None,
model_config: Mapping[str, Any] | None = None,
usage: dict[str, Any] | None = None,
span_id: str | None = None,
parent: Trace | Span[Any] | None = None,
disabled: bool = False,
) -> Span[GenerationSpanData]:
"""Create a new generation span. The span will not be started automatically, you should either
do `with generation_span() ...` or call `span.start()` + `span.finish()` manually.
This span captures the details of a model generation, including the
input message sequence, any generated outputs, the model name and
configuration, and usage data. If you only need to capture a model
response identifier, use `response_span()` instead.
Args:
input: The sequence of input messages sent to the model.
output: The sequence of output messages received from the model.
model: The model identifier used for the generation.
model_config: The model configuration (hyperparameters) used.
usage: A dictionary of usage information (input tokens, output tokens, etc.).
span_id: The ID of the span. Optional. If not provided, we will generate an ID. We
recommend using `util.gen_span_id()` to generate a span ID, to guarantee that IDs are
correctly formatted.
parent: The parent span or trace. If not provided, we will automatically use the current
trace/span as the parent.
disabled: If True, we will return a Span but the Span will not be recorded.
Returns:
The newly created generation span.
"""
return get_trace_provider().create_span(
span_data=GenerationSpanData(
input=input,
output=output,
model=model,
model_config=model_config,
usage=usage,
),
span_id=span_id,
parent=parent,
disabled=disabled,
)

View File

@@ -1,136 +0,0 @@
import abc
from typing import Any
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .spans import Span
from .traces import Trace
class TracingProcessor(abc.ABC):
"""Interface for processing and monitoring traces and spans in the OpenAI Agents system.
This abstract class defines the interface that all tracing processors must implement.
Processors receive notifications when traces and spans start and end, allowing them
to collect, process, and export tracing data.
Example:
```python
class CustomProcessor(TracingProcessor):
def __init__(self):
self.active_traces = {}
self.active_spans = {}
def on_trace_start(self, trace):
self.active_traces[trace.trace_id] = trace
def on_trace_end(self, trace):
# Process completed trace
del self.active_traces[trace.trace_id]
def on_span_start(self, span):
self.active_spans[span.span_id] = span
def on_span_end(self, span):
# Process completed span
del self.active_spans[span.span_id]
def shutdown(self):
# Clean up resources
self.active_traces.clear()
self.active_spans.clear()
def force_flush(self):
# Force processing of any queued items
pass
```
Notes:
- All methods should be thread-safe
- Methods should not block for long periods
- Handle errors gracefully to prevent disrupting agent execution
"""
@abc.abstractmethod
def on_trace_start(self, trace: "Trace") -> None:
"""Called when a new trace begins execution.
Args:
trace: The trace that started. Contains workflow name and metadata.
Notes:
- Called synchronously on trace start
- Should return quickly to avoid blocking execution
- Any errors should be caught and handled internally
"""
@abc.abstractmethod
def on_trace_end(self, trace: "Trace") -> None:
"""Called when a trace completes execution.
Args:
trace: The completed trace containing all spans and results.
Notes:
- Called synchronously when trace finishes
- Good time to export/process the complete trace
- Should handle cleanup of any trace-specific resources
"""
@abc.abstractmethod
def on_span_start(self, span: "Span[Any]") -> None:
"""Called when a new span begins execution.
Args:
span: The span that started. Contains operation details and context.
Notes:
- Called synchronously on span start
- Should return quickly to avoid blocking execution
- Spans are automatically nested under current trace/span
"""
@abc.abstractmethod
def on_span_end(self, span: "Span[Any]") -> None:
"""Called when a span completes execution.
Args:
span: The completed span containing execution results.
Notes:
- Called synchronously when span finishes
- Should not block or raise exceptions
- Good time to export/process the individual span
"""
@abc.abstractmethod
def shutdown(self) -> None:
"""Called when the application stops to clean up resources.
Should perform any necessary cleanup like:
- Flushing queued traces/spans
- Closing connections
- Releasing resources
"""
@abc.abstractmethod
def force_flush(self) -> None:
"""Forces immediate processing of any queued traces/spans.
Notes:
- Should process all queued items before returning
- Useful before shutdown or when immediate processing is needed
- May block while processing completes
"""
class TracingExporter(abc.ABC):
"""Exports traces and spans. For example, could log them or send them to a backend."""
@abc.abstractmethod
def export(self, items: list["Trace | Span[Any]"]) -> None:
"""Exports a list of traces and spans.
Args:
items: The items to export.
"""

View File

@@ -1,319 +0,0 @@
from __future__ import annotations
import threading
import uuid
from abc import ABC
from abc import abstractmethod
from datetime import datetime
from datetime import timezone
from typing import Any
from .processor_interface import TracingProcessor
from .scope import Scope
from .spans import NoOpSpan
from .spans import Span
from .spans import SpanImpl
from .spans import TSpanData
from .traces import NoOpTrace
from .traces import Trace
from .traces import TraceImpl
from onyx.utils.logger import setup_logger
logger = setup_logger(__name__)
class SynchronousMultiTracingProcessor(TracingProcessor):
"""
Forwards all calls to a list of TracingProcessors, in order of registration.
"""
def __init__(self) -> None:
# Using a tuple to avoid race conditions when iterating over processors
self._processors: tuple[TracingProcessor, ...] = ()
self._lock = threading.Lock()
def add_tracing_processor(self, tracing_processor: TracingProcessor) -> None:
"""
Add a processor to the list of processors. Each processor will receive all traces/spans.
"""
with self._lock:
self._processors += (tracing_processor,)
def set_processors(self, processors: list[TracingProcessor]) -> None:
"""
Set the list of processors. This will replace the current list of processors.
"""
with self._lock:
self._processors = tuple(processors)
def on_trace_start(self, trace: Trace) -> None:
"""
Called when a trace is started.
"""
for processor in self._processors:
try:
processor.on_trace_start(trace)
except Exception as e:
logger.error(
f"Error in trace processor {processor} during on_trace_start: {e}"
)
def on_trace_end(self, trace: Trace) -> None:
"""
Called when a trace is finished.
"""
for processor in self._processors:
try:
processor.on_trace_end(trace)
except Exception as e:
logger.error(
f"Error in trace processor {processor} during on_trace_end: {e}"
)
def on_span_start(self, span: Span[Any]) -> None:
"""
Called when a span is started.
"""
for processor in self._processors:
try:
processor.on_span_start(span)
except Exception as e:
logger.error(
f"Error in trace processor {processor} during on_span_start: {e}"
)
def on_span_end(self, span: Span[Any]) -> None:
"""
Called when a span is finished.
"""
for processor in self._processors:
try:
processor.on_span_end(span)
except Exception as e:
logger.error(
f"Error in trace processor {processor} during on_span_end: {e}"
)
def shutdown(self) -> None:
"""
Called when the application stops.
"""
for processor in self._processors:
logger.debug(f"Shutting down trace processor {processor}")
try:
processor.shutdown()
except Exception as e:
logger.error(f"Error shutting down trace processor {processor}: {e}")
def force_flush(self) -> None:
"""
Force the processors to flush their buffers.
"""
for processor in self._processors:
try:
processor.force_flush()
except Exception as e:
logger.error(f"Error flushing trace processor {processor}: {e}")
class TraceProvider(ABC):
"""Interface for creating traces and spans."""
@abstractmethod
def register_processor(self, processor: TracingProcessor) -> None:
"""Add a processor that will receive all traces and spans."""
@abstractmethod
def set_processors(self, processors: list[TracingProcessor]) -> None:
"""Replace the list of processors with ``processors``."""
@abstractmethod
def get_current_trace(self) -> Trace | None:
"""Return the currently active trace, if any."""
@abstractmethod
def get_current_span(self) -> Span[Any] | None:
"""Return the currently active span, if any."""
@abstractmethod
def time_iso(self) -> str:
"""Return the current time in ISO 8601 format."""
@abstractmethod
def gen_trace_id(self) -> str:
"""Generate a new trace identifier."""
@abstractmethod
def gen_span_id(self) -> str:
"""Generate a new span identifier."""
@abstractmethod
def gen_group_id(self) -> str:
"""Generate a new group identifier."""
@abstractmethod
def create_trace(
self,
name: str,
trace_id: str | None = None,
group_id: str | None = None,
metadata: dict[str, Any] | None = None,
disabled: bool = False,
) -> Trace:
"""Create a new trace."""
@abstractmethod
def create_span(
self,
span_data: TSpanData,
span_id: str | None = None,
parent: Trace | Span[Any] | None = None,
disabled: bool = False,
) -> Span[TSpanData]:
"""Create a new span."""
@abstractmethod
def shutdown(self) -> None:
"""Clean up any resources used by the provider."""
class DefaultTraceProvider(TraceProvider):
def __init__(self) -> None:
self._multi_processor = SynchronousMultiTracingProcessor()
def register_processor(self, processor: TracingProcessor) -> None:
"""
Add a processor to the list of processors. Each processor will receive all traces/spans.
"""
self._multi_processor.add_tracing_processor(processor)
def set_processors(self, processors: list[TracingProcessor]) -> None:
"""
Set the list of processors. This will replace the current list of processors.
"""
self._multi_processor.set_processors(processors)
def get_current_trace(self) -> Trace | None:
"""
Returns the currently active trace, if any.
"""
return Scope.get_current_trace()
def get_current_span(self) -> Span[Any] | None:
"""
Returns the currently active span, if any.
"""
return Scope.get_current_span()
def time_iso(self) -> str:
"""Return the current time in ISO 8601 format."""
return datetime.now(timezone.utc).isoformat()
def gen_trace_id(self) -> str:
"""Generate a new trace ID."""
return f"trace_{uuid.uuid4().hex}"
def gen_span_id(self) -> str:
"""Generate a new span ID."""
return f"span_{uuid.uuid4().hex[:24]}"
def gen_group_id(self) -> str:
"""Generate a new group ID."""
return f"group_{uuid.uuid4().hex[:24]}"
def create_trace(
self,
name: str,
trace_id: str | None = None,
group_id: str | None = None,
metadata: dict[str, Any] | None = None,
disabled: bool = False,
) -> Trace:
"""
Create a new trace.
"""
if disabled:
logger.debug(f"Tracing is disabled. Not creating trace {name}")
return NoOpTrace()
trace_id = trace_id or self.gen_trace_id()
logger.debug(f"Creating trace {name} with id {trace_id}")
return TraceImpl(
name=name,
trace_id=trace_id,
group_id=group_id,
metadata=metadata,
processor=self._multi_processor,
)
def create_span(
self,
span_data: TSpanData,
span_id: str | None = None,
parent: Trace | Span[Any] | None = None,
disabled: bool = False,
) -> Span[TSpanData]:
"""
Create a new span.
"""
if disabled:
logger.debug(f"Tracing is disabled. Not creating span {span_data}")
return NoOpSpan(span_data)
trace_id: str
parent_id: str | None
if not parent:
current_span = Scope.get_current_span()
current_trace = Scope.get_current_trace()
if current_trace is None:
logger.error(
"No active trace. Make sure to start a trace with `trace()` first "
"Returning NoOpSpan."
)
return NoOpSpan(span_data)
elif isinstance(current_trace, NoOpTrace) or isinstance(
current_span, NoOpSpan
):
logger.debug(
f"Parent {current_span} or {current_trace} is no-op, returning NoOpSpan"
)
return NoOpSpan(span_data)
parent_id = current_span.span_id if current_span else None
trace_id = current_trace.trace_id
elif isinstance(parent, Trace):
if isinstance(parent, NoOpTrace):
logger.debug(f"Parent {parent} is no-op, returning NoOpSpan")
return NoOpSpan(span_data)
trace_id = parent.trace_id
parent_id = None
elif isinstance(parent, Span):
if isinstance(parent, NoOpSpan):
logger.debug(f"Parent {parent} is no-op, returning NoOpSpan")
return NoOpSpan(span_data)
parent_id = parent.span_id
trace_id = parent.trace_id
else:
# This should never happen, but mypy needs it
raise ValueError(f"Invalid parent type: {type(parent)}")
logger.debug(f"Creating span {span_data} with id {span_id}")
return SpanImpl(
trace_id=trace_id,
span_id=span_id or self.gen_span_id(),
parent_id=parent_id,
processor=self._multi_processor,
span_data=span_data,
)
def shutdown(self) -> None:
try:
logger.debug("Shutting down trace provider")
self._multi_processor.shutdown()
except Exception as e:
logger.error(f"Error shutting down trace provider: {e}")

View File

@@ -1,49 +0,0 @@
import contextvars
from typing import Any
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .spans import Span
from .traces import Trace
_current_span: contextvars.ContextVar["Span[Any] | None"] = contextvars.ContextVar(
"current_span", default=None
)
_current_trace: contextvars.ContextVar["Trace | None"] = contextvars.ContextVar(
"current_trace", default=None
)
class Scope:
"""
Manages the current span and trace in the context.
"""
@classmethod
def get_current_span(cls) -> "Span[Any] | None":
return _current_span.get()
@classmethod
def set_current_span(
cls, span: "Span[Any] | None"
) -> "contextvars.Token[Span[Any] | None]":
return _current_span.set(span)
@classmethod
def reset_current_span(cls, token: "contextvars.Token[Span[Any] | None]") -> None:
_current_span.reset(token)
@classmethod
def get_current_trace(cls) -> "Trace | None":
return _current_trace.get()
@classmethod
def set_current_trace(
cls, trace: "Trace | None"
) -> "contextvars.Token[Trace | None]":
return _current_trace.set(trace)
@classmethod
def reset_current_trace(cls, token: "contextvars.Token[Trace | None]") -> None:
_current_trace.reset(token)

View File

@@ -1,21 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .provider import TraceProvider
GLOBAL_TRACE_PROVIDER: TraceProvider | None = None
def set_trace_provider(provider: TraceProvider) -> None:
"""Set the global trace provider used by tracing utilities."""
global GLOBAL_TRACE_PROVIDER
GLOBAL_TRACE_PROVIDER = provider
def get_trace_provider() -> TraceProvider:
"""Get the global trace provider used by tracing utilities."""
if GLOBAL_TRACE_PROVIDER is None:
raise RuntimeError("Trace provider not set")
return GLOBAL_TRACE_PROVIDER

View File

@@ -1,130 +0,0 @@
import abc
from collections.abc import Mapping
from collections.abc import Sequence
from typing import Any
class SpanData(abc.ABC):
"""
Represents span data in the trace.
"""
@abc.abstractmethod
def export(self) -> dict[str, Any]:
"""Export the span data as a dictionary."""
@property
@abc.abstractmethod
def type(self) -> str:
"""Return the type of the span."""
class AgentSpanData(SpanData):
"""
Represents an Agent Span in the trace.
Includes name, handoffs, tools, and output type.
"""
__slots__ = ("name", "handoffs", "tools", "output_type")
def __init__(
self,
name: str,
handoffs: list[str] | None = None,
tools: list[str] | None = None,
output_type: str | None = None,
):
self.name = name
self.handoffs: list[str] | None = handoffs
self.tools: list[str] | None = tools
self.output_type: str | None = output_type
@property
def type(self) -> str:
return "agent"
def export(self) -> dict[str, Any]:
return {
"type": self.type,
"name": self.name,
"handoffs": self.handoffs,
"tools": self.tools,
"output_type": self.output_type,
}
class FunctionSpanData(SpanData):
"""
Represents a Function Span in the trace.
Includes input, output and MCP data (if applicable).
"""
__slots__ = ("name", "input", "output", "mcp_data")
def __init__(
self,
name: str,
input: str | None,
output: Any | None,
mcp_data: dict[str, Any] | None = None,
):
self.name = name
self.input = input
self.output = output
self.mcp_data = mcp_data
@property
def type(self) -> str:
return "function"
def export(self) -> dict[str, Any]:
return {
"type": self.type,
"name": self.name,
"input": self.input,
"output": str(self.output) if self.output else None,
"mcp_data": self.mcp_data,
}
class GenerationSpanData(SpanData):
"""
Represents a Generation Span in the trace.
Includes input, output, model, model configuration, and usage.
"""
__slots__ = (
"input",
"output",
"model",
"model_config",
"usage",
)
def __init__(
self,
input: Sequence[Mapping[str, Any]] | None = None,
output: Sequence[Mapping[str, Any]] | None = None,
model: str | None = None,
model_config: Mapping[str, Any] | None = None,
usage: dict[str, Any] | None = None,
):
self.input = input
self.output = output
self.model = model
self.model_config = model_config
self.usage = usage
@property
def type(self) -> str:
return "generation"
def export(self) -> dict[str, Any]:
return {
"type": self.type,
"input": self.input,
"output": self.output,
"model": self.model,
"model_config": self.model_config,
"usage": self.usage,
}

View File

@@ -1,356 +0,0 @@
from __future__ import annotations
import abc
import contextvars
from types import TracebackType
from typing import Any
from typing import Generic
from typing import TypeVar
from typing_extensions import TypedDict
from . import util
from .processor_interface import TracingProcessor
from .scope import Scope
from .span_data import SpanData
TSpanData = TypeVar("TSpanData", bound=SpanData)
class SpanError(TypedDict):
"""Represents an error that occurred during span execution.
Attributes:
message: A human-readable error description
data: Optional dictionary containing additional error context
"""
message: str
data: dict[str, Any] | None
class Span(abc.ABC, Generic[TSpanData]):
"""Base class for representing traceable operations with timing and context.
A span represents a single operation within a trace (e.g., an LLM call, tool execution,
or agent run). Spans track timing, relationships between operations, and operation-specific
data.
Type Args:
TSpanData: The type of span-specific data this span contains.
Example:
```python
# Creating a custom span
with custom_span("database_query", {
"operation": "SELECT",
"table": "users"
}) as span:
results = await db.query("SELECT * FROM users")
span.set_output({"count": len(results)})
# Handling errors in spans
with custom_span("risky_operation") as span:
try:
result = perform_risky_operation()
except Exception as e:
span.set_error({
"message": str(e),
"data": {"operation": "risky_operation"}
})
raise
```
Notes:
- Spans automatically nest under the current trace
- Use context managers for reliable start/finish
- Include relevant data but avoid sensitive information
- Handle errors properly using set_error()
"""
@property
@abc.abstractmethod
def trace_id(self) -> str:
"""The ID of the trace this span belongs to.
Returns:
str: Unique identifier of the parent trace.
"""
@property
@abc.abstractmethod
def span_id(self) -> str:
"""Unique identifier for this span.
Returns:
str: The span's unique ID within its trace.
"""
@property
@abc.abstractmethod
def span_data(self) -> TSpanData:
"""Operation-specific data for this span.
Returns:
TSpanData: Data specific to this type of span (e.g., LLM generation data).
"""
@abc.abstractmethod
def start(self, mark_as_current: bool = False) -> None:
"""
Start the span.
Args:
mark_as_current: If true, the span will be marked as the current span.
"""
@abc.abstractmethod
def finish(self, reset_current: bool = False) -> None:
"""
Finish the span.
Args:
reset_current: If true, the span will be reset as the current span.
"""
@abc.abstractmethod
def __enter__(self) -> Span[TSpanData]:
pass
@abc.abstractmethod
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
pass
@property
@abc.abstractmethod
def parent_id(self) -> str | None:
"""ID of the parent span, if any.
Returns:
str | None: The parent span's ID, or None if this is a root span.
"""
@abc.abstractmethod
def set_error(self, error: SpanError) -> None:
pass
@property
@abc.abstractmethod
def error(self) -> SpanError | None:
"""Any error that occurred during span execution.
Returns:
SpanError | None: Error details if an error occurred, None otherwise.
"""
@abc.abstractmethod
def export(self) -> dict[str, Any] | None:
pass
@property
@abc.abstractmethod
def started_at(self) -> str | None:
"""When the span started execution.
Returns:
str | None: ISO format timestamp of span start, None if not started.
"""
@property
@abc.abstractmethod
def ended_at(self) -> str | None:
"""When the span finished execution.
Returns:
str | None: ISO format timestamp of span end, None if not finished.
"""
class NoOpSpan(Span[TSpanData]):
"""A no-op implementation of Span that doesn't record any data.
Used when tracing is disabled but span operations still need to work.
Args:
span_data: The operation-specific data for this span.
"""
__slots__ = ("_span_data", "_prev_span_token")
def __init__(self, span_data: TSpanData):
self._span_data = span_data
self._prev_span_token: contextvars.Token[Span[TSpanData] | None] | None = None
@property
def trace_id(self) -> str:
return "no-op"
@property
def span_id(self) -> str:
return "no-op"
@property
def span_data(self) -> TSpanData:
return self._span_data
@property
def parent_id(self) -> str | None:
return None
def start(self, mark_as_current: bool = False) -> None:
if mark_as_current:
self._prev_span_token = Scope.set_current_span(self)
def finish(self, reset_current: bool = False) -> None:
if reset_current and self._prev_span_token is not None:
Scope.reset_current_span(self._prev_span_token)
self._prev_span_token = None
def __enter__(self) -> Span[TSpanData]:
self.start(mark_as_current=True)
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
reset_current = True
if exc_type is GeneratorExit:
reset_current = False
self.finish(reset_current=reset_current)
def set_error(self, error: SpanError) -> None:
pass
@property
def error(self) -> SpanError | None:
return None
def export(self) -> dict[str, Any] | None:
return None
@property
def started_at(self) -> str | None:
return None
@property
def ended_at(self) -> str | None:
return None
class SpanImpl(Span[TSpanData]):
__slots__ = (
"_trace_id",
"_span_id",
"_parent_id",
"_started_at",
"_ended_at",
"_error",
"_prev_span_token",
"_processor",
"_span_data",
)
def __init__(
self,
trace_id: str,
span_id: str | None,
parent_id: str | None,
processor: TracingProcessor,
span_data: TSpanData,
):
self._trace_id = trace_id
self._span_id = span_id or util.gen_span_id()
self._parent_id = parent_id
self._started_at: str | None = None
self._ended_at: str | None = None
self._processor = processor
self._error: SpanError | None = None
self._prev_span_token: contextvars.Token[Span[TSpanData] | None] | None = None
self._span_data = span_data
@property
def trace_id(self) -> str:
return self._trace_id
@property
def span_id(self) -> str:
return self._span_id
@property
def span_data(self) -> TSpanData:
return self._span_data
@property
def parent_id(self) -> str | None:
return self._parent_id
def start(self, mark_as_current: bool = False) -> None:
if self.started_at is not None:
return
self._started_at = util.time_iso()
self._processor.on_span_start(self)
if mark_as_current:
self._prev_span_token = Scope.set_current_span(self)
def finish(self, reset_current: bool = False) -> None:
if self.ended_at is not None:
return
self._ended_at = util.time_iso()
self._processor.on_span_end(self)
if reset_current and self._prev_span_token is not None:
Scope.reset_current_span(self._prev_span_token)
self._prev_span_token = None
def __enter__(self) -> Span[TSpanData]:
self.start(mark_as_current=True)
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
reset_current = True
if exc_type is GeneratorExit:
reset_current = False
self.finish(reset_current=reset_current)
def set_error(self, error: SpanError) -> None:
self._error = error
@property
def error(self) -> SpanError | None:
return self._error
@property
def started_at(self) -> str | None:
return self._started_at
@property
def ended_at(self) -> str | None:
return self._ended_at
def export(self) -> dict[str, Any] | None:
return {
"object": "trace.span",
"id": self.span_id,
"trace_id": self.trace_id,
"parent_id": self._parent_id,
"started_at": self._started_at,
"ended_at": self._ended_at,
"span_data": self.span_data.export(),
"error": self._error,
}

View File

@@ -1,287 +0,0 @@
from __future__ import annotations
import abc
import contextvars
from types import TracebackType
from typing import Any
from typing import TYPE_CHECKING
from . import util
from .scope import Scope
if TYPE_CHECKING:
from .processor_interface import TracingProcessor
class Trace(abc.ABC):
"""A complete end-to-end workflow containing related spans and metadata.
A trace represents a logical workflow or operation (e.g., "Customer Service Query"
or "Code Generation") and contains all the spans (individual operations) that occur
during that workflow.
Example:
```python
# Basic trace usage
with trace("Order Processing") as t:
validation_result = await Runner.run(validator, order_data)
if validation_result.approved:
await Runner.run(processor, order_data)
# Trace with metadata and grouping
with trace(
"Customer Service",
group_id="chat_123",
metadata={"customer": "user_456"}
) as t:
result = await Runner.run(support_agent, query)
```
Notes:
- Use descriptive workflow names
- Group related traces with consistent group_ids
- Add relevant metadata for filtering/analysis
- Use context managers for reliable cleanup
- Consider privacy when adding trace data
"""
@abc.abstractmethod
def __enter__(self) -> Trace:
pass
@abc.abstractmethod
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
pass
@abc.abstractmethod
def start(self, mark_as_current: bool = False) -> None:
"""Start the trace and optionally mark it as the current trace.
Args:
mark_as_current: If true, marks this trace as the current trace
in the execution context.
Notes:
- Must be called before any spans can be added
- Only one trace can be current at a time
- Thread-safe when using mark_as_current
"""
@abc.abstractmethod
def finish(self, reset_current: bool = False) -> None:
"""Finish the trace and optionally reset the current trace.
Args:
reset_current: If true, resets the current trace to the previous
trace in the execution context.
Notes:
- Must be called to complete the trace
- Finalizes all open spans
- Thread-safe when using reset_current
"""
@property
@abc.abstractmethod
def trace_id(self) -> str:
"""Get the unique identifier for this trace.
Returns:
str: The trace's unique ID in the format 'trace_<32_alphanumeric>'
Notes:
- IDs are globally unique
- Used to link spans to their parent trace
- Can be used to look up traces in the dashboard
"""
@property
@abc.abstractmethod
def name(self) -> str:
"""Get the human-readable name of this workflow trace.
Returns:
str: The workflow name (e.g., "Customer Service", "Data Processing")
Notes:
- Should be descriptive and meaningful
- Used for grouping and filtering in the dashboard
- Helps identify the purpose of the trace
"""
@abc.abstractmethod
def export(self) -> dict[str, Any] | None:
"""Export the trace data as a serializable dictionary.
Returns:
dict | None: Dictionary containing trace data, or None if tracing is disabled.
Notes:
- Includes all spans and their data
- Used for sending traces to backends
- May include metadata and group ID
"""
class NoOpTrace(Trace):
"""A no-op implementation of Trace that doesn't record any data.
Used when tracing is disabled but trace operations still need to work.
Maintains proper context management but doesn't store or export any data.
Example:
```python
# When tracing is disabled, traces become NoOpTrace
with trace("Disabled Workflow") as t:
# Operations still work but nothing is recorded
await Runner.run(agent, "query")
```
"""
def __init__(self) -> None:
self._started = False
self._prev_context_token: contextvars.Token[Trace | None] | None = None
def __enter__(self) -> Trace:
if self._started:
return self
self._started = True
self.start(mark_as_current=True)
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.finish(reset_current=True)
def start(self, mark_as_current: bool = False) -> None:
if mark_as_current:
self._prev_context_token = Scope.set_current_trace(self)
def finish(self, reset_current: bool = False) -> None:
if reset_current and self._prev_context_token is not None:
Scope.reset_current_trace(self._prev_context_token)
self._prev_context_token = None
@property
def trace_id(self) -> str:
"""The trace's unique identifier.
Returns:
str: A unique ID for this trace.
"""
return "no-op"
@property
def name(self) -> str:
"""The workflow name for this trace.
Returns:
str: Human-readable name describing this workflow.
"""
return "no-op"
def export(self) -> dict[str, Any] | None:
"""Export the trace data as a dictionary.
Returns:
dict | None: Trace data in exportable format, or None if no data.
"""
return None
NO_OP_TRACE = NoOpTrace()
class TraceImpl(Trace):
"""
A trace that will be recorded by the tracing library.
"""
__slots__ = (
"_name",
"_trace_id",
"group_id",
"metadata",
"_prev_context_token",
"_processor",
"_started",
)
def __init__(
self,
name: str,
trace_id: str | None,
group_id: str | None,
metadata: dict[str, Any] | None,
processor: TracingProcessor,
):
self._name = name
self._trace_id = trace_id or util.gen_trace_id()
self.group_id = group_id
self.metadata = metadata
self._prev_context_token: contextvars.Token[Trace | None] | None = None
self._processor = processor
self._started = False
@property
def trace_id(self) -> str:
return self._trace_id
@property
def name(self) -> str:
return self._name
def start(self, mark_as_current: bool = False) -> None:
if self._started:
return
self._started = True
self._processor.on_trace_start(self)
if mark_as_current:
self._prev_context_token = Scope.set_current_trace(self)
def finish(self, reset_current: bool = False) -> None:
if not self._started:
return
self._processor.on_trace_end(self)
if reset_current and self._prev_context_token is not None:
Scope.reset_current_trace(self._prev_context_token)
self._prev_context_token = None
def __enter__(self) -> Trace:
if self._started:
return self
self.start(mark_as_current=True)
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.finish(reset_current=exc_type is not GeneratorExit)
def export(self) -> dict[str, Any] | None:
return {
"object": "trace",
"id": self.trace_id,
"workflow_name": self.name,
"group_id": self.group_id,
"metadata": self.metadata,
}

View File

@@ -1,18 +0,0 @@
import uuid
from datetime import datetime
from datetime import timezone
def time_iso() -> str:
"""Return the current time in ISO 8601 format."""
return datetime.now(timezone.utc).isoformat()
def gen_trace_id() -> str:
"""Generate a new trace ID."""
return f"trace_{uuid.uuid4().hex}"
def gen_span_id() -> str:
"""Generate a new span ID."""
return f"span_{uuid.uuid4().hex[:24]}"

View File

@@ -1,3 +1,4 @@
from onyx.configs.app_configs import LANGFUSE_HOST
from onyx.configs.app_configs import LANGFUSE_PUBLIC_KEY
from onyx.configs.app_configs import LANGFUSE_SECRET_KEY
from onyx.utils.logger import setup_logger
@@ -12,20 +13,16 @@ def setup_langfuse_if_creds_available() -> None:
return
import nest_asyncio # type: ignore
nest_asyncio.apply()
from langfuse import get_client
from openinference.instrumentation.openai_agents import OpenAIAgentsInstrumentor
nest_asyncio.apply()
OpenAIAgentsInstrumentor().instrument()
# TODO: this is how the tracing processor will look once we migrate over to new framework
# config = TraceConfig()
# tracer_provider = trace_api.get_tracer_provider()
# tracer = OITracer(
# trace_api.get_tracer(__name__, __version__, tracer_provider),
# config=config,
# )
# set_trace_processors(
# [OpenInferenceTracingProcessor(cast(trace_api.Tracer, tracer))]
# )
langfuse = get_client()
try:
if langfuse.auth_check():
logger.notice(f"Langfuse authentication successful (host: {LANGFUSE_HOST})")
else:
logger.warning("Langfuse authentication failed")
except Exception as e:
logger.error(f"Error setting up Langfuse: {e}")

View File

@@ -1,450 +0,0 @@
from __future__ import annotations
import logging
from collections import OrderedDict
from collections.abc import Iterable
from collections.abc import Iterator
from collections.abc import Mapping
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import Optional
from typing import Union
from openinference.instrumentation import safe_json_dumps
from openinference.semconv.trace import MessageAttributes
from openinference.semconv.trace import MessageContentAttributes
from openinference.semconv.trace import OpenInferenceLLMProviderValues
from openinference.semconv.trace import OpenInferenceLLMSystemValues
from openinference.semconv.trace import OpenInferenceMimeTypeValues
from openinference.semconv.trace import OpenInferenceSpanKindValues
from openinference.semconv.trace import SpanAttributes
from openinference.semconv.trace import ToolAttributes
from openinference.semconv.trace import ToolCallAttributes
from opentelemetry.context import attach
from opentelemetry.context import detach
from opentelemetry.trace import set_span_in_context
from opentelemetry.trace import Span as OtelSpan
from opentelemetry.trace import Status
from opentelemetry.trace import StatusCode
from opentelemetry.trace import Tracer
from opentelemetry.util.types import AttributeValue
from onyx.tracing.framework.processor_interface import TracingProcessor
from onyx.tracing.framework.span_data import AgentSpanData
from onyx.tracing.framework.span_data import FunctionSpanData
from onyx.tracing.framework.span_data import GenerationSpanData
from onyx.tracing.framework.span_data import SpanData
from onyx.tracing.framework.spans import Span
from onyx.tracing.framework.traces import Trace
logger = logging.getLogger(__name__)
class OpenInferenceTracingProcessor(TracingProcessor):
_MAX_HANDOFFS_IN_FLIGHT = 1000
def __init__(self, tracer: Tracer) -> None:
self._tracer = tracer
self._root_spans: dict[str, OtelSpan] = {}
self._otel_spans: dict[str, OtelSpan] = {}
self._tokens: dict[str, object] = {}
# This captures in flight handoff. Once the handoff is complete, the entry is deleted
# If the handoff does not complete, the entry stays in the dict.
# Use an OrderedDict and _MAX_HANDOFFS_IN_FLIGHT to cap the size of the dict
# in case there are large numbers of orphaned handoffs
self._reverse_handoffs_dict: OrderedDict[str, str] = OrderedDict()
self._first_input: dict[str, Any] = {}
self._last_output: dict[str, Any] = {}
def on_trace_start(self, trace: Trace) -> None:
"""Called when a trace is started.
Args:
trace: The trace that started.
"""
otel_span = self._tracer.start_span(
name=trace.name,
attributes={
OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.AGENT.value,
},
)
self._root_spans[trace.trace_id] = otel_span
def on_trace_end(self, trace: Trace) -> None:
"""Called when a trace is finished.
Args:
trace: The trace that started.
"""
if root_span := self._root_spans.pop(trace.trace_id, None):
# Get the first input and last output for this specific trace
trace_first_input = self._first_input.pop(trace.trace_id, None)
trace_last_output = self._last_output.pop(trace.trace_id, None)
# Set input/output attributes on the root span
if trace_first_input is not None:
try:
root_span.set_attribute(
INPUT_VALUE, safe_json_dumps(trace_first_input)
)
root_span.set_attribute(INPUT_MIME_TYPE, JSON)
except Exception:
# Fallback to string if JSON serialization fails
root_span.set_attribute(INPUT_VALUE, str(trace_first_input))
if trace_last_output is not None:
try:
root_span.set_attribute(
OUTPUT_VALUE, safe_json_dumps(trace_last_output)
)
root_span.set_attribute(OUTPUT_MIME_TYPE, JSON)
except Exception:
# Fallback to string if JSON serialization fails
root_span.set_attribute(OUTPUT_VALUE, str(trace_last_output))
root_span.set_status(Status(StatusCode.OK))
root_span.end()
else:
# Clean up stored input/output for this trace if root span doesn't exist
self._first_input.pop(trace.trace_id, None)
self._last_output.pop(trace.trace_id, None)
def on_span_start(self, span: Span[Any]) -> None:
"""Called when a span is started.
Args:
span: The span that started.
"""
if not span.started_at:
return
start_time = datetime.fromisoformat(span.started_at)
parent_span = (
self._otel_spans.get(span.parent_id)
if span.parent_id
else self._root_spans.get(span.trace_id)
)
context = set_span_in_context(parent_span) if parent_span else None
span_name = _get_span_name(span)
otel_span = self._tracer.start_span(
name=span_name,
context=context,
start_time=_as_utc_nano(start_time),
attributes={
OPENINFERENCE_SPAN_KIND: _get_span_kind(span.span_data),
LLM_SYSTEM: OpenInferenceLLMSystemValues.OPENAI.value,
},
)
self._otel_spans[span.span_id] = otel_span
self._tokens[span.span_id] = attach(set_span_in_context(otel_span))
def on_span_end(self, span: Span[Any]) -> None:
"""Called when a span is finished. Should not block or raise exceptions.
Args:
span: The span that finished.
"""
if token := self._tokens.pop(span.span_id, None):
detach(token) # type: ignore[arg-type]
if not (otel_span := self._otel_spans.pop(span.span_id, None)):
return
otel_span.update_name(_get_span_name(span))
# flatten_attributes: dict[str, AttributeValue] = dict(_flatten(span.export()))
# otel_span.set_attributes(flatten_attributes)
data = span.span_data
if isinstance(data, GenerationSpanData):
for k, v in _get_attributes_from_generation_span_data(data):
otel_span.set_attribute(k, v)
elif isinstance(data, FunctionSpanData):
for k, v in _get_attributes_from_function_span_data(data):
otel_span.set_attribute(k, v)
elif isinstance(data, AgentSpanData):
otel_span.set_attribute(GRAPH_NODE_ID, data.name)
# Lookup the parent node if exists
key = f"{data.name}:{span.trace_id}"
if parent_node := self._reverse_handoffs_dict.pop(key, None):
otel_span.set_attribute(GRAPH_NODE_PARENT_ID, parent_node)
end_time: Optional[int] = None
if span.ended_at:
try:
end_time = _as_utc_nano(datetime.fromisoformat(span.ended_at))
except ValueError:
pass
otel_span.set_status(status=_get_span_status(span))
otel_span.end(end_time)
# Store first input and last output per trace_id
trace_id = span.trace_id
input_: Optional[Any] = None
output: Optional[Any] = None
if isinstance(data, FunctionSpanData):
input_ = data.input
output = data.output
elif isinstance(data, GenerationSpanData):
input_ = data.input
output = data.output
if trace_id not in self._first_input and input_ is not None:
self._first_input[trace_id] = input_
if output is not None:
self._last_output[trace_id] = output
def force_flush(self) -> None:
"""Forces an immediate flush of all queued spans/traces."""
# TODO
def shutdown(self) -> None:
"""Called when the application stops."""
# TODO
def _as_utc_nano(dt: datetime) -> int:
return int(dt.astimezone(timezone.utc).timestamp() * 1_000_000_000)
def _get_span_name(obj: Span[Any]) -> str:
if hasattr(data := obj.span_data, "name") and isinstance(name := data.name, str):
return name
return obj.span_data.type # type: ignore[no-any-return]
def _get_span_kind(obj: SpanData) -> str:
if isinstance(obj, AgentSpanData):
return OpenInferenceSpanKindValues.AGENT.value
if isinstance(obj, FunctionSpanData):
return OpenInferenceSpanKindValues.TOOL.value
if isinstance(obj, GenerationSpanData):
return OpenInferenceSpanKindValues.LLM.value
return OpenInferenceSpanKindValues.CHAIN.value
def _get_attributes_from_generation_span_data(
obj: GenerationSpanData,
) -> Iterator[tuple[str, AttributeValue]]:
if isinstance(model := obj.model, str):
yield LLM_MODEL_NAME, model
if isinstance(obj.model_config, dict) and (
param := {k: v for k, v in obj.model_config.items() if v is not None}
):
yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(param)
if base_url := param.get("base_url"):
if "api.openai.com" in base_url:
yield LLM_PROVIDER, OpenInferenceLLMProviderValues.OPENAI.value
yield from _get_attributes_from_chat_completions_input(obj.input)
yield from _get_attributes_from_chat_completions_output(obj.output)
yield from _get_attributes_from_chat_completions_usage(obj.usage)
def _get_attributes_from_chat_completions_input(
obj: Optional[Iterable[Mapping[str, Any]]],
) -> Iterator[tuple[str, AttributeValue]]:
if not obj:
return
try:
yield INPUT_VALUE, safe_json_dumps(obj)
yield INPUT_MIME_TYPE, JSON
except Exception:
pass
yield from _get_attributes_from_chat_completions_message_dicts(
obj,
f"{LLM_INPUT_MESSAGES}.",
)
def _get_attributes_from_chat_completions_output(
obj: Optional[Iterable[Mapping[str, Any]]],
) -> Iterator[tuple[str, AttributeValue]]:
if not obj:
return
try:
yield OUTPUT_VALUE, safe_json_dumps(obj)
yield OUTPUT_MIME_TYPE, JSON
except Exception:
pass
yield from _get_attributes_from_chat_completions_message_dicts(
obj,
f"{LLM_OUTPUT_MESSAGES}.",
)
def _get_attributes_from_chat_completions_message_dicts(
obj: Iterable[Mapping[str, Any]],
prefix: str = "",
msg_idx: int = 0,
tool_call_idx: int = 0,
) -> Iterator[tuple[str, AttributeValue]]:
if not isinstance(obj, Iterable):
return
for msg in obj:
if isinstance(role := msg.get("role"), str):
yield f"{prefix}{msg_idx}.{MESSAGE_ROLE}", role
if content := msg.get("content"):
yield from _get_attributes_from_chat_completions_message_content(
content,
f"{prefix}{msg_idx}.",
)
if isinstance(tool_call_id := msg.get("tool_call_id"), str):
yield f"{prefix}{msg_idx}.{MESSAGE_TOOL_CALL_ID}", tool_call_id
if isinstance(tool_calls := msg.get("tool_calls"), Iterable):
for tc in tool_calls:
yield from _get_attributes_from_chat_completions_tool_call_dict(
tc,
f"{prefix}{msg_idx}.{MESSAGE_TOOL_CALLS}.{tool_call_idx}.",
)
tool_call_idx += 1
msg_idx += 1
def _get_attributes_from_chat_completions_message_content(
obj: Union[str, Iterable[Mapping[str, Any]]],
prefix: str = "",
) -> Iterator[tuple[str, AttributeValue]]:
if isinstance(obj, str):
yield f"{prefix}{MESSAGE_CONTENT}", obj
elif isinstance(obj, Iterable):
for i, item in enumerate(obj):
if not isinstance(item, Mapping):
continue
yield from _get_attributes_from_chat_completions_message_content_item(
item,
f"{prefix}{MESSAGE_CONTENTS}.{i}.",
)
def _get_attributes_from_chat_completions_message_content_item(
obj: Mapping[str, Any],
prefix: str = "",
) -> Iterator[tuple[str, AttributeValue]]:
if obj.get("type") == "text" and (text := obj.get("text")):
yield f"{prefix}{MESSAGE_CONTENT_TYPE}", "text"
yield f"{prefix}{MESSAGE_CONTENT_TEXT}", text
def _get_attributes_from_chat_completions_tool_call_dict(
obj: Mapping[str, Any],
prefix: str = "",
) -> Iterator[tuple[str, AttributeValue]]:
if id_ := obj.get("id"):
yield f"{prefix}{TOOL_CALL_ID}", id_
if function := obj.get("function"):
if name := function.get("name"):
yield f"{prefix}{TOOL_CALL_FUNCTION_NAME}", name
if arguments := function.get("arguments"):
if arguments != "{}":
yield f"{prefix}{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}", arguments
def _get_attributes_from_chat_completions_usage(
obj: Optional[Mapping[str, Any]],
) -> Iterator[tuple[str, AttributeValue]]:
if not obj:
return
if input_tokens := obj.get("input_tokens"):
yield LLM_TOKEN_COUNT_PROMPT, input_tokens
if output_tokens := obj.get("output_tokens"):
yield LLM_TOKEN_COUNT_COMPLETION, output_tokens
# convert dict, tuple, etc into one of these types ['bool', 'str', 'bytes', 'int', 'float']
def _convert_to_primitive(value: Any) -> Union[bool, str, bytes, int, float]:
if isinstance(value, (bool, str, bytes, int, float)):
return value
if isinstance(value, (list, tuple)):
return safe_json_dumps(value)
if isinstance(value, dict):
return safe_json_dumps(value)
return str(value)
def _get_attributes_from_function_span_data(
obj: FunctionSpanData,
) -> Iterator[tuple[str, AttributeValue]]:
yield TOOL_NAME, obj.name
if obj.input:
yield INPUT_VALUE, obj.input
yield INPUT_MIME_TYPE, JSON
if obj.output is not None:
yield OUTPUT_VALUE, _convert_to_primitive(obj.output)
if (
isinstance(obj.output, str)
and len(obj.output) > 1
and obj.output[0] == "{"
and obj.output[-1] == "}"
):
yield OUTPUT_MIME_TYPE, JSON
def _get_span_status(obj: Span[Any]) -> Status:
if error := getattr(obj, "error", None):
return Status(
status_code=StatusCode.ERROR,
description=f"{error.get('message')}: {error.get('data')}",
)
else:
return Status(StatusCode.OK)
def _flatten(
obj: Mapping[str, Any],
prefix: str = "",
) -> Iterator[tuple[str, AttributeValue]]:
for key, value in obj.items():
if isinstance(value, dict):
yield from _flatten(value, f"{prefix}{key}.")
elif isinstance(value, (str, int, float, bool, str)):
yield f"{prefix}{key}", value
else:
yield f"{prefix}{key}", str(value)
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
INPUT_VALUE = SpanAttributes.INPUT_VALUE
LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS
LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
LLM_PROVIDER = SpanAttributes.LLM_PROVIDER
LLM_SYSTEM = SpanAttributes.LLM_SYSTEM
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ = (
SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ
)
LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING = (
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING
)
LLM_TOOLS = SpanAttributes.LLM_TOOLS
METADATA = SpanAttributes.METADATA
OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
TOOL_DESCRIPTION = SpanAttributes.TOOL_DESCRIPTION
TOOL_NAME = SpanAttributes.TOOL_NAME
TOOL_PARAMETERS = SpanAttributes.TOOL_PARAMETERS
GRAPH_NODE_ID = SpanAttributes.GRAPH_NODE_ID
GRAPH_NODE_PARENT_ID = SpanAttributes.GRAPH_NODE_PARENT_ID
MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
MESSAGE_CONTENTS = MessageAttributes.MESSAGE_CONTENTS
MESSAGE_CONTENT_IMAGE = MessageContentAttributes.MESSAGE_CONTENT_IMAGE
MESSAGE_CONTENT_TEXT = MessageContentAttributes.MESSAGE_CONTENT_TEXT
MESSAGE_CONTENT_TYPE = MessageContentAttributes.MESSAGE_CONTENT_TYPE
MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON = (
MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON
)
MESSAGE_FUNCTION_CALL_NAME = MessageAttributes.MESSAGE_FUNCTION_CALL_NAME
MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
MESSAGE_TOOL_CALL_ID = MessageAttributes.MESSAGE_TOOL_CALL_ID
TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
TOOL_CALL_ID = ToolCallAttributes.TOOL_CALL_ID
TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
JSON = OpenInferenceMimeTypeValues.JSON.value

View File

@@ -50,7 +50,7 @@ nltk==3.9.1
Office365-REST-Python-Client==2.5.9
oauthlib==3.2.2
openai==2.6.1
openpyxl==3.0.10
openpyxl==3.1.5
passlib==1.7.4
playwright==1.55.0
psutil==5.9.5
@@ -84,7 +84,7 @@ supervisor==4.2.5
RapidFuzz==3.13.0
tiktoken==0.7.0
timeago==1.0.16
types-openpyxl==3.0.4.7
types-openpyxl==3.1.5.20250919
unstructured==0.15.1
unstructured-client==0.25.4
uvicorn==0.35.0
@@ -108,7 +108,7 @@ exa_py==1.15.4
braintrust[openai-agents]==0.2.6
braintrust-langchain==0.0.4
openai-agents==0.4.2
langfuse==3.10.0
langfuse==3.7.0
nest_asyncio==1.6.0
openinference-instrumentation-openai-agents==1.3.0
opentelemetry-proto==1.38.0

View File

@@ -19,14 +19,3 @@ class RerankerProvider(str, Enum):
class EmbedTextType(str, Enum):
QUERY = "query"
PASSAGE = "passage"
class WebSearchProviderType(str, Enum):
GOOGLE_PSE = "google_pse"
SERPER = "serper"
EXA = "exa"
class WebContentProviderType(str, Enum):
ONYX_WEB_CRAWLER = "onyx_web_crawler"
FIRECRAWL = "firecrawl"

View File

@@ -154,7 +154,7 @@ def test_versions_endpoint(client: TestClient) -> None:
assert migration["onyx"] == "airgapped-intfloat-nomic-migration"
assert migration["relational_db"] == "postgres:15.2-alpine"
assert migration["index"] == "vespaengine/vespa:8.277.17"
assert migration["nginx"] == "nginx:1.25.5-alpine"
assert migration["nginx"] == "nginx:1.23.4-alpine"
# Verify versions are different between stable and dev
assert stable["onyx"] != dev["onyx"], "Stable and dev versions should be different"

View File

@@ -2,8 +2,6 @@ import os
from datetime import datetime
from datetime import timezone
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
@@ -520,52 +518,3 @@ class TestHubSpotConnector:
note_url = connector._get_object_url("notes", "44444")
expected_note_url = "https://app.hubspot.com/contacts/12345/objects/0-4/44444"
assert note_url == expected_note_url
def test_ticket_with_none_content(self) -> None:
"""Test that tickets with None content are handled gracefully."""
connector = HubSpotConnector(object_types=["tickets"], batch_size=10)
connector._access_token = "mock_token"
connector._portal_id = "mock_portal_id"
# Create a mock ticket with None content
mock_ticket = MagicMock()
mock_ticket.id = "12345"
mock_ticket.properties = {
"subject": "Test Ticket",
"content": None, # This is the key test case
"hs_ticket_priority": "HIGH",
}
mock_ticket.updated_at = datetime.now(timezone.utc)
# Mock the HubSpot API client
mock_api_client = MagicMock()
# Mock the API calls and associated object methods
with patch(
"onyx.connectors.hubspot.connector.HubSpot"
) as MockHubSpot, patch.object(
connector, "_paginated_results"
) as mock_paginated, patch.object(
connector, "_get_associated_objects", return_value=[]
), patch.object(
connector, "_get_associated_notes", return_value=[]
):
MockHubSpot.return_value = mock_api_client
mock_paginated.return_value = iter([mock_ticket])
# This should not raise a validation error
document_batches = connector._process_tickets()
first_batch = next(document_batches, None)
# Verify the document was created successfully
assert first_batch is not None
assert len(first_batch) == 1
doc = first_batch[0]
assert doc.id == "hubspot_ticket_12345"
assert doc.semantic_identifier == "Test Ticket"
# Verify the first section has an empty string, not None
assert len(doc.sections) > 0
assert doc.sections[0].text == "" # Should be empty string, not None
assert doc.sections[0].link is not None

View File

@@ -1,6 +1,4 @@
import os
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import Mock
from unittest.mock import patch
from uuid import uuid4
@@ -10,11 +8,9 @@ os.environ["MODEL_SERVER_HOST"] = "disabled"
os.environ["MODEL_SERVER_PORT"] = "9000"
from sqlalchemy.orm import Session
from slack_sdk.errors import SlackApiError
from onyx.configs.constants import FederatedConnectorSource
from onyx.context.search.enums import RecencyBiasSetting
from onyx.context.search.federated.slack_search import fetch_and_cache_channel_metadata
from onyx.db.models import DocumentSet
from onyx.db.models import FederatedConnector
from onyx.db.models import LLMProvider
@@ -577,174 +573,3 @@ class TestSlackBotFederatedSearch:
finally:
self._teardown_common_mocks(patches)
@patch("onyx.context.search.federated.slack_search.get_redis_client")
@patch("onyx.context.search.federated.slack_search.WebClient")
def test_missing_scope_resilience(
mock_web_client: Mock, mock_redis_client: Mock
) -> None:
"""Test that missing scopes are handled gracefully"""
# Setup mock Redis client
mock_redis = MagicMock()
mock_redis.get.return_value = None # Cache miss
mock_redis_client.return_value = mock_redis
# Setup mock Slack client that simulates missing_scope error
mock_client_instance = MagicMock()
mock_web_client.return_value = mock_client_instance
# Track which channel types were attempted
attempted_types: list[str] = []
def mock_conversations_list(types: str | None = None, **kwargs: Any) -> MagicMock:
if types:
attempted_types.append(types)
# First call: all types including mpim -> missing_scope error
if types and "mpim" in types:
error_response = {
"ok": False,
"error": "missing_scope",
"needed": "mpim:read",
"provided": "identify,channels:history,channels:read,groups:read,im:read,search:read",
}
raise SlackApiError("missing_scope", error_response)
# Second call: without mpim -> success
mock_response = MagicMock()
mock_response.validate.return_value = None
mock_response.data = {
"channels": [
{
"id": "C1234567890",
"name": "general",
"is_channel": True,
"is_private": False,
"is_group": False,
"is_mpim": False,
"is_im": False,
"is_member": True,
},
{
"id": "D9876543210",
"name": "",
"is_channel": False,
"is_private": False,
"is_group": False,
"is_mpim": False,
"is_im": True,
"is_member": True,
},
],
"response_metadata": {},
}
return mock_response
mock_client_instance.conversations_list.side_effect = mock_conversations_list
# Call the function
result = fetch_and_cache_channel_metadata(
access_token="xoxp-test-token",
team_id="T1234567890",
include_private=True,
)
# Assertions
# Should have attempted twice: once with mpim, once without
assert len(attempted_types) == 2, f"Expected 2 attempts, got {len(attempted_types)}"
assert "mpim" in attempted_types[0], "First attempt should include mpim"
assert "mpim" not in attempted_types[1], "Second attempt should not include mpim"
# Should have successfully returned channels despite missing scope
assert len(result) == 2, f"Expected 2 channels, got {len(result)}"
assert "C1234567890" in result, "Should have public channel"
assert "D9876543210" in result, "Should have DM channel"
# Verify channel metadata structure
assert result["C1234567890"]["name"] == "general"
assert result["C1234567890"]["type"] == "public_channel"
assert result["D9876543210"]["type"] == "im"
@patch("onyx.context.search.federated.slack_search.get_redis_client")
@patch("onyx.context.search.federated.slack_search.WebClient")
def test_multiple_missing_scopes_resilience(
mock_web_client: Mock, mock_redis_client: Mock
) -> None:
"""Test handling multiple missing scopes gracefully"""
# Setup mock Redis client
mock_redis = MagicMock()
mock_redis.get.return_value = None # Cache miss
mock_redis_client.return_value = mock_redis
# Setup mock Slack client
mock_client_instance = MagicMock()
mock_web_client.return_value = mock_client_instance
# Track attempts
attempted_types: list[str] = []
def mock_conversations_list(types: str | None = None, **kwargs: Any) -> MagicMock:
if types:
attempted_types.append(types)
# First: mpim missing
if types and "mpim" in types:
error_response = {
"ok": False,
"error": "missing_scope",
"needed": "mpim:read",
"provided": "identify,channels:history,channels:read,groups:read",
}
raise SlackApiError("missing_scope", error_response)
# Second: im missing
if types and "im" in types:
error_response = {
"ok": False,
"error": "missing_scope",
"needed": "im:read",
"provided": "identify,channels:history,channels:read,groups:read",
}
raise SlackApiError("missing_scope", error_response)
# Third: success with only public and private channels
mock_response = MagicMock()
mock_response.validate.return_value = None
mock_response.data = {
"channels": [
{
"id": "C1234567890",
"name": "general",
"is_channel": True,
"is_private": False,
"is_group": False,
"is_mpim": False,
"is_im": False,
"is_member": True,
}
],
"response_metadata": {},
}
return mock_response
mock_client_instance.conversations_list.side_effect = mock_conversations_list
# Call the function
result = fetch_and_cache_channel_metadata(
access_token="xoxp-test-token",
team_id="T1234567890",
include_private=True,
)
# Should gracefully handle multiple missing scopes
assert len(attempted_types) == 3, f"Expected 3 attempts, got {len(attempted_types)}"
assert "mpim" in attempted_types[0], "First attempt should include mpim"
assert "mpim" not in attempted_types[1], "Second attempt should not include mpim"
assert "im" in attempted_types[1], "Second attempt should include im"
assert "im" not in attempted_types[2], "Third attempt should not include im"
# Should still return available channels
assert len(result) == 1, f"Expected 1 channel, got {len(result)}"
assert result["C1234567890"]["name"] == "general"

View File

@@ -140,22 +140,6 @@ class CCPairManager:
)
result.raise_for_status()
@staticmethod
def unpause_cc_pair(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> None:
result = requests.put(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/status",
json={"status": "ACTIVE"},
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
result.raise_for_status()
@staticmethod
def delete(
cc_pair: DATestCCPair,

View File

@@ -382,7 +382,6 @@ def test_mock_connector_checkpoint_recovery(
assert response.status_code == 200
# Create CC Pair and run initial indexing attempt
# Note: Setting refresh_freq to allow manual retrigger after failure
cc_pair = CCPairManager.create_from_scratch(
name=f"mock-connector-checkpoint-{uuid.uuid4()}",
source=DocumentSource.MOCK_CONNECTOR,
@@ -392,7 +391,6 @@ def test_mock_connector_checkpoint_recovery(
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
},
user_performing_action=admin_user,
refresh_freq=60 * 60, # 1 hour
)
# Wait for first index attempt to complete
@@ -465,14 +463,10 @@ def test_mock_connector_checkpoint_recovery(
)
assert response.status_code == 200
# After the failure, the connector is in repeated error state and paused.
# Set the manual indexing trigger first (while paused), then unpause.
# This ensures the trigger is set before CHECK_FOR_INDEXING runs, which will
# prevent the connector from being re-paused when repeated error state is detected.
# Trigger another indexing attempt
CCPairManager.run_once(
cc_pair, from_beginning=False, user_performing_action=admin_user
)
CCPairManager.unpause_cc_pair(cc_pair, user_performing_action=admin_user)
recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
cc_pair_id=cc_pair.id,
index_attempts_to_ignore=[initial_index_attempt.id],

View File

@@ -11,7 +11,6 @@ from onyx.connectors.mock_connector.connector import MockConnectorCheckpoint
from onyx.connectors.models import InputType
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_HOST
from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_PORT
@@ -122,11 +121,6 @@ def test_repeated_error_state_detection_and_recovery(
)
assert cc_pair_obj is not None
if cc_pair_obj.in_repeated_error_state:
# Verify the connector is also paused to prevent further indexing attempts
assert cc_pair_obj.status == ConnectorCredentialPairStatus.PAUSED, (
f"Expected status to be PAUSED when in repeated error state, "
f"but got {cc_pair_obj.status}"
)
break
if time.monotonic() - start_time > 30:
@@ -151,13 +145,10 @@ def test_repeated_error_state_detection_and_recovery(
)
assert response.status_code == 200
# Set the manual indexing trigger first (while paused), then unpause.
# This ensures the trigger is set before CHECK_FOR_INDEXING runs, which will
# prevent the connector from being re-paused when repeated error state is detected.
# Run another indexing attempt that should succeed
CCPairManager.run_once(
cc_pair, from_beginning=True, user_performing_action=admin_user
)
CCPairManager.unpause_cc_pair(cc_pair, user_performing_action=admin_user)
recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
cc_pair_id=cc_pair.id,

View File

@@ -140,84 +140,7 @@ def tool_call_chunk(
)
class FakeErrorTool(Tool):
"""Base fake tool for testing."""
def __init__(self, tool_name: str, tool_id: int = 1):
self._tool_name = tool_name
self._tool_id = tool_id
self.calls: list[dict[str, Any]] = []
@property
def id(self) -> int:
return self._tool_id
@property
def name(self) -> str:
return self._tool_name
@property
def description(self) -> str:
return f"{self._tool_name} tool"
@property
def display_name(self) -> str:
return self._tool_name.replace("_", " ").title()
def tool_definition(self) -> dict:
return {
"type": "function",
"function": {
"name": self._tool_name,
"description": self.description,
"parameters": {
"type": "object",
"properties": {
"queries": {
"type": "array",
"items": {"type": "string"},
}
},
"required": ["queries"],
},
},
}
def run_v2(
self,
run_context: RunContextWrapper[Any],
*args: Any,
**kwargs: Any,
) -> Any:
raise Exception("Error running tool")
def build_tool_message_content(self, *args: Any) -> str:
return ""
def get_args_for_non_tool_calling_llm(
self, query: Any, history: Any, llm: Any, force_run: bool = False
) -> None:
return None
def run(
self, override_kwargs: Any = None, **llm_kwargs: Any
) -> Generator[ToolResponse, None, None]:
raise NotImplementedError
yield # Make this a generator
def final_result(self, *args: Any) -> dict[str, Any]:
return {}
def build_next_prompt(
self,
prompt_builder: Any,
tool_call_summary: Any,
tool_responses: Any,
using_tool_calling_llm: Any,
) -> Any:
return prompt_builder
# Fake tools for testing
class FakeTool(Tool):
"""Base fake tool for testing."""
@@ -269,9 +192,7 @@ class FakeTool(Tool):
) -> Any:
queries = kwargs.get("queries", [])
self.calls.append({"queries": queries})
context = run_context.context
flag_name = f"{self._tool_name}_called"
context[flag_name] = True
run_context.context[f"{self._tool_name}_called"] = True
return f"{self.display_name} results for: {', '.join(queries)}"
def build_tool_message_content(self, *args: Any) -> str:
@@ -311,9 +232,3 @@ def fake_internal_search_tool() -> FakeTool:
def fake_web_search_tool() -> FakeTool:
"""Fixture providing a fake web search tool."""
return FakeTool("web_search", tool_id=2)
@pytest.fixture
def fake_error_tool() -> FakeErrorTool:
"""Fixture providing a fake error tool."""
return FakeErrorTool("error_tool", tool_id=3)

View File

@@ -7,7 +7,6 @@ from onyx.agents.agent_framework.models import ToolCallStreamItem
from onyx.agents.agent_framework.query import query
from onyx.llm.message_types import ChatCompletionMessage
from onyx.llm.model_response import ModelResponseStream
from tests.unit.onyx.agents.agent_framework.conftest import FakeErrorTool
from tests.unit.onyx.agents.agent_framework.conftest import FakeTool
from tests.unit.onyx.agents.agent_framework.conftest import stream_chunk
from tests.unit.onyx.agents.agent_framework.conftest import tool_call_chunk
@@ -464,165 +463,3 @@ def test_query_emits_reasoning_done_before_message_start(
assert len(result.new_messages_stateful) == 1
assert result.new_messages_stateful[0]["role"] == "assistant"
assert result.new_messages_stateful[0]["content"] == "Based on my analysis"
def test_query_treats_tool_like_message_as_tool_call(
fake_llm: Callable[[list[ModelResponseStream]], Any],
fake_internal_search_tool: FakeTool,
) -> None:
"""Test that query treats JSON message content as a tool call when structured_response_format is empty."""
stream_id = "chatcmpl-tool-like-message"
responses = [
stream_chunk(
id=stream_id,
created="1762545000",
content='{"name": "internal_search", ',
),
stream_chunk(
id=stream_id,
created="1762545000",
content='"arguments": {"queries": ["new agent", "framework"]}',
),
stream_chunk(id=stream_id, created="1762545000", content="}"),
stream_chunk(id=stream_id, created="1762545000", finish_reason="stop"),
]
llm = fake_llm(responses)
context: dict[str, bool] = {}
messages: list[ChatCompletionMessage] = [
{"role": "user", "content": "tell me about agents"}
]
result = query(
llm,
messages,
tools=[fake_internal_search_tool],
context=context,
tool_choice=None,
structured_response_format={},
)
events = list(result.stream)
run_item_events = [e for e in events if isinstance(e, RunItemStreamEvent)]
tool_call_events = [e for e in run_item_events if e.type == "tool_call"]
assert len(tool_call_events) == 1
assert tool_call_events[0].details is not None
assert isinstance(tool_call_events[0].details, ToolCallStreamItem)
assert tool_call_events[0].details.name == "internal_search"
assert (
tool_call_events[0].details.arguments
== '{"queries": ["new agent", "framework"]}'
)
assert len(fake_internal_search_tool.calls) == 1
assert fake_internal_search_tool.calls[0]["queries"] == ["new agent", "framework"]
assert context["internal_search_called"] is True
tool_output_events = [e for e in run_item_events if e.type == "tool_call_output"]
assert len(tool_output_events) == 1
assert tool_output_events[0].details is not None
assert isinstance(tool_output_events[0].details, ToolCallOutputStreamItem)
assert (
tool_output_events[0].details.output
== "Internal Search results for: new agent, framework"
)
assert len(result.new_messages_stateful) == 2
assert result.new_messages_stateful[0]["role"] == "assistant"
assert result.new_messages_stateful[0]["content"] is None
assert len(result.new_messages_stateful[0]["tool_calls"]) == 1
assert (
result.new_messages_stateful[0]["tool_calls"][0]["function"]["name"]
== "internal_search"
)
assert result.new_messages_stateful[1]["role"] == "tool"
assert (
result.new_messages_stateful[1]["tool_call_id"]
== tool_call_events[0].details.call_id
)
assert result.new_messages_stateful[1]["content"] == (
"Internal Search results for: new agent, framework"
)
def test_query_handles_tool_error_gracefully(
fake_llm: Callable[[list[ModelResponseStream]], Any],
fake_error_tool: FakeErrorTool,
) -> None:
"""Test that query handles tool errors gracefully and treats the error as the tool call response."""
call_id = "toolu_01ErrorToolCall123"
stream_id = "chatcmpl-error-test-12345"
responses = [
stream_chunk(
id=stream_id,
created="1762545000",
content="",
tool_calls=[tool_call_chunk(id=call_id, name="error_tool", arguments="")],
),
stream_chunk(
id=stream_id,
created="1762545000",
content="",
tool_calls=[tool_call_chunk(arguments="")],
),
*[
stream_chunk(
id=stream_id,
created="1762545000",
content="",
tool_calls=[tool_call_chunk(arguments=arg)],
)
for arg in ['{"queries": ', '["test"]}']
],
stream_chunk(id=stream_id, created="1762545000", finish_reason="tool_calls"),
]
llm = fake_llm(responses)
context: dict[str, bool] = {}
messages: list[ChatCompletionMessage] = [
{"role": "user", "content": "call the error tool"}
]
result = query(
llm,
messages,
tools=[fake_error_tool],
context=context,
tool_choice=None,
)
events = list(result.stream)
run_item_events = [e for e in events if isinstance(e, RunItemStreamEvent)]
# Verify tool_call event was emitted
tool_call_events = [e for e in run_item_events if e.type == "tool_call"]
assert len(tool_call_events) == 1
assert tool_call_events[0].details is not None
assert isinstance(tool_call_events[0].details, ToolCallStreamItem)
assert tool_call_events[0].details.call_id == call_id
assert tool_call_events[0].details.name == "error_tool"
# Verify tool_call_output event was emitted with the error message
tool_output_events = [e for e in run_item_events if e.type == "tool_call_output"]
assert len(tool_output_events) == 1
assert tool_output_events[0].details is not None
assert isinstance(tool_output_events[0].details, ToolCallOutputStreamItem)
assert tool_output_events[0].details.call_id == call_id
assert tool_output_events[0].details.output == "Error: Error running tool"
# Verify new_messages contains expected assistant and tool messages
assert len(result.new_messages_stateful) == 2
assert result.new_messages_stateful[0]["role"] == "assistant"
assert result.new_messages_stateful[0]["content"] is None
assert len(result.new_messages_stateful[0]["tool_calls"]) == 1
assert result.new_messages_stateful[0]["tool_calls"][0]["id"] == call_id
assert (
result.new_messages_stateful[0]["tool_calls"][0]["function"]["name"]
== "error_tool"
)
assert result.new_messages_stateful[1]["role"] == "tool"
assert result.new_messages_stateful[1]["tool_call_id"] == call_id
assert result.new_messages_stateful[1]["content"] == "Error: Error running tool"

View File

@@ -2,9 +2,6 @@ from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage as LangChainSystemMessage
from onyx.agents.agent_framework.message_format import (
base_messages_to_chat_completion_msgs,
)
from onyx.agents.agent_sdk.message_format import base_messages_to_agent_sdk_msgs
from onyx.agents.agent_sdk.message_types import AssistantMessageWithContent
from onyx.agents.agent_sdk.message_types import InputTextContent
@@ -215,46 +212,3 @@ def test_assistant_message_list_content_non_responses_api() -> None:
assert second_content["type"] == "input_text"
text_content2: InputTextContent = second_content # type: ignore[assignment]
assert text_content2["text"] == "It's known for the Eiffel Tower."
def test_base_messages_to_chat_completion_msgs_basic() -> None:
"""Ensure system and user messages convert to chat completion format."""
system_message = LangChainSystemMessage(
content="You are a helpful assistant.",
additional_kwargs={},
response_metadata={},
)
human_message = HumanMessage(
content="hello",
additional_kwargs={},
response_metadata={},
)
results = base_messages_to_chat_completion_msgs([system_message, human_message])
assert results == [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "hello"},
]
def test_base_messages_to_chat_completion_msgs_with_tool_call() -> None:
"""Ensure assistant messages with tool calls are preserved."""
ai_message = AIMessage(
content="",
tool_calls=[
{
"id": "call_1",
"name": "internal_search",
"args": {"query": "test"},
}
],
additional_kwargs={},
response_metadata={},
)
results = base_messages_to_chat_completion_msgs([ai_message])
assert len(results) == 1
assert results[0]["role"] == "assistant"
assert results[0]["tool_calls"][0]["function"]["name"] == "internal_search"

View File

@@ -1,17 +0,0 @@
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.onyx_web_crawler_client import (
OnyxWebCrawlerClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
build_content_provider_from_config,
)
from shared_configs.enums import WebContentProviderType
def test_build_content_provider_returns_onyx_crawler() -> None:
provider = build_content_provider_from_config(
provider_type=WebContentProviderType.ONYX_WEB_CRAWLER,
api_key=None,
config={"timeout_seconds": "20"},
provider_name="Built-in",
)
assert isinstance(provider, OnyxWebCrawlerClient)

View File

@@ -1,228 +0,0 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
import pytest
from onyx.auth import users as users_module
def test_extract_email_requires_valid_format() -> None:
"""Helper should validate email format before returning value."""
assert users_module._extract_email_from_jwt({"email": "invalid@"}) is None
result = users_module._extract_email_from_jwt(
{"preferred_username": "ValidUser@Example.COM"}
)
assert result == "validuser@example.com"
@pytest.mark.asyncio
async def test_get_or_create_user_updates_expiry(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Existing web-login users should be returned and their expiry synced."""
monkeypatch.setattr(users_module, "TRACK_EXTERNAL_IDP_EXPIRY", True)
invited_checked: dict[str, str] = {}
def mark_invited(value: str) -> None:
invited_checked["email"] = value
domain_checked: dict[str, str] = {}
def mark_domain(value: str) -> None:
domain_checked["email"] = value
monkeypatch.setattr(users_module, "verify_email_is_invited", mark_invited)
monkeypatch.setattr(users_module, "verify_email_domain", mark_domain)
email = "jwt-user@example.com"
exp_value = 1_700_000_000
payload: dict[str, Any] = {"email": email, "exp": exp_value}
existing_user = MagicMock()
existing_user.email = email
existing_user.oidc_expiry = None
existing_user.role.is_web_login.return_value = True # type: ignore[attr-defined]
manager_holder: dict[str, Any] = {}
class StubUserManager:
def __init__(self, _user_db: object) -> None:
manager_holder["instance"] = self # type: ignore[assignment]
self.user_db = MagicMock()
self.user_db.update = AsyncMock()
async def get_by_email(self, email_arg: str) -> MagicMock:
assert email_arg == email
return existing_user
monkeypatch.setattr(users_module, "UserManager", StubUserManager)
monkeypatch.setattr(
users_module,
"SQLAlchemyUserAdminDB",
lambda *args, **kwargs: MagicMock(),
)
result = await users_module._get_or_create_user_from_jwt(
payload, MagicMock(), MagicMock()
)
assert result is existing_user
assert invited_checked["email"] == email
assert domain_checked["email"] == email
expected_expiry = datetime.fromtimestamp(exp_value, tz=timezone.utc)
instance = manager_holder["instance"]
instance.user_db.update.assert_awaited_once_with( # type: ignore[attr-defined]
existing_user, {"oidc_expiry": expected_expiry}
)
assert existing_user.oidc_expiry == expected_expiry
@pytest.mark.asyncio
async def test_get_or_create_user_skips_inactive(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Inactive users should not be re-authenticated via JWT."""
monkeypatch.setattr(users_module, "TRACK_EXTERNAL_IDP_EXPIRY", True)
monkeypatch.setattr(users_module, "verify_email_is_invited", lambda _: None)
monkeypatch.setattr(users_module, "verify_email_domain", lambda _: None)
email = "inactive@example.com"
payload: dict[str, Any] = {"email": email}
existing_user = MagicMock()
existing_user.email = email
existing_user.is_active = False
existing_user.role.is_web_login.return_value = True # type: ignore[attr-defined]
class StubUserManager:
def __init__(self, _user_db: object) -> None:
self.user_db = MagicMock()
self.user_db.update = AsyncMock()
async def get_by_email(self, email_arg: str) -> MagicMock:
assert email_arg == email
return existing_user
monkeypatch.setattr(users_module, "UserManager", StubUserManager)
monkeypatch.setattr(
users_module,
"SQLAlchemyUserAdminDB",
lambda *args, **kwargs: MagicMock(),
)
result = await users_module._get_or_create_user_from_jwt(
payload, MagicMock(), MagicMock()
)
assert result is None
@pytest.mark.asyncio
async def test_get_or_create_user_handles_race_conditions(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""If provisioning races, newly inactive users should still be blocked."""
monkeypatch.setattr(users_module, "TRACK_EXTERNAL_IDP_EXPIRY", True)
monkeypatch.setattr(users_module, "verify_email_is_invited", lambda _: None)
monkeypatch.setattr(users_module, "verify_email_domain", lambda _: None)
email = "race@example.com"
payload: dict[str, Any] = {"email": email}
inactive_user = MagicMock()
inactive_user.email = email
inactive_user.is_active = False
inactive_user.role.is_web_login.return_value = True # type: ignore[attr-defined]
class StubUserManager:
def __init__(self, _user_db: object) -> None:
self.user_db = MagicMock()
self.user_db.update = AsyncMock()
self.get_calls = 0
async def get_by_email(self, email_arg: str) -> MagicMock:
assert email_arg == email
if self.get_calls == 0:
self.get_calls += 1
raise users_module.exceptions.UserNotExists()
self.get_calls += 1
return inactive_user
async def create(self, *args, **kwargs): # type: ignore[no-untyped-def]
raise users_module.exceptions.UserAlreadyExists()
monkeypatch.setattr(users_module, "UserManager", StubUserManager)
monkeypatch.setattr(
users_module,
"SQLAlchemyUserAdminDB",
lambda *args, **kwargs: MagicMock(),
)
result = await users_module._get_or_create_user_from_jwt(
payload, MagicMock(), MagicMock()
)
assert result is None
@pytest.mark.asyncio
async def test_get_or_create_user_provisions_new_user(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A brand new JWT user should be provisioned automatically."""
email = "new-user@example.com"
payload = {"email": email}
created_user = MagicMock()
created_user.email = email
created_user.oidc_expiry = None
created_user.role.is_web_login.return_value = True # type: ignore[attr-defined]
monkeypatch.setattr(users_module, "TRACK_EXTERNAL_IDP_EXPIRY", False)
monkeypatch.setattr(users_module, "generate_password", lambda: "TempPass123!")
monkeypatch.setattr(users_module, "verify_email_is_invited", lambda _: None)
monkeypatch.setattr(users_module, "verify_email_domain", lambda _: None)
recorded: dict[str, Any] = {}
class StubUserManager:
def __init__(self, _user_db: object) -> None:
recorded["instance"] = self
self.user_db = MagicMock()
self.user_db.update = AsyncMock()
async def get_by_email(self, _email: str) -> MagicMock:
raise users_module.exceptions.UserNotExists()
async def create(self, user_create, safe=False, request=None): # type: ignore[no-untyped-def]
recorded["user_create"] = user_create
recorded["request"] = request
return created_user
monkeypatch.setattr(users_module, "UserManager", StubUserManager)
monkeypatch.setattr(
users_module,
"SQLAlchemyUserAdminDB",
lambda *args, **kwargs: MagicMock(),
)
request = MagicMock()
result = await users_module._get_or_create_user_from_jwt(
payload, request, MagicMock()
)
assert result is created_user
created_payload = recorded["user_create"]
assert created_payload.email == email
assert created_payload.is_verified is True
assert recorded["request"] is request
@pytest.mark.asyncio
async def test_get_or_create_user_requires_email_claim() -> None:
"""Tokens without a usable email claim should be ignored."""
result = await users_module._get_or_create_user_from_jwt(
{}, MagicMock(), MagicMock()
)
assert result is None

View File

@@ -6,7 +6,6 @@ from typing import Any
import pytest
from onyx.tools.models import ToolResponse
from onyx.tools.tool import RunContextWrapper
from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
@@ -79,14 +78,6 @@ class FakeDummyTool(CustomTool):
),
)
def run_v2(
self,
run_context: RunContextWrapper[Any],
*args: Any,
**kwargs: Any,
) -> Any:
return "Tool executed successfully"
@pytest.fixture
def fake_dummy_tool() -> FakeDummyTool:
@@ -99,8 +90,8 @@ __all__ = [
"chat_turn_dependencies",
"fake_db_session",
"fake_dummy_tool",
"fake_model",
"fake_llm",
"fake_model",
"fake_redis_client",
"fake_tools",
]

View File

@@ -298,5 +298,4 @@ def test_multiple_tool_calls_streaming(default_multi_llm: LitellmLLM) -> None:
timeout=30,
parallel_tool_calls=False,
mock_response=MOCK_LLM_RESPONSE,
stream_options={"include_usage": True},
)

View File

@@ -37,7 +37,7 @@ def test_partial_match_in_model_map() -> None:
"supports_audio_output": False,
"supports_function_calling": True,
"supports_response_schema": True,
"supports_system_messages": False,
"supports_system_messages": True,
"supports_tool_choice": True,
"supports_vision": True,
}
@@ -46,13 +46,13 @@ def test_partial_match_in_model_map() -> None:
assert result1 is not None
for key, value in _EXPECTED_FIELDS.items():
assert key in result1
assert result1[key] == value, "Unexpected value for key: {}".format(key)
assert result1[key] == value
result2 = find_model_obj(model_map, "openai", "gemma-3-27b-it")
assert result2 is not None
for key, value in _EXPECTED_FIELDS.items():
assert key in result2
assert result2[key] == value, "Unexpected value for key: {}".format(key)
assert result2[key] == value
get_model_map.cache_clear()

View File

@@ -1,48 +0,0 @@
"""Unit tests for tracing setup functions."""
import importlib
import os
from onyx.configs import app_configs
def test_setup_langfuse_if_creds_available_with_creds() -> None:
"""Test that setup_langfuse_if_creds_available executes without error when credentials are available."""
# Set credentials to non-empty values to avoid early return
os.environ["LANGFUSE_SECRET_KEY"] = "test-secret-key"
os.environ["LANGFUSE_PUBLIC_KEY"] = "test-public-key"
# Reload modules to pick up new environment variables
importlib.reload(app_configs)
from onyx.tracing import langfuse_tracing
importlib.reload(langfuse_tracing)
# Call the function - should not raise an error
langfuse_tracing.setup_langfuse_if_creds_available()
# Clean up
os.environ.pop("LANGFUSE_SECRET_KEY", None)
os.environ.pop("LANGFUSE_PUBLIC_KEY", None)
importlib.reload(app_configs)
def test_setup_braintrust_if_creds_available_with_creds() -> None:
"""Test that setup_braintrust_if_creds_available executes without error when credentials are available."""
# Set credentials to non-empty values to avoid early return
os.environ["BRAINTRUST_API_KEY"] = "test-api-key"
os.environ["BRAINTRUST_PROJECT"] = "test-project"
# Reload modules to pick up new environment variables
importlib.reload(app_configs)
from onyx.tracing import braintrust_tracing
importlib.reload(braintrust_tracing)
# Call the function - should not raise an error
braintrust_tracing.setup_braintrust_if_creds_available()
# Clean up environment variables
os.environ.pop("BRAINTRUST_API_KEY", None)
os.environ.pop("BRAINTRUST_PROJECT", None)
importlib.reload(app_configs)

View File

@@ -109,7 +109,7 @@ Resources:
Family: !Sub ${Environment}-${ServiceName}-TaskDefinition
ContainerDefinitions:
- Name: nginx
Image: nginx:1.25.5-alpine
Image: nginx:1.23.4-alpine
Cpu: 0
PortMappings:
- Name: nginx-80-tcp

View File

@@ -55,6 +55,8 @@ services:
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
- DISABLE_LITELLM_STREAMING=${DISABLE_LITELLM_STREAMING:-}
- LITELLM_EXTRA_HEADERS=${LITELLM_EXTRA_HEADERS:-}
- BING_API_KEY=${BING_API_KEY:-}
- EXA_API_KEY=${EXA_API_KEY:-}
- DISABLE_LLM_DOC_RELEVANCE=${DISABLE_LLM_DOC_RELEVANCE:-}
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
- TOKEN_BUDGET_GLOBALLY_ENABLED=${TOKEN_BUDGET_GLOBALLY_ENABLED:-}
@@ -168,6 +170,8 @@ services:
- DISABLE_LITELLM_STREAMING=${DISABLE_LITELLM_STREAMING:-}
- LITELLM_EXTRA_HEADERS=${LITELLM_EXTRA_HEADERS:-}
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
- BING_API_KEY=${BING_API_KEY:-}
- EXA_API_KEY=${EXA_API_KEY:-}
# Query Options
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-}
- HYBRID_ALPHA=${HYBRID_ALPHA:-}
@@ -400,7 +404,7 @@ services:
max-file: "6"
nginx:
image: nginx:1.25.5-alpine
image: nginx:1.23.4-alpine
restart: unless-stopped
# nginx will immediately crash with `nginx: [emerg] host not found in upstream`
# if api_server / web_server are not up

View File

@@ -188,7 +188,7 @@ services:
max-file: "6"
nginx:
image: nginx:1.25.5-alpine
image: nginx:1.23.4-alpine
restart: unless-stopped
# nginx will immediately crash with `nginx: [emerg] host not found in upstream`
# if api_server / web_server are not up

View File

@@ -212,7 +212,7 @@ services:
max-file: "6"
nginx:
image: nginx:1.25.5-alpine
image: nginx:1.23.4-alpine
restart: unless-stopped
# nginx will immediately crash with `nginx: [emerg] host not found in upstream`
# if api_server / web_server are not up

View File

@@ -224,7 +224,7 @@ services:
max-file: "6"
nginx:
image: nginx:1.25.5-alpine
image: nginx:1.23.4-alpine
restart: unless-stopped
# nginx will immediately crash with `nginx: [emerg] host not found in upstream`
# if api_server / web_server are not up

View File

@@ -193,7 +193,7 @@ services:
max-file: "6"
nginx:
image: nginx:1.25.5-alpine
image: nginx:1.23.4-alpine
restart: unless-stopped
# nginx will immediately crash with `nginx: [emerg] host not found in upstream`
# if api_server / web_server are not up

View File

@@ -3,7 +3,7 @@
# =============================================================================
# This is the default configuration for Onyx. This file is fairly configurable,
# also see env.template for possible settings.
#
#
# PRODUCTION DEPLOYMENT CHECKLIST:
# To convert this setup to a production deployment following best practices,
# follow the checklist below. Note that there are other ways to secure the Onyx
@@ -283,7 +283,7 @@ services:
max-file: "6"
nginx:
image: nginx:1.25.5-alpine
image: nginx:1.23.4-alpine
restart: unless-stopped
# nginx will immediately crash with `nginx: [emerg] host not found in upstream`
# if api_server / web_server are not up
@@ -300,7 +300,7 @@ services:
- NGINX_PROXY_SEND_TIMEOUT=${NGINX_PROXY_SEND_TIMEOUT:-300}
- NGINX_PROXY_READ_TIMEOUT=${NGINX_PROXY_READ_TIMEOUT:-300}
ports:
- "${HOST_PORT_80:-80}:80"
- "80:80"
- "${HOST_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
volumes:
- ../data/nginx:/etc/nginx/conf.d
@@ -319,7 +319,7 @@ services:
# in order to make this work on both Unix-like systems and windows
# PRODUCTION: Change to app.conf.template.prod for production nginx config
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
cache:

Some files were not shown because too many files have changed in this diff Show More