mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 15:55:45 +00:00
Compare commits
6 Commits
sidebar-mo
...
header
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9e9c3ec0b9 | ||
|
|
1457ca2a20 | ||
|
|
edc390edc6 | ||
|
|
022624cb5a | ||
|
|
f301257130 | ||
|
|
9eecc71cda |
4
.github/actions/setup-playwright/action.yml
vendored
4
.github/actions/setup-playwright/action.yml
vendored
@@ -7,9 +7,9 @@ runs:
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
key: ${{ runner.os }}-${{ runner.arch }}-playwright-${{ hashFiles('backend/requirements/default.txt') }}
|
||||
key: ${{ runner.os }}-playwright-${{ hashFiles('backend/requirements/default.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-${{ runner.arch }}-playwright-
|
||||
${{ runner.os }}-playwright-
|
||||
|
||||
- name: Install playwright
|
||||
shell: bash
|
||||
|
||||
5
.github/workflows/check-lazy-imports.yml
vendored
5
.github/workflows/check-lazy-imports.yml
vendored
@@ -10,9 +10,6 @@ on:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
check-lazy-imports:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -20,8 +17,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # ratchet:actions/setup-python@v6
|
||||
|
||||
772
.github/workflows/deployment.yml
vendored
772
.github/workflows/deployment.yml
vendored
@@ -6,9 +6,6 @@ on:
|
||||
- "*"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
IS_DRY_RUN: ${{ github.event_name == 'workflow_dispatch' }}
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
@@ -33,7 +30,7 @@ jobs:
|
||||
- name: Check which components to build and version info
|
||||
id: check
|
||||
run: |
|
||||
TAG="${GITHUB_REF_NAME}"
|
||||
TAG="${{ github.ref_name }}"
|
||||
# Sanitize tag name by replacing slashes with hyphens (for Docker tag compatibility)
|
||||
SANITIZED_TAG=$(echo "$TAG" | tr '/' '-')
|
||||
IS_CLOUD=false
|
||||
@@ -82,143 +79,22 @@ jobs:
|
||||
echo "sanitized-tag=$SANITIZED_TAG"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
build-web-amd64:
|
||||
build-web:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-web == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-web-amd64
|
||||
- run-id=${{ github.run_id }}-web-build
|
||||
- extras=ecr-cache
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
DEPLOYMENT: standalone
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/amd64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-web-arm64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-web == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-web-arm64
|
||||
- extras=ecr-cache
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/arm64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
merge-web:
|
||||
needs:
|
||||
- determine-builds
|
||||
- build-web-amd64
|
||||
- build-web-arm64
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-merge-web
|
||||
- extras=ecr-cache
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -233,171 +109,50 @@ jobs:
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-beta == 'true' && 'beta' || '' }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-web-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-web-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
run: |
|
||||
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
|
||||
docker buildx imagetools create \
|
||||
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
|
||||
$IMAGES
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
build-web-cloud-amd64:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-${{ env.DEPLOYMENT }}-cache
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-${{ env.DEPLOYMENT }}-cache,mode=max
|
||||
|
||||
build-web-cloud:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-web-cloud == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-web-cloud-amd64
|
||||
- run-id=${{ github.run_id }}-web-cloud-build
|
||||
- extras=ecr-cache
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
DEPLOYMENT: cloud
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/amd64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NEXT_PUBLIC_CLOUD_ENABLED=true
|
||||
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-web-cloud-arm64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-web-cloud == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-web-cloud-arm64
|
||||
- extras=ecr-cache
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/arm64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NEXT_PUBLIC_CLOUD_ENABLED=true
|
||||
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
merge-web-cloud:
|
||||
needs:
|
||||
- determine-builds
|
||||
- build-web-cloud-amd64
|
||||
- build-web-cloud-arm64
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-merge-web-cloud
|
||||
- extras=ecr-cache
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -409,153 +164,58 @@ jobs:
|
||||
tags: |
|
||||
type=raw,value=${{ github.event_name == 'workflow_dispatch' && format('web-cloud-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-web-cloud-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-web-cloud-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
run: |
|
||||
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
|
||||
docker buildx imagetools create \
|
||||
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
|
||||
$IMAGES
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
build-backend-amd64:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NEXT_PUBLIC_CLOUD_ENABLED=true
|
||||
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-${{ env.DEPLOYMENT }}-cache
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-${{ env.DEPLOYMENT }}-cache,mode=max
|
||||
|
||||
build-backend:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-backend == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-backend-amd64
|
||||
- run-id=${{ github.run_id }}-backend-build
|
||||
- extras=ecr-cache
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-backend-arm64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-backend == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-backend-arm64
|
||||
- extras=ecr-cache
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/arm64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
merge-backend:
|
||||
needs:
|
||||
- determine-builds
|
||||
- build-backend-amd64
|
||||
- build-backend-arm64
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-merge-backend
|
||||
- extras=ecr-cache
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -570,51 +230,8 @@ jobs:
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-backend-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-backend-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
run: |
|
||||
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
|
||||
docker buildx imagetools create \
|
||||
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
|
||||
$IMAGES
|
||||
|
||||
build-model-server-amd64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-model-server == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-model-server-amd64
|
||||
- volume=40gb
|
||||
- extras=ecr-cache
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
@@ -622,115 +239,43 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
env:
|
||||
DEBUG: ${{ vars.DOCKER_DEBUG == 'true' && 1 || 0 }}
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-${{ env.DEPLOYMENT }}-cache
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
provenance: false
|
||||
sbom: false
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-${{ env.DEPLOYMENT }}-cache,mode=max
|
||||
|
||||
build-model-server-arm64:
|
||||
build-model-server:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-model-server == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-model-server-arm64
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-model-server-build
|
||||
- volume=40gb
|
||||
- extras=ecr-cache
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
DOCKER_BUILDKIT: 1
|
||||
BUILDKIT_PROGRESS: plain
|
||||
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
env:
|
||||
DEBUG: ${{ vars.DOCKER_DEBUG == 'true' && 1 || 0 }}
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/arm64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
provenance: false
|
||||
sbom: false
|
||||
|
||||
merge-model-server:
|
||||
needs:
|
||||
- determine-builds
|
||||
- build-model-server-amd64
|
||||
- build-model-server-arm64
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-merge-model-server
|
||||
- extras=ecr-cache
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -745,26 +290,43 @@ jobs:
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-model-server-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-model-server-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
run: |
|
||||
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
|
||||
docker buildx imagetools create \
|
||||
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
|
||||
$IMAGES
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver-opts: |
|
||||
image=moby/buildkit:latest
|
||||
network=host
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-${{ env.DEPLOYMENT }}-cache
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-${{ env.DEPLOYMENT }}-cache,mode=max
|
||||
|
||||
trivy-scan-web:
|
||||
needs:
|
||||
- determine-builds
|
||||
- merge-web
|
||||
if: needs.merge-web.result == 'success'
|
||||
needs: [determine-builds, build-web]
|
||||
if: needs.build-web.result == 'success'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-trivy-scan-web
|
||||
- extras=ecr-cache
|
||||
env:
|
||||
@@ -797,13 +359,11 @@ jobs:
|
||||
${SCAN_IMAGE}
|
||||
|
||||
trivy-scan-web-cloud:
|
||||
needs:
|
||||
- determine-builds
|
||||
- merge-web-cloud
|
||||
if: needs.merge-web-cloud.result == 'success'
|
||||
needs: [determine-builds, build-web-cloud]
|
||||
if: needs.build-web-cloud.result == 'success'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-trivy-scan-web-cloud
|
||||
- extras=ecr-cache
|
||||
env:
|
||||
@@ -836,13 +396,11 @@ jobs:
|
||||
${SCAN_IMAGE}
|
||||
|
||||
trivy-scan-backend:
|
||||
needs:
|
||||
- determine-builds
|
||||
- merge-backend
|
||||
if: needs.merge-backend.result == 'success'
|
||||
needs: [determine-builds, build-backend]
|
||||
if: needs.build-backend.result == 'success'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-trivy-scan-backend
|
||||
- extras=ecr-cache
|
||||
env:
|
||||
@@ -852,8 +410,6 @@ jobs:
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
@@ -882,13 +438,11 @@ jobs:
|
||||
${SCAN_IMAGE}
|
||||
|
||||
trivy-scan-model-server:
|
||||
needs:
|
||||
- determine-builds
|
||||
- merge-model-server
|
||||
if: needs.merge-model-server.result == 'success'
|
||||
needs: [determine-builds, build-model-server]
|
||||
if: needs.build-model-server.result == 'success'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-trivy-scan-model-server
|
||||
- extras=ecr-cache
|
||||
env:
|
||||
@@ -921,85 +475,33 @@ jobs:
|
||||
${SCAN_IMAGE}
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs:
|
||||
- build-web-amd64
|
||||
- build-web-arm64
|
||||
- merge-web
|
||||
- build-web-cloud-amd64
|
||||
- build-web-cloud-arm64
|
||||
- merge-web-cloud
|
||||
- build-backend-amd64
|
||||
- build-backend-arm64
|
||||
- merge-backend
|
||||
- build-model-server-amd64
|
||||
- build-model-server-arm64
|
||||
- merge-model-server
|
||||
if: always() && (needs.build-web-amd64.result == 'failure' || needs.build-web-arm64.result == 'failure' || needs.merge-web.result == 'failure' || needs.build-web-cloud-amd64.result == 'failure' || needs.build-web-cloud-arm64.result == 'failure' || needs.merge-web-cloud.result == 'failure' || needs.build-backend-amd64.result == 'failure' || needs.build-backend-arm64.result == 'failure' || needs.merge-backend.result == 'failure' || needs.build-model-server-amd64.result == 'failure' || needs.build-model-server-arm64.result == 'failure' || needs.merge-model-server.result == 'failure') && github.event_name != 'workflow_dispatch'
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
needs: [build-web, build-web-cloud, build-backend, build-model-server]
|
||||
if: always() && (needs.build-web.result == 'failure' || needs.build-web-cloud.result == 'failure' || needs.build-backend.result == 'failure' || needs.build-model-server.result == 'failure') && github.event_name != 'workflow_dispatch'
|
||||
runs-on: ubuntu-slim
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Determine failed jobs
|
||||
id: failed-jobs
|
||||
shell: bash
|
||||
run: |
|
||||
FAILED_JOBS=""
|
||||
if [ "${NEEDS_BUILD_WEB_AMD64_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-web-amd64\\n"
|
||||
if [ "${{ needs.build-web.result }}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-web\\n"
|
||||
fi
|
||||
if [ "${NEEDS_BUILD_WEB_ARM64_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-web-arm64\\n"
|
||||
if [ "${{ needs.build-web-cloud.result }}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-web-cloud\\n"
|
||||
fi
|
||||
if [ "${NEEDS_MERGE_WEB_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• merge-web\\n"
|
||||
if [ "${{ needs.build-backend.result }}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-backend\\n"
|
||||
fi
|
||||
if [ "${NEEDS_BUILD_WEB_CLOUD_AMD64_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-web-cloud-amd64\\n"
|
||||
fi
|
||||
if [ "${NEEDS_BUILD_WEB_CLOUD_ARM64_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-web-cloud-arm64\\n"
|
||||
fi
|
||||
if [ "${NEEDS_MERGE_WEB_CLOUD_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• merge-web-cloud\\n"
|
||||
fi
|
||||
if [ "${NEEDS_BUILD_BACKEND_AMD64_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-backend-amd64\\n"
|
||||
fi
|
||||
if [ "${NEEDS_BUILD_BACKEND_ARM64_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-backend-arm64\\n"
|
||||
fi
|
||||
if [ "${NEEDS_MERGE_BACKEND_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• merge-backend\\n"
|
||||
fi
|
||||
if [ "${NEEDS_BUILD_MODEL_SERVER_AMD64_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-model-server-amd64\\n"
|
||||
fi
|
||||
if [ "${NEEDS_BUILD_MODEL_SERVER_ARM64_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-model-server-arm64\\n"
|
||||
fi
|
||||
if [ "${NEEDS_MERGE_MODEL_SERVER_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• merge-model-server\\n"
|
||||
if [ "${{ needs.build-model-server.result }}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-model-server\\n"
|
||||
fi
|
||||
# Remove trailing \n and set output
|
||||
FAILED_JOBS=$(printf '%s' "$FAILED_JOBS" | sed 's/\\n$//')
|
||||
echo "jobs=$FAILED_JOBS" >> "$GITHUB_OUTPUT"
|
||||
env:
|
||||
NEEDS_BUILD_WEB_AMD64_RESULT: ${{ needs.build-web-amd64.result }}
|
||||
NEEDS_BUILD_WEB_ARM64_RESULT: ${{ needs.build-web-arm64.result }}
|
||||
NEEDS_MERGE_WEB_RESULT: ${{ needs.merge-web.result }}
|
||||
NEEDS_BUILD_WEB_CLOUD_AMD64_RESULT: ${{ needs.build-web-cloud-amd64.result }}
|
||||
NEEDS_BUILD_WEB_CLOUD_ARM64_RESULT: ${{ needs.build-web-cloud-arm64.result }}
|
||||
NEEDS_MERGE_WEB_CLOUD_RESULT: ${{ needs.merge-web-cloud.result }}
|
||||
NEEDS_BUILD_BACKEND_AMD64_RESULT: ${{ needs.build-backend-amd64.result }}
|
||||
NEEDS_BUILD_BACKEND_ARM64_RESULT: ${{ needs.build-backend-arm64.result }}
|
||||
NEEDS_MERGE_BACKEND_RESULT: ${{ needs.merge-backend.result }}
|
||||
NEEDS_BUILD_MODEL_SERVER_AMD64_RESULT: ${{ needs.build-model-server-amd64.result }}
|
||||
NEEDS_BUILD_MODEL_SERVER_ARM64_RESULT: ${{ needs.build-model-server-arm64.result }}
|
||||
NEEDS_MERGE_MODEL_SERVER_RESULT: ${{ needs.merge-model-server.result }}
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
|
||||
15
.github/workflows/docker-tag-beta.yml
vendored
15
.github/workflows/docker-tag-beta.yml
vendored
@@ -10,9 +10,6 @@ on:
|
||||
description: "The version (ie v1.0.0-beta.0) to tag as beta"
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
tag:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
@@ -32,19 +29,13 @@ jobs:
|
||||
run: echo "DOCKER_CLI_EXPERIMENTAL=enabled" >> $GITHUB_ENV
|
||||
|
||||
- name: Pull, Tag and Push Web Server Image
|
||||
env:
|
||||
VERSION: ${{ github.event.inputs.version }}
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-web-server:beta onyxdotapp/onyx-web-server:${VERSION}
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-web-server:beta onyxdotapp/onyx-web-server:${{ github.event.inputs.version }}
|
||||
|
||||
- name: Pull, Tag and Push API Server Image
|
||||
env:
|
||||
VERSION: ${{ github.event.inputs.version }}
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-backend:beta onyxdotapp/onyx-backend:${VERSION}
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-backend:beta onyxdotapp/onyx-backend:${{ github.event.inputs.version }}
|
||||
|
||||
- name: Pull, Tag and Push Model Server Image
|
||||
env:
|
||||
VERSION: ${{ github.event.inputs.version }}
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-model-server:beta onyxdotapp/onyx-model-server:${VERSION}
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-model-server:beta onyxdotapp/onyx-model-server:${{ github.event.inputs.version }}
|
||||
|
||||
15
.github/workflows/docker-tag-latest.yml
vendored
15
.github/workflows/docker-tag-latest.yml
vendored
@@ -10,9 +10,6 @@ on:
|
||||
description: "The version (ie v0.0.1) to tag as latest"
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
tag:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
@@ -32,19 +29,13 @@ jobs:
|
||||
run: echo "DOCKER_CLI_EXPERIMENTAL=enabled" >> $GITHUB_ENV
|
||||
|
||||
- name: Pull, Tag and Push Web Server Image
|
||||
env:
|
||||
VERSION: ${{ github.event.inputs.version }}
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:${VERSION}
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:${{ github.event.inputs.version }}
|
||||
|
||||
- name: Pull, Tag and Push API Server Image
|
||||
env:
|
||||
VERSION: ${{ github.event.inputs.version }}
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-backend:latest onyxdotapp/onyx-backend:${VERSION}
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-backend:latest onyxdotapp/onyx-backend:${{ github.event.inputs.version }}
|
||||
|
||||
- name: Pull, Tag and Push Model Server Image
|
||||
env:
|
||||
VERSION: ${{ github.event.inputs.version }}
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:${VERSION}
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:${{ github.event.inputs.version }}
|
||||
|
||||
1
.github/workflows/helm-chart-releases.yml
vendored
1
.github/workflows/helm-chart-releases.yml
vendored
@@ -17,7 +17,6 @@ jobs:
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Helm CLI
|
||||
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4
|
||||
|
||||
11
.github/workflows/nightly-scan-licenses.yml
vendored
11
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -15,21 +15,16 @@ on:
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
jobs:
|
||||
scan-licenses:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-licenses"]
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # ratchet:actions/setup-python@v6
|
||||
@@ -59,9 +54,7 @@ jobs:
|
||||
|
||||
- name: Print report
|
||||
if: always()
|
||||
env:
|
||||
REPORT: ${{ steps.license_check_report.outputs.report }}
|
||||
run: echo "$REPORT"
|
||||
run: echo "${{ steps.license_check_report.outputs.report }}"
|
||||
|
||||
- name: Install npm dependencies
|
||||
working-directory: ./web
|
||||
|
||||
@@ -8,9 +8,6 @@ on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
# AWS
|
||||
S3_AWS_ACCESS_KEY_ID: ${{ secrets.S3_AWS_ACCESS_KEY_ID }}
|
||||
@@ -40,8 +37,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Discover test directories
|
||||
id: set-matrix
|
||||
@@ -72,8 +67,6 @@ jobs:
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Python and Install Dependencies
|
||||
uses: ./.github/actions/setup-python-and-install-dependencies
|
||||
@@ -104,12 +97,10 @@ jobs:
|
||||
|
||||
- name: Run Tests for ${{ matrix.test-dir }}
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
env:
|
||||
TEST_DIR: ${{ matrix.test-dir }}
|
||||
run: |
|
||||
py.test \
|
||||
--durations=8 \
|
||||
-o junit_family=xunit2 \
|
||||
-xv \
|
||||
--ff \
|
||||
backend/tests/external_dependency_unit/${TEST_DIR}
|
||||
backend/tests/external_dependency_unit/${{ matrix.test-dir }}
|
||||
|
||||
10
.github/workflows/pr-helm-chart-testing.yml
vendored
10
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -9,9 +9,6 @@ on:
|
||||
branches: [ main ]
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
helm-chart-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
@@ -23,7 +20,6 @@ jobs:
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4.3.1
|
||||
@@ -36,11 +32,9 @@ jobs:
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
env:
|
||||
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
|
||||
run: |
|
||||
echo "default_branch: ${DEFAULT_BRANCH}"
|
||||
changed=$(ct list-changed --remote origin --target-branch ${DEFAULT_BRANCH} --chart-dirs deployment/helm/charts)
|
||||
echo "default_branch: ${{ github.event.repository.default_branch }}"
|
||||
changed=$(ct list-changed --remote origin --target-branch ${{ github.event.repository.default_branch }} --chart-dirs deployment/helm/charts)
|
||||
echo "list-changed output: $changed"
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
|
||||
71
.github/workflows/pr-integration-tests.yml
vendored
71
.github/workflows/pr-integration-tests.yml
vendored
@@ -10,9 +10,6 @@ on:
|
||||
- main
|
||||
- "release/**"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
@@ -40,8 +37,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Discover test directories
|
||||
id: set-matrix
|
||||
@@ -70,8 +65,6 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
@@ -92,11 +85,8 @@ jobs:
|
||||
file: ./backend/Dockerfile
|
||||
push: true
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
|
||||
type=registry,ref=onyxdotapp/onyx-backend:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
|
||||
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-cache
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-cache,mode=max
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
|
||||
@@ -106,8 +96,6 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
@@ -128,10 +116,8 @@ jobs:
|
||||
file: ./backend/Dockerfile.model_server
|
||||
push: true
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
|
||||
type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-cache
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-cache,mode=max
|
||||
|
||||
|
||||
build-integration-image:
|
||||
@@ -140,8 +126,6 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
@@ -157,16 +141,9 @@ jobs:
|
||||
|
||||
- name: Build and push integration test image with Docker Bake
|
||||
env:
|
||||
INTEGRATION_REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
TAG: integration-test-${{ github.run_id }}
|
||||
run: |
|
||||
cd backend && docker buildx bake --push \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
|
||||
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
|
||||
integration
|
||||
run: cd backend && docker buildx bake --push integration
|
||||
|
||||
integration-tests:
|
||||
needs:
|
||||
@@ -191,8 +168,6 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -206,9 +181,6 @@ jobs:
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
# NOTE: don't need web server for integration tests
|
||||
- name: Start Docker containers
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
@@ -217,8 +189,8 @@ jobs:
|
||||
POSTGRES_USE_NULL_POOL=true \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
ONYX_BACKEND_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }} \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 \
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
|
||||
@@ -350,8 +322,6 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
@@ -360,9 +330,6 @@ jobs:
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Start Docker containers for multi-tenant tests
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
@@ -370,8 +337,8 @@ jobs:
|
||||
AUTH_TYPE=cloud \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
ONYX_BACKEND_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }} \
|
||||
DEV_MODE=true \
|
||||
docker compose -f docker-compose.multitenant-dev.yml up \
|
||||
relational_db \
|
||||
@@ -412,9 +379,6 @@ jobs:
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run Multi-Tenant Integration Tests
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
echo "Running multi-tenant integration tests..."
|
||||
docker run --rm --network onyx_default \
|
||||
@@ -438,7 +402,7 @@ jobs:
|
||||
-e REQUIRE_EMAIL_VERIFICATION=false \
|
||||
-e DISABLE_TELEMETRY=true \
|
||||
-e DEV_MODE=true \
|
||||
${ECR_CACHE}:integration-test-${RUN_ID} \
|
||||
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
|
||||
/app/tests/integration/multitenant_tests
|
||||
|
||||
- name: Dump API server logs (multi-tenant)
|
||||
@@ -472,6 +436,13 @@ jobs:
|
||||
needs: [integration-tests, multitenant-tests]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Check job status
|
||||
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
|
||||
run: exit 1
|
||||
- uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # ratchet:actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
const needs = ${{ toJSON(needs) }};
|
||||
const failed = Object.values(needs).some(n => n.result !== 'success');
|
||||
if (failed) {
|
||||
core.setFailed('One or more upstream jobs failed or were cancelled.');
|
||||
} else {
|
||||
core.notice('All required jobs succeeded.');
|
||||
}
|
||||
|
||||
5
.github/workflows/pr-jest-tests.yml
vendored
5
.github/workflows/pr-jest-tests.yml
vendored
@@ -5,9 +5,6 @@ concurrency:
|
||||
|
||||
on: push
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
jest-tests:
|
||||
name: Jest Tests
|
||||
@@ -15,8 +12,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4
|
||||
|
||||
3
.github/workflows/pr-labeler.yml
vendored
3
.github/workflows/pr-labeler.yml
vendored
@@ -1,7 +1,7 @@
|
||||
name: PR Labeler
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
pull_request_target:
|
||||
branches:
|
||||
- main
|
||||
types:
|
||||
@@ -12,6 +12,7 @@ on:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
validate_pr_title:
|
||||
|
||||
3
.github/workflows/pr-linear-check.yml
vendored
3
.github/workflows/pr-linear-check.yml
vendored
@@ -7,9 +7,6 @@ on:
|
||||
pull_request:
|
||||
types: [opened, edited, reopened, synchronize]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
linear-check:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
56
.github/workflows/pr-mit-integration-tests.yml
vendored
56
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -7,9 +7,6 @@ on:
|
||||
merge_group:
|
||||
types: [checks_requested]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
@@ -36,8 +33,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Discover test directories
|
||||
id: set-matrix
|
||||
@@ -65,8 +60,6 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
@@ -87,10 +80,8 @@ jobs:
|
||||
file: ./backend/Dockerfile
|
||||
push: true
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
|
||||
type=registry,ref=onyxdotapp/onyx-backend:latest
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
|
||||
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-cache
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-cache,mode=max
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-model-server-image:
|
||||
@@ -99,8 +90,6 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
@@ -121,10 +110,8 @@ jobs:
|
||||
file: ./backend/Dockerfile.model_server
|
||||
push: true
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
|
||||
type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-cache
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-cache,mode=max
|
||||
|
||||
build-integration-image:
|
||||
runs-on: [runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-build-integration-image", "extras=ecr-cache"]
|
||||
@@ -132,8 +119,6 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
@@ -149,16 +134,9 @@ jobs:
|
||||
|
||||
- name: Build and push integration test image with Docker Bake
|
||||
env:
|
||||
INTEGRATION_REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
TAG: integration-test-${{ github.run_id }}
|
||||
run: |
|
||||
cd backend && docker buildx bake --push \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
|
||||
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
|
||||
integration
|
||||
run: cd backend && docker buildx bake --push integration
|
||||
|
||||
integration-tests-mit:
|
||||
needs:
|
||||
@@ -183,8 +161,6 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -198,9 +174,6 @@ jobs:
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
# NOTE: don't need web server for integration tests
|
||||
- name: Start Docker containers
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
AUTH_TYPE=basic \
|
||||
@@ -208,8 +181,8 @@ jobs:
|
||||
POSTGRES_USE_NULL_POOL=true \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
ONYX_BACKEND_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }} \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
@@ -334,6 +307,13 @@ jobs:
|
||||
needs: [integration-tests-mit]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Check job status
|
||||
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
|
||||
run: exit 1
|
||||
- uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # ratchet:actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
const needs = ${{ toJSON(needs) }};
|
||||
const failed = Object.values(needs).some(n => n.result !== 'success');
|
||||
if (failed) {
|
||||
core.setFailed('One or more upstream jobs failed or were cancelled.');
|
||||
} else {
|
||||
core.notice('All required jobs succeeded.');
|
||||
}
|
||||
|
||||
62
.github/workflows/pr-playwright-tests.yml
vendored
62
.github/workflows/pr-playwright-tests.yml
vendored
@@ -5,9 +5,6 @@ concurrency:
|
||||
|
||||
on: push
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
@@ -45,8 +42,6 @@ jobs:
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
@@ -67,10 +62,8 @@ jobs:
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-${{ github.run_id }}
|
||||
push: true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache
|
||||
type=registry,ref=onyxdotapp/onyx-web-server:latest
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache,mode=max
|
||||
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-cache
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-cache,mode=max
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-backend-image:
|
||||
@@ -80,8 +73,6 @@ jobs:
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
@@ -102,11 +93,8 @@ jobs:
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-${{ github.run_id }}
|
||||
push: true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
|
||||
type=registry,ref=onyxdotapp/onyx-backend:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
|
||||
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-cache
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-cache,mode=max
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-model-server-image:
|
||||
@@ -116,8 +104,6 @@ jobs:
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
@@ -138,10 +124,8 @@ jobs:
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-${{ github.run_id }}
|
||||
push: true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
|
||||
type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-cache
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-cache,mode=max
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
playwright-tests:
|
||||
@@ -159,7 +143,6 @@ jobs:
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4
|
||||
@@ -185,22 +168,17 @@ jobs:
|
||||
run: npx playwright install --with-deps
|
||||
|
||||
- name: Create .env file for Docker Compose
|
||||
env:
|
||||
OPENAI_API_KEY_VALUE: ${{ env.OPENAI_API_KEY }}
|
||||
EXA_API_KEY_VALUE: ${{ env.EXA_API_KEY }}
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
AUTH_TYPE=basic
|
||||
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
|
||||
EXA_API_KEY=${EXA_API_KEY_VALUE}
|
||||
GEN_AI_API_KEY=${{ env.OPENAI_API_KEY }}
|
||||
EXA_API_KEY=${{ env.EXA_API_KEY }}
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:playwright-test-backend-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:playwright-test-model-server-${RUN_ID}
|
||||
ONYX_WEB_SERVER_IMAGE=${ECR_CACHE}:playwright-test-web-${RUN_ID}
|
||||
ONYX_BACKEND_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-${{ github.run_id }}
|
||||
ONYX_MODEL_SERVER_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-${{ github.run_id }}
|
||||
ONYX_WEB_SERVER_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-${{ github.run_id }}
|
||||
EOF
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
@@ -277,12 +255,10 @@ jobs:
|
||||
|
||||
- name: Run Playwright tests
|
||||
working-directory: ./web
|
||||
env:
|
||||
PROJECT: ${{ matrix.project }}
|
||||
run: |
|
||||
# Create test-results directory to ensure it exists for artifact upload
|
||||
mkdir -p test-results
|
||||
npx playwright test --project ${PROJECT}
|
||||
npx playwright test --project ${{ matrix.project }}
|
||||
|
||||
- uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
if: always()
|
||||
@@ -295,12 +271,10 @@ jobs:
|
||||
# save before stopping the containers so the logs can be captured
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
env:
|
||||
WORKSPACE: ${{ github.workspace }}
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs > docker-compose.log
|
||||
mv docker-compose.log ${WORKSPACE}/docker-compose.log
|
||||
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
@@ -309,16 +283,6 @@ jobs:
|
||||
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
playwright-required:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
needs: [playwright-tests]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Check job status
|
||||
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
|
||||
run: exit 1
|
||||
|
||||
|
||||
# NOTE: Chromatic UI diff testing is currently disabled.
|
||||
# We are using Playwright for local and CI testing without visual regression checks.
|
||||
|
||||
5
.github/workflows/pr-python-checks.yml
vendored
5
.github/workflows/pr-python-checks.yml
vendored
@@ -10,9 +10,6 @@ on:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
mypy-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
@@ -24,8 +21,6 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
# needed for pulling openapitools/openapi-generator-cli
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
|
||||
@@ -11,9 +11,6 @@ on:
|
||||
# This cron expression runs the job daily at 16:00 UTC (9am PT)
|
||||
- cron: "0 16 * * *"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
# AWS
|
||||
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
|
||||
@@ -135,8 +132,6 @@ jobs:
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Python and Install Dependencies
|
||||
uses: ./.github/actions/setup-python-and-install-dependencies
|
||||
@@ -219,10 +214,8 @@ jobs:
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
env:
|
||||
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
|
||||
REPO: ${{ github.repository }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
curl -X POST \
|
||||
-H 'Content-type: application/json' \
|
||||
--data "{\"text\":\"Scheduled Connector Tests failed! Check the run at: https://github.com/${REPO}/actions/runs/${RUN_ID}\"}" \
|
||||
--data '{"text":"Scheduled Connector Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
|
||||
$SLACK_WEBHOOK
|
||||
|
||||
9
.github/workflows/pr-python-model-tests.yml
vendored
9
.github/workflows/pr-python-model-tests.yml
vendored
@@ -11,9 +11,6 @@ on:
|
||||
required: false
|
||||
default: 'main'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
# Bedrock
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
@@ -39,8 +36,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
@@ -127,12 +122,10 @@ jobs:
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
env:
|
||||
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
|
||||
REPO: ${{ github.repository }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
curl -X POST \
|
||||
-H 'Content-type: application/json' \
|
||||
--data "{\"text\":\"Scheduled Model Tests failed! Check the run at: https://github.com/${REPO}/actions/runs/${RUN_ID}\"}" \
|
||||
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
|
||||
$SLACK_WEBHOOK
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
|
||||
5
.github/workflows/pr-python-tests.yml
vendored
5
.github/workflows/pr-python-tests.yml
vendored
@@ -10,9 +10,6 @@ on:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
backend-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
@@ -31,8 +28,6 @@ jobs:
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Python and Install Dependencies
|
||||
uses: ./.github/actions/setup-python-and-install-dependencies
|
||||
|
||||
4
.github/workflows/pr-quality-checks.yml
vendored
4
.github/workflows/pr-quality-checks.yml
vendored
@@ -7,9 +7,6 @@ on:
|
||||
merge_group:
|
||||
pull_request: null
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
quality-checks:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
@@ -19,7 +16,6 @@ jobs:
|
||||
- uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # ratchet:actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
1
.github/workflows/sync_foss.yml
vendored
1
.github/workflows/sync_foss.yml
vendored
@@ -16,7 +16,6 @@ jobs:
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install git-filter-repo
|
||||
run: |
|
||||
|
||||
22
.github/workflows/tag-nightly.yml
vendored
22
.github/workflows/tag-nightly.yml
vendored
@@ -3,29 +3,30 @@ name: Nightly Tag Push
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 10 * * *" # Runs every day at 2 AM PST / 3 AM PDT / 10 AM UTC
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: write # Allows pushing tags to the repository
|
||||
|
||||
jobs:
|
||||
create-and-push-tag:
|
||||
runs-on: ubuntu-slim
|
||||
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-create-and-push-tag"]
|
||||
|
||||
steps:
|
||||
# actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes
|
||||
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
|
||||
# implement here which needs an actual user's deploy key
|
||||
|
||||
# Additional NOTE: even though this is named "rkuo", the actual key is tied to the onyx repo
|
||||
# and not rkuo's personal account. It is fine to leave this key as is!
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
|
||||
with:
|
||||
ssh-key: "${{ secrets.DEPLOY_KEY }}"
|
||||
persist-credentials: true
|
||||
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
|
||||
|
||||
- name: Set up Git user
|
||||
run: |
|
||||
git config user.name "Onyx Bot [bot]"
|
||||
git config user.email "onyx-bot[bot]@onyx.app"
|
||||
git config user.name "Richard Kuo [bot]"
|
||||
git config user.email "rkuo[bot]@onyx.app"
|
||||
|
||||
- name: Check for existing nightly tag
|
||||
id: check_tag
|
||||
@@ -53,12 +54,3 @@ jobs:
|
||||
run: |
|
||||
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
|
||||
git push origin $TAG_NAME
|
||||
|
||||
- name: Send Slack notification
|
||||
if: failure()
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
title: "🚨 Nightly Tag Push Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
failed-jobs: "create-and-push-tag"
|
||||
|
||||
35
.github/workflows/zizmor.yml
vendored
35
.github/workflows/zizmor.yml
vendored
@@ -1,35 +0,0 @@
|
||||
name: Run Zizmor
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["**"]
|
||||
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
zizmor:
|
||||
name: zizmor
|
||||
runs-on: ubuntu-slim
|
||||
permissions:
|
||||
security-events: write # needed for SARIF uploads
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # ratchet:actions/checkout@v5.0.1
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@5a7eac68fb9809dea845d802897dc5c723910fa3 # ratchet:astral-sh/setup-uv@v7.1.3
|
||||
|
||||
- name: Run zizmor
|
||||
run: uvx zizmor==1.16.3 --format=sarif . > results.sarif
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Upload SARIF file
|
||||
uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab # ratchet:github/codeql-action/upload-sarif@codeql-bundle-v2.23.5
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
category: zizmor
|
||||
@@ -1,89 +0,0 @@
|
||||
"""add internet search and content provider tables
|
||||
|
||||
Revision ID: 1f2a3b4c5d6e
|
||||
Revises: 9drpiiw74ljy
|
||||
Create Date: 2025-11-10 19:45:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1f2a3b4c5d6e"
|
||||
down_revision = "9drpiiw74ljy"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"internet_search_provider",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("name", sa.String(), nullable=False, unique=True),
|
||||
sa.Column("provider_type", sa.String(), nullable=False),
|
||||
sa.Column("api_key", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("config", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column(
|
||||
"is_active", sa.Boolean(), nullable=False, server_default=sa.text("false")
|
||||
),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_internet_search_provider_is_active",
|
||||
"internet_search_provider",
|
||||
["is_active"],
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"internet_content_provider",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("name", sa.String(), nullable=False, unique=True),
|
||||
sa.Column("provider_type", sa.String(), nullable=False),
|
||||
sa.Column("api_key", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("config", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column(
|
||||
"is_active", sa.Boolean(), nullable=False, server_default=sa.text("false")
|
||||
),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_internet_content_provider_is_active",
|
||||
"internet_content_provider",
|
||||
["is_active"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
"ix_internet_content_provider_is_active", table_name="internet_content_provider"
|
||||
)
|
||||
op.drop_table("internet_content_provider")
|
||||
op.drop_index(
|
||||
"ix_internet_search_provider_is_active", table_name="internet_search_provider"
|
||||
)
|
||||
op.drop_table("internet_search_provider")
|
||||
@@ -1,16 +1,4 @@
|
||||
group "default" {
|
||||
targets = ["backend", "model-server"]
|
||||
}
|
||||
|
||||
variable "BACKEND_REPOSITORY" {
|
||||
default = "onyxdotapp/onyx-backend"
|
||||
}
|
||||
|
||||
variable "MODEL_SERVER_REPOSITORY" {
|
||||
default = "onyxdotapp/onyx-model-server"
|
||||
}
|
||||
|
||||
variable "INTEGRATION_REPOSITORY" {
|
||||
variable "REPOSITORY" {
|
||||
default = "onyxdotapp/onyx-integration"
|
||||
}
|
||||
|
||||
@@ -21,22 +9,6 @@ variable "TAG" {
|
||||
target "backend" {
|
||||
context = "."
|
||||
dockerfile = "Dockerfile"
|
||||
|
||||
cache-from = ["type=registry,ref=${BACKEND_REPOSITORY}:latest"]
|
||||
cache-to = ["type=inline"]
|
||||
|
||||
tags = ["${BACKEND_REPOSITORY}:${TAG}"]
|
||||
}
|
||||
|
||||
target "model-server" {
|
||||
context = "."
|
||||
|
||||
dockerfile = "Dockerfile.model_server"
|
||||
|
||||
cache-from = ["type=registry,ref=${MODEL_SERVER_REPOSITORY}:latest"]
|
||||
cache-to = ["type=inline"]
|
||||
|
||||
tags = ["${MODEL_SERVER_REPOSITORY}:${TAG}"]
|
||||
}
|
||||
|
||||
target "integration" {
|
||||
@@ -48,5 +20,8 @@ target "integration" {
|
||||
base = "target:backend"
|
||||
}
|
||||
|
||||
tags = ["${INTEGRATION_REPOSITORY}:${TAG}"]
|
||||
cache-from = ["type=registry,ref=${REPOSITORY}:integration-test-backend-cache"]
|
||||
cache-to = ["type=registry,ref=${REPOSITORY}:integration-test-backend-cache,mode=max"]
|
||||
|
||||
tags = ["${REPOSITORY}:${TAG}"]
|
||||
}
|
||||
|
||||
@@ -116,7 +116,7 @@ def _concurrent_embedding(
|
||||
# the model to fail to encode texts. It's pretty rare and we want to allow
|
||||
# concurrent embedding, hence we retry (the specific error is
|
||||
# "RuntimeError: Already borrowed" and occurs in the transformers library)
|
||||
logger.warning(f"Error encoding texts, retrying: {e}")
|
||||
logger.error(f"Error encoding texts, retrying: {e}")
|
||||
time.sleep(ENCODING_RETRY_DELAY)
|
||||
return model.encode(texts, normalize_embeddings=normalize_embeddings)
|
||||
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import FunctionMessage
|
||||
|
||||
from onyx.llm.message_types import AssistantMessage
|
||||
from onyx.llm.message_types import ChatCompletionMessage
|
||||
from onyx.llm.message_types import FunctionCall
|
||||
from onyx.llm.message_types import SystemMessage
|
||||
from onyx.llm.message_types import ToolCall
|
||||
from onyx.llm.message_types import ToolMessage
|
||||
from onyx.llm.message_types import UserMessageWithText
|
||||
|
||||
|
||||
HUMAN = "human"
|
||||
SYSTEM = "system"
|
||||
AI = "ai"
|
||||
FUNCTION = "function"
|
||||
|
||||
|
||||
def base_messages_to_chat_completion_msgs(
|
||||
msgs: Sequence[BaseMessage],
|
||||
) -> list[ChatCompletionMessage]:
|
||||
return [_base_message_to_chat_completion_msg(msg) for msg in msgs]
|
||||
|
||||
|
||||
def _base_message_to_chat_completion_msg(
|
||||
msg: BaseMessage,
|
||||
) -> ChatCompletionMessage:
|
||||
if msg.type == HUMAN:
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
user_msg: UserMessageWithText = {"role": "user", "content": content}
|
||||
return user_msg
|
||||
if msg.type == SYSTEM:
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
system_msg: SystemMessage = {"role": "system", "content": content}
|
||||
return system_msg
|
||||
if msg.type == AI:
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
assistant_msg: AssistantMessage = {
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
}
|
||||
if isinstance(msg, AIMessage) and msg.tool_calls:
|
||||
assistant_msg["tool_calls"] = [
|
||||
ToolCall(
|
||||
id=tool_call.get("id") or "",
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=tool_call["name"],
|
||||
arguments=json.dumps(tool_call["args"]),
|
||||
),
|
||||
)
|
||||
for tool_call in msg.tool_calls
|
||||
]
|
||||
return assistant_msg
|
||||
if msg.type == FUNCTION:
|
||||
function_message = cast(FunctionMessage, msg)
|
||||
content = (
|
||||
function_message.content
|
||||
if isinstance(function_message.content, str)
|
||||
else str(function_message.content)
|
||||
)
|
||||
tool_msg: ToolMessage = {
|
||||
"role": "tool",
|
||||
"content": content,
|
||||
"tool_call_id": function_message.name or "",
|
||||
}
|
||||
return tool_msg
|
||||
raise ValueError(f"Unexpected message type: {msg.type}")
|
||||
@@ -1,11 +1,9 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import onyx.tracing.framework._error_tracing as _error_tracing
|
||||
from onyx.agents.agent_framework.models import RunItemStreamEvent
|
||||
from onyx.agents.agent_framework.models import StreamEvent
|
||||
from onyx.agents.agent_framework.models import ToolCallOutputStreamItem
|
||||
@@ -18,10 +16,6 @@ from onyx.llm.message_types import ToolCall
|
||||
from onyx.llm.model_response import ModelResponseStream
|
||||
from onyx.tools.tool import RunContextWrapper
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tracing.framework.create import agent_span
|
||||
from onyx.tracing.framework.create import function_span
|
||||
from onyx.tracing.framework.create import generation_span
|
||||
from onyx.tracing.framework.spans import SpanError
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -39,75 +33,6 @@ def _serialize_tool_output(output: Any) -> str:
|
||||
return str(output)
|
||||
|
||||
|
||||
def _parse_tool_calls_from_message_content(
|
||||
content: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Parse JSON content that represents tool call instructions."""
|
||||
try:
|
||||
parsed_content = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
if isinstance(parsed_content, dict):
|
||||
candidates = [parsed_content]
|
||||
elif isinstance(parsed_content, list):
|
||||
candidates = [item for item in parsed_content if isinstance(item, dict)]
|
||||
else:
|
||||
return []
|
||||
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
for candidate in candidates:
|
||||
name = candidate.get("name")
|
||||
arguments = candidate.get("arguments")
|
||||
|
||||
if not isinstance(name, str) or arguments is None:
|
||||
continue
|
||||
|
||||
if not isinstance(arguments, dict):
|
||||
continue
|
||||
|
||||
call_id = candidate.get("id")
|
||||
arguments_str = json.dumps(arguments)
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": call_id,
|
||||
"name": name,
|
||||
"arguments": arguments_str,
|
||||
}
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _try_convert_content_to_tool_calls_for_non_tool_calling_llms(
|
||||
tool_calls_in_progress: dict[int, dict[str, Any]],
|
||||
content_parts: list[str],
|
||||
structured_response_format: dict | None,
|
||||
next_synthetic_tool_call_id: Callable[[], str],
|
||||
) -> None:
|
||||
"""Populate tool_calls_in_progress when a non-tool-calling LLM returns JSON content describing tool calls."""
|
||||
if tool_calls_in_progress or not content_parts or structured_response_format:
|
||||
return
|
||||
|
||||
tool_calls_from_content = _parse_tool_calls_from_message_content(
|
||||
"".join(content_parts)
|
||||
)
|
||||
|
||||
if not tool_calls_from_content:
|
||||
return
|
||||
|
||||
content_parts.clear()
|
||||
|
||||
for index, tool_call_data in enumerate(tool_calls_from_content):
|
||||
call_id = tool_call_data["id"] or next_synthetic_tool_call_id()
|
||||
tool_calls_in_progress[index] = {
|
||||
"id": call_id,
|
||||
"name": tool_call_data["name"],
|
||||
"arguments": tool_call_data["arguments"],
|
||||
}
|
||||
|
||||
|
||||
def _update_tool_call_with_delta(
|
||||
tool_calls_in_progress: dict[int, dict[str, Any]],
|
||||
tool_call_delta: Any,
|
||||
@@ -140,225 +65,150 @@ def query(
|
||||
tools: Sequence[Tool],
|
||||
context: Any,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> QueryResult:
|
||||
tool_definitions = [tool.tool_definition() for tool in tools]
|
||||
tools_by_name = {tool.name: tool for tool in tools}
|
||||
|
||||
new_messages_stateful: list[ChatCompletionMessage] = []
|
||||
|
||||
current_span = agent_span(
|
||||
name="agent_framework_query",
|
||||
output_type="dict" if structured_response_format else "str",
|
||||
)
|
||||
current_span.start(mark_as_current=True)
|
||||
current_span.span_data.tools = [t.name for t in tools]
|
||||
|
||||
def stream_generator() -> Iterator[StreamEvent]:
|
||||
message_started = False
|
||||
reasoning_started = False
|
||||
message_started = False
|
||||
|
||||
tool_calls_in_progress: dict[int, dict[str, Any]] = {}
|
||||
|
||||
content_parts: list[str] = []
|
||||
reasoning_parts: list[str] = []
|
||||
|
||||
synthetic_tool_call_counter = 0
|
||||
for chunk in llm_with_default_settings.stream(
|
||||
prompt=messages,
|
||||
tools=tool_definitions,
|
||||
tool_choice=tool_choice,
|
||||
):
|
||||
assert isinstance(chunk, ModelResponseStream)
|
||||
|
||||
def _next_synthetic_tool_call_id() -> str:
|
||||
nonlocal synthetic_tool_call_counter
|
||||
call_id = f"synthetic_tool_call_{synthetic_tool_call_counter}"
|
||||
synthetic_tool_call_counter += 1
|
||||
return call_id
|
||||
delta = chunk.choice.delta
|
||||
finish_reason = chunk.choice.finish_reason
|
||||
|
||||
with generation_span( # type: ignore[misc]
|
||||
model=llm_with_default_settings.config.model_name,
|
||||
model_config={
|
||||
"base_url": str(llm_with_default_settings.config.api_base or ""),
|
||||
"model_impl": "litellm",
|
||||
},
|
||||
) as span_generation:
|
||||
# Only set input if messages is a sequence (not a string)
|
||||
# ChatCompletionMessage TypedDicts are compatible with Mapping[str, Any] at runtime
|
||||
if isinstance(messages, Sequence) and not isinstance(messages, str):
|
||||
# Convert ChatCompletionMessage sequence to Sequence[Mapping[str, Any]]
|
||||
span_generation.span_data.input = [dict(msg) for msg in messages] # type: ignore[assignment]
|
||||
for chunk in llm_with_default_settings.stream(
|
||||
prompt=messages,
|
||||
tools=tool_definitions,
|
||||
tool_choice=tool_choice,
|
||||
structured_response_format=structured_response_format,
|
||||
):
|
||||
assert isinstance(chunk, ModelResponseStream)
|
||||
usage = getattr(chunk, "usage", None)
|
||||
if usage:
|
||||
span_generation.span_data.usage = {
|
||||
"input_tokens": usage.prompt_tokens,
|
||||
"output_tokens": usage.completion_tokens,
|
||||
"cache_read_input_tokens": usage.cache_read_input_tokens,
|
||||
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
|
||||
}
|
||||
|
||||
delta = chunk.choice.delta
|
||||
finish_reason = chunk.choice.finish_reason
|
||||
|
||||
if delta.reasoning_content:
|
||||
if not reasoning_started:
|
||||
yield RunItemStreamEvent(type="reasoning_start")
|
||||
reasoning_started = True
|
||||
|
||||
if delta.content:
|
||||
if reasoning_started:
|
||||
yield RunItemStreamEvent(type="reasoning_done")
|
||||
reasoning_started = False
|
||||
content_parts.append(delta.content)
|
||||
if not message_started:
|
||||
yield RunItemStreamEvent(type="message_start")
|
||||
message_started = True
|
||||
|
||||
if delta.tool_calls:
|
||||
if reasoning_started:
|
||||
yield RunItemStreamEvent(type="reasoning_done")
|
||||
reasoning_started = False
|
||||
if message_started:
|
||||
yield RunItemStreamEvent(type="message_done")
|
||||
message_started = False
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
_update_tool_call_with_delta(
|
||||
tool_calls_in_progress, tool_call_delta
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
if not finish_reason:
|
||||
continue
|
||||
if delta.reasoning_content:
|
||||
reasoning_parts.append(delta.reasoning_content)
|
||||
if not reasoning_started:
|
||||
yield RunItemStreamEvent(type="reasoning_start")
|
||||
reasoning_started = True
|
||||
|
||||
if delta.content:
|
||||
content_parts.append(delta.content)
|
||||
if reasoning_started:
|
||||
yield RunItemStreamEvent(type="reasoning_done")
|
||||
reasoning_started = False
|
||||
if not message_started:
|
||||
yield RunItemStreamEvent(type="message_start")
|
||||
message_started = True
|
||||
|
||||
if delta.tool_calls:
|
||||
if reasoning_started and not message_started:
|
||||
yield RunItemStreamEvent(type="reasoning_done")
|
||||
reasoning_started = False
|
||||
if message_started:
|
||||
yield RunItemStreamEvent(type="message_done")
|
||||
message_started = False
|
||||
|
||||
if tool_choice != "none":
|
||||
_try_convert_content_to_tool_calls_for_non_tool_calling_llms(
|
||||
tool_calls_in_progress,
|
||||
content_parts,
|
||||
structured_response_format,
|
||||
_next_synthetic_tool_call_id,
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
_update_tool_call_with_delta(
|
||||
tool_calls_in_progress, tool_call_delta
|
||||
)
|
||||
|
||||
if content_parts:
|
||||
new_messages_stateful.append(
|
||||
yield chunk
|
||||
|
||||
if not finish_reason:
|
||||
continue
|
||||
if message_started:
|
||||
yield RunItemStreamEvent(type="message_done")
|
||||
message_started = False
|
||||
|
||||
if finish_reason == "tool_calls" and tool_calls_in_progress:
|
||||
sorted_tool_calls = sorted(tool_calls_in_progress.items())
|
||||
|
||||
# Build tool calls for the message and execute tools
|
||||
assistant_tool_calls: list[ToolCall] = []
|
||||
tool_outputs: dict[str, str] = {}
|
||||
|
||||
for _, tool_call_data in sorted_tool_calls:
|
||||
call_id = tool_call_data["id"]
|
||||
name = tool_call_data["name"]
|
||||
arguments_str = tool_call_data["arguments"]
|
||||
|
||||
if call_id is None or name is None:
|
||||
continue
|
||||
|
||||
assistant_tool_calls.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "".join(content_parts),
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": arguments_str,
|
||||
},
|
||||
}
|
||||
)
|
||||
span_generation.span_data.output = new_messages_stateful
|
||||
|
||||
# Execute tool calls outside of the stream loop and generation_span
|
||||
if tool_calls_in_progress:
|
||||
sorted_tool_calls = sorted(tool_calls_in_progress.items())
|
||||
yield RunItemStreamEvent(
|
||||
type="tool_call",
|
||||
details=ToolCallStreamItem(
|
||||
call_id=call_id,
|
||||
name=name,
|
||||
arguments=arguments_str,
|
||||
),
|
||||
)
|
||||
|
||||
# Build tool calls for the message and execute tools
|
||||
assistant_tool_calls: list[ToolCall] = []
|
||||
tool_outputs: dict[str, str] = {}
|
||||
if name in tools_by_name:
|
||||
tool = tools_by_name[name]
|
||||
arguments = json.loads(arguments_str)
|
||||
|
||||
for _, tool_call_data in sorted_tool_calls:
|
||||
call_id = tool_call_data["id"]
|
||||
name = tool_call_data["name"]
|
||||
arguments_str = tool_call_data["arguments"]
|
||||
run_context = RunContextWrapper(context=context)
|
||||
|
||||
if call_id is None or name is None:
|
||||
continue
|
||||
# TODO: Instead of executing sequentially, execute in parallel
|
||||
# In practice, it's not a must right now since we don't use parallel
|
||||
# tool calls, so kicking the can down the road for now.
|
||||
output = tool.run_v2(run_context, **arguments)
|
||||
tool_outputs[call_id] = _serialize_tool_output(output)
|
||||
|
||||
assistant_tool_calls.append(
|
||||
yield RunItemStreamEvent(
|
||||
type="tool_call_output",
|
||||
details=ToolCallOutputStreamItem(
|
||||
call_id=call_id,
|
||||
output=output,
|
||||
),
|
||||
)
|
||||
|
||||
new_messages_stateful.append(
|
||||
{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": arguments_str,
|
||||
},
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": assistant_tool_calls,
|
||||
}
|
||||
)
|
||||
|
||||
yield RunItemStreamEvent(
|
||||
type="tool_call",
|
||||
details=ToolCallStreamItem(
|
||||
call_id=call_id,
|
||||
name=name,
|
||||
arguments=arguments_str,
|
||||
),
|
||||
for _, tool_call_data in sorted_tool_calls:
|
||||
call_id = tool_call_data["id"]
|
||||
|
||||
if call_id in tool_outputs:
|
||||
new_messages_stateful.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": tool_outputs[call_id],
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
)
|
||||
|
||||
elif finish_reason == "stop" and content_parts:
|
||||
new_messages_stateful.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "".join(content_parts),
|
||||
}
|
||||
)
|
||||
|
||||
if name in tools_by_name:
|
||||
tool = tools_by_name[name]
|
||||
arguments = json.loads(arguments_str)
|
||||
|
||||
run_context = RunContextWrapper(context=context)
|
||||
|
||||
# TODO: Instead of executing sequentially, execute in parallel
|
||||
# In practice, it's not a must right now since we don't use parallel
|
||||
# tool calls, so kicking the can down the road for now.
|
||||
with function_span(tool.name) as span_fn:
|
||||
span_fn.span_data.input = arguments
|
||||
try:
|
||||
output = tool.run_v2(run_context, **arguments)
|
||||
tool_outputs[call_id] = _serialize_tool_output(output)
|
||||
span_fn.span_data.output = output
|
||||
except Exception as e:
|
||||
_error_tracing.attach_error_to_current_span(
|
||||
SpanError(
|
||||
message="Error running tool",
|
||||
data={"tool_name": tool.name, "error": str(e)},
|
||||
)
|
||||
)
|
||||
# Treat the error as the tool output so the framework can continue
|
||||
error_output = f"Error: {str(e)}"
|
||||
tool_outputs[call_id] = error_output
|
||||
output = error_output
|
||||
|
||||
yield RunItemStreamEvent(
|
||||
type="tool_call_output",
|
||||
details=ToolCallOutputStreamItem(
|
||||
call_id=call_id,
|
||||
output=output,
|
||||
),
|
||||
)
|
||||
else:
|
||||
not_found_output = f"Tool {name} not found"
|
||||
tool_outputs[call_id] = _serialize_tool_output(not_found_output)
|
||||
yield RunItemStreamEvent(
|
||||
type="tool_call_output",
|
||||
details=ToolCallOutputStreamItem(
|
||||
call_id=call_id,
|
||||
output=not_found_output,
|
||||
),
|
||||
)
|
||||
|
||||
new_messages_stateful.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": assistant_tool_calls,
|
||||
}
|
||||
)
|
||||
|
||||
for _, tool_call_data in sorted_tool_calls:
|
||||
call_id = tool_call_data["id"]
|
||||
|
||||
if call_id in tool_outputs:
|
||||
new_messages_stateful.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": tool_outputs[call_id],
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
)
|
||||
current_span.finish(reset_current=True)
|
||||
|
||||
return QueryResult(
|
||||
stream=stream_generator(),
|
||||
new_messages_stateful=new_messages_stateful,
|
||||
|
||||
@@ -26,9 +26,9 @@ def monkey_patch_convert_tool_choice_to_ignore_openai_hosted_web_search() -> Non
|
||||
# Without this patch, the library uses special formatting that breaks our custom tools
|
||||
# See: https://platform.openai.com/docs/api-reference/responses/create#responses_create-tool_choice-hosted_tool-type
|
||||
if tool_choice == "web_search":
|
||||
return "web_search"
|
||||
return {"type": "function", "name": "web_search"}
|
||||
if tool_choice == "image_generation":
|
||||
return "image_generation"
|
||||
return {"type": "function", "name": "image_generation"}
|
||||
return orig_func(cls, tool_choice)
|
||||
|
||||
OpenAIResponsesConverter.convert_tool_choice = classmethod( # type: ignore[method-assign, assignment]
|
||||
|
||||
@@ -12,12 +12,13 @@ from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.configs.chat_configs import EXA_API_KEY
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
class ExaClient(WebSearchProvider):
|
||||
def __init__(self, api_key: str) -> None:
|
||||
def __init__(self, api_key: str | None = EXA_API_KEY) -> None:
|
||||
self.exa = Exa(api_key=api_key)
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
|
||||
@@ -1,163 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebContent
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebContentProvider,
|
||||
)
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
FIRECRAWL_SCRAPE_URL = "https://api.firecrawl.dev/v1/scrape"
|
||||
_DEFAULT_MAX_WORKERS = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedContentFields:
|
||||
text: str
|
||||
title: str
|
||||
published_date: datetime | None
|
||||
|
||||
|
||||
class FirecrawlClient(WebContentProvider):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
*,
|
||||
base_url: str = FIRECRAWL_SCRAPE_URL,
|
||||
timeout_seconds: int = 30,
|
||||
) -> None:
|
||||
|
||||
self._headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
self._base_url = base_url
|
||||
self._timeout_seconds = timeout_seconds
|
||||
self._last_error: str | None = None
|
||||
|
||||
@property
|
||||
def last_error(self) -> str | None:
|
||||
return self._last_error
|
||||
|
||||
def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
if not urls:
|
||||
return []
|
||||
|
||||
max_workers = min(_DEFAULT_MAX_WORKERS, len(urls))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
return list(executor.map(self._get_webpage_content_safe, urls))
|
||||
|
||||
def _get_webpage_content_safe(self, url: str) -> WebContent:
|
||||
try:
|
||||
return self._get_webpage_content(url)
|
||||
except Exception as exc:
|
||||
self._last_error = str(exc)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def _get_webpage_content(self, url: str) -> WebContent:
|
||||
payload = {
|
||||
"url": url,
|
||||
"formats": ["markdown"],
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
self._base_url,
|
||||
headers=self._headers,
|
||||
json=payload,
|
||||
timeout=self._timeout_seconds,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
error_payload = response.json()
|
||||
except Exception:
|
||||
error_payload = response.text
|
||||
self._last_error = (
|
||||
error_payload if isinstance(error_payload, str) else str(error_payload)
|
||||
)
|
||||
|
||||
if 400 <= response.status_code < 500:
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Firecrawl fetch failed with status {response.status_code}."
|
||||
)
|
||||
else:
|
||||
self._last_error = None
|
||||
|
||||
response_json = response.json()
|
||||
extracted = self._extract_content_fields(response_json, url)
|
||||
|
||||
return WebContent(
|
||||
title=extracted.title,
|
||||
link=url,
|
||||
full_content=extracted.text,
|
||||
published_date=extracted.published_date,
|
||||
scrape_successful=bool(extracted.text),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_content_fields(
|
||||
response_json: dict[str, Any], url: str
|
||||
) -> ExtractedContentFields:
|
||||
data_section = response_json.get("data") or {}
|
||||
metadata = data_section.get("metadata") or response_json.get("metadata") or {}
|
||||
|
||||
text_candidates = [
|
||||
data_section.get("markdown"),
|
||||
data_section.get("content"),
|
||||
data_section.get("text"),
|
||||
response_json.get("markdown"),
|
||||
response_json.get("content"),
|
||||
response_json.get("text"),
|
||||
]
|
||||
|
||||
text = next((candidate for candidate in text_candidates if candidate), "")
|
||||
title = metadata.get("title") or response_json.get("title") or ""
|
||||
published_date = None
|
||||
|
||||
published_date_str = (
|
||||
metadata.get("publishedTime")
|
||||
or metadata.get("date")
|
||||
or response_json.get("publishedTime")
|
||||
or response_json.get("date")
|
||||
)
|
||||
|
||||
if published_date_str:
|
||||
try:
|
||||
published_date = time_str_to_utc(published_date_str)
|
||||
except Exception:
|
||||
published_date = None
|
||||
|
||||
if not text:
|
||||
logger.warning(f"Firecrawl returned empty content for url={url}")
|
||||
|
||||
return ExtractedContentFields(
|
||||
text=text or "",
|
||||
title=title or "",
|
||||
published_date=published_date,
|
||||
)
|
||||
@@ -1,138 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebContent
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebSearchResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
GOOGLE_CUSTOM_SEARCH_URL = "https://customsearch.googleapis.com/customsearch/v1"
|
||||
|
||||
|
||||
class GooglePSEClient(WebSearchProvider):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
search_engine_id: str,
|
||||
*,
|
||||
num_results: int = 10,
|
||||
timeout_seconds: int = 10,
|
||||
) -> None:
|
||||
self._api_key = api_key
|
||||
self._search_engine_id = search_engine_id
|
||||
self._num_results = num_results
|
||||
self._timeout_seconds = timeout_seconds
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def search(self, query: str) -> list[WebSearchResult]:
|
||||
params: dict[str, str] = {
|
||||
"key": self._api_key,
|
||||
"cx": self._search_engine_id,
|
||||
"q": query,
|
||||
"num": str(self._num_results),
|
||||
}
|
||||
|
||||
response = requests.get(
|
||||
GOOGLE_CUSTOM_SEARCH_URL, params=params, timeout=self._timeout_seconds
|
||||
)
|
||||
|
||||
# Check for HTTP errors first
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as exc:
|
||||
status = response.status_code
|
||||
error_detail = "Unknown error"
|
||||
try:
|
||||
error_data = response.json()
|
||||
if "error" in error_data:
|
||||
error_info = error_data["error"]
|
||||
error_detail = error_info.get("message", str(error_info))
|
||||
except Exception:
|
||||
error_detail = (
|
||||
response.text[:200] if response.text else "No error details"
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Google PSE search failed (status {status}): {error_detail}"
|
||||
) from exc
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Google Custom Search API can return errors in the response body even with 200 status
|
||||
if "error" in data:
|
||||
error_info = data["error"]
|
||||
error_message = error_info.get("message", "Unknown error")
|
||||
error_code = error_info.get("code", "Unknown")
|
||||
raise ValueError(f"Google PSE API error ({error_code}): {error_message}")
|
||||
|
||||
items: list[dict[str, Any]] = data.get("items", [])
|
||||
results: list[WebSearchResult] = []
|
||||
|
||||
for item in items:
|
||||
link = item.get("link")
|
||||
if not link:
|
||||
continue
|
||||
|
||||
snippet = item.get("snippet") or ""
|
||||
|
||||
# Attempt to extract metadata if available
|
||||
pagemap = item.get("pagemap") or {}
|
||||
metatags = pagemap.get("metatags", [])
|
||||
published_date: datetime | None = None
|
||||
author: str | None = None
|
||||
|
||||
if metatags:
|
||||
meta = metatags[0]
|
||||
author = meta.get("og:site_name") or meta.get("author")
|
||||
published_str = (
|
||||
meta.get("article:published_time")
|
||||
or meta.get("og:updated_time")
|
||||
or meta.get("date")
|
||||
)
|
||||
if published_str:
|
||||
try:
|
||||
published_date = datetime.fromisoformat(
|
||||
published_str.replace("Z", "+00:00")
|
||||
)
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
f"Failed to parse published_date '{published_str}' for link {link}"
|
||||
)
|
||||
published_date = None
|
||||
|
||||
results.append(
|
||||
WebSearchResult(
|
||||
title=item.get("title") or "",
|
||||
link=link,
|
||||
snippet=snippet,
|
||||
author=author,
|
||||
published_date=published_date,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
logger.warning(
|
||||
"Google PSE does not support content fetching; returning empty results."
|
||||
)
|
||||
return [
|
||||
WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
for url in urls
|
||||
]
|
||||
@@ -1,94 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebContent,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebContentProvider,
|
||||
)
|
||||
from onyx.file_processing.html_utils import ParsedHTML
|
||||
from onyx.file_processing.html_utils import web_html_cleanup
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DEFAULT_TIMEOUT_SECONDS = 15
|
||||
DEFAULT_USER_AGENT = "OnyxWebCrawler/1.0 (+https://www.onyx.app)"
|
||||
|
||||
|
||||
class OnyxWebCrawlerClient(WebContentProvider):
|
||||
"""
|
||||
Lightweight built-in crawler that fetches HTML directly and extracts readable text.
|
||||
Acts as the default content provider when no external crawler (e.g. Firecrawl) is
|
||||
configured.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
|
||||
user_agent: str = DEFAULT_USER_AGENT,
|
||||
) -> None:
|
||||
self._timeout_seconds = timeout_seconds
|
||||
self._headers = {
|
||||
"User-Agent": user_agent,
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
|
||||
}
|
||||
|
||||
def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
results: list[WebContent] = []
|
||||
for url in urls:
|
||||
results.append(self._fetch_url(url))
|
||||
return results
|
||||
|
||||
def _fetch_url(self, url: str) -> WebContent:
|
||||
try:
|
||||
response = requests.get(
|
||||
url, headers=self._headers, timeout=self._timeout_seconds
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - network failures vary
|
||||
logger.warning(
|
||||
"Onyx crawler failed to fetch %s (%s)",
|
||||
url,
|
||||
exc.__class__.__name__,
|
||||
)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
if response.status_code >= 400:
|
||||
logger.warning("Onyx crawler received %s for %s", response.status_code, url)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
try:
|
||||
parsed: ParsedHTML = web_html_cleanup(response.text)
|
||||
text_content = parsed.cleaned_text or ""
|
||||
title = parsed.title or ""
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Onyx crawler failed to parse %s (%s)", url, exc.__class__.__name__
|
||||
)
|
||||
text_content = ""
|
||||
title = ""
|
||||
|
||||
return WebContent(
|
||||
title=title,
|
||||
link=url,
|
||||
full_content=text_content,
|
||||
published_date=None,
|
||||
scrape_successful=bool(text_content.strip()),
|
||||
)
|
||||
@@ -13,6 +13,7 @@ from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.configs.chat_configs import SERPER_API_KEY
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
@@ -21,7 +22,7 @@ SERPER_CONTENTS_URL = "https://scrape.serper.dev"
|
||||
|
||||
|
||||
class SerperClient(WebSearchProvider):
|
||||
def __init__(self, api_key: str) -> None:
|
||||
def __init__(self, api_key: str | None = SERPER_API_KEY) -> None:
|
||||
self.headers = {
|
||||
"X-API-KEY": api_key,
|
||||
"Content-Type": "application/json",
|
||||
@@ -39,13 +40,7 @@ class SerperClient(WebSearchProvider):
|
||||
data=json.dumps(payload),
|
||||
)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except Exception:
|
||||
# Avoid leaking API keys/URLs
|
||||
raise ValueError(
|
||||
"Serper search failed. Check credentials or quota."
|
||||
) from None
|
||||
response.raise_for_status()
|
||||
|
||||
results = response.json()
|
||||
organic_results = results["organic"]
|
||||
@@ -104,13 +99,7 @@ class SerperClient(WebSearchProvider):
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except Exception:
|
||||
# Avoid leaking API keys/URLs
|
||||
raise ValueError(
|
||||
"Serper content fetch failed. Check credentials."
|
||||
) from None
|
||||
response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
|
||||
@@ -74,22 +74,13 @@ def web_search(
|
||||
if not provider:
|
||||
raise ValueError("No internet search provider found")
|
||||
|
||||
# Log which provider type is being used
|
||||
provider_type = type(provider).__name__
|
||||
logger.info(
|
||||
f"Performing web search with {provider_type} for query: '{search_query}'"
|
||||
)
|
||||
|
||||
@traceable(name="Search Provider API Call")
|
||||
def _search(search_query: str) -> list[WebSearchResult]:
|
||||
search_results: list[WebSearchResult] = []
|
||||
try:
|
||||
search_results = list(provider.search(search_query))
|
||||
logger.info(
|
||||
f"Search returned {len(search_results)} results using {provider_type}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error performing search with {provider_type}: {e}")
|
||||
logger.error(f"Error performing search: {e}")
|
||||
return search_results
|
||||
|
||||
search_results: list[WebSearchResult] = _search(search_query)
|
||||
|
||||
@@ -5,7 +5,7 @@ from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
|
||||
get_default_content_provider,
|
||||
get_default_provider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchUpdate
|
||||
@@ -41,9 +41,9 @@ def web_fetch(
|
||||
if graph_config.inputs.persona is None:
|
||||
raise ValueError("persona is not set")
|
||||
|
||||
provider = get_default_content_provider()
|
||||
provider = get_default_provider()
|
||||
if provider is None:
|
||||
raise ValueError("No web content provider found")
|
||||
raise ValueError("No web search provider found")
|
||||
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
try:
|
||||
@@ -52,7 +52,7 @@ def web_fetch(
|
||||
for result in provider.contents(state.urls_to_open)
|
||||
]
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
logger.error(f"Error fetching URLs: {e}")
|
||||
|
||||
if not retrieved_docs:
|
||||
logger.warning("No content retrieved from URLs")
|
||||
|
||||
@@ -2,6 +2,7 @@ from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import field_validator
|
||||
@@ -9,6 +10,13 @@ from pydantic import field_validator
|
||||
from onyx.utils.url import normalize_url
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
"""Enum for internet search provider types"""
|
||||
|
||||
GOOGLE = "google"
|
||||
EXA = "exa"
|
||||
|
||||
|
||||
class WebSearchResult(BaseModel):
|
||||
title: str
|
||||
link: str
|
||||
@@ -35,13 +43,11 @@ class WebContent(BaseModel):
|
||||
return normalize_url(v)
|
||||
|
||||
|
||||
class WebContentProvider(ABC):
|
||||
@abstractmethod
|
||||
def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
pass
|
||||
|
||||
|
||||
class WebSearchProvider(WebContentProvider):
|
||||
class WebSearchProvider(ABC):
|
||||
@abstractmethod
|
||||
def search(self, query: str) -> Sequence[WebSearchResult]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
pass
|
||||
|
||||
@@ -1,199 +1,19 @@
|
||||
from typing import Any
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.exa_client import (
|
||||
ExaClient,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.firecrawl_client import (
|
||||
FIRECRAWL_SCRAPE_URL,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.firecrawl_client import (
|
||||
FirecrawlClient,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.google_pse_client import (
|
||||
GooglePSEClient,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.onyx_web_crawler_client import (
|
||||
OnyxWebCrawlerClient,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.serper_client import (
|
||||
SerperClient,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebContentProvider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.web_search import fetch_active_web_content_provider
|
||||
from onyx.db.web_search import fetch_active_web_search_provider
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.enums import WebContentProviderType
|
||||
from shared_configs.enums import WebSearchProviderType
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_search_provider_from_config(
|
||||
*,
|
||||
provider_type: WebSearchProviderType,
|
||||
api_key: str | None,
|
||||
config: dict[str, str] | None,
|
||||
provider_name: str = "web_search_provider",
|
||||
) -> WebSearchProvider | None:
|
||||
provider_type_value = provider_type.value
|
||||
try:
|
||||
provider_type_enum = WebSearchProviderType(provider_type_value)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
f"Unknown web search provider type '{provider_type_value}'. "
|
||||
"Skipping provider initialization."
|
||||
)
|
||||
return None
|
||||
|
||||
# All web search providers require an API key
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
f"Web search provider '{provider_name}' is missing an API key."
|
||||
)
|
||||
assert api_key is not None
|
||||
|
||||
config = config or {}
|
||||
|
||||
if provider_type_enum == WebSearchProviderType.EXA:
|
||||
return ExaClient(api_key=api_key)
|
||||
if provider_type_enum == WebSearchProviderType.SERPER:
|
||||
return SerperClient(api_key=api_key)
|
||||
if provider_type_enum == WebSearchProviderType.GOOGLE_PSE:
|
||||
search_engine_id = (
|
||||
config.get("search_engine_id")
|
||||
or config.get("cx")
|
||||
or config.get("search_engine")
|
||||
)
|
||||
if not search_engine_id:
|
||||
raise ValueError(
|
||||
"Google PSE provider requires a search engine id (cx) in addition to the API key."
|
||||
)
|
||||
assert search_engine_id is not None
|
||||
try:
|
||||
num_results = int(config.get("num_results", 10))
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(
|
||||
"Invalid value for Google PSE 'num_results'; expected integer."
|
||||
)
|
||||
try:
|
||||
timeout_seconds = int(config.get("timeout_seconds", 10))
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(
|
||||
"Invalid value for Google PSE 'timeout_seconds'; expected integer."
|
||||
)
|
||||
return GooglePSEClient(
|
||||
api_key=api_key,
|
||||
search_engine_id=search_engine_id,
|
||||
num_results=num_results,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
|
||||
logger.error(
|
||||
f"Unhandled web search provider type '{provider_type_value}'. "
|
||||
"Skipping provider initialization."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _build_search_provider(provider_model: Any) -> WebSearchProvider | None:
|
||||
return build_search_provider_from_config(
|
||||
provider_type=WebSearchProviderType(provider_model.provider_type),
|
||||
api_key=provider_model.api_key,
|
||||
config=provider_model.config or {},
|
||||
provider_name=provider_model.name,
|
||||
)
|
||||
|
||||
|
||||
def build_content_provider_from_config(
|
||||
*,
|
||||
provider_type: WebContentProviderType,
|
||||
api_key: str | None,
|
||||
config: dict[str, str] | None,
|
||||
provider_name: str = "web_content_provider",
|
||||
) -> WebContentProvider | None:
|
||||
provider_type_value = provider_type.value
|
||||
try:
|
||||
provider_type_enum = WebContentProviderType(provider_type_value)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
f"Unknown web content provider type '{provider_type_value}'. "
|
||||
"Skipping provider initialization."
|
||||
)
|
||||
return None
|
||||
|
||||
if provider_type_enum == WebContentProviderType.ONYX_WEB_CRAWLER:
|
||||
config = config or {}
|
||||
timeout_value = config.get("timeout_seconds", 15)
|
||||
try:
|
||||
timeout_seconds = int(timeout_value)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(
|
||||
"Invalid value for Onyx Web Crawler 'timeout_seconds'; expected integer."
|
||||
)
|
||||
return OnyxWebCrawlerClient(timeout_seconds=timeout_seconds)
|
||||
|
||||
if provider_type_enum == WebContentProviderType.FIRECRAWL:
|
||||
if not api_key:
|
||||
raise ValueError("Firecrawl content provider requires an API key.")
|
||||
assert api_key is not None
|
||||
config = config or {}
|
||||
timeout_seconds_str = config.get("timeout_seconds")
|
||||
if timeout_seconds_str is None:
|
||||
timeout_seconds = 10
|
||||
else:
|
||||
try:
|
||||
timeout_seconds = int(timeout_seconds_str)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(
|
||||
"Invalid value for Firecrawl 'timeout_seconds'; expected integer."
|
||||
)
|
||||
return FirecrawlClient(
|
||||
api_key=api_key,
|
||||
base_url=config.get("base_url") or FIRECRAWL_SCRAPE_URL,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
|
||||
logger.error(
|
||||
f"Unhandled web content provider type '{provider_type_value}'. "
|
||||
"Skipping provider initialization."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _build_content_provider(provider_model: Any) -> WebContentProvider | None:
|
||||
return build_content_provider_from_config(
|
||||
provider_type=WebContentProviderType(provider_model.provider_type),
|
||||
api_key=provider_model.api_key,
|
||||
config=provider_model.config or {},
|
||||
provider_name=provider_model.name,
|
||||
)
|
||||
from onyx.configs.chat_configs import EXA_API_KEY
|
||||
from onyx.configs.chat_configs import SERPER_API_KEY
|
||||
|
||||
|
||||
def get_default_provider() -> WebSearchProvider | None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
provider_model = fetch_active_web_search_provider(db_session)
|
||||
if provider_model is None:
|
||||
return None
|
||||
return _build_search_provider(provider_model)
|
||||
|
||||
|
||||
def get_default_content_provider() -> WebContentProvider | None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
provider_model = fetch_active_web_content_provider(db_session)
|
||||
if provider_model:
|
||||
provider = _build_content_provider(provider_model)
|
||||
if provider:
|
||||
return provider
|
||||
|
||||
# Fall back to built-in Onyx crawler when nothing is configured.
|
||||
try:
|
||||
return OnyxWebCrawlerClient()
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.error(f"Failed to initialize default Onyx crawler: {exc}")
|
||||
return None
|
||||
if EXA_API_KEY:
|
||||
return ExaClient()
|
||||
if SERPER_API_KEY:
|
||||
return SerperClient()
|
||||
return None
|
||||
|
||||
@@ -11,8 +11,11 @@ from jwt import decode as jwt_decode
|
||||
from jwt import InvalidTokenError
|
||||
from jwt import PyJWTError
|
||||
from jwt.algorithms import RSAAlgorithm
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.configs.app_configs import JWT_PUBLIC_KEY_URL
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -128,7 +131,7 @@ def _resolve_public_key_from_jwks(
|
||||
return None
|
||||
|
||||
|
||||
async def verify_jwt_token(token: str) -> dict[str, Any] | None:
|
||||
async def verify_jwt_token(token: str, async_db_session: AsyncSession) -> User | None:
|
||||
for attempt in range(_PUBLIC_KEY_FETCH_ATTEMPTS):
|
||||
public_key = get_public_key(token)
|
||||
if public_key is None:
|
||||
@@ -139,6 +142,8 @@ async def verify_jwt_token(token: str) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
try:
|
||||
from sqlalchemy import func
|
||||
|
||||
payload = jwt_decode(
|
||||
token,
|
||||
public_key,
|
||||
@@ -158,6 +163,15 @@ async def verify_jwt_token(token: str) -> dict[str, Any] | None:
|
||||
continue
|
||||
return None
|
||||
|
||||
return payload
|
||||
email = payload.get("email")
|
||||
if email:
|
||||
result = await async_db_session.execute(
|
||||
select(User).where(func.lower(User.email) == func.lower(email))
|
||||
)
|
||||
return result.scalars().first()
|
||||
logger.warning(
|
||||
"JWT token decoded successfully but no email claim found; skipping auth"
|
||||
)
|
||||
break
|
||||
|
||||
return None
|
||||
|
||||
@@ -1063,107 +1063,6 @@ fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
|
||||
optional_fastapi_current_user = fastapi_users.current_user(active=True, optional=True)
|
||||
|
||||
|
||||
_JWT_EMAIL_CLAIM_KEYS = ("email", "preferred_username", "upn")
|
||||
|
||||
|
||||
def _extract_email_from_jwt(payload: dict[str, Any]) -> str | None:
|
||||
"""Return the best-effort email/username from a decoded JWT payload."""
|
||||
for key in _JWT_EMAIL_CLAIM_KEYS:
|
||||
value = payload.get(key)
|
||||
if isinstance(value, str) and value:
|
||||
try:
|
||||
email_info = validate_email(value, check_deliverability=False)
|
||||
except EmailNotValidError:
|
||||
continue
|
||||
normalized_email = email_info.normalized or email_info.email
|
||||
return normalized_email.lower()
|
||||
return None
|
||||
|
||||
|
||||
async def _sync_jwt_oidc_expiry(
|
||||
user_manager: UserManager, user: User, payload: dict[str, Any]
|
||||
) -> None:
|
||||
if TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
expires_at = payload.get("exp")
|
||||
if expires_at is None:
|
||||
return
|
||||
try:
|
||||
expiry_timestamp = int(expires_at)
|
||||
except (TypeError, ValueError):
|
||||
logger.warning("Invalid exp claim on JWT for user %s", user.email)
|
||||
return
|
||||
|
||||
oidc_expiry = datetime.fromtimestamp(expiry_timestamp, tz=timezone.utc)
|
||||
if user.oidc_expiry == oidc_expiry:
|
||||
return
|
||||
|
||||
await user_manager.user_db.update(user, {"oidc_expiry": oidc_expiry})
|
||||
user.oidc_expiry = oidc_expiry # type: ignore
|
||||
return
|
||||
|
||||
if user.oidc_expiry is not None:
|
||||
await user_manager.user_db.update(user, {"oidc_expiry": None})
|
||||
user.oidc_expiry = None # type: ignore
|
||||
|
||||
|
||||
async def _get_or_create_user_from_jwt(
|
||||
payload: dict[str, Any],
|
||||
request: Request,
|
||||
async_db_session: AsyncSession,
|
||||
) -> User | None:
|
||||
email = _extract_email_from_jwt(payload)
|
||||
if email is None:
|
||||
logger.warning(
|
||||
"JWT token decoded successfully but no email claim found; skipping auth"
|
||||
)
|
||||
return None
|
||||
|
||||
# Enforce the same allowlist/domain policies as other auth flows
|
||||
verify_email_is_invited(email)
|
||||
verify_email_domain(email)
|
||||
|
||||
user_db: SQLAlchemyUserAdminDB[User, uuid.UUID] = SQLAlchemyUserAdminDB(
|
||||
async_db_session, User, OAuthAccount
|
||||
)
|
||||
user_manager = UserManager(user_db)
|
||||
|
||||
try:
|
||||
user = await user_manager.get_by_email(email)
|
||||
if not user.is_active:
|
||||
logger.warning("Inactive user %s attempted JWT login; skipping", email)
|
||||
return None
|
||||
if not user.role.is_web_login():
|
||||
raise exceptions.UserNotExists()
|
||||
except exceptions.UserNotExists:
|
||||
logger.info("Provisioning user %s from JWT login", email)
|
||||
try:
|
||||
user = await user_manager.create(
|
||||
UserCreate(
|
||||
email=email,
|
||||
password=generate_password(),
|
||||
is_verified=True,
|
||||
),
|
||||
request=request,
|
||||
)
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await user_manager.get_by_email(email)
|
||||
if not user.is_active:
|
||||
logger.warning(
|
||||
"Inactive user %s attempted JWT login during provisioning race; skipping",
|
||||
email,
|
||||
)
|
||||
return None
|
||||
if not user.role.is_web_login():
|
||||
logger.warning(
|
||||
"Non-web-login user %s attempted JWT login during provisioning race; skipping",
|
||||
email,
|
||||
)
|
||||
return None
|
||||
|
||||
await _sync_jwt_oidc_expiry(user_manager, user, payload)
|
||||
return user
|
||||
|
||||
|
||||
async def _check_for_saml_and_jwt(
|
||||
request: Request,
|
||||
user: User | None,
|
||||
@@ -1174,11 +1073,7 @@ async def _check_for_saml_and_jwt(
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
token = auth_header[len("Bearer ") :].strip()
|
||||
payload = await verify_jwt_token(token)
|
||||
if payload is not None:
|
||||
user = await _get_or_create_user_from_jwt(
|
||||
payload, request, async_db_session
|
||||
)
|
||||
user = await verify_jwt_token(token, async_db_session)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@@ -53,8 +53,8 @@ from onyx.server.query_and_chat.streaming_models import PacketObj
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.adapter_v1_to_v2 import force_use_tool_to_function_tool_names
|
||||
from onyx.tools.adapter_v1_to_v2 import tools_to_function_tools
|
||||
from onyx.tools.force import filter_tools_for_force_tool_use
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
@@ -109,10 +109,6 @@ def _run_agent_loop(
|
||||
available_tools: Sequence[Tool] = (
|
||||
dependencies.tools if iteration_count < MAX_ITERATIONS else []
|
||||
)
|
||||
if force_use_tool and force_use_tool.force_use:
|
||||
available_tools = filter_tools_for_force_tool_use(
|
||||
list(available_tools), force_use_tool
|
||||
)
|
||||
memories = get_memories(dependencies.user_or_none, dependencies.db_session)
|
||||
# TODO: The system is rather prompt-cache efficient except for rebuilding the system prompt.
|
||||
# The biggest offender is when we hit max iterations and then all the tool calls cannot
|
||||
@@ -146,8 +142,10 @@ def _run_agent_loop(
|
||||
tool_choice = None
|
||||
else:
|
||||
tool_choice = (
|
||||
"required" if force_use_tool and force_use_tool.force_use else "auto"
|
||||
)
|
||||
force_use_tool_to_function_tool_names(force_use_tool, available_tools)
|
||||
if iteration_count == 0 and force_use_tool
|
||||
else None
|
||||
) or "auto"
|
||||
model_settings = replace(dependencies.model_settings, tool_choice=tool_choice)
|
||||
|
||||
agent = Agent(
|
||||
|
||||
@@ -89,6 +89,9 @@ STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
|
||||
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true"
|
||||
|
||||
# Internet Search
|
||||
EXA_API_KEY = os.environ.get("EXA_API_KEY") or None
|
||||
SERPER_API_KEY = os.environ.get("SERPER_API_KEY") or None
|
||||
|
||||
NUM_INTERNET_SEARCH_RESULTS = int(os.environ.get("NUM_INTERNET_SEARCH_RESULTS") or 10)
|
||||
NUM_INTERNET_SEARCH_CHUNKS = int(os.environ.get("NUM_INTERNET_SEARCH_CHUNKS") or 50)
|
||||
|
||||
|
||||
@@ -426,7 +426,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
|
||||
title = ticket.properties.get("subject") or f"Ticket {ticket.id}"
|
||||
link = self._get_object_url("tickets", ticket.id)
|
||||
content_text = ticket.properties.get("content") or ""
|
||||
content_text = ticket.properties.get("content", "")
|
||||
|
||||
# Main ticket section
|
||||
sections = [TextSection(link=link, text=content_text)]
|
||||
|
||||
@@ -15,14 +15,10 @@ from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.context.search.federated.models import SlackMessage
|
||||
from onyx.context.search.federated.slack_search_utils import ALL_CHANNEL_TYPES
|
||||
from onyx.context.search.federated.slack_search_utils import build_channel_query_filter
|
||||
from onyx.context.search.federated.slack_search_utils import build_slack_queries
|
||||
from onyx.context.search.federated.slack_search_utils import ChannelTypeString
|
||||
from onyx.context.search.federated.slack_search_utils import get_channel_type
|
||||
from onyx.context.search.federated.slack_search_utils import (
|
||||
get_channel_type_for_missing_scope,
|
||||
)
|
||||
from onyx.context.search.federated.slack_search_utils import is_recency_query
|
||||
from onyx.context.search.federated.slack_search_utils import should_include_message
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
@@ -50,6 +46,7 @@ logger = setup_logger()
|
||||
HIGHLIGHT_START_CHAR = "\ue000"
|
||||
HIGHLIGHT_END_CHAR = "\ue001"
|
||||
|
||||
CHANNEL_TYPES = ["public_channel", "im", "mpim", "private_channel"]
|
||||
CHANNEL_METADATA_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
SLACK_THREAD_CONTEXT_WINDOW = 3 # Number of messages before matched message to include
|
||||
CHANNEL_METADATA_MAX_RETRIES = 3 # Maximum retry attempts for channel metadata fetching
|
||||
@@ -101,12 +98,10 @@ def fetch_and_cache_channel_metadata(
|
||||
|
||||
# Retry logic with exponential backoff
|
||||
last_exception = None
|
||||
available_channel_types = ALL_CHANNEL_TYPES.copy()
|
||||
|
||||
for attempt in range(CHANNEL_METADATA_MAX_RETRIES):
|
||||
try:
|
||||
# Use available channel types (may be reduced if scopes are missing)
|
||||
channel_types = ",".join(available_channel_types)
|
||||
# ALWAYS fetch all channel types including private
|
||||
channel_types = ",".join(CHANNEL_TYPES)
|
||||
|
||||
# Fetch all channels in one call
|
||||
cursor = None
|
||||
@@ -162,42 +157,6 @@ def fetch_and_cache_channel_metadata(
|
||||
|
||||
except SlackApiError as e:
|
||||
last_exception = e
|
||||
|
||||
# Extract all needed fields from response upfront
|
||||
if e.response:
|
||||
error_response = e.response.get("error", "")
|
||||
needed_scope = e.response.get("needed", "")
|
||||
else:
|
||||
error_response = ""
|
||||
needed_scope = ""
|
||||
|
||||
# Check if this is a missing_scope error
|
||||
if error_response == "missing_scope":
|
||||
|
||||
# Get the channel type that requires this scope
|
||||
missing_channel_type = get_channel_type_for_missing_scope(needed_scope)
|
||||
|
||||
if (
|
||||
missing_channel_type
|
||||
and missing_channel_type in available_channel_types
|
||||
):
|
||||
# Remove the problematic channel type and retry
|
||||
available_channel_types.remove(missing_channel_type)
|
||||
logger.warning(
|
||||
f"Missing scope '{needed_scope}' for channel type '{missing_channel_type}'. "
|
||||
f"Continuing with reduced channel types: {available_channel_types}"
|
||||
)
|
||||
# Don't count this as a retry attempt, just try again with fewer types
|
||||
if available_channel_types: # Only continue if we have types left
|
||||
continue
|
||||
# Otherwise fall through to retry logic
|
||||
else:
|
||||
logger.error(
|
||||
f"Missing scope '{needed_scope}' but could not map to channel type or already removed. "
|
||||
f"Response: {e.response}"
|
||||
)
|
||||
|
||||
# For other errors, use retry logic
|
||||
if attempt < CHANNEL_METADATA_MAX_RETRIES - 1:
|
||||
retry_delay = CHANNEL_METADATA_RETRY_DELAY * (2**attempt)
|
||||
logger.warning(
|
||||
@@ -210,15 +169,7 @@ def fetch_and_cache_channel_metadata(
|
||||
f"Failed to fetch channel metadata after {CHANNEL_METADATA_MAX_RETRIES} attempts: {e}"
|
||||
)
|
||||
|
||||
# If we have some channel metadata despite errors, return it with a warning
|
||||
if channel_metadata:
|
||||
logger.warning(
|
||||
f"Returning partial channel metadata ({len(channel_metadata)} channels) despite errors. "
|
||||
f"Last error: {last_exception}"
|
||||
)
|
||||
return channel_metadata
|
||||
|
||||
# If we exhausted all retries and have no data, raise the last exception
|
||||
# If we exhausted all retries, raise the last exception
|
||||
if last_exception:
|
||||
raise SlackApiError(
|
||||
f"Channel metadata fetching failed after {CHANNEL_METADATA_MAX_RETRIES} attempts",
|
||||
|
||||
@@ -29,9 +29,6 @@ DAYS_PER_WEEK = 7
|
||||
DAYS_PER_MONTH = 30
|
||||
MAX_CONTENT_WORDS = 3
|
||||
|
||||
# Punctuation to strip from words during analysis
|
||||
WORD_PUNCTUATION = ".,!?;:\"'#"
|
||||
|
||||
RECENCY_KEYWORDS = ["recent", "latest", "newest", "last"]
|
||||
|
||||
|
||||
@@ -44,48 +41,6 @@ class ChannelTypeString(str, Enum):
|
||||
PUBLIC_CHANNEL = "public_channel"
|
||||
|
||||
|
||||
# All Slack channel types for fetching metadata
|
||||
ALL_CHANNEL_TYPES = [
|
||||
ChannelTypeString.PUBLIC_CHANNEL.value,
|
||||
ChannelTypeString.IM.value,
|
||||
ChannelTypeString.MPIM.value,
|
||||
ChannelTypeString.PRIVATE_CHANNEL.value,
|
||||
]
|
||||
|
||||
# Map Slack API scopes to their corresponding channel types
|
||||
# This is used for graceful degradation when scopes are missing
|
||||
SCOPE_TO_CHANNEL_TYPE_MAP = {
|
||||
"mpim:read": ChannelTypeString.MPIM.value,
|
||||
"mpim:history": ChannelTypeString.MPIM.value,
|
||||
"im:read": ChannelTypeString.IM.value,
|
||||
"im:history": ChannelTypeString.IM.value,
|
||||
"groups:read": ChannelTypeString.PRIVATE_CHANNEL.value,
|
||||
"groups:history": ChannelTypeString.PRIVATE_CHANNEL.value,
|
||||
"channels:read": ChannelTypeString.PUBLIC_CHANNEL.value,
|
||||
"channels:history": ChannelTypeString.PUBLIC_CHANNEL.value,
|
||||
}
|
||||
|
||||
|
||||
def get_channel_type_for_missing_scope(scope: str) -> str | None:
|
||||
"""Get the channel type that requires a specific Slack scope.
|
||||
|
||||
Args:
|
||||
scope: The Slack API scope (e.g., 'mpim:read', 'im:history')
|
||||
|
||||
Returns:
|
||||
The channel type string if scope is recognized, None otherwise
|
||||
|
||||
Examples:
|
||||
>>> get_channel_type_for_missing_scope('mpim:read')
|
||||
'mpim'
|
||||
>>> get_channel_type_for_missing_scope('im:read')
|
||||
'im'
|
||||
>>> get_channel_type_for_missing_scope('unknown:scope')
|
||||
None
|
||||
"""
|
||||
return SCOPE_TO_CHANNEL_TYPE_MAP.get(scope)
|
||||
|
||||
|
||||
def _parse_llm_code_block_response(response: str) -> str:
|
||||
"""Remove code block markers from LLM response if present.
|
||||
|
||||
@@ -109,40 +64,11 @@ def _parse_llm_code_block_response(response: str) -> str:
|
||||
|
||||
|
||||
def is_recency_query(query: str) -> bool:
|
||||
"""Check if a query is primarily about recency (not content + recency).
|
||||
|
||||
Returns True only for pure recency queries like "recent messages" or "latest updates",
|
||||
but False for queries with content + recency like "golf scores last saturday".
|
||||
"""
|
||||
# Check if query contains recency keywords
|
||||
has_recency_keyword = any(
|
||||
return any(
|
||||
re.search(rf"\b{re.escape(keyword)}\b", query, flags=re.IGNORECASE)
|
||||
for keyword in RECENCY_KEYWORDS
|
||||
)
|
||||
|
||||
if not has_recency_keyword:
|
||||
return False
|
||||
|
||||
# Get combined stop words (NLTK + Slack-specific)
|
||||
all_stop_words = _get_combined_stop_words()
|
||||
|
||||
# Extract content words (excluding stop words)
|
||||
query_lower = query.lower()
|
||||
words = query_lower.split()
|
||||
|
||||
# Count content words (not stop words, length > 2)
|
||||
content_word_count = 0
|
||||
for word in words:
|
||||
clean_word = word.strip(WORD_PUNCTUATION)
|
||||
if clean_word and len(clean_word) > 2 and clean_word not in all_stop_words:
|
||||
content_word_count += 1
|
||||
|
||||
# If query has significant content words (>= 2), it's not a pure recency query
|
||||
# Examples:
|
||||
# - "recent messages" -> content_word_count = 0 -> pure recency
|
||||
# - "golf scores last saturday" -> content_word_count = 3 (golf, scores, saturday) -> not pure recency
|
||||
return content_word_count < 2
|
||||
|
||||
|
||||
def extract_date_range_from_query(
|
||||
query: str,
|
||||
@@ -157,21 +83,6 @@ def extract_date_range_from_query(
|
||||
if re.search(r"\byesterday\b", query_lower):
|
||||
return min(1, default_search_days)
|
||||
|
||||
# Handle "last [day of week]" - e.g., "last monday", "last saturday"
|
||||
days_of_week = [
|
||||
"monday",
|
||||
"tuesday",
|
||||
"wednesday",
|
||||
"thursday",
|
||||
"friday",
|
||||
"saturday",
|
||||
"sunday",
|
||||
]
|
||||
for day in days_of_week:
|
||||
if re.search(rf"\b(?:last|this)\s+{day}\b", query_lower):
|
||||
# Assume last occurrence of that day was within the past week
|
||||
return min(DAYS_PER_WEEK, default_search_days)
|
||||
|
||||
match = re.search(r"\b(?:last|past)\s+(\d+)\s+days?\b", query_lower)
|
||||
if match:
|
||||
days = int(match.group(1))
|
||||
@@ -209,40 +120,22 @@ def extract_date_range_from_query(
|
||||
|
||||
try:
|
||||
data = json.loads(response_clean)
|
||||
if not isinstance(data, dict):
|
||||
logger.debug(
|
||||
f"LLM date extraction returned non-dict response for query: "
|
||||
f"'{query}', using default: {default_search_days} days"
|
||||
)
|
||||
return default_search_days
|
||||
|
||||
days_back = data.get("days_back")
|
||||
if days_back is None:
|
||||
logger.debug(
|
||||
f"LLM date extraction returned null for query: '{query}', "
|
||||
f"using default: {default_search_days} days"
|
||||
f"LLM date extraction returned null for query: '{query}', using default: {default_search_days} days"
|
||||
)
|
||||
return default_search_days
|
||||
|
||||
if not isinstance(days_back, (int, float)):
|
||||
logger.debug(
|
||||
f"LLM date extraction returned non-numeric days_back for "
|
||||
f"query: '{query}', using default: {default_search_days} days"
|
||||
)
|
||||
return default_search_days
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.debug(
|
||||
f"Failed to parse LLM date extraction response for query: '{query}' "
|
||||
f"(response: '{response_clean}'), "
|
||||
f"using default: {default_search_days} days"
|
||||
f"Failed to parse LLM date extraction response for query: '{query}', using default: {default_search_days} days"
|
||||
)
|
||||
return default_search_days
|
||||
|
||||
return min(int(days_back), default_search_days)
|
||||
return min(days_back, default_search_days)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting date range with LLM for query '{query}': {e}")
|
||||
logger.warning(f"Error extracting date range with LLM: {e}")
|
||||
return default_search_days
|
||||
|
||||
|
||||
@@ -520,29 +413,6 @@ SLACK_SPECIFIC_STOP_WORDS = frozenset(
|
||||
)
|
||||
|
||||
|
||||
def _get_combined_stop_words() -> set[str]:
|
||||
"""Get combined NLTK + Slack-specific stop words.
|
||||
|
||||
Returns a set of stop words for filtering content words.
|
||||
Falls back to just Slack-specific stop words if NLTK is unavailable.
|
||||
|
||||
Note: Currently only supports English stop words. Non-English queries
|
||||
may have suboptimal content word extraction. Future enhancement could
|
||||
detect query language and load appropriate stop words.
|
||||
"""
|
||||
try:
|
||||
from nltk.corpus import stopwords # type: ignore
|
||||
|
||||
# TODO: Support multiple languages - currently hardcoded to English
|
||||
# Could detect language or allow configuration
|
||||
nltk_stop_words = set(stopwords.words("english"))
|
||||
except Exception:
|
||||
# Fallback if NLTK not available
|
||||
nltk_stop_words = set()
|
||||
|
||||
return nltk_stop_words | SLACK_SPECIFIC_STOP_WORDS
|
||||
|
||||
|
||||
def extract_content_words_from_recency_query(
|
||||
query_text: str, channel_references: set[str]
|
||||
) -> list[str]:
|
||||
@@ -557,19 +427,28 @@ def extract_content_words_from_recency_query(
|
||||
Returns:
|
||||
List of content words (up to MAX_CONTENT_WORDS)
|
||||
"""
|
||||
# Get combined stop words (NLTK + Slack-specific)
|
||||
all_stop_words = _get_combined_stop_words()
|
||||
# Get standard English stop words from NLTK (lazy import)
|
||||
try:
|
||||
from nltk.corpus import stopwords # type: ignore
|
||||
|
||||
nltk_stop_words = set(stopwords.words("english"))
|
||||
except Exception:
|
||||
# Fallback if NLTK not available
|
||||
nltk_stop_words = set()
|
||||
|
||||
# Combine NLTK stop words with Slack-specific stop words
|
||||
all_stop_words = nltk_stop_words | SLACK_SPECIFIC_STOP_WORDS
|
||||
|
||||
words = query_text.split()
|
||||
content_words = []
|
||||
|
||||
for word in words:
|
||||
clean_word = word.lower().strip(WORD_PUNCTUATION)
|
||||
clean_word = word.lower().strip(".,!?;:\"'#")
|
||||
# Skip if it's a channel reference or a stop word
|
||||
if clean_word in channel_references:
|
||||
continue
|
||||
if clean_word and clean_word not in all_stop_words and len(clean_word) > 2:
|
||||
clean_word_orig = word.strip(WORD_PUNCTUATION)
|
||||
clean_word_orig = word.strip(".,!?;:\"'#")
|
||||
if clean_word_orig.lower() not in all_stop_words:
|
||||
content_words.append(clean_word_orig)
|
||||
|
||||
|
||||
@@ -442,25 +442,10 @@ def set_cc_pair_repeated_error_state(
|
||||
cc_pair_id: int,
|
||||
in_repeated_error_state: bool,
|
||||
) -> None:
|
||||
values: dict = {"in_repeated_error_state": in_repeated_error_state}
|
||||
|
||||
# When entering repeated error state, also pause the connector
|
||||
# to prevent continued indexing retry attempts.
|
||||
# However, don't pause if there's an active manual indexing trigger,
|
||||
# which indicates the user wants to retry immediately.
|
||||
if in_repeated_error_state:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
# Only pause if there's no manual indexing trigger active
|
||||
if cc_pair and cc_pair.indexing_trigger is None:
|
||||
values["status"] = ConnectorCredentialPairStatus.PAUSED
|
||||
|
||||
stmt = (
|
||||
update(ConnectorCredentialPair)
|
||||
.where(ConnectorCredentialPair.id == cc_pair_id)
|
||||
.values(**values)
|
||||
.values(in_repeated_error_state=in_repeated_error_state)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
db_session.commit()
|
||||
|
||||
@@ -2482,50 +2482,6 @@ class CloudEmbeddingProvider(Base):
|
||||
return f"<EmbeddingProvider(type='{self.provider_type}')>"
|
||||
|
||||
|
||||
class InternetSearchProvider(Base):
|
||||
__tablename__ = "internet_search_provider"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True, nullable=False)
|
||||
provider_type: Mapped[str] = mapped_column(String, nullable=False)
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
|
||||
config: Mapped[dict[str, str] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<InternetSearchProvider(name='{self.name}', provider_type='{self.provider_type}')>"
|
||||
|
||||
|
||||
class InternetContentProvider(Base):
|
||||
__tablename__ = "internet_content_provider"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True, nullable=False)
|
||||
provider_type: Mapped[str] = mapped_column(String, nullable=False)
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
|
||||
config: Mapped[dict[str, str] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<InternetContentProvider(name='{self.name}', provider_type='{self.provider_type}')>"
|
||||
|
||||
|
||||
class DocumentSet(Base):
|
||||
__tablename__ = "document_set"
|
||||
|
||||
|
||||
@@ -1,309 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import InternetContentProvider
|
||||
from onyx.db.models import InternetSearchProvider
|
||||
from shared_configs.enums import WebContentProviderType
|
||||
from shared_configs.enums import WebSearchProviderType
|
||||
|
||||
|
||||
def fetch_web_search_providers(db_session: Session) -> list[InternetSearchProvider]:
|
||||
stmt = select(InternetSearchProvider).order_by(InternetSearchProvider.id.asc())
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def fetch_web_content_providers(db_session: Session) -> list[InternetContentProvider]:
|
||||
stmt = select(InternetContentProvider).order_by(InternetContentProvider.id.asc())
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def fetch_active_web_search_provider(
|
||||
db_session: Session,
|
||||
) -> InternetSearchProvider | None:
|
||||
stmt = select(InternetSearchProvider).where(
|
||||
InternetSearchProvider.is_active.is_(True)
|
||||
)
|
||||
return db_session.scalars(stmt).first()
|
||||
|
||||
|
||||
def fetch_web_search_provider_by_id(
|
||||
provider_id: int, db_session: Session
|
||||
) -> InternetSearchProvider | None:
|
||||
return db_session.get(InternetSearchProvider, provider_id)
|
||||
|
||||
|
||||
def fetch_web_search_provider_by_name(
|
||||
name: str, db_session: Session
|
||||
) -> InternetSearchProvider | None:
|
||||
stmt = select(InternetSearchProvider).where(InternetSearchProvider.name.ilike(name))
|
||||
return db_session.scalars(stmt).first()
|
||||
|
||||
|
||||
def _ensure_unique_search_name(
|
||||
name: str, provider_id: int | None, db_session: Session
|
||||
) -> None:
|
||||
existing = fetch_web_search_provider_by_name(name=name, db_session=db_session)
|
||||
if existing and existing.id != provider_id:
|
||||
raise ValueError(f"A web search provider named '{name}' already exists.")
|
||||
|
||||
|
||||
def _apply_search_provider_updates(
|
||||
provider: InternetSearchProvider,
|
||||
*,
|
||||
name: str,
|
||||
provider_type: WebSearchProviderType,
|
||||
api_key: str | None,
|
||||
api_key_changed: bool,
|
||||
config: dict[str, str] | None,
|
||||
) -> None:
|
||||
provider.name = name
|
||||
provider.provider_type = provider_type.value
|
||||
provider.config = config
|
||||
if api_key_changed or provider.api_key is None:
|
||||
provider.api_key = api_key
|
||||
|
||||
|
||||
def upsert_web_search_provider(
|
||||
*,
|
||||
provider_id: int | None,
|
||||
name: str,
|
||||
provider_type: WebSearchProviderType,
|
||||
api_key: str | None,
|
||||
api_key_changed: bool,
|
||||
config: dict[str, str] | None,
|
||||
activate: bool,
|
||||
db_session: Session,
|
||||
) -> InternetSearchProvider:
|
||||
_ensure_unique_search_name(
|
||||
name=name, provider_id=provider_id, db_session=db_session
|
||||
)
|
||||
|
||||
provider: InternetSearchProvider | None = None
|
||||
if provider_id is not None:
|
||||
provider = fetch_web_search_provider_by_id(provider_id, db_session)
|
||||
if provider is None:
|
||||
raise ValueError(f"No web search provider with id {provider_id} exists.")
|
||||
else:
|
||||
provider = InternetSearchProvider()
|
||||
db_session.add(provider)
|
||||
|
||||
_apply_search_provider_updates(
|
||||
provider,
|
||||
name=name,
|
||||
provider_type=provider_type,
|
||||
api_key=api_key,
|
||||
api_key_changed=api_key_changed,
|
||||
config=config,
|
||||
)
|
||||
|
||||
db_session.flush()
|
||||
|
||||
if activate:
|
||||
set_active_web_search_provider(provider_id=provider.id, db_session=db_session)
|
||||
|
||||
db_session.commit()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def set_active_web_search_provider(
|
||||
*, provider_id: int | None, db_session: Session
|
||||
) -> InternetSearchProvider:
|
||||
if provider_id is None:
|
||||
raise ValueError("Cannot activate a provider without an id.")
|
||||
|
||||
provider = fetch_web_search_provider_by_id(provider_id, db_session)
|
||||
if provider is None:
|
||||
raise ValueError(f"No web search provider with id {provider_id} exists.")
|
||||
|
||||
db_session.execute(
|
||||
update(InternetSearchProvider)
|
||||
.where(
|
||||
InternetSearchProvider.is_active.is_(True),
|
||||
InternetSearchProvider.id != provider_id,
|
||||
)
|
||||
.values(is_active=False)
|
||||
)
|
||||
provider.is_active = True
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def deactivate_web_search_provider(
|
||||
*, provider_id: int | None, db_session: Session
|
||||
) -> InternetSearchProvider:
|
||||
if provider_id is None:
|
||||
raise ValueError("Cannot deactivate a provider without an id.")
|
||||
|
||||
provider = fetch_web_search_provider_by_id(provider_id, db_session)
|
||||
if provider is None:
|
||||
raise ValueError(f"No web search provider with id {provider_id} exists.")
|
||||
|
||||
provider.is_active = False
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def delete_web_search_provider(provider_id: int, db_session: Session) -> None:
|
||||
provider = fetch_web_search_provider_by_id(provider_id, db_session)
|
||||
if provider is None:
|
||||
raise ValueError(f"No web search provider with id {provider_id} exists.")
|
||||
|
||||
db_session.delete(provider)
|
||||
db_session.flush()
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# Content provider helpers
|
||||
|
||||
|
||||
def fetch_active_web_content_provider(
|
||||
db_session: Session,
|
||||
) -> InternetContentProvider | None:
|
||||
stmt = select(InternetContentProvider).where(
|
||||
InternetContentProvider.is_active.is_(True)
|
||||
)
|
||||
return db_session.scalars(stmt).first()
|
||||
|
||||
|
||||
def fetch_web_content_provider_by_id(
|
||||
provider_id: int, db_session: Session
|
||||
) -> InternetContentProvider | None:
|
||||
return db_session.get(InternetContentProvider, provider_id)
|
||||
|
||||
|
||||
def fetch_web_content_provider_by_name(
|
||||
name: str, db_session: Session
|
||||
) -> InternetContentProvider | None:
|
||||
stmt = select(InternetContentProvider).where(
|
||||
InternetContentProvider.name.ilike(name)
|
||||
)
|
||||
return db_session.scalars(stmt).first()
|
||||
|
||||
|
||||
def _ensure_unique_content_name(
|
||||
name: str, provider_id: int | None, db_session: Session
|
||||
) -> None:
|
||||
existing = fetch_web_content_provider_by_name(name=name, db_session=db_session)
|
||||
if existing and existing.id != provider_id:
|
||||
raise ValueError(f"A web content provider named '{name}' already exists.")
|
||||
|
||||
|
||||
def _apply_content_provider_updates(
|
||||
provider: InternetContentProvider,
|
||||
*,
|
||||
name: str,
|
||||
provider_type: WebContentProviderType,
|
||||
api_key: str | None,
|
||||
api_key_changed: bool,
|
||||
config: dict[str, str] | None,
|
||||
) -> None:
|
||||
provider.name = name
|
||||
provider.provider_type = provider_type.value
|
||||
provider.config = config
|
||||
if api_key_changed or provider.api_key is None:
|
||||
provider.api_key = api_key
|
||||
|
||||
|
||||
def upsert_web_content_provider(
|
||||
*,
|
||||
provider_id: int | None,
|
||||
name: str,
|
||||
provider_type: WebContentProviderType,
|
||||
api_key: str | None,
|
||||
api_key_changed: bool,
|
||||
config: dict[str, str] | None,
|
||||
activate: bool,
|
||||
db_session: Session,
|
||||
) -> InternetContentProvider:
|
||||
_ensure_unique_content_name(
|
||||
name=name, provider_id=provider_id, db_session=db_session
|
||||
)
|
||||
|
||||
provider: InternetContentProvider | None = None
|
||||
if provider_id is not None:
|
||||
provider = fetch_web_content_provider_by_id(provider_id, db_session)
|
||||
if provider is None:
|
||||
raise ValueError(f"No web content provider with id {provider_id} exists.")
|
||||
else:
|
||||
provider = InternetContentProvider()
|
||||
db_session.add(provider)
|
||||
|
||||
_apply_content_provider_updates(
|
||||
provider,
|
||||
name=name,
|
||||
provider_type=provider_type,
|
||||
api_key=api_key,
|
||||
api_key_changed=api_key_changed,
|
||||
config=config,
|
||||
)
|
||||
|
||||
db_session.flush()
|
||||
|
||||
if activate:
|
||||
set_active_web_content_provider(provider_id=provider.id, db_session=db_session)
|
||||
|
||||
db_session.commit()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def set_active_web_content_provider(
|
||||
*, provider_id: int | None, db_session: Session
|
||||
) -> InternetContentProvider:
|
||||
if provider_id is None:
|
||||
raise ValueError("Cannot activate a provider without an id.")
|
||||
|
||||
provider = fetch_web_content_provider_by_id(provider_id, db_session)
|
||||
if provider is None:
|
||||
raise ValueError(f"No web content provider with id {provider_id} exists.")
|
||||
|
||||
db_session.execute(
|
||||
update(InternetContentProvider)
|
||||
.where(
|
||||
InternetContentProvider.is_active.is_(True),
|
||||
InternetContentProvider.id != provider_id,
|
||||
)
|
||||
.values(is_active=False)
|
||||
)
|
||||
provider.is_active = True
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def deactivate_web_content_provider(
|
||||
*, provider_id: int | None, db_session: Session
|
||||
) -> InternetContentProvider:
|
||||
if provider_id is None:
|
||||
raise ValueError("Cannot deactivate a provider without an id.")
|
||||
|
||||
provider = fetch_web_content_provider_by_id(provider_id, db_session)
|
||||
if provider is None:
|
||||
raise ValueError(f"No web content provider with id {provider_id} exists.")
|
||||
|
||||
provider.is_active = False
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def delete_web_content_provider(provider_id: int, db_session: Session) -> None:
|
||||
provider = fetch_web_content_provider_by_id(provider_id, db_session)
|
||||
if provider is None:
|
||||
raise ValueError(f"No web content provider with id {provider_id} exists.")
|
||||
|
||||
db_session.delete(provider)
|
||||
db_session.flush()
|
||||
|
||||
db_session.commit()
|
||||
@@ -510,7 +510,6 @@ class LitellmLLM(LLM):
|
||||
# model params
|
||||
temperature=(1 if is_reasoning else self._temperature),
|
||||
timeout=timeout_override or self._timeout,
|
||||
**({"stream_options": {"include_usage": True}} if stream else {}),
|
||||
# For now, we don't support parallel tool calls
|
||||
# NOTE: we can't pass this in if tools are not specified
|
||||
# or else OpenAI throws an error
|
||||
|
||||
@@ -3,87 +3,29 @@ import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import TypedDict
|
||||
from typing import Union
|
||||
|
||||
from litellm import AllMessageValues
|
||||
from litellm.completion_extras.litellm_responses_transformation.transformation import (
|
||||
OpenAiResponsesToChatCompletionStreamIterator,
|
||||
)
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
|
||||
try:
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
extract_images_from_message,
|
||||
)
|
||||
except ImportError:
|
||||
extract_images_from_message = None # type: ignore[assignment]
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
extract_images_from_message,
|
||||
)
|
||||
from litellm.llms.ollama.chat.transformation import OllamaChatCompletionResponseIterator
|
||||
from litellm.llms.ollama.chat.transformation import OllamaChatConfig
|
||||
from litellm.llms.ollama.common_utils import OllamaError
|
||||
|
||||
try:
|
||||
from litellm.types.llms.ollama import OllamaChatCompletionMessage
|
||||
except ImportError:
|
||||
|
||||
class OllamaChatCompletionMessage(TypedDict, total=False): # type: ignore[no-redef]
|
||||
"""Fallback for LiteLLM versions where this TypedDict was removed."""
|
||||
|
||||
role: str
|
||||
content: Optional[str]
|
||||
images: Optional[List[Any]]
|
||||
thinking: Optional[str]
|
||||
tool_calls: Optional[List["OllamaToolCall"]]
|
||||
|
||||
|
||||
from litellm.types.llms.ollama import OllamaChatCompletionMessage
|
||||
from litellm.types.llms.ollama import OllamaToolCall
|
||||
from litellm.types.llms.ollama import OllamaToolCallFunction
|
||||
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
|
||||
from litellm.types.utils import ChatCompletionUsageBlock
|
||||
from litellm.types.utils import GenericStreamingChunk
|
||||
from litellm.types.utils import ModelResponseStream
|
||||
from litellm.utils import verbose_logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
if extract_images_from_message is None:
|
||||
|
||||
def extract_images_from_message(
|
||||
message: AllMessageValues,
|
||||
) -> Optional[List[Any]]:
|
||||
"""Fallback for LiteLLM versions that dropped extract_images_from_message."""
|
||||
|
||||
images: List[Any] = []
|
||||
content = message.get("content")
|
||||
if not isinstance(content, list):
|
||||
return None
|
||||
|
||||
for item in content:
|
||||
if not isinstance(item, Dict):
|
||||
continue
|
||||
|
||||
item_type = item.get("type")
|
||||
if item_type == "image_url":
|
||||
image_url = item.get("image_url")
|
||||
if isinstance(image_url, dict):
|
||||
if image_url.get("url"):
|
||||
images.append(image_url)
|
||||
elif image_url:
|
||||
images.append(image_url)
|
||||
elif item_type in {"input_image", "image"}:
|
||||
image_value = item.get("image")
|
||||
if image_value:
|
||||
images.append(image_value)
|
||||
|
||||
return images or None
|
||||
|
||||
|
||||
def _patch_ollama_transform_request() -> None:
|
||||
"""
|
||||
Patches OllamaChatConfig.transform_request to handle reasoning content
|
||||
@@ -312,189 +254,16 @@ def _patch_ollama_chunk_parser() -> None:
|
||||
OllamaChatCompletionResponseIterator.chunk_parser = _patched_chunk_parser # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_openai_responses_chunk_parser() -> None:
|
||||
"""
|
||||
Patches OpenAiResponsesToChatCompletionStreamIterator.chunk_parser to properly
|
||||
handle OpenAI Responses API streaming format and convert it to chat completion format.
|
||||
"""
|
||||
if (
|
||||
getattr(
|
||||
OpenAiResponsesToChatCompletionStreamIterator.chunk_parser,
|
||||
"__name__",
|
||||
"",
|
||||
)
|
||||
== "_patched_openai_responses_chunk_parser"
|
||||
):
|
||||
return
|
||||
|
||||
def _patched_openai_responses_chunk_parser(
|
||||
self: Any, chunk: dict
|
||||
) -> Union["GenericStreamingChunk", "ModelResponseStream"]:
|
||||
# Transform responses API streaming chunk to chat completion format
|
||||
from litellm.types.llms.openai import ChatCompletionToolCallFunctionChunk
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionToolCallChunk,
|
||||
GenericStreamingChunk,
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Chat provider: transform_streaming_response called with chunk: {chunk}"
|
||||
)
|
||||
parsed_chunk = chunk
|
||||
if not parsed_chunk:
|
||||
raise ValueError("Chat provider: Empty parsed_chunk")
|
||||
if not isinstance(parsed_chunk, dict):
|
||||
raise ValueError(f"Chat provider: Invalid chunk type {type(parsed_chunk)}")
|
||||
# Handle different event types from responses API
|
||||
|
||||
event_type = parsed_chunk.get("type")
|
||||
verbose_logger.debug(f"Chat provider: Processing event type: {event_type}")
|
||||
|
||||
if event_type == "response.created":
|
||||
# Initial response creation event
|
||||
verbose_logger.debug(f"Chat provider: response.created -> {chunk}")
|
||||
return GenericStreamingChunk(
|
||||
text="", tool_use=None, is_finished=False, finish_reason="", usage=None
|
||||
)
|
||||
|
||||
elif event_type == "response.output_item.added":
|
||||
# New output item added
|
||||
output_item = parsed_chunk.get("item", {})
|
||||
if output_item.get("type") == "function_call":
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
id=output_item.get("call_id"),
|
||||
index=0,
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=output_item.get("name", None),
|
||||
arguments=parsed_chunk.get("arguments", ""),
|
||||
),
|
||||
),
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
)
|
||||
elif output_item.get("type") == "message":
|
||||
pass
|
||||
elif output_item.get("type") == "reasoning":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Chat provider: Invalid output_item {output_item}")
|
||||
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
content_part: Optional[str] = parsed_chunk.get("delta", None)
|
||||
if content_part:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
id=None,
|
||||
index=0,
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=None, arguments=content_part
|
||||
),
|
||||
),
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Chat provider: Invalid function argument delta {parsed_chunk}"
|
||||
)
|
||||
|
||||
elif event_type == "response.output_item.done":
|
||||
# New output item added
|
||||
output_item = parsed_chunk.get("item", {})
|
||||
if output_item.get("type") == "function_call":
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
id=output_item.get("call_id"),
|
||||
index=0,
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=parsed_chunk.get("name", None),
|
||||
arguments="", # responses API sends everything again, we don't
|
||||
),
|
||||
),
|
||||
is_finished=True,
|
||||
finish_reason="tool_calls",
|
||||
usage=None,
|
||||
)
|
||||
elif output_item.get("type") == "message":
|
||||
return GenericStreamingChunk(
|
||||
finish_reason="stop", is_finished=True, usage=None, text=""
|
||||
)
|
||||
elif output_item.get("type") == "reasoning":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Chat provider: Invalid output_item {output_item}")
|
||||
|
||||
elif event_type == "response.output_text.delta":
|
||||
# Content part added to output
|
||||
content_part = parsed_chunk.get("delta", None)
|
||||
if content_part is not None:
|
||||
return GenericStreamingChunk(
|
||||
text=content_part,
|
||||
tool_use=None,
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Chat provider: Invalid text delta {parsed_chunk}")
|
||||
|
||||
elif event_type == "response.reasoning_summary_text.delta":
|
||||
content_part = parsed_chunk.get("delta", None)
|
||||
if content_part:
|
||||
from litellm.types.utils import (
|
||||
Delta,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
return ModelResponseStream(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
index=cast(int, parsed_chunk.get("summary_index")),
|
||||
delta=Delta(reasoning_content=content_part),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
# For any unhandled event types, create a minimal valid chunk or skip
|
||||
verbose_logger.debug(
|
||||
f"Chat provider: Unhandled event type '{event_type}', creating empty chunk"
|
||||
)
|
||||
# Return a minimal valid chunk for unknown events
|
||||
return GenericStreamingChunk(
|
||||
text="", tool_use=None, is_finished=False, finish_reason="", usage=None
|
||||
)
|
||||
|
||||
_patched_openai_responses_chunk_parser.__name__ = (
|
||||
"_patched_openai_responses_chunk_parser"
|
||||
)
|
||||
OpenAiResponsesToChatCompletionStreamIterator.chunk_parser = _patched_openai_responses_chunk_parser # type: ignore[method-assign]
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for compatibility.
|
||||
Apply all necessary monkey patches to LiteLLM for Ollama compatibility.
|
||||
|
||||
This includes:
|
||||
- Patching OllamaChatConfig.transform_request for reasoning content support
|
||||
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
|
||||
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
|
||||
"""
|
||||
_patch_ollama_transform_request()
|
||||
_patch_ollama_chunk_parser()
|
||||
_patch_openai_responses_chunk_parser()
|
||||
|
||||
|
||||
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:
|
||||
|
||||
146
backend/onyx/llm/message_format.py
Normal file
146
backend/onyx/llm/message_format.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
|
||||
from onyx.llm.message_types import AssistantMessage
|
||||
from onyx.llm.message_types import ChatCompletionMessage
|
||||
from onyx.llm.message_types import ContentPart
|
||||
from onyx.llm.message_types import ImageContentPart
|
||||
from onyx.llm.message_types import SystemMessage
|
||||
from onyx.llm.message_types import TextContentPart
|
||||
from onyx.llm.message_types import UserMessageWithParts
|
||||
from onyx.llm.message_types import UserMessageWithText
|
||||
|
||||
|
||||
def base_messages_to_chat_completion_msgs(
|
||||
msgs: Sequence[BaseMessage],
|
||||
) -> list[ChatCompletionMessage]:
|
||||
return [_base_message_to_chat_completion_msg(msg) for msg in msgs]
|
||||
|
||||
|
||||
def _base_message_to_chat_completion_msg(msg: BaseMessage) -> ChatCompletionMessage:
|
||||
message_type_to_role = {
|
||||
"human": "user",
|
||||
"system": "system",
|
||||
"ai": "assistant",
|
||||
"tool": "tool",
|
||||
}
|
||||
role = message_type_to_role[msg.type]
|
||||
|
||||
content = msg.content
|
||||
|
||||
if isinstance(content, str):
|
||||
# Simple string content
|
||||
if role == "system":
|
||||
system_msg: SystemMessage = {
|
||||
"role": "system",
|
||||
"content": content,
|
||||
}
|
||||
return system_msg
|
||||
elif role == "user":
|
||||
user_msg: UserMessageWithText = {
|
||||
"role": "user",
|
||||
"content": content,
|
||||
}
|
||||
return user_msg
|
||||
else: # assistant
|
||||
assistant_msg: AssistantMessage = {
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
}
|
||||
return assistant_msg
|
||||
|
||||
elif isinstance(content, list):
|
||||
# List content - need to convert to OpenAI format
|
||||
if role == "assistant":
|
||||
# For assistant, convert list to simple string
|
||||
# (OpenAI format uses string content, not list)
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
text_parts.append(item)
|
||||
elif isinstance(item, dict) and item.get("type") == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected item type for assistant message: {type(item)}. Item: {item}"
|
||||
)
|
||||
|
||||
assistant_msg_from_list: AssistantMessage = {
|
||||
"role": "assistant",
|
||||
"content": " ".join(text_parts) if text_parts else None,
|
||||
}
|
||||
return assistant_msg_from_list
|
||||
|
||||
else: # system or user
|
||||
content_parts: list[ContentPart] = []
|
||||
has_images = False
|
||||
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
content_parts.append(TextContentPart(type="text", text=item))
|
||||
elif isinstance(item, dict):
|
||||
item_type = item.get("type")
|
||||
if item_type == "text":
|
||||
content_parts.append(
|
||||
TextContentPart(type="text", text=item.get("text", ""))
|
||||
)
|
||||
elif item_type == "image_url":
|
||||
has_images = True
|
||||
# Convert image_url to OpenAI format
|
||||
image_url = item.get("image_url", {})
|
||||
if isinstance(image_url, dict):
|
||||
url = image_url.get("url", "")
|
||||
detail = image_url.get("detail", "auto")
|
||||
else:
|
||||
url = image_url
|
||||
detail = "auto"
|
||||
|
||||
image_part: ImageContentPart = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": url, "detail": detail},
|
||||
}
|
||||
content_parts.append(image_part)
|
||||
else:
|
||||
raise ValueError(f"Unexpected item type: {item_type}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected item type: {type(item)}. Item: {item}"
|
||||
)
|
||||
|
||||
if role == "system":
|
||||
# System messages should be text only, concatenate all text parts
|
||||
text_parts = [
|
||||
part["text"] for part in content_parts if part["type"] == "text"
|
||||
]
|
||||
system_msg_from_list: SystemMessage = {
|
||||
"role": "system",
|
||||
"content": " ".join(text_parts),
|
||||
}
|
||||
return system_msg_from_list
|
||||
else: # user
|
||||
# If there are images, use the parts format; otherwise use simple string
|
||||
if has_images or len(content_parts) > 1:
|
||||
user_msg_with_parts: UserMessageWithParts = {
|
||||
"role": "user",
|
||||
"content": content_parts,
|
||||
}
|
||||
return user_msg_with_parts
|
||||
elif len(content_parts) == 1 and content_parts[0]["type"] == "text":
|
||||
# Single text part - use simple string format
|
||||
user_msg_simple: UserMessageWithText = {
|
||||
"role": "user",
|
||||
"content": content_parts[0]["text"],
|
||||
}
|
||||
return user_msg_simple
|
||||
else:
|
||||
# Empty content
|
||||
user_msg_empty: UserMessageWithText = {
|
||||
"role": "user",
|
||||
"content": "",
|
||||
}
|
||||
return user_msg_empty
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected content type: {type(content)}. Content: {content}"
|
||||
)
|
||||
@@ -38,19 +38,10 @@ class StreamingChoice(BaseModel):
|
||||
delta: Delta = Field(default_factory=Delta)
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
completion_tokens: int
|
||||
prompt_tokens: int
|
||||
total_tokens: int
|
||||
cache_creation_input_tokens: int
|
||||
cache_read_input_tokens: int
|
||||
|
||||
|
||||
class ModelResponseStream(BaseModel):
|
||||
id: str
|
||||
created: str
|
||||
choice: StreamingChoice
|
||||
usage: Usage | None = None
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -183,24 +174,10 @@ def from_litellm_model_response_stream(
|
||||
delta=parsed_delta,
|
||||
)
|
||||
|
||||
usage_data = response_data.get("usage")
|
||||
return ModelResponseStream(
|
||||
id=response_id,
|
||||
created=created,
|
||||
choice=streaming_choice,
|
||||
usage=(
|
||||
Usage(
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
cache_creation_input_tokens=usage_data.get(
|
||||
"cache_creation_input_tokens", 0
|
||||
),
|
||||
cache_read_input_tokens=usage_data.get("cache_read_input_tokens", 0),
|
||||
)
|
||||
if usage_data
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -116,7 +116,6 @@ def litellm_exception_to_error_msg(
|
||||
from litellm.exceptions import Timeout
|
||||
from litellm.exceptions import ContentPolicyViolationError
|
||||
from litellm.exceptions import BudgetExceededError
|
||||
from litellm.exceptions import ServiceUnavailableError
|
||||
|
||||
core_exception = _unwrap_nested_exception(e)
|
||||
error_msg = str(core_exception)
|
||||
@@ -169,23 +168,6 @@ def litellm_exception_to_error_msg(
|
||||
if upstream_detail
|
||||
else f"{provider_name} rate limit exceeded: Please slow down your requests and try again later."
|
||||
)
|
||||
elif isinstance(core_exception, ServiceUnavailableError):
|
||||
provider_name = (
|
||||
llm.config.model_provider
|
||||
if llm is not None and llm.config.model_provider
|
||||
else "The LLM provider"
|
||||
)
|
||||
# Check if this is specifically the Bedrock "Too many connections" error
|
||||
if "Too many connections" in error_msg or "BedrockException" in error_msg:
|
||||
error_msg = (
|
||||
f"{provider_name} is experiencing high connection volume and cannot process your request right now. "
|
||||
"This typically happens when there are too many simultaneous requests to the AI model. "
|
||||
"Please wait a moment and try again. If this persists, contact your system administrator "
|
||||
"to review connection limits and retry configurations."
|
||||
)
|
||||
else:
|
||||
# Generic 503 Service Unavailable
|
||||
error_msg = f"{provider_name} service error: {str(core_exception)}"
|
||||
elif isinstance(core_exception, ContextWindowExceededError):
|
||||
error_msg = (
|
||||
"Context window exceeded: Your input is too long for the model to process."
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
@@ -103,9 +102,6 @@ from onyx.server.manage.llm.api import basic_router as llm_router
|
||||
from onyx.server.manage.search_settings import router as search_settings_router
|
||||
from onyx.server.manage.slack_bot import router as slack_bot_management_router
|
||||
from onyx.server.manage.users import router as user_router
|
||||
from onyx.server.manage.web_search.api import (
|
||||
admin_router as web_search_admin_router,
|
||||
)
|
||||
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
|
||||
from onyx.server.middleware.rate_limiting import close_auth_limiter
|
||||
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
|
||||
@@ -147,13 +143,6 @@ from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore", category=ResourceWarning, message=r"Unclosed client session"
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
"ignore", category=ResourceWarning, message=r"Unclosed connector"
|
||||
)
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
file_handlers = [
|
||||
@@ -403,7 +392,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, llm_router)
|
||||
include_router_with_global_prefix_prepended(application, embedding_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, embedding_router)
|
||||
include_router_with_global_prefix_prepended(application, web_search_admin_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, token_rate_limit_settings_router
|
||||
)
|
||||
|
||||
@@ -125,18 +125,18 @@ def get_versions() -> AllVersions:
|
||||
onyx=latest_stable_version,
|
||||
relational_db="postgres:15.2-alpine",
|
||||
index="vespaengine/vespa:8.277.17",
|
||||
nginx="nginx:1.25.5-alpine",
|
||||
nginx="nginx:1.23.4-alpine",
|
||||
),
|
||||
dev=ContainerVersions(
|
||||
onyx=latest_dev_version,
|
||||
relational_db="postgres:15.2-alpine",
|
||||
index="vespaengine/vespa:8.277.17",
|
||||
nginx="nginx:1.25.5-alpine",
|
||||
nginx="nginx:1.23.4-alpine",
|
||||
),
|
||||
migration=ContainerVersions(
|
||||
onyx="airgapped-intfloat-nomic-migration",
|
||||
relational_db="postgres:15.2-alpine",
|
||||
index="vespaengine/vespa:8.277.17",
|
||||
nginx="nginx:1.25.5-alpine",
|
||||
nginx="nginx:1.23.4-alpine",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,364 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
|
||||
build_content_provider_from_config,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
|
||||
build_search_provider_from_config,
|
||||
)
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.web_search import deactivate_web_content_provider
|
||||
from onyx.db.web_search import deactivate_web_search_provider
|
||||
from onyx.db.web_search import delete_web_content_provider
|
||||
from onyx.db.web_search import delete_web_search_provider
|
||||
from onyx.db.web_search import fetch_web_content_provider_by_name
|
||||
from onyx.db.web_search import fetch_web_content_providers
|
||||
from onyx.db.web_search import fetch_web_search_provider_by_name
|
||||
from onyx.db.web_search import fetch_web_search_providers
|
||||
from onyx.db.web_search import set_active_web_content_provider
|
||||
from onyx.db.web_search import set_active_web_search_provider
|
||||
from onyx.db.web_search import upsert_web_content_provider
|
||||
from onyx.db.web_search import upsert_web_search_provider
|
||||
from onyx.server.manage.web_search.models import WebContentProviderTestRequest
|
||||
from onyx.server.manage.web_search.models import WebContentProviderUpsertRequest
|
||||
from onyx.server.manage.web_search.models import WebContentProviderView
|
||||
from onyx.server.manage.web_search.models import WebSearchProviderTestRequest
|
||||
from onyx.server.manage.web_search.models import WebSearchProviderUpsertRequest
|
||||
from onyx.server.manage.web_search.models import WebSearchProviderView
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.enums import WebContentProviderType
|
||||
from shared_configs.enums import WebSearchProviderType
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/web-search")
|
||||
|
||||
|
||||
@admin_router.get("/search-providers", response_model=list[WebSearchProviderView])
|
||||
def list_search_providers(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[WebSearchProviderView]:
|
||||
providers = fetch_web_search_providers(db_session)
|
||||
return [
|
||||
WebSearchProviderView(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
provider_type=WebSearchProviderType(provider.provider_type),
|
||||
is_active=provider.is_active,
|
||||
config=provider.config or {},
|
||||
has_api_key=bool(provider.api_key),
|
||||
)
|
||||
for provider in providers
|
||||
]
|
||||
|
||||
|
||||
@admin_router.post("/search-providers", response_model=WebSearchProviderView)
|
||||
def upsert_search_provider_endpoint(
|
||||
request: WebSearchProviderUpsertRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> WebSearchProviderView:
|
||||
existing_by_name = fetch_web_search_provider_by_name(request.name, db_session)
|
||||
if (
|
||||
existing_by_name
|
||||
and request.id is not None
|
||||
and existing_by_name.id != request.id
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"A search provider named '{request.name}' already exists.",
|
||||
)
|
||||
|
||||
provider = upsert_web_search_provider(
|
||||
provider_id=request.id,
|
||||
name=request.name,
|
||||
provider_type=request.provider_type,
|
||||
api_key=request.api_key,
|
||||
api_key_changed=request.api_key_changed,
|
||||
config=request.config,
|
||||
activate=request.activate,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
return WebSearchProviderView(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
provider_type=WebSearchProviderType(provider.provider_type),
|
||||
is_active=provider.is_active,
|
||||
config=provider.config or {},
|
||||
has_api_key=bool(provider.api_key),
|
||||
)
|
||||
|
||||
|
||||
@admin_router.delete(
|
||||
"/search-providers/{provider_id}", status_code=204, response_class=Response
|
||||
)
|
||||
def delete_search_provider(
|
||||
provider_id: int,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
delete_web_search_provider(provider_id, db_session)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@admin_router.post("/search-providers/{provider_id}/activate")
|
||||
def activate_search_provider(
|
||||
provider_id: int,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> WebSearchProviderView:
|
||||
provider = set_active_web_search_provider(
|
||||
provider_id=provider_id, db_session=db_session
|
||||
)
|
||||
db_session.commit()
|
||||
return WebSearchProviderView(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
provider_type=WebSearchProviderType(provider.provider_type),
|
||||
is_active=provider.is_active,
|
||||
config=provider.config or {},
|
||||
has_api_key=bool(provider.api_key),
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/search-providers/{provider_id}/deactivate")
|
||||
def deactivate_search_provider(
|
||||
provider_id: int,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
deactivate_web_search_provider(provider_id=provider_id, db_session=db_session)
|
||||
db_session.commit()
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@admin_router.post("/search-providers/test")
|
||||
def test_search_provider(
|
||||
request: WebSearchProviderTestRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> dict[str, str]:
|
||||
try:
|
||||
provider = build_search_provider_from_config(
|
||||
provider_type=request.provider_type,
|
||||
api_key=request.api_key,
|
||||
config=request.config or {},
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
if provider is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Unable to build provider configuration."
|
||||
)
|
||||
|
||||
# Actually test the API key by making a real search call
|
||||
try:
|
||||
test_results = provider.search("test")
|
||||
if not test_results or not any(result.link for result in test_results):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API key validation failed: search returned no results.",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if (
|
||||
"api" in error_msg.lower()
|
||||
or "key" in error_msg.lower()
|
||||
or "auth" in error_msg.lower()
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid API key: {error_msg}",
|
||||
) from e
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"API key validation failed: {error_msg}",
|
||||
) from e
|
||||
|
||||
logger.info(
|
||||
f"Web search provider test succeeded for {request.provider_type.value}."
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@admin_router.get("/content-providers", response_model=list[WebContentProviderView])
|
||||
def list_content_providers(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[WebContentProviderView]:
|
||||
providers = fetch_web_content_providers(db_session)
|
||||
return [
|
||||
WebContentProviderView(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
provider_type=WebContentProviderType(provider.provider_type),
|
||||
is_active=provider.is_active,
|
||||
config=provider.config or {},
|
||||
has_api_key=bool(provider.api_key),
|
||||
)
|
||||
for provider in providers
|
||||
]
|
||||
|
||||
|
||||
@admin_router.post("/content-providers", response_model=WebContentProviderView)
|
||||
def upsert_content_provider_endpoint(
|
||||
request: WebContentProviderUpsertRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> WebContentProviderView:
|
||||
existing_by_name = fetch_web_content_provider_by_name(request.name, db_session)
|
||||
if (
|
||||
existing_by_name
|
||||
and request.id is not None
|
||||
and existing_by_name.id != request.id
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"A content provider named '{request.name}' already exists.",
|
||||
)
|
||||
|
||||
provider = upsert_web_content_provider(
|
||||
provider_id=request.id,
|
||||
name=request.name,
|
||||
provider_type=request.provider_type,
|
||||
api_key=request.api_key,
|
||||
api_key_changed=request.api_key_changed,
|
||||
config=request.config,
|
||||
activate=request.activate,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
return WebContentProviderView(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
provider_type=WebContentProviderType(provider.provider_type),
|
||||
is_active=provider.is_active,
|
||||
config=provider.config or {},
|
||||
has_api_key=bool(provider.api_key),
|
||||
)
|
||||
|
||||
|
||||
@admin_router.delete(
|
||||
"/content-providers/{provider_id}", status_code=204, response_class=Response
|
||||
)
|
||||
def delete_content_provider(
|
||||
provider_id: int,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
delete_web_content_provider(provider_id, db_session)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@admin_router.post("/content-providers/{provider_id}/activate")
|
||||
def activate_content_provider(
|
||||
provider_id: int,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> WebContentProviderView:
|
||||
provider = set_active_web_content_provider(
|
||||
provider_id=provider_id, db_session=db_session
|
||||
)
|
||||
db_session.commit()
|
||||
return WebContentProviderView(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
provider_type=WebContentProviderType(provider.provider_type),
|
||||
is_active=provider.is_active,
|
||||
config=provider.config or {},
|
||||
has_api_key=bool(provider.api_key),
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/content-providers/reset-default")
|
||||
def reset_content_provider_default(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
providers = fetch_web_content_providers(db_session)
|
||||
active_ids = [provider.id for provider in providers if provider.is_active]
|
||||
|
||||
for provider_id in active_ids:
|
||||
deactivate_web_content_provider(provider_id=provider_id, db_session=db_session)
|
||||
db_session.commit()
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@admin_router.post("/content-providers/{provider_id}/deactivate")
|
||||
def deactivate_content_provider(
|
||||
provider_id: int,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
deactivate_web_content_provider(provider_id=provider_id, db_session=db_session)
|
||||
db_session.commit()
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@admin_router.post("/content-providers/test")
|
||||
def test_content_provider(
|
||||
request: WebContentProviderTestRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> dict[str, str]:
|
||||
try:
|
||||
provider = build_content_provider_from_config(
|
||||
provider_type=request.provider_type,
|
||||
api_key=request.api_key,
|
||||
config=request.config or {},
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
if provider is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Unable to build provider configuration."
|
||||
)
|
||||
|
||||
# Actually test the API key by making a real content fetch call
|
||||
try:
|
||||
test_url = "https://example.com"
|
||||
test_results = provider.contents([test_url])
|
||||
if not test_results or not any(
|
||||
result.scrape_successful for result in test_results
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API key validation failed: content fetch returned no results.",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if (
|
||||
"api" in error_msg.lower()
|
||||
or "key" in error_msg.lower()
|
||||
or "auth" in error_msg.lower()
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid API key: {error_msg}",
|
||||
) from e
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"API key validation failed: {error_msg}",
|
||||
) from e
|
||||
|
||||
logger.info(
|
||||
f"Web content provider test succeeded for {request.provider_type.value}."
|
||||
)
|
||||
return {"status": "ok"}
|
||||
@@ -1,69 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from shared_configs.enums import WebContentProviderType
|
||||
from shared_configs.enums import WebSearchProviderType
|
||||
|
||||
|
||||
class WebSearchProviderView(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
provider_type: WebSearchProviderType
|
||||
is_active: bool
|
||||
config: dict[str, str] | None
|
||||
has_api_key: bool = Field(
|
||||
default=False,
|
||||
description="Indicates whether an API key is stored for this provider.",
|
||||
)
|
||||
|
||||
|
||||
class WebSearchProviderUpsertRequest(BaseModel):
|
||||
id: int | None = Field(default=None, description="Existing provider ID to update.")
|
||||
name: str
|
||||
provider_type: WebSearchProviderType
|
||||
config: dict[str, str] | None = None
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for the provider. Only required when creating or updating credentials.",
|
||||
)
|
||||
api_key_changed: bool = Field(
|
||||
default=False,
|
||||
description="Set to true when providing a new API key for an existing provider.",
|
||||
)
|
||||
activate: bool = Field(
|
||||
default=False,
|
||||
description="If true, sets this provider as the active one after upsert.",
|
||||
)
|
||||
|
||||
|
||||
class WebContentProviderView(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
provider_type: WebContentProviderType
|
||||
is_active: bool
|
||||
config: dict[str, str] | None
|
||||
has_api_key: bool = Field(default=False)
|
||||
|
||||
|
||||
class WebContentProviderUpsertRequest(BaseModel):
|
||||
id: int | None = None
|
||||
name: str
|
||||
provider_type: WebContentProviderType
|
||||
config: dict[str, str] | None = None
|
||||
api_key: str | None = None
|
||||
api_key_changed: bool = False
|
||||
activate: bool = False
|
||||
|
||||
|
||||
class WebSearchProviderTestRequest(BaseModel):
|
||||
provider_type: WebSearchProviderType
|
||||
api_key: str | None = None
|
||||
config: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class WebContentProviderTestRequest(BaseModel):
|
||||
provider_type: WebContentProviderType
|
||||
api_key: str | None = None
|
||||
config: dict[str, Any] | None = None
|
||||
@@ -16,7 +16,7 @@ class ForceUseTool(BaseModel):
|
||||
|
||||
def build_openai_tool_choice_dict(self) -> dict[str, Any]:
|
||||
"""Build dict in the format that OpenAI expects which tells them to use this tool."""
|
||||
return {"type": "function", "name": self.tool_name}
|
||||
return {"type": "function", "function": {"name": self.tool_name}}
|
||||
|
||||
|
||||
def filter_tools_for_force_tool_use(
|
||||
|
||||
@@ -5,8 +5,8 @@ from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.web_search import fetch_active_web_search_provider
|
||||
from onyx.configs.chat_configs import EXA_API_KEY
|
||||
from onyx.configs.chat_configs import SERPER_API_KEY
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
@@ -51,10 +51,8 @@ class WebSearchTool(Tool[None]):
|
||||
@override
|
||||
@classmethod
|
||||
def is_available(cls, db_session: Session) -> bool:
|
||||
"""Available only if an active web search provider is configured in the database."""
|
||||
with get_session_with_current_tenant() as session:
|
||||
provider = fetch_active_web_search_provider(session)
|
||||
return provider is not None
|
||||
"""Available only if EXA or SERPER API key is configured."""
|
||||
return bool(EXA_API_KEY) or bool(SERPER_API_KEY)
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
|
||||
@@ -6,20 +6,14 @@ from pydantic import TypeAdapter
|
||||
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebContentProvider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
|
||||
get_default_content_provider,
|
||||
get_default_provider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
|
||||
get_default_provider,
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
|
||||
dummy_inference_section_from_internet_content,
|
||||
@@ -196,7 +190,7 @@ changing or evolving.
|
||||
def _open_url_core(
|
||||
run_context: RunContextWrapper[ChatTurnContext],
|
||||
urls: Sequence[str],
|
||||
content_provider: WebContentProvider,
|
||||
search_provider: WebSearchProvider,
|
||||
) -> list[LlmOpenUrlResult]:
|
||||
# TODO: Find better way to track index that isn't so implicit
|
||||
# based on number of tool calls
|
||||
@@ -212,7 +206,7 @@ def _open_url_core(
|
||||
)
|
||||
)
|
||||
|
||||
docs = content_provider.contents(urls)
|
||||
docs = search_provider.contents(urls)
|
||||
results = [
|
||||
LlmOpenUrlResult(
|
||||
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
|
||||
@@ -273,10 +267,10 @@ def open_url(
|
||||
"""
|
||||
Tool for fetching and extracting full content from web pages.
|
||||
"""
|
||||
content_provider = get_default_content_provider()
|
||||
if content_provider is None:
|
||||
raise ValueError("No web content provider found")
|
||||
retrieved_docs = _open_url_core(run_context, urls, content_provider)
|
||||
search_provider = get_default_provider()
|
||||
if search_provider is None:
|
||||
raise ValueError("No search provider found")
|
||||
retrieved_docs = _open_url_core(run_context, urls, search_provider)
|
||||
adapter = TypeAdapter(list[LlmOpenUrlResult])
|
||||
return adapter.dump_json(retrieved_docs).decode()
|
||||
|
||||
|
||||
@@ -80,11 +80,7 @@ def setup_braintrust_if_creds_available() -> None:
|
||||
api_key=BRAINTRUST_API_KEY,
|
||||
)
|
||||
braintrust.set_masking_function(_mask)
|
||||
set_trace_processors([BraintrustTracingProcessor(braintrust_logger)])
|
||||
_setup_legacy_langchain_tracing()
|
||||
logger.notice("Braintrust tracing initialized")
|
||||
|
||||
|
||||
def _setup_legacy_langchain_tracing() -> None:
|
||||
handler = BraintrustCallbackHandler()
|
||||
set_global_handler(handler)
|
||||
set_trace_processors([BraintrustTracingProcessor(braintrust_logger)])
|
||||
logger.notice("Braintrust tracing initialized")
|
||||
|
||||
@@ -1,216 +0,0 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
import braintrust
|
||||
from braintrust import NOOP_SPAN
|
||||
|
||||
from .framework.processor_interface import TracingProcessor
|
||||
from .framework.span_data import AgentSpanData
|
||||
from .framework.span_data import FunctionSpanData
|
||||
from .framework.span_data import GenerationSpanData
|
||||
from .framework.span_data import SpanData
|
||||
from .framework.spans import Span
|
||||
from .framework.traces import Trace
|
||||
|
||||
|
||||
def _span_type(span: Span[Any]) -> braintrust.SpanTypeAttribute:
|
||||
if span.span_data.type in ["agent"]:
|
||||
return braintrust.SpanTypeAttribute.TASK
|
||||
elif span.span_data.type in ["function"]:
|
||||
return braintrust.SpanTypeAttribute.TOOL
|
||||
elif span.span_data.type in ["generation"]:
|
||||
return braintrust.SpanTypeAttribute.LLM
|
||||
else:
|
||||
return braintrust.SpanTypeAttribute.TASK
|
||||
|
||||
|
||||
def _span_name(span: Span[Any]) -> str:
|
||||
if isinstance(span.span_data, AgentSpanData) or isinstance(
|
||||
span.span_data, FunctionSpanData
|
||||
):
|
||||
return span.span_data.name
|
||||
elif isinstance(span.span_data, GenerationSpanData):
|
||||
return "Generation"
|
||||
else:
|
||||
return "Unknown"
|
||||
|
||||
|
||||
def _timestamp_from_maybe_iso(timestamp: Optional[str]) -> Optional[float]:
|
||||
if timestamp is None:
|
||||
return None
|
||||
return datetime.datetime.fromisoformat(timestamp).timestamp()
|
||||
|
||||
|
||||
def _maybe_timestamp_elapsed(
|
||||
end: Optional[str], start: Optional[str]
|
||||
) -> Optional[float]:
|
||||
if start is None or end is None:
|
||||
return None
|
||||
return (
|
||||
datetime.datetime.fromisoformat(end) - datetime.datetime.fromisoformat(start)
|
||||
).total_seconds()
|
||||
|
||||
|
||||
class BraintrustTracingProcessor(TracingProcessor):
|
||||
"""
|
||||
`BraintrustTracingProcessor` is a `tracing.TracingProcessor` that logs traces to Braintrust.
|
||||
|
||||
Args:
|
||||
logger: A `braintrust.Span` or `braintrust.Experiment` or `braintrust.Logger` to use for logging.
|
||||
If `None`, the current span, experiment, or logger will be selected exactly as in `braintrust.start_span`.
|
||||
"""
|
||||
|
||||
def __init__(self, logger: Optional[braintrust.Logger] = None):
|
||||
self._logger = logger
|
||||
self._spans: Dict[str, Any] = {}
|
||||
self._first_input: Dict[str, Any] = {}
|
||||
self._last_output: Dict[str, Any] = {}
|
||||
|
||||
def on_trace_start(self, trace: Trace) -> None:
|
||||
trace_meta = trace.export() or {}
|
||||
metadata = {
|
||||
"group_id": trace_meta.get("group_id"),
|
||||
**(trace_meta.get("metadata") or {}),
|
||||
}
|
||||
|
||||
current_context = braintrust.current_span()
|
||||
if current_context != NOOP_SPAN:
|
||||
self._spans[trace.trace_id] = current_context.start_span( # type: ignore[assignment]
|
||||
name=trace.name,
|
||||
span_attributes={"type": "task", "name": trace.name},
|
||||
metadata=metadata,
|
||||
)
|
||||
elif self._logger is not None:
|
||||
self._spans[trace.trace_id] = self._logger.start_span( # type: ignore[assignment]
|
||||
span_attributes={"type": "task", "name": trace.name},
|
||||
span_id=trace.trace_id,
|
||||
root_span_id=trace.trace_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
else:
|
||||
self._spans[trace.trace_id] = braintrust.start_span( # type: ignore[assignment]
|
||||
id=trace.trace_id,
|
||||
span_attributes={"type": "task", "name": trace.name},
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def on_trace_end(self, trace: Trace) -> None:
|
||||
span: Any = self._spans.pop(trace.trace_id)
|
||||
# Get the first input and last output for this specific trace
|
||||
trace_first_input = self._first_input.pop(trace.trace_id, None)
|
||||
trace_last_output = self._last_output.pop(trace.trace_id, None)
|
||||
span.log(input=trace_first_input, output=trace_last_output)
|
||||
span.end()
|
||||
|
||||
def _agent_log_data(self, span: Span[AgentSpanData]) -> Dict[str, Any]:
|
||||
return {
|
||||
"metadata": {
|
||||
"tools": span.span_data.tools,
|
||||
"handoffs": span.span_data.handoffs,
|
||||
"output_type": span.span_data.output_type,
|
||||
}
|
||||
}
|
||||
|
||||
def _function_log_data(self, span: Span[FunctionSpanData]) -> Dict[str, Any]:
|
||||
return {
|
||||
"input": span.span_data.input,
|
||||
"output": span.span_data.output,
|
||||
}
|
||||
|
||||
def _generation_log_data(self, span: Span[GenerationSpanData]) -> Dict[str, Any]:
|
||||
metrics = {}
|
||||
ttft = _maybe_timestamp_elapsed(span.ended_at, span.started_at)
|
||||
|
||||
if ttft is not None:
|
||||
metrics["time_to_first_token"] = ttft
|
||||
|
||||
usage = span.span_data.usage or {}
|
||||
if "prompt_tokens" in usage:
|
||||
metrics["prompt_tokens"] = usage["prompt_tokens"]
|
||||
elif "input_tokens" in usage:
|
||||
metrics["prompt_tokens"] = usage["input_tokens"]
|
||||
|
||||
if "completion_tokens" in usage:
|
||||
metrics["completion_tokens"] = usage["completion_tokens"]
|
||||
elif "output_tokens" in usage:
|
||||
metrics["completion_tokens"] = usage["output_tokens"]
|
||||
|
||||
if "total_tokens" in usage:
|
||||
metrics["tokens"] = usage["total_tokens"]
|
||||
elif "input_tokens" in usage and "output_tokens" in usage:
|
||||
metrics["tokens"] = usage["input_tokens"] + usage["output_tokens"]
|
||||
|
||||
if "cache_read_input_tokens" in usage:
|
||||
metrics["prompt_cached_tokens"] = usage["cache_read_input_tokens"]
|
||||
if "cache_creation_input_tokens" in usage:
|
||||
metrics["prompt_cache_creation_tokens"] = usage[
|
||||
"cache_creation_input_tokens"
|
||||
]
|
||||
|
||||
return {
|
||||
"input": span.span_data.input,
|
||||
"output": span.span_data.output,
|
||||
"metadata": {
|
||||
"model": span.span_data.model,
|
||||
"model_config": span.span_data.model_config,
|
||||
},
|
||||
"metrics": metrics,
|
||||
}
|
||||
|
||||
def _log_data(self, span: Span[Any]) -> Dict[str, Any]:
|
||||
if isinstance(span.span_data, AgentSpanData):
|
||||
return self._agent_log_data(span)
|
||||
elif isinstance(span.span_data, FunctionSpanData):
|
||||
return self._function_log_data(span)
|
||||
elif isinstance(span.span_data, GenerationSpanData):
|
||||
return self._generation_log_data(span)
|
||||
else:
|
||||
return {}
|
||||
|
||||
def on_span_start(self, span: Span[SpanData]) -> None:
|
||||
parent: Any = (
|
||||
self._spans[span.parent_id]
|
||||
if span.parent_id is not None
|
||||
else self._spans[span.trace_id]
|
||||
)
|
||||
created_span: Any = parent.start_span(
|
||||
id=span.span_id,
|
||||
name=_span_name(span),
|
||||
type=_span_type(span),
|
||||
start_time=_timestamp_from_maybe_iso(span.started_at),
|
||||
)
|
||||
self._spans[span.span_id] = created_span
|
||||
|
||||
# Set the span as current so current_span() calls will return it
|
||||
created_span.set_current()
|
||||
|
||||
def on_span_end(self, span: Span[SpanData]) -> None:
|
||||
s: Any = self._spans.pop(span.span_id)
|
||||
event = dict(error=span.error, **self._log_data(span))
|
||||
s.log(**event)
|
||||
s.unset_current()
|
||||
s.end(_timestamp_from_maybe_iso(span.ended_at))
|
||||
|
||||
input_ = event.get("input")
|
||||
output = event.get("output")
|
||||
# Store first input and last output per trace_id
|
||||
trace_id = span.trace_id
|
||||
if trace_id not in self._first_input and input_ is not None:
|
||||
self._first_input[trace_id] = input_
|
||||
|
||||
if output is not None:
|
||||
self._last_output[trace_id] = output
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self._logger is not None:
|
||||
self._logger.flush()
|
||||
else:
|
||||
braintrust.flush()
|
||||
|
||||
def force_flush(self) -> None:
|
||||
if self._logger is not None:
|
||||
self._logger.flush()
|
||||
else:
|
||||
braintrust.flush()
|
||||
@@ -1,21 +0,0 @@
|
||||
from .processor_interface import TracingProcessor
|
||||
from .provider import DefaultTraceProvider
|
||||
from .setup import get_trace_provider
|
||||
from .setup import set_trace_provider
|
||||
|
||||
|
||||
def add_trace_processor(span_processor: TracingProcessor) -> None:
|
||||
"""
|
||||
Adds a new trace processor. This processor will receive all traces/spans.
|
||||
"""
|
||||
get_trace_provider().register_processor(span_processor)
|
||||
|
||||
|
||||
def set_trace_processors(processors: list[TracingProcessor]) -> None:
|
||||
"""
|
||||
Set the list of trace processors. This will replace the current list of processors.
|
||||
"""
|
||||
get_trace_provider().set_processors(processors)
|
||||
|
||||
|
||||
set_trace_provider(DefaultTraceProvider())
|
||||
@@ -1,21 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from .create import get_current_span
|
||||
from .spans import Span
|
||||
from .spans import SpanError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
def attach_error_to_span(span: Span[Any], error: SpanError) -> None:
|
||||
span.set_error(error)
|
||||
|
||||
|
||||
def attach_error_to_current_span(error: SpanError) -> None:
|
||||
span = get_current_span()
|
||||
if span:
|
||||
attach_error_to_span(span, error)
|
||||
else:
|
||||
logger.warning(f"No span to add error {error} to")
|
||||
@@ -1,192 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .setup import get_trace_provider
|
||||
from .span_data import AgentSpanData
|
||||
from .span_data import FunctionSpanData
|
||||
from .span_data import GenerationSpanData
|
||||
from .spans import Span
|
||||
from .traces import Trace
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
def trace(
|
||||
workflow_name: str,
|
||||
trace_id: str | None = None,
|
||||
group_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
disabled: bool = False,
|
||||
) -> Trace:
|
||||
"""
|
||||
Create a new trace. The trace will not be started automatically; you should either use
|
||||
it as a context manager (`with trace(...):`) or call `trace.start()` + `trace.finish()`
|
||||
manually.
|
||||
|
||||
In addition to the workflow name and optional grouping identifier, you can provide
|
||||
an arbitrary metadata dictionary to attach additional user-defined information to
|
||||
the trace.
|
||||
|
||||
Args:
|
||||
workflow_name: The name of the logical app or workflow. For example, you might provide
|
||||
"code_bot" for a coding agent, or "customer_support_agent" for a customer support agent.
|
||||
trace_id: The ID of the trace. Optional. If not provided, we will generate an ID. We
|
||||
recommend using `util.gen_trace_id()` to generate a trace ID, to guarantee that IDs are
|
||||
correctly formatted.
|
||||
group_id: Optional grouping identifier to link multiple traces from the same conversation
|
||||
or process. For instance, you might use a chat thread ID.
|
||||
metadata: Optional dictionary of additional metadata to attach to the trace.
|
||||
disabled: If True, we will return a Trace but the Trace will not be recorded.
|
||||
|
||||
Returns:
|
||||
The newly created trace object.
|
||||
"""
|
||||
current_trace = get_trace_provider().get_current_trace()
|
||||
if current_trace:
|
||||
logger.warning(
|
||||
"Trace already exists. Creating a new trace, but this is probably a mistake."
|
||||
)
|
||||
|
||||
return get_trace_provider().create_trace(
|
||||
name=workflow_name,
|
||||
trace_id=trace_id,
|
||||
group_id=group_id,
|
||||
metadata=metadata,
|
||||
disabled=disabled,
|
||||
)
|
||||
|
||||
|
||||
def get_current_trace() -> Trace | None:
|
||||
"""Returns the currently active trace, if present."""
|
||||
return get_trace_provider().get_current_trace()
|
||||
|
||||
|
||||
def get_current_span() -> Span[Any] | None:
|
||||
"""Returns the currently active span, if present."""
|
||||
return get_trace_provider().get_current_span()
|
||||
|
||||
|
||||
def agent_span(
|
||||
name: str,
|
||||
handoffs: list[str] | None = None,
|
||||
tools: list[str] | None = None,
|
||||
output_type: str | None = None,
|
||||
span_id: str | None = None,
|
||||
parent: Trace | Span[Any] | None = None,
|
||||
disabled: bool = False,
|
||||
) -> Span[AgentSpanData]:
|
||||
"""Create a new agent span. The span will not be started automatically, you should either do
|
||||
`with agent_span() ...` or call `span.start()` + `span.finish()` manually.
|
||||
|
||||
Args:
|
||||
name: The name of the agent.
|
||||
handoffs: Optional list of agent names to which this agent could hand off control.
|
||||
tools: Optional list of tool names available to this agent.
|
||||
output_type: Optional name of the output type produced by the agent.
|
||||
span_id: The ID of the span. Optional. If not provided, we will generate an ID. We
|
||||
recommend using `util.gen_span_id()` to generate a span ID, to guarantee that IDs are
|
||||
correctly formatted.
|
||||
parent: The parent span or trace. If not provided, we will automatically use the current
|
||||
trace/span as the parent.
|
||||
disabled: If True, we will return a Span but the Span will not be recorded.
|
||||
|
||||
Returns:
|
||||
The newly created agent span.
|
||||
"""
|
||||
return get_trace_provider().create_span(
|
||||
span_data=AgentSpanData(
|
||||
name=name, handoffs=handoffs, tools=tools, output_type=output_type
|
||||
),
|
||||
span_id=span_id,
|
||||
parent=parent,
|
||||
disabled=disabled,
|
||||
)
|
||||
|
||||
|
||||
def function_span(
|
||||
name: str,
|
||||
input: str | None = None,
|
||||
output: str | None = None,
|
||||
span_id: str | None = None,
|
||||
parent: Trace | Span[Any] | None = None,
|
||||
disabled: bool = False,
|
||||
) -> Span[FunctionSpanData]:
|
||||
"""Create a new function span. The span will not be started automatically, you should either do
|
||||
`with function_span() ...` or call `span.start()` + `span.finish()` manually.
|
||||
|
||||
Args:
|
||||
name: The name of the function.
|
||||
input: The input to the function.
|
||||
output: The output of the function.
|
||||
span_id: The ID of the span. Optional. If not provided, we will generate an ID. We
|
||||
recommend using `util.gen_span_id()` to generate a span ID, to guarantee that IDs are
|
||||
correctly formatted.
|
||||
parent: The parent span or trace. If not provided, we will automatically use the current
|
||||
trace/span as the parent.
|
||||
disabled: If True, we will return a Span but the Span will not be recorded.
|
||||
|
||||
Returns:
|
||||
The newly created function span.
|
||||
"""
|
||||
return get_trace_provider().create_span(
|
||||
span_data=FunctionSpanData(name=name, input=input, output=output),
|
||||
span_id=span_id,
|
||||
parent=parent,
|
||||
disabled=disabled,
|
||||
)
|
||||
|
||||
|
||||
def generation_span(
|
||||
input: Sequence[Mapping[str, Any]] | None = None,
|
||||
output: Sequence[Mapping[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
model_config: Mapping[str, Any] | None = None,
|
||||
usage: dict[str, Any] | None = None,
|
||||
span_id: str | None = None,
|
||||
parent: Trace | Span[Any] | None = None,
|
||||
disabled: bool = False,
|
||||
) -> Span[GenerationSpanData]:
|
||||
"""Create a new generation span. The span will not be started automatically, you should either
|
||||
do `with generation_span() ...` or call `span.start()` + `span.finish()` manually.
|
||||
|
||||
This span captures the details of a model generation, including the
|
||||
input message sequence, any generated outputs, the model name and
|
||||
configuration, and usage data. If you only need to capture a model
|
||||
response identifier, use `response_span()` instead.
|
||||
|
||||
Args:
|
||||
input: The sequence of input messages sent to the model.
|
||||
output: The sequence of output messages received from the model.
|
||||
model: The model identifier used for the generation.
|
||||
model_config: The model configuration (hyperparameters) used.
|
||||
usage: A dictionary of usage information (input tokens, output tokens, etc.).
|
||||
span_id: The ID of the span. Optional. If not provided, we will generate an ID. We
|
||||
recommend using `util.gen_span_id()` to generate a span ID, to guarantee that IDs are
|
||||
correctly formatted.
|
||||
parent: The parent span or trace. If not provided, we will automatically use the current
|
||||
trace/span as the parent.
|
||||
disabled: If True, we will return a Span but the Span will not be recorded.
|
||||
|
||||
Returns:
|
||||
The newly created generation span.
|
||||
"""
|
||||
return get_trace_provider().create_span(
|
||||
span_data=GenerationSpanData(
|
||||
input=input,
|
||||
output=output,
|
||||
model=model,
|
||||
model_config=model_config,
|
||||
usage=usage,
|
||||
),
|
||||
span_id=span_id,
|
||||
parent=parent,
|
||||
disabled=disabled,
|
||||
)
|
||||
@@ -1,136 +0,0 @@
|
||||
import abc
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .spans import Span
|
||||
from .traces import Trace
|
||||
|
||||
|
||||
class TracingProcessor(abc.ABC):
|
||||
"""Interface for processing and monitoring traces and spans in the OpenAI Agents system.
|
||||
|
||||
This abstract class defines the interface that all tracing processors must implement.
|
||||
Processors receive notifications when traces and spans start and end, allowing them
|
||||
to collect, process, and export tracing data.
|
||||
|
||||
Example:
|
||||
```python
|
||||
class CustomProcessor(TracingProcessor):
|
||||
def __init__(self):
|
||||
self.active_traces = {}
|
||||
self.active_spans = {}
|
||||
|
||||
def on_trace_start(self, trace):
|
||||
self.active_traces[trace.trace_id] = trace
|
||||
|
||||
def on_trace_end(self, trace):
|
||||
# Process completed trace
|
||||
del self.active_traces[trace.trace_id]
|
||||
|
||||
def on_span_start(self, span):
|
||||
self.active_spans[span.span_id] = span
|
||||
|
||||
def on_span_end(self, span):
|
||||
# Process completed span
|
||||
del self.active_spans[span.span_id]
|
||||
|
||||
def shutdown(self):
|
||||
# Clean up resources
|
||||
self.active_traces.clear()
|
||||
self.active_spans.clear()
|
||||
|
||||
def force_flush(self):
|
||||
# Force processing of any queued items
|
||||
pass
|
||||
```
|
||||
|
||||
Notes:
|
||||
- All methods should be thread-safe
|
||||
- Methods should not block for long periods
|
||||
- Handle errors gracefully to prevent disrupting agent execution
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_trace_start(self, trace: "Trace") -> None:
|
||||
"""Called when a new trace begins execution.
|
||||
|
||||
Args:
|
||||
trace: The trace that started. Contains workflow name and metadata.
|
||||
|
||||
Notes:
|
||||
- Called synchronously on trace start
|
||||
- Should return quickly to avoid blocking execution
|
||||
- Any errors should be caught and handled internally
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_trace_end(self, trace: "Trace") -> None:
|
||||
"""Called when a trace completes execution.
|
||||
|
||||
Args:
|
||||
trace: The completed trace containing all spans and results.
|
||||
|
||||
Notes:
|
||||
- Called synchronously when trace finishes
|
||||
- Good time to export/process the complete trace
|
||||
- Should handle cleanup of any trace-specific resources
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_span_start(self, span: "Span[Any]") -> None:
|
||||
"""Called when a new span begins execution.
|
||||
|
||||
Args:
|
||||
span: The span that started. Contains operation details and context.
|
||||
|
||||
Notes:
|
||||
- Called synchronously on span start
|
||||
- Should return quickly to avoid blocking execution
|
||||
- Spans are automatically nested under current trace/span
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_span_end(self, span: "Span[Any]") -> None:
|
||||
"""Called when a span completes execution.
|
||||
|
||||
Args:
|
||||
span: The completed span containing execution results.
|
||||
|
||||
Notes:
|
||||
- Called synchronously when span finishes
|
||||
- Should not block or raise exceptions
|
||||
- Good time to export/process the individual span
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def shutdown(self) -> None:
|
||||
"""Called when the application stops to clean up resources.
|
||||
|
||||
Should perform any necessary cleanup like:
|
||||
- Flushing queued traces/spans
|
||||
- Closing connections
|
||||
- Releasing resources
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def force_flush(self) -> None:
|
||||
"""Forces immediate processing of any queued traces/spans.
|
||||
|
||||
Notes:
|
||||
- Should process all queued items before returning
|
||||
- Useful before shutdown or when immediate processing is needed
|
||||
- May block while processing completes
|
||||
"""
|
||||
|
||||
|
||||
class TracingExporter(abc.ABC):
|
||||
"""Exports traces and spans. For example, could log them or send them to a backend."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def export(self, items: list["Trace | Span[Any]"]) -> None:
|
||||
"""Exports a list of traces and spans.
|
||||
|
||||
Args:
|
||||
items: The items to export.
|
||||
"""
|
||||
@@ -1,319 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from .processor_interface import TracingProcessor
|
||||
from .scope import Scope
|
||||
from .spans import NoOpSpan
|
||||
from .spans import Span
|
||||
from .spans import SpanImpl
|
||||
from .spans import TSpanData
|
||||
from .traces import NoOpTrace
|
||||
from .traces import Trace
|
||||
from .traces import TraceImpl
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class SynchronousMultiTracingProcessor(TracingProcessor):
|
||||
"""
|
||||
Forwards all calls to a list of TracingProcessors, in order of registration.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Using a tuple to avoid race conditions when iterating over processors
|
||||
self._processors: tuple[TracingProcessor, ...] = ()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def add_tracing_processor(self, tracing_processor: TracingProcessor) -> None:
|
||||
"""
|
||||
Add a processor to the list of processors. Each processor will receive all traces/spans.
|
||||
"""
|
||||
with self._lock:
|
||||
self._processors += (tracing_processor,)
|
||||
|
||||
def set_processors(self, processors: list[TracingProcessor]) -> None:
|
||||
"""
|
||||
Set the list of processors. This will replace the current list of processors.
|
||||
"""
|
||||
with self._lock:
|
||||
self._processors = tuple(processors)
|
||||
|
||||
def on_trace_start(self, trace: Trace) -> None:
|
||||
"""
|
||||
Called when a trace is started.
|
||||
"""
|
||||
for processor in self._processors:
|
||||
try:
|
||||
processor.on_trace_start(trace)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in trace processor {processor} during on_trace_start: {e}"
|
||||
)
|
||||
|
||||
def on_trace_end(self, trace: Trace) -> None:
|
||||
"""
|
||||
Called when a trace is finished.
|
||||
"""
|
||||
for processor in self._processors:
|
||||
try:
|
||||
processor.on_trace_end(trace)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in trace processor {processor} during on_trace_end: {e}"
|
||||
)
|
||||
|
||||
def on_span_start(self, span: Span[Any]) -> None:
|
||||
"""
|
||||
Called when a span is started.
|
||||
"""
|
||||
for processor in self._processors:
|
||||
try:
|
||||
processor.on_span_start(span)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in trace processor {processor} during on_span_start: {e}"
|
||||
)
|
||||
|
||||
def on_span_end(self, span: Span[Any]) -> None:
|
||||
"""
|
||||
Called when a span is finished.
|
||||
"""
|
||||
for processor in self._processors:
|
||||
try:
|
||||
processor.on_span_end(span)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in trace processor {processor} during on_span_end: {e}"
|
||||
)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
Called when the application stops.
|
||||
"""
|
||||
for processor in self._processors:
|
||||
logger.debug(f"Shutting down trace processor {processor}")
|
||||
try:
|
||||
processor.shutdown()
|
||||
except Exception as e:
|
||||
logger.error(f"Error shutting down trace processor {processor}: {e}")
|
||||
|
||||
def force_flush(self) -> None:
|
||||
"""
|
||||
Force the processors to flush their buffers.
|
||||
"""
|
||||
for processor in self._processors:
|
||||
try:
|
||||
processor.force_flush()
|
||||
except Exception as e:
|
||||
logger.error(f"Error flushing trace processor {processor}: {e}")
|
||||
|
||||
|
||||
class TraceProvider(ABC):
|
||||
"""Interface for creating traces and spans."""
|
||||
|
||||
@abstractmethod
|
||||
def register_processor(self, processor: TracingProcessor) -> None:
|
||||
"""Add a processor that will receive all traces and spans."""
|
||||
|
||||
@abstractmethod
|
||||
def set_processors(self, processors: list[TracingProcessor]) -> None:
|
||||
"""Replace the list of processors with ``processors``."""
|
||||
|
||||
@abstractmethod
|
||||
def get_current_trace(self) -> Trace | None:
|
||||
"""Return the currently active trace, if any."""
|
||||
|
||||
@abstractmethod
|
||||
def get_current_span(self) -> Span[Any] | None:
|
||||
"""Return the currently active span, if any."""
|
||||
|
||||
@abstractmethod
|
||||
def time_iso(self) -> str:
|
||||
"""Return the current time in ISO 8601 format."""
|
||||
|
||||
@abstractmethod
|
||||
def gen_trace_id(self) -> str:
|
||||
"""Generate a new trace identifier."""
|
||||
|
||||
@abstractmethod
|
||||
def gen_span_id(self) -> str:
|
||||
"""Generate a new span identifier."""
|
||||
|
||||
@abstractmethod
|
||||
def gen_group_id(self) -> str:
|
||||
"""Generate a new group identifier."""
|
||||
|
||||
@abstractmethod
|
||||
def create_trace(
|
||||
self,
|
||||
name: str,
|
||||
trace_id: str | None = None,
|
||||
group_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
disabled: bool = False,
|
||||
) -> Trace:
|
||||
"""Create a new trace."""
|
||||
|
||||
@abstractmethod
|
||||
def create_span(
|
||||
self,
|
||||
span_data: TSpanData,
|
||||
span_id: str | None = None,
|
||||
parent: Trace | Span[Any] | None = None,
|
||||
disabled: bool = False,
|
||||
) -> Span[TSpanData]:
|
||||
"""Create a new span."""
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self) -> None:
|
||||
"""Clean up any resources used by the provider."""
|
||||
|
||||
|
||||
class DefaultTraceProvider(TraceProvider):
|
||||
def __init__(self) -> None:
|
||||
self._multi_processor = SynchronousMultiTracingProcessor()
|
||||
|
||||
def register_processor(self, processor: TracingProcessor) -> None:
|
||||
"""
|
||||
Add a processor to the list of processors. Each processor will receive all traces/spans.
|
||||
"""
|
||||
self._multi_processor.add_tracing_processor(processor)
|
||||
|
||||
def set_processors(self, processors: list[TracingProcessor]) -> None:
|
||||
"""
|
||||
Set the list of processors. This will replace the current list of processors.
|
||||
"""
|
||||
self._multi_processor.set_processors(processors)
|
||||
|
||||
def get_current_trace(self) -> Trace | None:
|
||||
"""
|
||||
Returns the currently active trace, if any.
|
||||
"""
|
||||
return Scope.get_current_trace()
|
||||
|
||||
def get_current_span(self) -> Span[Any] | None:
|
||||
"""
|
||||
Returns the currently active span, if any.
|
||||
"""
|
||||
return Scope.get_current_span()
|
||||
|
||||
def time_iso(self) -> str:
|
||||
"""Return the current time in ISO 8601 format."""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
def gen_trace_id(self) -> str:
|
||||
"""Generate a new trace ID."""
|
||||
return f"trace_{uuid.uuid4().hex}"
|
||||
|
||||
def gen_span_id(self) -> str:
|
||||
"""Generate a new span ID."""
|
||||
return f"span_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
def gen_group_id(self) -> str:
|
||||
"""Generate a new group ID."""
|
||||
return f"group_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
def create_trace(
|
||||
self,
|
||||
name: str,
|
||||
trace_id: str | None = None,
|
||||
group_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
disabled: bool = False,
|
||||
) -> Trace:
|
||||
"""
|
||||
Create a new trace.
|
||||
"""
|
||||
if disabled:
|
||||
logger.debug(f"Tracing is disabled. Not creating trace {name}")
|
||||
return NoOpTrace()
|
||||
|
||||
trace_id = trace_id or self.gen_trace_id()
|
||||
|
||||
logger.debug(f"Creating trace {name} with id {trace_id}")
|
||||
|
||||
return TraceImpl(
|
||||
name=name,
|
||||
trace_id=trace_id,
|
||||
group_id=group_id,
|
||||
metadata=metadata,
|
||||
processor=self._multi_processor,
|
||||
)
|
||||
|
||||
def create_span(
|
||||
self,
|
||||
span_data: TSpanData,
|
||||
span_id: str | None = None,
|
||||
parent: Trace | Span[Any] | None = None,
|
||||
disabled: bool = False,
|
||||
) -> Span[TSpanData]:
|
||||
"""
|
||||
Create a new span.
|
||||
"""
|
||||
if disabled:
|
||||
logger.debug(f"Tracing is disabled. Not creating span {span_data}")
|
||||
return NoOpSpan(span_data)
|
||||
|
||||
trace_id: str
|
||||
parent_id: str | None
|
||||
|
||||
if not parent:
|
||||
current_span = Scope.get_current_span()
|
||||
current_trace = Scope.get_current_trace()
|
||||
if current_trace is None:
|
||||
logger.error(
|
||||
"No active trace. Make sure to start a trace with `trace()` first "
|
||||
"Returning NoOpSpan."
|
||||
)
|
||||
return NoOpSpan(span_data)
|
||||
elif isinstance(current_trace, NoOpTrace) or isinstance(
|
||||
current_span, NoOpSpan
|
||||
):
|
||||
logger.debug(
|
||||
f"Parent {current_span} or {current_trace} is no-op, returning NoOpSpan"
|
||||
)
|
||||
return NoOpSpan(span_data)
|
||||
|
||||
parent_id = current_span.span_id if current_span else None
|
||||
trace_id = current_trace.trace_id
|
||||
|
||||
elif isinstance(parent, Trace):
|
||||
if isinstance(parent, NoOpTrace):
|
||||
logger.debug(f"Parent {parent} is no-op, returning NoOpSpan")
|
||||
return NoOpSpan(span_data)
|
||||
trace_id = parent.trace_id
|
||||
parent_id = None
|
||||
elif isinstance(parent, Span):
|
||||
if isinstance(parent, NoOpSpan):
|
||||
logger.debug(f"Parent {parent} is no-op, returning NoOpSpan")
|
||||
return NoOpSpan(span_data)
|
||||
parent_id = parent.span_id
|
||||
trace_id = parent.trace_id
|
||||
else:
|
||||
# This should never happen, but mypy needs it
|
||||
raise ValueError(f"Invalid parent type: {type(parent)}")
|
||||
|
||||
logger.debug(f"Creating span {span_data} with id {span_id}")
|
||||
|
||||
return SpanImpl(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id or self.gen_span_id(),
|
||||
parent_id=parent_id,
|
||||
processor=self._multi_processor,
|
||||
span_data=span_data,
|
||||
)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
try:
|
||||
logger.debug("Shutting down trace provider")
|
||||
self._multi_processor.shutdown()
|
||||
except Exception as e:
|
||||
logger.error(f"Error shutting down trace provider: {e}")
|
||||
@@ -1,49 +0,0 @@
|
||||
import contextvars
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .spans import Span
|
||||
from .traces import Trace
|
||||
|
||||
_current_span: contextvars.ContextVar["Span[Any] | None"] = contextvars.ContextVar(
|
||||
"current_span", default=None
|
||||
)
|
||||
|
||||
_current_trace: contextvars.ContextVar["Trace | None"] = contextvars.ContextVar(
|
||||
"current_trace", default=None
|
||||
)
|
||||
|
||||
|
||||
class Scope:
|
||||
"""
|
||||
Manages the current span and trace in the context.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_current_span(cls) -> "Span[Any] | None":
|
||||
return _current_span.get()
|
||||
|
||||
@classmethod
|
||||
def set_current_span(
|
||||
cls, span: "Span[Any] | None"
|
||||
) -> "contextvars.Token[Span[Any] | None]":
|
||||
return _current_span.set(span)
|
||||
|
||||
@classmethod
|
||||
def reset_current_span(cls, token: "contextvars.Token[Span[Any] | None]") -> None:
|
||||
_current_span.reset(token)
|
||||
|
||||
@classmethod
|
||||
def get_current_trace(cls) -> "Trace | None":
|
||||
return _current_trace.get()
|
||||
|
||||
@classmethod
|
||||
def set_current_trace(
|
||||
cls, trace: "Trace | None"
|
||||
) -> "contextvars.Token[Trace | None]":
|
||||
return _current_trace.set(trace)
|
||||
|
||||
@classmethod
|
||||
def reset_current_trace(cls, token: "contextvars.Token[Trace | None]") -> None:
|
||||
_current_trace.reset(token)
|
||||
@@ -1,21 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .provider import TraceProvider
|
||||
|
||||
GLOBAL_TRACE_PROVIDER: TraceProvider | None = None
|
||||
|
||||
|
||||
def set_trace_provider(provider: TraceProvider) -> None:
|
||||
"""Set the global trace provider used by tracing utilities."""
|
||||
global GLOBAL_TRACE_PROVIDER
|
||||
GLOBAL_TRACE_PROVIDER = provider
|
||||
|
||||
|
||||
def get_trace_provider() -> TraceProvider:
|
||||
"""Get the global trace provider used by tracing utilities."""
|
||||
if GLOBAL_TRACE_PROVIDER is None:
|
||||
raise RuntimeError("Trace provider not set")
|
||||
return GLOBAL_TRACE_PROVIDER
|
||||
@@ -1,130 +0,0 @@
|
||||
import abc
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SpanData(abc.ABC):
|
||||
"""
|
||||
Represents span data in the trace.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def export(self) -> dict[str, Any]:
|
||||
"""Export the span data as a dictionary."""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def type(self) -> str:
|
||||
"""Return the type of the span."""
|
||||
|
||||
|
||||
class AgentSpanData(SpanData):
|
||||
"""
|
||||
Represents an Agent Span in the trace.
|
||||
Includes name, handoffs, tools, and output type.
|
||||
"""
|
||||
|
||||
__slots__ = ("name", "handoffs", "tools", "output_type")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
handoffs: list[str] | None = None,
|
||||
tools: list[str] | None = None,
|
||||
output_type: str | None = None,
|
||||
):
|
||||
self.name = name
|
||||
self.handoffs: list[str] | None = handoffs
|
||||
self.tools: list[str] | None = tools
|
||||
self.output_type: str | None = output_type
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "agent"
|
||||
|
||||
def export(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": self.type,
|
||||
"name": self.name,
|
||||
"handoffs": self.handoffs,
|
||||
"tools": self.tools,
|
||||
"output_type": self.output_type,
|
||||
}
|
||||
|
||||
|
||||
class FunctionSpanData(SpanData):
|
||||
"""
|
||||
Represents a Function Span in the trace.
|
||||
Includes input, output and MCP data (if applicable).
|
||||
"""
|
||||
|
||||
__slots__ = ("name", "input", "output", "mcp_data")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
input: str | None,
|
||||
output: Any | None,
|
||||
mcp_data: dict[str, Any] | None = None,
|
||||
):
|
||||
self.name = name
|
||||
self.input = input
|
||||
self.output = output
|
||||
self.mcp_data = mcp_data
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "function"
|
||||
|
||||
def export(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": self.type,
|
||||
"name": self.name,
|
||||
"input": self.input,
|
||||
"output": str(self.output) if self.output else None,
|
||||
"mcp_data": self.mcp_data,
|
||||
}
|
||||
|
||||
|
||||
class GenerationSpanData(SpanData):
|
||||
"""
|
||||
Represents a Generation Span in the trace.
|
||||
Includes input, output, model, model configuration, and usage.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"input",
|
||||
"output",
|
||||
"model",
|
||||
"model_config",
|
||||
"usage",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input: Sequence[Mapping[str, Any]] | None = None,
|
||||
output: Sequence[Mapping[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
model_config: Mapping[str, Any] | None = None,
|
||||
usage: dict[str, Any] | None = None,
|
||||
):
|
||||
self.input = input
|
||||
self.output = output
|
||||
self.model = model
|
||||
self.model_config = model_config
|
||||
self.usage = usage
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "generation"
|
||||
|
||||
def export(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": self.type,
|
||||
"input": self.input,
|
||||
"output": self.output,
|
||||
"model": self.model,
|
||||
"model_config": self.model_config,
|
||||
"usage": self.usage,
|
||||
}
|
||||
@@ -1,356 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import contextvars
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from . import util
|
||||
from .processor_interface import TracingProcessor
|
||||
from .scope import Scope
|
||||
from .span_data import SpanData
|
||||
|
||||
TSpanData = TypeVar("TSpanData", bound=SpanData)
|
||||
|
||||
|
||||
class SpanError(TypedDict):
|
||||
"""Represents an error that occurred during span execution.
|
||||
|
||||
Attributes:
|
||||
message: A human-readable error description
|
||||
data: Optional dictionary containing additional error context
|
||||
"""
|
||||
|
||||
message: str
|
||||
data: dict[str, Any] | None
|
||||
|
||||
|
||||
class Span(abc.ABC, Generic[TSpanData]):
|
||||
"""Base class for representing traceable operations with timing and context.
|
||||
|
||||
A span represents a single operation within a trace (e.g., an LLM call, tool execution,
|
||||
or agent run). Spans track timing, relationships between operations, and operation-specific
|
||||
data.
|
||||
|
||||
Type Args:
|
||||
TSpanData: The type of span-specific data this span contains.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Creating a custom span
|
||||
with custom_span("database_query", {
|
||||
"operation": "SELECT",
|
||||
"table": "users"
|
||||
}) as span:
|
||||
results = await db.query("SELECT * FROM users")
|
||||
span.set_output({"count": len(results)})
|
||||
|
||||
# Handling errors in spans
|
||||
with custom_span("risky_operation") as span:
|
||||
try:
|
||||
result = perform_risky_operation()
|
||||
except Exception as e:
|
||||
span.set_error({
|
||||
"message": str(e),
|
||||
"data": {"operation": "risky_operation"}
|
||||
})
|
||||
raise
|
||||
```
|
||||
|
||||
Notes:
|
||||
- Spans automatically nest under the current trace
|
||||
- Use context managers for reliable start/finish
|
||||
- Include relevant data but avoid sensitive information
|
||||
- Handle errors properly using set_error()
|
||||
"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def trace_id(self) -> str:
|
||||
"""The ID of the trace this span belongs to.
|
||||
|
||||
Returns:
|
||||
str: Unique identifier of the parent trace.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def span_id(self) -> str:
|
||||
"""Unique identifier for this span.
|
||||
|
||||
Returns:
|
||||
str: The span's unique ID within its trace.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def span_data(self) -> TSpanData:
|
||||
"""Operation-specific data for this span.
|
||||
|
||||
Returns:
|
||||
TSpanData: Data specific to this type of span (e.g., LLM generation data).
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def start(self, mark_as_current: bool = False) -> None:
|
||||
"""
|
||||
Start the span.
|
||||
|
||||
Args:
|
||||
mark_as_current: If true, the span will be marked as the current span.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def finish(self, reset_current: bool = False) -> None:
|
||||
"""
|
||||
Finish the span.
|
||||
|
||||
Args:
|
||||
reset_current: If true, the span will be reset as the current span.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __enter__(self) -> Span[TSpanData]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def parent_id(self) -> str | None:
|
||||
"""ID of the parent span, if any.
|
||||
|
||||
Returns:
|
||||
str | None: The parent span's ID, or None if this is a root span.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_error(self, error: SpanError) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def error(self) -> SpanError | None:
|
||||
"""Any error that occurred during span execution.
|
||||
|
||||
Returns:
|
||||
SpanError | None: Error details if an error occurred, None otherwise.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def export(self) -> dict[str, Any] | None:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def started_at(self) -> str | None:
|
||||
"""When the span started execution.
|
||||
|
||||
Returns:
|
||||
str | None: ISO format timestamp of span start, None if not started.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def ended_at(self) -> str | None:
|
||||
"""When the span finished execution.
|
||||
|
||||
Returns:
|
||||
str | None: ISO format timestamp of span end, None if not finished.
|
||||
"""
|
||||
|
||||
|
||||
class NoOpSpan(Span[TSpanData]):
|
||||
"""A no-op implementation of Span that doesn't record any data.
|
||||
|
||||
Used when tracing is disabled but span operations still need to work.
|
||||
|
||||
Args:
|
||||
span_data: The operation-specific data for this span.
|
||||
"""
|
||||
|
||||
__slots__ = ("_span_data", "_prev_span_token")
|
||||
|
||||
def __init__(self, span_data: TSpanData):
|
||||
self._span_data = span_data
|
||||
self._prev_span_token: contextvars.Token[Span[TSpanData] | None] | None = None
|
||||
|
||||
@property
|
||||
def trace_id(self) -> str:
|
||||
return "no-op"
|
||||
|
||||
@property
|
||||
def span_id(self) -> str:
|
||||
return "no-op"
|
||||
|
||||
@property
|
||||
def span_data(self) -> TSpanData:
|
||||
return self._span_data
|
||||
|
||||
@property
|
||||
def parent_id(self) -> str | None:
|
||||
return None
|
||||
|
||||
def start(self, mark_as_current: bool = False) -> None:
|
||||
if mark_as_current:
|
||||
self._prev_span_token = Scope.set_current_span(self)
|
||||
|
||||
def finish(self, reset_current: bool = False) -> None:
|
||||
if reset_current and self._prev_span_token is not None:
|
||||
Scope.reset_current_span(self._prev_span_token)
|
||||
self._prev_span_token = None
|
||||
|
||||
def __enter__(self) -> Span[TSpanData]:
|
||||
self.start(mark_as_current=True)
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
reset_current = True
|
||||
if exc_type is GeneratorExit:
|
||||
reset_current = False
|
||||
|
||||
self.finish(reset_current=reset_current)
|
||||
|
||||
def set_error(self, error: SpanError) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def error(self) -> SpanError | None:
|
||||
return None
|
||||
|
||||
def export(self) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def started_at(self) -> str | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def ended_at(self) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
class SpanImpl(Span[TSpanData]):
|
||||
__slots__ = (
|
||||
"_trace_id",
|
||||
"_span_id",
|
||||
"_parent_id",
|
||||
"_started_at",
|
||||
"_ended_at",
|
||||
"_error",
|
||||
"_prev_span_token",
|
||||
"_processor",
|
||||
"_span_data",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trace_id: str,
|
||||
span_id: str | None,
|
||||
parent_id: str | None,
|
||||
processor: TracingProcessor,
|
||||
span_data: TSpanData,
|
||||
):
|
||||
self._trace_id = trace_id
|
||||
self._span_id = span_id or util.gen_span_id()
|
||||
self._parent_id = parent_id
|
||||
self._started_at: str | None = None
|
||||
self._ended_at: str | None = None
|
||||
self._processor = processor
|
||||
self._error: SpanError | None = None
|
||||
self._prev_span_token: contextvars.Token[Span[TSpanData] | None] | None = None
|
||||
self._span_data = span_data
|
||||
|
||||
@property
|
||||
def trace_id(self) -> str:
|
||||
return self._trace_id
|
||||
|
||||
@property
|
||||
def span_id(self) -> str:
|
||||
return self._span_id
|
||||
|
||||
@property
|
||||
def span_data(self) -> TSpanData:
|
||||
return self._span_data
|
||||
|
||||
@property
|
||||
def parent_id(self) -> str | None:
|
||||
return self._parent_id
|
||||
|
||||
def start(self, mark_as_current: bool = False) -> None:
|
||||
if self.started_at is not None:
|
||||
return
|
||||
|
||||
self._started_at = util.time_iso()
|
||||
self._processor.on_span_start(self)
|
||||
if mark_as_current:
|
||||
self._prev_span_token = Scope.set_current_span(self)
|
||||
|
||||
def finish(self, reset_current: bool = False) -> None:
|
||||
if self.ended_at is not None:
|
||||
return
|
||||
|
||||
self._ended_at = util.time_iso()
|
||||
self._processor.on_span_end(self)
|
||||
if reset_current and self._prev_span_token is not None:
|
||||
Scope.reset_current_span(self._prev_span_token)
|
||||
self._prev_span_token = None
|
||||
|
||||
def __enter__(self) -> Span[TSpanData]:
|
||||
self.start(mark_as_current=True)
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
reset_current = True
|
||||
if exc_type is GeneratorExit:
|
||||
reset_current = False
|
||||
|
||||
self.finish(reset_current=reset_current)
|
||||
|
||||
def set_error(self, error: SpanError) -> None:
|
||||
self._error = error
|
||||
|
||||
@property
|
||||
def error(self) -> SpanError | None:
|
||||
return self._error
|
||||
|
||||
@property
|
||||
def started_at(self) -> str | None:
|
||||
return self._started_at
|
||||
|
||||
@property
|
||||
def ended_at(self) -> str | None:
|
||||
return self._ended_at
|
||||
|
||||
def export(self) -> dict[str, Any] | None:
|
||||
return {
|
||||
"object": "trace.span",
|
||||
"id": self.span_id,
|
||||
"trace_id": self.trace_id,
|
||||
"parent_id": self._parent_id,
|
||||
"started_at": self._started_at,
|
||||
"ended_at": self._ended_at,
|
||||
"span_data": self.span_data.export(),
|
||||
"error": self._error,
|
||||
}
|
||||
@@ -1,287 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import contextvars
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from . import util
|
||||
from .scope import Scope
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .processor_interface import TracingProcessor
|
||||
|
||||
|
||||
class Trace(abc.ABC):
|
||||
"""A complete end-to-end workflow containing related spans and metadata.
|
||||
|
||||
A trace represents a logical workflow or operation (e.g., "Customer Service Query"
|
||||
or "Code Generation") and contains all the spans (individual operations) that occur
|
||||
during that workflow.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Basic trace usage
|
||||
with trace("Order Processing") as t:
|
||||
validation_result = await Runner.run(validator, order_data)
|
||||
if validation_result.approved:
|
||||
await Runner.run(processor, order_data)
|
||||
|
||||
# Trace with metadata and grouping
|
||||
with trace(
|
||||
"Customer Service",
|
||||
group_id="chat_123",
|
||||
metadata={"customer": "user_456"}
|
||||
) as t:
|
||||
result = await Runner.run(support_agent, query)
|
||||
```
|
||||
|
||||
Notes:
|
||||
- Use descriptive workflow names
|
||||
- Group related traces with consistent group_ids
|
||||
- Add relevant metadata for filtering/analysis
|
||||
- Use context managers for reliable cleanup
|
||||
- Consider privacy when adding trace data
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __enter__(self) -> Trace:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def start(self, mark_as_current: bool = False) -> None:
|
||||
"""Start the trace and optionally mark it as the current trace.
|
||||
|
||||
Args:
|
||||
mark_as_current: If true, marks this trace as the current trace
|
||||
in the execution context.
|
||||
|
||||
Notes:
|
||||
- Must be called before any spans can be added
|
||||
- Only one trace can be current at a time
|
||||
- Thread-safe when using mark_as_current
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def finish(self, reset_current: bool = False) -> None:
|
||||
"""Finish the trace and optionally reset the current trace.
|
||||
|
||||
Args:
|
||||
reset_current: If true, resets the current trace to the previous
|
||||
trace in the execution context.
|
||||
|
||||
Notes:
|
||||
- Must be called to complete the trace
|
||||
- Finalizes all open spans
|
||||
- Thread-safe when using reset_current
|
||||
"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def trace_id(self) -> str:
|
||||
"""Get the unique identifier for this trace.
|
||||
|
||||
Returns:
|
||||
str: The trace's unique ID in the format 'trace_<32_alphanumeric>'
|
||||
|
||||
Notes:
|
||||
- IDs are globally unique
|
||||
- Used to link spans to their parent trace
|
||||
- Can be used to look up traces in the dashboard
|
||||
"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Get the human-readable name of this workflow trace.
|
||||
|
||||
Returns:
|
||||
str: The workflow name (e.g., "Customer Service", "Data Processing")
|
||||
|
||||
Notes:
|
||||
- Should be descriptive and meaningful
|
||||
- Used for grouping and filtering in the dashboard
|
||||
- Helps identify the purpose of the trace
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def export(self) -> dict[str, Any] | None:
|
||||
"""Export the trace data as a serializable dictionary.
|
||||
|
||||
Returns:
|
||||
dict | None: Dictionary containing trace data, or None if tracing is disabled.
|
||||
|
||||
Notes:
|
||||
- Includes all spans and their data
|
||||
- Used for sending traces to backends
|
||||
- May include metadata and group ID
|
||||
"""
|
||||
|
||||
|
||||
class NoOpTrace(Trace):
|
||||
"""A no-op implementation of Trace that doesn't record any data.
|
||||
|
||||
Used when tracing is disabled but trace operations still need to work.
|
||||
Maintains proper context management but doesn't store or export any data.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# When tracing is disabled, traces become NoOpTrace
|
||||
with trace("Disabled Workflow") as t:
|
||||
# Operations still work but nothing is recorded
|
||||
await Runner.run(agent, "query")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._started = False
|
||||
self._prev_context_token: contextvars.Token[Trace | None] | None = None
|
||||
|
||||
def __enter__(self) -> Trace:
|
||||
if self._started:
|
||||
return self
|
||||
|
||||
self._started = True
|
||||
self.start(mark_as_current=True)
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
self.finish(reset_current=True)
|
||||
|
||||
def start(self, mark_as_current: bool = False) -> None:
|
||||
if mark_as_current:
|
||||
self._prev_context_token = Scope.set_current_trace(self)
|
||||
|
||||
def finish(self, reset_current: bool = False) -> None:
|
||||
if reset_current and self._prev_context_token is not None:
|
||||
Scope.reset_current_trace(self._prev_context_token)
|
||||
self._prev_context_token = None
|
||||
|
||||
@property
|
||||
def trace_id(self) -> str:
|
||||
"""The trace's unique identifier.
|
||||
|
||||
Returns:
|
||||
str: A unique ID for this trace.
|
||||
"""
|
||||
return "no-op"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The workflow name for this trace.
|
||||
|
||||
Returns:
|
||||
str: Human-readable name describing this workflow.
|
||||
"""
|
||||
return "no-op"
|
||||
|
||||
def export(self) -> dict[str, Any] | None:
|
||||
"""Export the trace data as a dictionary.
|
||||
|
||||
Returns:
|
||||
dict | None: Trace data in exportable format, or None if no data.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
NO_OP_TRACE = NoOpTrace()
|
||||
|
||||
|
||||
class TraceImpl(Trace):
|
||||
"""
|
||||
A trace that will be recorded by the tracing library.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"_name",
|
||||
"_trace_id",
|
||||
"group_id",
|
||||
"metadata",
|
||||
"_prev_context_token",
|
||||
"_processor",
|
||||
"_started",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
trace_id: str | None,
|
||||
group_id: str | None,
|
||||
metadata: dict[str, Any] | None,
|
||||
processor: TracingProcessor,
|
||||
):
|
||||
self._name = name
|
||||
self._trace_id = trace_id or util.gen_trace_id()
|
||||
self.group_id = group_id
|
||||
self.metadata = metadata
|
||||
self._prev_context_token: contextvars.Token[Trace | None] | None = None
|
||||
self._processor = processor
|
||||
self._started = False
|
||||
|
||||
@property
|
||||
def trace_id(self) -> str:
|
||||
return self._trace_id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def start(self, mark_as_current: bool = False) -> None:
|
||||
if self._started:
|
||||
return
|
||||
|
||||
self._started = True
|
||||
self._processor.on_trace_start(self)
|
||||
|
||||
if mark_as_current:
|
||||
self._prev_context_token = Scope.set_current_trace(self)
|
||||
|
||||
def finish(self, reset_current: bool = False) -> None:
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
self._processor.on_trace_end(self)
|
||||
|
||||
if reset_current and self._prev_context_token is not None:
|
||||
Scope.reset_current_trace(self._prev_context_token)
|
||||
self._prev_context_token = None
|
||||
|
||||
def __enter__(self) -> Trace:
|
||||
if self._started:
|
||||
return self
|
||||
|
||||
self.start(mark_as_current=True)
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
self.finish(reset_current=exc_type is not GeneratorExit)
|
||||
|
||||
def export(self) -> dict[str, Any] | None:
|
||||
return {
|
||||
"object": "trace",
|
||||
"id": self.trace_id,
|
||||
"workflow_name": self.name,
|
||||
"group_id": self.group_id,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
|
||||
def time_iso() -> str:
|
||||
"""Return the current time in ISO 8601 format."""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def gen_trace_id() -> str:
|
||||
"""Generate a new trace ID."""
|
||||
return f"trace_{uuid.uuid4().hex}"
|
||||
|
||||
|
||||
def gen_span_id() -> str:
|
||||
"""Generate a new span ID."""
|
||||
return f"span_{uuid.uuid4().hex[:24]}"
|
||||
@@ -1,3 +1,4 @@
|
||||
from onyx.configs.app_configs import LANGFUSE_HOST
|
||||
from onyx.configs.app_configs import LANGFUSE_PUBLIC_KEY
|
||||
from onyx.configs.app_configs import LANGFUSE_SECRET_KEY
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -12,20 +13,16 @@ def setup_langfuse_if_creds_available() -> None:
|
||||
return
|
||||
|
||||
import nest_asyncio # type: ignore
|
||||
|
||||
nest_asyncio.apply()
|
||||
|
||||
from langfuse import get_client
|
||||
from openinference.instrumentation.openai_agents import OpenAIAgentsInstrumentor
|
||||
|
||||
nest_asyncio.apply()
|
||||
OpenAIAgentsInstrumentor().instrument()
|
||||
# TODO: this is how the tracing processor will look once we migrate over to new framework
|
||||
# config = TraceConfig()
|
||||
# tracer_provider = trace_api.get_tracer_provider()
|
||||
# tracer = OITracer(
|
||||
# trace_api.get_tracer(__name__, __version__, tracer_provider),
|
||||
# config=config,
|
||||
# )
|
||||
|
||||
# set_trace_processors(
|
||||
# [OpenInferenceTracingProcessor(cast(trace_api.Tracer, tracer))]
|
||||
# )
|
||||
langfuse = get_client()
|
||||
try:
|
||||
if langfuse.auth_check():
|
||||
logger.notice(f"Langfuse authentication successful (host: {LANGFUSE_HOST})")
|
||||
else:
|
||||
logger.warning("Langfuse authentication failed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting up Langfuse: {e}")
|
||||
|
||||
@@ -1,450 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from openinference.instrumentation import safe_json_dumps
|
||||
from openinference.semconv.trace import MessageAttributes
|
||||
from openinference.semconv.trace import MessageContentAttributes
|
||||
from openinference.semconv.trace import OpenInferenceLLMProviderValues
|
||||
from openinference.semconv.trace import OpenInferenceLLMSystemValues
|
||||
from openinference.semconv.trace import OpenInferenceMimeTypeValues
|
||||
from openinference.semconv.trace import OpenInferenceSpanKindValues
|
||||
from openinference.semconv.trace import SpanAttributes
|
||||
from openinference.semconv.trace import ToolAttributes
|
||||
from openinference.semconv.trace import ToolCallAttributes
|
||||
from opentelemetry.context import attach
|
||||
from opentelemetry.context import detach
|
||||
from opentelemetry.trace import set_span_in_context
|
||||
from opentelemetry.trace import Span as OtelSpan
|
||||
from opentelemetry.trace import Status
|
||||
from opentelemetry.trace import StatusCode
|
||||
from opentelemetry.trace import Tracer
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
|
||||
from onyx.tracing.framework.processor_interface import TracingProcessor
|
||||
from onyx.tracing.framework.span_data import AgentSpanData
|
||||
from onyx.tracing.framework.span_data import FunctionSpanData
|
||||
from onyx.tracing.framework.span_data import GenerationSpanData
|
||||
from onyx.tracing.framework.span_data import SpanData
|
||||
from onyx.tracing.framework.spans import Span
|
||||
from onyx.tracing.framework.traces import Trace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenInferenceTracingProcessor(TracingProcessor):
|
||||
_MAX_HANDOFFS_IN_FLIGHT = 1000
|
||||
|
||||
def __init__(self, tracer: Tracer) -> None:
|
||||
self._tracer = tracer
|
||||
self._root_spans: dict[str, OtelSpan] = {}
|
||||
self._otel_spans: dict[str, OtelSpan] = {}
|
||||
self._tokens: dict[str, object] = {}
|
||||
# This captures in flight handoff. Once the handoff is complete, the entry is deleted
|
||||
# If the handoff does not complete, the entry stays in the dict.
|
||||
# Use an OrderedDict and _MAX_HANDOFFS_IN_FLIGHT to cap the size of the dict
|
||||
# in case there are large numbers of orphaned handoffs
|
||||
self._reverse_handoffs_dict: OrderedDict[str, str] = OrderedDict()
|
||||
self._first_input: dict[str, Any] = {}
|
||||
self._last_output: dict[str, Any] = {}
|
||||
|
||||
def on_trace_start(self, trace: Trace) -> None:
|
||||
"""Called when a trace is started.
|
||||
|
||||
Args:
|
||||
trace: The trace that started.
|
||||
"""
|
||||
otel_span = self._tracer.start_span(
|
||||
name=trace.name,
|
||||
attributes={
|
||||
OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.AGENT.value,
|
||||
},
|
||||
)
|
||||
self._root_spans[trace.trace_id] = otel_span
|
||||
|
||||
def on_trace_end(self, trace: Trace) -> None:
|
||||
"""Called when a trace is finished.
|
||||
|
||||
Args:
|
||||
trace: The trace that started.
|
||||
"""
|
||||
if root_span := self._root_spans.pop(trace.trace_id, None):
|
||||
# Get the first input and last output for this specific trace
|
||||
trace_first_input = self._first_input.pop(trace.trace_id, None)
|
||||
trace_last_output = self._last_output.pop(trace.trace_id, None)
|
||||
|
||||
# Set input/output attributes on the root span
|
||||
if trace_first_input is not None:
|
||||
try:
|
||||
root_span.set_attribute(
|
||||
INPUT_VALUE, safe_json_dumps(trace_first_input)
|
||||
)
|
||||
root_span.set_attribute(INPUT_MIME_TYPE, JSON)
|
||||
except Exception:
|
||||
# Fallback to string if JSON serialization fails
|
||||
root_span.set_attribute(INPUT_VALUE, str(trace_first_input))
|
||||
|
||||
if trace_last_output is not None:
|
||||
try:
|
||||
root_span.set_attribute(
|
||||
OUTPUT_VALUE, safe_json_dumps(trace_last_output)
|
||||
)
|
||||
root_span.set_attribute(OUTPUT_MIME_TYPE, JSON)
|
||||
except Exception:
|
||||
# Fallback to string if JSON serialization fails
|
||||
root_span.set_attribute(OUTPUT_VALUE, str(trace_last_output))
|
||||
|
||||
root_span.set_status(Status(StatusCode.OK))
|
||||
root_span.end()
|
||||
else:
|
||||
# Clean up stored input/output for this trace if root span doesn't exist
|
||||
self._first_input.pop(trace.trace_id, None)
|
||||
self._last_output.pop(trace.trace_id, None)
|
||||
|
||||
def on_span_start(self, span: Span[Any]) -> None:
|
||||
"""Called when a span is started.
|
||||
|
||||
Args:
|
||||
span: The span that started.
|
||||
"""
|
||||
if not span.started_at:
|
||||
return
|
||||
start_time = datetime.fromisoformat(span.started_at)
|
||||
parent_span = (
|
||||
self._otel_spans.get(span.parent_id)
|
||||
if span.parent_id
|
||||
else self._root_spans.get(span.trace_id)
|
||||
)
|
||||
context = set_span_in_context(parent_span) if parent_span else None
|
||||
span_name = _get_span_name(span)
|
||||
otel_span = self._tracer.start_span(
|
||||
name=span_name,
|
||||
context=context,
|
||||
start_time=_as_utc_nano(start_time),
|
||||
attributes={
|
||||
OPENINFERENCE_SPAN_KIND: _get_span_kind(span.span_data),
|
||||
LLM_SYSTEM: OpenInferenceLLMSystemValues.OPENAI.value,
|
||||
},
|
||||
)
|
||||
self._otel_spans[span.span_id] = otel_span
|
||||
self._tokens[span.span_id] = attach(set_span_in_context(otel_span))
|
||||
|
||||
def on_span_end(self, span: Span[Any]) -> None:
|
||||
"""Called when a span is finished. Should not block or raise exceptions.
|
||||
|
||||
Args:
|
||||
span: The span that finished.
|
||||
"""
|
||||
if token := self._tokens.pop(span.span_id, None):
|
||||
detach(token) # type: ignore[arg-type]
|
||||
if not (otel_span := self._otel_spans.pop(span.span_id, None)):
|
||||
return
|
||||
otel_span.update_name(_get_span_name(span))
|
||||
# flatten_attributes: dict[str, AttributeValue] = dict(_flatten(span.export()))
|
||||
# otel_span.set_attributes(flatten_attributes)
|
||||
data = span.span_data
|
||||
if isinstance(data, GenerationSpanData):
|
||||
for k, v in _get_attributes_from_generation_span_data(data):
|
||||
otel_span.set_attribute(k, v)
|
||||
elif isinstance(data, FunctionSpanData):
|
||||
for k, v in _get_attributes_from_function_span_data(data):
|
||||
otel_span.set_attribute(k, v)
|
||||
elif isinstance(data, AgentSpanData):
|
||||
otel_span.set_attribute(GRAPH_NODE_ID, data.name)
|
||||
# Lookup the parent node if exists
|
||||
key = f"{data.name}:{span.trace_id}"
|
||||
if parent_node := self._reverse_handoffs_dict.pop(key, None):
|
||||
otel_span.set_attribute(GRAPH_NODE_PARENT_ID, parent_node)
|
||||
|
||||
end_time: Optional[int] = None
|
||||
if span.ended_at:
|
||||
try:
|
||||
end_time = _as_utc_nano(datetime.fromisoformat(span.ended_at))
|
||||
except ValueError:
|
||||
pass
|
||||
otel_span.set_status(status=_get_span_status(span))
|
||||
otel_span.end(end_time)
|
||||
|
||||
# Store first input and last output per trace_id
|
||||
trace_id = span.trace_id
|
||||
input_: Optional[Any] = None
|
||||
output: Optional[Any] = None
|
||||
|
||||
if isinstance(data, FunctionSpanData):
|
||||
input_ = data.input
|
||||
output = data.output
|
||||
elif isinstance(data, GenerationSpanData):
|
||||
input_ = data.input
|
||||
output = data.output
|
||||
|
||||
if trace_id not in self._first_input and input_ is not None:
|
||||
self._first_input[trace_id] = input_
|
||||
|
||||
if output is not None:
|
||||
self._last_output[trace_id] = output
|
||||
|
||||
def force_flush(self) -> None:
|
||||
"""Forces an immediate flush of all queued spans/traces."""
|
||||
# TODO
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Called when the application stops."""
|
||||
# TODO
|
||||
|
||||
|
||||
def _as_utc_nano(dt: datetime) -> int:
|
||||
return int(dt.astimezone(timezone.utc).timestamp() * 1_000_000_000)
|
||||
|
||||
|
||||
def _get_span_name(obj: Span[Any]) -> str:
|
||||
if hasattr(data := obj.span_data, "name") and isinstance(name := data.name, str):
|
||||
return name
|
||||
return obj.span_data.type # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def _get_span_kind(obj: SpanData) -> str:
|
||||
if isinstance(obj, AgentSpanData):
|
||||
return OpenInferenceSpanKindValues.AGENT.value
|
||||
if isinstance(obj, FunctionSpanData):
|
||||
return OpenInferenceSpanKindValues.TOOL.value
|
||||
if isinstance(obj, GenerationSpanData):
|
||||
return OpenInferenceSpanKindValues.LLM.value
|
||||
return OpenInferenceSpanKindValues.CHAIN.value
|
||||
|
||||
|
||||
def _get_attributes_from_generation_span_data(
|
||||
obj: GenerationSpanData,
|
||||
) -> Iterator[tuple[str, AttributeValue]]:
|
||||
if isinstance(model := obj.model, str):
|
||||
yield LLM_MODEL_NAME, model
|
||||
if isinstance(obj.model_config, dict) and (
|
||||
param := {k: v for k, v in obj.model_config.items() if v is not None}
|
||||
):
|
||||
yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(param)
|
||||
if base_url := param.get("base_url"):
|
||||
if "api.openai.com" in base_url:
|
||||
yield LLM_PROVIDER, OpenInferenceLLMProviderValues.OPENAI.value
|
||||
yield from _get_attributes_from_chat_completions_input(obj.input)
|
||||
yield from _get_attributes_from_chat_completions_output(obj.output)
|
||||
yield from _get_attributes_from_chat_completions_usage(obj.usage)
|
||||
|
||||
|
||||
def _get_attributes_from_chat_completions_input(
|
||||
obj: Optional[Iterable[Mapping[str, Any]]],
|
||||
) -> Iterator[tuple[str, AttributeValue]]:
|
||||
if not obj:
|
||||
return
|
||||
try:
|
||||
yield INPUT_VALUE, safe_json_dumps(obj)
|
||||
yield INPUT_MIME_TYPE, JSON
|
||||
except Exception:
|
||||
pass
|
||||
yield from _get_attributes_from_chat_completions_message_dicts(
|
||||
obj,
|
||||
f"{LLM_INPUT_MESSAGES}.",
|
||||
)
|
||||
|
||||
|
||||
def _get_attributes_from_chat_completions_output(
|
||||
obj: Optional[Iterable[Mapping[str, Any]]],
|
||||
) -> Iterator[tuple[str, AttributeValue]]:
|
||||
if not obj:
|
||||
return
|
||||
try:
|
||||
yield OUTPUT_VALUE, safe_json_dumps(obj)
|
||||
yield OUTPUT_MIME_TYPE, JSON
|
||||
except Exception:
|
||||
pass
|
||||
yield from _get_attributes_from_chat_completions_message_dicts(
|
||||
obj,
|
||||
f"{LLM_OUTPUT_MESSAGES}.",
|
||||
)
|
||||
|
||||
|
||||
def _get_attributes_from_chat_completions_message_dicts(
|
||||
obj: Iterable[Mapping[str, Any]],
|
||||
prefix: str = "",
|
||||
msg_idx: int = 0,
|
||||
tool_call_idx: int = 0,
|
||||
) -> Iterator[tuple[str, AttributeValue]]:
|
||||
if not isinstance(obj, Iterable):
|
||||
return
|
||||
for msg in obj:
|
||||
if isinstance(role := msg.get("role"), str):
|
||||
yield f"{prefix}{msg_idx}.{MESSAGE_ROLE}", role
|
||||
if content := msg.get("content"):
|
||||
yield from _get_attributes_from_chat_completions_message_content(
|
||||
content,
|
||||
f"{prefix}{msg_idx}.",
|
||||
)
|
||||
if isinstance(tool_call_id := msg.get("tool_call_id"), str):
|
||||
yield f"{prefix}{msg_idx}.{MESSAGE_TOOL_CALL_ID}", tool_call_id
|
||||
if isinstance(tool_calls := msg.get("tool_calls"), Iterable):
|
||||
for tc in tool_calls:
|
||||
yield from _get_attributes_from_chat_completions_tool_call_dict(
|
||||
tc,
|
||||
f"{prefix}{msg_idx}.{MESSAGE_TOOL_CALLS}.{tool_call_idx}.",
|
||||
)
|
||||
tool_call_idx += 1
|
||||
msg_idx += 1
|
||||
|
||||
|
||||
def _get_attributes_from_chat_completions_message_content(
|
||||
obj: Union[str, Iterable[Mapping[str, Any]]],
|
||||
prefix: str = "",
|
||||
) -> Iterator[tuple[str, AttributeValue]]:
|
||||
if isinstance(obj, str):
|
||||
yield f"{prefix}{MESSAGE_CONTENT}", obj
|
||||
elif isinstance(obj, Iterable):
|
||||
for i, item in enumerate(obj):
|
||||
if not isinstance(item, Mapping):
|
||||
continue
|
||||
yield from _get_attributes_from_chat_completions_message_content_item(
|
||||
item,
|
||||
f"{prefix}{MESSAGE_CONTENTS}.{i}.",
|
||||
)
|
||||
|
||||
|
||||
def _get_attributes_from_chat_completions_message_content_item(
|
||||
obj: Mapping[str, Any],
|
||||
prefix: str = "",
|
||||
) -> Iterator[tuple[str, AttributeValue]]:
|
||||
if obj.get("type") == "text" and (text := obj.get("text")):
|
||||
yield f"{prefix}{MESSAGE_CONTENT_TYPE}", "text"
|
||||
yield f"{prefix}{MESSAGE_CONTENT_TEXT}", text
|
||||
|
||||
|
||||
def _get_attributes_from_chat_completions_tool_call_dict(
|
||||
obj: Mapping[str, Any],
|
||||
prefix: str = "",
|
||||
) -> Iterator[tuple[str, AttributeValue]]:
|
||||
if id_ := obj.get("id"):
|
||||
yield f"{prefix}{TOOL_CALL_ID}", id_
|
||||
if function := obj.get("function"):
|
||||
if name := function.get("name"):
|
||||
yield f"{prefix}{TOOL_CALL_FUNCTION_NAME}", name
|
||||
if arguments := function.get("arguments"):
|
||||
if arguments != "{}":
|
||||
yield f"{prefix}{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}", arguments
|
||||
|
||||
|
||||
def _get_attributes_from_chat_completions_usage(
|
||||
obj: Optional[Mapping[str, Any]],
|
||||
) -> Iterator[tuple[str, AttributeValue]]:
|
||||
if not obj:
|
||||
return
|
||||
if input_tokens := obj.get("input_tokens"):
|
||||
yield LLM_TOKEN_COUNT_PROMPT, input_tokens
|
||||
if output_tokens := obj.get("output_tokens"):
|
||||
yield LLM_TOKEN_COUNT_COMPLETION, output_tokens
|
||||
|
||||
|
||||
# convert dict, tuple, etc into one of these types ['bool', 'str', 'bytes', 'int', 'float']
|
||||
def _convert_to_primitive(value: Any) -> Union[bool, str, bytes, int, float]:
|
||||
if isinstance(value, (bool, str, bytes, int, float)):
|
||||
return value
|
||||
if isinstance(value, (list, tuple)):
|
||||
return safe_json_dumps(value)
|
||||
if isinstance(value, dict):
|
||||
return safe_json_dumps(value)
|
||||
return str(value)
|
||||
|
||||
|
||||
def _get_attributes_from_function_span_data(
|
||||
obj: FunctionSpanData,
|
||||
) -> Iterator[tuple[str, AttributeValue]]:
|
||||
yield TOOL_NAME, obj.name
|
||||
if obj.input:
|
||||
yield INPUT_VALUE, obj.input
|
||||
yield INPUT_MIME_TYPE, JSON
|
||||
if obj.output is not None:
|
||||
yield OUTPUT_VALUE, _convert_to_primitive(obj.output)
|
||||
if (
|
||||
isinstance(obj.output, str)
|
||||
and len(obj.output) > 1
|
||||
and obj.output[0] == "{"
|
||||
and obj.output[-1] == "}"
|
||||
):
|
||||
yield OUTPUT_MIME_TYPE, JSON
|
||||
|
||||
|
||||
def _get_span_status(obj: Span[Any]) -> Status:
|
||||
if error := getattr(obj, "error", None):
|
||||
return Status(
|
||||
status_code=StatusCode.ERROR,
|
||||
description=f"{error.get('message')}: {error.get('data')}",
|
||||
)
|
||||
else:
|
||||
return Status(StatusCode.OK)
|
||||
|
||||
|
||||
def _flatten(
|
||||
obj: Mapping[str, Any],
|
||||
prefix: str = "",
|
||||
) -> Iterator[tuple[str, AttributeValue]]:
|
||||
for key, value in obj.items():
|
||||
if isinstance(value, dict):
|
||||
yield from _flatten(value, f"{prefix}{key}.")
|
||||
elif isinstance(value, (str, int, float, bool, str)):
|
||||
yield f"{prefix}{key}", value
|
||||
else:
|
||||
yield f"{prefix}{key}", str(value)
|
||||
|
||||
|
||||
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
|
||||
INPUT_VALUE = SpanAttributes.INPUT_VALUE
|
||||
LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
|
||||
LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS
|
||||
LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
|
||||
LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
|
||||
LLM_PROVIDER = SpanAttributes.LLM_PROVIDER
|
||||
LLM_SYSTEM = SpanAttributes.LLM_SYSTEM
|
||||
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
|
||||
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
|
||||
LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL
|
||||
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ = (
|
||||
SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ
|
||||
)
|
||||
LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING = (
|
||||
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING
|
||||
)
|
||||
LLM_TOOLS = SpanAttributes.LLM_TOOLS
|
||||
METADATA = SpanAttributes.METADATA
|
||||
OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND
|
||||
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
|
||||
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
|
||||
TOOL_DESCRIPTION = SpanAttributes.TOOL_DESCRIPTION
|
||||
TOOL_NAME = SpanAttributes.TOOL_NAME
|
||||
TOOL_PARAMETERS = SpanAttributes.TOOL_PARAMETERS
|
||||
GRAPH_NODE_ID = SpanAttributes.GRAPH_NODE_ID
|
||||
GRAPH_NODE_PARENT_ID = SpanAttributes.GRAPH_NODE_PARENT_ID
|
||||
|
||||
MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
|
||||
MESSAGE_CONTENTS = MessageAttributes.MESSAGE_CONTENTS
|
||||
MESSAGE_CONTENT_IMAGE = MessageContentAttributes.MESSAGE_CONTENT_IMAGE
|
||||
MESSAGE_CONTENT_TEXT = MessageContentAttributes.MESSAGE_CONTENT_TEXT
|
||||
MESSAGE_CONTENT_TYPE = MessageContentAttributes.MESSAGE_CONTENT_TYPE
|
||||
MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON = (
|
||||
MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON
|
||||
)
|
||||
MESSAGE_FUNCTION_CALL_NAME = MessageAttributes.MESSAGE_FUNCTION_CALL_NAME
|
||||
MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
|
||||
MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
|
||||
MESSAGE_TOOL_CALL_ID = MessageAttributes.MESSAGE_TOOL_CALL_ID
|
||||
|
||||
TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
|
||||
TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
|
||||
TOOL_CALL_ID = ToolCallAttributes.TOOL_CALL_ID
|
||||
|
||||
TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
|
||||
|
||||
JSON = OpenInferenceMimeTypeValues.JSON.value
|
||||
@@ -50,7 +50,7 @@ nltk==3.9.1
|
||||
Office365-REST-Python-Client==2.5.9
|
||||
oauthlib==3.2.2
|
||||
openai==2.6.1
|
||||
openpyxl==3.0.10
|
||||
openpyxl==3.1.5
|
||||
passlib==1.7.4
|
||||
playwright==1.55.0
|
||||
psutil==5.9.5
|
||||
@@ -84,7 +84,7 @@ supervisor==4.2.5
|
||||
RapidFuzz==3.13.0
|
||||
tiktoken==0.7.0
|
||||
timeago==1.0.16
|
||||
types-openpyxl==3.0.4.7
|
||||
types-openpyxl==3.1.5.20250919
|
||||
unstructured==0.15.1
|
||||
unstructured-client==0.25.4
|
||||
uvicorn==0.35.0
|
||||
@@ -108,7 +108,7 @@ exa_py==1.15.4
|
||||
braintrust[openai-agents]==0.2.6
|
||||
braintrust-langchain==0.0.4
|
||||
openai-agents==0.4.2
|
||||
langfuse==3.10.0
|
||||
langfuse==3.7.0
|
||||
nest_asyncio==1.6.0
|
||||
openinference-instrumentation-openai-agents==1.3.0
|
||||
opentelemetry-proto==1.38.0
|
||||
|
||||
@@ -19,14 +19,3 @@ class RerankerProvider(str, Enum):
|
||||
class EmbedTextType(str, Enum):
|
||||
QUERY = "query"
|
||||
PASSAGE = "passage"
|
||||
|
||||
|
||||
class WebSearchProviderType(str, Enum):
|
||||
GOOGLE_PSE = "google_pse"
|
||||
SERPER = "serper"
|
||||
EXA = "exa"
|
||||
|
||||
|
||||
class WebContentProviderType(str, Enum):
|
||||
ONYX_WEB_CRAWLER = "onyx_web_crawler"
|
||||
FIRECRAWL = "firecrawl"
|
||||
|
||||
@@ -154,7 +154,7 @@ def test_versions_endpoint(client: TestClient) -> None:
|
||||
assert migration["onyx"] == "airgapped-intfloat-nomic-migration"
|
||||
assert migration["relational_db"] == "postgres:15.2-alpine"
|
||||
assert migration["index"] == "vespaengine/vespa:8.277.17"
|
||||
assert migration["nginx"] == "nginx:1.25.5-alpine"
|
||||
assert migration["nginx"] == "nginx:1.23.4-alpine"
|
||||
|
||||
# Verify versions are different between stable and dev
|
||||
assert stable["onyx"] != dev["onyx"], "Stable and dev versions should be different"
|
||||
|
||||
@@ -2,8 +2,6 @@ import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -520,52 +518,3 @@ class TestHubSpotConnector:
|
||||
note_url = connector._get_object_url("notes", "44444")
|
||||
expected_note_url = "https://app.hubspot.com/contacts/12345/objects/0-4/44444"
|
||||
assert note_url == expected_note_url
|
||||
|
||||
def test_ticket_with_none_content(self) -> None:
|
||||
"""Test that tickets with None content are handled gracefully."""
|
||||
connector = HubSpotConnector(object_types=["tickets"], batch_size=10)
|
||||
connector._access_token = "mock_token"
|
||||
connector._portal_id = "mock_portal_id"
|
||||
|
||||
# Create a mock ticket with None content
|
||||
mock_ticket = MagicMock()
|
||||
mock_ticket.id = "12345"
|
||||
mock_ticket.properties = {
|
||||
"subject": "Test Ticket",
|
||||
"content": None, # This is the key test case
|
||||
"hs_ticket_priority": "HIGH",
|
||||
}
|
||||
mock_ticket.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
# Mock the HubSpot API client
|
||||
mock_api_client = MagicMock()
|
||||
|
||||
# Mock the API calls and associated object methods
|
||||
with patch(
|
||||
"onyx.connectors.hubspot.connector.HubSpot"
|
||||
) as MockHubSpot, patch.object(
|
||||
connector, "_paginated_results"
|
||||
) as mock_paginated, patch.object(
|
||||
connector, "_get_associated_objects", return_value=[]
|
||||
), patch.object(
|
||||
connector, "_get_associated_notes", return_value=[]
|
||||
):
|
||||
MockHubSpot.return_value = mock_api_client
|
||||
mock_paginated.return_value = iter([mock_ticket])
|
||||
|
||||
# This should not raise a validation error
|
||||
document_batches = connector._process_tickets()
|
||||
first_batch = next(document_batches, None)
|
||||
|
||||
# Verify the document was created successfully
|
||||
assert first_batch is not None
|
||||
assert len(first_batch) == 1
|
||||
|
||||
doc = first_batch[0]
|
||||
assert doc.id == "hubspot_ticket_12345"
|
||||
assert doc.semantic_identifier == "Test Ticket"
|
||||
|
||||
# Verify the first section has an empty string, not None
|
||||
assert len(doc.sections) > 0
|
||||
assert doc.sections[0].text == "" # Should be empty string, not None
|
||||
assert doc.sections[0].link is not None
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import os
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
@@ -10,11 +8,9 @@ os.environ["MODEL_SERVER_HOST"] = "disabled"
|
||||
os.environ["MODEL_SERVER_PORT"] = "9000"
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.configs.constants import FederatedConnectorSource
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.federated.slack_search import fetch_and_cache_channel_metadata
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import FederatedConnector
|
||||
from onyx.db.models import LLMProvider
|
||||
@@ -577,174 +573,3 @@ class TestSlackBotFederatedSearch:
|
||||
|
||||
finally:
|
||||
self._teardown_common_mocks(patches)
|
||||
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.get_redis_client")
|
||||
@patch("onyx.context.search.federated.slack_search.WebClient")
|
||||
def test_missing_scope_resilience(
|
||||
mock_web_client: Mock, mock_redis_client: Mock
|
||||
) -> None:
|
||||
"""Test that missing scopes are handled gracefully"""
|
||||
# Setup mock Redis client
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get.return_value = None # Cache miss
|
||||
mock_redis_client.return_value = mock_redis
|
||||
|
||||
# Setup mock Slack client that simulates missing_scope error
|
||||
mock_client_instance = MagicMock()
|
||||
mock_web_client.return_value = mock_client_instance
|
||||
|
||||
# Track which channel types were attempted
|
||||
attempted_types: list[str] = []
|
||||
|
||||
def mock_conversations_list(types: str | None = None, **kwargs: Any) -> MagicMock:
|
||||
if types:
|
||||
attempted_types.append(types)
|
||||
|
||||
# First call: all types including mpim -> missing_scope error
|
||||
if types and "mpim" in types:
|
||||
error_response = {
|
||||
"ok": False,
|
||||
"error": "missing_scope",
|
||||
"needed": "mpim:read",
|
||||
"provided": "identify,channels:history,channels:read,groups:read,im:read,search:read",
|
||||
}
|
||||
raise SlackApiError("missing_scope", error_response)
|
||||
|
||||
# Second call: without mpim -> success
|
||||
mock_response = MagicMock()
|
||||
mock_response.validate.return_value = None
|
||||
mock_response.data = {
|
||||
"channels": [
|
||||
{
|
||||
"id": "C1234567890",
|
||||
"name": "general",
|
||||
"is_channel": True,
|
||||
"is_private": False,
|
||||
"is_group": False,
|
||||
"is_mpim": False,
|
||||
"is_im": False,
|
||||
"is_member": True,
|
||||
},
|
||||
{
|
||||
"id": "D9876543210",
|
||||
"name": "",
|
||||
"is_channel": False,
|
||||
"is_private": False,
|
||||
"is_group": False,
|
||||
"is_mpim": False,
|
||||
"is_im": True,
|
||||
"is_member": True,
|
||||
},
|
||||
],
|
||||
"response_metadata": {},
|
||||
}
|
||||
return mock_response
|
||||
|
||||
mock_client_instance.conversations_list.side_effect = mock_conversations_list
|
||||
|
||||
# Call the function
|
||||
result = fetch_and_cache_channel_metadata(
|
||||
access_token="xoxp-test-token",
|
||||
team_id="T1234567890",
|
||||
include_private=True,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
# Should have attempted twice: once with mpim, once without
|
||||
assert len(attempted_types) == 2, f"Expected 2 attempts, got {len(attempted_types)}"
|
||||
assert "mpim" in attempted_types[0], "First attempt should include mpim"
|
||||
assert "mpim" not in attempted_types[1], "Second attempt should not include mpim"
|
||||
|
||||
# Should have successfully returned channels despite missing scope
|
||||
assert len(result) == 2, f"Expected 2 channels, got {len(result)}"
|
||||
assert "C1234567890" in result, "Should have public channel"
|
||||
assert "D9876543210" in result, "Should have DM channel"
|
||||
|
||||
# Verify channel metadata structure
|
||||
assert result["C1234567890"]["name"] == "general"
|
||||
assert result["C1234567890"]["type"] == "public_channel"
|
||||
assert result["D9876543210"]["type"] == "im"
|
||||
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.get_redis_client")
|
||||
@patch("onyx.context.search.federated.slack_search.WebClient")
|
||||
def test_multiple_missing_scopes_resilience(
|
||||
mock_web_client: Mock, mock_redis_client: Mock
|
||||
) -> None:
|
||||
"""Test handling multiple missing scopes gracefully"""
|
||||
# Setup mock Redis client
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get.return_value = None # Cache miss
|
||||
mock_redis_client.return_value = mock_redis
|
||||
|
||||
# Setup mock Slack client
|
||||
mock_client_instance = MagicMock()
|
||||
mock_web_client.return_value = mock_client_instance
|
||||
|
||||
# Track attempts
|
||||
attempted_types: list[str] = []
|
||||
|
||||
def mock_conversations_list(types: str | None = None, **kwargs: Any) -> MagicMock:
|
||||
if types:
|
||||
attempted_types.append(types)
|
||||
|
||||
# First: mpim missing
|
||||
if types and "mpim" in types:
|
||||
error_response = {
|
||||
"ok": False,
|
||||
"error": "missing_scope",
|
||||
"needed": "mpim:read",
|
||||
"provided": "identify,channels:history,channels:read,groups:read",
|
||||
}
|
||||
raise SlackApiError("missing_scope", error_response)
|
||||
|
||||
# Second: im missing
|
||||
if types and "im" in types:
|
||||
error_response = {
|
||||
"ok": False,
|
||||
"error": "missing_scope",
|
||||
"needed": "im:read",
|
||||
"provided": "identify,channels:history,channels:read,groups:read",
|
||||
}
|
||||
raise SlackApiError("missing_scope", error_response)
|
||||
|
||||
# Third: success with only public and private channels
|
||||
mock_response = MagicMock()
|
||||
mock_response.validate.return_value = None
|
||||
mock_response.data = {
|
||||
"channels": [
|
||||
{
|
||||
"id": "C1234567890",
|
||||
"name": "general",
|
||||
"is_channel": True,
|
||||
"is_private": False,
|
||||
"is_group": False,
|
||||
"is_mpim": False,
|
||||
"is_im": False,
|
||||
"is_member": True,
|
||||
}
|
||||
],
|
||||
"response_metadata": {},
|
||||
}
|
||||
return mock_response
|
||||
|
||||
mock_client_instance.conversations_list.side_effect = mock_conversations_list
|
||||
|
||||
# Call the function
|
||||
result = fetch_and_cache_channel_metadata(
|
||||
access_token="xoxp-test-token",
|
||||
team_id="T1234567890",
|
||||
include_private=True,
|
||||
)
|
||||
|
||||
# Should gracefully handle multiple missing scopes
|
||||
assert len(attempted_types) == 3, f"Expected 3 attempts, got {len(attempted_types)}"
|
||||
assert "mpim" in attempted_types[0], "First attempt should include mpim"
|
||||
assert "mpim" not in attempted_types[1], "Second attempt should not include mpim"
|
||||
assert "im" in attempted_types[1], "Second attempt should include im"
|
||||
assert "im" not in attempted_types[2], "Third attempt should not include im"
|
||||
|
||||
# Should still return available channels
|
||||
assert len(result) == 1, f"Expected 1 channel, got {len(result)}"
|
||||
assert result["C1234567890"]["name"] == "general"
|
||||
|
||||
@@ -140,22 +140,6 @@ class CCPairManager:
|
||||
)
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def unpause_cc_pair(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
result = requests.put(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/status",
|
||||
json={"status": "ACTIVE"},
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
cc_pair: DATestCCPair,
|
||||
|
||||
@@ -382,7 +382,6 @@ def test_mock_connector_checkpoint_recovery(
|
||||
assert response.status_code == 200
|
||||
|
||||
# Create CC Pair and run initial indexing attempt
|
||||
# Note: Setting refresh_freq to allow manual retrigger after failure
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
name=f"mock-connector-checkpoint-{uuid.uuid4()}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
@@ -392,7 +391,6 @@ def test_mock_connector_checkpoint_recovery(
|
||||
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
refresh_freq=60 * 60, # 1 hour
|
||||
)
|
||||
|
||||
# Wait for first index attempt to complete
|
||||
@@ -465,14 +463,10 @@ def test_mock_connector_checkpoint_recovery(
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# After the failure, the connector is in repeated error state and paused.
|
||||
# Set the manual indexing trigger first (while paused), then unpause.
|
||||
# This ensures the trigger is set before CHECK_FOR_INDEXING runs, which will
|
||||
# prevent the connector from being re-paused when repeated error state is detected.
|
||||
# Trigger another indexing attempt
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=False, user_performing_action=admin_user
|
||||
)
|
||||
CCPairManager.unpause_cc_pair(cc_pair, user_performing_action=admin_user)
|
||||
recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempts_to_ignore=[initial_index_attempt.id],
|
||||
|
||||
@@ -11,7 +11,6 @@ from onyx.connectors.mock_connector.connector import MockConnectorCheckpoint
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_HOST
|
||||
from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_PORT
|
||||
@@ -122,11 +121,6 @@ def test_repeated_error_state_detection_and_recovery(
|
||||
)
|
||||
assert cc_pair_obj is not None
|
||||
if cc_pair_obj.in_repeated_error_state:
|
||||
# Verify the connector is also paused to prevent further indexing attempts
|
||||
assert cc_pair_obj.status == ConnectorCredentialPairStatus.PAUSED, (
|
||||
f"Expected status to be PAUSED when in repeated error state, "
|
||||
f"but got {cc_pair_obj.status}"
|
||||
)
|
||||
break
|
||||
|
||||
if time.monotonic() - start_time > 30:
|
||||
@@ -151,13 +145,10 @@ def test_repeated_error_state_detection_and_recovery(
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Set the manual indexing trigger first (while paused), then unpause.
|
||||
# This ensures the trigger is set before CHECK_FOR_INDEXING runs, which will
|
||||
# prevent the connector from being re-paused when repeated error state is detected.
|
||||
# Run another indexing attempt that should succeed
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=True, user_performing_action=admin_user
|
||||
)
|
||||
CCPairManager.unpause_cc_pair(cc_pair, user_performing_action=admin_user)
|
||||
|
||||
recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
|
||||
@@ -140,84 +140,7 @@ def tool_call_chunk(
|
||||
)
|
||||
|
||||
|
||||
class FakeErrorTool(Tool):
|
||||
"""Base fake tool for testing."""
|
||||
|
||||
def __init__(self, tool_name: str, tool_id: int = 1):
|
||||
self._tool_name = tool_name
|
||||
self._tool_id = tool_id
|
||||
self.calls: list[dict[str, Any]] = []
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._tool_id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._tool_name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return f"{self._tool_name} tool"
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return self._tool_name.replace("_", " ").title()
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self._tool_name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"queries": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
},
|
||||
"required": ["queries"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def run_v2(
|
||||
self,
|
||||
run_context: RunContextWrapper[Any],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
raise Exception("Error running tool")
|
||||
|
||||
def build_tool_message_content(self, *args: Any) -> str:
|
||||
return ""
|
||||
|
||||
def get_args_for_non_tool_calling_llm(
|
||||
self, query: Any, history: Any, llm: Any, force_run: bool = False
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
def run(
|
||||
self, override_kwargs: Any = None, **llm_kwargs: Any
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
raise NotImplementedError
|
||||
yield # Make this a generator
|
||||
|
||||
def final_result(self, *args: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def build_next_prompt(
|
||||
self,
|
||||
prompt_builder: Any,
|
||||
tool_call_summary: Any,
|
||||
tool_responses: Any,
|
||||
using_tool_calling_llm: Any,
|
||||
) -> Any:
|
||||
return prompt_builder
|
||||
|
||||
|
||||
# Fake tools for testing
|
||||
class FakeTool(Tool):
|
||||
"""Base fake tool for testing."""
|
||||
|
||||
@@ -269,9 +192,7 @@ class FakeTool(Tool):
|
||||
) -> Any:
|
||||
queries = kwargs.get("queries", [])
|
||||
self.calls.append({"queries": queries})
|
||||
context = run_context.context
|
||||
flag_name = f"{self._tool_name}_called"
|
||||
context[flag_name] = True
|
||||
run_context.context[f"{self._tool_name}_called"] = True
|
||||
return f"{self.display_name} results for: {', '.join(queries)}"
|
||||
|
||||
def build_tool_message_content(self, *args: Any) -> str:
|
||||
@@ -311,9 +232,3 @@ def fake_internal_search_tool() -> FakeTool:
|
||||
def fake_web_search_tool() -> FakeTool:
|
||||
"""Fixture providing a fake web search tool."""
|
||||
return FakeTool("web_search", tool_id=2)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_error_tool() -> FakeErrorTool:
|
||||
"""Fixture providing a fake error tool."""
|
||||
return FakeErrorTool("error_tool", tool_id=3)
|
||||
|
||||
@@ -7,7 +7,6 @@ from onyx.agents.agent_framework.models import ToolCallStreamItem
|
||||
from onyx.agents.agent_framework.query import query
|
||||
from onyx.llm.message_types import ChatCompletionMessage
|
||||
from onyx.llm.model_response import ModelResponseStream
|
||||
from tests.unit.onyx.agents.agent_framework.conftest import FakeErrorTool
|
||||
from tests.unit.onyx.agents.agent_framework.conftest import FakeTool
|
||||
from tests.unit.onyx.agents.agent_framework.conftest import stream_chunk
|
||||
from tests.unit.onyx.agents.agent_framework.conftest import tool_call_chunk
|
||||
@@ -464,165 +463,3 @@ def test_query_emits_reasoning_done_before_message_start(
|
||||
assert len(result.new_messages_stateful) == 1
|
||||
assert result.new_messages_stateful[0]["role"] == "assistant"
|
||||
assert result.new_messages_stateful[0]["content"] == "Based on my analysis"
|
||||
|
||||
|
||||
def test_query_treats_tool_like_message_as_tool_call(
|
||||
fake_llm: Callable[[list[ModelResponseStream]], Any],
|
||||
fake_internal_search_tool: FakeTool,
|
||||
) -> None:
|
||||
"""Test that query treats JSON message content as a tool call when structured_response_format is empty."""
|
||||
stream_id = "chatcmpl-tool-like-message"
|
||||
|
||||
responses = [
|
||||
stream_chunk(
|
||||
id=stream_id,
|
||||
created="1762545000",
|
||||
content='{"name": "internal_search", ',
|
||||
),
|
||||
stream_chunk(
|
||||
id=stream_id,
|
||||
created="1762545000",
|
||||
content='"arguments": {"queries": ["new agent", "framework"]}',
|
||||
),
|
||||
stream_chunk(id=stream_id, created="1762545000", content="}"),
|
||||
stream_chunk(id=stream_id, created="1762545000", finish_reason="stop"),
|
||||
]
|
||||
|
||||
llm = fake_llm(responses)
|
||||
context: dict[str, bool] = {}
|
||||
messages: list[ChatCompletionMessage] = [
|
||||
{"role": "user", "content": "tell me about agents"}
|
||||
]
|
||||
|
||||
result = query(
|
||||
llm,
|
||||
messages,
|
||||
tools=[fake_internal_search_tool],
|
||||
context=context,
|
||||
tool_choice=None,
|
||||
structured_response_format={},
|
||||
)
|
||||
|
||||
events = list(result.stream)
|
||||
run_item_events = [e for e in events if isinstance(e, RunItemStreamEvent)]
|
||||
|
||||
tool_call_events = [e for e in run_item_events if e.type == "tool_call"]
|
||||
assert len(tool_call_events) == 1
|
||||
assert tool_call_events[0].details is not None
|
||||
assert isinstance(tool_call_events[0].details, ToolCallStreamItem)
|
||||
assert tool_call_events[0].details.name == "internal_search"
|
||||
assert (
|
||||
tool_call_events[0].details.arguments
|
||||
== '{"queries": ["new agent", "framework"]}'
|
||||
)
|
||||
|
||||
assert len(fake_internal_search_tool.calls) == 1
|
||||
assert fake_internal_search_tool.calls[0]["queries"] == ["new agent", "framework"]
|
||||
assert context["internal_search_called"] is True
|
||||
|
||||
tool_output_events = [e for e in run_item_events if e.type == "tool_call_output"]
|
||||
assert len(tool_output_events) == 1
|
||||
assert tool_output_events[0].details is not None
|
||||
assert isinstance(tool_output_events[0].details, ToolCallOutputStreamItem)
|
||||
assert (
|
||||
tool_output_events[0].details.output
|
||||
== "Internal Search results for: new agent, framework"
|
||||
)
|
||||
|
||||
assert len(result.new_messages_stateful) == 2
|
||||
assert result.new_messages_stateful[0]["role"] == "assistant"
|
||||
assert result.new_messages_stateful[0]["content"] is None
|
||||
assert len(result.new_messages_stateful[0]["tool_calls"]) == 1
|
||||
assert (
|
||||
result.new_messages_stateful[0]["tool_calls"][0]["function"]["name"]
|
||||
== "internal_search"
|
||||
)
|
||||
assert result.new_messages_stateful[1]["role"] == "tool"
|
||||
assert (
|
||||
result.new_messages_stateful[1]["tool_call_id"]
|
||||
== tool_call_events[0].details.call_id
|
||||
)
|
||||
assert result.new_messages_stateful[1]["content"] == (
|
||||
"Internal Search results for: new agent, framework"
|
||||
)
|
||||
|
||||
|
||||
def test_query_handles_tool_error_gracefully(
|
||||
fake_llm: Callable[[list[ModelResponseStream]], Any],
|
||||
fake_error_tool: FakeErrorTool,
|
||||
) -> None:
|
||||
"""Test that query handles tool errors gracefully and treats the error as the tool call response."""
|
||||
call_id = "toolu_01ErrorToolCall123"
|
||||
stream_id = "chatcmpl-error-test-12345"
|
||||
|
||||
responses = [
|
||||
stream_chunk(
|
||||
id=stream_id,
|
||||
created="1762545000",
|
||||
content="",
|
||||
tool_calls=[tool_call_chunk(id=call_id, name="error_tool", arguments="")],
|
||||
),
|
||||
stream_chunk(
|
||||
id=stream_id,
|
||||
created="1762545000",
|
||||
content="",
|
||||
tool_calls=[tool_call_chunk(arguments="")],
|
||||
),
|
||||
*[
|
||||
stream_chunk(
|
||||
id=stream_id,
|
||||
created="1762545000",
|
||||
content="",
|
||||
tool_calls=[tool_call_chunk(arguments=arg)],
|
||||
)
|
||||
for arg in ['{"queries": ', '["test"]}']
|
||||
],
|
||||
stream_chunk(id=stream_id, created="1762545000", finish_reason="tool_calls"),
|
||||
]
|
||||
|
||||
llm = fake_llm(responses)
|
||||
context: dict[str, bool] = {}
|
||||
messages: list[ChatCompletionMessage] = [
|
||||
{"role": "user", "content": "call the error tool"}
|
||||
]
|
||||
|
||||
result = query(
|
||||
llm,
|
||||
messages,
|
||||
tools=[fake_error_tool],
|
||||
context=context,
|
||||
tool_choice=None,
|
||||
)
|
||||
events = list(result.stream)
|
||||
|
||||
run_item_events = [e for e in events if isinstance(e, RunItemStreamEvent)]
|
||||
|
||||
# Verify tool_call event was emitted
|
||||
tool_call_events = [e for e in run_item_events if e.type == "tool_call"]
|
||||
assert len(tool_call_events) == 1
|
||||
assert tool_call_events[0].details is not None
|
||||
assert isinstance(tool_call_events[0].details, ToolCallStreamItem)
|
||||
assert tool_call_events[0].details.call_id == call_id
|
||||
assert tool_call_events[0].details.name == "error_tool"
|
||||
|
||||
# Verify tool_call_output event was emitted with the error message
|
||||
tool_output_events = [e for e in run_item_events if e.type == "tool_call_output"]
|
||||
assert len(tool_output_events) == 1
|
||||
assert tool_output_events[0].details is not None
|
||||
assert isinstance(tool_output_events[0].details, ToolCallOutputStreamItem)
|
||||
assert tool_output_events[0].details.call_id == call_id
|
||||
assert tool_output_events[0].details.output == "Error: Error running tool"
|
||||
|
||||
# Verify new_messages contains expected assistant and tool messages
|
||||
assert len(result.new_messages_stateful) == 2
|
||||
assert result.new_messages_stateful[0]["role"] == "assistant"
|
||||
assert result.new_messages_stateful[0]["content"] is None
|
||||
assert len(result.new_messages_stateful[0]["tool_calls"]) == 1
|
||||
assert result.new_messages_stateful[0]["tool_calls"][0]["id"] == call_id
|
||||
assert (
|
||||
result.new_messages_stateful[0]["tool_calls"][0]["function"]["name"]
|
||||
== "error_tool"
|
||||
)
|
||||
assert result.new_messages_stateful[1]["role"] == "tool"
|
||||
assert result.new_messages_stateful[1]["tool_call_id"] == call_id
|
||||
assert result.new_messages_stateful[1]["content"] == "Error: Error running tool"
|
||||
|
||||
@@ -2,9 +2,6 @@ from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage as LangChainSystemMessage
|
||||
|
||||
from onyx.agents.agent_framework.message_format import (
|
||||
base_messages_to_chat_completion_msgs,
|
||||
)
|
||||
from onyx.agents.agent_sdk.message_format import base_messages_to_agent_sdk_msgs
|
||||
from onyx.agents.agent_sdk.message_types import AssistantMessageWithContent
|
||||
from onyx.agents.agent_sdk.message_types import InputTextContent
|
||||
@@ -215,46 +212,3 @@ def test_assistant_message_list_content_non_responses_api() -> None:
|
||||
assert second_content["type"] == "input_text"
|
||||
text_content2: InputTextContent = second_content # type: ignore[assignment]
|
||||
assert text_content2["text"] == "It's known for the Eiffel Tower."
|
||||
|
||||
|
||||
def test_base_messages_to_chat_completion_msgs_basic() -> None:
|
||||
"""Ensure system and user messages convert to chat completion format."""
|
||||
system_message = LangChainSystemMessage(
|
||||
content="You are a helpful assistant.",
|
||||
additional_kwargs={},
|
||||
response_metadata={},
|
||||
)
|
||||
human_message = HumanMessage(
|
||||
content="hello",
|
||||
additional_kwargs={},
|
||||
response_metadata={},
|
||||
)
|
||||
|
||||
results = base_messages_to_chat_completion_msgs([system_message, human_message])
|
||||
|
||||
assert results == [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "hello"},
|
||||
]
|
||||
|
||||
|
||||
def test_base_messages_to_chat_completion_msgs_with_tool_call() -> None:
|
||||
"""Ensure assistant messages with tool calls are preserved."""
|
||||
ai_message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_1",
|
||||
"name": "internal_search",
|
||||
"args": {"query": "test"},
|
||||
}
|
||||
],
|
||||
additional_kwargs={},
|
||||
response_metadata={},
|
||||
)
|
||||
|
||||
results = base_messages_to_chat_completion_msgs([ai_message])
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["role"] == "assistant"
|
||||
assert results[0]["tool_calls"][0]["function"]["name"] == "internal_search"
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.onyx_web_crawler_client import (
|
||||
OnyxWebCrawlerClient,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
|
||||
build_content_provider_from_config,
|
||||
)
|
||||
from shared_configs.enums import WebContentProviderType
|
||||
|
||||
|
||||
def test_build_content_provider_returns_onyx_crawler() -> None:
|
||||
provider = build_content_provider_from_config(
|
||||
provider_type=WebContentProviderType.ONYX_WEB_CRAWLER,
|
||||
api_key=None,
|
||||
config={"timeout_seconds": "20"},
|
||||
provider_name="Built-in",
|
||||
)
|
||||
assert isinstance(provider, OnyxWebCrawlerClient)
|
||||
@@ -1,228 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.auth import users as users_module
|
||||
|
||||
|
||||
def test_extract_email_requires_valid_format() -> None:
|
||||
"""Helper should validate email format before returning value."""
|
||||
assert users_module._extract_email_from_jwt({"email": "invalid@"}) is None
|
||||
result = users_module._extract_email_from_jwt(
|
||||
{"preferred_username": "ValidUser@Example.COM"}
|
||||
)
|
||||
assert result == "validuser@example.com"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_user_updates_expiry(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Existing web-login users should be returned and their expiry synced."""
|
||||
monkeypatch.setattr(users_module, "TRACK_EXTERNAL_IDP_EXPIRY", True)
|
||||
invited_checked: dict[str, str] = {}
|
||||
|
||||
def mark_invited(value: str) -> None:
|
||||
invited_checked["email"] = value
|
||||
|
||||
domain_checked: dict[str, str] = {}
|
||||
|
||||
def mark_domain(value: str) -> None:
|
||||
domain_checked["email"] = value
|
||||
|
||||
monkeypatch.setattr(users_module, "verify_email_is_invited", mark_invited)
|
||||
monkeypatch.setattr(users_module, "verify_email_domain", mark_domain)
|
||||
email = "jwt-user@example.com"
|
||||
exp_value = 1_700_000_000
|
||||
payload: dict[str, Any] = {"email": email, "exp": exp_value}
|
||||
|
||||
existing_user = MagicMock()
|
||||
existing_user.email = email
|
||||
existing_user.oidc_expiry = None
|
||||
existing_user.role.is_web_login.return_value = True # type: ignore[attr-defined]
|
||||
|
||||
manager_holder: dict[str, Any] = {}
|
||||
|
||||
class StubUserManager:
|
||||
def __init__(self, _user_db: object) -> None:
|
||||
manager_holder["instance"] = self # type: ignore[assignment]
|
||||
self.user_db = MagicMock()
|
||||
self.user_db.update = AsyncMock()
|
||||
|
||||
async def get_by_email(self, email_arg: str) -> MagicMock:
|
||||
assert email_arg == email
|
||||
return existing_user
|
||||
|
||||
monkeypatch.setattr(users_module, "UserManager", StubUserManager)
|
||||
monkeypatch.setattr(
|
||||
users_module,
|
||||
"SQLAlchemyUserAdminDB",
|
||||
lambda *args, **kwargs: MagicMock(),
|
||||
)
|
||||
|
||||
result = await users_module._get_or_create_user_from_jwt(
|
||||
payload, MagicMock(), MagicMock()
|
||||
)
|
||||
|
||||
assert result is existing_user
|
||||
assert invited_checked["email"] == email
|
||||
assert domain_checked["email"] == email
|
||||
expected_expiry = datetime.fromtimestamp(exp_value, tz=timezone.utc)
|
||||
instance = manager_holder["instance"]
|
||||
instance.user_db.update.assert_awaited_once_with( # type: ignore[attr-defined]
|
||||
existing_user, {"oidc_expiry": expected_expiry}
|
||||
)
|
||||
assert existing_user.oidc_expiry == expected_expiry
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_user_skips_inactive(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Inactive users should not be re-authenticated via JWT."""
|
||||
monkeypatch.setattr(users_module, "TRACK_EXTERNAL_IDP_EXPIRY", True)
|
||||
monkeypatch.setattr(users_module, "verify_email_is_invited", lambda _: None)
|
||||
monkeypatch.setattr(users_module, "verify_email_domain", lambda _: None)
|
||||
|
||||
email = "inactive@example.com"
|
||||
payload: dict[str, Any] = {"email": email}
|
||||
|
||||
existing_user = MagicMock()
|
||||
existing_user.email = email
|
||||
existing_user.is_active = False
|
||||
existing_user.role.is_web_login.return_value = True # type: ignore[attr-defined]
|
||||
|
||||
class StubUserManager:
|
||||
def __init__(self, _user_db: object) -> None:
|
||||
self.user_db = MagicMock()
|
||||
self.user_db.update = AsyncMock()
|
||||
|
||||
async def get_by_email(self, email_arg: str) -> MagicMock:
|
||||
assert email_arg == email
|
||||
return existing_user
|
||||
|
||||
monkeypatch.setattr(users_module, "UserManager", StubUserManager)
|
||||
monkeypatch.setattr(
|
||||
users_module,
|
||||
"SQLAlchemyUserAdminDB",
|
||||
lambda *args, **kwargs: MagicMock(),
|
||||
)
|
||||
|
||||
result = await users_module._get_or_create_user_from_jwt(
|
||||
payload, MagicMock(), MagicMock()
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_user_handles_race_conditions(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""If provisioning races, newly inactive users should still be blocked."""
|
||||
monkeypatch.setattr(users_module, "TRACK_EXTERNAL_IDP_EXPIRY", True)
|
||||
monkeypatch.setattr(users_module, "verify_email_is_invited", lambda _: None)
|
||||
monkeypatch.setattr(users_module, "verify_email_domain", lambda _: None)
|
||||
|
||||
email = "race@example.com"
|
||||
payload: dict[str, Any] = {"email": email}
|
||||
|
||||
inactive_user = MagicMock()
|
||||
inactive_user.email = email
|
||||
inactive_user.is_active = False
|
||||
inactive_user.role.is_web_login.return_value = True # type: ignore[attr-defined]
|
||||
|
||||
class StubUserManager:
|
||||
def __init__(self, _user_db: object) -> None:
|
||||
self.user_db = MagicMock()
|
||||
self.user_db.update = AsyncMock()
|
||||
self.get_calls = 0
|
||||
|
||||
async def get_by_email(self, email_arg: str) -> MagicMock:
|
||||
assert email_arg == email
|
||||
if self.get_calls == 0:
|
||||
self.get_calls += 1
|
||||
raise users_module.exceptions.UserNotExists()
|
||||
self.get_calls += 1
|
||||
return inactive_user
|
||||
|
||||
async def create(self, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
raise users_module.exceptions.UserAlreadyExists()
|
||||
|
||||
monkeypatch.setattr(users_module, "UserManager", StubUserManager)
|
||||
monkeypatch.setattr(
|
||||
users_module,
|
||||
"SQLAlchemyUserAdminDB",
|
||||
lambda *args, **kwargs: MagicMock(),
|
||||
)
|
||||
|
||||
result = await users_module._get_or_create_user_from_jwt(
|
||||
payload, MagicMock(), MagicMock()
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_user_provisions_new_user(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A brand new JWT user should be provisioned automatically."""
|
||||
email = "new-user@example.com"
|
||||
payload = {"email": email}
|
||||
created_user = MagicMock()
|
||||
created_user.email = email
|
||||
created_user.oidc_expiry = None
|
||||
created_user.role.is_web_login.return_value = True # type: ignore[attr-defined]
|
||||
|
||||
monkeypatch.setattr(users_module, "TRACK_EXTERNAL_IDP_EXPIRY", False)
|
||||
monkeypatch.setattr(users_module, "generate_password", lambda: "TempPass123!")
|
||||
monkeypatch.setattr(users_module, "verify_email_is_invited", lambda _: None)
|
||||
monkeypatch.setattr(users_module, "verify_email_domain", lambda _: None)
|
||||
|
||||
recorded: dict[str, Any] = {}
|
||||
|
||||
class StubUserManager:
|
||||
def __init__(self, _user_db: object) -> None:
|
||||
recorded["instance"] = self
|
||||
self.user_db = MagicMock()
|
||||
self.user_db.update = AsyncMock()
|
||||
|
||||
async def get_by_email(self, _email: str) -> MagicMock:
|
||||
raise users_module.exceptions.UserNotExists()
|
||||
|
||||
async def create(self, user_create, safe=False, request=None): # type: ignore[no-untyped-def]
|
||||
recorded["user_create"] = user_create
|
||||
recorded["request"] = request
|
||||
return created_user
|
||||
|
||||
monkeypatch.setattr(users_module, "UserManager", StubUserManager)
|
||||
monkeypatch.setattr(
|
||||
users_module,
|
||||
"SQLAlchemyUserAdminDB",
|
||||
lambda *args, **kwargs: MagicMock(),
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
result = await users_module._get_or_create_user_from_jwt(
|
||||
payload, request, MagicMock()
|
||||
)
|
||||
|
||||
assert result is created_user
|
||||
created_payload = recorded["user_create"]
|
||||
assert created_payload.email == email
|
||||
assert created_payload.is_verified is True
|
||||
assert recorded["request"] is request
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_user_requires_email_claim() -> None:
|
||||
"""Tokens without a usable email claim should be ignored."""
|
||||
result = await users_module._get_or_create_user_from_jwt(
|
||||
{}, MagicMock(), MagicMock()
|
||||
)
|
||||
assert result is None
|
||||
@@ -6,7 +6,6 @@ from typing import Any
|
||||
import pytest
|
||||
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import RunContextWrapper
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
@@ -79,14 +78,6 @@ class FakeDummyTool(CustomTool):
|
||||
),
|
||||
)
|
||||
|
||||
def run_v2(
|
||||
self,
|
||||
run_context: RunContextWrapper[Any],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return "Tool executed successfully"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_dummy_tool() -> FakeDummyTool:
|
||||
@@ -99,8 +90,8 @@ __all__ = [
|
||||
"chat_turn_dependencies",
|
||||
"fake_db_session",
|
||||
"fake_dummy_tool",
|
||||
"fake_model",
|
||||
"fake_llm",
|
||||
"fake_model",
|
||||
"fake_redis_client",
|
||||
"fake_tools",
|
||||
]
|
||||
|
||||
@@ -298,5 +298,4 @@ def test_multiple_tool_calls_streaming(default_multi_llm: LitellmLLM) -> None:
|
||||
timeout=30,
|
||||
parallel_tool_calls=False,
|
||||
mock_response=MOCK_LLM_RESPONSE,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
|
||||
@@ -37,7 +37,7 @@ def test_partial_match_in_model_map() -> None:
|
||||
"supports_audio_output": False,
|
||||
"supports_function_calling": True,
|
||||
"supports_response_schema": True,
|
||||
"supports_system_messages": False,
|
||||
"supports_system_messages": True,
|
||||
"supports_tool_choice": True,
|
||||
"supports_vision": True,
|
||||
}
|
||||
@@ -46,13 +46,13 @@ def test_partial_match_in_model_map() -> None:
|
||||
assert result1 is not None
|
||||
for key, value in _EXPECTED_FIELDS.items():
|
||||
assert key in result1
|
||||
assert result1[key] == value, "Unexpected value for key: {}".format(key)
|
||||
assert result1[key] == value
|
||||
|
||||
result2 = find_model_obj(model_map, "openai", "gemma-3-27b-it")
|
||||
assert result2 is not None
|
||||
for key, value in _EXPECTED_FIELDS.items():
|
||||
assert key in result2
|
||||
assert result2[key] == value, "Unexpected value for key: {}".format(key)
|
||||
assert result2[key] == value
|
||||
|
||||
get_model_map.cache_clear()
|
||||
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
"""Unit tests for tracing setup functions."""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
|
||||
from onyx.configs import app_configs
|
||||
|
||||
|
||||
def test_setup_langfuse_if_creds_available_with_creds() -> None:
|
||||
"""Test that setup_langfuse_if_creds_available executes without error when credentials are available."""
|
||||
# Set credentials to non-empty values to avoid early return
|
||||
os.environ["LANGFUSE_SECRET_KEY"] = "test-secret-key"
|
||||
os.environ["LANGFUSE_PUBLIC_KEY"] = "test-public-key"
|
||||
|
||||
# Reload modules to pick up new environment variables
|
||||
importlib.reload(app_configs)
|
||||
from onyx.tracing import langfuse_tracing
|
||||
|
||||
importlib.reload(langfuse_tracing)
|
||||
|
||||
# Call the function - should not raise an error
|
||||
langfuse_tracing.setup_langfuse_if_creds_available()
|
||||
|
||||
# Clean up
|
||||
os.environ.pop("LANGFUSE_SECRET_KEY", None)
|
||||
os.environ.pop("LANGFUSE_PUBLIC_KEY", None)
|
||||
importlib.reload(app_configs)
|
||||
|
||||
|
||||
def test_setup_braintrust_if_creds_available_with_creds() -> None:
|
||||
"""Test that setup_braintrust_if_creds_available executes without error when credentials are available."""
|
||||
# Set credentials to non-empty values to avoid early return
|
||||
os.environ["BRAINTRUST_API_KEY"] = "test-api-key"
|
||||
os.environ["BRAINTRUST_PROJECT"] = "test-project"
|
||||
|
||||
# Reload modules to pick up new environment variables
|
||||
importlib.reload(app_configs)
|
||||
from onyx.tracing import braintrust_tracing
|
||||
|
||||
importlib.reload(braintrust_tracing)
|
||||
|
||||
# Call the function - should not raise an error
|
||||
braintrust_tracing.setup_braintrust_if_creds_available()
|
||||
|
||||
# Clean up environment variables
|
||||
os.environ.pop("BRAINTRUST_API_KEY", None)
|
||||
os.environ.pop("BRAINTRUST_PROJECT", None)
|
||||
importlib.reload(app_configs)
|
||||
@@ -109,7 +109,7 @@ Resources:
|
||||
Family: !Sub ${Environment}-${ServiceName}-TaskDefinition
|
||||
ContainerDefinitions:
|
||||
- Name: nginx
|
||||
Image: nginx:1.25.5-alpine
|
||||
Image: nginx:1.23.4-alpine
|
||||
Cpu: 0
|
||||
PortMappings:
|
||||
- Name: nginx-80-tcp
|
||||
|
||||
@@ -55,6 +55,8 @@ services:
|
||||
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
|
||||
- DISABLE_LITELLM_STREAMING=${DISABLE_LITELLM_STREAMING:-}
|
||||
- LITELLM_EXTRA_HEADERS=${LITELLM_EXTRA_HEADERS:-}
|
||||
- BING_API_KEY=${BING_API_KEY:-}
|
||||
- EXA_API_KEY=${EXA_API_KEY:-}
|
||||
- DISABLE_LLM_DOC_RELEVANCE=${DISABLE_LLM_DOC_RELEVANCE:-}
|
||||
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
|
||||
- TOKEN_BUDGET_GLOBALLY_ENABLED=${TOKEN_BUDGET_GLOBALLY_ENABLED:-}
|
||||
@@ -168,6 +170,8 @@ services:
|
||||
- DISABLE_LITELLM_STREAMING=${DISABLE_LITELLM_STREAMING:-}
|
||||
- LITELLM_EXTRA_HEADERS=${LITELLM_EXTRA_HEADERS:-}
|
||||
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
|
||||
- BING_API_KEY=${BING_API_KEY:-}
|
||||
- EXA_API_KEY=${EXA_API_KEY:-}
|
||||
# Query Options
|
||||
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-}
|
||||
- HYBRID_ALPHA=${HYBRID_ALPHA:-}
|
||||
@@ -400,7 +404,7 @@ services:
|
||||
max-file: "6"
|
||||
|
||||
nginx:
|
||||
image: nginx:1.25.5-alpine
|
||||
image: nginx:1.23.4-alpine
|
||||
restart: unless-stopped
|
||||
# nginx will immediately crash with `nginx: [emerg] host not found in upstream`
|
||||
# if api_server / web_server are not up
|
||||
|
||||
@@ -188,7 +188,7 @@ services:
|
||||
max-file: "6"
|
||||
|
||||
nginx:
|
||||
image: nginx:1.25.5-alpine
|
||||
image: nginx:1.23.4-alpine
|
||||
restart: unless-stopped
|
||||
# nginx will immediately crash with `nginx: [emerg] host not found in upstream`
|
||||
# if api_server / web_server are not up
|
||||
|
||||
@@ -212,7 +212,7 @@ services:
|
||||
max-file: "6"
|
||||
|
||||
nginx:
|
||||
image: nginx:1.25.5-alpine
|
||||
image: nginx:1.23.4-alpine
|
||||
restart: unless-stopped
|
||||
# nginx will immediately crash with `nginx: [emerg] host not found in upstream`
|
||||
# if api_server / web_server are not up
|
||||
|
||||
@@ -224,7 +224,7 @@ services:
|
||||
max-file: "6"
|
||||
|
||||
nginx:
|
||||
image: nginx:1.25.5-alpine
|
||||
image: nginx:1.23.4-alpine
|
||||
restart: unless-stopped
|
||||
# nginx will immediately crash with `nginx: [emerg] host not found in upstream`
|
||||
# if api_server / web_server are not up
|
||||
|
||||
@@ -193,7 +193,7 @@ services:
|
||||
max-file: "6"
|
||||
|
||||
nginx:
|
||||
image: nginx:1.25.5-alpine
|
||||
image: nginx:1.23.4-alpine
|
||||
restart: unless-stopped
|
||||
# nginx will immediately crash with `nginx: [emerg] host not found in upstream`
|
||||
# if api_server / web_server are not up
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# =============================================================================
|
||||
# This is the default configuration for Onyx. This file is fairly configurable,
|
||||
# also see env.template for possible settings.
|
||||
#
|
||||
#
|
||||
# PRODUCTION DEPLOYMENT CHECKLIST:
|
||||
# To convert this setup to a production deployment following best practices,
|
||||
# follow the checklist below. Note that there are other ways to secure the Onyx
|
||||
@@ -283,7 +283,7 @@ services:
|
||||
max-file: "6"
|
||||
|
||||
nginx:
|
||||
image: nginx:1.25.5-alpine
|
||||
image: nginx:1.23.4-alpine
|
||||
restart: unless-stopped
|
||||
# nginx will immediately crash with `nginx: [emerg] host not found in upstream`
|
||||
# if api_server / web_server are not up
|
||||
@@ -300,7 +300,7 @@ services:
|
||||
- NGINX_PROXY_SEND_TIMEOUT=${NGINX_PROXY_SEND_TIMEOUT:-300}
|
||||
- NGINX_PROXY_READ_TIMEOUT=${NGINX_PROXY_READ_TIMEOUT:-300}
|
||||
ports:
|
||||
- "${HOST_PORT_80:-80}:80"
|
||||
- "80:80"
|
||||
- "${HOST_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
|
||||
volumes:
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
@@ -319,7 +319,7 @@ services:
|
||||
# in order to make this work on both Unix-like systems and windows
|
||||
# PRODUCTION: Change to app.conf.template.prod for production nginx config
|
||||
command: >
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
|
||||
|
||||
cache:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user