Compare commits

..

54 Commits

Author SHA1 Message Date
Jessica Singh
d272ac252e mypy 2025-11-21 17:08:18 -08:00
Jessica Singh
4620fb1129 fix 2025-11-21 14:08:09 -08:00
Jessica Singh
7e334f0de1 use tokenizer 2025-11-21 14:04:07 -08:00
Jessica Singh
a381104d6f thread context for slack bot 2025-11-21 14:04:07 -08:00
Raunak Bhagat
93d2febf2a fix: Update buttons and stylings for new-team-modal (#6384) 2025-11-21 21:26:51 +00:00
Raunak Bhagat
693286411a feat: Responsiveness (#6383)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-11-21 21:01:27 +00:00
Justin Tahara
01a3064ca3 fix(testrail): Linting (#6382) 2025-11-21 10:50:08 -08:00
sashank-rayapudi-ai
09a80265ee feat(testrail): Implement a read-only custom connector for Testrail (#6084) 2025-11-21 10:16:40 -08:00
Wenxi
2a77481c1e test(onboarding): add playwright test for onboarding flow (#6376) 2025-11-21 12:23:37 -05:00
Jamison Lahman
6838487689 chore(deployments): separate flag for model-server, enable nightly (#6377) 2025-11-21 04:29:41 +00:00
Jamison Lahman
1713c24080 chore(docker): breakup model-server model layers (#6370) 2025-11-21 03:47:47 +00:00
Chris Weaver
73b3a2525a fix: chat switching (#6374) 2025-11-20 18:32:54 -08:00
Wenxi
59738d9243 feat: cross link cookies (#6371) 2025-11-21 02:03:52 +00:00
Wenxi
c0ff9c623b feat(APIs): web search apis and indexed sources api (#6363) 2025-11-20 20:23:06 -05:00
Jessica Singh
c03979209a fix(ui): icon alignment + color (#6373) 2025-11-20 17:16:10 -08:00
Justin Tahara
a0b7639693 fix(connectors): Normalizing Onyx Metatada Connector Type (#6315) 2025-11-21 00:46:45 +00:00
Raunak Bhagat
e3ede3c186 fix: Sidebar fixes (#6358) 2025-11-21 00:35:31 +00:00
Jessica Singh
092dbebdf2 fix(migration): exa env var into db (#6366) 2025-11-21 00:12:09 +00:00
Justin Tahara
838e2fe924 chore(bedrock): Add better logging (#6368) 2025-11-20 23:38:19 +00:00
Chris Weaver
48e2bfa3eb chore: prevent sentry spam on fake issue (#6369) 2025-11-20 22:47:30 +00:00
Jamison Lahman
2a004ad257 chore(deployments): fix nightly tagging + add alerts & workflow_dispatch (#6367) 2025-11-20 21:55:24 +00:00
Wenxi
416c7fd75e chore(WebSearch): remove old web search env vars and update tooltip (#6365)
Co-authored-by: justin-tahara <justintahara@gmail.com>
2025-11-20 21:09:24 +00:00
Justin Tahara
a4372b461f feat(helm): Add Tolerations and Affinity (#6362) 2025-11-20 20:25:20 +00:00
mristau-alltrails
7eb13db6d9 SECURITY FIX: CVE-2023-38545 and CVE-2023-38546 (#6356) 2025-11-20 20:11:35 +00:00
Justin Tahara
c0075d5f59 fix(docprocessing): Pause Failing Connectors (#6350) 2025-11-20 19:14:56 +00:00
Wenxi
475a3afe56 fix(connector): handle hubspot ticket with None content (#6357) 2025-11-20 13:35:46 -05:00
SubashMohan
bf5b8e7bae fix(Project): project pending issues (#6099) 2025-11-20 17:53:08 +00:00
Jamison Lahman
4ff28c897b chore(dev): nginx container port 80 respects HOST_PORT_80 (#6338) 2025-11-20 17:48:10 +00:00
SubashMohan
ec9e9be42e Fix/user file modal (#6333) 2025-11-20 16:41:38 +00:00
Nikolas Garza
af5fa8fe54 fix: web search and image generation tool playwright test failures (#6347) 2025-11-20 07:13:05 +00:00
Jamison Lahman
03a9e9e068 chore(gha): playwright browser cache is arch-aware (#6351) 2025-11-20 03:28:53 +00:00
Richard Guan
ad81c3f9eb chore(tracing): updates (#6322) 2025-11-20 00:58:00 +00:00
Jamison Lahman
62129f4ab9 chore(gha): require playwright passing on merge (#6346) 2025-11-20 00:55:19 +00:00
Jamison Lahman
b30d38c747 chore(gha): fix zizmor issues (#6344) 2025-11-19 23:57:34 +00:00
Nikolas Garza
0596b57501 fix: featured assistant typo (#6341) 2025-11-19 14:44:54 -08:00
Jamison Lahman
482b2c4204 chore(gha): run uvx zizmor --fix=all (#6342) 2025-11-19 14:26:45 -08:00
Jamison Lahman
df155835b1 chore(docker): docker bake UX (#6339)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-11-19 14:19:53 -08:00
Richard Guan
fd0762a1ee chore(agent): framework query improvements (#6297) 2025-11-19 21:43:33 +00:00
Jamison Lahman
bd41618dd9 chore(deployments): correctly set --debug for docker build (#6337) 2025-11-19 11:04:15 -08:00
Justin Tahara
5a7c6312af feat(jwt): JIT provision from token (#6252) 2025-11-19 10:06:20 -08:00
Raunak Bhagat
a477508bd7 fix: Fix header flashing (#6331)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-11-19 09:27:49 -08:00
Raunak Bhagat
8ac34a8433 refactor: input type in fixes (#6335) 2025-11-19 08:31:39 -08:00
Raunak Bhagat
2c51466bc3 fix: Some minor touch-ups for the new modal (#6332) 2025-11-19 14:03:15 +00:00
Raunak Bhagat
62966bd172 fix: Switch fix (#6279) 2025-11-19 01:40:40 -08:00
Jamison Lahman
a8d4482b59 chore(deployments): set provenance=false and flag debug (#6330) 2025-11-18 22:26:53 -08:00
Jamison Lahman
dd42a45008 chore(deployments): flag to disable docker caching (#6328) 2025-11-19 04:07:07 +00:00
Jessica Singh
a368556282 feat(web search providers): adding support and changing env var approach (#6273) 2025-11-19 02:49:54 +00:00
Evan Lohn
679d1a5ef6 fix: openpyxl bug (#6317) 2025-11-19 00:59:46 +00:00
Nikolas Garza
12e49cd661 fix: slack config forms + scope issues (#6318) 2025-11-18 16:49:16 -08:00
Jamison Lahman
1859a0ad79 chore(gha): run zizmor (#6326) 2025-11-18 16:10:07 -08:00
Jamison Lahman
9199d146be fix(tests): test_partial_match_in_model_map AssertionError (#6321) 2025-11-18 16:06:01 -08:00
Jamison Lahman
9c1208ffd6 chore(deployments): separate builds by platform (#6314) 2025-11-18 14:49:23 -08:00
Jamison Lahman
c3387e33eb chore(deployments): remove DEPLOYMENT from cache path (#6319) 2025-11-18 14:16:09 -08:00
Jamison Lahman
c37f633a37 chore(deployments): remove driver-opts from model-server build (#6313) 2025-11-18 10:45:24 -08:00
235 changed files with 12423 additions and 1953 deletions

View File

@@ -7,9 +7,9 @@ runs:
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-${{ hashFiles('backend/requirements/default.txt') }}
key: ${{ runner.os }}-${{ runner.arch }}-playwright-${{ hashFiles('backend/requirements/default.txt') }}
restore-keys: |
${{ runner.os }}-playwright-
${{ runner.os }}-${{ runner.arch }}-playwright-
- name: Install playwright
shell: bash

View File

@@ -10,6 +10,9 @@ on:
- main
- 'release/**'
permissions:
contents: read
jobs:
check-lazy-imports:
runs-on: ubuntu-latest
@@ -17,6 +20,8 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # ratchet:actions/setup-python@v6

View File

@@ -6,6 +6,9 @@ on:
- "*"
workflow_dispatch:
permissions:
contents: read
env:
IS_DRY_RUN: ${{ github.event_name == 'workflow_dispatch' }}
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
@@ -30,7 +33,7 @@ jobs:
- name: Check which components to build and version info
id: check
run: |
TAG="${{ github.ref_name }}"
TAG="${GITHUB_REF_NAME}"
# Sanitize tag name by replacing slashes with hyphens (for Docker tag compatibility)
SANITIZED_TAG=$(echo "$TAG" | tr '/' '-')
IS_CLOUD=false
@@ -79,22 +82,143 @@ jobs:
echo "sanitized-tag=$SANITIZED_TAG"
} >> "$GITHUB_OUTPUT"
build-web:
build-web-amd64:
needs: determine-builds
if: needs.determine-builds.outputs.build-web == 'true'
runs-on:
- runs-on
- runner=4cpu-linux-x64
- run-id=${{ github.run_id }}-web-build
- run-id=${{ github.run_id }}-web-amd64
- extras=ecr-cache
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
DEPLOYMENT: standalone
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/amd64
labels: ${{ steps.meta.outputs.labels }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64,mode=max
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-web-arm64:
needs: determine-builds
if: needs.determine-builds.outputs.build-web == 'true'
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- run-id=${{ github.run_id }}-web-arm64
- extras=ecr-cache
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/arm64
labels: ${{ steps.meta.outputs.labels }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64,mode=max
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
merge-web:
needs:
- determine-builds
- build-web-amd64
- build-web-arm64
runs-on:
- runs-on
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-merge-web
- extras=ecr-cache
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Docker meta
id: meta
@@ -109,50 +233,37 @@ jobs:
type=raw,value=${{ github.event_name != 'workflow_dispatch' && env.EDGE_TAG == 'true' && 'edge' || '' }}
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-beta == 'true' && 'beta' || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Create and push manifest
env:
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
AMD64_DIGEST: ${{ needs.build-web-amd64.outputs.digest }}
ARM64_DIGEST: ${{ needs.build-web-arm64.outputs.digest }}
META_TAGS: ${{ steps.meta.outputs.tags }}
run: |
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
docker buildx imagetools create \
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
$IMAGES
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-${{ env.DEPLOYMENT }}-cache
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-${{ env.DEPLOYMENT }}-cache,mode=max
build-web-cloud:
build-web-cloud-amd64:
needs: determine-builds
if: needs.determine-builds.outputs.build-web-cloud == 'true'
runs-on:
- runs-on
- runner=4cpu-linux-x64
- run-id=${{ github.run_id }}-web-cloud-build
- run-id=${{ github.run_id }}-web-cloud-amd64
- extras=ecr-cache
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
DEPLOYMENT: cloud
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Docker meta
id: meta
@@ -161,8 +272,6 @@ jobs:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
latest=false
tags: |
type=raw,value=${{ github.event_name == 'workflow_dispatch' && format('web-cloud-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -173,14 +282,13 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push
- name: Build and push AMD64
id: build
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
platforms: linux/amd64
labels: ${{ steps.meta.outputs.labels }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
@@ -195,27 +303,259 @@ jobs:
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-${{ env.DEPLOYMENT }}-cache
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-${{ env.DEPLOYMENT }}-cache,mode=max
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64,mode=max
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-backend:
build-web-cloud-arm64:
needs: determine-builds
if: needs.determine-builds.outputs.build-web-cloud == 'true'
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- run-id=${{ github.run_id }}-web-cloud-arm64
- extras=ecr-cache
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/arm64
labels: ${{ steps.meta.outputs.labels }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
NEXT_PUBLIC_CLOUD_ENABLED=true
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64,mode=max
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
merge-web-cloud:
needs:
- determine-builds
- build-web-cloud-amd64
- build-web-cloud-arm64
runs-on:
- runs-on
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-merge-web-cloud
- extras=ecr-cache
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
latest=false
tags: |
type=raw,value=${{ github.event_name == 'workflow_dispatch' && format('web-cloud-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
- name: Create and push manifest
env:
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
AMD64_DIGEST: ${{ needs.build-web-cloud-amd64.outputs.digest }}
ARM64_DIGEST: ${{ needs.build-web-cloud-arm64.outputs.digest }}
META_TAGS: ${{ steps.meta.outputs.tags }}
run: |
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
docker buildx imagetools create \
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
$IMAGES
build-backend-amd64:
needs: determine-builds
if: needs.determine-builds.outputs.build-backend == 'true'
runs-on:
- runs-on
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-backend-build
- run-id=${{ github.run_id }}-backend-amd64
- extras=ecr-cache
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64
labels: ${{ steps.meta.outputs.labels }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64,mode=max
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-backend-arm64:
needs: determine-builds
if: needs.determine-builds.outputs.build-backend == 'true'
runs-on:
- runs-on
- runner=2cpu-linux-arm64
- run-id=${{ github.run_id }}-backend-arm64
- extras=ecr-cache
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/arm64
labels: ${{ steps.meta.outputs.labels }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64,mode=max
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
merge-backend:
needs:
- determine-builds
- build-backend-amd64
- build-backend-arm64
runs-on:
- runs-on
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-merge-backend
- extras=ecr-cache
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Docker meta
id: meta
@@ -230,6 +570,159 @@ jobs:
type=raw,value=${{ github.event_name != 'workflow_dispatch' && env.EDGE_TAG == 'true' && 'edge' || '' }}
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
- name: Create and push manifest
env:
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
AMD64_DIGEST: ${{ needs.build-backend-amd64.outputs.digest }}
ARM64_DIGEST: ${{ needs.build-backend-arm64.outputs.digest }}
META_TAGS: ${{ steps.meta.outputs.tags }}
run: |
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
docker buildx imagetools create \
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
$IMAGES
build-model-server-amd64:
needs: determine-builds
if: needs.determine-builds.outputs.build-model-server == 'true'
runs-on:
- runs-on
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-model-server-amd64
- volume=40gb
- extras=ecr-cache
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
with:
buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }}
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
env:
DEBUG: ${{ vars.DOCKER_DEBUG == 'true' && 1 || 0 }}
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64
labels: ${{ steps.meta.outputs.labels }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64,mode=max
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
no-cache: ${{ env.EDGE_TAG != 'true' && vars.MODEL_SERVER_NO_CACHE == 'true' }}
provenance: false
sbom: false
build-model-server-arm64:
needs: determine-builds
if: needs.determine-builds.outputs.build-model-server == 'true'
runs-on:
- runs-on
- runner=2cpu-linux-arm64
- run-id=${{ github.run_id }}-model-server-arm64
- volume=40gb
- extras=ecr-cache
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
with:
buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }}
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
env:
DEBUG: ${{ vars.DOCKER_DEBUG == 'true' && 1 || 0 }}
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/arm64
labels: ${{ steps.meta.outputs.labels }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64,mode=max
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
no-cache: ${{ env.EDGE_TAG != 'true' && vars.MODEL_SERVER_NO_CACHE == 'true' }}
provenance: false
sbom: false
merge-model-server:
needs:
- determine-builds
- build-model-server-amd64
- build-model-server-arm64
runs-on:
- runs-on
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-merge-model-server
- extras=ecr-cache
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -239,44 +732,6 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-${{ env.DEPLOYMENT }}-cache
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-${{ env.DEPLOYMENT }}-cache,mode=max
build-model-server:
needs: determine-builds
if: needs.determine-builds.outputs.build-model-server == 'true'
runs-on:
- runs-on
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-model-server-build
- volume=40gb
- extras=ecr-cache
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
DOCKER_BUILDKIT: 1
BUILDKIT_PROGRESS: plain
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
@@ -290,43 +745,26 @@ jobs:
type=raw,value=${{ github.event_name != 'workflow_dispatch' && env.EDGE_TAG == 'true' && 'edge' || '' }}
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
with:
driver-opts: |
image=moby/buildkit:latest
network=host
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-${{ env.DEPLOYMENT }}-cache
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-${{ env.DEPLOYMENT }}-cache,mode=max
- name: Create and push manifest
env:
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
AMD64_DIGEST: ${{ needs.build-model-server-amd64.outputs.digest }}
ARM64_DIGEST: ${{ needs.build-model-server-arm64.outputs.digest }}
META_TAGS: ${{ steps.meta.outputs.tags }}
run: |
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
docker buildx imagetools create \
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
$IMAGES
trivy-scan-web:
needs: [determine-builds, build-web]
if: needs.build-web.result == 'success'
needs:
- determine-builds
- merge-web
if: needs.merge-web.result == 'success'
runs-on:
- runs-on
- runner=2cpu-linux-x64
- runner=2cpu-linux-arm64
- run-id=${{ github.run_id }}-trivy-scan-web
- extras=ecr-cache
env:
@@ -359,11 +797,13 @@ jobs:
${SCAN_IMAGE}
trivy-scan-web-cloud:
needs: [determine-builds, build-web-cloud]
if: needs.build-web-cloud.result == 'success'
needs:
- determine-builds
- merge-web-cloud
if: needs.merge-web-cloud.result == 'success'
runs-on:
- runs-on
- runner=2cpu-linux-x64
- runner=2cpu-linux-arm64
- run-id=${{ github.run_id }}-trivy-scan-web-cloud
- extras=ecr-cache
env:
@@ -396,11 +836,13 @@ jobs:
${SCAN_IMAGE}
trivy-scan-backend:
needs: [determine-builds, build-backend]
if: needs.build-backend.result == 'success'
needs:
- determine-builds
- merge-backend
if: needs.merge-backend.result == 'success'
runs-on:
- runs-on
- runner=2cpu-linux-x64
- runner=2cpu-linux-arm64
- run-id=${{ github.run_id }}-trivy-scan-backend
- extras=ecr-cache
env:
@@ -410,6 +852,8 @@ jobs:
- name: Checkout
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Run Trivy vulnerability scanner
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
@@ -438,11 +882,13 @@ jobs:
${SCAN_IMAGE}
trivy-scan-model-server:
needs: [determine-builds, build-model-server]
if: needs.build-model-server.result == 'success'
needs:
- determine-builds
- merge-model-server
if: needs.merge-model-server.result == 'success'
runs-on:
- runs-on
- runner=2cpu-linux-x64
- runner=2cpu-linux-arm64
- run-id=${{ github.run_id }}-trivy-scan-model-server
- extras=ecr-cache
env:
@@ -475,33 +921,85 @@ jobs:
${SCAN_IMAGE}
notify-slack-on-failure:
needs: [build-web, build-web-cloud, build-backend, build-model-server]
if: always() && (needs.build-web.result == 'failure' || needs.build-web-cloud.result == 'failure' || needs.build-backend.result == 'failure' || needs.build-model-server.result == 'failure') && github.event_name != 'workflow_dispatch'
needs:
- build-web-amd64
- build-web-arm64
- merge-web
- build-web-cloud-amd64
- build-web-cloud-arm64
- merge-web-cloud
- build-backend-amd64
- build-backend-arm64
- merge-backend
- build-model-server-amd64
- build-model-server-arm64
- merge-model-server
if: always() && (needs.build-web-amd64.result == 'failure' || needs.build-web-arm64.result == 'failure' || needs.merge-web.result == 'failure' || needs.build-web-cloud-amd64.result == 'failure' || needs.build-web-cloud-arm64.result == 'failure' || needs.merge-web-cloud.result == 'failure' || needs.build-backend-amd64.result == 'failure' || needs.build-backend-arm64.result == 'failure' || needs.merge-backend.result == 'failure' || needs.build-model-server-amd64.result == 'failure' || needs.build-model-server-arm64.result == 'failure' || needs.merge-model-server.result == 'failure') && github.event_name != 'workflow_dispatch'
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
steps:
- name: Checkout
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Determine failed jobs
id: failed-jobs
shell: bash
run: |
FAILED_JOBS=""
if [ "${{ needs.build-web.result }}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-web\\n"
if [ "${NEEDS_BUILD_WEB_AMD64_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-web-amd64\\n"
fi
if [ "${{ needs.build-web-cloud.result }}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-web-cloud\\n"
if [ "${NEEDS_BUILD_WEB_ARM64_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-web-arm64\\n"
fi
if [ "${{ needs.build-backend.result }}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-backend\\n"
if [ "${NEEDS_MERGE_WEB_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• merge-web\\n"
fi
if [ "${{ needs.build-model-server.result }}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-model-server\\n"
if [ "${NEEDS_BUILD_WEB_CLOUD_AMD64_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-web-cloud-amd64\\n"
fi
if [ "${NEEDS_BUILD_WEB_CLOUD_ARM64_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-web-cloud-arm64\\n"
fi
if [ "${NEEDS_MERGE_WEB_CLOUD_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• merge-web-cloud\\n"
fi
if [ "${NEEDS_BUILD_BACKEND_AMD64_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-backend-amd64\\n"
fi
if [ "${NEEDS_BUILD_BACKEND_ARM64_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-backend-arm64\\n"
fi
if [ "${NEEDS_MERGE_BACKEND_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• merge-backend\\n"
fi
if [ "${NEEDS_BUILD_MODEL_SERVER_AMD64_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-model-server-amd64\\n"
fi
if [ "${NEEDS_BUILD_MODEL_SERVER_ARM64_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-model-server-arm64\\n"
fi
if [ "${NEEDS_MERGE_MODEL_SERVER_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• merge-model-server\\n"
fi
# Remove trailing \n and set output
FAILED_JOBS=$(printf '%s' "$FAILED_JOBS" | sed 's/\\n$//')
echo "jobs=$FAILED_JOBS" >> "$GITHUB_OUTPUT"
env:
NEEDS_BUILD_WEB_AMD64_RESULT: ${{ needs.build-web-amd64.result }}
NEEDS_BUILD_WEB_ARM64_RESULT: ${{ needs.build-web-arm64.result }}
NEEDS_MERGE_WEB_RESULT: ${{ needs.merge-web.result }}
NEEDS_BUILD_WEB_CLOUD_AMD64_RESULT: ${{ needs.build-web-cloud-amd64.result }}
NEEDS_BUILD_WEB_CLOUD_ARM64_RESULT: ${{ needs.build-web-cloud-arm64.result }}
NEEDS_MERGE_WEB_CLOUD_RESULT: ${{ needs.merge-web-cloud.result }}
NEEDS_BUILD_BACKEND_AMD64_RESULT: ${{ needs.build-backend-amd64.result }}
NEEDS_BUILD_BACKEND_ARM64_RESULT: ${{ needs.build-backend-arm64.result }}
NEEDS_MERGE_BACKEND_RESULT: ${{ needs.merge-backend.result }}
NEEDS_BUILD_MODEL_SERVER_AMD64_RESULT: ${{ needs.build-model-server-amd64.result }}
NEEDS_BUILD_MODEL_SERVER_ARM64_RESULT: ${{ needs.build-model-server-arm64.result }}
NEEDS_MERGE_MODEL_SERVER_RESULT: ${{ needs.merge-model-server.result }}
- name: Send Slack notification
uses: ./.github/actions/slack-notify

View File

@@ -10,6 +10,9 @@ on:
description: "The version (ie v1.0.0-beta.0) to tag as beta"
required: true
permissions:
contents: read
jobs:
tag:
# See https://runs-on.com/runners/linux/
@@ -29,13 +32,19 @@ jobs:
run: echo "DOCKER_CLI_EXPERIMENTAL=enabled" >> $GITHUB_ENV
- name: Pull, Tag and Push Web Server Image
env:
VERSION: ${{ github.event.inputs.version }}
run: |
docker buildx imagetools create -t onyxdotapp/onyx-web-server:beta onyxdotapp/onyx-web-server:${{ github.event.inputs.version }}
docker buildx imagetools create -t onyxdotapp/onyx-web-server:beta onyxdotapp/onyx-web-server:${VERSION}
- name: Pull, Tag and Push API Server Image
env:
VERSION: ${{ github.event.inputs.version }}
run: |
docker buildx imagetools create -t onyxdotapp/onyx-backend:beta onyxdotapp/onyx-backend:${{ github.event.inputs.version }}
docker buildx imagetools create -t onyxdotapp/onyx-backend:beta onyxdotapp/onyx-backend:${VERSION}
- name: Pull, Tag and Push Model Server Image
env:
VERSION: ${{ github.event.inputs.version }}
run: |
docker buildx imagetools create -t onyxdotapp/onyx-model-server:beta onyxdotapp/onyx-model-server:${{ github.event.inputs.version }}
docker buildx imagetools create -t onyxdotapp/onyx-model-server:beta onyxdotapp/onyx-model-server:${VERSION}

View File

@@ -10,6 +10,9 @@ on:
description: "The version (ie v0.0.1) to tag as latest"
required: true
permissions:
contents: read
jobs:
tag:
# See https://runs-on.com/runners/linux/
@@ -29,13 +32,19 @@ jobs:
run: echo "DOCKER_CLI_EXPERIMENTAL=enabled" >> $GITHUB_ENV
- name: Pull, Tag and Push Web Server Image
env:
VERSION: ${{ github.event.inputs.version }}
run: |
docker buildx imagetools create -t onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:${{ github.event.inputs.version }}
docker buildx imagetools create -t onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:${VERSION}
- name: Pull, Tag and Push API Server Image
env:
VERSION: ${{ github.event.inputs.version }}
run: |
docker buildx imagetools create -t onyxdotapp/onyx-backend:latest onyxdotapp/onyx-backend:${{ github.event.inputs.version }}
docker buildx imagetools create -t onyxdotapp/onyx-backend:latest onyxdotapp/onyx-backend:${VERSION}
- name: Pull, Tag and Push Model Server Image
env:
VERSION: ${{ github.event.inputs.version }}
run: |
docker buildx imagetools create -t onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:${{ github.event.inputs.version }}
docker buildx imagetools create -t onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:${VERSION}

View File

@@ -17,6 +17,7 @@ jobs:
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Install Helm CLI
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4

View File

@@ -15,16 +15,21 @@ on:
permissions:
actions: read
contents: read
security-events: write
jobs:
scan-licenses:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-licenses"]
permissions:
actions: read
contents: read
security-events: write
steps:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # ratchet:actions/setup-python@v6
@@ -54,7 +59,9 @@ jobs:
- name: Print report
if: always()
run: echo "${{ steps.license_check_report.outputs.report }}"
env:
REPORT: ${{ steps.license_check_report.outputs.report }}
run: echo "$REPORT"
- name: Install npm dependencies
working-directory: ./web

View File

@@ -8,6 +8,9 @@ on:
pull_request:
branches: [main]
permissions:
contents: read
env:
# AWS
S3_AWS_ACCESS_KEY_ID: ${{ secrets.S3_AWS_ACCESS_KEY_ID }}
@@ -37,6 +40,8 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Discover test directories
id: set-matrix
@@ -67,6 +72,8 @@ jobs:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Setup Python and Install Dependencies
uses: ./.github/actions/setup-python-and-install-dependencies
@@ -97,10 +104,12 @@ jobs:
- name: Run Tests for ${{ matrix.test-dir }}
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
env:
TEST_DIR: ${{ matrix.test-dir }}
run: |
py.test \
--durations=8 \
-o junit_family=xunit2 \
-xv \
--ff \
backend/tests/external_dependency_unit/${{ matrix.test-dir }}
backend/tests/external_dependency_unit/${TEST_DIR}

View File

@@ -9,6 +9,9 @@ on:
branches: [ main ]
workflow_dispatch: # Allows manual triggering
permissions:
contents: read
jobs:
helm-chart-check:
# See https://runs-on.com/runners/linux/
@@ -20,6 +23,7 @@ jobs:
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Set up Helm
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4.3.1
@@ -32,9 +36,11 @@ jobs:
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
- name: Run chart-testing (list-changed)
id: list-changed
env:
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
run: |
echo "default_branch: ${{ github.event.repository.default_branch }}"
changed=$(ct list-changed --remote origin --target-branch ${{ github.event.repository.default_branch }} --chart-dirs deployment/helm/charts)
echo "default_branch: ${DEFAULT_BRANCH}"
changed=$(ct list-changed --remote origin --target-branch ${DEFAULT_BRANCH} --chart-dirs deployment/helm/charts)
echo "list-changed output: $changed"
if [[ -n "$changed" ]]; then
echo "changed=true" >> "$GITHUB_OUTPUT"

View File

@@ -10,6 +10,9 @@ on:
- main
- "release/**"
permissions:
contents: read
env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -37,6 +40,8 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Discover test directories
id: set-matrix
@@ -65,6 +70,8 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -85,8 +92,11 @@ jobs:
file: ./backend/Dockerfile
push: true
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-cache
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-cache,mode=max
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
type=registry,ref=onyxdotapp/onyx-backend:latest
cache-to: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
@@ -96,6 +106,8 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -116,8 +128,10 @@ jobs:
file: ./backend/Dockerfile.model_server
push: true
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-cache
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-cache,mode=max
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
type=registry,ref=onyxdotapp/onyx-model-server:latest
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
build-integration-image:
@@ -126,6 +140,8 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -141,9 +157,16 @@ jobs:
- name: Build and push integration test image with Docker Bake
env:
REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
INTEGRATION_REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
TAG: integration-test-${{ github.run_id }}
run: cd backend && docker buildx bake --push integration
run: |
cd backend && docker buildx bake --push \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
integration
integration-tests:
needs:
@@ -168,6 +191,8 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
@@ -181,6 +206,9 @@ jobs:
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
# NOTE: don't need web server for integration tests
- name: Start Docker containers
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
@@ -189,8 +217,8 @@ jobs:
POSTGRES_USE_NULL_POOL=true \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
ONYX_BACKEND_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }} \
ONYX_MODEL_SERVER_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }} \
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
INTEGRATION_TESTS_MODE=true \
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 \
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
@@ -267,6 +295,7 @@ jobs:
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
@@ -322,6 +351,8 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
@@ -330,6 +361,9 @@ jobs:
password: ${{ secrets.DOCKER_TOKEN }}
- name: Start Docker containers for multi-tenant tests
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
@@ -337,8 +371,8 @@ jobs:
AUTH_TYPE=cloud \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
ONYX_BACKEND_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }} \
ONYX_MODEL_SERVER_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }} \
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
DEV_MODE=true \
docker compose -f docker-compose.multitenant-dev.yml up \
relational_db \
@@ -379,6 +413,9 @@ jobs:
echo "Finished waiting for service."
- name: Run Multi-Tenant Integration Tests
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
echo "Running multi-tenant integration tests..."
docker run --rm --network onyx_default \
@@ -394,6 +431,7 @@ jobs:
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e AUTH_TYPE=cloud \
@@ -402,7 +440,7 @@ jobs:
-e REQUIRE_EMAIL_VERIFICATION=false \
-e DISABLE_TELEMETRY=true \
-e DEV_MODE=true \
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
${ECR_CACHE}:integration-test-${RUN_ID} \
/app/tests/integration/multitenant_tests
- name: Dump API server logs (multi-tenant)
@@ -436,13 +474,6 @@ jobs:
needs: [integration-tests, multitenant-tests]
if: ${{ always() }}
steps:
- uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # ratchet:actions/github-script@v8
with:
script: |
const needs = ${{ toJSON(needs) }};
const failed = Object.values(needs).some(n => n.result !== 'success');
if (failed) {
core.setFailed('One or more upstream jobs failed or were cancelled.');
} else {
core.notice('All required jobs succeeded.');
}
- name: Check job status
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
run: exit 1

View File

@@ -5,6 +5,9 @@ concurrency:
on: push
permissions:
contents: read
jobs:
jest-tests:
name: Jest Tests
@@ -12,6 +15,8 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Setup node
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4

View File

@@ -1,7 +1,7 @@
name: PR Labeler
on:
pull_request_target:
pull_request:
branches:
- main
types:
@@ -12,7 +12,6 @@ on:
permissions:
contents: read
pull-requests: write
jobs:
validate_pr_title:

View File

@@ -7,6 +7,9 @@ on:
pull_request:
types: [opened, edited, reopened, synchronize]
permissions:
contents: read
jobs:
linear-check:
runs-on: ubuntu-latest

View File

@@ -7,6 +7,9 @@ on:
merge_group:
types: [checks_requested]
permissions:
contents: read
env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -33,6 +36,8 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Discover test directories
id: set-matrix
@@ -60,6 +65,8 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -80,8 +87,10 @@ jobs:
file: ./backend/Dockerfile
push: true
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-cache
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-cache,mode=max
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
type=registry,ref=onyxdotapp/onyx-backend:latest
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-model-server-image:
@@ -90,6 +99,8 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -110,8 +121,10 @@ jobs:
file: ./backend/Dockerfile.model_server
push: true
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-cache
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-cache,mode=max
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
type=registry,ref=onyxdotapp/onyx-model-server:latest
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
build-integration-image:
runs-on: [runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-build-integration-image", "extras=ecr-cache"]
@@ -119,6 +132,8 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -134,9 +149,16 @@ jobs:
- name: Build and push integration test image with Docker Bake
env:
REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
INTEGRATION_REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
TAG: integration-test-${{ github.run_id }}
run: cd backend && docker buildx bake --push integration
run: |
cd backend && docker buildx bake --push \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
integration
integration-tests-mit:
needs:
@@ -161,6 +183,8 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
@@ -174,6 +198,9 @@ jobs:
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
# NOTE: don't need web server for integration tests
- name: Start Docker containers
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cd deployment/docker_compose
AUTH_TYPE=basic \
@@ -181,8 +208,8 @@ jobs:
POSTGRES_USE_NULL_POOL=true \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
ONYX_BACKEND_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }} \
ONYX_MODEL_SERVER_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }} \
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
INTEGRATION_TESTS_MODE=true \
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
relational_db \
@@ -307,13 +334,6 @@ jobs:
needs: [integration-tests-mit]
if: ${{ always() }}
steps:
- uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # ratchet:actions/github-script@v8
with:
script: |
const needs = ${{ toJSON(needs) }};
const failed = Object.values(needs).some(n => n.result !== 'success');
if (failed) {
core.setFailed('One or more upstream jobs failed or were cancelled.');
} else {
core.notice('All required jobs succeeded.');
}
- name: Check job status
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
run: exit 1

View File

@@ -5,6 +5,9 @@ concurrency:
on: push
permissions:
contents: read
env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -42,6 +45,8 @@ jobs:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -62,8 +67,10 @@ jobs:
platforms: linux/arm64
tags: ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-${{ github.run_id }}
push: true
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-cache
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-cache,mode=max
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache
type=registry,ref=onyxdotapp/onyx-web-server:latest
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache,mode=max
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-backend-image:
@@ -73,6 +80,8 @@ jobs:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -93,8 +102,11 @@ jobs:
platforms: linux/arm64
tags: ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-${{ github.run_id }}
push: true
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-cache
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-cache,mode=max
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
type=registry,ref=onyxdotapp/onyx-backend:latest
cache-to: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-model-server-image:
@@ -104,6 +116,8 @@ jobs:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
@@ -124,8 +138,10 @@ jobs:
platforms: linux/arm64
tags: ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-${{ github.run_id }}
push: true
cache-from: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-cache
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-cache,mode=max
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
type=registry,ref=onyxdotapp/onyx-model-server:latest
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
playwright-tests:
@@ -143,6 +159,7 @@ jobs:
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Setup node
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4
@@ -168,18 +185,26 @@ jobs:
run: npx playwright install --with-deps
- name: Create .env file for Docker Compose
env:
OPENAI_API_KEY_VALUE: ${{ env.OPENAI_API_KEY }}
EXA_API_KEY_VALUE: ${{ env.EXA_API_KEY }}
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cat <<EOF > deployment/docker_compose/.env
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
AUTH_TYPE=basic
GEN_AI_API_KEY=${{ env.OPENAI_API_KEY }}
EXA_API_KEY=${{ env.EXA_API_KEY }}
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
EXA_API_KEY=${EXA_API_KEY_VALUE}
REQUIRE_EMAIL_VERIFICATION=false
DISABLE_TELEMETRY=true
ONYX_BACKEND_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-${{ github.run_id }}
ONYX_MODEL_SERVER_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-${{ github.run_id }}
ONYX_WEB_SERVER_IMAGE=${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-${{ github.run_id }}
ONYX_BACKEND_IMAGE=${ECR_CACHE}:playwright-test-backend-${RUN_ID}
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:playwright-test-model-server-${RUN_ID}
ONYX_WEB_SERVER_IMAGE=${ECR_CACHE}:playwright-test-web-${RUN_ID}
EOF
if [ "${{ matrix.project }}" = "no-auth" ]; then
echo "PLAYWRIGHT_FORCE_EMPTY_LLM_PROVIDERS=true" >> deployment/docker_compose/.env
fi
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
@@ -255,10 +280,15 @@ jobs:
- name: Run Playwright tests
working-directory: ./web
env:
PROJECT: ${{ matrix.project }}
run: |
# Create test-results directory to ensure it exists for artifact upload
mkdir -p test-results
npx playwright test --project ${{ matrix.project }}
if [ "${PROJECT}" = "no-auth" ]; then
export PLAYWRIGHT_FORCE_EMPTY_LLM_PROVIDERS=true
fi
npx playwright test --project ${PROJECT}
- uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
if: always()
@@ -271,10 +301,12 @@ jobs:
# save before stopping the containers so the logs can be captured
- name: Save Docker logs
if: success() || failure()
env:
WORKSPACE: ${{ github.workspace }}
run: |
cd deployment/docker_compose
docker compose logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
mv docker-compose.log ${WORKSPACE}/docker-compose.log
- name: Upload logs
if: success() || failure()
@@ -283,6 +315,16 @@ jobs:
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
path: ${{ github.workspace }}/docker-compose.log
playwright-required:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
needs: [playwright-tests]
if: ${{ always() }}
steps:
- name: Check job status
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
run: exit 1
# NOTE: Chromatic UI diff testing is currently disabled.
# We are using Playwright for local and CI testing without visual regression checks.

View File

@@ -10,6 +10,9 @@ on:
- main
- 'release/**'
permissions:
contents: read
jobs:
mypy-check:
# See https://runs-on.com/runners/linux/
@@ -21,6 +24,8 @@ jobs:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
# needed for pulling openapitools/openapi-generator-cli
# otherwise, we hit the "Unauthenticated users" limit

View File

@@ -11,6 +11,9 @@ on:
# This cron expression runs the job daily at 16:00 UTC (9am PT)
- cron: "0 16 * * *"
permissions:
contents: read
env:
# AWS
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
@@ -132,6 +135,8 @@ jobs:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Setup Python and Install Dependencies
uses: ./.github/actions/setup-python-and-install-dependencies
@@ -214,8 +219,10 @@ jobs:
if: failure() && github.event_name == 'schedule'
env:
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
REPO: ${{ github.repository }}
RUN_ID: ${{ github.run_id }}
run: |
curl -X POST \
-H 'Content-type: application/json' \
--data '{"text":"Scheduled Connector Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
--data "{\"text\":\"Scheduled Connector Tests failed! Check the run at: https://github.com/${REPO}/actions/runs/${RUN_ID}\"}" \
$SLACK_WEBHOOK

View File

@@ -11,6 +11,9 @@ on:
required: false
default: 'main'
permissions:
contents: read
env:
# Bedrock
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
@@ -36,6 +39,8 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
@@ -122,10 +127,12 @@ jobs:
if: failure() && github.event_name == 'schedule'
env:
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
REPO: ${{ github.repository }}
RUN_ID: ${{ github.run_id }}
run: |
curl -X POST \
-H 'Content-type: application/json' \
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
--data "{\"text\":\"Scheduled Model Tests failed! Check the run at: https://github.com/${REPO}/actions/runs/${RUN_ID}\"}" \
$SLACK_WEBHOOK
- name: Dump all-container logs (optional)

View File

@@ -10,6 +10,9 @@ on:
- main
- 'release/**'
permissions:
contents: read
jobs:
backend-check:
# See https://runs-on.com/runners/linux/
@@ -28,6 +31,8 @@ jobs:
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
persist-credentials: false
- name: Setup Python and Install Dependencies
uses: ./.github/actions/setup-python-and-install-dependencies

View File

@@ -7,6 +7,9 @@ on:
merge_group:
pull_request: null
permissions:
contents: read
jobs:
quality-checks:
# See https://runs-on.com/runners/linux/
@@ -16,6 +19,7 @@ jobs:
- uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # ratchet:actions/setup-python@v6
with:
python-version: "3.11"

View File

@@ -16,6 +16,7 @@ jobs:
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Install git-filter-repo
run: |

View File

@@ -3,30 +3,29 @@ name: Nightly Tag Push
on:
schedule:
- cron: "0 10 * * *" # Runs every day at 2 AM PST / 3 AM PDT / 10 AM UTC
workflow_dispatch:
permissions:
contents: write # Allows pushing tags to the repository
jobs:
create-and-push-tag:
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-create-and-push-tag"]
runs-on: ubuntu-slim
steps:
# actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
# implement here which needs an actual user's deploy key
# Additional NOTE: even though this is named "rkuo", the actual key is tied to the onyx repo
# and not rkuo's personal account. It is fine to leave this key as is!
- name: Checkout code
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # ratchet:actions/checkout@v4
with:
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
ssh-key: "${{ secrets.DEPLOY_KEY }}"
persist-credentials: true
- name: Set up Git user
run: |
git config user.name "Richard Kuo [bot]"
git config user.email "rkuo[bot]@onyx.app"
git config user.name "Onyx Bot [bot]"
git config user.email "onyx-bot[bot]@onyx.app"
- name: Check for existing nightly tag
id: check_tag
@@ -54,3 +53,12 @@ jobs:
run: |
TAG_NAME="nightly-latest-$(date +'%Y%m%d')"
git push origin $TAG_NAME
- name: Send Slack notification
if: failure()
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
title: "🚨 Nightly Tag Push Failed"
ref-name: ${{ github.ref_name }}
failed-jobs: "create-and-push-tag"

35
.github/workflows/zizmor.yml vendored Normal file
View File

@@ -0,0 +1,35 @@
name: Run Zizmor
on:
push:
branches: ["main"]
pull_request:
branches: ["**"]
permissions: {}
jobs:
zizmor:
name: zizmor
runs-on: ubuntu-slim
permissions:
security-events: write # needed for SARIF uploads
steps:
- name: Checkout repository
uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # ratchet:actions/checkout@v5.0.1
with:
persist-credentials: false
- name: Install the latest version of uv
uses: astral-sh/setup-uv@5a7eac68fb9809dea845d802897dc5c723910fa3 # ratchet:astral-sh/setup-uv@v7.1.3
- name: Run zizmor
run: uvx zizmor==1.16.3 --format=sarif . > results.sarif
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Upload SARIF file
uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab # ratchet:github/codeql-action/upload-sarif@codeql-bundle-v2.23.5
with:
sarif_file: results.sarif
category: zizmor

View File

@@ -1,4 +1,42 @@
FROM python:3.11.7-slim-bookworm
# Base stage with dependencies
FROM python:3.11.7-slim-bookworm AS base
ENV DANSWER_RUNNING_IN_DOCKER="true" \
HF_HOME=/app/.cache/huggingface
COPY --from=ghcr.io/astral-sh/uv:0.9.9 /uv /uvx /bin/
RUN mkdir -p /app/.cache/huggingface
COPY ./requirements/model_server.txt /tmp/requirements.txt
RUN uv pip install --system --no-cache-dir --upgrade \
-r /tmp/requirements.txt && \
rm -rf ~/.cache/uv /tmp/*.txt
# Stage for downloading tokenizers
FROM base AS tokenizers
RUN python -c "from transformers import AutoTokenizer; \
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1');"
# Stage for downloading Onyx models
FROM base AS onyx-models
RUN python -c "from huggingface_hub import snapshot_download; \
snapshot_download(repo_id='onyx-dot-app/hybrid-intent-token-classifier'); \
snapshot_download(repo_id='onyx-dot-app/information-content-model');"
# Stage for downloading embedding and reranking models
FROM base AS embedding-models
RUN python -c "from huggingface_hub import snapshot_download; \
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1');"
# Initialize SentenceTransformer to cache the custom architecture
RUN python -c "from sentence_transformers import SentenceTransformer; \
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);"
# Final stage - combine all downloads
FROM base AS final
LABEL com.danswer.maintainer="founders@onyx.app"
LABEL com.danswer.description="This image is for the Onyx model server which runs all of the \
@@ -6,44 +44,19 @@ AI models for Onyx. This container and all the code is MIT Licensed and free for
You can find it at https://hub.docker.com/r/onyx/onyx-model-server. For more details, \
visit https://github.com/onyx-dot-app/onyx."
ENV DANSWER_RUNNING_IN_DOCKER="true" \
HF_HOME=/app/.cache/huggingface
COPY --from=ghcr.io/astral-sh/uv:0.9.9 /uv /uvx /bin/
# Create non-root user for security best practices
RUN mkdir -p /app && \
groupadd -g 1001 onyx && \
useradd -u 1001 -g onyx -m -s /bin/bash onyx && \
chown -R onyx:onyx /app && \
RUN groupadd -g 1001 onyx && \
useradd -u 1001 -g onyx -m -s /bin/bash onyx && \
mkdir -p /var/log/onyx && \
chmod 755 /var/log/onyx && \
chown onyx:onyx /var/log/onyx
COPY ./requirements/model_server.txt /tmp/requirements.txt
RUN uv pip install --system --no-cache-dir --upgrade \
-r /tmp/requirements.txt && \
rm -rf ~/.cache/uv /tmp/*.txt
# Pre-downloading models for setups with limited egress
# Download tokenizers, distilbert for the Onyx model
# Download model weights
# Run Nomic to pull in the custom architecture and have it cached locally
RUN python -c "from transformers import AutoTokenizer; \
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
from huggingface_hub import snapshot_download; \
snapshot_download(repo_id='onyx-dot-app/hybrid-intent-token-classifier'); \
snapshot_download(repo_id='onyx-dot-app/information-content-model'); \
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
from sentence_transformers import SentenceTransformer; \
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);" && \
# In case the user has volumes mounted to /app/.cache/huggingface that they've downloaded while
# running Onyx, move the current contents of the cache folder to a temporary location to ensure
# it's preserved in order to combine with the user's cache contents
mv /app/.cache/huggingface /app/.cache/temp_huggingface && \
chown -R onyx:onyx /app
# In case the user has volumes mounted to /app/.cache/huggingface that they've downloaded while
# running Onyx, move the current contents of the cache folder to a temporary location to ensure
# it's preserved in order to combine with the user's cache contents
COPY --chown=onyx:onyx --from=tokenizers /app/.cache/huggingface /app/.cache/temp_huggingface
COPY --chown=onyx:onyx --from=onyx-models /app/.cache/huggingface /app/.cache/temp_huggingface
COPY --chown=onyx:onyx --from=embedding-models /app/.cache/huggingface /app/.cache/temp_huggingface
WORKDIR /app

View File

@@ -0,0 +1,89 @@
"""add internet search and content provider tables
Revision ID: 1f2a3b4c5d6e
Revises: 9drpiiw74ljy
Create Date: 2025-11-10 19:45:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "1f2a3b4c5d6e"
down_revision = "9drpiiw74ljy"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"internet_search_provider",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("name", sa.String(), nullable=False, unique=True),
sa.Column("provider_type", sa.String(), nullable=False),
sa.Column("api_key", sa.LargeBinary(), nullable=True),
sa.Column("config", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column(
"is_active", sa.Boolean(), nullable=False, server_default=sa.text("false")
),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
)
op.create_index(
"ix_internet_search_provider_is_active",
"internet_search_provider",
["is_active"],
)
op.create_table(
"internet_content_provider",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("name", sa.String(), nullable=False, unique=True),
sa.Column("provider_type", sa.String(), nullable=False),
sa.Column("api_key", sa.LargeBinary(), nullable=True),
sa.Column("config", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column(
"is_active", sa.Boolean(), nullable=False, server_default=sa.text("false")
),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
)
op.create_index(
"ix_internet_content_provider_is_active",
"internet_content_provider",
["is_active"],
)
def downgrade() -> None:
op.drop_index(
"ix_internet_content_provider_is_active", table_name="internet_content_provider"
)
op.drop_table("internet_content_provider")
op.drop_index(
"ix_internet_search_provider_is_active", table_name="internet_search_provider"
)
op.drop_table("internet_search_provider")

View File

@@ -0,0 +1,89 @@
"""seed_exa_provider_from_env
Revision ID: 3c9a65f1207f
Revises: 1f2a3b4c5d6e
Create Date: 2025-11-20 19:18:00.000000
"""
from __future__ import annotations
import os
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from dotenv import load_dotenv, find_dotenv
from onyx.utils.encryption import encrypt_string_to_bytes
revision = "3c9a65f1207f"
down_revision = "1f2a3b4c5d6e"
branch_labels = None
depends_on = None
EXA_PROVIDER_NAME = "Exa"
def _get_internet_search_table(metadata: sa.MetaData) -> sa.Table:
return sa.Table(
"internet_search_provider",
metadata,
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("name", sa.String),
sa.Column("provider_type", sa.String),
sa.Column("api_key", sa.LargeBinary),
sa.Column("config", postgresql.JSONB),
sa.Column("is_active", sa.Boolean),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
)
def upgrade() -> None:
load_dotenv(find_dotenv())
exa_api_key = os.environ.get("EXA_API_KEY")
if not exa_api_key:
return
bind = op.get_bind()
metadata = sa.MetaData()
table = _get_internet_search_table(metadata)
existing = bind.execute(
sa.select(table.c.id).where(table.c.name == EXA_PROVIDER_NAME)
).first()
if existing:
return
encrypted_key = encrypt_string_to_bytes(exa_api_key)
has_active_provider = bind.execute(
sa.select(table.c.id).where(table.c.is_active.is_(True))
).first()
bind.execute(
table.insert().values(
name=EXA_PROVIDER_NAME,
provider_type="exa",
api_key=encrypted_key,
config=None,
is_active=not bool(has_active_provider),
)
)
def downgrade() -> None:
return

View File

@@ -1,4 +1,16 @@
variable "REPOSITORY" {
group "default" {
targets = ["backend", "model-server"]
}
variable "BACKEND_REPOSITORY" {
default = "onyxdotapp/onyx-backend"
}
variable "MODEL_SERVER_REPOSITORY" {
default = "onyxdotapp/onyx-model-server"
}
variable "INTEGRATION_REPOSITORY" {
default = "onyxdotapp/onyx-integration"
}
@@ -9,6 +21,22 @@ variable "TAG" {
target "backend" {
context = "."
dockerfile = "Dockerfile"
cache-from = ["type=registry,ref=${BACKEND_REPOSITORY}:latest"]
cache-to = ["type=inline"]
tags = ["${BACKEND_REPOSITORY}:${TAG}"]
}
target "model-server" {
context = "."
dockerfile = "Dockerfile.model_server"
cache-from = ["type=registry,ref=${MODEL_SERVER_REPOSITORY}:latest"]
cache-to = ["type=inline"]
tags = ["${MODEL_SERVER_REPOSITORY}:${TAG}"]
}
target "integration" {
@@ -20,8 +48,5 @@ target "integration" {
base = "target:backend"
}
cache-from = ["type=registry,ref=${REPOSITORY}:integration-test-backend-cache"]
cache-to = ["type=registry,ref=${REPOSITORY}:integration-test-backend-cache,mode=max"]
tags = ["${REPOSITORY}:${TAG}"]
tags = ["${INTEGRATION_REPOSITORY}:${TAG}"]
}

View File

@@ -124,6 +124,8 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
MARKETING_POSTHOG_API_KEY = os.environ.get("MARKETING_POSTHOG_API_KEY")
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
GATED_TENANTS_KEY = "gated_tenants"

View File

@@ -1,7 +1,10 @@
import json
from typing import Any
from urllib.parse import unquote
from posthog import Posthog
from ee.onyx.configs.app_configs import MARKETING_POSTHOG_API_KEY
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
from ee.onyx.configs.app_configs import POSTHOG_HOST
from onyx.utils.logger import setup_logger
@@ -20,3 +23,80 @@ posthog = Posthog(
debug=True,
on_error=posthog_on_error,
)
# For cross referencing between cloud and www Onyx sites
# NOTE: These clients are separate because they are separate posthog projects.
# We should eventually unify them into a single posthog project,
# which would no longer require this workaround
marketing_posthog = None
if MARKETING_POSTHOG_API_KEY:
marketing_posthog = Posthog(
project_api_key=MARKETING_POSTHOG_API_KEY,
host=POSTHOG_HOST,
debug=True,
on_error=posthog_on_error,
)
def capture_and_sync_with_alternate_posthog(
alternate_distinct_id: str, event: str, properties: dict[str, Any]
) -> None:
"""
Identify in both PostHog projects and capture the event in marketing.
- Marketing keeps the marketing distinct_id (for feature flags).
- Cloud identify uses the cloud distinct_id
"""
if not marketing_posthog:
return
props = properties.copy()
try:
marketing_posthog.identify(distinct_id=alternate_distinct_id, properties=props)
marketing_posthog.capture(alternate_distinct_id, event, props)
marketing_posthog.flush()
except Exception as e:
logger.error(f"Error capturing marketing posthog event: {e}")
try:
if cloud_user_id := props.get("onyx_cloud_user_id"):
cloud_props = props.copy()
cloud_props.pop("onyx_cloud_user_id", None)
posthog.identify(
distinct_id=cloud_user_id,
properties=cloud_props,
)
except Exception as e:
logger.error(f"Error identifying cloud posthog user: {e}")
def get_marketing_posthog_cookie_name() -> str | None:
if not MARKETING_POSTHOG_API_KEY:
return None
return f"onyx_custom_ph_{MARKETING_POSTHOG_API_KEY}_posthog"
def parse_marketing_cookie(cookie_value: str) -> dict[str, Any] | None:
"""
Parse the URL-encoded JSON marketing cookie.
Expected format (URL-encoded):
{"distinct_id":"...", "featureFlags":{"landing_page_variant":"..."}, ...}
Returns:
Dict with 'distinct_id' explicitly required and all other cookie values
passed through as-is, or None if parsing fails or distinct_id is missing.
"""
try:
decoded_cookie = unquote(cookie_value)
cookie_data = json.loads(decoded_cookie)
distinct_id = cookie_data.get("distinct_id")
if not distinct_id:
return None
return cookie_data
except (json.JSONDecodeError, KeyError, TypeError, AttributeError) as e:
logger.warning(f"Failed to parse cookie: {e}")
return None

View File

@@ -116,7 +116,7 @@ def _concurrent_embedding(
# the model to fail to encode texts. It's pretty rare and we want to allow
# concurrent embedding, hence we retry (the specific error is
# "RuntimeError: Already borrowed" and occurs in the transformers library)
logger.error(f"Error encoding texts, retrying: {e}")
logger.warning(f"Error encoding texts, retrying: {e}")
time.sleep(ENCODING_RETRY_DELAY)
return model.encode(texts, normalize_embeddings=normalize_embeddings)

View File

@@ -0,0 +1,73 @@
import json
from collections.abc import Sequence
from typing import cast
from langchain_core.messages import AIMessage
from langchain_core.messages import BaseMessage
from langchain_core.messages import FunctionMessage
from onyx.llm.message_types import AssistantMessage
from onyx.llm.message_types import ChatCompletionMessage
from onyx.llm.message_types import FunctionCall
from onyx.llm.message_types import SystemMessage
from onyx.llm.message_types import ToolCall
from onyx.llm.message_types import ToolMessage
from onyx.llm.message_types import UserMessageWithText
HUMAN = "human"
SYSTEM = "system"
AI = "ai"
FUNCTION = "function"
def base_messages_to_chat_completion_msgs(
msgs: Sequence[BaseMessage],
) -> list[ChatCompletionMessage]:
return [_base_message_to_chat_completion_msg(msg) for msg in msgs]
def _base_message_to_chat_completion_msg(
msg: BaseMessage,
) -> ChatCompletionMessage:
if msg.type == HUMAN:
content = msg.content if isinstance(msg.content, str) else str(msg.content)
user_msg: UserMessageWithText = {"role": "user", "content": content}
return user_msg
if msg.type == SYSTEM:
content = msg.content if isinstance(msg.content, str) else str(msg.content)
system_msg: SystemMessage = {"role": "system", "content": content}
return system_msg
if msg.type == AI:
content = msg.content if isinstance(msg.content, str) else str(msg.content)
assistant_msg: AssistantMessage = {
"role": "assistant",
"content": content,
}
if isinstance(msg, AIMessage) and msg.tool_calls:
assistant_msg["tool_calls"] = [
ToolCall(
id=tool_call.get("id") or "",
type="function",
function=FunctionCall(
name=tool_call["name"],
arguments=json.dumps(tool_call["args"]),
),
)
for tool_call in msg.tool_calls
]
return assistant_msg
if msg.type == FUNCTION:
function_message = cast(FunctionMessage, msg)
content = (
function_message.content
if isinstance(function_message.content, str)
else str(function_message.content)
)
tool_msg: ToolMessage = {
"role": "tool",
"content": content,
"tool_call_id": function_message.name or "",
}
return tool_msg
raise ValueError(f"Unexpected message type: {msg.type}")

View File

@@ -1,9 +1,11 @@
import json
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
import onyx.tracing.framework._error_tracing as _error_tracing
from onyx.agents.agent_framework.models import RunItemStreamEvent
from onyx.agents.agent_framework.models import StreamEvent
from onyx.agents.agent_framework.models import ToolCallOutputStreamItem
@@ -16,6 +18,10 @@ from onyx.llm.message_types import ToolCall
from onyx.llm.model_response import ModelResponseStream
from onyx.tools.tool import RunContextWrapper
from onyx.tools.tool import Tool
from onyx.tracing.framework.create import agent_span
from onyx.tracing.framework.create import function_span
from onyx.tracing.framework.create import generation_span
from onyx.tracing.framework.spans import SpanError
@dataclass
@@ -33,6 +39,75 @@ def _serialize_tool_output(output: Any) -> str:
return str(output)
def _parse_tool_calls_from_message_content(
content: str,
) -> list[dict[str, Any]]:
"""Parse JSON content that represents tool call instructions."""
try:
parsed_content = json.loads(content)
except json.JSONDecodeError:
return []
if isinstance(parsed_content, dict):
candidates = [parsed_content]
elif isinstance(parsed_content, list):
candidates = [item for item in parsed_content if isinstance(item, dict)]
else:
return []
tool_calls: list[dict[str, Any]] = []
for candidate in candidates:
name = candidate.get("name")
arguments = candidate.get("arguments")
if not isinstance(name, str) or arguments is None:
continue
if not isinstance(arguments, dict):
continue
call_id = candidate.get("id")
arguments_str = json.dumps(arguments)
tool_calls.append(
{
"id": call_id,
"name": name,
"arguments": arguments_str,
}
)
return tool_calls
def _try_convert_content_to_tool_calls_for_non_tool_calling_llms(
tool_calls_in_progress: dict[int, dict[str, Any]],
content_parts: list[str],
structured_response_format: dict | None,
next_synthetic_tool_call_id: Callable[[], str],
) -> None:
"""Populate tool_calls_in_progress when a non-tool-calling LLM returns JSON content describing tool calls."""
if tool_calls_in_progress or not content_parts or structured_response_format:
return
tool_calls_from_content = _parse_tool_calls_from_message_content(
"".join(content_parts)
)
if not tool_calls_from_content:
return
content_parts.clear()
for index, tool_call_data in enumerate(tool_calls_from_content):
call_id = tool_call_data["id"] or next_synthetic_tool_call_id()
tool_calls_in_progress[index] = {
"id": call_id,
"name": tool_call_data["name"],
"arguments": tool_call_data["arguments"],
}
def _update_tool_call_with_delta(
tool_calls_in_progress: dict[int, dict[str, Any]],
tool_call_delta: Any,
@@ -65,149 +140,224 @@ def query(
tools: Sequence[Tool],
context: Any,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> QueryResult:
tool_definitions = [tool.tool_definition() for tool in tools]
tools_by_name = {tool.name: tool for tool in tools}
new_messages_stateful: list[ChatCompletionMessage] = []
current_span = agent_span(
name="agent_framework_query",
output_type="dict" if structured_response_format else "str",
)
current_span.start(mark_as_current=True)
current_span.span_data.tools = [t.name for t in tools]
def stream_generator() -> Iterator[StreamEvent]:
reasoning_started = False
message_started = False
reasoning_started = False
tool_calls_in_progress: dict[int, dict[str, Any]] = {}
content_parts: list[str] = []
reasoning_parts: list[str] = []
for chunk in llm_with_default_settings.stream(
prompt=messages,
tools=tool_definitions,
tool_choice=tool_choice,
):
assert isinstance(chunk, ModelResponseStream)
synthetic_tool_call_counter = 0
delta = chunk.choice.delta
finish_reason = chunk.choice.finish_reason
def _next_synthetic_tool_call_id() -> str:
nonlocal synthetic_tool_call_counter
call_id = f"synthetic_tool_call_{synthetic_tool_call_counter}"
synthetic_tool_call_counter += 1
return call_id
if delta.reasoning_content:
reasoning_parts.append(delta.reasoning_content)
if not reasoning_started:
yield RunItemStreamEvent(type="reasoning_start")
reasoning_started = True
with generation_span( # type: ignore[misc]
model=llm_with_default_settings.config.model_name,
model_config={
"base_url": str(llm_with_default_settings.config.api_base or ""),
"model_impl": "litellm",
},
) as span_generation:
# Only set input if messages is a sequence (not a string)
# ChatCompletionMessage TypedDicts are compatible with Mapping[str, Any] at runtime
if isinstance(messages, Sequence) and not isinstance(messages, str):
# Convert ChatCompletionMessage sequence to Sequence[Mapping[str, Any]]
span_generation.span_data.input = [dict(msg) for msg in messages] # type: ignore[assignment]
for chunk in llm_with_default_settings.stream(
prompt=messages,
tools=tool_definitions,
tool_choice=tool_choice,
structured_response_format=structured_response_format,
):
assert isinstance(chunk, ModelResponseStream)
usage = getattr(chunk, "usage", None)
if usage:
span_generation.span_data.usage = {
"input_tokens": usage.prompt_tokens,
"output_tokens": usage.completion_tokens,
"cache_read_input_tokens": usage.cache_read_input_tokens,
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
}
delta = chunk.choice.delta
finish_reason = chunk.choice.finish_reason
if delta.reasoning_content:
if not reasoning_started:
yield RunItemStreamEvent(type="reasoning_start")
reasoning_started = True
if delta.content:
if reasoning_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
content_parts.append(delta.content)
if not message_started:
yield RunItemStreamEvent(type="message_start")
message_started = True
if delta.tool_calls:
if reasoning_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
if message_started:
yield RunItemStreamEvent(type="message_done")
message_started = False
for tool_call_delta in delta.tool_calls:
_update_tool_call_with_delta(
tool_calls_in_progress, tool_call_delta
)
yield chunk
if not finish_reason:
continue
if delta.content:
content_parts.append(delta.content)
if reasoning_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
if not message_started:
yield RunItemStreamEvent(type="message_start")
message_started = True
if delta.tool_calls:
if reasoning_started and not message_started:
yield RunItemStreamEvent(type="reasoning_done")
reasoning_started = False
if message_started:
yield RunItemStreamEvent(type="message_done")
message_started = False
for tool_call_delta in delta.tool_calls:
_update_tool_call_with_delta(
tool_calls_in_progress, tool_call_delta
if tool_choice != "none":
_try_convert_content_to_tool_calls_for_non_tool_calling_llms(
tool_calls_in_progress,
content_parts,
structured_response_format,
_next_synthetic_tool_call_id,
)
yield chunk
if not finish_reason:
continue
if message_started:
yield RunItemStreamEvent(type="message_done")
message_started = False
if finish_reason == "tool_calls" and tool_calls_in_progress:
sorted_tool_calls = sorted(tool_calls_in_progress.items())
# Build tool calls for the message and execute tools
assistant_tool_calls: list[ToolCall] = []
tool_outputs: dict[str, str] = {}
for _, tool_call_data in sorted_tool_calls:
call_id = tool_call_data["id"]
name = tool_call_data["name"]
arguments_str = tool_call_data["arguments"]
if call_id is None or name is None:
continue
assistant_tool_calls.append(
if content_parts:
new_messages_stateful.append(
{
"id": call_id,
"type": "function",
"function": {
"name": name,
"arguments": arguments_str,
},
"role": "assistant",
"content": "".join(content_parts),
}
)
span_generation.span_data.output = new_messages_stateful
# Execute tool calls outside of the stream loop and generation_span
if tool_calls_in_progress:
sorted_tool_calls = sorted(tool_calls_in_progress.items())
# Build tool calls for the message and execute tools
assistant_tool_calls: list[ToolCall] = []
tool_outputs: dict[str, str] = {}
for _, tool_call_data in sorted_tool_calls:
call_id = tool_call_data["id"]
name = tool_call_data["name"]
arguments_str = tool_call_data["arguments"]
if call_id is None or name is None:
continue
assistant_tool_calls.append(
{
"id": call_id,
"type": "function",
"function": {
"name": name,
"arguments": arguments_str,
},
}
)
yield RunItemStreamEvent(
type="tool_call",
details=ToolCallStreamItem(
call_id=call_id,
name=name,
arguments=arguments_str,
),
)
if name in tools_by_name:
tool = tools_by_name[name]
arguments = json.loads(arguments_str)
run_context = RunContextWrapper(context=context)
# TODO: Instead of executing sequentially, execute in parallel
# In practice, it's not a must right now since we don't use parallel
# tool calls, so kicking the can down the road for now.
with function_span(tool.name) as span_fn:
span_fn.span_data.input = arguments
try:
output = tool.run_v2(run_context, **arguments)
tool_outputs[call_id] = _serialize_tool_output(output)
span_fn.span_data.output = output
except Exception as e:
_error_tracing.attach_error_to_current_span(
SpanError(
message="Error running tool",
data={"tool_name": tool.name, "error": str(e)},
)
)
# Treat the error as the tool output so the framework can continue
error_output = f"Error: {str(e)}"
tool_outputs[call_id] = error_output
output = error_output
yield RunItemStreamEvent(
type="tool_call",
details=ToolCallStreamItem(
type="tool_call_output",
details=ToolCallOutputStreamItem(
call_id=call_id,
name=name,
arguments=arguments_str,
output=output,
),
)
else:
not_found_output = f"Tool {name} not found"
tool_outputs[call_id] = _serialize_tool_output(not_found_output)
yield RunItemStreamEvent(
type="tool_call_output",
details=ToolCallOutputStreamItem(
call_id=call_id,
output=not_found_output,
),
)
if name in tools_by_name:
tool = tools_by_name[name]
arguments = json.loads(arguments_str)
new_messages_stateful.append(
{
"role": "assistant",
"content": None,
"tool_calls": assistant_tool_calls,
}
)
run_context = RunContextWrapper(context=context)
for _, tool_call_data in sorted_tool_calls:
call_id = tool_call_data["id"]
# TODO: Instead of executing sequentially, execute in parallel
# In practice, it's not a must right now since we don't use parallel
# tool calls, so kicking the can down the road for now.
output = tool.run_v2(run_context, **arguments)
tool_outputs[call_id] = _serialize_tool_output(output)
yield RunItemStreamEvent(
type="tool_call_output",
details=ToolCallOutputStreamItem(
call_id=call_id,
output=output,
),
)
new_messages_stateful.append(
{
"role": "assistant",
"content": None,
"tool_calls": assistant_tool_calls,
}
)
for _, tool_call_data in sorted_tool_calls:
call_id = tool_call_data["id"]
if call_id in tool_outputs:
new_messages_stateful.append(
{
"role": "tool",
"content": tool_outputs[call_id],
"tool_call_id": call_id,
}
)
elif finish_reason == "stop" and content_parts:
new_messages_stateful.append(
{
"role": "assistant",
"content": "".join(content_parts),
}
)
if call_id in tool_outputs:
new_messages_stateful.append(
{
"role": "tool",
"content": tool_outputs[call_id],
"tool_call_id": call_id,
}
)
current_span.finish(reset_current=True)
return QueryResult(
stream=stream_generator(),

View File

@@ -26,9 +26,9 @@ def monkey_patch_convert_tool_choice_to_ignore_openai_hosted_web_search() -> Non
# Without this patch, the library uses special formatting that breaks our custom tools
# See: https://platform.openai.com/docs/api-reference/responses/create#responses_create-tool_choice-hosted_tool-type
if tool_choice == "web_search":
return {"type": "function", "name": "web_search"}
return "web_search"
if tool_choice == "image_generation":
return {"type": "function", "name": "image_generation"}
return "image_generation"
return orig_func(cls, tool_choice)
OpenAIResponsesConverter.convert_tool_choice = classmethod( # type: ignore[method-assign, assignment]

View File

@@ -12,13 +12,12 @@ from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.configs.chat_configs import EXA_API_KEY
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.utils.retry_wrapper import retry_builder
class ExaClient(WebSearchProvider):
def __init__(self, api_key: str | None = EXA_API_KEY) -> None:
def __init__(self, api_key: str) -> None:
self.exa = Exa(api_key=api_key)
@retry_builder(tries=3, delay=1, backoff=2)

View File

@@ -0,0 +1,163 @@
from __future__ import annotations
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from datetime import datetime
from typing import Any
import requests
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebContent
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContentProvider,
)
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
FIRECRAWL_SCRAPE_URL = "https://api.firecrawl.dev/v1/scrape"
_DEFAULT_MAX_WORKERS = 4
@dataclass
class ExtractedContentFields:
text: str
title: str
published_date: datetime | None
class FirecrawlClient(WebContentProvider):
def __init__(
self,
api_key: str,
*,
base_url: str = FIRECRAWL_SCRAPE_URL,
timeout_seconds: int = 30,
) -> None:
self._headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
self._base_url = base_url
self._timeout_seconds = timeout_seconds
self._last_error: str | None = None
@property
def last_error(self) -> str | None:
return self._last_error
def contents(self, urls: Sequence[str]) -> list[WebContent]:
if not urls:
return []
max_workers = min(_DEFAULT_MAX_WORKERS, len(urls))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
return list(executor.map(self._get_webpage_content_safe, urls))
def _get_webpage_content_safe(self, url: str) -> WebContent:
try:
return self._get_webpage_content(url)
except Exception as exc:
self._last_error = str(exc)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
@retry_builder(tries=3, delay=1, backoff=2)
def _get_webpage_content(self, url: str) -> WebContent:
payload = {
"url": url,
"formats": ["markdown"],
}
response = requests.post(
self._base_url,
headers=self._headers,
json=payload,
timeout=self._timeout_seconds,
)
if response.status_code != 200:
try:
error_payload = response.json()
except Exception:
error_payload = response.text
self._last_error = (
error_payload if isinstance(error_payload, str) else str(error_payload)
)
if 400 <= response.status_code < 500:
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
raise ValueError(
f"Firecrawl fetch failed with status {response.status_code}."
)
else:
self._last_error = None
response_json = response.json()
extracted = self._extract_content_fields(response_json, url)
return WebContent(
title=extracted.title,
link=url,
full_content=extracted.text,
published_date=extracted.published_date,
scrape_successful=bool(extracted.text),
)
@staticmethod
def _extract_content_fields(
response_json: dict[str, Any], url: str
) -> ExtractedContentFields:
data_section = response_json.get("data") or {}
metadata = data_section.get("metadata") or response_json.get("metadata") or {}
text_candidates = [
data_section.get("markdown"),
data_section.get("content"),
data_section.get("text"),
response_json.get("markdown"),
response_json.get("content"),
response_json.get("text"),
]
text = next((candidate for candidate in text_candidates if candidate), "")
title = metadata.get("title") or response_json.get("title") or ""
published_date = None
published_date_str = (
metadata.get("publishedTime")
or metadata.get("date")
or response_json.get("publishedTime")
or response_json.get("date")
)
if published_date_str:
try:
published_date = time_str_to_utc(published_date_str)
except Exception:
published_date = None
if not text:
logger.warning(f"Firecrawl returned empty content for url={url}")
return ExtractedContentFields(
text=text or "",
title=title or "",
published_date=published_date,
)

View File

@@ -0,0 +1,138 @@
from __future__ import annotations
from collections.abc import Sequence
from datetime import datetime
from typing import Any
import requests
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebContent
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import WebSearchResult
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
GOOGLE_CUSTOM_SEARCH_URL = "https://customsearch.googleapis.com/customsearch/v1"
class GooglePSEClient(WebSearchProvider):
def __init__(
self,
api_key: str,
search_engine_id: str,
*,
num_results: int = 10,
timeout_seconds: int = 10,
) -> None:
self._api_key = api_key
self._search_engine_id = search_engine_id
self._num_results = num_results
self._timeout_seconds = timeout_seconds
@retry_builder(tries=3, delay=1, backoff=2)
def search(self, query: str) -> list[WebSearchResult]:
params: dict[str, str] = {
"key": self._api_key,
"cx": self._search_engine_id,
"q": query,
"num": str(self._num_results),
}
response = requests.get(
GOOGLE_CUSTOM_SEARCH_URL, params=params, timeout=self._timeout_seconds
)
# Check for HTTP errors first
try:
response.raise_for_status()
except requests.HTTPError as exc:
status = response.status_code
error_detail = "Unknown error"
try:
error_data = response.json()
if "error" in error_data:
error_info = error_data["error"]
error_detail = error_info.get("message", str(error_info))
except Exception:
error_detail = (
response.text[:200] if response.text else "No error details"
)
raise ValueError(
f"Google PSE search failed (status {status}): {error_detail}"
) from exc
data = response.json()
# Google Custom Search API can return errors in the response body even with 200 status
if "error" in data:
error_info = data["error"]
error_message = error_info.get("message", "Unknown error")
error_code = error_info.get("code", "Unknown")
raise ValueError(f"Google PSE API error ({error_code}): {error_message}")
items: list[dict[str, Any]] = data.get("items", [])
results: list[WebSearchResult] = []
for item in items:
link = item.get("link")
if not link:
continue
snippet = item.get("snippet") or ""
# Attempt to extract metadata if available
pagemap = item.get("pagemap") or {}
metatags = pagemap.get("metatags", [])
published_date: datetime | None = None
author: str | None = None
if metatags:
meta = metatags[0]
author = meta.get("og:site_name") or meta.get("author")
published_str = (
meta.get("article:published_time")
or meta.get("og:updated_time")
or meta.get("date")
)
if published_str:
try:
published_date = datetime.fromisoformat(
published_str.replace("Z", "+00:00")
)
except ValueError:
logger.debug(
f"Failed to parse published_date '{published_str}' for link {link}"
)
published_date = None
results.append(
WebSearchResult(
title=item.get("title") or "",
link=link,
snippet=snippet,
author=author,
published_date=published_date,
)
)
return results
def contents(self, urls: Sequence[str]) -> list[WebContent]:
logger.warning(
"Google PSE does not support content fetching; returning empty results."
)
return [
WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
for url in urls
]

View File

@@ -0,0 +1,94 @@
from __future__ import annotations
from collections.abc import Sequence
import requests
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContent,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContentProvider,
)
from onyx.file_processing.html_utils import ParsedHTML
from onyx.file_processing.html_utils import web_html_cleanup
from onyx.utils.logger import setup_logger
logger = setup_logger()
DEFAULT_TIMEOUT_SECONDS = 15
DEFAULT_USER_AGENT = "OnyxWebCrawler/1.0 (+https://www.onyx.app)"
class OnyxWebCrawlerClient(WebContentProvider):
"""
Lightweight built-in crawler that fetches HTML directly and extracts readable text.
Acts as the default content provider when no external crawler (e.g. Firecrawl) is
configured.
"""
def __init__(
self,
*,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
user_agent: str = DEFAULT_USER_AGENT,
) -> None:
self._timeout_seconds = timeout_seconds
self._headers = {
"User-Agent": user_agent,
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
}
def contents(self, urls: Sequence[str]) -> list[WebContent]:
results: list[WebContent] = []
for url in urls:
results.append(self._fetch_url(url))
return results
def _fetch_url(self, url: str) -> WebContent:
try:
response = requests.get(
url, headers=self._headers, timeout=self._timeout_seconds
)
except Exception as exc: # pragma: no cover - network failures vary
logger.warning(
"Onyx crawler failed to fetch %s (%s)",
url,
exc.__class__.__name__,
)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
if response.status_code >= 400:
logger.warning("Onyx crawler received %s for %s", response.status_code, url)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
try:
parsed: ParsedHTML = web_html_cleanup(response.text)
text_content = parsed.cleaned_text or ""
title = parsed.title or ""
except Exception as exc:
logger.warning(
"Onyx crawler failed to parse %s (%s)", url, exc.__class__.__name__
)
text_content = ""
title = ""
return WebContent(
title=title,
link=url,
full_content=text_content,
published_date=None,
scrape_successful=bool(text_content.strip()),
)

View File

@@ -13,7 +13,6 @@ from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.configs.chat_configs import SERPER_API_KEY
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.utils.retry_wrapper import retry_builder
@@ -22,7 +21,7 @@ SERPER_CONTENTS_URL = "https://scrape.serper.dev"
class SerperClient(WebSearchProvider):
def __init__(self, api_key: str | None = SERPER_API_KEY) -> None:
def __init__(self, api_key: str) -> None:
self.headers = {
"X-API-KEY": api_key,
"Content-Type": "application/json",
@@ -40,7 +39,13 @@ class SerperClient(WebSearchProvider):
data=json.dumps(payload),
)
response.raise_for_status()
try:
response.raise_for_status()
except Exception:
# Avoid leaking API keys/URLs
raise ValueError(
"Serper search failed. Check credentials or quota."
) from None
results = response.json()
organic_results = results["organic"]
@@ -99,7 +104,13 @@ class SerperClient(WebSearchProvider):
scrape_successful=False,
)
response.raise_for_status()
try:
response.raise_for_status()
except Exception:
# Avoid leaking API keys/URLs
raise ValueError(
"Serper content fetch failed. Check credentials."
) from None
response_json = response.json()

View File

@@ -74,13 +74,22 @@ def web_search(
if not provider:
raise ValueError("No internet search provider found")
# Log which provider type is being used
provider_type = type(provider).__name__
logger.info(
f"Performing web search with {provider_type} for query: '{search_query}'"
)
@traceable(name="Search Provider API Call")
def _search(search_query: str) -> list[WebSearchResult]:
search_results: list[WebSearchResult] = []
try:
search_results = list(provider.search(search_query))
logger.info(
f"Search returned {len(search_results)} results using {provider_type}"
)
except Exception as e:
logger.error(f"Error performing search: {e}")
logger.error(f"Error performing search with {provider_type}: {e}")
return search_results
search_results: list[WebSearchResult] = _search(search_query)

View File

@@ -5,7 +5,7 @@ from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
get_default_provider,
get_default_content_provider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchUpdate
@@ -41,9 +41,9 @@ def web_fetch(
if graph_config.inputs.persona is None:
raise ValueError("persona is not set")
provider = get_default_provider()
provider = get_default_content_provider()
if provider is None:
raise ValueError("No web search provider found")
raise ValueError("No web content provider found")
retrieved_docs: list[InferenceSection] = []
try:
@@ -52,7 +52,7 @@ def web_fetch(
for result in provider.contents(state.urls_to_open)
]
except Exception as e:
logger.error(f"Error fetching URLs: {e}")
logger.exception(e)
if not retrieved_docs:
logger.warning("No content retrieved from URLs")

View File

@@ -2,7 +2,6 @@ from abc import ABC
from abc import abstractmethod
from collections.abc import Sequence
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from pydantic import field_validator
@@ -10,13 +9,6 @@ from pydantic import field_validator
from onyx.utils.url import normalize_url
class ProviderType(Enum):
"""Enum for internet search provider types"""
GOOGLE = "google"
EXA = "exa"
class WebSearchResult(BaseModel):
title: str
link: str
@@ -43,11 +35,13 @@ class WebContent(BaseModel):
return normalize_url(v)
class WebSearchProvider(ABC):
@abstractmethod
def search(self, query: str) -> Sequence[WebSearchResult]:
pass
class WebContentProvider(ABC):
@abstractmethod
def contents(self, urls: Sequence[str]) -> list[WebContent]:
pass
class WebSearchProvider(WebContentProvider):
@abstractmethod
def search(self, query: str) -> Sequence[WebSearchResult]:
pass

View File

@@ -1,19 +1,199 @@
from typing import Any
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.exa_client import (
ExaClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.firecrawl_client import (
FIRECRAWL_SCRAPE_URL,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.firecrawl_client import (
FirecrawlClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.google_pse_client import (
GooglePSEClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.onyx_web_crawler_client import (
OnyxWebCrawlerClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.serper_client import (
SerperClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContentProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.configs.chat_configs import EXA_API_KEY
from onyx.configs.chat_configs import SERPER_API_KEY
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.web_search import fetch_active_web_content_provider
from onyx.db.web_search import fetch_active_web_search_provider
from onyx.utils.logger import setup_logger
from shared_configs.enums import WebContentProviderType
from shared_configs.enums import WebSearchProviderType
logger = setup_logger()
def build_search_provider_from_config(
*,
provider_type: WebSearchProviderType,
api_key: str | None,
config: dict[str, str] | None,
provider_name: str = "web_search_provider",
) -> WebSearchProvider | None:
provider_type_value = provider_type.value
try:
provider_type_enum = WebSearchProviderType(provider_type_value)
except ValueError:
logger.error(
f"Unknown web search provider type '{provider_type_value}'. "
"Skipping provider initialization."
)
return None
# All web search providers require an API key
if not api_key:
raise ValueError(
f"Web search provider '{provider_name}' is missing an API key."
)
assert api_key is not None
config = config or {}
if provider_type_enum == WebSearchProviderType.EXA:
return ExaClient(api_key=api_key)
if provider_type_enum == WebSearchProviderType.SERPER:
return SerperClient(api_key=api_key)
if provider_type_enum == WebSearchProviderType.GOOGLE_PSE:
search_engine_id = (
config.get("search_engine_id")
or config.get("cx")
or config.get("search_engine")
)
if not search_engine_id:
raise ValueError(
"Google PSE provider requires a search engine id (cx) in addition to the API key."
)
assert search_engine_id is not None
try:
num_results = int(config.get("num_results", 10))
except (TypeError, ValueError):
raise ValueError(
"Invalid value for Google PSE 'num_results'; expected integer."
)
try:
timeout_seconds = int(config.get("timeout_seconds", 10))
except (TypeError, ValueError):
raise ValueError(
"Invalid value for Google PSE 'timeout_seconds'; expected integer."
)
return GooglePSEClient(
api_key=api_key,
search_engine_id=search_engine_id,
num_results=num_results,
timeout_seconds=timeout_seconds,
)
logger.error(
f"Unhandled web search provider type '{provider_type_value}'. "
"Skipping provider initialization."
)
return None
def _build_search_provider(provider_model: Any) -> WebSearchProvider | None:
return build_search_provider_from_config(
provider_type=WebSearchProviderType(provider_model.provider_type),
api_key=provider_model.api_key,
config=provider_model.config or {},
provider_name=provider_model.name,
)
def build_content_provider_from_config(
*,
provider_type: WebContentProviderType,
api_key: str | None,
config: dict[str, str] | None,
provider_name: str = "web_content_provider",
) -> WebContentProvider | None:
provider_type_value = provider_type.value
try:
provider_type_enum = WebContentProviderType(provider_type_value)
except ValueError:
logger.error(
f"Unknown web content provider type '{provider_type_value}'. "
"Skipping provider initialization."
)
return None
if provider_type_enum == WebContentProviderType.ONYX_WEB_CRAWLER:
config = config or {}
timeout_value = config.get("timeout_seconds", 15)
try:
timeout_seconds = int(timeout_value)
except (TypeError, ValueError):
raise ValueError(
"Invalid value for Onyx Web Crawler 'timeout_seconds'; expected integer."
)
return OnyxWebCrawlerClient(timeout_seconds=timeout_seconds)
if provider_type_enum == WebContentProviderType.FIRECRAWL:
if not api_key:
raise ValueError("Firecrawl content provider requires an API key.")
assert api_key is not None
config = config or {}
timeout_seconds_str = config.get("timeout_seconds")
if timeout_seconds_str is None:
timeout_seconds = 10
else:
try:
timeout_seconds = int(timeout_seconds_str)
except (TypeError, ValueError):
raise ValueError(
"Invalid value for Firecrawl 'timeout_seconds'; expected integer."
)
return FirecrawlClient(
api_key=api_key,
base_url=config.get("base_url") or FIRECRAWL_SCRAPE_URL,
timeout_seconds=timeout_seconds,
)
logger.error(
f"Unhandled web content provider type '{provider_type_value}'. "
"Skipping provider initialization."
)
return None
def _build_content_provider(provider_model: Any) -> WebContentProvider | None:
return build_content_provider_from_config(
provider_type=WebContentProviderType(provider_model.provider_type),
api_key=provider_model.api_key,
config=provider_model.config or {},
provider_name=provider_model.name,
)
def get_default_provider() -> WebSearchProvider | None:
if EXA_API_KEY:
return ExaClient()
if SERPER_API_KEY:
return SerperClient()
return None
with get_session_with_current_tenant() as db_session:
provider_model = fetch_active_web_search_provider(db_session)
if provider_model is None:
return None
return _build_search_provider(provider_model)
def get_default_content_provider() -> WebContentProvider | None:
with get_session_with_current_tenant() as db_session:
provider_model = fetch_active_web_content_provider(db_session)
if provider_model:
provider = _build_content_provider(provider_model)
if provider:
return provider
# Fall back to built-in Onyx crawler when nothing is configured.
try:
return OnyxWebCrawlerClient()
except Exception as exc: # pragma: no cover - defensive
logger.error(f"Failed to initialize default Onyx crawler: {exc}")
return None

View File

@@ -11,11 +11,8 @@ from jwt import decode as jwt_decode
from jwt import InvalidTokenError
from jwt import PyJWTError
from jwt.algorithms import RSAAlgorithm
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from onyx.configs.app_configs import JWT_PUBLIC_KEY_URL
from onyx.db.models import User
from onyx.utils.logger import setup_logger
@@ -131,7 +128,7 @@ def _resolve_public_key_from_jwks(
return None
async def verify_jwt_token(token: str, async_db_session: AsyncSession) -> User | None:
async def verify_jwt_token(token: str) -> dict[str, Any] | None:
for attempt in range(_PUBLIC_KEY_FETCH_ATTEMPTS):
public_key = get_public_key(token)
if public_key is None:
@@ -142,8 +139,6 @@ async def verify_jwt_token(token: str, async_db_session: AsyncSession) -> User |
return None
try:
from sqlalchemy import func
payload = jwt_decode(
token,
public_key,
@@ -163,15 +158,6 @@ async def verify_jwt_token(token: str, async_db_session: AsyncSession) -> User |
continue
return None
email = payload.get("email")
if email:
result = await async_db_session.execute(
select(User).where(func.lower(User.email) == func.lower(email))
)
return result.scalars().first()
logger.warning(
"JWT token decoded successfully but no email claim found; skipping auth"
)
break
return payload
return None

View File

@@ -611,6 +611,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
request=request,
)
user_count = None
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
user_count = await get_user_count()
@@ -633,6 +634,57 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
# Fetch EE PostHog functions if available
get_marketing_posthog_cookie_name = fetch_ee_implementation_or_noop(
module="onyx.utils.posthog_client",
attribute="get_marketing_posthog_cookie_name",
noop_return_value=None,
)
parse_marketing_cookie = fetch_ee_implementation_or_noop(
module="onyx.utils.posthog_client",
attribute="parse_marketing_cookie",
noop_return_value=None,
)
capture_and_sync_with_alternate_posthog = fetch_ee_implementation_or_noop(
module="onyx.utils.posthog_client",
attribute="capture_and_sync_with_alternate_posthog",
noop_return_value=None,
)
if (
request
and user_count is not None
and (marketing_cookie_name := get_marketing_posthog_cookie_name())
and (marketing_cookie_value := request.cookies.get(marketing_cookie_name))
and (parsed_cookie := parse_marketing_cookie(marketing_cookie_value))
):
marketing_anonymous_id = parsed_cookie["distinct_id"]
# Technically, USER_SIGNED_UP is only fired from the cloud site when
# it is the first user in a tenant. However, it is semantically correct
# for the marketing site and should probably be refactored for the cloud site
# to also be semantically correct.
properties = {
"email": user.email,
"onyx_cloud_user_id": str(user.id),
"tenant_id": str(tenant_id) if tenant_id else None,
"role": user.role.value,
"is_first_user": user_count == 1,
"source": "marketing_site_signup",
"conversion_timestamp": datetime.now(timezone.utc).isoformat(),
}
# Add all other values from the marketing cookie (featureFlags, etc.)
for key, value in parsed_cookie.items():
if key != "distinct_id":
properties.setdefault(key, value)
capture_and_sync_with_alternate_posthog(
alternate_distinct_id=marketing_anonymous_id,
event=MilestoneRecordType.USER_SIGNED_UP,
properties=properties,
)
logger.debug(f"User {user.id} has registered.")
optional_telemetry(
record_type=RecordType.SIGN_UP,
@@ -1063,6 +1115,107 @@ fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
optional_fastapi_current_user = fastapi_users.current_user(active=True, optional=True)
_JWT_EMAIL_CLAIM_KEYS = ("email", "preferred_username", "upn")
def _extract_email_from_jwt(payload: dict[str, Any]) -> str | None:
"""Return the best-effort email/username from a decoded JWT payload."""
for key in _JWT_EMAIL_CLAIM_KEYS:
value = payload.get(key)
if isinstance(value, str) and value:
try:
email_info = validate_email(value, check_deliverability=False)
except EmailNotValidError:
continue
normalized_email = email_info.normalized or email_info.email
return normalized_email.lower()
return None
async def _sync_jwt_oidc_expiry(
user_manager: UserManager, user: User, payload: dict[str, Any]
) -> None:
if TRACK_EXTERNAL_IDP_EXPIRY:
expires_at = payload.get("exp")
if expires_at is None:
return
try:
expiry_timestamp = int(expires_at)
except (TypeError, ValueError):
logger.warning("Invalid exp claim on JWT for user %s", user.email)
return
oidc_expiry = datetime.fromtimestamp(expiry_timestamp, tz=timezone.utc)
if user.oidc_expiry == oidc_expiry:
return
await user_manager.user_db.update(user, {"oidc_expiry": oidc_expiry})
user.oidc_expiry = oidc_expiry # type: ignore
return
if user.oidc_expiry is not None:
await user_manager.user_db.update(user, {"oidc_expiry": None})
user.oidc_expiry = None # type: ignore
async def _get_or_create_user_from_jwt(
payload: dict[str, Any],
request: Request,
async_db_session: AsyncSession,
) -> User | None:
email = _extract_email_from_jwt(payload)
if email is None:
logger.warning(
"JWT token decoded successfully but no email claim found; skipping auth"
)
return None
# Enforce the same allowlist/domain policies as other auth flows
verify_email_is_invited(email)
verify_email_domain(email)
user_db: SQLAlchemyUserAdminDB[User, uuid.UUID] = SQLAlchemyUserAdminDB(
async_db_session, User, OAuthAccount
)
user_manager = UserManager(user_db)
try:
user = await user_manager.get_by_email(email)
if not user.is_active:
logger.warning("Inactive user %s attempted JWT login; skipping", email)
return None
if not user.role.is_web_login():
raise exceptions.UserNotExists()
except exceptions.UserNotExists:
logger.info("Provisioning user %s from JWT login", email)
try:
user = await user_manager.create(
UserCreate(
email=email,
password=generate_password(),
is_verified=True,
),
request=request,
)
except exceptions.UserAlreadyExists:
user = await user_manager.get_by_email(email)
if not user.is_active:
logger.warning(
"Inactive user %s attempted JWT login during provisioning race; skipping",
email,
)
return None
if not user.role.is_web_login():
logger.warning(
"Non-web-login user %s attempted JWT login during provisioning race; skipping",
email,
)
return None
await _sync_jwt_oidc_expiry(user_manager, user, payload)
return user
async def _check_for_saml_and_jwt(
request: Request,
user: User | None,
@@ -1073,7 +1226,11 @@ async def _check_for_saml_and_jwt(
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header[len("Bearer ") :].strip()
user = await verify_jwt_token(token, async_db_session)
payload = await verify_jwt_token(token)
if payload is not None:
user = await _get_or_create_user_from_jwt(
payload, request, async_db_session
)
return user

View File

@@ -341,6 +341,9 @@ def stream_chat_message_objects(
# messages.
# NOTE: is not stored in the database at all.
single_message_history: str | None = None,
# Thread messages as PreviousMessage objects for proper query rephrasing
# Used in Slack bot flow to enable contextual search
thread_message_history: list[PreviousMessage] | None = None,
) -> AnswerStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
@@ -757,6 +760,13 @@ def stream_chat_message_objects(
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
]
# If thread messages are provided (e.g., from Slack), use them for query rephrasing
if thread_message_history:
message_history = thread_message_history + message_history
logger.info(
f"Added {len(thread_message_history)} thread messages to history for query rephrasing"
)
if not search_tool_override_kwargs_for_user_files and in_memory_user_files:
# we only want to send the user files attached to the current message
yield UserKnowledgeFilePacket(
@@ -781,6 +791,7 @@ def stream_chat_message_objects(
user_query=final_msg.message,
prompt_config=prompt_config,
files=latest_query_files,
single_message_history=single_message_history,
)
memories = get_memories(user, db_session)
system_message = (

View File

@@ -171,7 +171,7 @@ def default_build_user_message(
task_prompt=prompt_config.reminder,
user_query=user_query,
)
if prompt_config.reminder
if prompt_config.reminder or single_message_history
else user_query
)

View File

@@ -53,8 +53,8 @@ from onyx.server.query_and_chat.streaming_models import PacketObj
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.tools.adapter_v1_to_v2 import force_use_tool_to_function_tool_names
from onyx.tools.adapter_v1_to_v2 import tools_to_function_tools
from onyx.tools.force import filter_tools_for_force_tool_use
from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool
@@ -109,6 +109,10 @@ def _run_agent_loop(
available_tools: Sequence[Tool] = (
dependencies.tools if iteration_count < MAX_ITERATIONS else []
)
if force_use_tool and force_use_tool.force_use:
available_tools = filter_tools_for_force_tool_use(
list(available_tools), force_use_tool
)
memories = get_memories(dependencies.user_or_none, dependencies.db_session)
# TODO: The system is rather prompt-cache efficient except for rebuilding the system prompt.
# The biggest offender is when we hit max iterations and then all the tool calls cannot
@@ -142,10 +146,8 @@ def _run_agent_loop(
tool_choice = None
else:
tool_choice = (
force_use_tool_to_function_tool_names(force_use_tool, available_tools)
if iteration_count == 0 and force_use_tool
else None
) or "auto"
"required" if force_use_tool and force_use_tool.force_use else "auto"
)
model_settings = replace(dependencies.model_settings, tool_choice=tool_choice)
agent = Agent(
@@ -430,9 +432,9 @@ def _default_packet_translation(
# (e.g. if we've already sent the MessageStart / MessageDelta packets, then we
# shouldn't do anything)
if ctx.current_output_index == output_index:
packets.append(Packet(ind=ctx.current_run_step, obj=SectionEnd()))
ctx.current_run_step += 1
ctx.current_output_index = None
packets.append(Packet(ind=ctx.current_run_step, obj=SectionEnd()))
# ------------------------------------------------------------
# Message packets

View File

@@ -581,6 +581,12 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = (
os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true"
)
# TestRail specific configs
TESTRAIL_BASE_URL = os.environ.get("TESTRAIL_BASE_URL", "")
TESTRAIL_USERNAME = os.environ.get("TESTRAIL_USERNAME", "")
TESTRAIL_API_KEY = os.environ.get("TESTRAIL_API_KEY", "")
LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE = (
os.environ.get("LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE", "").lower()
== "true"

View File

@@ -89,9 +89,6 @@ STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true"
# Internet Search
EXA_API_KEY = os.environ.get("EXA_API_KEY") or None
SERPER_API_KEY = os.environ.get("SERPER_API_KEY") or None
NUM_INTERNET_SEARCH_RESULTS = int(os.environ.get("NUM_INTERNET_SEARCH_RESULTS") or 10)
NUM_INTERNET_SEARCH_CHUNKS = int(os.environ.get("NUM_INTERNET_SEARCH_CHUNKS") or 50)

View File

@@ -211,6 +211,7 @@ class DocumentSource(str, Enum):
IMAP = "imap"
BITBUCKET = "bitbucket"
TESTRAIL = "testrail"
# Special case just for integration tests
MOCK_CONNECTOR = "mock_connector"
@@ -618,4 +619,5 @@ project management, and collaboration tools into a single, customizable platform
DocumentSource.AIRTABLE: "airtable - database",
DocumentSource.HIGHSPOT: "highspot - CRM data",
DocumentSource.IMAP: "imap - email data",
DocumentSource.TESTRAIL: "testrail - test case management tool for QA processes",
}

View File

@@ -468,6 +468,9 @@ class BlobStorageConnector(LoadConnector, PollConnector):
link = onyx_metadata.link or link
primary_owners = onyx_metadata.primary_owners
secondary_owners = onyx_metadata.secondary_owners
source_type = onyx_metadata.source_type or DocumentSource(
self.bucket_type.value
)
sections: list[TextSection | ImageSection] = []
if extraction_result.text_content.strip():
@@ -489,7 +492,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
if sections
else [TextSection(link=link, text="")]
),
source=DocumentSource(self.bucket_type.value),
source=source_type,
semantic_identifier=file_display_name,
doc_updated_at=time_updated,
metadata=custom_tags,

View File

@@ -13,14 +13,17 @@ from dateutil.parser import parse
from dateutil.parser import ParserError
from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import IGNORE_FOR_QA
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import OnyxMetadata
from onyx.utils.logger import setup_logger
from onyx.utils.text_processing import is_valid_email
T = TypeVar("T")
U = TypeVar("U")
logger = setup_logger()
def datetime_to_utc(dt: datetime) -> datetime:
@@ -105,6 +108,28 @@ def get_metadata_keys_to_ignore() -> list[str]:
return [IGNORE_FOR_QA]
def _parse_document_source(connector_type: Any) -> DocumentSource | None:
if connector_type is None:
return None
if isinstance(connector_type, DocumentSource):
return connector_type
if not isinstance(connector_type, str):
logger.warning(f"Invalid connector_type type: {type(connector_type).__name__}")
return None
normalized = re.sub(r"[\s\-]+", "_", connector_type.strip().lower())
try:
return DocumentSource(normalized)
except ValueError:
logger.warning(
f"Invalid connector_type value: '{connector_type}' "
f"(normalized: '{normalized}')"
)
return None
def process_onyx_metadata(
metadata: dict[str, Any],
) -> tuple[OnyxMetadata, dict[str, Any]]:
@@ -125,13 +150,14 @@ def process_onyx_metadata(
if s_owner_names
else None
)
source_type = _parse_document_source(metadata.get("connector_type"))
dt_str = metadata.get("doc_updated_at")
doc_updated_at = time_str_to_utc(dt_str) if dt_str else None
return (
OnyxMetadata(
source_type=metadata.get("connector_type"),
source_type=source_type,
link=metadata.get("link"),
file_display_name=metadata.get("file_display_name"),
title=metadata.get("title"),

View File

@@ -105,10 +105,7 @@ def _process_file(
link = onyx_metadata.link
# These metadata items are not settable by the user
source_type_str = metadata.get("connector_type")
source_type = (
DocumentSource(source_type_str) if source_type_str else DocumentSource.FILE
)
source_type = onyx_metadata.source_type or DocumentSource.FILE
doc_id = f"FILE_CONNECTOR__{file_id}"
title = metadata.get("title") or file_display_name

View File

@@ -426,7 +426,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
title = ticket.properties.get("subject") or f"Ticket {ticket.id}"
link = self._get_object_url("tickets", ticket.id)
content_text = ticket.properties.get("content", "")
content_text = ticket.properties.get("content") or ""
# Main ticket section
sections = [TextSection(link=link, text=content_text)]

View File

@@ -200,6 +200,10 @@ CONNECTOR_CLASS_MAP = {
module_path="onyx.connectors.bitbucket.connector",
class_name="BitbucketConnector",
),
DocumentSource.TESTRAIL: ConnectorMapping(
module_path="onyx.connectors.testrail.connector",
class_name="TestRailConnector",
),
# just for integration tests
DocumentSource.MOCK_CONNECTOR: ConnectorMapping(
module_path="onyx.connectors.mock_connector.connector",

View File

@@ -0,0 +1 @@
# Package marker for TestRail connector

View File

@@ -0,0 +1,560 @@
from __future__ import annotations
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import ClassVar
from typing import Optional
import requests
from bs4 import BeautifulSoup
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.file_processing.html_utils import format_document_soup
from onyx.utils.logger import setup_logger
from onyx.utils.text_processing import remove_markdown_image_references
logger = setup_logger()
class TestRailConnector(LoadConnector, PollConnector):
"""Connector for TestRail.
Minimal implementation that indexes Test Cases per project.
"""
document_source_type: ClassVar[DocumentSource] = DocumentSource.TESTRAIL
# Fields that need ID-to-label value mapping
FIELDS_NEEDING_VALUE_MAPPING: ClassVar[set[str]] = {
"priority_id",
"custom_automation_type",
"custom_scenario_db_automation",
"custom_case_golden_canvas_automation",
"custom_customers",
"custom_case_environments",
"custom_case_overall_automation",
"custom_case_team_ownership",
"custom_case_unit_or_integration_automation",
"custom_effort",
}
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
project_ids: str | list[int] | None = None,
cases_page_size: int | None = None,
max_pages: int | None = None,
skip_doc_absolute_chars: int | None = None,
) -> None:
self.base_url: str | None = None
self.username: str | None = None
self.api_key: str | None = None
self.batch_size = batch_size
parsed_project_ids: list[int] | None
# Parse project_ids from string if needed
# None = all projects (no filtering), [] = no projects, [1,2,3] = specific projects
if isinstance(project_ids, str):
if project_ids.strip():
parsed_project_ids = [
int(x.strip()) for x in project_ids.split(",") if x.strip()
]
else:
# Empty string from UI means "all projects"
parsed_project_ids = None
elif project_ids is None:
parsed_project_ids = None
else:
parsed_project_ids = [int(pid) for pid in project_ids]
self.project_ids: list[int] | None = parsed_project_ids
# Handle empty strings from UI and convert to int with defaults
self.cases_page_size = (
int(cases_page_size)
if cases_page_size and str(cases_page_size).strip()
else 250
)
self.max_pages = (
int(max_pages) if max_pages and str(max_pages).strip() else 10000
)
self.skip_doc_absolute_chars = (
int(skip_doc_absolute_chars)
if skip_doc_absolute_chars and str(skip_doc_absolute_chars).strip()
else 200000
)
# Cache for field labels and value mappings - will be populated on first use
self._field_labels: dict[str, str] | None = None
self._value_maps: dict[str, dict[str, str]] | None = None
# --- Rich text sanitization helpers ---
# Note: TestRail stores some fields as HTML (e.g. shared test steps).
# This function handles both HTML and plain text.
@staticmethod
def _sanitize_rich_text(value: Any) -> str:
if value is None:
return ""
text = str(value)
# Parse HTML and remove image tags
soup = BeautifulSoup(text, "html.parser")
# Remove all img tags and their containers
for img_tag in soup.find_all("img"):
img_tag.decompose()
for span in soup.find_all("span", class_="markdown-img-container"):
span.decompose()
# Use format_document_soup for better HTML-to-text conversion
# This preserves document structure (paragraphs, lists, line breaks, etc.)
text = format_document_soup(soup)
# Also remove markdown-style image references (in case any remain)
text = remove_markdown_image_references(text)
return text.strip()
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# Expected keys from UI credential JSON
self.base_url = str(credentials["testrail_base_url"]).rstrip("/")
self.username = str(credentials["testrail_username"]) # email or username
self.api_key = str(credentials["testrail_api_key"]) # API key (password)
return None
def validate_connector_settings(self) -> None:
"""Lightweight validation to surface common misconfigurations early."""
projects = self._list_projects()
if not projects:
logger.warning("TestRail: no projects visible to this credential.")
# ---- API helpers ----
def _api_get(self, endpoint: str, params: Optional[dict[str, Any]] = None) -> Any:
if not self.base_url or not self.username or not self.api_key:
raise ConnectorMissingCredentialError("testrail")
# TestRail API base is typically /index.php?/api/v2/<endpoint>
url = f"{self.base_url}/index.php?/api/v2/{endpoint}"
try:
response = requests.get(
url,
auth=(self.username, self.api_key),
params=params,
)
response.raise_for_status()
except requests.exceptions.HTTPError as e:
status = e.response.status_code if getattr(e, "response", None) else None
if status == 401:
raise CredentialExpiredError(
"Invalid or expired TestRail credentials (HTTP 401)."
) from e
if status == 403:
raise InsufficientPermissionsError(
"Insufficient permissions to access TestRail resources (HTTP 403)."
) from e
raise UnexpectedValidationError(
f"Unexpected TestRail HTTP error (status={status})."
) from e
except requests.exceptions.RequestException as e:
raise UnexpectedValidationError(f"TestRail request failed: {e}") from e
try:
return response.json()
except ValueError as e:
raise UnexpectedValidationError(
"Invalid JSON returned by TestRail API"
) from e
def _list_projects(self) -> list[dict[str, Any]]:
projects = self._api_get("get_projects")
if isinstance(projects, dict):
projects_list = projects.get("projects")
return projects_list if isinstance(projects_list, list) else []
return []
def _list_suites(self, project_id: int) -> list[dict[str, Any]]:
"""Return suites for a project. If the project is in single-suite mode,
some TestRail instances may return an empty list; callers should
gracefully fallback to calling get_cases without suite_id.
"""
suites = self._api_get(f"get_suites/{project_id}")
if isinstance(suites, dict):
suites_list = suites.get("suites")
return suites_list if isinstance(suites_list, list) else []
return []
def _get_case_fields(self) -> list[dict[str, Any]]:
"""Get case field definitions from TestRail API."""
try:
fields = self._api_get("get_case_fields")
return fields if isinstance(fields, list) else []
except Exception as e:
logger.warning(f"Failed to fetch case fields from TestRail: {e}")
return []
def _parse_items_string(self, items_str: str) -> dict[str, str]:
"""Parse items string from field config into ID -> label mapping.
Format: "1, Option A\\n2, Option B\\n3, Option C"
Returns: {"1": "Option A", "2": "Option B", "3": "Option C"}
"""
id_to_label: dict[str, str] = {}
if not items_str:
return id_to_label
for line in items_str.split("\n"):
line = line.strip()
if not line:
continue
parts = line.split(",", 1)
if len(parts) == 2:
item_id = parts[0].strip()
item_label = parts[1].strip()
id_to_label[item_id] = item_label
return id_to_label
def _build_field_maps(self) -> tuple[dict[str, str], dict[str, dict[str, str]]]:
"""Build both field labels and value mappings in one pass.
Returns:
(field_labels, value_maps) where:
- field_labels: system_name -> label
- value_maps: system_name -> {id -> label}
"""
field_labels = {}
value_maps = {}
try:
fields = self._get_case_fields()
for field in fields:
system_name = field.get("system_name")
# Build field label map
label = field.get("label")
if system_name and label:
field_labels[system_name] = label
# Build value map if needed
if system_name in self.FIELDS_NEEDING_VALUE_MAPPING:
configs = field.get("configs", [])
if configs and len(configs) > 0:
options = configs[0].get("options", {})
items_str = options.get("items")
if items_str:
value_maps[system_name] = self._parse_items_string(
items_str
)
except Exception as e:
logger.warning(f"Failed to build field maps from TestRail: {e}")
return field_labels, value_maps
def _get_field_labels(self) -> dict[str, str]:
"""Get field labels, fetching from API if not cached."""
if self._field_labels is None:
self._field_labels, self._value_maps = self._build_field_maps()
return self._field_labels
def _get_value_maps(self) -> dict[str, dict[str, str]]:
"""Get value maps, fetching from API if not cached."""
if self._value_maps is None:
self._field_labels, self._value_maps = self._build_field_maps()
return self._value_maps
def _map_field_value(self, field_name: str, field_value: Any) -> str:
"""Map a field value using the value map if available.
Examples:
- priority_id: 2 -> "Medium"
- custom_case_team_ownership: [10] -> "Sim Platform"
- custom_case_environments: [1, 2] -> "Local, Cloud"
"""
if field_value is None or field_value == "":
return ""
# Get value map for this field
value_maps = self._get_value_maps()
value_map = value_maps.get(field_name, {})
# Handle list values
if isinstance(field_value, list):
if not field_value:
return ""
mapped = [value_map.get(str(v), str(v)) for v in field_value]
return ", ".join(mapped)
# Handle single values
val_str = str(field_value)
return value_map.get(val_str, val_str)
def _get_cases(
self, project_id: int, suite_id: Optional[int], limit: int, offset: int
) -> list[dict[str, Any]]:
"""Get cases for a project from the API."""
params: dict[str, Any] = {"limit": limit, "offset": offset}
if suite_id is not None:
params["suite_id"] = suite_id
cases_response = self._api_get(f"get_cases/{project_id}", params=params)
cases_list: list[dict[str, Any]] = []
if isinstance(cases_response, dict):
cases_items = cases_response.get("cases")
if isinstance(cases_items, list):
cases_list = cases_items
return cases_list
def _iter_cases(
self,
project_id: int,
suite_id: Optional[int] = None,
start: Optional[SecondsSinceUnixEpoch] = None,
end: Optional[SecondsSinceUnixEpoch] = None,
) -> Iterator[dict[str, Any]]:
# Pagination: TestRail supports 'limit' and 'offset' for many list endpoints
limit = self.cases_page_size
# Use a bounded page loop to avoid infinite loops on API anomalies
for page_index in range(self.max_pages):
offset = page_index * limit
cases = self._get_cases(project_id, suite_id, limit, offset)
if not cases:
break
# Filter by updated window if provided
for case in cases:
# 'updated_on' is unix timestamp (seconds)
updated_on = case.get("updated_on") or case.get("created_on")
if start is not None and updated_on is not None and updated_on < start:
continue
if end is not None and updated_on is not None and updated_on > end:
continue
yield case
if len(cases) < limit:
break
def _build_case_link(self, project_id: int, case_id: int) -> str:
# Standard UI link to a case
return f"{self.base_url}/index.php?/cases/view/{case_id}"
def _doc_from_case(
self,
project: dict[str, Any],
case: dict[str, Any],
suite: dict[str, Any] | None = None,
) -> Document | None:
project_id = project.get("id")
if not isinstance(project_id, int):
logger.warning(
"Skipping TestRail case because project id is missing or invalid: %s",
project_id,
)
return None
case_id = case.get("id")
if not isinstance(case_id, int):
logger.warning(
"Skipping TestRail case because case id is missing or invalid: %s",
case_id,
)
return None
title = case.get("title", f"Case {case_id}")
case_key = f"C{case_id}"
# Convert epoch seconds to aware datetime if available
updated = case.get("updated_on") or case.get("created_on")
updated_dt = (
datetime.fromtimestamp(updated, tz=timezone.utc)
if isinstance(updated, (int, float))
else None
)
text_lines: list[str] = []
if case.get("title"):
text_lines.append(f"Title: {case['title']}")
if case_key:
text_lines.append(f"Case ID: {case_key}")
if case_id is not None:
text_lines.append(f"ID: {case_id}")
doc_link = case.get("custom_documentation_link")
if doc_link:
text_lines.append(f"Documentation: {doc_link}")
# Add fields that need value mapping
field_labels = self._get_field_labels()
for field_name in self.FIELDS_NEEDING_VALUE_MAPPING:
field_value = case.get(field_name)
if field_value is not None and field_value != "" and field_value != []:
mapped_value = self._map_field_value(field_name, field_value)
if mapped_value:
# Get label from TestRail field definition
label = field_labels.get(
field_name, field_name.replace("_", " ").title()
)
text_lines.append(f"{label}: {mapped_value}")
pre = self._sanitize_rich_text(case.get("custom_preconds"))
if pre:
text_lines.append(f"Preconditions: {pre}")
# Steps: use separated steps format if available
steps_added = False
steps_separated = case.get("custom_steps_separated")
if isinstance(steps_separated, list) and steps_separated:
rendered_steps: list[str] = []
for idx, step_item in enumerate(steps_separated, start=1):
step_content = self._sanitize_rich_text(step_item.get("content"))
step_expected = self._sanitize_rich_text(step_item.get("expected"))
parts: list[str] = []
if step_content:
parts.append(f"Step {idx}: {step_content}")
else:
parts.append(f"Step {idx}:")
if step_expected:
parts.append(f"Expected: {step_expected}")
rendered_steps.append("\n".join(parts))
if rendered_steps:
text_lines.append("Steps:\n" + "\n".join(rendered_steps))
steps_added = True
# Fallback to custom_steps and custom_expected if no separated steps
if not steps_added:
custom_steps = self._sanitize_rich_text(case.get("custom_steps"))
custom_expected = self._sanitize_rich_text(case.get("custom_expected"))
if custom_steps:
text_lines.append(f"Steps: {custom_steps}")
if custom_expected:
text_lines.append(f"Expected: {custom_expected}")
link = self._build_case_link(project_id, case_id)
# Build full text and apply size policies
full_text = "\n".join(text_lines)
if len(full_text) > self.skip_doc_absolute_chars:
logger.warning(
f"Skipping TestRail case {case_id} due to excessive size: {len(full_text)} chars"
)
return None
# Metadata for document identification
metadata: dict[str, Any] = {}
if case_key:
metadata["case_key"] = case_key
# Include the human-friendly case key in identifiers for easier search
display_title = f"{case_key}: {title}" if case_key else title
return Document(
id=f"TESTRAIL_CASE_{case_id}",
source=DocumentSource.TESTRAIL,
semantic_identifier=display_title,
title=display_title,
sections=[TextSection(link=link, text=full_text)],
metadata=metadata,
doc_updated_at=updated_dt,
)
def _generate_documents(
self,
start: Optional[SecondsSinceUnixEpoch],
end: Optional[SecondsSinceUnixEpoch],
) -> GenerateDocumentsOutput:
if not self.base_url or not self.username or not self.api_key:
raise ConnectorMissingCredentialError("testrail")
doc_batch: list[Document] = []
projects = self._list_projects()
project_filter: list[int] | None = self.project_ids
for project in projects:
project_id_raw = project.get("id")
if not isinstance(project_id_raw, int):
logger.warning(
"Skipping TestRail project with invalid id: %s", project_id_raw
)
continue
project_id = project_id_raw
# None = index all, [] = index none, [1,2,3] = index only those
if project_filter is not None and project_id not in project_filter:
continue
suites = self._list_suites(project_id)
if suites:
for s in suites:
suite_id = s.get("id")
for case in self._iter_cases(project_id, suite_id, start, end):
doc = self._doc_from_case(project, case, s)
if doc is None:
continue
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
else:
# single-suite mode fallback
for case in self._iter_cases(project_id, None, start, end):
doc = self._doc_from_case(project, case, None)
if doc is None:
continue
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
# ---- Onyx interfaces ----
def load_from_state(self) -> GenerateDocumentsOutput:
return self._generate_documents(start=None, end=None)
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
return self._generate_documents(start=start, end=end)
if __name__ == "__main__":
from onyx.configs.app_configs import (
TESTRAIL_API_KEY,
TESTRAIL_BASE_URL,
TESTRAIL_USERNAME,
)
connector = TestRailConnector()
connector.load_credentials(
{
"testrail_base_url": TESTRAIL_BASE_URL,
"testrail_username": TESTRAIL_USERNAME,
"testrail_api_key": TESTRAIL_API_KEY,
}
)
connector.validate_connector_settings()
# Probe a tiny batch from load
total = 0
for batch in connector.load_from_state():
print(f"Fetched batch: {len(batch)} docs")
total += len(batch)
if total >= 10:
break
print(f"Total fetched in test: {total}")

View File

@@ -15,10 +15,14 @@ from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import TextSection
from onyx.context.search.federated.models import SlackMessage
from onyx.context.search.federated.slack_search_utils import ALL_CHANNEL_TYPES
from onyx.context.search.federated.slack_search_utils import build_channel_query_filter
from onyx.context.search.federated.slack_search_utils import build_slack_queries
from onyx.context.search.federated.slack_search_utils import ChannelTypeString
from onyx.context.search.federated.slack_search_utils import get_channel_type
from onyx.context.search.federated.slack_search_utils import (
get_channel_type_for_missing_scope,
)
from onyx.context.search.federated.slack_search_utils import is_recency_query
from onyx.context.search.federated.slack_search_utils import should_include_message
from onyx.context.search.models import InferenceChunk
@@ -46,7 +50,6 @@ logger = setup_logger()
HIGHLIGHT_START_CHAR = "\ue000"
HIGHLIGHT_END_CHAR = "\ue001"
CHANNEL_TYPES = ["public_channel", "im", "mpim", "private_channel"]
CHANNEL_METADATA_CACHE_TTL = 60 * 60 * 24 # 24 hours
SLACK_THREAD_CONTEXT_WINDOW = 3 # Number of messages before matched message to include
CHANNEL_METADATA_MAX_RETRIES = 3 # Maximum retry attempts for channel metadata fetching
@@ -98,10 +101,12 @@ def fetch_and_cache_channel_metadata(
# Retry logic with exponential backoff
last_exception = None
available_channel_types = ALL_CHANNEL_TYPES.copy()
for attempt in range(CHANNEL_METADATA_MAX_RETRIES):
try:
# ALWAYS fetch all channel types including private
channel_types = ",".join(CHANNEL_TYPES)
# Use available channel types (may be reduced if scopes are missing)
channel_types = ",".join(available_channel_types)
# Fetch all channels in one call
cursor = None
@@ -157,6 +162,42 @@ def fetch_and_cache_channel_metadata(
except SlackApiError as e:
last_exception = e
# Extract all needed fields from response upfront
if e.response:
error_response = e.response.get("error", "")
needed_scope = e.response.get("needed", "")
else:
error_response = ""
needed_scope = ""
# Check if this is a missing_scope error
if error_response == "missing_scope":
# Get the channel type that requires this scope
missing_channel_type = get_channel_type_for_missing_scope(needed_scope)
if (
missing_channel_type
and missing_channel_type in available_channel_types
):
# Remove the problematic channel type and retry
available_channel_types.remove(missing_channel_type)
logger.warning(
f"Missing scope '{needed_scope}' for channel type '{missing_channel_type}'. "
f"Continuing with reduced channel types: {available_channel_types}"
)
# Don't count this as a retry attempt, just try again with fewer types
if available_channel_types: # Only continue if we have types left
continue
# Otherwise fall through to retry logic
else:
logger.error(
f"Missing scope '{needed_scope}' but could not map to channel type or already removed. "
f"Response: {e.response}"
)
# For other errors, use retry logic
if attempt < CHANNEL_METADATA_MAX_RETRIES - 1:
retry_delay = CHANNEL_METADATA_RETRY_DELAY * (2**attempt)
logger.warning(
@@ -169,7 +210,15 @@ def fetch_and_cache_channel_metadata(
f"Failed to fetch channel metadata after {CHANNEL_METADATA_MAX_RETRIES} attempts: {e}"
)
# If we exhausted all retries, raise the last exception
# If we have some channel metadata despite errors, return it with a warning
if channel_metadata:
logger.warning(
f"Returning partial channel metadata ({len(channel_metadata)} channels) despite errors. "
f"Last error: {last_exception}"
)
return channel_metadata
# If we exhausted all retries and have no data, raise the last exception
if last_exception:
raise SlackApiError(
f"Channel metadata fetching failed after {CHANNEL_METADATA_MAX_RETRIES} attempts",

View File

@@ -29,6 +29,9 @@ DAYS_PER_WEEK = 7
DAYS_PER_MONTH = 30
MAX_CONTENT_WORDS = 3
# Punctuation to strip from words during analysis
WORD_PUNCTUATION = ".,!?;:\"'#"
RECENCY_KEYWORDS = ["recent", "latest", "newest", "last"]
@@ -41,6 +44,48 @@ class ChannelTypeString(str, Enum):
PUBLIC_CHANNEL = "public_channel"
# All Slack channel types for fetching metadata
ALL_CHANNEL_TYPES = [
ChannelTypeString.PUBLIC_CHANNEL.value,
ChannelTypeString.IM.value,
ChannelTypeString.MPIM.value,
ChannelTypeString.PRIVATE_CHANNEL.value,
]
# Map Slack API scopes to their corresponding channel types
# This is used for graceful degradation when scopes are missing
SCOPE_TO_CHANNEL_TYPE_MAP = {
"mpim:read": ChannelTypeString.MPIM.value,
"mpim:history": ChannelTypeString.MPIM.value,
"im:read": ChannelTypeString.IM.value,
"im:history": ChannelTypeString.IM.value,
"groups:read": ChannelTypeString.PRIVATE_CHANNEL.value,
"groups:history": ChannelTypeString.PRIVATE_CHANNEL.value,
"channels:read": ChannelTypeString.PUBLIC_CHANNEL.value,
"channels:history": ChannelTypeString.PUBLIC_CHANNEL.value,
}
def get_channel_type_for_missing_scope(scope: str) -> str | None:
"""Get the channel type that requires a specific Slack scope.
Args:
scope: The Slack API scope (e.g., 'mpim:read', 'im:history')
Returns:
The channel type string if scope is recognized, None otherwise
Examples:
>>> get_channel_type_for_missing_scope('mpim:read')
'mpim'
>>> get_channel_type_for_missing_scope('im:read')
'im'
>>> get_channel_type_for_missing_scope('unknown:scope')
None
"""
return SCOPE_TO_CHANNEL_TYPE_MAP.get(scope)
def _parse_llm_code_block_response(response: str) -> str:
"""Remove code block markers from LLM response if present.
@@ -64,11 +109,40 @@ def _parse_llm_code_block_response(response: str) -> str:
def is_recency_query(query: str) -> bool:
return any(
"""Check if a query is primarily about recency (not content + recency).
Returns True only for pure recency queries like "recent messages" or "latest updates",
but False for queries with content + recency like "golf scores last saturday".
"""
# Check if query contains recency keywords
has_recency_keyword = any(
re.search(rf"\b{re.escape(keyword)}\b", query, flags=re.IGNORECASE)
for keyword in RECENCY_KEYWORDS
)
if not has_recency_keyword:
return False
# Get combined stop words (NLTK + Slack-specific)
all_stop_words = _get_combined_stop_words()
# Extract content words (excluding stop words)
query_lower = query.lower()
words = query_lower.split()
# Count content words (not stop words, length > 2)
content_word_count = 0
for word in words:
clean_word = word.strip(WORD_PUNCTUATION)
if clean_word and len(clean_word) > 2 and clean_word not in all_stop_words:
content_word_count += 1
# If query has significant content words (>= 2), it's not a pure recency query
# Examples:
# - "recent messages" -> content_word_count = 0 -> pure recency
# - "golf scores last saturday" -> content_word_count = 3 (golf, scores, saturday) -> not pure recency
return content_word_count < 2
def extract_date_range_from_query(
query: str,
@@ -83,6 +157,21 @@ def extract_date_range_from_query(
if re.search(r"\byesterday\b", query_lower):
return min(1, default_search_days)
# Handle "last [day of week]" - e.g., "last monday", "last saturday"
days_of_week = [
"monday",
"tuesday",
"wednesday",
"thursday",
"friday",
"saturday",
"sunday",
]
for day in days_of_week:
if re.search(rf"\b(?:last|this)\s+{day}\b", query_lower):
# Assume last occurrence of that day was within the past week
return min(DAYS_PER_WEEK, default_search_days)
match = re.search(r"\b(?:last|past)\s+(\d+)\s+days?\b", query_lower)
if match:
days = int(match.group(1))
@@ -120,22 +209,40 @@ def extract_date_range_from_query(
try:
data = json.loads(response_clean)
if not isinstance(data, dict):
logger.debug(
f"LLM date extraction returned non-dict response for query: "
f"'{query}', using default: {default_search_days} days"
)
return default_search_days
days_back = data.get("days_back")
if days_back is None:
logger.debug(
f"LLM date extraction returned null for query: '{query}', using default: {default_search_days} days"
f"LLM date extraction returned null for query: '{query}', "
f"using default: {default_search_days} days"
)
return default_search_days
if not isinstance(days_back, (int, float)):
logger.debug(
f"LLM date extraction returned non-numeric days_back for "
f"query: '{query}', using default: {default_search_days} days"
)
return default_search_days
except json.JSONDecodeError:
logger.debug(
f"Failed to parse LLM date extraction response for query: '{query}', using default: {default_search_days} days"
f"Failed to parse LLM date extraction response for query: '{query}' "
f"(response: '{response_clean}'), "
f"using default: {default_search_days} days"
)
return default_search_days
return min(days_back, default_search_days)
return min(int(days_back), default_search_days)
except Exception as e:
logger.warning(f"Error extracting date range with LLM: {e}")
logger.warning(f"Error extracting date range with LLM for query '{query}': {e}")
return default_search_days
@@ -413,6 +520,29 @@ SLACK_SPECIFIC_STOP_WORDS = frozenset(
)
def _get_combined_stop_words() -> set[str]:
"""Get combined NLTK + Slack-specific stop words.
Returns a set of stop words for filtering content words.
Falls back to just Slack-specific stop words if NLTK is unavailable.
Note: Currently only supports English stop words. Non-English queries
may have suboptimal content word extraction. Future enhancement could
detect query language and load appropriate stop words.
"""
try:
from nltk.corpus import stopwords # type: ignore
# TODO: Support multiple languages - currently hardcoded to English
# Could detect language or allow configuration
nltk_stop_words = set(stopwords.words("english"))
except Exception:
# Fallback if NLTK not available
nltk_stop_words = set()
return nltk_stop_words | SLACK_SPECIFIC_STOP_WORDS
def extract_content_words_from_recency_query(
query_text: str, channel_references: set[str]
) -> list[str]:
@@ -427,28 +557,19 @@ def extract_content_words_from_recency_query(
Returns:
List of content words (up to MAX_CONTENT_WORDS)
"""
# Get standard English stop words from NLTK (lazy import)
try:
from nltk.corpus import stopwords # type: ignore
nltk_stop_words = set(stopwords.words("english"))
except Exception:
# Fallback if NLTK not available
nltk_stop_words = set()
# Combine NLTK stop words with Slack-specific stop words
all_stop_words = nltk_stop_words | SLACK_SPECIFIC_STOP_WORDS
# Get combined stop words (NLTK + Slack-specific)
all_stop_words = _get_combined_stop_words()
words = query_text.split()
content_words = []
for word in words:
clean_word = word.lower().strip(".,!?;:\"'#")
clean_word = word.lower().strip(WORD_PUNCTUATION)
# Skip if it's a channel reference or a stop word
if clean_word in channel_references:
continue
if clean_word and clean_word not in all_stop_words and len(clean_word) > 2:
clean_word_orig = word.strip(".,!?;:\"'#")
clean_word_orig = word.strip(WORD_PUNCTUATION)
if clean_word_orig.lower() not in all_stop_words:
content_words.append(clean_word_orig)

View File

@@ -442,10 +442,25 @@ def set_cc_pair_repeated_error_state(
cc_pair_id: int,
in_repeated_error_state: bool,
) -> None:
values: dict = {"in_repeated_error_state": in_repeated_error_state}
# When entering repeated error state, also pause the connector
# to prevent continued indexing retry attempts.
# However, don't pause if there's an active manual indexing trigger,
# which indicates the user wants to retry immediately.
if in_repeated_error_state:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
# Only pause if there's no manual indexing trigger active
if cc_pair and cc_pair.indexing_trigger is None:
values["status"] = ConnectorCredentialPairStatus.PAUSED
stmt = (
update(ConnectorCredentialPair)
.where(ConnectorCredentialPair.id == cc_pair_id)
.values(in_repeated_error_state=in_repeated_error_state)
.values(**values)
)
db_session.execute(stmt)
db_session.commit()

View File

@@ -2482,6 +2482,50 @@ class CloudEmbeddingProvider(Base):
return f"<EmbeddingProvider(type='{self.provider_type}')>"
class InternetSearchProvider(Base):
__tablename__ = "internet_search_provider"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True, nullable=False)
provider_type: Mapped[str] = mapped_column(String, nullable=False)
api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
config: Mapped[dict[str, str] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
def __repr__(self) -> str:
return f"<InternetSearchProvider(name='{self.name}', provider_type='{self.provider_type}')>"
class InternetContentProvider(Base):
__tablename__ = "internet_content_provider"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True, nullable=False)
provider_type: Mapped[str] = mapped_column(String, nullable=False)
api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
config: Mapped[dict[str, str] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
def __repr__(self) -> str:
return f"<InternetContentProvider(name='{self.name}', provider_type='{self.provider_type}')>"
class DocumentSet(Base):
__tablename__ = "document_set"

View File

@@ -0,0 +1,309 @@
from __future__ import annotations
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.db.models import InternetContentProvider
from onyx.db.models import InternetSearchProvider
from shared_configs.enums import WebContentProviderType
from shared_configs.enums import WebSearchProviderType
def fetch_web_search_providers(db_session: Session) -> list[InternetSearchProvider]:
stmt = select(InternetSearchProvider).order_by(InternetSearchProvider.id.asc())
return list(db_session.scalars(stmt).all())
def fetch_web_content_providers(db_session: Session) -> list[InternetContentProvider]:
stmt = select(InternetContentProvider).order_by(InternetContentProvider.id.asc())
return list(db_session.scalars(stmt).all())
def fetch_active_web_search_provider(
db_session: Session,
) -> InternetSearchProvider | None:
stmt = select(InternetSearchProvider).where(
InternetSearchProvider.is_active.is_(True)
)
return db_session.scalars(stmt).first()
def fetch_web_search_provider_by_id(
provider_id: int, db_session: Session
) -> InternetSearchProvider | None:
return db_session.get(InternetSearchProvider, provider_id)
def fetch_web_search_provider_by_name(
name: str, db_session: Session
) -> InternetSearchProvider | None:
stmt = select(InternetSearchProvider).where(InternetSearchProvider.name.ilike(name))
return db_session.scalars(stmt).first()
def _ensure_unique_search_name(
name: str, provider_id: int | None, db_session: Session
) -> None:
existing = fetch_web_search_provider_by_name(name=name, db_session=db_session)
if existing and existing.id != provider_id:
raise ValueError(f"A web search provider named '{name}' already exists.")
def _apply_search_provider_updates(
provider: InternetSearchProvider,
*,
name: str,
provider_type: WebSearchProviderType,
api_key: str | None,
api_key_changed: bool,
config: dict[str, str] | None,
) -> None:
provider.name = name
provider.provider_type = provider_type.value
provider.config = config
if api_key_changed or provider.api_key is None:
provider.api_key = api_key
def upsert_web_search_provider(
*,
provider_id: int | None,
name: str,
provider_type: WebSearchProviderType,
api_key: str | None,
api_key_changed: bool,
config: dict[str, str] | None,
activate: bool,
db_session: Session,
) -> InternetSearchProvider:
_ensure_unique_search_name(
name=name, provider_id=provider_id, db_session=db_session
)
provider: InternetSearchProvider | None = None
if provider_id is not None:
provider = fetch_web_search_provider_by_id(provider_id, db_session)
if provider is None:
raise ValueError(f"No web search provider with id {provider_id} exists.")
else:
provider = InternetSearchProvider()
db_session.add(provider)
_apply_search_provider_updates(
provider,
name=name,
provider_type=provider_type,
api_key=api_key,
api_key_changed=api_key_changed,
config=config,
)
db_session.flush()
if activate:
set_active_web_search_provider(provider_id=provider.id, db_session=db_session)
db_session.commit()
db_session.refresh(provider)
return provider
def set_active_web_search_provider(
*, provider_id: int | None, db_session: Session
) -> InternetSearchProvider:
if provider_id is None:
raise ValueError("Cannot activate a provider without an id.")
provider = fetch_web_search_provider_by_id(provider_id, db_session)
if provider is None:
raise ValueError(f"No web search provider with id {provider_id} exists.")
db_session.execute(
update(InternetSearchProvider)
.where(
InternetSearchProvider.is_active.is_(True),
InternetSearchProvider.id != provider_id,
)
.values(is_active=False)
)
provider.is_active = True
db_session.flush()
db_session.refresh(provider)
return provider
def deactivate_web_search_provider(
*, provider_id: int | None, db_session: Session
) -> InternetSearchProvider:
if provider_id is None:
raise ValueError("Cannot deactivate a provider without an id.")
provider = fetch_web_search_provider_by_id(provider_id, db_session)
if provider is None:
raise ValueError(f"No web search provider with id {provider_id} exists.")
provider.is_active = False
db_session.flush()
db_session.refresh(provider)
return provider
def delete_web_search_provider(provider_id: int, db_session: Session) -> None:
provider = fetch_web_search_provider_by_id(provider_id, db_session)
if provider is None:
raise ValueError(f"No web search provider with id {provider_id} exists.")
db_session.delete(provider)
db_session.flush()
db_session.commit()
# Content provider helpers
def fetch_active_web_content_provider(
db_session: Session,
) -> InternetContentProvider | None:
stmt = select(InternetContentProvider).where(
InternetContentProvider.is_active.is_(True)
)
return db_session.scalars(stmt).first()
def fetch_web_content_provider_by_id(
provider_id: int, db_session: Session
) -> InternetContentProvider | None:
return db_session.get(InternetContentProvider, provider_id)
def fetch_web_content_provider_by_name(
name: str, db_session: Session
) -> InternetContentProvider | None:
stmt = select(InternetContentProvider).where(
InternetContentProvider.name.ilike(name)
)
return db_session.scalars(stmt).first()
def _ensure_unique_content_name(
name: str, provider_id: int | None, db_session: Session
) -> None:
existing = fetch_web_content_provider_by_name(name=name, db_session=db_session)
if existing and existing.id != provider_id:
raise ValueError(f"A web content provider named '{name}' already exists.")
def _apply_content_provider_updates(
provider: InternetContentProvider,
*,
name: str,
provider_type: WebContentProviderType,
api_key: str | None,
api_key_changed: bool,
config: dict[str, str] | None,
) -> None:
provider.name = name
provider.provider_type = provider_type.value
provider.config = config
if api_key_changed or provider.api_key is None:
provider.api_key = api_key
def upsert_web_content_provider(
*,
provider_id: int | None,
name: str,
provider_type: WebContentProviderType,
api_key: str | None,
api_key_changed: bool,
config: dict[str, str] | None,
activate: bool,
db_session: Session,
) -> InternetContentProvider:
_ensure_unique_content_name(
name=name, provider_id=provider_id, db_session=db_session
)
provider: InternetContentProvider | None = None
if provider_id is not None:
provider = fetch_web_content_provider_by_id(provider_id, db_session)
if provider is None:
raise ValueError(f"No web content provider with id {provider_id} exists.")
else:
provider = InternetContentProvider()
db_session.add(provider)
_apply_content_provider_updates(
provider,
name=name,
provider_type=provider_type,
api_key=api_key,
api_key_changed=api_key_changed,
config=config,
)
db_session.flush()
if activate:
set_active_web_content_provider(provider_id=provider.id, db_session=db_session)
db_session.commit()
db_session.refresh(provider)
return provider
def set_active_web_content_provider(
*, provider_id: int | None, db_session: Session
) -> InternetContentProvider:
if provider_id is None:
raise ValueError("Cannot activate a provider without an id.")
provider = fetch_web_content_provider_by_id(provider_id, db_session)
if provider is None:
raise ValueError(f"No web content provider with id {provider_id} exists.")
db_session.execute(
update(InternetContentProvider)
.where(
InternetContentProvider.is_active.is_(True),
InternetContentProvider.id != provider_id,
)
.values(is_active=False)
)
provider.is_active = True
db_session.flush()
db_session.refresh(provider)
return provider
def deactivate_web_content_provider(
*, provider_id: int | None, db_session: Session
) -> InternetContentProvider:
if provider_id is None:
raise ValueError("Cannot deactivate a provider without an id.")
provider = fetch_web_content_provider_by_id(provider_id, db_session)
if provider is None:
raise ValueError(f"No web content provider with id {provider_id} exists.")
provider.is_active = False
db_session.flush()
db_session.refresh(provider)
return provider
def delete_web_content_provider(provider_id: int, db_session: Session) -> None:
provider = fetch_web_content_provider_by_id(provider_id, db_session)
if provider is None:
raise ValueError(f"No web content provider with id {provider_id} exists.")
db_session.delete(provider)
db_session.flush()
db_session.commit()

View File

@@ -510,6 +510,7 @@ class LitellmLLM(LLM):
# model params
temperature=(1 if is_reasoning else self._temperature),
timeout=timeout_override or self._timeout,
**({"stream_options": {"include_usage": True}} if stream else {}),
# For now, we don't support parallel tool calls
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error

View File

@@ -3,29 +3,87 @@ import time
import uuid
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import TypedDict
from typing import Union
from litellm import AllMessageValues
from litellm.completion_extras.litellm_responses_transformation.transformation import (
OpenAiResponsesToChatCompletionStreamIterator,
)
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.litellm_core_utils.prompt_templates.common_utils import (
extract_images_from_message,
)
try:
from litellm.litellm_core_utils.prompt_templates.common_utils import (
extract_images_from_message,
)
except ImportError:
extract_images_from_message = None # type: ignore[assignment]
from litellm.llms.ollama.chat.transformation import OllamaChatCompletionResponseIterator
from litellm.llms.ollama.chat.transformation import OllamaChatConfig
from litellm.llms.ollama.common_utils import OllamaError
from litellm.types.llms.ollama import OllamaChatCompletionMessage
try:
from litellm.types.llms.ollama import OllamaChatCompletionMessage
except ImportError:
class OllamaChatCompletionMessage(TypedDict, total=False): # type: ignore[no-redef]
"""Fallback for LiteLLM versions where this TypedDict was removed."""
role: str
content: Optional[str]
images: Optional[List[Any]]
thinking: Optional[str]
tool_calls: Optional[List["OllamaToolCall"]]
from litellm.types.llms.ollama import OllamaToolCall
from litellm.types.llms.ollama import OllamaToolCallFunction
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
from litellm.types.utils import ChatCompletionUsageBlock
from litellm.types.utils import GenericStreamingChunk
from litellm.types.utils import ModelResponseStream
from litellm.utils import verbose_logger
from pydantic import BaseModel
if extract_images_from_message is None:
def extract_images_from_message(
message: AllMessageValues,
) -> Optional[List[Any]]:
"""Fallback for LiteLLM versions that dropped extract_images_from_message."""
images: List[Any] = []
content = message.get("content")
if not isinstance(content, list):
return None
for item in content:
if not isinstance(item, Dict):
continue
item_type = item.get("type")
if item_type == "image_url":
image_url = item.get("image_url")
if isinstance(image_url, dict):
if image_url.get("url"):
images.append(image_url)
elif image_url:
images.append(image_url)
elif item_type in {"input_image", "image"}:
image_value = item.get("image")
if image_value:
images.append(image_value)
return images or None
def _patch_ollama_transform_request() -> None:
"""
Patches OllamaChatConfig.transform_request to handle reasoning content
@@ -254,16 +312,189 @@ def _patch_ollama_chunk_parser() -> None:
OllamaChatCompletionResponseIterator.chunk_parser = _patched_chunk_parser # type: ignore[method-assign]
def _patch_openai_responses_chunk_parser() -> None:
"""
Patches OpenAiResponsesToChatCompletionStreamIterator.chunk_parser to properly
handle OpenAI Responses API streaming format and convert it to chat completion format.
"""
if (
getattr(
OpenAiResponsesToChatCompletionStreamIterator.chunk_parser,
"__name__",
"",
)
== "_patched_openai_responses_chunk_parser"
):
return
def _patched_openai_responses_chunk_parser(
self: Any, chunk: dict
) -> Union["GenericStreamingChunk", "ModelResponseStream"]:
# Transform responses API streaming chunk to chat completion format
from litellm.types.llms.openai import ChatCompletionToolCallFunctionChunk
from litellm.types.utils import (
ChatCompletionToolCallChunk,
GenericStreamingChunk,
)
verbose_logger.debug(
f"Chat provider: transform_streaming_response called with chunk: {chunk}"
)
parsed_chunk = chunk
if not parsed_chunk:
raise ValueError("Chat provider: Empty parsed_chunk")
if not isinstance(parsed_chunk, dict):
raise ValueError(f"Chat provider: Invalid chunk type {type(parsed_chunk)}")
# Handle different event types from responses API
event_type = parsed_chunk.get("type")
verbose_logger.debug(f"Chat provider: Processing event type: {event_type}")
if event_type == "response.created":
# Initial response creation event
verbose_logger.debug(f"Chat provider: response.created -> {chunk}")
return GenericStreamingChunk(
text="", tool_use=None, is_finished=False, finish_reason="", usage=None
)
elif event_type == "response.output_item.added":
# New output item added
output_item = parsed_chunk.get("item", {})
if output_item.get("type") == "function_call":
return GenericStreamingChunk(
text="",
tool_use=ChatCompletionToolCallChunk(
id=output_item.get("call_id"),
index=0,
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=output_item.get("name", None),
arguments=parsed_chunk.get("arguments", ""),
),
),
is_finished=False,
finish_reason="",
usage=None,
)
elif output_item.get("type") == "message":
pass
elif output_item.get("type") == "reasoning":
pass
else:
raise ValueError(f"Chat provider: Invalid output_item {output_item}")
elif event_type == "response.function_call_arguments.delta":
content_part: Optional[str] = parsed_chunk.get("delta", None)
if content_part:
return GenericStreamingChunk(
text="",
tool_use=ChatCompletionToolCallChunk(
id=None,
index=0,
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=None, arguments=content_part
),
),
is_finished=False,
finish_reason="",
usage=None,
)
else:
raise ValueError(
f"Chat provider: Invalid function argument delta {parsed_chunk}"
)
elif event_type == "response.output_item.done":
# New output item added
output_item = parsed_chunk.get("item", {})
if output_item.get("type") == "function_call":
return GenericStreamingChunk(
text="",
tool_use=ChatCompletionToolCallChunk(
id=output_item.get("call_id"),
index=0,
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=parsed_chunk.get("name", None),
arguments="", # responses API sends everything again, we don't
),
),
is_finished=True,
finish_reason="tool_calls",
usage=None,
)
elif output_item.get("type") == "message":
return GenericStreamingChunk(
finish_reason="stop", is_finished=True, usage=None, text=""
)
elif output_item.get("type") == "reasoning":
pass
else:
raise ValueError(f"Chat provider: Invalid output_item {output_item}")
elif event_type == "response.output_text.delta":
# Content part added to output
content_part = parsed_chunk.get("delta", None)
if content_part is not None:
return GenericStreamingChunk(
text=content_part,
tool_use=None,
is_finished=False,
finish_reason="",
usage=None,
)
else:
raise ValueError(f"Chat provider: Invalid text delta {parsed_chunk}")
elif event_type == "response.reasoning_summary_text.delta":
content_part = parsed_chunk.get("delta", None)
if content_part:
from litellm.types.utils import (
Delta,
ModelResponseStream,
StreamingChoices,
)
return ModelResponseStream(
choices=[
StreamingChoices(
index=cast(int, parsed_chunk.get("summary_index")),
delta=Delta(reasoning_content=content_part),
)
]
)
else:
pass
# For any unhandled event types, create a minimal valid chunk or skip
verbose_logger.debug(
f"Chat provider: Unhandled event type '{event_type}', creating empty chunk"
)
# Return a minimal valid chunk for unknown events
return GenericStreamingChunk(
text="", tool_use=None, is_finished=False, finish_reason="", usage=None
)
_patched_openai_responses_chunk_parser.__name__ = (
"_patched_openai_responses_chunk_parser"
)
OpenAiResponsesToChatCompletionStreamIterator.chunk_parser = _patched_openai_responses_chunk_parser # type: ignore[method-assign]
def apply_monkey_patches() -> None:
"""
Apply all necessary monkey patches to LiteLLM for Ollama compatibility.
Apply all necessary monkey patches to LiteLLM for compatibility.
This includes:
- Patching OllamaChatConfig.transform_request for reasoning content support
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
"""
_patch_ollama_transform_request()
_patch_ollama_chunk_parser()
_patch_openai_responses_chunk_parser()
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:

View File

@@ -1,146 +0,0 @@
from collections.abc import Sequence
from langchain.schema.messages import BaseMessage
from onyx.llm.message_types import AssistantMessage
from onyx.llm.message_types import ChatCompletionMessage
from onyx.llm.message_types import ContentPart
from onyx.llm.message_types import ImageContentPart
from onyx.llm.message_types import SystemMessage
from onyx.llm.message_types import TextContentPart
from onyx.llm.message_types import UserMessageWithParts
from onyx.llm.message_types import UserMessageWithText
def base_messages_to_chat_completion_msgs(
msgs: Sequence[BaseMessage],
) -> list[ChatCompletionMessage]:
return [_base_message_to_chat_completion_msg(msg) for msg in msgs]
def _base_message_to_chat_completion_msg(msg: BaseMessage) -> ChatCompletionMessage:
message_type_to_role = {
"human": "user",
"system": "system",
"ai": "assistant",
"tool": "tool",
}
role = message_type_to_role[msg.type]
content = msg.content
if isinstance(content, str):
# Simple string content
if role == "system":
system_msg: SystemMessage = {
"role": "system",
"content": content,
}
return system_msg
elif role == "user":
user_msg: UserMessageWithText = {
"role": "user",
"content": content,
}
return user_msg
else: # assistant
assistant_msg: AssistantMessage = {
"role": "assistant",
"content": content,
}
return assistant_msg
elif isinstance(content, list):
# List content - need to convert to OpenAI format
if role == "assistant":
# For assistant, convert list to simple string
# (OpenAI format uses string content, not list)
text_parts = []
for item in content:
if isinstance(item, str):
text_parts.append(item)
elif isinstance(item, dict) and item.get("type") == "text":
text_parts.append(item.get("text", ""))
else:
raise ValueError(
f"Unexpected item type for assistant message: {type(item)}. Item: {item}"
)
assistant_msg_from_list: AssistantMessage = {
"role": "assistant",
"content": " ".join(text_parts) if text_parts else None,
}
return assistant_msg_from_list
else: # system or user
content_parts: list[ContentPart] = []
has_images = False
for item in content:
if isinstance(item, str):
content_parts.append(TextContentPart(type="text", text=item))
elif isinstance(item, dict):
item_type = item.get("type")
if item_type == "text":
content_parts.append(
TextContentPart(type="text", text=item.get("text", ""))
)
elif item_type == "image_url":
has_images = True
# Convert image_url to OpenAI format
image_url = item.get("image_url", {})
if isinstance(image_url, dict):
url = image_url.get("url", "")
detail = image_url.get("detail", "auto")
else:
url = image_url
detail = "auto"
image_part: ImageContentPart = {
"type": "image_url",
"image_url": {"url": url, "detail": detail},
}
content_parts.append(image_part)
else:
raise ValueError(f"Unexpected item type: {item_type}")
else:
raise ValueError(
f"Unexpected item type: {type(item)}. Item: {item}"
)
if role == "system":
# System messages should be text only, concatenate all text parts
text_parts = [
part["text"] for part in content_parts if part["type"] == "text"
]
system_msg_from_list: SystemMessage = {
"role": "system",
"content": " ".join(text_parts),
}
return system_msg_from_list
else: # user
# If there are images, use the parts format; otherwise use simple string
if has_images or len(content_parts) > 1:
user_msg_with_parts: UserMessageWithParts = {
"role": "user",
"content": content_parts,
}
return user_msg_with_parts
elif len(content_parts) == 1 and content_parts[0]["type"] == "text":
# Single text part - use simple string format
user_msg_simple: UserMessageWithText = {
"role": "user",
"content": content_parts[0]["text"],
}
return user_msg_simple
else:
# Empty content
user_msg_empty: UserMessageWithText = {
"role": "user",
"content": "",
}
return user_msg_empty
else:
raise ValueError(
f"Unexpected content type: {type(content)}. Content: {content}"
)

View File

@@ -38,10 +38,19 @@ class StreamingChoice(BaseModel):
delta: Delta = Field(default_factory=Delta)
class Usage(BaseModel):
completion_tokens: int
prompt_tokens: int
total_tokens: int
cache_creation_input_tokens: int
cache_read_input_tokens: int
class ModelResponseStream(BaseModel):
id: str
created: str
choice: StreamingChoice
usage: Usage | None = None
if TYPE_CHECKING:
@@ -174,10 +183,24 @@ def from_litellm_model_response_stream(
delta=parsed_delta,
)
usage_data = response_data.get("usage")
return ModelResponseStream(
id=response_id,
created=created,
choice=streaming_choice,
usage=(
Usage(
completion_tokens=usage_data.get("completion_tokens", 0),
prompt_tokens=usage_data.get("prompt_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
cache_creation_input_tokens=usage_data.get(
"cache_creation_input_tokens", 0
),
cache_read_input_tokens=usage_data.get("cache_read_input_tokens", 0),
)
if usage_data
else None
),
)

View File

@@ -116,6 +116,7 @@ def litellm_exception_to_error_msg(
from litellm.exceptions import Timeout
from litellm.exceptions import ContentPolicyViolationError
from litellm.exceptions import BudgetExceededError
from litellm.exceptions import ServiceUnavailableError
core_exception = _unwrap_nested_exception(e)
error_msg = str(core_exception)
@@ -168,6 +169,23 @@ def litellm_exception_to_error_msg(
if upstream_detail
else f"{provider_name} rate limit exceeded: Please slow down your requests and try again later."
)
elif isinstance(core_exception, ServiceUnavailableError):
provider_name = (
llm.config.model_provider
if llm is not None and llm.config.model_provider
else "The LLM provider"
)
# Check if this is specifically the Bedrock "Too many connections" error
if "Too many connections" in error_msg or "BedrockException" in error_msg:
error_msg = (
f"{provider_name} is experiencing high connection volume and cannot process your request right now. "
"This typically happens when there are too many simultaneous requests to the AI model. "
"Please wait a moment and try again. If this persists, contact your system administrator "
"to review connection limits and retry configurations."
)
else:
# Generic 503 Service Unavailable
error_msg = f"{provider_name} service error: {str(core_exception)}"
elif isinstance(core_exception, ContextWindowExceededError):
error_msg = (
"Context window exceeded: Your input is too long for the model to process."

View File

@@ -1,6 +1,7 @@
import logging
import sys
import traceback
import warnings
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Any
@@ -87,6 +88,7 @@ from onyx.server.features.projects.api import router as projects_router
from onyx.server.features.tool.api import admin_router as admin_tool_router
from onyx.server.features.tool.api import router as tool_router
from onyx.server.features.user_oauth_token.api import router as user_oauth_token_router
from onyx.server.features.web_search.api import router as web_search_router
from onyx.server.federated.api import router as federated_router
from onyx.server.gpts.api import router as gpts_router
from onyx.server.kg.api import admin_router as kg_admin_router
@@ -102,6 +104,9 @@ from onyx.server.manage.llm.api import basic_router as llm_router
from onyx.server.manage.search_settings import router as search_settings_router
from onyx.server.manage.slack_bot import router as slack_bot_management_router
from onyx.server.manage.users import router as user_router
from onyx.server.manage.web_search.api import (
admin_router as web_search_admin_router,
)
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
from onyx.server.middleware.rate_limiting import close_auth_limiter
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
@@ -143,6 +148,13 @@ from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SENTRY_DSN
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
warnings.filterwarnings(
"ignore", category=ResourceWarning, message=r"Unclosed client session"
)
warnings.filterwarnings(
"ignore", category=ResourceWarning, message=r"Unclosed connector"
)
logger = setup_logger()
file_handlers = [
@@ -392,6 +404,8 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
include_router_with_global_prefix_prepended(application, llm_router)
include_router_with_global_prefix_prepended(application, embedding_admin_router)
include_router_with_global_prefix_prepended(application, embedding_router)
include_router_with_global_prefix_prepended(application, web_search_router)
include_router_with_global_prefix_prepended(application, web_search_admin_router)
include_router_with_global_prefix_prepended(
application, token_rate_limit_settings_router
)

View File

@@ -14,6 +14,7 @@ from onyx.chat.process_message import gather_stream
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from onyx.configs.onyxbot_configs import MAX_THREAD_CONTEXT_PERCENTAGE
from onyx.configs.onyxbot_configs import ONYX_BOT_DISABLE_DOCS_ONLY_ANSWER
from onyx.configs.onyxbot_configs import ONYX_BOT_DISPLAY_ERROR_MSGS
@@ -28,6 +29,9 @@ from onyx.db.models import User
from onyx.db.persona import get_persona_by_id
from onyx.db.persona import persona_has_search_tool
from onyx.db.users import get_user_by_email
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.models import PreviousMessage
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.onyxbot.slack.blocks import build_slack_response_blocks
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
from onyx.onyxbot.slack.handlers.utils import slackify_message_thread
@@ -155,6 +159,47 @@ def handle_regular_answer(
history_messages = messages[:-1]
single_message_history = slackify_message_thread(history_messages) or None
# Convert ThreadMessage objects to PreviousMessage for query rephrasing
thread_previous_messages: list[PreviousMessage] | None = None
if history_messages:
llm, _ = get_llms_for_persona(persona, user)
llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,
provider_type=llm.config.model_provider,
)
# Work backwards from most recent messages, only keeping what fits in max token count
temp_messages = []
total_token_count = 0
for thread_msg in reversed(history_messages):
token_count = len(llm_tokenizer.encode(thread_msg.message))
# Stop if adding this message would exceed the max token count
if total_token_count + token_count > GEN_AI_HISTORY_CUTOFF:
break
temp_messages.append(
PreviousMessage(
message=thread_msg.message,
token_count=token_count,
message_type=thread_msg.role,
files=[],
tool_call=None,
refined_answer_improvement=None,
research_answer_purpose=None,
)
)
total_token_count += token_count
# Reverse back to chronological order (oldest to newest)
thread_previous_messages = list(reversed(temp_messages))
logger.info(
f"Converted {len(thread_previous_messages)} of {len(history_messages)} "
f"thread messages ({total_token_count} tokens) for query rephrasing"
)
# Always check for ACL permissions, also for documnt sets that were explicitly added
# to the Bot by the Administrator. (Change relative to earlier behavior where all documents
# in an attached document set were available to all users in the channel.)
@@ -184,6 +229,7 @@ def handle_regular_answer(
db_session=db_session,
bypass_acl=bypass_acl,
single_message_history=single_message_history,
thread_message_history=thread_previous_messages,
)
answer = gather_stream(packets)

View File

@@ -72,6 +72,7 @@ from onyx.db.connector import create_connector
from onyx.db.connector import delete_connector
from onyx.db.connector import fetch_connector_by_id
from onyx.db.connector import fetch_connectors
from onyx.db.connector import fetch_unique_document_sources
from onyx.db.connector import get_connector_credential_ids
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector import update_connector
@@ -128,6 +129,7 @@ from onyx.server.documents.models import GmailCallback
from onyx.server.documents.models import GoogleAppCredentials
from onyx.server.documents.models import GoogleServiceAccountCredentialRequest
from onyx.server.documents.models import GoogleServiceAccountKey
from onyx.server.documents.models import IndexedSourceTypesResponse
from onyx.server.documents.models import IndexingStatusRequest
from onyx.server.documents.models import ObjectCreationIdResponse
from onyx.server.documents.models import RunConnectorRequest
@@ -1479,6 +1481,17 @@ def get_connectors(
]
@router.get("/indexed-source-types")
def get_indexed_source_types(
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> IndexedSourceTypesResponse:
source_types = sorted(
fetch_unique_document_sources(db_session), key=lambda source: source.value
)
return IndexedSourceTypesResponse(source_types=source_types)
@router.get("/connector/{connector_id}")
def get_connector_by_id(
connector_id: int,

View File

@@ -9,6 +9,7 @@ from typing import TypeVar
from uuid import UUID
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from onyx.configs.app_configs import MASK_CREDENTIAL_PREFIX
@@ -54,6 +55,11 @@ class ChunkInfo(BaseModel):
num_tokens: int
class IndexedSourceTypesResponse(BaseModel):
model_config = ConfigDict(use_enum_values=True)
source_types: list[DocumentSource]
class DeletionAttemptSnapshot(BaseModel):
connector_id: int
credential_id: int

View File

@@ -0,0 +1,262 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.onyx_web_crawler_client import (
OnyxWebCrawlerClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContentProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
build_content_provider_from_config,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
build_search_provider_from_config,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
truncate_search_result_content,
)
from onyx.auth.users import current_user
from onyx.chat.models import DOCUMENT_CITATION_NUMBER_EMPTY_VALUE
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.db.web_search import fetch_active_web_content_provider
from onyx.db.web_search import fetch_active_web_search_provider
from onyx.server.features.web_search.models import OpenUrlsToolRequest
from onyx.server.features.web_search.models import OpenUrlsToolResponse
from onyx.server.features.web_search.models import WebSearchToolRequest
from onyx.server.features.web_search.models import WebSearchToolResponse
from onyx.server.features.web_search.models import WebSearchWithContentResponse
from onyx.server.manage.web_search.models import WebContentProviderView
from onyx.server.manage.web_search.models import WebSearchProviderView
from onyx.tools.tool_implementations_v2.tool_result_models import (
LlmOpenUrlResult,
)
from onyx.tools.tool_implementations_v2.tool_result_models import (
LlmWebSearchResult,
)
from onyx.utils.logger import setup_logger
from shared_configs.enums import WebContentProviderType
from shared_configs.enums import WebSearchProviderType
router = APIRouter(prefix="/web-search")
logger = setup_logger()
def _get_active_search_provider(
db_session: Session,
) -> tuple[WebSearchProviderView, WebSearchProvider]:
provider_model = fetch_active_web_search_provider(db_session)
if provider_model is None:
raise HTTPException(
status_code=400,
detail="No web search provider configured.",
)
provider_view = WebSearchProviderView(
id=provider_model.id,
name=provider_model.name,
provider_type=WebSearchProviderType(provider_model.provider_type),
is_active=provider_model.is_active,
config=provider_model.config or {},
has_api_key=bool(provider_model.api_key),
)
try:
provider: WebSearchProvider | None = build_search_provider_from_config(
provider_type=provider_view.provider_type,
api_key=provider_model.api_key,
config=provider_model.config or {},
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
if provider is None:
raise HTTPException(
status_code=400,
detail="Unable to initialize the configured web search provider.",
)
return provider_view, provider
def _get_active_content_provider(
db_session: Session,
) -> tuple[WebContentProviderView | None, WebContentProvider]:
provider_model = fetch_active_web_content_provider(db_session)
if provider_model is None:
# Default to the built-in crawler if nothing is configured. Always available.
# NOTE: the OnyxWebCrawlerClient is not stored in the content provider table,
# so we need to return it directly.
return None, OnyxWebCrawlerClient()
try:
provider_type = WebContentProviderType(provider_model.provider_type)
provider: WebContentProvider | None = build_content_provider_from_config(
provider_type=provider_type,
api_key=provider_model.api_key,
config=provider_model.config or {},
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
if provider is None:
raise HTTPException(
status_code=400,
detail="Unable to initialize the configured web content provider.",
)
provider_view = WebContentProviderView(
id=provider_model.id,
name=provider_model.name,
provider_type=provider_type,
is_active=provider_model.is_active,
config=provider_model.config or {},
has_api_key=bool(provider_model.api_key),
)
return provider_view, provider
def _run_web_search(
request: WebSearchToolRequest, db_session: Session
) -> tuple[WebSearchProviderType, list[LlmWebSearchResult]]:
provider_view, provider = _get_active_search_provider(db_session)
results: list[LlmWebSearchResult] = []
for query in request.queries:
try:
search_results = provider.search(query)
except HTTPException:
raise
except Exception as exc:
logger.exception("Web search provider failed for query '%s'", query)
raise HTTPException(
status_code=502, detail="Web search provider failed to execute query."
) from exc
trimmed_results = list(search_results)[: request.max_results]
for search_result in trimmed_results:
results.append(
LlmWebSearchResult(
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
url=search_result.link,
title=search_result.title,
snippet=search_result.snippet or "",
unique_identifier_to_strip_away=search_result.link,
)
)
return provider_view.provider_type, results
def _open_urls(
urls: list[str],
db_session: Session,
) -> tuple[WebContentProviderType | None, list[LlmOpenUrlResult]]:
provider_view, provider = _get_active_content_provider(db_session)
try:
docs = provider.contents(urls)
except HTTPException:
raise
except Exception as exc:
logger.exception("Web content provider failed to fetch URLs")
raise HTTPException(
status_code=502, detail="Web content provider failed to fetch URLs."
) from exc
results: list[LlmOpenUrlResult] = []
for doc in docs:
results.append(
LlmOpenUrlResult(
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
content=truncate_search_result_content(doc.full_content),
unique_identifier_to_strip_away=doc.link,
)
)
provider_type = (
provider_view.provider_type
if provider_view
else WebContentProviderType.ONYX_WEB_CRAWLER
)
return provider_type, results
@router.post("/search", response_model=WebSearchWithContentResponse)
def execute_web_search(
request: WebSearchToolRequest,
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> WebSearchWithContentResponse:
"""
Perform a web search and immediately fetch content for the returned URLs.
Use this when you want both snippets and page contents from one call.
If you want to selectively fetch content (i.e. let the LLM decide which URLs to read),
use `/search-lite` and then call `/open-urls` separately.
"""
search_provider_type, search_results = _run_web_search(request, db_session)
if not search_results:
return WebSearchWithContentResponse(
search_provider_type=search_provider_type,
content_provider_type=None,
search_results=[],
full_content_results=[],
)
# Fetch contents for unique URLs in the order they appear
seen: set[str] = set()
urls_to_fetch: list[str] = []
for result in search_results:
url = result.url
if url not in seen:
seen.add(url)
urls_to_fetch.append(url)
content_provider_type, full_content_results = _open_urls(urls_to_fetch, db_session)
return WebSearchWithContentResponse(
search_provider_type=search_provider_type,
content_provider_type=content_provider_type,
search_results=search_results,
full_content_results=full_content_results,
)
@router.post("/search-lite", response_model=WebSearchToolResponse)
def execute_web_search_lite(
request: WebSearchToolRequest,
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> WebSearchToolResponse:
"""
Lightweight search-only endpoint. Returns search snippets and URLs without
fetching page contents. Pair with `/open-urls` if you need to fetch content
later.
"""
provider_type, search_results = _run_web_search(request, db_session)
return WebSearchToolResponse(results=search_results, provider_type=provider_type)
@router.post("/open-urls", response_model=OpenUrlsToolResponse)
def execute_open_urls(
request: OpenUrlsToolRequest,
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> OpenUrlsToolResponse:
"""
Fetch content for specific URLs using the configured content provider.
Intended to complement `/search-lite` when you need content for a subset of URLs.
"""
provider_type, results = _open_urls(request.urls, db_session)
return OpenUrlsToolResponse(results=results, provider_type=provider_type)

View File

@@ -0,0 +1,76 @@
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from onyx.tools.tool_implementations_v2.tool_result_models import (
LlmOpenUrlResult,
)
from onyx.tools.tool_implementations_v2.tool_result_models import (
LlmWebSearchResult,
)
from shared_configs.enums import WebContentProviderType
from shared_configs.enums import WebSearchProviderType
class WebSearchToolRequest(BaseModel):
queries: list[str] = Field(
...,
min_length=1,
description="List of search queries to send to the configured provider.",
)
max_results: int | None = Field(
default=10,
description=(
"Optional cap on number of results to return per query. Defaults to 10."
),
)
@field_validator("queries")
@classmethod
def _strip_and_validate_queries(cls, queries: list[str]) -> list[str]:
cleaned_queries = [q.strip() for q in queries if q and q.strip()]
if not cleaned_queries:
raise ValueError("queries must include at least one non-empty value")
return cleaned_queries
@field_validator("max_results")
@classmethod
def _default_and_validate_max_results(cls, max_results: int | None) -> int:
# Default to 10 when not provided
max_results = 10 if max_results is None else max_results
if max_results < 1:
raise ValueError("max_results must be at least 1")
return max_results
class WebSearchToolResponse(BaseModel):
results: list[LlmWebSearchResult]
provider_type: WebSearchProviderType
class WebSearchWithContentResponse(BaseModel):
search_provider_type: WebSearchProviderType
content_provider_type: WebContentProviderType | None = None
search_results: list[LlmWebSearchResult]
full_content_results: list[LlmOpenUrlResult]
class OpenUrlsToolRequest(BaseModel):
urls: list[str] = Field(
...,
min_length=1,
description="URLs to fetch using the configured content provider.",
)
@field_validator("urls")
@classmethod
def _strip_and_validate_urls(cls, urls: list[str]) -> list[str]:
cleaned_urls = [url.strip() for url in urls if url and url.strip()]
if not cleaned_urls:
raise ValueError("urls must include at least one non-empty value")
return cleaned_urls
class OpenUrlsToolResponse(BaseModel):
results: list[LlmOpenUrlResult]
provider_type: WebContentProviderType | None = None

View File

@@ -125,18 +125,18 @@ def get_versions() -> AllVersions:
onyx=latest_stable_version,
relational_db="postgres:15.2-alpine",
index="vespaengine/vespa:8.277.17",
nginx="nginx:1.23.4-alpine",
nginx="nginx:1.25.5-alpine",
),
dev=ContainerVersions(
onyx=latest_dev_version,
relational_db="postgres:15.2-alpine",
index="vespaengine/vespa:8.277.17",
nginx="nginx:1.23.4-alpine",
nginx="nginx:1.25.5-alpine",
),
migration=ContainerVersions(
onyx="airgapped-intfloat-nomic-migration",
relational_db="postgres:15.2-alpine",
index="vespaengine/vespa:8.277.17",
nginx="nginx:1.23.4-alpine",
nginx="nginx:1.25.5-alpine",
),
)

View File

@@ -0,0 +1,364 @@
from __future__ import annotations
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
build_content_provider_from_config,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
build_search_provider_from_config,
)
from onyx.auth.users import current_admin_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.db.web_search import deactivate_web_content_provider
from onyx.db.web_search import deactivate_web_search_provider
from onyx.db.web_search import delete_web_content_provider
from onyx.db.web_search import delete_web_search_provider
from onyx.db.web_search import fetch_web_content_provider_by_name
from onyx.db.web_search import fetch_web_content_providers
from onyx.db.web_search import fetch_web_search_provider_by_name
from onyx.db.web_search import fetch_web_search_providers
from onyx.db.web_search import set_active_web_content_provider
from onyx.db.web_search import set_active_web_search_provider
from onyx.db.web_search import upsert_web_content_provider
from onyx.db.web_search import upsert_web_search_provider
from onyx.server.manage.web_search.models import WebContentProviderTestRequest
from onyx.server.manage.web_search.models import WebContentProviderUpsertRequest
from onyx.server.manage.web_search.models import WebContentProviderView
from onyx.server.manage.web_search.models import WebSearchProviderTestRequest
from onyx.server.manage.web_search.models import WebSearchProviderUpsertRequest
from onyx.server.manage.web_search.models import WebSearchProviderView
from onyx.utils.logger import setup_logger
from shared_configs.enums import WebContentProviderType
from shared_configs.enums import WebSearchProviderType
logger = setup_logger()
admin_router = APIRouter(prefix="/admin/web-search")
@admin_router.get("/search-providers", response_model=list[WebSearchProviderView])
def list_search_providers(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[WebSearchProviderView]:
providers = fetch_web_search_providers(db_session)
return [
WebSearchProviderView(
id=provider.id,
name=provider.name,
provider_type=WebSearchProviderType(provider.provider_type),
is_active=provider.is_active,
config=provider.config or {},
has_api_key=bool(provider.api_key),
)
for provider in providers
]
@admin_router.post("/search-providers", response_model=WebSearchProviderView)
def upsert_search_provider_endpoint(
request: WebSearchProviderUpsertRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> WebSearchProviderView:
existing_by_name = fetch_web_search_provider_by_name(request.name, db_session)
if (
existing_by_name
and request.id is not None
and existing_by_name.id != request.id
):
raise HTTPException(
status_code=400,
detail=f"A search provider named '{request.name}' already exists.",
)
provider = upsert_web_search_provider(
provider_id=request.id,
name=request.name,
provider_type=request.provider_type,
api_key=request.api_key,
api_key_changed=request.api_key_changed,
config=request.config,
activate=request.activate,
db_session=db_session,
)
db_session.commit()
return WebSearchProviderView(
id=provider.id,
name=provider.name,
provider_type=WebSearchProviderType(provider.provider_type),
is_active=provider.is_active,
config=provider.config or {},
has_api_key=bool(provider.api_key),
)
@admin_router.delete(
"/search-providers/{provider_id}", status_code=204, response_class=Response
)
def delete_search_provider(
provider_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> Response:
delete_web_search_provider(provider_id, db_session)
return Response(status_code=204)
@admin_router.post("/search-providers/{provider_id}/activate")
def activate_search_provider(
provider_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> WebSearchProviderView:
provider = set_active_web_search_provider(
provider_id=provider_id, db_session=db_session
)
db_session.commit()
return WebSearchProviderView(
id=provider.id,
name=provider.name,
provider_type=WebSearchProviderType(provider.provider_type),
is_active=provider.is_active,
config=provider.config or {},
has_api_key=bool(provider.api_key),
)
@admin_router.post("/search-providers/{provider_id}/deactivate")
def deactivate_search_provider(
provider_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
deactivate_web_search_provider(provider_id=provider_id, db_session=db_session)
db_session.commit()
return {"status": "ok"}
@admin_router.post("/search-providers/test")
def test_search_provider(
request: WebSearchProviderTestRequest,
_: User | None = Depends(current_admin_user),
) -> dict[str, str]:
try:
provider = build_search_provider_from_config(
provider_type=request.provider_type,
api_key=request.api_key,
config=request.config or {},
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
if provider is None:
raise HTTPException(
status_code=400, detail="Unable to build provider configuration."
)
# Actually test the API key by making a real search call
try:
test_results = provider.search("test")
if not test_results or not any(result.link for result in test_results):
raise HTTPException(
status_code=400,
detail="API key validation failed: search returned no results.",
)
except HTTPException:
raise
except Exception as e:
error_msg = str(e)
if (
"api" in error_msg.lower()
or "key" in error_msg.lower()
or "auth" in error_msg.lower()
):
raise HTTPException(
status_code=400,
detail=f"Invalid API key: {error_msg}",
) from e
raise HTTPException(
status_code=400,
detail=f"API key validation failed: {error_msg}",
) from e
logger.info(
f"Web search provider test succeeded for {request.provider_type.value}."
)
return {"status": "ok"}
@admin_router.get("/content-providers", response_model=list[WebContentProviderView])
def list_content_providers(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[WebContentProviderView]:
providers = fetch_web_content_providers(db_session)
return [
WebContentProviderView(
id=provider.id,
name=provider.name,
provider_type=WebContentProviderType(provider.provider_type),
is_active=provider.is_active,
config=provider.config or {},
has_api_key=bool(provider.api_key),
)
for provider in providers
]
@admin_router.post("/content-providers", response_model=WebContentProviderView)
def upsert_content_provider_endpoint(
request: WebContentProviderUpsertRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> WebContentProviderView:
existing_by_name = fetch_web_content_provider_by_name(request.name, db_session)
if (
existing_by_name
and request.id is not None
and existing_by_name.id != request.id
):
raise HTTPException(
status_code=400,
detail=f"A content provider named '{request.name}' already exists.",
)
provider = upsert_web_content_provider(
provider_id=request.id,
name=request.name,
provider_type=request.provider_type,
api_key=request.api_key,
api_key_changed=request.api_key_changed,
config=request.config,
activate=request.activate,
db_session=db_session,
)
db_session.commit()
return WebContentProviderView(
id=provider.id,
name=provider.name,
provider_type=WebContentProviderType(provider.provider_type),
is_active=provider.is_active,
config=provider.config or {},
has_api_key=bool(provider.api_key),
)
@admin_router.delete(
"/content-providers/{provider_id}", status_code=204, response_class=Response
)
def delete_content_provider(
provider_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> Response:
delete_web_content_provider(provider_id, db_session)
return Response(status_code=204)
@admin_router.post("/content-providers/{provider_id}/activate")
def activate_content_provider(
provider_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> WebContentProviderView:
provider = set_active_web_content_provider(
provider_id=provider_id, db_session=db_session
)
db_session.commit()
return WebContentProviderView(
id=provider.id,
name=provider.name,
provider_type=WebContentProviderType(provider.provider_type),
is_active=provider.is_active,
config=provider.config or {},
has_api_key=bool(provider.api_key),
)
@admin_router.post("/content-providers/reset-default")
def reset_content_provider_default(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
providers = fetch_web_content_providers(db_session)
active_ids = [provider.id for provider in providers if provider.is_active]
for provider_id in active_ids:
deactivate_web_content_provider(provider_id=provider_id, db_session=db_session)
db_session.commit()
return {"status": "ok"}
@admin_router.post("/content-providers/{provider_id}/deactivate")
def deactivate_content_provider(
provider_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
deactivate_web_content_provider(provider_id=provider_id, db_session=db_session)
db_session.commit()
return {"status": "ok"}
@admin_router.post("/content-providers/test")
def test_content_provider(
request: WebContentProviderTestRequest,
_: User | None = Depends(current_admin_user),
) -> dict[str, str]:
try:
provider = build_content_provider_from_config(
provider_type=request.provider_type,
api_key=request.api_key,
config=request.config or {},
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
if provider is None:
raise HTTPException(
status_code=400, detail="Unable to build provider configuration."
)
# Actually test the API key by making a real content fetch call
try:
test_url = "https://example.com"
test_results = provider.contents([test_url])
if not test_results or not any(
result.scrape_successful for result in test_results
):
raise HTTPException(
status_code=400,
detail="API key validation failed: content fetch returned no results.",
)
except HTTPException:
raise
except Exception as e:
error_msg = str(e)
if (
"api" in error_msg.lower()
or "key" in error_msg.lower()
or "auth" in error_msg.lower()
):
raise HTTPException(
status_code=400,
detail=f"Invalid API key: {error_msg}",
) from e
raise HTTPException(
status_code=400,
detail=f"API key validation failed: {error_msg}",
) from e
logger.info(
f"Web content provider test succeeded for {request.provider_type.value}."
)
return {"status": "ok"}

View File

@@ -0,0 +1,69 @@
from typing import Any
from pydantic import BaseModel
from pydantic import Field
from shared_configs.enums import WebContentProviderType
from shared_configs.enums import WebSearchProviderType
class WebSearchProviderView(BaseModel):
id: int
name: str
provider_type: WebSearchProviderType
is_active: bool
config: dict[str, str] | None
has_api_key: bool = Field(
default=False,
description="Indicates whether an API key is stored for this provider.",
)
class WebSearchProviderUpsertRequest(BaseModel):
id: int | None = Field(default=None, description="Existing provider ID to update.")
name: str
provider_type: WebSearchProviderType
config: dict[str, str] | None = None
api_key: str | None = Field(
default=None,
description="API key for the provider. Only required when creating or updating credentials.",
)
api_key_changed: bool = Field(
default=False,
description="Set to true when providing a new API key for an existing provider.",
)
activate: bool = Field(
default=False,
description="If true, sets this provider as the active one after upsert.",
)
class WebContentProviderView(BaseModel):
id: int
name: str
provider_type: WebContentProviderType
is_active: bool
config: dict[str, str] | None
has_api_key: bool = Field(default=False)
class WebContentProviderUpsertRequest(BaseModel):
id: int | None = None
name: str
provider_type: WebContentProviderType
config: dict[str, str] | None = None
api_key: str | None = None
api_key_changed: bool = False
activate: bool = False
class WebSearchProviderTestRequest(BaseModel):
provider_type: WebSearchProviderType
api_key: str | None = None
config: dict[str, Any] | None = None
class WebContentProviderTestRequest(BaseModel):
provider_type: WebContentProviderType
api_key: str | None = None
config: dict[str, Any] | None = None

View File

@@ -16,7 +16,7 @@ class ForceUseTool(BaseModel):
def build_openai_tool_choice_dict(self) -> dict[str, Any]:
"""Build dict in the format that OpenAI expects which tells them to use this tool."""
return {"type": "function", "function": {"name": self.tool_name}}
return {"type": "function", "name": self.tool_name}
def filter_tools_for_force_tool_use(

View File

@@ -5,8 +5,8 @@ from sqlalchemy.orm import Session
from typing_extensions import override
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.chat_configs import EXA_API_KEY
from onyx.configs.chat_configs import SERPER_API_KEY
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.web_search import fetch_active_web_search_provider
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.tools.message import ToolCallSummary
@@ -51,8 +51,10 @@ class WebSearchTool(Tool[None]):
@override
@classmethod
def is_available(cls, db_session: Session) -> bool:
"""Available only if EXA or SERPER API key is configured."""
return bool(EXA_API_KEY) or bool(SERPER_API_KEY)
"""Available only if an active web search provider is configured in the database."""
with get_session_with_current_tenant() as session:
provider = fetch_active_web_search_provider(session)
return provider is not None
def tool_definition(self) -> dict:
return {

View File

@@ -6,14 +6,20 @@ from pydantic import TypeAdapter
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContentProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
get_default_provider,
get_default_content_provider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
WebSearchProvider,
get_default_provider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
dummy_inference_section_from_internet_content,
@@ -190,7 +196,7 @@ changing or evolving.
def _open_url_core(
run_context: RunContextWrapper[ChatTurnContext],
urls: Sequence[str],
search_provider: WebSearchProvider,
content_provider: WebContentProvider,
) -> list[LlmOpenUrlResult]:
# TODO: Find better way to track index that isn't so implicit
# based on number of tool calls
@@ -206,7 +212,7 @@ def _open_url_core(
)
)
docs = search_provider.contents(urls)
docs = content_provider.contents(urls)
results = [
LlmOpenUrlResult(
document_citation_number=DOCUMENT_CITATION_NUMBER_EMPTY_VALUE,
@@ -267,10 +273,10 @@ def open_url(
"""
Tool for fetching and extracting full content from web pages.
"""
search_provider = get_default_provider()
if search_provider is None:
raise ValueError("No search provider found")
retrieved_docs = _open_url_core(run_context, urls, search_provider)
content_provider = get_default_content_provider()
if content_provider is None:
raise ValueError("No web content provider found")
retrieved_docs = _open_url_core(run_context, urls, content_provider)
adapter = TypeAdapter(list[LlmOpenUrlResult])
return adapter.dump_json(retrieved_docs).decode()

View File

@@ -80,7 +80,11 @@ def setup_braintrust_if_creds_available() -> None:
api_key=BRAINTRUST_API_KEY,
)
braintrust.set_masking_function(_mask)
set_trace_processors([BraintrustTracingProcessor(braintrust_logger)])
_setup_legacy_langchain_tracing()
logger.notice("Braintrust tracing initialized")
def _setup_legacy_langchain_tracing() -> None:
handler = BraintrustCallbackHandler()
set_global_handler(handler)
set_trace_processors([BraintrustTracingProcessor(braintrust_logger)])
logger.notice("Braintrust tracing initialized")

View File

@@ -0,0 +1,216 @@
import datetime
from typing import Any
from typing import Dict
from typing import Optional
import braintrust
from braintrust import NOOP_SPAN
from .framework.processor_interface import TracingProcessor
from .framework.span_data import AgentSpanData
from .framework.span_data import FunctionSpanData
from .framework.span_data import GenerationSpanData
from .framework.span_data import SpanData
from .framework.spans import Span
from .framework.traces import Trace
def _span_type(span: Span[Any]) -> braintrust.SpanTypeAttribute:
if span.span_data.type in ["agent"]:
return braintrust.SpanTypeAttribute.TASK
elif span.span_data.type in ["function"]:
return braintrust.SpanTypeAttribute.TOOL
elif span.span_data.type in ["generation"]:
return braintrust.SpanTypeAttribute.LLM
else:
return braintrust.SpanTypeAttribute.TASK
def _span_name(span: Span[Any]) -> str:
if isinstance(span.span_data, AgentSpanData) or isinstance(
span.span_data, FunctionSpanData
):
return span.span_data.name
elif isinstance(span.span_data, GenerationSpanData):
return "Generation"
else:
return "Unknown"
def _timestamp_from_maybe_iso(timestamp: Optional[str]) -> Optional[float]:
if timestamp is None:
return None
return datetime.datetime.fromisoformat(timestamp).timestamp()
def _maybe_timestamp_elapsed(
end: Optional[str], start: Optional[str]
) -> Optional[float]:
if start is None or end is None:
return None
return (
datetime.datetime.fromisoformat(end) - datetime.datetime.fromisoformat(start)
).total_seconds()
class BraintrustTracingProcessor(TracingProcessor):
"""
`BraintrustTracingProcessor` is a `tracing.TracingProcessor` that logs traces to Braintrust.
Args:
logger: A `braintrust.Span` or `braintrust.Experiment` or `braintrust.Logger` to use for logging.
If `None`, the current span, experiment, or logger will be selected exactly as in `braintrust.start_span`.
"""
def __init__(self, logger: Optional[braintrust.Logger] = None):
self._logger = logger
self._spans: Dict[str, Any] = {}
self._first_input: Dict[str, Any] = {}
self._last_output: Dict[str, Any] = {}
def on_trace_start(self, trace: Trace) -> None:
trace_meta = trace.export() or {}
metadata = {
"group_id": trace_meta.get("group_id"),
**(trace_meta.get("metadata") or {}),
}
current_context = braintrust.current_span()
if current_context != NOOP_SPAN:
self._spans[trace.trace_id] = current_context.start_span( # type: ignore[assignment]
name=trace.name,
span_attributes={"type": "task", "name": trace.name},
metadata=metadata,
)
elif self._logger is not None:
self._spans[trace.trace_id] = self._logger.start_span( # type: ignore[assignment]
span_attributes={"type": "task", "name": trace.name},
span_id=trace.trace_id,
root_span_id=trace.trace_id,
metadata=metadata,
)
else:
self._spans[trace.trace_id] = braintrust.start_span( # type: ignore[assignment]
id=trace.trace_id,
span_attributes={"type": "task", "name": trace.name},
metadata=metadata,
)
def on_trace_end(self, trace: Trace) -> None:
span: Any = self._spans.pop(trace.trace_id)
# Get the first input and last output for this specific trace
trace_first_input = self._first_input.pop(trace.trace_id, None)
trace_last_output = self._last_output.pop(trace.trace_id, None)
span.log(input=trace_first_input, output=trace_last_output)
span.end()
def _agent_log_data(self, span: Span[AgentSpanData]) -> Dict[str, Any]:
return {
"metadata": {
"tools": span.span_data.tools,
"handoffs": span.span_data.handoffs,
"output_type": span.span_data.output_type,
}
}
def _function_log_data(self, span: Span[FunctionSpanData]) -> Dict[str, Any]:
return {
"input": span.span_data.input,
"output": span.span_data.output,
}
def _generation_log_data(self, span: Span[GenerationSpanData]) -> Dict[str, Any]:
metrics = {}
ttft = _maybe_timestamp_elapsed(span.ended_at, span.started_at)
if ttft is not None:
metrics["time_to_first_token"] = ttft
usage = span.span_data.usage or {}
if "prompt_tokens" in usage:
metrics["prompt_tokens"] = usage["prompt_tokens"]
elif "input_tokens" in usage:
metrics["prompt_tokens"] = usage["input_tokens"]
if "completion_tokens" in usage:
metrics["completion_tokens"] = usage["completion_tokens"]
elif "output_tokens" in usage:
metrics["completion_tokens"] = usage["output_tokens"]
if "total_tokens" in usage:
metrics["tokens"] = usage["total_tokens"]
elif "input_tokens" in usage and "output_tokens" in usage:
metrics["tokens"] = usage["input_tokens"] + usage["output_tokens"]
if "cache_read_input_tokens" in usage:
metrics["prompt_cached_tokens"] = usage["cache_read_input_tokens"]
if "cache_creation_input_tokens" in usage:
metrics["prompt_cache_creation_tokens"] = usage[
"cache_creation_input_tokens"
]
return {
"input": span.span_data.input,
"output": span.span_data.output,
"metadata": {
"model": span.span_data.model,
"model_config": span.span_data.model_config,
},
"metrics": metrics,
}
def _log_data(self, span: Span[Any]) -> Dict[str, Any]:
if isinstance(span.span_data, AgentSpanData):
return self._agent_log_data(span)
elif isinstance(span.span_data, FunctionSpanData):
return self._function_log_data(span)
elif isinstance(span.span_data, GenerationSpanData):
return self._generation_log_data(span)
else:
return {}
def on_span_start(self, span: Span[SpanData]) -> None:
parent: Any = (
self._spans[span.parent_id]
if span.parent_id is not None
else self._spans[span.trace_id]
)
created_span: Any = parent.start_span(
id=span.span_id,
name=_span_name(span),
type=_span_type(span),
start_time=_timestamp_from_maybe_iso(span.started_at),
)
self._spans[span.span_id] = created_span
# Set the span as current so current_span() calls will return it
created_span.set_current()
def on_span_end(self, span: Span[SpanData]) -> None:
s: Any = self._spans.pop(span.span_id)
event = dict(error=span.error, **self._log_data(span))
s.log(**event)
s.unset_current()
s.end(_timestamp_from_maybe_iso(span.ended_at))
input_ = event.get("input")
output = event.get("output")
# Store first input and last output per trace_id
trace_id = span.trace_id
if trace_id not in self._first_input and input_ is not None:
self._first_input[trace_id] = input_
if output is not None:
self._last_output[trace_id] = output
def shutdown(self) -> None:
if self._logger is not None:
self._logger.flush()
else:
braintrust.flush()
def force_flush(self) -> None:
if self._logger is not None:
self._logger.flush()
else:
braintrust.flush()

View File

@@ -0,0 +1,21 @@
from .processor_interface import TracingProcessor
from .provider import DefaultTraceProvider
from .setup import get_trace_provider
from .setup import set_trace_provider
def add_trace_processor(span_processor: TracingProcessor) -> None:
"""
Adds a new trace processor. This processor will receive all traces/spans.
"""
get_trace_provider().register_processor(span_processor)
def set_trace_processors(processors: list[TracingProcessor]) -> None:
"""
Set the list of trace processors. This will replace the current list of processors.
"""
get_trace_provider().set_processors(processors)
set_trace_provider(DefaultTraceProvider())

View File

@@ -0,0 +1,21 @@
from typing import Any
from .create import get_current_span
from .spans import Span
from .spans import SpanError
from onyx.utils.logger import setup_logger
logger = setup_logger(__name__)
def attach_error_to_span(span: Span[Any], error: SpanError) -> None:
span.set_error(error)
def attach_error_to_current_span(error: SpanError) -> None:
span = get_current_span()
if span:
attach_error_to_span(span, error)
else:
logger.warning(f"No span to add error {error} to")

View File

@@ -0,0 +1,192 @@
from __future__ import annotations
from collections.abc import Mapping
from collections.abc import Sequence
from typing import Any
from typing import TYPE_CHECKING
from .setup import get_trace_provider
from .span_data import AgentSpanData
from .span_data import FunctionSpanData
from .span_data import GenerationSpanData
from .spans import Span
from .traces import Trace
from onyx.utils.logger import setup_logger
if TYPE_CHECKING:
pass
logger = setup_logger(__name__)
def trace(
workflow_name: str,
trace_id: str | None = None,
group_id: str | None = None,
metadata: dict[str, Any] | None = None,
disabled: bool = False,
) -> Trace:
"""
Create a new trace. The trace will not be started automatically; you should either use
it as a context manager (`with trace(...):`) or call `trace.start()` + `trace.finish()`
manually.
In addition to the workflow name and optional grouping identifier, you can provide
an arbitrary metadata dictionary to attach additional user-defined information to
the trace.
Args:
workflow_name: The name of the logical app or workflow. For example, you might provide
"code_bot" for a coding agent, or "customer_support_agent" for a customer support agent.
trace_id: The ID of the trace. Optional. If not provided, we will generate an ID. We
recommend using `util.gen_trace_id()` to generate a trace ID, to guarantee that IDs are
correctly formatted.
group_id: Optional grouping identifier to link multiple traces from the same conversation
or process. For instance, you might use a chat thread ID.
metadata: Optional dictionary of additional metadata to attach to the trace.
disabled: If True, we will return a Trace but the Trace will not be recorded.
Returns:
The newly created trace object.
"""
current_trace = get_trace_provider().get_current_trace()
if current_trace:
logger.warning(
"Trace already exists. Creating a new trace, but this is probably a mistake."
)
return get_trace_provider().create_trace(
name=workflow_name,
trace_id=trace_id,
group_id=group_id,
metadata=metadata,
disabled=disabled,
)
def get_current_trace() -> Trace | None:
"""Returns the currently active trace, if present."""
return get_trace_provider().get_current_trace()
def get_current_span() -> Span[Any] | None:
"""Returns the currently active span, if present."""
return get_trace_provider().get_current_span()
def agent_span(
name: str,
handoffs: list[str] | None = None,
tools: list[str] | None = None,
output_type: str | None = None,
span_id: str | None = None,
parent: Trace | Span[Any] | None = None,
disabled: bool = False,
) -> Span[AgentSpanData]:
"""Create a new agent span. The span will not be started automatically, you should either do
`with agent_span() ...` or call `span.start()` + `span.finish()` manually.
Args:
name: The name of the agent.
handoffs: Optional list of agent names to which this agent could hand off control.
tools: Optional list of tool names available to this agent.
output_type: Optional name of the output type produced by the agent.
span_id: The ID of the span. Optional. If not provided, we will generate an ID. We
recommend using `util.gen_span_id()` to generate a span ID, to guarantee that IDs are
correctly formatted.
parent: The parent span or trace. If not provided, we will automatically use the current
trace/span as the parent.
disabled: If True, we will return a Span but the Span will not be recorded.
Returns:
The newly created agent span.
"""
return get_trace_provider().create_span(
span_data=AgentSpanData(
name=name, handoffs=handoffs, tools=tools, output_type=output_type
),
span_id=span_id,
parent=parent,
disabled=disabled,
)
def function_span(
name: str,
input: str | None = None,
output: str | None = None,
span_id: str | None = None,
parent: Trace | Span[Any] | None = None,
disabled: bool = False,
) -> Span[FunctionSpanData]:
"""Create a new function span. The span will not be started automatically, you should either do
`with function_span() ...` or call `span.start()` + `span.finish()` manually.
Args:
name: The name of the function.
input: The input to the function.
output: The output of the function.
span_id: The ID of the span. Optional. If not provided, we will generate an ID. We
recommend using `util.gen_span_id()` to generate a span ID, to guarantee that IDs are
correctly formatted.
parent: The parent span or trace. If not provided, we will automatically use the current
trace/span as the parent.
disabled: If True, we will return a Span but the Span will not be recorded.
Returns:
The newly created function span.
"""
return get_trace_provider().create_span(
span_data=FunctionSpanData(name=name, input=input, output=output),
span_id=span_id,
parent=parent,
disabled=disabled,
)
def generation_span(
input: Sequence[Mapping[str, Any]] | None = None,
output: Sequence[Mapping[str, Any]] | None = None,
model: str | None = None,
model_config: Mapping[str, Any] | None = None,
usage: dict[str, Any] | None = None,
span_id: str | None = None,
parent: Trace | Span[Any] | None = None,
disabled: bool = False,
) -> Span[GenerationSpanData]:
"""Create a new generation span. The span will not be started automatically, you should either
do `with generation_span() ...` or call `span.start()` + `span.finish()` manually.
This span captures the details of a model generation, including the
input message sequence, any generated outputs, the model name and
configuration, and usage data. If you only need to capture a model
response identifier, use `response_span()` instead.
Args:
input: The sequence of input messages sent to the model.
output: The sequence of output messages received from the model.
model: The model identifier used for the generation.
model_config: The model configuration (hyperparameters) used.
usage: A dictionary of usage information (input tokens, output tokens, etc.).
span_id: The ID of the span. Optional. If not provided, we will generate an ID. We
recommend using `util.gen_span_id()` to generate a span ID, to guarantee that IDs are
correctly formatted.
parent: The parent span or trace. If not provided, we will automatically use the current
trace/span as the parent.
disabled: If True, we will return a Span but the Span will not be recorded.
Returns:
The newly created generation span.
"""
return get_trace_provider().create_span(
span_data=GenerationSpanData(
input=input,
output=output,
model=model,
model_config=model_config,
usage=usage,
),
span_id=span_id,
parent=parent,
disabled=disabled,
)

View File

@@ -0,0 +1,136 @@
import abc
from typing import Any
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .spans import Span
from .traces import Trace
class TracingProcessor(abc.ABC):
"""Interface for processing and monitoring traces and spans in the OpenAI Agents system.
This abstract class defines the interface that all tracing processors must implement.
Processors receive notifications when traces and spans start and end, allowing them
to collect, process, and export tracing data.
Example:
```python
class CustomProcessor(TracingProcessor):
def __init__(self):
self.active_traces = {}
self.active_spans = {}
def on_trace_start(self, trace):
self.active_traces[trace.trace_id] = trace
def on_trace_end(self, trace):
# Process completed trace
del self.active_traces[trace.trace_id]
def on_span_start(self, span):
self.active_spans[span.span_id] = span
def on_span_end(self, span):
# Process completed span
del self.active_spans[span.span_id]
def shutdown(self):
# Clean up resources
self.active_traces.clear()
self.active_spans.clear()
def force_flush(self):
# Force processing of any queued items
pass
```
Notes:
- All methods should be thread-safe
- Methods should not block for long periods
- Handle errors gracefully to prevent disrupting agent execution
"""
@abc.abstractmethod
def on_trace_start(self, trace: "Trace") -> None:
"""Called when a new trace begins execution.
Args:
trace: The trace that started. Contains workflow name and metadata.
Notes:
- Called synchronously on trace start
- Should return quickly to avoid blocking execution
- Any errors should be caught and handled internally
"""
@abc.abstractmethod
def on_trace_end(self, trace: "Trace") -> None:
"""Called when a trace completes execution.
Args:
trace: The completed trace containing all spans and results.
Notes:
- Called synchronously when trace finishes
- Good time to export/process the complete trace
- Should handle cleanup of any trace-specific resources
"""
@abc.abstractmethod
def on_span_start(self, span: "Span[Any]") -> None:
"""Called when a new span begins execution.
Args:
span: The span that started. Contains operation details and context.
Notes:
- Called synchronously on span start
- Should return quickly to avoid blocking execution
- Spans are automatically nested under current trace/span
"""
@abc.abstractmethod
def on_span_end(self, span: "Span[Any]") -> None:
"""Called when a span completes execution.
Args:
span: The completed span containing execution results.
Notes:
- Called synchronously when span finishes
- Should not block or raise exceptions
- Good time to export/process the individual span
"""
@abc.abstractmethod
def shutdown(self) -> None:
"""Called when the application stops to clean up resources.
Should perform any necessary cleanup like:
- Flushing queued traces/spans
- Closing connections
- Releasing resources
"""
@abc.abstractmethod
def force_flush(self) -> None:
"""Forces immediate processing of any queued traces/spans.
Notes:
- Should process all queued items before returning
- Useful before shutdown or when immediate processing is needed
- May block while processing completes
"""
class TracingExporter(abc.ABC):
"""Exports traces and spans. For example, could log them or send them to a backend."""
@abc.abstractmethod
def export(self, items: list["Trace | Span[Any]"]) -> None:
"""Exports a list of traces and spans.
Args:
items: The items to export.
"""

View File

@@ -0,0 +1,319 @@
from __future__ import annotations
import threading
import uuid
from abc import ABC
from abc import abstractmethod
from datetime import datetime
from datetime import timezone
from typing import Any
from .processor_interface import TracingProcessor
from .scope import Scope
from .spans import NoOpSpan
from .spans import Span
from .spans import SpanImpl
from .spans import TSpanData
from .traces import NoOpTrace
from .traces import Trace
from .traces import TraceImpl
from onyx.utils.logger import setup_logger
logger = setup_logger(__name__)
class SynchronousMultiTracingProcessor(TracingProcessor):
"""
Forwards all calls to a list of TracingProcessors, in order of registration.
"""
def __init__(self) -> None:
# Using a tuple to avoid race conditions when iterating over processors
self._processors: tuple[TracingProcessor, ...] = ()
self._lock = threading.Lock()
def add_tracing_processor(self, tracing_processor: TracingProcessor) -> None:
"""
Add a processor to the list of processors. Each processor will receive all traces/spans.
"""
with self._lock:
self._processors += (tracing_processor,)
def set_processors(self, processors: list[TracingProcessor]) -> None:
"""
Set the list of processors. This will replace the current list of processors.
"""
with self._lock:
self._processors = tuple(processors)
def on_trace_start(self, trace: Trace) -> None:
"""
Called when a trace is started.
"""
for processor in self._processors:
try:
processor.on_trace_start(trace)
except Exception as e:
logger.error(
f"Error in trace processor {processor} during on_trace_start: {e}"
)
def on_trace_end(self, trace: Trace) -> None:
"""
Called when a trace is finished.
"""
for processor in self._processors:
try:
processor.on_trace_end(trace)
except Exception as e:
logger.error(
f"Error in trace processor {processor} during on_trace_end: {e}"
)
def on_span_start(self, span: Span[Any]) -> None:
"""
Called when a span is started.
"""
for processor in self._processors:
try:
processor.on_span_start(span)
except Exception as e:
logger.error(
f"Error in trace processor {processor} during on_span_start: {e}"
)
def on_span_end(self, span: Span[Any]) -> None:
"""
Called when a span is finished.
"""
for processor in self._processors:
try:
processor.on_span_end(span)
except Exception as e:
logger.error(
f"Error in trace processor {processor} during on_span_end: {e}"
)
def shutdown(self) -> None:
"""
Called when the application stops.
"""
for processor in self._processors:
logger.debug(f"Shutting down trace processor {processor}")
try:
processor.shutdown()
except Exception as e:
logger.error(f"Error shutting down trace processor {processor}: {e}")
def force_flush(self) -> None:
"""
Force the processors to flush their buffers.
"""
for processor in self._processors:
try:
processor.force_flush()
except Exception as e:
logger.error(f"Error flushing trace processor {processor}: {e}")
class TraceProvider(ABC):
"""Interface for creating traces and spans."""
@abstractmethod
def register_processor(self, processor: TracingProcessor) -> None:
"""Add a processor that will receive all traces and spans."""
@abstractmethod
def set_processors(self, processors: list[TracingProcessor]) -> None:
"""Replace the list of processors with ``processors``."""
@abstractmethod
def get_current_trace(self) -> Trace | None:
"""Return the currently active trace, if any."""
@abstractmethod
def get_current_span(self) -> Span[Any] | None:
"""Return the currently active span, if any."""
@abstractmethod
def time_iso(self) -> str:
"""Return the current time in ISO 8601 format."""
@abstractmethod
def gen_trace_id(self) -> str:
"""Generate a new trace identifier."""
@abstractmethod
def gen_span_id(self) -> str:
"""Generate a new span identifier."""
@abstractmethod
def gen_group_id(self) -> str:
"""Generate a new group identifier."""
@abstractmethod
def create_trace(
self,
name: str,
trace_id: str | None = None,
group_id: str | None = None,
metadata: dict[str, Any] | None = None,
disabled: bool = False,
) -> Trace:
"""Create a new trace."""
@abstractmethod
def create_span(
self,
span_data: TSpanData,
span_id: str | None = None,
parent: Trace | Span[Any] | None = None,
disabled: bool = False,
) -> Span[TSpanData]:
"""Create a new span."""
@abstractmethod
def shutdown(self) -> None:
"""Clean up any resources used by the provider."""
class DefaultTraceProvider(TraceProvider):
def __init__(self) -> None:
self._multi_processor = SynchronousMultiTracingProcessor()
def register_processor(self, processor: TracingProcessor) -> None:
"""
Add a processor to the list of processors. Each processor will receive all traces/spans.
"""
self._multi_processor.add_tracing_processor(processor)
def set_processors(self, processors: list[TracingProcessor]) -> None:
"""
Set the list of processors. This will replace the current list of processors.
"""
self._multi_processor.set_processors(processors)
def get_current_trace(self) -> Trace | None:
"""
Returns the currently active trace, if any.
"""
return Scope.get_current_trace()
def get_current_span(self) -> Span[Any] | None:
"""
Returns the currently active span, if any.
"""
return Scope.get_current_span()
def time_iso(self) -> str:
"""Return the current time in ISO 8601 format."""
return datetime.now(timezone.utc).isoformat()
def gen_trace_id(self) -> str:
"""Generate a new trace ID."""
return f"trace_{uuid.uuid4().hex}"
def gen_span_id(self) -> str:
"""Generate a new span ID."""
return f"span_{uuid.uuid4().hex[:24]}"
def gen_group_id(self) -> str:
"""Generate a new group ID."""
return f"group_{uuid.uuid4().hex[:24]}"
def create_trace(
self,
name: str,
trace_id: str | None = None,
group_id: str | None = None,
metadata: dict[str, Any] | None = None,
disabled: bool = False,
) -> Trace:
"""
Create a new trace.
"""
if disabled:
logger.debug(f"Tracing is disabled. Not creating trace {name}")
return NoOpTrace()
trace_id = trace_id or self.gen_trace_id()
logger.debug(f"Creating trace {name} with id {trace_id}")
return TraceImpl(
name=name,
trace_id=trace_id,
group_id=group_id,
metadata=metadata,
processor=self._multi_processor,
)
def create_span(
self,
span_data: TSpanData,
span_id: str | None = None,
parent: Trace | Span[Any] | None = None,
disabled: bool = False,
) -> Span[TSpanData]:
"""
Create a new span.
"""
if disabled:
logger.debug(f"Tracing is disabled. Not creating span {span_data}")
return NoOpSpan(span_data)
trace_id: str
parent_id: str | None
if not parent:
current_span = Scope.get_current_span()
current_trace = Scope.get_current_trace()
if current_trace is None:
logger.error(
"No active trace. Make sure to start a trace with `trace()` first "
"Returning NoOpSpan."
)
return NoOpSpan(span_data)
elif isinstance(current_trace, NoOpTrace) or isinstance(
current_span, NoOpSpan
):
logger.debug(
f"Parent {current_span} or {current_trace} is no-op, returning NoOpSpan"
)
return NoOpSpan(span_data)
parent_id = current_span.span_id if current_span else None
trace_id = current_trace.trace_id
elif isinstance(parent, Trace):
if isinstance(parent, NoOpTrace):
logger.debug(f"Parent {parent} is no-op, returning NoOpSpan")
return NoOpSpan(span_data)
trace_id = parent.trace_id
parent_id = None
elif isinstance(parent, Span):
if isinstance(parent, NoOpSpan):
logger.debug(f"Parent {parent} is no-op, returning NoOpSpan")
return NoOpSpan(span_data)
parent_id = parent.span_id
trace_id = parent.trace_id
else:
# This should never happen, but mypy needs it
raise ValueError(f"Invalid parent type: {type(parent)}")
logger.debug(f"Creating span {span_data} with id {span_id}")
return SpanImpl(
trace_id=trace_id,
span_id=span_id or self.gen_span_id(),
parent_id=parent_id,
processor=self._multi_processor,
span_data=span_data,
)
def shutdown(self) -> None:
try:
logger.debug("Shutting down trace provider")
self._multi_processor.shutdown()
except Exception as e:
logger.error(f"Error shutting down trace provider: {e}")

View File

@@ -0,0 +1,49 @@
import contextvars
from typing import Any
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .spans import Span
from .traces import Trace
_current_span: contextvars.ContextVar["Span[Any] | None"] = contextvars.ContextVar(
"current_span", default=None
)
_current_trace: contextvars.ContextVar["Trace | None"] = contextvars.ContextVar(
"current_trace", default=None
)
class Scope:
"""
Manages the current span and trace in the context.
"""
@classmethod
def get_current_span(cls) -> "Span[Any] | None":
return _current_span.get()
@classmethod
def set_current_span(
cls, span: "Span[Any] | None"
) -> "contextvars.Token[Span[Any] | None]":
return _current_span.set(span)
@classmethod
def reset_current_span(cls, token: "contextvars.Token[Span[Any] | None]") -> None:
_current_span.reset(token)
@classmethod
def get_current_trace(cls) -> "Trace | None":
return _current_trace.get()
@classmethod
def set_current_trace(
cls, trace: "Trace | None"
) -> "contextvars.Token[Trace | None]":
return _current_trace.set(trace)
@classmethod
def reset_current_trace(cls, token: "contextvars.Token[Trace | None]") -> None:
_current_trace.reset(token)

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .provider import TraceProvider
GLOBAL_TRACE_PROVIDER: TraceProvider | None = None
def set_trace_provider(provider: TraceProvider) -> None:
"""Set the global trace provider used by tracing utilities."""
global GLOBAL_TRACE_PROVIDER
GLOBAL_TRACE_PROVIDER = provider
def get_trace_provider() -> TraceProvider:
"""Get the global trace provider used by tracing utilities."""
if GLOBAL_TRACE_PROVIDER is None:
raise RuntimeError("Trace provider not set")
return GLOBAL_TRACE_PROVIDER

View File

@@ -0,0 +1,130 @@
import abc
from collections.abc import Mapping
from collections.abc import Sequence
from typing import Any
class SpanData(abc.ABC):
"""
Represents span data in the trace.
"""
@abc.abstractmethod
def export(self) -> dict[str, Any]:
"""Export the span data as a dictionary."""
@property
@abc.abstractmethod
def type(self) -> str:
"""Return the type of the span."""
class AgentSpanData(SpanData):
"""
Represents an Agent Span in the trace.
Includes name, handoffs, tools, and output type.
"""
__slots__ = ("name", "handoffs", "tools", "output_type")
def __init__(
self,
name: str,
handoffs: list[str] | None = None,
tools: list[str] | None = None,
output_type: str | None = None,
):
self.name = name
self.handoffs: list[str] | None = handoffs
self.tools: list[str] | None = tools
self.output_type: str | None = output_type
@property
def type(self) -> str:
return "agent"
def export(self) -> dict[str, Any]:
return {
"type": self.type,
"name": self.name,
"handoffs": self.handoffs,
"tools": self.tools,
"output_type": self.output_type,
}
class FunctionSpanData(SpanData):
"""
Represents a Function Span in the trace.
Includes input, output and MCP data (if applicable).
"""
__slots__ = ("name", "input", "output", "mcp_data")
def __init__(
self,
name: str,
input: str | None,
output: Any | None,
mcp_data: dict[str, Any] | None = None,
):
self.name = name
self.input = input
self.output = output
self.mcp_data = mcp_data
@property
def type(self) -> str:
return "function"
def export(self) -> dict[str, Any]:
return {
"type": self.type,
"name": self.name,
"input": self.input,
"output": str(self.output) if self.output else None,
"mcp_data": self.mcp_data,
}
class GenerationSpanData(SpanData):
"""
Represents a Generation Span in the trace.
Includes input, output, model, model configuration, and usage.
"""
__slots__ = (
"input",
"output",
"model",
"model_config",
"usage",
)
def __init__(
self,
input: Sequence[Mapping[str, Any]] | None = None,
output: Sequence[Mapping[str, Any]] | None = None,
model: str | None = None,
model_config: Mapping[str, Any] | None = None,
usage: dict[str, Any] | None = None,
):
self.input = input
self.output = output
self.model = model
self.model_config = model_config
self.usage = usage
@property
def type(self) -> str:
return "generation"
def export(self) -> dict[str, Any]:
return {
"type": self.type,
"input": self.input,
"output": self.output,
"model": self.model,
"model_config": self.model_config,
"usage": self.usage,
}

View File

@@ -0,0 +1,356 @@
from __future__ import annotations
import abc
import contextvars
from types import TracebackType
from typing import Any
from typing import Generic
from typing import TypeVar
from typing_extensions import TypedDict
from . import util
from .processor_interface import TracingProcessor
from .scope import Scope
from .span_data import SpanData
TSpanData = TypeVar("TSpanData", bound=SpanData)
class SpanError(TypedDict):
"""Represents an error that occurred during span execution.
Attributes:
message: A human-readable error description
data: Optional dictionary containing additional error context
"""
message: str
data: dict[str, Any] | None
class Span(abc.ABC, Generic[TSpanData]):
"""Base class for representing traceable operations with timing and context.
A span represents a single operation within a trace (e.g., an LLM call, tool execution,
or agent run). Spans track timing, relationships between operations, and operation-specific
data.
Type Args:
TSpanData: The type of span-specific data this span contains.
Example:
```python
# Creating a custom span
with custom_span("database_query", {
"operation": "SELECT",
"table": "users"
}) as span:
results = await db.query("SELECT * FROM users")
span.set_output({"count": len(results)})
# Handling errors in spans
with custom_span("risky_operation") as span:
try:
result = perform_risky_operation()
except Exception as e:
span.set_error({
"message": str(e),
"data": {"operation": "risky_operation"}
})
raise
```
Notes:
- Spans automatically nest under the current trace
- Use context managers for reliable start/finish
- Include relevant data but avoid sensitive information
- Handle errors properly using set_error()
"""
@property
@abc.abstractmethod
def trace_id(self) -> str:
"""The ID of the trace this span belongs to.
Returns:
str: Unique identifier of the parent trace.
"""
@property
@abc.abstractmethod
def span_id(self) -> str:
"""Unique identifier for this span.
Returns:
str: The span's unique ID within its trace.
"""
@property
@abc.abstractmethod
def span_data(self) -> TSpanData:
"""Operation-specific data for this span.
Returns:
TSpanData: Data specific to this type of span (e.g., LLM generation data).
"""
@abc.abstractmethod
def start(self, mark_as_current: bool = False) -> None:
"""
Start the span.
Args:
mark_as_current: If true, the span will be marked as the current span.
"""
@abc.abstractmethod
def finish(self, reset_current: bool = False) -> None:
"""
Finish the span.
Args:
reset_current: If true, the span will be reset as the current span.
"""
@abc.abstractmethod
def __enter__(self) -> Span[TSpanData]:
pass
@abc.abstractmethod
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
pass
@property
@abc.abstractmethod
def parent_id(self) -> str | None:
"""ID of the parent span, if any.
Returns:
str | None: The parent span's ID, or None if this is a root span.
"""
@abc.abstractmethod
def set_error(self, error: SpanError) -> None:
pass
@property
@abc.abstractmethod
def error(self) -> SpanError | None:
"""Any error that occurred during span execution.
Returns:
SpanError | None: Error details if an error occurred, None otherwise.
"""
@abc.abstractmethod
def export(self) -> dict[str, Any] | None:
pass
@property
@abc.abstractmethod
def started_at(self) -> str | None:
"""When the span started execution.
Returns:
str | None: ISO format timestamp of span start, None if not started.
"""
@property
@abc.abstractmethod
def ended_at(self) -> str | None:
"""When the span finished execution.
Returns:
str | None: ISO format timestamp of span end, None if not finished.
"""
class NoOpSpan(Span[TSpanData]):
"""A no-op implementation of Span that doesn't record any data.
Used when tracing is disabled but span operations still need to work.
Args:
span_data: The operation-specific data for this span.
"""
__slots__ = ("_span_data", "_prev_span_token")
def __init__(self, span_data: TSpanData):
self._span_data = span_data
self._prev_span_token: contextvars.Token[Span[TSpanData] | None] | None = None
@property
def trace_id(self) -> str:
return "no-op"
@property
def span_id(self) -> str:
return "no-op"
@property
def span_data(self) -> TSpanData:
return self._span_data
@property
def parent_id(self) -> str | None:
return None
def start(self, mark_as_current: bool = False) -> None:
if mark_as_current:
self._prev_span_token = Scope.set_current_span(self)
def finish(self, reset_current: bool = False) -> None:
if reset_current and self._prev_span_token is not None:
Scope.reset_current_span(self._prev_span_token)
self._prev_span_token = None
def __enter__(self) -> Span[TSpanData]:
self.start(mark_as_current=True)
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
reset_current = True
if exc_type is GeneratorExit:
reset_current = False
self.finish(reset_current=reset_current)
def set_error(self, error: SpanError) -> None:
pass
@property
def error(self) -> SpanError | None:
return None
def export(self) -> dict[str, Any] | None:
return None
@property
def started_at(self) -> str | None:
return None
@property
def ended_at(self) -> str | None:
return None
class SpanImpl(Span[TSpanData]):
__slots__ = (
"_trace_id",
"_span_id",
"_parent_id",
"_started_at",
"_ended_at",
"_error",
"_prev_span_token",
"_processor",
"_span_data",
)
def __init__(
self,
trace_id: str,
span_id: str | None,
parent_id: str | None,
processor: TracingProcessor,
span_data: TSpanData,
):
self._trace_id = trace_id
self._span_id = span_id or util.gen_span_id()
self._parent_id = parent_id
self._started_at: str | None = None
self._ended_at: str | None = None
self._processor = processor
self._error: SpanError | None = None
self._prev_span_token: contextvars.Token[Span[TSpanData] | None] | None = None
self._span_data = span_data
@property
def trace_id(self) -> str:
return self._trace_id
@property
def span_id(self) -> str:
return self._span_id
@property
def span_data(self) -> TSpanData:
return self._span_data
@property
def parent_id(self) -> str | None:
return self._parent_id
def start(self, mark_as_current: bool = False) -> None:
if self.started_at is not None:
return
self._started_at = util.time_iso()
self._processor.on_span_start(self)
if mark_as_current:
self._prev_span_token = Scope.set_current_span(self)
def finish(self, reset_current: bool = False) -> None:
if self.ended_at is not None:
return
self._ended_at = util.time_iso()
self._processor.on_span_end(self)
if reset_current and self._prev_span_token is not None:
Scope.reset_current_span(self._prev_span_token)
self._prev_span_token = None
def __enter__(self) -> Span[TSpanData]:
self.start(mark_as_current=True)
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
reset_current = True
if exc_type is GeneratorExit:
reset_current = False
self.finish(reset_current=reset_current)
def set_error(self, error: SpanError) -> None:
self._error = error
@property
def error(self) -> SpanError | None:
return self._error
@property
def started_at(self) -> str | None:
return self._started_at
@property
def ended_at(self) -> str | None:
return self._ended_at
def export(self) -> dict[str, Any] | None:
return {
"object": "trace.span",
"id": self.span_id,
"trace_id": self.trace_id,
"parent_id": self._parent_id,
"started_at": self._started_at,
"ended_at": self._ended_at,
"span_data": self.span_data.export(),
"error": self._error,
}

View File

@@ -0,0 +1,287 @@
from __future__ import annotations
import abc
import contextvars
from types import TracebackType
from typing import Any
from typing import TYPE_CHECKING
from . import util
from .scope import Scope
if TYPE_CHECKING:
from .processor_interface import TracingProcessor
class Trace(abc.ABC):
"""A complete end-to-end workflow containing related spans and metadata.
A trace represents a logical workflow or operation (e.g., "Customer Service Query"
or "Code Generation") and contains all the spans (individual operations) that occur
during that workflow.
Example:
```python
# Basic trace usage
with trace("Order Processing") as t:
validation_result = await Runner.run(validator, order_data)
if validation_result.approved:
await Runner.run(processor, order_data)
# Trace with metadata and grouping
with trace(
"Customer Service",
group_id="chat_123",
metadata={"customer": "user_456"}
) as t:
result = await Runner.run(support_agent, query)
```
Notes:
- Use descriptive workflow names
- Group related traces with consistent group_ids
- Add relevant metadata for filtering/analysis
- Use context managers for reliable cleanup
- Consider privacy when adding trace data
"""
@abc.abstractmethod
def __enter__(self) -> Trace:
pass
@abc.abstractmethod
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
pass
@abc.abstractmethod
def start(self, mark_as_current: bool = False) -> None:
"""Start the trace and optionally mark it as the current trace.
Args:
mark_as_current: If true, marks this trace as the current trace
in the execution context.
Notes:
- Must be called before any spans can be added
- Only one trace can be current at a time
- Thread-safe when using mark_as_current
"""
@abc.abstractmethod
def finish(self, reset_current: bool = False) -> None:
"""Finish the trace and optionally reset the current trace.
Args:
reset_current: If true, resets the current trace to the previous
trace in the execution context.
Notes:
- Must be called to complete the trace
- Finalizes all open spans
- Thread-safe when using reset_current
"""
@property
@abc.abstractmethod
def trace_id(self) -> str:
"""Get the unique identifier for this trace.
Returns:
str: The trace's unique ID in the format 'trace_<32_alphanumeric>'
Notes:
- IDs are globally unique
- Used to link spans to their parent trace
- Can be used to look up traces in the dashboard
"""
@property
@abc.abstractmethod
def name(self) -> str:
"""Get the human-readable name of this workflow trace.
Returns:
str: The workflow name (e.g., "Customer Service", "Data Processing")
Notes:
- Should be descriptive and meaningful
- Used for grouping and filtering in the dashboard
- Helps identify the purpose of the trace
"""
@abc.abstractmethod
def export(self) -> dict[str, Any] | None:
"""Export the trace data as a serializable dictionary.
Returns:
dict | None: Dictionary containing trace data, or None if tracing is disabled.
Notes:
- Includes all spans and their data
- Used for sending traces to backends
- May include metadata and group ID
"""
class NoOpTrace(Trace):
"""A no-op implementation of Trace that doesn't record any data.
Used when tracing is disabled but trace operations still need to work.
Maintains proper context management but doesn't store or export any data.
Example:
```python
# When tracing is disabled, traces become NoOpTrace
with trace("Disabled Workflow") as t:
# Operations still work but nothing is recorded
await Runner.run(agent, "query")
```
"""
def __init__(self) -> None:
self._started = False
self._prev_context_token: contextvars.Token[Trace | None] | None = None
def __enter__(self) -> Trace:
if self._started:
return self
self._started = True
self.start(mark_as_current=True)
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.finish(reset_current=True)
def start(self, mark_as_current: bool = False) -> None:
if mark_as_current:
self._prev_context_token = Scope.set_current_trace(self)
def finish(self, reset_current: bool = False) -> None:
if reset_current and self._prev_context_token is not None:
Scope.reset_current_trace(self._prev_context_token)
self._prev_context_token = None
@property
def trace_id(self) -> str:
"""The trace's unique identifier.
Returns:
str: A unique ID for this trace.
"""
return "no-op"
@property
def name(self) -> str:
"""The workflow name for this trace.
Returns:
str: Human-readable name describing this workflow.
"""
return "no-op"
def export(self) -> dict[str, Any] | None:
"""Export the trace data as a dictionary.
Returns:
dict | None: Trace data in exportable format, or None if no data.
"""
return None
NO_OP_TRACE = NoOpTrace()
class TraceImpl(Trace):
"""
A trace that will be recorded by the tracing library.
"""
__slots__ = (
"_name",
"_trace_id",
"group_id",
"metadata",
"_prev_context_token",
"_processor",
"_started",
)
def __init__(
self,
name: str,
trace_id: str | None,
group_id: str | None,
metadata: dict[str, Any] | None,
processor: TracingProcessor,
):
self._name = name
self._trace_id = trace_id or util.gen_trace_id()
self.group_id = group_id
self.metadata = metadata
self._prev_context_token: contextvars.Token[Trace | None] | None = None
self._processor = processor
self._started = False
@property
def trace_id(self) -> str:
return self._trace_id
@property
def name(self) -> str:
return self._name
def start(self, mark_as_current: bool = False) -> None:
if self._started:
return
self._started = True
self._processor.on_trace_start(self)
if mark_as_current:
self._prev_context_token = Scope.set_current_trace(self)
def finish(self, reset_current: bool = False) -> None:
if not self._started:
return
self._processor.on_trace_end(self)
if reset_current and self._prev_context_token is not None:
Scope.reset_current_trace(self._prev_context_token)
self._prev_context_token = None
def __enter__(self) -> Trace:
if self._started:
return self
self.start(mark_as_current=True)
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.finish(reset_current=exc_type is not GeneratorExit)
def export(self) -> dict[str, Any] | None:
return {
"object": "trace",
"id": self.trace_id,
"workflow_name": self.name,
"group_id": self.group_id,
"metadata": self.metadata,
}

View File

@@ -0,0 +1,18 @@
import uuid
from datetime import datetime
from datetime import timezone
def time_iso() -> str:
"""Return the current time in ISO 8601 format."""
return datetime.now(timezone.utc).isoformat()
def gen_trace_id() -> str:
"""Generate a new trace ID."""
return f"trace_{uuid.uuid4().hex}"
def gen_span_id() -> str:
"""Generate a new span ID."""
return f"span_{uuid.uuid4().hex[:24]}"

View File

@@ -1,4 +1,3 @@
from onyx.configs.app_configs import LANGFUSE_HOST
from onyx.configs.app_configs import LANGFUSE_PUBLIC_KEY
from onyx.configs.app_configs import LANGFUSE_SECRET_KEY
from onyx.utils.logger import setup_logger
@@ -13,16 +12,20 @@ def setup_langfuse_if_creds_available() -> None:
return
import nest_asyncio # type: ignore
from langfuse import get_client
from openinference.instrumentation.openai_agents import OpenAIAgentsInstrumentor
nest_asyncio.apply()
from openinference.instrumentation.openai_agents import OpenAIAgentsInstrumentor
OpenAIAgentsInstrumentor().instrument()
langfuse = get_client()
try:
if langfuse.auth_check():
logger.notice(f"Langfuse authentication successful (host: {LANGFUSE_HOST})")
else:
logger.warning("Langfuse authentication failed")
except Exception as e:
logger.error(f"Error setting up Langfuse: {e}")
# TODO: this is how the tracing processor will look once we migrate over to new framework
# config = TraceConfig()
# tracer_provider = trace_api.get_tracer_provider()
# tracer = OITracer(
# trace_api.get_tracer(__name__, __version__, tracer_provider),
# config=config,
# )
# set_trace_processors(
# [OpenInferenceTracingProcessor(cast(trace_api.Tracer, tracer))]
# )

View File

@@ -0,0 +1,450 @@
from __future__ import annotations
import logging
from collections import OrderedDict
from collections.abc import Iterable
from collections.abc import Iterator
from collections.abc import Mapping
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import Optional
from typing import Union
from openinference.instrumentation import safe_json_dumps
from openinference.semconv.trace import MessageAttributes
from openinference.semconv.trace import MessageContentAttributes
from openinference.semconv.trace import OpenInferenceLLMProviderValues
from openinference.semconv.trace import OpenInferenceLLMSystemValues
from openinference.semconv.trace import OpenInferenceMimeTypeValues
from openinference.semconv.trace import OpenInferenceSpanKindValues
from openinference.semconv.trace import SpanAttributes
from openinference.semconv.trace import ToolAttributes
from openinference.semconv.trace import ToolCallAttributes
from opentelemetry.context import attach
from opentelemetry.context import detach
from opentelemetry.trace import set_span_in_context
from opentelemetry.trace import Span as OtelSpan
from opentelemetry.trace import Status
from opentelemetry.trace import StatusCode
from opentelemetry.trace import Tracer
from opentelemetry.util.types import AttributeValue
from onyx.tracing.framework.processor_interface import TracingProcessor
from onyx.tracing.framework.span_data import AgentSpanData
from onyx.tracing.framework.span_data import FunctionSpanData
from onyx.tracing.framework.span_data import GenerationSpanData
from onyx.tracing.framework.span_data import SpanData
from onyx.tracing.framework.spans import Span
from onyx.tracing.framework.traces import Trace
logger = logging.getLogger(__name__)
class OpenInferenceTracingProcessor(TracingProcessor):
_MAX_HANDOFFS_IN_FLIGHT = 1000
def __init__(self, tracer: Tracer) -> None:
self._tracer = tracer
self._root_spans: dict[str, OtelSpan] = {}
self._otel_spans: dict[str, OtelSpan] = {}
self._tokens: dict[str, object] = {}
# This captures in flight handoff. Once the handoff is complete, the entry is deleted
# If the handoff does not complete, the entry stays in the dict.
# Use an OrderedDict and _MAX_HANDOFFS_IN_FLIGHT to cap the size of the dict
# in case there are large numbers of orphaned handoffs
self._reverse_handoffs_dict: OrderedDict[str, str] = OrderedDict()
self._first_input: dict[str, Any] = {}
self._last_output: dict[str, Any] = {}
def on_trace_start(self, trace: Trace) -> None:
"""Called when a trace is started.
Args:
trace: The trace that started.
"""
otel_span = self._tracer.start_span(
name=trace.name,
attributes={
OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.AGENT.value,
},
)
self._root_spans[trace.trace_id] = otel_span
def on_trace_end(self, trace: Trace) -> None:
"""Called when a trace is finished.
Args:
trace: The trace that started.
"""
if root_span := self._root_spans.pop(trace.trace_id, None):
# Get the first input and last output for this specific trace
trace_first_input = self._first_input.pop(trace.trace_id, None)
trace_last_output = self._last_output.pop(trace.trace_id, None)
# Set input/output attributes on the root span
if trace_first_input is not None:
try:
root_span.set_attribute(
INPUT_VALUE, safe_json_dumps(trace_first_input)
)
root_span.set_attribute(INPUT_MIME_TYPE, JSON)
except Exception:
# Fallback to string if JSON serialization fails
root_span.set_attribute(INPUT_VALUE, str(trace_first_input))
if trace_last_output is not None:
try:
root_span.set_attribute(
OUTPUT_VALUE, safe_json_dumps(trace_last_output)
)
root_span.set_attribute(OUTPUT_MIME_TYPE, JSON)
except Exception:
# Fallback to string if JSON serialization fails
root_span.set_attribute(OUTPUT_VALUE, str(trace_last_output))
root_span.set_status(Status(StatusCode.OK))
root_span.end()
else:
# Clean up stored input/output for this trace if root span doesn't exist
self._first_input.pop(trace.trace_id, None)
self._last_output.pop(trace.trace_id, None)
def on_span_start(self, span: Span[Any]) -> None:
"""Called when a span is started.
Args:
span: The span that started.
"""
if not span.started_at:
return
start_time = datetime.fromisoformat(span.started_at)
parent_span = (
self._otel_spans.get(span.parent_id)
if span.parent_id
else self._root_spans.get(span.trace_id)
)
context = set_span_in_context(parent_span) if parent_span else None
span_name = _get_span_name(span)
otel_span = self._tracer.start_span(
name=span_name,
context=context,
start_time=_as_utc_nano(start_time),
attributes={
OPENINFERENCE_SPAN_KIND: _get_span_kind(span.span_data),
LLM_SYSTEM: OpenInferenceLLMSystemValues.OPENAI.value,
},
)
self._otel_spans[span.span_id] = otel_span
self._tokens[span.span_id] = attach(set_span_in_context(otel_span))
def on_span_end(self, span: Span[Any]) -> None:
"""Called when a span is finished. Should not block or raise exceptions.
Args:
span: The span that finished.
"""
if token := self._tokens.pop(span.span_id, None):
detach(token) # type: ignore[arg-type]
if not (otel_span := self._otel_spans.pop(span.span_id, None)):
return
otel_span.update_name(_get_span_name(span))
# flatten_attributes: dict[str, AttributeValue] = dict(_flatten(span.export()))
# otel_span.set_attributes(flatten_attributes)
data = span.span_data
if isinstance(data, GenerationSpanData):
for k, v in _get_attributes_from_generation_span_data(data):
otel_span.set_attribute(k, v)
elif isinstance(data, FunctionSpanData):
for k, v in _get_attributes_from_function_span_data(data):
otel_span.set_attribute(k, v)
elif isinstance(data, AgentSpanData):
otel_span.set_attribute(GRAPH_NODE_ID, data.name)
# Lookup the parent node if exists
key = f"{data.name}:{span.trace_id}"
if parent_node := self._reverse_handoffs_dict.pop(key, None):
otel_span.set_attribute(GRAPH_NODE_PARENT_ID, parent_node)
end_time: Optional[int] = None
if span.ended_at:
try:
end_time = _as_utc_nano(datetime.fromisoformat(span.ended_at))
except ValueError:
pass
otel_span.set_status(status=_get_span_status(span))
otel_span.end(end_time)
# Store first input and last output per trace_id
trace_id = span.trace_id
input_: Optional[Any] = None
output: Optional[Any] = None
if isinstance(data, FunctionSpanData):
input_ = data.input
output = data.output
elif isinstance(data, GenerationSpanData):
input_ = data.input
output = data.output
if trace_id not in self._first_input and input_ is not None:
self._first_input[trace_id] = input_
if output is not None:
self._last_output[trace_id] = output
def force_flush(self) -> None:
"""Forces an immediate flush of all queued spans/traces."""
# TODO
def shutdown(self) -> None:
"""Called when the application stops."""
# TODO
def _as_utc_nano(dt: datetime) -> int:
return int(dt.astimezone(timezone.utc).timestamp() * 1_000_000_000)
def _get_span_name(obj: Span[Any]) -> str:
if hasattr(data := obj.span_data, "name") and isinstance(name := data.name, str):
return name
return obj.span_data.type # type: ignore[no-any-return]
def _get_span_kind(obj: SpanData) -> str:
if isinstance(obj, AgentSpanData):
return OpenInferenceSpanKindValues.AGENT.value
if isinstance(obj, FunctionSpanData):
return OpenInferenceSpanKindValues.TOOL.value
if isinstance(obj, GenerationSpanData):
return OpenInferenceSpanKindValues.LLM.value
return OpenInferenceSpanKindValues.CHAIN.value
def _get_attributes_from_generation_span_data(
obj: GenerationSpanData,
) -> Iterator[tuple[str, AttributeValue]]:
if isinstance(model := obj.model, str):
yield LLM_MODEL_NAME, model
if isinstance(obj.model_config, dict) and (
param := {k: v for k, v in obj.model_config.items() if v is not None}
):
yield LLM_INVOCATION_PARAMETERS, safe_json_dumps(param)
if base_url := param.get("base_url"):
if "api.openai.com" in base_url:
yield LLM_PROVIDER, OpenInferenceLLMProviderValues.OPENAI.value
yield from _get_attributes_from_chat_completions_input(obj.input)
yield from _get_attributes_from_chat_completions_output(obj.output)
yield from _get_attributes_from_chat_completions_usage(obj.usage)
def _get_attributes_from_chat_completions_input(
obj: Optional[Iterable[Mapping[str, Any]]],
) -> Iterator[tuple[str, AttributeValue]]:
if not obj:
return
try:
yield INPUT_VALUE, safe_json_dumps(obj)
yield INPUT_MIME_TYPE, JSON
except Exception:
pass
yield from _get_attributes_from_chat_completions_message_dicts(
obj,
f"{LLM_INPUT_MESSAGES}.",
)
def _get_attributes_from_chat_completions_output(
obj: Optional[Iterable[Mapping[str, Any]]],
) -> Iterator[tuple[str, AttributeValue]]:
if not obj:
return
try:
yield OUTPUT_VALUE, safe_json_dumps(obj)
yield OUTPUT_MIME_TYPE, JSON
except Exception:
pass
yield from _get_attributes_from_chat_completions_message_dicts(
obj,
f"{LLM_OUTPUT_MESSAGES}.",
)
def _get_attributes_from_chat_completions_message_dicts(
obj: Iterable[Mapping[str, Any]],
prefix: str = "",
msg_idx: int = 0,
tool_call_idx: int = 0,
) -> Iterator[tuple[str, AttributeValue]]:
if not isinstance(obj, Iterable):
return
for msg in obj:
if isinstance(role := msg.get("role"), str):
yield f"{prefix}{msg_idx}.{MESSAGE_ROLE}", role
if content := msg.get("content"):
yield from _get_attributes_from_chat_completions_message_content(
content,
f"{prefix}{msg_idx}.",
)
if isinstance(tool_call_id := msg.get("tool_call_id"), str):
yield f"{prefix}{msg_idx}.{MESSAGE_TOOL_CALL_ID}", tool_call_id
if isinstance(tool_calls := msg.get("tool_calls"), Iterable):
for tc in tool_calls:
yield from _get_attributes_from_chat_completions_tool_call_dict(
tc,
f"{prefix}{msg_idx}.{MESSAGE_TOOL_CALLS}.{tool_call_idx}.",
)
tool_call_idx += 1
msg_idx += 1
def _get_attributes_from_chat_completions_message_content(
obj: Union[str, Iterable[Mapping[str, Any]]],
prefix: str = "",
) -> Iterator[tuple[str, AttributeValue]]:
if isinstance(obj, str):
yield f"{prefix}{MESSAGE_CONTENT}", obj
elif isinstance(obj, Iterable):
for i, item in enumerate(obj):
if not isinstance(item, Mapping):
continue
yield from _get_attributes_from_chat_completions_message_content_item(
item,
f"{prefix}{MESSAGE_CONTENTS}.{i}.",
)
def _get_attributes_from_chat_completions_message_content_item(
obj: Mapping[str, Any],
prefix: str = "",
) -> Iterator[tuple[str, AttributeValue]]:
if obj.get("type") == "text" and (text := obj.get("text")):
yield f"{prefix}{MESSAGE_CONTENT_TYPE}", "text"
yield f"{prefix}{MESSAGE_CONTENT_TEXT}", text
def _get_attributes_from_chat_completions_tool_call_dict(
obj: Mapping[str, Any],
prefix: str = "",
) -> Iterator[tuple[str, AttributeValue]]:
if id_ := obj.get("id"):
yield f"{prefix}{TOOL_CALL_ID}", id_
if function := obj.get("function"):
if name := function.get("name"):
yield f"{prefix}{TOOL_CALL_FUNCTION_NAME}", name
if arguments := function.get("arguments"):
if arguments != "{}":
yield f"{prefix}{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}", arguments
def _get_attributes_from_chat_completions_usage(
obj: Optional[Mapping[str, Any]],
) -> Iterator[tuple[str, AttributeValue]]:
if not obj:
return
if input_tokens := obj.get("input_tokens"):
yield LLM_TOKEN_COUNT_PROMPT, input_tokens
if output_tokens := obj.get("output_tokens"):
yield LLM_TOKEN_COUNT_COMPLETION, output_tokens
# convert dict, tuple, etc into one of these types ['bool', 'str', 'bytes', 'int', 'float']
def _convert_to_primitive(value: Any) -> Union[bool, str, bytes, int, float]:
if isinstance(value, (bool, str, bytes, int, float)):
return value
if isinstance(value, (list, tuple)):
return safe_json_dumps(value)
if isinstance(value, dict):
return safe_json_dumps(value)
return str(value)
def _get_attributes_from_function_span_data(
obj: FunctionSpanData,
) -> Iterator[tuple[str, AttributeValue]]:
yield TOOL_NAME, obj.name
if obj.input:
yield INPUT_VALUE, obj.input
yield INPUT_MIME_TYPE, JSON
if obj.output is not None:
yield OUTPUT_VALUE, _convert_to_primitive(obj.output)
if (
isinstance(obj.output, str)
and len(obj.output) > 1
and obj.output[0] == "{"
and obj.output[-1] == "}"
):
yield OUTPUT_MIME_TYPE, JSON
def _get_span_status(obj: Span[Any]) -> Status:
if error := getattr(obj, "error", None):
return Status(
status_code=StatusCode.ERROR,
description=f"{error.get('message')}: {error.get('data')}",
)
else:
return Status(StatusCode.OK)
def _flatten(
obj: Mapping[str, Any],
prefix: str = "",
) -> Iterator[tuple[str, AttributeValue]]:
for key, value in obj.items():
if isinstance(value, dict):
yield from _flatten(value, f"{prefix}{key}.")
elif isinstance(value, (str, int, float, bool, str)):
yield f"{prefix}{key}", value
else:
yield f"{prefix}{key}", str(value)
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
INPUT_VALUE = SpanAttributes.INPUT_VALUE
LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS
LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
LLM_PROVIDER = SpanAttributes.LLM_PROVIDER
LLM_SYSTEM = SpanAttributes.LLM_SYSTEM
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ = (
SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ
)
LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING = (
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING
)
LLM_TOOLS = SpanAttributes.LLM_TOOLS
METADATA = SpanAttributes.METADATA
OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
TOOL_DESCRIPTION = SpanAttributes.TOOL_DESCRIPTION
TOOL_NAME = SpanAttributes.TOOL_NAME
TOOL_PARAMETERS = SpanAttributes.TOOL_PARAMETERS
GRAPH_NODE_ID = SpanAttributes.GRAPH_NODE_ID
GRAPH_NODE_PARENT_ID = SpanAttributes.GRAPH_NODE_PARENT_ID
MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
MESSAGE_CONTENTS = MessageAttributes.MESSAGE_CONTENTS
MESSAGE_CONTENT_IMAGE = MessageContentAttributes.MESSAGE_CONTENT_IMAGE
MESSAGE_CONTENT_TEXT = MessageContentAttributes.MESSAGE_CONTENT_TEXT
MESSAGE_CONTENT_TYPE = MessageContentAttributes.MESSAGE_CONTENT_TYPE
MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON = (
MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON
)
MESSAGE_FUNCTION_CALL_NAME = MessageAttributes.MESSAGE_FUNCTION_CALL_NAME
MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS
MESSAGE_TOOL_CALL_ID = MessageAttributes.MESSAGE_TOOL_CALL_ID
TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
TOOL_CALL_ID = ToolCallAttributes.TOOL_CALL_ID
TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
JSON = OpenInferenceMimeTypeValues.JSON.value

View File

@@ -160,3 +160,8 @@ def is_valid_email(text: str) -> bool:
def count_punctuation(text: str) -> int:
return sum(1 for char in text if char in string.punctuation)
def remove_markdown_image_references(text: str) -> str:
"""Remove markdown-style image references like ![alt text](url)"""
return re.sub(r"!\[[^\]]*\]\([^\)]+\)", "", text)

View File

@@ -50,7 +50,7 @@ nltk==3.9.1
Office365-REST-Python-Client==2.5.9
oauthlib==3.2.2
openai==2.6.1
openpyxl==3.1.5
openpyxl==3.0.10
passlib==1.7.4
playwright==1.55.0
psutil==5.9.5
@@ -84,7 +84,7 @@ supervisor==4.2.5
RapidFuzz==3.13.0
tiktoken==0.7.0
timeago==1.0.16
types-openpyxl==3.1.5.20250919
types-openpyxl==3.0.4.7
unstructured==0.15.1
unstructured-client==0.25.4
uvicorn==0.35.0
@@ -108,7 +108,7 @@ exa_py==1.15.4
braintrust[openai-agents]==0.2.6
braintrust-langchain==0.0.4
openai-agents==0.4.2
langfuse==3.7.0
langfuse==3.10.0
nest_asyncio==1.6.0
openinference-instrumentation-openai-agents==1.3.0
opentelemetry-proto==1.38.0

View File

@@ -19,3 +19,14 @@ class RerankerProvider(str, Enum):
class EmbedTextType(str, Enum):
QUERY = "query"
PASSAGE = "passage"
class WebSearchProviderType(str, Enum):
GOOGLE_PSE = "google_pse"
SERPER = "serper"
EXA = "exa"
class WebContentProviderType(str, Enum):
ONYX_WEB_CRAWLER = "onyx_web_crawler"
FIRECRAWL = "firecrawl"

View File

@@ -154,7 +154,7 @@ def test_versions_endpoint(client: TestClient) -> None:
assert migration["onyx"] == "airgapped-intfloat-nomic-migration"
assert migration["relational_db"] == "postgres:15.2-alpine"
assert migration["index"] == "vespaengine/vespa:8.277.17"
assert migration["nginx"] == "nginx:1.23.4-alpine"
assert migration["nginx"] == "nginx:1.25.5-alpine"
# Verify versions are different between stable and dev
assert stable["onyx"] != dev["onyx"], "Stable and dev versions should be different"

View File

@@ -2,6 +2,8 @@ import os
from datetime import datetime
from datetime import timezone
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
@@ -518,3 +520,52 @@ class TestHubSpotConnector:
note_url = connector._get_object_url("notes", "44444")
expected_note_url = "https://app.hubspot.com/contacts/12345/objects/0-4/44444"
assert note_url == expected_note_url
def test_ticket_with_none_content(self) -> None:
"""Test that tickets with None content are handled gracefully."""
connector = HubSpotConnector(object_types=["tickets"], batch_size=10)
connector._access_token = "mock_token"
connector._portal_id = "mock_portal_id"
# Create a mock ticket with None content
mock_ticket = MagicMock()
mock_ticket.id = "12345"
mock_ticket.properties = {
"subject": "Test Ticket",
"content": None, # This is the key test case
"hs_ticket_priority": "HIGH",
}
mock_ticket.updated_at = datetime.now(timezone.utc)
# Mock the HubSpot API client
mock_api_client = MagicMock()
# Mock the API calls and associated object methods
with patch(
"onyx.connectors.hubspot.connector.HubSpot"
) as MockHubSpot, patch.object(
connector, "_paginated_results"
) as mock_paginated, patch.object(
connector, "_get_associated_objects", return_value=[]
), patch.object(
connector, "_get_associated_notes", return_value=[]
):
MockHubSpot.return_value = mock_api_client
mock_paginated.return_value = iter([mock_ticket])
# This should not raise a validation error
document_batches = connector._process_tickets()
first_batch = next(document_batches, None)
# Verify the document was created successfully
assert first_batch is not None
assert len(first_batch) == 1
doc = first_batch[0]
assert doc.id == "hubspot_ticket_12345"
assert doc.semantic_identifier == "Test Ticket"
# Verify the first section has an empty string, not None
assert len(doc.sections) > 0
assert doc.sections[0].text == "" # Should be empty string, not None
assert doc.sections[0].link is not None

View File

@@ -1,4 +1,6 @@
import os
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import Mock
from unittest.mock import patch
from uuid import uuid4
@@ -8,9 +10,11 @@ os.environ["MODEL_SERVER_HOST"] = "disabled"
os.environ["MODEL_SERVER_PORT"] = "9000"
from sqlalchemy.orm import Session
from slack_sdk.errors import SlackApiError
from onyx.configs.constants import FederatedConnectorSource
from onyx.context.search.enums import RecencyBiasSetting
from onyx.context.search.federated.slack_search import fetch_and_cache_channel_metadata
from onyx.db.models import DocumentSet
from onyx.db.models import FederatedConnector
from onyx.db.models import LLMProvider
@@ -573,3 +577,174 @@ class TestSlackBotFederatedSearch:
finally:
self._teardown_common_mocks(patches)
@patch("onyx.context.search.federated.slack_search.get_redis_client")
@patch("onyx.context.search.federated.slack_search.WebClient")
def test_missing_scope_resilience(
mock_web_client: Mock, mock_redis_client: Mock
) -> None:
"""Test that missing scopes are handled gracefully"""
# Setup mock Redis client
mock_redis = MagicMock()
mock_redis.get.return_value = None # Cache miss
mock_redis_client.return_value = mock_redis
# Setup mock Slack client that simulates missing_scope error
mock_client_instance = MagicMock()
mock_web_client.return_value = mock_client_instance
# Track which channel types were attempted
attempted_types: list[str] = []
def mock_conversations_list(types: str | None = None, **kwargs: Any) -> MagicMock:
if types:
attempted_types.append(types)
# First call: all types including mpim -> missing_scope error
if types and "mpim" in types:
error_response = {
"ok": False,
"error": "missing_scope",
"needed": "mpim:read",
"provided": "identify,channels:history,channels:read,groups:read,im:read,search:read",
}
raise SlackApiError("missing_scope", error_response)
# Second call: without mpim -> success
mock_response = MagicMock()
mock_response.validate.return_value = None
mock_response.data = {
"channels": [
{
"id": "C1234567890",
"name": "general",
"is_channel": True,
"is_private": False,
"is_group": False,
"is_mpim": False,
"is_im": False,
"is_member": True,
},
{
"id": "D9876543210",
"name": "",
"is_channel": False,
"is_private": False,
"is_group": False,
"is_mpim": False,
"is_im": True,
"is_member": True,
},
],
"response_metadata": {},
}
return mock_response
mock_client_instance.conversations_list.side_effect = mock_conversations_list
# Call the function
result = fetch_and_cache_channel_metadata(
access_token="xoxp-test-token",
team_id="T1234567890",
include_private=True,
)
# Assertions
# Should have attempted twice: once with mpim, once without
assert len(attempted_types) == 2, f"Expected 2 attempts, got {len(attempted_types)}"
assert "mpim" in attempted_types[0], "First attempt should include mpim"
assert "mpim" not in attempted_types[1], "Second attempt should not include mpim"
# Should have successfully returned channels despite missing scope
assert len(result) == 2, f"Expected 2 channels, got {len(result)}"
assert "C1234567890" in result, "Should have public channel"
assert "D9876543210" in result, "Should have DM channel"
# Verify channel metadata structure
assert result["C1234567890"]["name"] == "general"
assert result["C1234567890"]["type"] == "public_channel"
assert result["D9876543210"]["type"] == "im"
@patch("onyx.context.search.federated.slack_search.get_redis_client")
@patch("onyx.context.search.federated.slack_search.WebClient")
def test_multiple_missing_scopes_resilience(
mock_web_client: Mock, mock_redis_client: Mock
) -> None:
"""Test handling multiple missing scopes gracefully"""
# Setup mock Redis client
mock_redis = MagicMock()
mock_redis.get.return_value = None # Cache miss
mock_redis_client.return_value = mock_redis
# Setup mock Slack client
mock_client_instance = MagicMock()
mock_web_client.return_value = mock_client_instance
# Track attempts
attempted_types: list[str] = []
def mock_conversations_list(types: str | None = None, **kwargs: Any) -> MagicMock:
if types:
attempted_types.append(types)
# First: mpim missing
if types and "mpim" in types:
error_response = {
"ok": False,
"error": "missing_scope",
"needed": "mpim:read",
"provided": "identify,channels:history,channels:read,groups:read",
}
raise SlackApiError("missing_scope", error_response)
# Second: im missing
if types and "im" in types:
error_response = {
"ok": False,
"error": "missing_scope",
"needed": "im:read",
"provided": "identify,channels:history,channels:read,groups:read",
}
raise SlackApiError("missing_scope", error_response)
# Third: success with only public and private channels
mock_response = MagicMock()
mock_response.validate.return_value = None
mock_response.data = {
"channels": [
{
"id": "C1234567890",
"name": "general",
"is_channel": True,
"is_private": False,
"is_group": False,
"is_mpim": False,
"is_im": False,
"is_member": True,
}
],
"response_metadata": {},
}
return mock_response
mock_client_instance.conversations_list.side_effect = mock_conversations_list
# Call the function
result = fetch_and_cache_channel_metadata(
access_token="xoxp-test-token",
team_id="T1234567890",
include_private=True,
)
# Should gracefully handle multiple missing scopes
assert len(attempted_types) == 3, f"Expected 3 attempts, got {len(attempted_types)}"
assert "mpim" in attempted_types[0], "First attempt should include mpim"
assert "mpim" not in attempted_types[1], "Second attempt should not include mpim"
assert "im" in attempted_types[1], "Second attempt should include im"
assert "im" not in attempted_types[2], "Third attempt should not include im"
# Should still return available channels
assert len(result) == 1, f"Expected 1 channel, got {len(result)}"
assert result["C1234567890"]["name"] == "general"

Some files were not shown because too many files have changed in this diff Show More