Compare commits

...

7 Commits
oops ... v2.0.2

Author SHA1 Message Date
Wenxi
e3358d439b chore: hotfix/v2.0.2 (#5813)
Co-authored-by: Chris Weaver <chris@onyx.app>
Co-authored-by: Evan Lohn <evan@danswer.ai>
Co-authored-by: edwin-onyx <edwin@onyx.app>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
Co-authored-by: SubashMohan <subashmohan75@gmail.com>
Co-authored-by: trial-danswer <trial@danswer.ai>
Co-authored-by: Richard Guan <41275416+rguan72@users.noreply.github.com>
Co-authored-by: Justin Tahara <105671973+justin-tahara@users.noreply.github.com>
Co-authored-by: Nils <94993442+nsklei@users.noreply.github.com>
Co-authored-by: nsklei <nils.kleinrahm@pledoc.de>
Co-authored-by: Shahar Mazor <103638798+Django149@users.noreply.github.com>
Co-authored-by: Raunak Bhagat <r@rabh.io>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: Paulius Klyvis <grafke@users.noreply.github.com>
Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
Co-authored-by: Nikolas Garza <nikolas@Nikolass-MacBook-Pro.attlocal.net>
Co-authored-by: Edwin Luo <edwin@parafin.com>
Co-authored-by: Jessica Singh <86633231+jessicasingh7@users.noreply.github.com>
Co-authored-by: Eli Ben-Shoshan <eli+github@benshoshan.com>
Co-authored-by: Eli Ben-Shoshan <ebs@ufl.edu>
Co-authored-by: Nikolas Garza <nikolas@Nikolass-MBP.attlocal.net>
2025-10-20 18:18:42 -07:00
Chris Weaver
37f47ae83b feat: add new fields to usage report (#5784) (#5790) 2025-10-19 21:55:15 -07:00
Wenxi
f9b9819608 chore: hotfix/v2.0.0 beta.5 (#5775)
Co-authored-by: Chris Weaver <chris@onyx.app>
Co-authored-by: Evan Lohn <evan@danswer.ai>
Co-authored-by: edwin-onyx <edwin@onyx.app>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
Co-authored-by: SubashMohan <subashmohan75@gmail.com>
Co-authored-by: trial-danswer <trial@danswer.ai>
Co-authored-by: Richard Guan <41275416+rguan72@users.noreply.github.com>
Co-authored-by: Justin Tahara <105671973+justin-tahara@users.noreply.github.com>
Co-authored-by: Nils <94993442+nsklei@users.noreply.github.com>
Co-authored-by: nsklei <nils.kleinrahm@pledoc.de>
Co-authored-by: Shahar Mazor <103638798+Django149@users.noreply.github.com>
Co-authored-by: Raunak Bhagat <r@rabh.io>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: Paulius Klyvis <grafke@users.noreply.github.com>
Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
Co-authored-by: Nikolas Garza <nikolas@Nikolass-MacBook-Pro.attlocal.net>
Co-authored-by: Edwin Luo <edwin@parafin.com>
Co-authored-by: Jessica Singh <86633231+jessicasingh7@users.noreply.github.com>
Co-authored-by: Eli Ben-Shoshan <eli+github@benshoshan.com>
Co-authored-by: Eli Ben-Shoshan <ebs@ufl.edu>
Co-authored-by: Nikolas Garza <nikolas@Nikolass-MBP.attlocal.net>
2025-10-17 19:19:47 -07:00
Wenxi
9f91a93471 chore: hot fix v2.0.0-beta.4 (#5737)
Co-authored-by: Chris Weaver <chris@onyx.app>
Co-authored-by: Evan Lohn <evan@danswer.ai>
Co-authored-by: edwin-onyx <edwin@onyx.app>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
Co-authored-by: SubashMohan <subashmohan75@gmail.com>
Co-authored-by: trial-danswer <trial@danswer.ai>
Co-authored-by: Richard Guan <41275416+rguan72@users.noreply.github.com>
Co-authored-by: Justin Tahara <105671973+justin-tahara@users.noreply.github.com>
Co-authored-by: Nils <94993442+nsklei@users.noreply.github.com>
Co-authored-by: nsklei <nils.kleinrahm@pledoc.de>
Co-authored-by: Shahar Mazor <103638798+Django149@users.noreply.github.com>
Co-authored-by: Raunak Bhagat <r@rabh.io>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: Paulius Klyvis <grafke@users.noreply.github.com>
Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
Co-authored-by: Nikolas Garza <nikolas@Nikolass-MacBook-Pro.attlocal.net>
Co-authored-by: Edwin Luo <edwin@parafin.com>
Co-authored-by: Jessica Singh <86633231+jessicasingh7@users.noreply.github.com>
2025-10-15 16:53:36 -07:00
Wenxi
ed40cbdd00 chore: hotfix/v2.0.0 beta.3 (#5715)
Co-authored-by: Chris Weaver <chris@onyx.app>
Co-authored-by: Evan Lohn <evan@danswer.ai>
Co-authored-by: edwin-onyx <edwin@onyx.app>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
Co-authored-by: SubashMohan <subashmohan75@gmail.com>
Co-authored-by: trial-danswer <trial@danswer.ai>
Co-authored-by: Richard Guan <41275416+rguan72@users.noreply.github.com>
Co-authored-by: Justin Tahara <105671973+justin-tahara@users.noreply.github.com>
Co-authored-by: Nils <94993442+nsklei@users.noreply.github.com>
Co-authored-by: nsklei <nils.kleinrahm@pledoc.de>
Co-authored-by: Shahar Mazor <103638798+Django149@users.noreply.github.com>
Co-authored-by: Raunak Bhagat <r@rabh.io>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: Paulius Klyvis <grafke@users.noreply.github.com>
Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
Co-authored-by: Nikolas Garza <nikolas@Nikolass-MacBook-Pro.attlocal.net>
Co-authored-by: Edwin Luo <edwin@parafin.com>
2025-10-14 12:29:51 -07:00
Wenxi
b36910240d chore: Hotfix v2.0.0-beta.2 (#5658)
Co-authored-by: Chris Weaver <chris@onyx.app>
Co-authored-by: Evan Lohn <evan@danswer.ai>
Co-authored-by: edwin-onyx <edwin@onyx.app>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
Co-authored-by: SubashMohan <subashmohan75@gmail.com>
Co-authored-by: trial-danswer <trial@danswer.ai>
Co-authored-by: Richard Guan <41275416+rguan72@users.noreply.github.com>
Co-authored-by: Justin Tahara <105671973+justin-tahara@users.noreply.github.com>
Co-authored-by: Nils <94993442+nsklei@users.noreply.github.com>
Co-authored-by: nsklei <nils.kleinrahm@pledoc.de>
Co-authored-by: Shahar Mazor <103638798+Django149@users.noreply.github.com>
Co-authored-by: Raunak Bhagat <r@rabh.io>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: Paulius Klyvis <grafke@users.noreply.github.com>
2025-10-07 18:30:48 -07:00
Wenxi
488b27ba04 chore: hotfix v2.0.0 beta.1 (#5616)
Co-authored-by: Chris Weaver <chris@onyx.app>
Co-authored-by: Evan Lohn <evan@danswer.ai>
Co-authored-by: edwin-onyx <edwin@onyx.app>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
Co-authored-by: SubashMohan <subashmohan75@gmail.com>
Co-authored-by: Richard Guan <41275416+rguan72@users.noreply.github.com>
Co-authored-by: Justin Tahara <105671973+justin-tahara@users.noreply.github.com>
2025-10-07 17:08:17 -07:00
791 changed files with 39937 additions and 50414 deletions

View File

@@ -8,9 +8,9 @@ on:
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
# don't tag cloud images with "latest"
LATEST_TAG: ${{ contains(github.ref_name, 'latest') && !contains(github.ref_name, 'cloud') }}
# tag nightly builds with "edge"
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
jobs:
build-and-push:
@@ -33,7 +33,16 @@ jobs:
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Check if stable release version
id: check_version
run: |
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
echo "is_stable=true" >> $GITHUB_OUTPUT
else
echo "is_stable=false" >> $GITHUB_OUTPUT
fi
- name: Checkout code
uses: actions/checkout@v4
@@ -46,7 +55,8 @@ jobs:
latest=false
tags: |
type=raw,value=${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -99,6 +109,15 @@ jobs:
# Needed for trivyignore
- name: Checkout
uses: actions/checkout@v4
- name: Check if stable release version
id: check_version
run: |
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
echo "is_stable=true" >> $GITHUB_OUTPUT
else
echo "is_stable=false" >> $GITHUB_OUTPUT
fi
- name: Download digests
uses: actions/download-artifact@v4
@@ -119,7 +138,8 @@ jobs:
latest=false
tags: |
type=raw,value=${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
- name: Login to Docker Hub
uses: docker/login-action@v3

View File

@@ -11,8 +11,8 @@ env:
BUILDKIT_PROGRESS: plain
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
# don't tag cloud images with "latest"
LATEST_TAG: ${{ contains(github.ref_name, 'latest') && !contains(github.ref_name, 'cloud') }}
# tag nightly builds with "edge"
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
jobs:
@@ -145,6 +145,15 @@ jobs:
if: needs.check_model_server_changes.outputs.changed == 'true'
runs-on: ubuntu-latest
steps:
- name: Check if stable release version
id: check_version
run: |
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
echo "is_stable=true" >> $GITHUB_OUTPUT
else
echo "is_stable=false" >> $GITHUB_OUTPUT
fi
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
@@ -157,11 +166,16 @@ jobs:
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
if [[ "${{ env.LATEST_TAG }}" == "true" ]]; then
if [[ "${{ steps.check_version.outputs.is_stable }}" == "true" ]]; then
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:latest \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
fi
if [[ "${{ env.EDGE_TAG }}" == "true" ]]; then
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:edge \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
fi
- name: Run Trivy vulnerability scanner
uses: nick-fields/retry@v3

View File

@@ -7,7 +7,10 @@ on:
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
# tag nightly builds with "edge"
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
DEPLOYMENT: standalone
jobs:
@@ -45,6 +48,15 @@ jobs:
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Check if stable release version
id: check_version
run: |
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
echo "is_stable=true" >> $GITHUB_OUTPUT
else
echo "is_stable=false" >> $GITHUB_OUTPUT
fi
- name: Checkout
uses: actions/checkout@v4
@@ -57,7 +69,8 @@ jobs:
latest=false
tags: |
type=raw,value=${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -107,6 +120,15 @@ jobs:
if: needs.precheck.outputs.should-run == 'true'
runs-on: ubuntu-latest
steps:
- name: Check if stable release version
id: check_version
run: |
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
echo "is_stable=true" >> $GITHUB_OUTPUT
else
echo "is_stable=false" >> $GITHUB_OUTPUT
fi
- name: Download digests
uses: actions/download-artifact@v4
with:
@@ -126,7 +148,8 @@ jobs:
latest=false
tags: |
type=raw,value=${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
- name: Login to Docker Hub
uses: docker/login-action@v3

View File

@@ -35,3 +35,7 @@ jobs:
- name: Pull, Tag and Push API Server Image
run: |
docker buildx imagetools create -t onyxdotapp/onyx-backend:latest onyxdotapp/onyx-backend:${{ github.event.inputs.version }}
- name: Pull, Tag and Push Model Server Image
run: |
docker buildx imagetools create -t onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:${{ github.event.inputs.version }}

View File

@@ -25,9 +25,11 @@ jobs:
- name: Add required Helm repositories
run: |
helm repo add bitnami https://charts.bitnami.com/bitnami
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
helm repo add keda https://kedacore.github.io/charts
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo update
- name: Build chart dependencies

View File

@@ -1,171 +0,0 @@
# This workflow is intended to be manually triggered via the GitHub Action tab.
# Given a hotfix branch, it will attempt to open a PR to all release branches and
# by default auto merge them
name: Hotfix release branches
on:
workflow_dispatch:
inputs:
hotfix_commit:
description: "Hotfix commit hash"
required: true
hotfix_suffix:
description: "Hotfix branch suffix (e.g. hotfix/v0.8-{suffix})"
required: true
release_branch_pattern:
description: "Release branch pattern (regex)"
required: true
default: "release/.*"
auto_merge:
description: "Automatically merge the hotfix PRs"
required: true
type: choice
default: "true"
options:
- true
- false
jobs:
hotfix_release_branches:
permissions: write-all
# See https://runs-on.com/runners/linux/
# use a lower powered instance since this just does i/o to docker hub
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}"]
steps:
# needs RKUO_DEPLOY_KEY for write access to merge PR's
- name: Checkout Repository
uses: actions/checkout@v4
with:
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
fetch-depth: 0
- name: Set up Git user
run: |
git config user.name "Richard Kuo [bot]"
git config user.email "rkuo[bot]@onyx.app"
- name: Fetch All Branches
run: |
git fetch --all --prune
- name: Verify Hotfix Commit Exists
run: |
git rev-parse --verify "${{ github.event.inputs.hotfix_commit }}" || { echo "Commit not found: ${{ github.event.inputs.hotfix_commit }}"; exit 1; }
- name: Get Release Branches
id: get_release_branches
run: |
BRANCHES=$(git branch -r | grep -E "${{ github.event.inputs.release_branch_pattern }}" | sed 's|origin/||' | tr -d ' ')
if [ -z "$BRANCHES" ]; then
echo "No release branches found matching pattern '${{ github.event.inputs.release_branch_pattern }}'."
exit 1
fi
echo "Found release branches:"
echo "$BRANCHES"
# Join the branches into a single line separated by commas
BRANCHES_JOINED=$(echo "$BRANCHES" | tr '\n' ',' | sed 's/,$//')
# Set the branches as an output
echo "branches=$BRANCHES_JOINED" >> $GITHUB_OUTPUT
# notes on all the vagaries of wiring up automated PR's
# https://github.com/peter-evans/create-pull-request/blob/main/docs/concepts-guidelines.md#triggering-further-workflow-runs
# we must use a custom token for GH_TOKEN to trigger the subsequent PR checks
- name: Create and Merge Pull Requests to Matching Release Branches
env:
HOTFIX_COMMIT: ${{ github.event.inputs.hotfix_commit }}
HOTFIX_SUFFIX: ${{ github.event.inputs.hotfix_suffix }}
AUTO_MERGE: ${{ github.event.inputs.auto_merge }}
GH_TOKEN: ${{ secrets.RKUO_PERSONAL_ACCESS_TOKEN }}
run: |
# Get the branches from the previous step
BRANCHES="${{ steps.get_release_branches.outputs.branches }}"
# Convert BRANCHES to an array
IFS=$',' read -ra BRANCH_ARRAY <<< "$BRANCHES"
# Loop through each release branch and create and merge a PR
for RELEASE_BRANCH in "${BRANCH_ARRAY[@]}"; do
echo "Processing $RELEASE_BRANCH..."
# Parse out the release version by removing "release/" from the branch name
RELEASE_VERSION=${RELEASE_BRANCH#release/}
echo "Release version parsed: $RELEASE_VERSION"
HOTFIX_BRANCH="hotfix/${RELEASE_VERSION}-${HOTFIX_SUFFIX}"
echo "Creating PR from $HOTFIX_BRANCH to $RELEASE_BRANCH"
# Checkout the release branch
echo "Checking out $RELEASE_BRANCH"
git checkout "$RELEASE_BRANCH"
# Create the new hotfix branch
if git rev-parse --verify "$HOTFIX_BRANCH" >/dev/null 2>&1; then
echo "Hotfix branch $HOTFIX_BRANCH already exists. Skipping branch creation."
else
echo "Branching $RELEASE_BRANCH to $HOTFIX_BRANCH"
git checkout -b "$HOTFIX_BRANCH"
fi
# Check if the hotfix commit is a merge commit
if git rev-list --merges -n 1 "$HOTFIX_COMMIT" >/dev/null 2>&1; then
# -m 1 uses the target branch as the base (which is what we want)
echo "Hotfix commit $HOTFIX_COMMIT is a merge commit, using -m 1 for cherry-pick"
CHERRY_PICK_CMD="git cherry-pick -m 1 $HOTFIX_COMMIT"
else
CHERRY_PICK_CMD="git cherry-pick $HOTFIX_COMMIT"
fi
# Perform the cherry-pick
echo "Executing: $CHERRY_PICK_CMD"
eval "$CHERRY_PICK_CMD"
if [ $? -ne 0 ]; then
echo "Cherry-pick failed for $HOTFIX_COMMIT on $HOTFIX_BRANCH. Aborting..."
git cherry-pick --abort
continue
fi
# Push the hotfix branch to the remote
echo "Pushing $HOTFIX_BRANCH..."
git push origin "$HOTFIX_BRANCH"
echo "Hotfix branch $HOTFIX_BRANCH created and pushed."
# Check if PR already exists
EXISTING_PR=$(gh pr list --head "$HOTFIX_BRANCH" --base "$RELEASE_BRANCH" --state open --json number --jq '.[0].number')
if [ -n "$EXISTING_PR" ]; then
echo "An open PR already exists: #$EXISTING_PR. Skipping..."
continue
fi
# Create a new PR and capture the output
PR_OUTPUT=$(gh pr create --title "Merge $HOTFIX_BRANCH into $RELEASE_BRANCH" \
--body "Automated PR to merge \`$HOTFIX_BRANCH\` into \`$RELEASE_BRANCH\`." \
--head "$HOTFIX_BRANCH" --base "$RELEASE_BRANCH")
# Extract the URL from the output
PR_URL=$(echo "$PR_OUTPUT" | grep -Eo 'https://github.com/[^ ]+')
echo "Pull request created: $PR_URL"
# Extract PR number from URL
PR_NUMBER=$(basename "$PR_URL")
echo "Pull request created: $PR_NUMBER"
if [ "$AUTO_MERGE" == "true" ]; then
echo "Attempting to merge pull request #$PR_NUMBER"
# Attempt to merge the PR
gh pr merge "$PR_NUMBER" --merge --auto --delete-branch
if [ $? -eq 0 ]; then
echo "Pull request #$PR_NUMBER merged successfully."
else
# Optionally, handle the error or continue
echo "Failed to merge pull request #$PR_NUMBER."
fi
fi
done

View File

@@ -14,12 +14,12 @@ env:
S3_ENDPOINT_URL: "http://localhost:9004"
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ vars.CONFLUENCE_TEST_SPACE }}
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
# LLMs
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

View File

@@ -65,35 +65,45 @@ jobs:
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Adding Helm repositories ==="
helm repo add bitnami https://charts.bitnami.com/bitnami
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo update
- name: Pre-pull critical images
- name: Install Redis operator
if: steps.list-changed.outputs.changed == 'true'
shell: bash
run: |
echo "=== Installing redis-operator CRDs ==="
helm upgrade --install redis-operator ot-container-kit/redis-operator \
--namespace redis-operator --create-namespace --wait --timeout 300s
- name: Pre-pull required images
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Pre-pulling critical images to avoid timeout ==="
# Get kind cluster name
echo "=== Pre-pulling required images to avoid timeout ==="
KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//')
echo "Kind cluster: $KIND_CLUSTER"
# Pre-pull images that are likely to be used
echo "Pre-pulling PostgreSQL image..."
docker pull postgres:15-alpine || echo "Failed to pull postgres:15-alpine"
kind load docker-image postgres:15-alpine --name $KIND_CLUSTER || echo "Failed to load postgres image"
echo "Pre-pulling Redis image..."
docker pull redis:7-alpine || echo "Failed to pull redis:7-alpine"
kind load docker-image redis:7-alpine --name $KIND_CLUSTER || echo "Failed to load redis image"
echo "Pre-pulling Onyx images..."
docker pull docker.io/onyxdotapp/onyx-web-server:latest || echo "Failed to pull onyx web server"
docker pull docker.io/onyxdotapp/onyx-backend:latest || echo "Failed to pull onyx backend"
kind load docker-image docker.io/onyxdotapp/onyx-web-server:latest --name $KIND_CLUSTER || echo "Failed to load onyx web server"
kind load docker-image docker.io/onyxdotapp/onyx-backend:latest --name $KIND_CLUSTER || echo "Failed to load onyx backend"
IMAGES=(
"ghcr.io/cloudnative-pg/cloudnative-pg:1.27.0"
"quay.io/opstree/redis:v7.0.15"
"docker.io/onyxdotapp/onyx-web-server:latest"
)
for image in "${IMAGES[@]}"; do
echo "Pre-pulling $image"
if docker pull "$image"; then
kind load docker-image "$image" --name "$KIND_CLUSTER" || echo "Failed to load $image into kind"
else
echo "Failed to pull $image"
fi
done
echo "=== Images loaded into Kind cluster ==="
docker exec $KIND_CLUSTER-control-plane crictl images | grep -E "(postgres|redis|onyx)" || echo "Some images may still be loading..."
docker exec "$KIND_CLUSTER"-control-plane crictl images | grep -E "(cloudnative-pg|redis|onyx)" || echo "Some images may still be loading..."
- name: Validate chart dependencies
if: steps.list-changed.outputs.changed == 'true'
@@ -149,6 +159,7 @@ jobs:
# Run the actual installation with detailed logging
echo "=== Starting ct install ==="
set +e
ct install --all \
--helm-extra-set-args="\
--set=nginx.enabled=false \
@@ -156,8 +167,10 @@ jobs:
--set=vespa.enabled=false \
--set=slackbot.enabled=false \
--set=postgresql.enabled=true \
--set=postgresql.primary.persistence.enabled=false \
--set=postgresql.nameOverride=cloudnative-pg \
--set=postgresql.cluster.storage.storageClass=standard \
--set=redis.enabled=true \
--set=redis.storageSpec.volumeClaimTemplate.spec.storageClassName=standard \
--set=webserver.replicaCount=1 \
--set=api.replicaCount=0 \
--set=inferenceCapability.replicaCount=0 \
@@ -173,8 +186,16 @@ jobs:
--set=celery_worker_user_files_indexing.replicaCount=0" \
--helm-extra-args="--timeout 900s --debug" \
--debug --config ct.yaml
echo "=== Installation completed successfully ==="
CT_EXIT=$?
set -e
if [[ $CT_EXIT -ne 0 ]]; then
echo "ct install failed with exit code $CT_EXIT"
exit $CT_EXIT
else
echo "=== Installation completed successfully ==="
fi
kubectl get pods --all-namespaces
- name: Post-install verification
@@ -199,7 +220,7 @@ jobs:
echo "=== Recent logs for debugging ==="
kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found"
echo "=== Helm releases ==="
helm list --all-namespaces
# the following would install only changed charts, but we only have one chart so

View File

@@ -19,12 +19,14 @@ env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
@@ -131,6 +133,7 @@ jobs:
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
push: true
outputs: type=registry
no-cache: true
build-model-server-image:
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
@@ -158,6 +161,7 @@ jobs:
push: true
outputs: type=registry
provenance: false
no-cache: true
build-integration-image:
needs: prepare-build
@@ -191,6 +195,7 @@ jobs:
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
push: true
outputs: type=registry
no-cache: true
integration-tests:
needs:
@@ -337,9 +342,11 @@ jobs:
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \

35
.github/workflows/pr-jest-tests.yml vendored Normal file
View File

@@ -0,0 +1,35 @@
name: Run Jest Tests
concurrency:
group: Run-Jest-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on: push
jobs:
jest-tests:
name: Jest Tests
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Setup node
uses: actions/setup-node@v4
with:
node-version: 22
- name: Install node dependencies
working-directory: ./web
run: npm ci
- name: Run Jest tests
working-directory: ./web
run: npm test -- --ci --coverage --maxWorkers=50%
- name: Upload coverage reports
if: always()
uses: actions/upload-artifact@v4
with:
name: jest-coverage-${{ github.run_id }}
path: ./web/coverage
retention-days: 7

View File

@@ -16,12 +16,14 @@ env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
@@ -128,6 +130,7 @@ jobs:
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
push: true
outputs: type=registry
no-cache: true
build-model-server-image:
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
@@ -155,6 +158,7 @@ jobs:
push: true
outputs: type=registry
provenance: false
no-cache: true
build-integration-image:
needs: prepare-build
@@ -188,6 +192,7 @@ jobs:
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
push: true
outputs: type=registry
no-cache: true
integration-tests-mit:
needs:
@@ -334,9 +339,11 @@ jobs:
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \

View File

@@ -13,18 +13,28 @@ env:
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
# Cloudflare R2
R2_ACCOUNT_ID_DAILY_CONNECTOR_TESTS: ${{ vars.R2_ACCOUNT_ID_DAILY_CONNECTOR_TESTS }}
R2_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.R2_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
R2_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.R2_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
# Google Cloud Storage
GCS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.GCS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
GCS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.GCS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ vars.CONFLUENCE_TEST_SPACE }}
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
# Jira
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
# Gong
GONG_ACCESS_KEY: ${{ secrets.GONG_ACCESS_KEY }}
@@ -54,22 +64,22 @@ env:
HUBSPOT_ACCESS_TOKEN: ${{ secrets.HUBSPOT_ACCESS_TOKEN }}
# IMAP
IMAP_HOST: ${{ secrets.IMAP_HOST }}
IMAP_USERNAME: ${{ secrets.IMAP_USERNAME }}
IMAP_HOST: ${{ vars.IMAP_HOST }}
IMAP_USERNAME: ${{ vars.IMAP_USERNAME }}
IMAP_PASSWORD: ${{ secrets.IMAP_PASSWORD }}
IMAP_MAILBOXES: ${{ secrets.IMAP_MAILBOXES }}
IMAP_MAILBOXES: ${{ vars.IMAP_MAILBOXES }}
# Airtable
AIRTABLE_TEST_BASE_ID: ${{ secrets.AIRTABLE_TEST_BASE_ID }}
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
AIRTABLE_TEST_BASE_ID: ${{ vars.AIRTABLE_TEST_BASE_ID }}
AIRTABLE_TEST_TABLE_ID: ${{ vars.AIRTABLE_TEST_TABLE_ID }}
AIRTABLE_TEST_TABLE_NAME: ${{ vars.AIRTABLE_TEST_TABLE_NAME }}
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
# Sharepoint
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
SHAREPOINT_CLIENT_ID: ${{ vars.SHAREPOINT_CLIENT_ID }}
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ vars.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ vars.SHAREPOINT_SITE }}
# Github
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
@@ -100,9 +110,12 @@ env:
BITBUCKET_WORKSPACE: ${{ secrets.BITBUCKET_WORKSPACE }}
BITBUCKET_REPOSITORIES: ${{ secrets.BITBUCKET_REPOSITORIES }}
BITBUCKET_PROJECTS: ${{ secrets.BITBUCKET_PROJECTS }}
BITBUCKET_EMAIL: ${{ secrets.BITBUCKET_EMAIL }}
BITBUCKET_EMAIL: ${{ vars.BITBUCKET_EMAIL }}
BITBUCKET_API_TOKEN: ${{ secrets.BITBUCKET_API_TOKEN }}
# Fireflies
FIREFLIES_API_KEY: ${{ secrets.FIREFLIES_API_KEY }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/
@@ -132,7 +145,24 @@ jobs:
playwright install chromium
playwright install-deps chromium
- name: Run Tests
- name: Detect Connector changes
id: changes
uses: dorny/paths-filter@v3
with:
filters: |
hubspot:
- 'backend/onyx/connectors/hubspot/**'
- 'backend/tests/daily/connectors/hubspot/**'
salesforce:
- 'backend/onyx/connectors/salesforce/**'
- 'backend/tests/daily/connectors/salesforce/**'
github:
- 'backend/onyx/connectors/github/**'
- 'backend/tests/daily/connectors/github/**'
file_processing:
- 'backend/onyx/file_processing/**'
- name: Run Tests (excluding HubSpot, Salesforce, and GitHub)
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
py.test \
@@ -142,7 +172,49 @@ jobs:
-o junit_family=xunit2 \
-xv \
--ff \
backend/tests/daily/connectors
backend/tests/daily/connectors \
--ignore backend/tests/daily/connectors/hubspot \
--ignore backend/tests/daily/connectors/salesforce \
--ignore backend/tests/daily/connectors/github
- name: Run HubSpot Connector Tests
if: ${{ github.event_name == 'schedule' || steps.changes.outputs.hubspot == 'true' || steps.changes.outputs.file_processing == 'true' }}
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
py.test \
-n 8 \
--dist loadfile \
--durations=8 \
-o junit_family=xunit2 \
-xv \
--ff \
backend/tests/daily/connectors/hubspot
- name: Run Salesforce Connector Tests
if: ${{ github.event_name == 'schedule' || steps.changes.outputs.salesforce == 'true' || steps.changes.outputs.file_processing == 'true' }}
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
py.test \
-n 8 \
--dist loadfile \
--durations=8 \
-o junit_family=xunit2 \
-xv \
--ff \
backend/tests/daily/connectors/salesforce
- name: Run GitHub Connector Tests
if: ${{ github.event_name == 'schedule' || steps.changes.outputs.github == 'true' || steps.changes.outputs.file_processing == 'true' }}
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
py.test \
-n 8 \
--dist loadfile \
--durations=8 \
-o junit_family=xunit2 \
-xv \
--ff \
backend/tests/daily/connectors/github
- name: Alert on Failure
if: failure() && github.event_name == 'schedule'

View File

@@ -15,7 +15,7 @@ env:
# Bedrock
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
AWS_REGION_NAME: ${{ vars.AWS_REGION_NAME }}
# API keys for testing
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
@@ -23,7 +23,7 @@ env:
LITELLM_API_URL: ${{ secrets.LITELLM_API_URL }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
AZURE_API_URL: ${{ secrets.AZURE_API_URL }}
AZURE_API_URL: ${{ vars.AZURE_API_URL }}
jobs:
model-check:

View File

@@ -15,6 +15,9 @@ jobs:
# actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
# implement here which needs an actual user's deploy key
# Additional NOTE: even though this is named "rkuo", the actual key is tied to the onyx repo
# and not rkuo's personal account. It is fine to leave this key as is!
- name: Checkout code
uses: actions/checkout@v4
with:

View File

@@ -34,8 +34,7 @@ repos:
hooks:
- id: prettier
types_or: [html, css, javascript, ts, tsx]
additional_dependencies:
- prettier
language_version: system
- repo: local
hooks:

View File

@@ -23,12 +23,10 @@
"Slack Bot",
"Celery primary",
"Celery light",
"Celery heavy",
"Celery background",
"Celery docfetching",
"Celery docprocessing",
"Celery beat",
"Celery monitoring",
"Celery user file processing"
"Celery beat"
],
"presentation": {
"group": "1"
@@ -42,16 +40,29 @@
}
},
{
"name": "Celery (all)",
"name": "Celery (lightweight mode)",
"configurations": [
"Celery primary",
"Celery background",
"Celery beat"
],
"presentation": {
"group": "1"
},
"stopAll": true
},
{
"name": "Celery (standard mode)",
"configurations": [
"Celery primary",
"Celery light",
"Celery heavy",
"Celery kg_processing",
"Celery monitoring",
"Celery user_file_processing",
"Celery docfetching",
"Celery docprocessing",
"Celery beat",
"Celery monitoring",
"Celery user file processing"
"Celery beat"
],
"presentation": {
"group": "1"
@@ -199,6 +210,35 @@
},
"consoleTitle": "Celery light Console"
},
{
"name": "Celery background",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.background",
"worker",
"--pool=threads",
"--concurrency=20",
"--prefetch-multiplier=4",
"--loglevel=INFO",
"--hostname=background@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,user_files_indexing,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery background Console"
},
{
"name": "Celery heavy",
"type": "debugpy",
@@ -221,13 +261,100 @@
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync"
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery heavy Console"
},
{
"name": "Celery kg_processing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.kg_processing",
"worker",
"--pool=threads",
"--concurrency=2",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=kg_processing@%n",
"-Q",
"kg_processing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery kg_processing Console"
},
{
"name": "Celery monitoring",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery monitoring Console"
},
{
"name": "Celery user_file_processing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.user_file_processing",
"worker",
"--pool=threads",
"--concurrency=2",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=user_file_processing@%n",
"-Q",
"user_file_processing,user_file_project_sync,user_file_delete"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery user_file_processing Console"
},
{
"name": "Celery docfetching",
"type": "debugpy",
@@ -311,58 +438,6 @@
},
"consoleTitle": "Celery beat Console"
},
{
"name": "Celery monitoring",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {},
"args": [
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=solo",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery monitoring Console"
},
{
"name": "Celery user file processing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"args": [
"-A",
"onyx.background.celery.versioned_apps.user_file_processing",
"worker",
"--loglevel=INFO",
"--hostname=user_file_processing@%n",
"--pool=threads",
"-Q",
"user_file_processing,user_file_project_sync"
],
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"presentation": {
"group": "2"
},
"consoleTitle": "Celery user file processing Console"
},
{
"name": "Pytest",
"consoleName": "Pytest",

View File

@@ -70,7 +70,12 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
- Single thread (monitoring doesn't need parallelism)
- Cloud-specific monitoring tasks
8. **Beat Worker** (`beat`)
8. **User File Processing Worker** (`user_file_processing`)
- Processes user-uploaded files
- Handles user file indexing and project synchronization
- Configurable concurrency
9. **Beat Worker** (`beat`)
- Celery's scheduler for periodic tasks
- Uses DynamicTenantScheduler for multi-tenant support
- Schedules tasks like:
@@ -82,6 +87,31 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
- Monitoring tasks (every 5 minutes)
- Cleanup tasks (hourly)
#### Worker Deployment Modes
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
- Runs a single consolidated `background` worker that handles all background tasks:
- Pruning operations (from `heavy` worker)
- Knowledge graph processing (from `kg_processing` worker)
- Monitoring tasks (from `monitoring` worker)
- User file processing (from `user_file_processing` worker)
- Lower resource footprint (single worker process)
- Suitable for smaller deployments or development environments
- Default concurrency: 6 threads
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
- Runs separate specialized workers as documented above (heavy, kg_processing, monitoring, user_file_processing)
- Better isolation and scalability
- Can scale individual workers independently based on workload
- Suitable for production deployments with higher load
The deployment mode affects:
- **Backend**: Worker processes spawned by supervisord or dev scripts
- **Helm**: Which Kubernetes deployments are created
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
#### Key Features
- **Thread-based Workers**: All workers use thread pools (not processes) for stability

View File

@@ -70,7 +70,12 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
- Single thread (monitoring doesn't need parallelism)
- Cloud-specific monitoring tasks
8. **Beat Worker** (`beat`)
8. **User File Processing Worker** (`user_file_processing`)
- Processes user-uploaded files
- Handles user file indexing and project synchronization
- Configurable concurrency
9. **Beat Worker** (`beat`)
- Celery's scheduler for periodic tasks
- Uses DynamicTenantScheduler for multi-tenant support
- Schedules tasks like:
@@ -82,11 +87,39 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
- Monitoring tasks (every 5 minutes)
- Cleanup tasks (hourly)
#### Worker Deployment Modes
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
- Runs a single consolidated `background` worker that handles all background tasks:
- Light worker tasks (Vespa operations, permissions sync, deletion)
- Document processing (indexing pipeline)
- Document fetching (connector data retrieval)
- Pruning operations (from `heavy` worker)
- Knowledge graph processing (from `kg_processing` worker)
- Monitoring tasks (from `monitoring` worker)
- User file processing (from `user_file_processing` worker)
- Lower resource footprint (fewer worker processes)
- Suitable for smaller deployments or development environments
- Default concurrency: 20 threads (increased to handle combined workload)
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
- Runs separate specialized workers as documented above (light, docprocessing, docfetching, heavy, kg_processing, monitoring, user_file_processing)
- Better isolation and scalability
- Can scale individual workers independently based on workload
- Suitable for production deployments with higher load
The deployment mode affects:
- **Backend**: Worker processes spawned by supervisord or dev scripts
- **Helm**: Which Kubernetes deployments are created
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
#### Key Features
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
- **Tenant Awareness**: Multi-tenant support with per-tenant task isolation. There is a
middleware layer that automatically finds the appropriate tenant ID when sending tasks
- **Tenant Awareness**: Multi-tenant support with per-tenant task isolation. There is a
middleware layer that automatically finds the appropriate tenant ID when sending tasks
via Celery Beat.
- **Task Prioritization**: High, Medium, Low priority queues
- **Monitoring**: Built-in heartbeat and liveness checking

View File

@@ -13,8 +13,7 @@ As an open source project in a rapidly changing space, we welcome all contributi
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to start for contribution ideas.
To ensure that your contribution is aligned with the project's direction, please reach out to any maintainer on the Onyx team
via [Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-34lu4m7xg-TsKGO6h8PDvR5W27zTdyhA) /
[Discord](https://discord.gg/TDJ59cGV2X) or [email](mailto:founders@onyx.app).
via [Discord](https://discord.gg/4NA5SbzrWb) or [email](mailto:hello@onyx.app).
Issues that have been explicitly approved by the maintainers (aligned with the direction of the project)
will be marked with the `approved by maintainers` label.
@@ -28,8 +27,7 @@ Your input is vital to making sure that Onyx moves in the right direction.
Before starting on implementation, please raise a GitHub issue.
Also, always feel free to message the founders (Chris Weaver / Yuhong Sun) on
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-34lu4m7xg-TsKGO6h8PDvR5W27zTdyhA) /
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.
[Discord](https://discord.gg/4NA5SbzrWb) directly about anything at all.
### Contributing Code
@@ -46,9 +44,7 @@ Our goal is to make contributing as easy as possible. If you run into any issues
That way we can help future contributors and users can avoid the same issue.
We also have support channels and generally interesting discussions on our
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA)
and
[Discord](https://discord.gg/TDJ59cGV2X).
[Discord](https://discord.gg/4NA5SbzrWb).
We would love to see you there!
@@ -105,6 +101,11 @@ pip install -r backend/requirements/ee.txt
pip install -r backend/requirements/model_server.txt
```
Fix vscode/cursor auto-imports:
```bash
pip install -e .
```
Install Playwright for Python (headless browser required by the Web Connector)
In the activated Python virtualenv, install Playwright for Python by running:
@@ -117,8 +118,15 @@ You may have to deactivate and reactivate your virtualenv for `playwright` to ap
#### Frontend: Node dependencies
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
Once the above is done, navigate to `onyx/web` run:
Onyx uses Node v22.20.0. We highly recommend you use [Node Version Manager (nvm)](https://github.com/nvm-sh/nvm)
to manage your Node installations. Once installed, you can run
```bash
nvm install 22 && nvm use 22`
node -v # verify your active version
```
Navigate to `onyx/web` and run:
```bash
npm i
@@ -129,8 +137,6 @@ npm i
### Backend
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
First, install pre-commit (if you don't have it already) following the instructions
[here](https://pre-commit.com/#installation).
With the virtual environment active, install the pre-commit library with:
@@ -150,15 +156,17 @@ To run the mypy checks manually, run `python -m mypy .` from the `onyx/backend`
### Web
We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `onyx/web` directory.
We use `prettier` for formatting. The desired version will be installed via a `npm i` from the `onyx/web` directory.
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
Please double check that prettier passes before creating a pull request.
Pre-commit will also run prettier automatically on files you've recently touched. If re-formatted, your commit will fail.
Re-stage your changes and commit again.
# Running the application for development
## Developing using VSCode Debugger (recommended)
We highly recommend using VSCode debugger for development.
**We highly recommend using VSCode debugger for development.**
See [CONTRIBUTING_VSCODE.md](./CONTRIBUTING_VSCODE.md) for more details.
Otherwise, you can follow the instructions below to run the application for development.

View File

@@ -21,6 +21,9 @@ Before starting, make sure the Docker Daemon is running.
5. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
6. Use the debug toolbar to step through code, inspect variables, etc.
Note: Clear and Restart External Volumes and Containers will reset your postgres and Vespa (relational-db and index).
Only run this if you are okay with wiping your data.
## Features
- Hot reload is enabled for the web server and API servers

View File

@@ -111,6 +111,8 @@ COPY ./static /app/static
# Escape hatch scripts
COPY ./scripts/debugging /app/scripts/debugging
COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
COPY ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
RUN chmod +x /app/scripts/supervisord_entrypoint.sh
# Put logo in assets
COPY ./assets /app/assets

View File

@@ -0,0 +1,153 @@
"""add permission sync attempt tables
Revision ID: 03d710ccf29c
Revises: 96a5702df6aa
Create Date: 2025-09-11 13:30:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "03d710ccf29c" # Generate a new unique ID
down_revision = "96a5702df6aa"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create the permission sync status enum
permission_sync_status_enum = sa.Enum(
"not_started",
"in_progress",
"success",
"canceled",
"failed",
"completed_with_errors",
name="permissionsyncstatus",
native_enum=False,
)
# Create doc_permission_sync_attempt table
op.create_table(
"doc_permission_sync_attempt",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False),
sa.Column("status", permission_sync_status_enum, nullable=False),
sa.Column("total_docs_synced", sa.Integer(), nullable=True),
sa.Column("docs_with_permission_errors", sa.Integer(), nullable=True),
sa.Column("error_message", sa.Text(), nullable=True),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("time_started", sa.DateTime(timezone=True), nullable=True),
sa.Column("time_finished", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(
["connector_credential_pair_id"],
["connector_credential_pair.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# Create indexes for doc_permission_sync_attempt
op.create_index(
"ix_doc_permission_sync_attempt_time_created",
"doc_permission_sync_attempt",
["time_created"],
unique=False,
)
op.create_index(
"ix_permission_sync_attempt_latest_for_cc_pair",
"doc_permission_sync_attempt",
["connector_credential_pair_id", "time_created"],
unique=False,
)
op.create_index(
"ix_permission_sync_attempt_status_time",
"doc_permission_sync_attempt",
["status", sa.text("time_finished DESC")],
unique=False,
)
# Create external_group_permission_sync_attempt table
# connector_credential_pair_id is nullable - group syncs can be global (e.g., Confluence)
op.create_table(
"external_group_permission_sync_attempt",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=True),
sa.Column("status", permission_sync_status_enum, nullable=False),
sa.Column("total_users_processed", sa.Integer(), nullable=True),
sa.Column("total_groups_processed", sa.Integer(), nullable=True),
sa.Column("total_group_memberships_synced", sa.Integer(), nullable=True),
sa.Column("error_message", sa.Text(), nullable=True),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("time_started", sa.DateTime(timezone=True), nullable=True),
sa.Column("time_finished", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(
["connector_credential_pair_id"],
["connector_credential_pair.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# Create indexes for external_group_permission_sync_attempt
op.create_index(
"ix_external_group_permission_sync_attempt_time_created",
"external_group_permission_sync_attempt",
["time_created"],
unique=False,
)
op.create_index(
"ix_group_sync_attempt_cc_pair_time",
"external_group_permission_sync_attempt",
["connector_credential_pair_id", "time_created"],
unique=False,
)
op.create_index(
"ix_group_sync_attempt_status_time",
"external_group_permission_sync_attempt",
["status", sa.text("time_finished DESC")],
unique=False,
)
def downgrade() -> None:
# Drop indexes
op.drop_index(
"ix_group_sync_attempt_status_time",
table_name="external_group_permission_sync_attempt",
)
op.drop_index(
"ix_group_sync_attempt_cc_pair_time",
table_name="external_group_permission_sync_attempt",
)
op.drop_index(
"ix_external_group_permission_sync_attempt_time_created",
table_name="external_group_permission_sync_attempt",
)
op.drop_index(
"ix_permission_sync_attempt_status_time",
table_name="doc_permission_sync_attempt",
)
op.drop_index(
"ix_permission_sync_attempt_latest_for_cc_pair",
table_name="doc_permission_sync_attempt",
)
op.drop_index(
"ix_doc_permission_sync_attempt_time_created",
table_name="doc_permission_sync_attempt",
)
# Drop tables
op.drop_table("external_group_permission_sync_attempt")
op.drop_table("doc_permission_sync_attempt")

View File

@@ -0,0 +1,28 @@
"""reset userfile document_id_migrated field
Revision ID: 40926a4dab77
Revises: 64bd5677aeb6
Create Date: 2025-10-06 16:10:32.898668
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "40926a4dab77"
down_revision = "64bd5677aeb6"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Set all existing records to not migrated
op.execute(
"UPDATE user_file SET document_id_migrated = FALSE "
"WHERE document_id_migrated IS DISTINCT FROM FALSE;"
)
def downgrade() -> None:
# No-op
pass

View File

@@ -0,0 +1,37 @@
"""Add image input support to model config
Revision ID: 64bd5677aeb6
Revises: b30353be4eec
Create Date: 2025-09-28 15:48:12.003612
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "64bd5677aeb6"
down_revision = "b30353be4eec"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"model_configuration",
sa.Column("supports_image_input", sa.Boolean(), nullable=True),
)
# Seems to be left over from when model visibility was introduced and a nullable field.
# Set any null is_visible values to False
connection = op.get_bind()
connection.execute(
sa.text(
"UPDATE model_configuration SET is_visible = false WHERE is_visible IS NULL"
)
)
def downgrade() -> None:
op.drop_column("model_configuration", "supports_image_input")

View File

@@ -0,0 +1,37 @@
"""add queries and is web fetch to iteration answer
Revision ID: 6f4f86aef280
Revises: 03d710ccf29c
Create Date: 2025-10-14 18:08:30.920123
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "6f4f86aef280"
down_revision = "03d710ccf29c"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add is_web_fetch column
op.add_column(
"research_agent_iteration_sub_step",
sa.Column("is_web_fetch", sa.Boolean(), nullable=True),
)
# Add queries column
op.add_column(
"research_agent_iteration_sub_step",
sa.Column("queries", postgresql.JSONB(), nullable=True),
)
def downgrade() -> None:
op.drop_column("research_agent_iteration_sub_step", "queries")
op.drop_column("research_agent_iteration_sub_step", "is_web_fetch")

View File

@@ -0,0 +1,45 @@
"""mcp_tool_enabled
Revision ID: 96a5702df6aa
Revises: 40926a4dab77
Create Date: 2025-10-09 12:10:21.733097
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "96a5702df6aa"
down_revision = "40926a4dab77"
branch_labels = None
depends_on = None
DELETE_DISABLED_TOOLS_SQL = "DELETE FROM tool WHERE enabled = false"
def upgrade() -> None:
op.add_column(
"tool",
sa.Column(
"enabled",
sa.Boolean(),
nullable=False,
server_default=sa.true(),
),
)
op.create_index(
"ix_tool_mcp_server_enabled",
"tool",
["mcp_server_id", "enabled"],
)
# Remove the server default so application controls defaulting
op.alter_column("tool", "enabled", server_default=None)
def downgrade() -> None:
op.execute(DELETE_DISABLED_TOOLS_SQL)
op.drop_index("ix_tool_mcp_server_enabled", table_name="tool")
op.drop_column("tool", "enabled")

View File

@@ -0,0 +1,72 @@
"""personalization_user_info
Revision ID: c8a93a2af083
Revises: 6f4f86aef280
Create Date: 2025-10-14 15:59:03.577343
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "c8a93a2af083"
down_revision = "6f4f86aef280"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column("personal_name", sa.String(), nullable=True),
)
op.add_column(
"user",
sa.Column("personal_role", sa.String(), nullable=True),
)
op.add_column(
"user",
sa.Column(
"use_memories",
sa.Boolean(),
nullable=False,
server_default=sa.true(),
),
)
op.create_table(
"memory",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("memory_text", sa.Text(), nullable=False),
sa.Column("conversation_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("message_id", sa.Integer(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_memory_user_id", "memory", ["user_id"])
def downgrade() -> None:
op.drop_index("ix_memory_user_id", table_name="memory")
op.drop_table("memory")
op.drop_column("user", "use_memories")
op.drop_column("user", "personal_role")
op.drop_column("user", "personal_name")

View File

@@ -1,29 +1,17 @@
from datetime import datetime
from functools import lru_cache
import jwt
import requests
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import status
from jwt import decode as jwt_decode
from jwt import InvalidTokenError
from jwt import PyJWTError
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from ee.onyx.configs.app_configs import JWT_PUBLIC_KEY_URL
from ee.onyx.configs.app_configs import SUPER_CLOUD_API_KEY
from ee.onyx.configs.app_configs import SUPER_USERS
from ee.onyx.db.saml import get_saml_account
from ee.onyx.server.seeding import get_seed_config
from ee.onyx.utils.secrets import extract_hashed_cookie
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.constants import AuthType
from onyx.db.models import User
from onyx.utils.logger import setup_logger
@@ -31,75 +19,11 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
@lru_cache()
def get_public_key() -> str | None:
if JWT_PUBLIC_KEY_URL is None:
logger.error("JWT_PUBLIC_KEY_URL is not set")
return None
response = requests.get(JWT_PUBLIC_KEY_URL)
response.raise_for_status()
return response.text
async def verify_jwt_token(token: str, async_db_session: AsyncSession) -> User | None:
try:
public_key_pem = get_public_key()
if public_key_pem is None:
logger.error("Failed to retrieve public key")
return None
payload = jwt_decode(
token,
public_key_pem,
algorithms=["RS256"],
audience=None,
)
email = payload.get("email")
if email:
result = await async_db_session.execute(
select(User).where(func.lower(User.email) == func.lower(email))
)
return result.scalars().first()
except InvalidTokenError:
logger.error("Invalid JWT token")
get_public_key.cache_clear()
except PyJWTError as e:
logger.error(f"JWT decoding error: {str(e)}")
get_public_key.cache_clear()
return None
def verify_auth_setting() -> None:
# All the Auth flows are valid for EE version
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
async def optional_user_(
request: Request,
user: User | None,
async_db_session: AsyncSession,
) -> User | None:
# Check if the user has a session cookie from SAML
if AUTH_TYPE == AuthType.SAML:
saved_cookie = extract_hashed_cookie(request)
if saved_cookie:
saml_account = await get_saml_account(
cookie=saved_cookie, async_db_session=async_db_session
)
user = saml_account.user if saml_account else None
# If user is still None, check for JWT in Authorization header
if user is None and JWT_PUBLIC_KEY_URL is not None:
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header[len("Bearer ") :].strip()
user = await verify_jwt_token(token, async_db_session)
return user
def get_default_admin_user_emails_() -> list[str]:
seed_config = get_seed_config()
if seed_config and seed_config.admin_user_emails:

View File

@@ -0,0 +1,12 @@
from onyx.background.celery.apps.background import celery_app
celery_app.autodiscover_tasks(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cleanup",
"ee.onyx.background.celery.tasks.tenant_provisioning",
"ee.onyx.background.celery.tasks.query_history",
]
)

View File

@@ -1,123 +1,4 @@
import csv
import io
from datetime import datetime
from celery import shared_task
from celery import Task
from ee.onyx.server.query_history.api import fetch_and_process_chat_session_history
from ee.onyx.server.query_history.api import ONYX_ANONYMIZED_EMAIL
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
from onyx.background.celery.apps.heavy import celery_app
from onyx.background.task_utils import construct_query_history_report_name
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import FileType
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import QueryHistoryType
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.tasks import delete_task_with_id
from onyx.db.tasks import mark_task_as_finished_with_id
from onyx.db.tasks import mark_task_as_started_with_id
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
logger = setup_logger()
@shared_task(
name=OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
bind=True,
trail=False,
)
def export_query_history_task(
self: Task,
*,
start: datetime,
end: datetime,
start_time: datetime,
# Need to include the tenant_id since the TenantAwareTask needs this
tenant_id: str,
) -> None:
if not self.request.id:
raise RuntimeError("No task id defined for this task; cannot identify it")
task_id = self.request.id
stream = io.StringIO()
writer = csv.DictWriter(
stream,
fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()),
)
writer.writeheader()
with get_session_with_current_tenant() as db_session:
try:
mark_task_as_started_with_id(
db_session=db_session,
task_id=task_id,
)
snapshot_generator = fetch_and_process_chat_session_history(
db_session=db_session,
start=start,
end=end,
)
for snapshot in snapshot_generator:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
writer.writerows(
qa_pair.to_json()
for qa_pair in QuestionAnswerPairSnapshot.from_chat_session_snapshot(
snapshot
)
)
except Exception:
logger.exception(f"Failed to export query history with {task_id=}")
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=False,
)
raise
report_name = construct_query_history_report_name(task_id)
with get_session_with_current_tenant() as db_session:
try:
stream.seek(0)
get_default_file_store().save_file(
content=stream,
display_name=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,
file_type=FileType.CSV,
file_metadata={
"start": start.isoformat(),
"end": end.isoformat(),
"start_time": start_time.isoformat(),
},
file_id=report_name,
)
delete_task_with_id(
db_session=db_session,
task_id=task_id,
)
except Exception:
logger.exception(
f"Failed to save query history export file; {report_name=}"
)
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=False,
)
raise
celery_app.autodiscover_tasks(
@@ -125,5 +6,6 @@ celery_app.autodiscover_tasks(
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cleanup",
"ee.onyx.background.celery.tasks.query_history",
]
)

View File

@@ -56,6 +56,12 @@ from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.permission_sync_attempt import complete_doc_permission_sync_attempt
from onyx.db.permission_sync_attempt import create_doc_permission_sync_attempt
from onyx.db.permission_sync_attempt import mark_doc_permission_sync_attempt_failed
from onyx.db.permission_sync_attempt import (
mark_doc_permission_sync_attempt_in_progress,
)
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
@@ -113,6 +119,14 @@ def _get_fence_validation_block_expiration() -> int:
"""Jobs / utils for kicking off doc permissions sync tasks."""
def _fail_doc_permission_sync_attempt(attempt_id: int, error_msg: str) -> None:
"""Helper to mark a doc permission sync attempt as failed with an error message."""
with get_session_with_current_tenant() as db_session:
mark_doc_permission_sync_attempt_failed(
attempt_id, db_session, error_message=error_msg
)
def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if external doc permissions sync is due."""
@@ -379,6 +393,15 @@ def connector_permission_sync_generator_task(
doc_permission_sync_ctx_dict["request_id"] = self.request.id
doc_permission_sync_ctx.set(doc_permission_sync_ctx_dict)
with get_session_with_current_tenant() as db_session:
attempt_id = create_doc_permission_sync_attempt(
connector_credential_pair_id=cc_pair_id,
db_session=db_session,
)
task_logger.info(
f"Created doc permission sync attempt: {attempt_id} for cc_pair={cc_pair_id}"
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client()
@@ -389,22 +412,28 @@ def connector_permission_sync_generator_task(
start = time.monotonic()
while True:
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
raise ValueError(
error_msg = (
f"connector_permission_sync_generator_task - timed out waiting for fence to be ready: "
f"fence={redis_connector.permissions.fence_key}"
)
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
raise ValueError(error_msg)
if not redis_connector.permissions.fenced: # The fence must exist
raise ValueError(
error_msg = (
f"connector_permission_sync_generator_task - fence not found: "
f"fence={redis_connector.permissions.fence_key}"
)
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
raise ValueError(error_msg)
payload = redis_connector.permissions.payload # The payload must exist
if not payload:
raise ValueError(
error_msg = (
"connector_permission_sync_generator_task: payload invalid or not found"
)
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
raise ValueError(error_msg)
if payload.celery_task_id is None:
logger.info(
@@ -432,9 +461,11 @@ def connector_permission_sync_generator_task(
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
error_msg = (
f"Permission sync task already running, exiting...: cc_pair={cc_pair_id}"
)
task_logger.warning(error_msg)
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
return None
try:
@@ -470,11 +501,15 @@ def connector_permission_sync_generator_task(
source_type = cc_pair.connector.source
sync_config = get_source_perm_sync_config(source_type)
if sync_config is None:
logger.error(f"No sync config found for {source_type}")
error_msg = f"No sync config found for {source_type}"
logger.error(error_msg)
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
return None
if sync_config.doc_sync_config is None:
if sync_config.censoring_config:
error_msg = f"Doc sync config is None but censoring config exists for {source_type}"
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
return None
raise ValueError(
@@ -483,6 +518,8 @@ def connector_permission_sync_generator_task(
logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
mark_doc_permission_sync_attempt_in_progress(attempt_id, db_session)
payload = redis_connector.permissions.payload
if not payload:
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
@@ -533,8 +570,9 @@ def connector_permission_sync_generator_task(
)
tasks_generated = 0
docs_with_errors = 0
for doc_external_access in document_external_accesses:
redis_connector.permissions.update_db(
result = redis_connector.permissions.update_db(
lock=lock,
new_permissions=[doc_external_access],
source_string=source_type,
@@ -542,11 +580,23 @@ def connector_permission_sync_generator_task(
credential_id=cc_pair.credential.id,
task_logger=task_logger,
)
tasks_generated += 1
tasks_generated += result.num_updated
docs_with_errors += result.num_errors
task_logger.info(
f"RedisConnector.permissions.generate_tasks finished. "
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated} docs_with_errors={docs_with_errors}"
)
complete_doc_permission_sync_attempt(
db_session=db_session,
attempt_id=attempt_id,
total_docs_synced=tasks_generated,
docs_with_permission_errors=docs_with_errors,
)
task_logger.info(
f"Completed doc permission sync attempt {attempt_id}: "
f"{tasks_generated} docs, {docs_with_errors} errors"
)
redis_connector.permissions.generator_complete = tasks_generated
@@ -561,6 +611,11 @@ def connector_permission_sync_generator_task(
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id}"
)
with get_session_with_current_tenant() as db_session:
mark_doc_permission_sync_attempt_failed(
attempt_id, db_session, error_message=error_msg
)
redis_connector.permissions.generator_clear()
redis_connector.permissions.taskset_clear()
redis_connector.permissions.set_fence(None)

View File

@@ -49,6 +49,16 @@ from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.permission_sync_attempt import complete_external_group_sync_attempt
from onyx.db.permission_sync_attempt import (
create_external_group_sync_attempt,
)
from onyx.db.permission_sync_attempt import (
mark_external_group_sync_attempt_failed,
)
from onyx.db.permission_sync_attempt import (
mark_external_group_sync_attempt_in_progress,
)
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.redis.redis_connector import RedisConnector
@@ -70,6 +80,14 @@ logger = setup_logger()
_EXTERNAL_GROUP_BATCH_SIZE = 100
def _fail_external_group_sync_attempt(attempt_id: int, error_msg: str) -> None:
"""Helper to mark an external group sync attempt as failed with an error message."""
with get_session_with_current_tenant() as db_session:
mark_external_group_sync_attempt_failed(
attempt_id, db_session, error_message=error_msg
)
def _get_fence_validation_block_expiration() -> int:
"""
Compute the expiration time for the fence validation block signal.
@@ -449,6 +467,16 @@ def _perform_external_group_sync(
cc_pair_id: int,
tenant_id: str,
) -> None:
# Create attempt record at the start
with get_session_with_current_tenant() as db_session:
attempt_id = create_external_group_sync_attempt(
connector_credential_pair_id=cc_pair_id,
db_session=db_session,
)
logger.info(
f"Created external group sync attempt: {attempt_id} for cc_pair={cc_pair_id}"
)
with get_session_with_current_tenant() as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
@@ -463,11 +491,13 @@ def _perform_external_group_sync(
if sync_config is None:
msg = f"No sync config found for {source_type} for cc_pair: {cc_pair_id}"
emit_background_error(msg, cc_pair_id=cc_pair_id)
_fail_external_group_sync_attempt(attempt_id, msg)
raise ValueError(msg)
if sync_config.group_sync_config is None:
msg = f"No group sync config found for {source_type} for cc_pair: {cc_pair_id}"
emit_background_error(msg, cc_pair_id=cc_pair_id)
_fail_external_group_sync_attempt(attempt_id, msg)
raise ValueError(msg)
ext_group_sync_func = sync_config.group_sync_config.group_sync_func
@@ -477,14 +507,27 @@ def _perform_external_group_sync(
)
mark_old_external_groups_as_stale(db_session, cc_pair_id)
# Mark attempt as in progress
mark_external_group_sync_attempt_in_progress(attempt_id, db_session)
logger.info(f"Marked external group sync attempt {attempt_id} as in progress")
logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
)
external_user_group_batch: list[ExternalUserGroup] = []
seen_users: set[str] = set() # Track unique users across all groups
total_groups_processed = 0
total_group_memberships_synced = 0
try:
external_user_group_generator = ext_group_sync_func(tenant_id, cc_pair)
for external_user_group in external_user_group_generator:
external_user_group_batch.append(external_user_group)
# Track progress
total_groups_processed += 1
total_group_memberships_synced += len(external_user_group.user_emails)
seen_users = seen_users.union(external_user_group.user_emails)
if len(external_user_group_batch) >= _EXTERNAL_GROUP_BATCH_SIZE:
logger.debug(
f"New external user groups: {external_user_group_batch}"
@@ -506,6 +549,13 @@ def _perform_external_group_sync(
source=cc_pair.connector.source,
)
except Exception as e:
format_error_for_logging(e)
# Mark as failed (this also updates progress to show partial progress)
mark_external_group_sync_attempt_failed(
attempt_id, db_session, error_message=str(e)
)
# TODO: add some notification to the admins here
logger.exception(
f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
@@ -517,6 +567,24 @@ def _perform_external_group_sync(
)
remove_stale_external_groups(db_session, cc_pair_id)
# Calculate total unique users processed
total_users_processed = len(seen_users)
# Complete the sync attempt with final progress
complete_external_group_sync_attempt(
db_session=db_session,
attempt_id=attempt_id,
total_users_processed=total_users_processed,
total_groups_processed=total_groups_processed,
total_group_memberships_synced=total_group_memberships_synced,
errors_encountered=0,
)
logger.info(
f"Completed external group sync attempt {attempt_id}: "
f"{total_groups_processed} groups, {total_users_processed} users, "
f"{total_group_memberships_synced} memberships"
)
mark_all_relevant_cc_pairs_as_external_group_synced(db_session, cc_pair)

View File

@@ -0,0 +1,119 @@
import csv
import io
from datetime import datetime
from celery import shared_task
from celery import Task
from ee.onyx.server.query_history.api import fetch_and_process_chat_session_history
from ee.onyx.server.query_history.api import ONYX_ANONYMIZED_EMAIL
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
from onyx.background.task_utils import construct_query_history_report_name
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import FileType
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import QueryHistoryType
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.tasks import delete_task_with_id
from onyx.db.tasks import mark_task_as_finished_with_id
from onyx.db.tasks import mark_task_as_started_with_id
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
logger = setup_logger()
@shared_task(
name=OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
bind=True,
trail=False,
)
def export_query_history_task(
self: Task,
*,
start: datetime,
end: datetime,
start_time: datetime,
# Need to include the tenant_id since the TenantAwareTask needs this
tenant_id: str,
) -> None:
if not self.request.id:
raise RuntimeError("No task id defined for this task; cannot identify it")
task_id = self.request.id
stream = io.StringIO()
writer = csv.DictWriter(
stream,
fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()),
)
writer.writeheader()
with get_session_with_current_tenant() as db_session:
try:
mark_task_as_started_with_id(
db_session=db_session,
task_id=task_id,
)
snapshot_generator = fetch_and_process_chat_session_history(
db_session=db_session,
start=start,
end=end,
)
for snapshot in snapshot_generator:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
writer.writerows(
qa_pair.to_json()
for qa_pair in QuestionAnswerPairSnapshot.from_chat_session_snapshot(
snapshot
)
)
except Exception:
logger.exception(f"Failed to export query history with {task_id=}")
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=False,
)
raise
report_name = construct_query_history_report_name(task_id)
with get_session_with_current_tenant() as db_session:
try:
stream.seek(0)
get_default_file_store().save_file(
content=stream,
display_name=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,
file_type=FileType.CSV,
file_metadata={
"start": start.isoformat(),
"end": end.isoformat(),
"start_time": start_time.isoformat(),
},
file_id=report_name,
)
delete_task_with_id(
db_session=db_session,
task_id=task_id,
)
except Exception:
logger.exception(
f"Failed to save query history export file; {report_name=}"
)
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=False,
)
raise

View File

@@ -1,26 +1,6 @@
import json
import os
# Applicable for OIDC Auth
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
# Applicable for OIDC Auth, allows you to override the scopes that
# are requested from the OIDC provider. Currently used when passing
# over access tokens to tool calls and the tool needs more scopes
OIDC_SCOPE_OVERRIDE: list[str] | None = None
_OIDC_SCOPE_OVERRIDE = os.environ.get("OIDC_SCOPE_OVERRIDE")
if _OIDC_SCOPE_OVERRIDE:
try:
OIDC_SCOPE_OVERRIDE = [
scope.strip() for scope in _OIDC_SCOPE_OVERRIDE.split(",")
]
except Exception:
pass
# Applicable for SAML Auth
SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_config"
#####
# Auto Permission Sync

View File

@@ -73,6 +73,12 @@ def fetch_per_user_query_analytics(
ChatSession.user_id,
)
.join(ChatSession, ChatSession.id == ChatMessage.chat_session_id)
# Include chats that have no explicit feedback instead of dropping them
.join(
ChatMessageFeedback,
ChatMessageFeedback.chat_message_id == ChatMessage.id,
isouter=True,
)
.where(
ChatMessage.time_sent >= start,
)

View File

@@ -50,6 +50,25 @@ def get_empty_chat_messages_entries__paginated(
if message.message_type != MessageType.USER:
continue
# Get user email
user_email = chat_session.user.email if chat_session.user else None
# Get assistant name (from session persona, or alternate if specified)
assistant_name = None
if message.alternate_assistant_id:
# If there's an alternate assistant, we need to fetch it
from onyx.db.models import Persona
alternate_persona = (
db_session.query(Persona)
.filter(Persona.id == message.alternate_assistant_id)
.first()
)
if alternate_persona:
assistant_name = alternate_persona.name
elif chat_session.persona:
assistant_name = chat_session.persona.name
message_skeletons.append(
ChatMessageSkeleton(
message_id=message.id,
@@ -57,6 +76,9 @@ def get_empty_chat_messages_entries__paginated(
user_id=str(chat_session.user_id) if chat_session.user_id else None,
flow_type=flow_type,
time_sent=message.time_sent,
assistant_name=assistant_name,
user_email=user_email,
number_of_tokens=message.token_count,
)
)
if len(chat_sessions) == 0:

View File

@@ -124,9 +124,9 @@ def get_space_permission(
and not space_permissions.external_user_group_ids
):
logger.warning(
f"No permissions found for space '{space_key}'. This is very unlikely"
"to be correct and is more likely caused by an access token with"
"insufficient permissions. Make sure that the access token has Admin"
f"No permissions found for space '{space_key}'. This is very unlikely "
"to be correct and is more likely caused by an access token with "
"insufficient permissions. Make sure that the access token has Admin "
f"permissions for space '{space_key}'"
)

View File

@@ -26,7 +26,7 @@ def _get_slim_doc_generator(
else 0.0
)
return gmail_connector.retrieve_all_slim_documents(
return gmail_connector.retrieve_all_slim_docs_perm_sync(
start=start_time,
end=current_time.timestamp(),
callback=callback,

View File

@@ -34,7 +34,7 @@ def _get_slim_doc_generator(
else 0.0
)
return google_drive_connector.retrieve_all_slim_documents(
return google_drive_connector.retrieve_all_slim_docs_perm_sync(
start=start_time,
end=current_time.timestamp(),
callback=callback,

View File

@@ -59,7 +59,7 @@ def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]:
for raw_perm in permissions:
if not hasattr(raw_perm, "raw"):
logger.warn(f"Expected a 'raw' field, but none was found: {raw_perm=}")
logger.warning(f"Expected a 'raw' field, but none was found: {raw_perm=}")
continue
permission = Permission(**raw_perm.raw)
@@ -71,14 +71,14 @@ def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]:
# In order to associate this permission to some Atlassian entity, we need the "Holder".
# If this doesn't exist, then we cannot associate this permission to anyone; just skip.
if not permission.holder:
logger.warn(
logger.warning(
f"Expected to find a permission holder, but none was found: {permission=}"
)
continue
type = permission.holder.get("type")
if not type:
logger.warn(
logger.warning(
f"Expected to find the type of permission holder, but none was found: {permission=}"
)
continue

View File

@@ -105,7 +105,9 @@ def _get_slack_document_access(
channel_permissions: dict[str, ExternalAccess],
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
slim_doc_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
callback=callback
)
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:

View File

@@ -4,7 +4,7 @@ from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFun
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -17,7 +17,7 @@ def generic_doc_sync(
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None,
doc_source: DocumentSource,
slim_connector: SlimConnector,
slim_connector: SlimConnectorWithPermSync,
label: str,
) -> Generator[DocExternalAccess, None, None]:
"""
@@ -40,7 +40,7 @@ def generic_doc_sync(
newly_fetched_doc_ids: set[str] = set()
logger.info(f"Fetching all slim documents from {doc_source}")
for doc_batch in slim_connector.retrieve_all_slim_documents(callback=callback):
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(callback=callback):
logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}")
if callback:

View File

@@ -0,0 +1,15 @@
from ee.onyx.feature_flags.posthog_provider import PostHogFeatureFlagProvider
from onyx.feature_flags.interface import FeatureFlagProvider
def get_posthog_feature_flag_provider() -> FeatureFlagProvider:
"""
Get the PostHog feature flag provider instance.
This is the EE implementation that gets loaded by the versioned
implementation loader.
Returns:
PostHogFeatureFlagProvider: The PostHog-based feature flag provider
"""
return PostHogFeatureFlagProvider()

View File

@@ -0,0 +1,54 @@
from typing import Any
from uuid import UUID
from ee.onyx.utils.posthog_client import posthog
from onyx.feature_flags.interface import FeatureFlagProvider
from onyx.utils.logger import setup_logger
logger = setup_logger()
class PostHogFeatureFlagProvider(FeatureFlagProvider):
"""
PostHog-based feature flag provider.
Uses PostHog's feature flag API to determine if features are enabled
for specific users. Only active in multi-tenant mode.
"""
def feature_enabled(
self,
flag_key: str,
user_id: UUID,
user_properties: dict[str, Any] | None = None,
) -> bool:
"""
Check if a feature flag is enabled for a user via PostHog.
Args:
flag_key: The identifier for the feature flag to check
user_id: The unique identifier for the user
user_properties: Optional dictionary of user properties/attributes
that may influence flag evaluation
Returns:
True if the feature is enabled for the user, False otherwise.
"""
try:
posthog.set(
distinct_id=user_id,
properties=user_properties,
)
is_enabled = posthog.feature_enabled(
flag_key,
str(user_id),
person_properties=user_properties,
)
return bool(is_enabled) if is_enabled is not None else False
except Exception as e:
logger.error(
f"Error checking feature flag {flag_key} for user {user_id}: {e}"
)
return False

View File

@@ -3,11 +3,7 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from httpx_oauth.clients.openid import BASE_SCOPES
from httpx_oauth.clients.openid import OpenID
from ee.onyx.configs.app_configs import OIDC_SCOPE_OVERRIDE
from ee.onyx.configs.app_configs import OPENID_CONFIG_URL
from ee.onyx.server.analytics.api import router as analytics_router
from ee.onyx.server.auth_check import check_ee_router_auth
from ee.onyx.server.documents.cc_pair import router as ee_document_cc_pair_router
@@ -31,7 +27,6 @@ from ee.onyx.server.query_and_chat.query_backend import (
)
from ee.onyx.server.query_history.api import router as query_history_router
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
from ee.onyx.server.saml import router as saml_router
from ee.onyx.server.seeding import seed_db
from ee.onyx.server.tenants.api import router as tenants_router
from ee.onyx.server.token_rate_limits.api import (
@@ -117,49 +112,6 @@ def get_application() -> FastAPI:
prefix="/auth",
)
if AUTH_TYPE == AuthType.OIDC:
# Ensure we request offline_access for refresh tokens
try:
oidc_scopes = list(OIDC_SCOPE_OVERRIDE or BASE_SCOPES)
if "offline_access" not in oidc_scopes:
oidc_scopes.append("offline_access")
except Exception as e:
logger.warning(f"Error configuring OIDC scopes: {e}")
# Fall back to default scopes if there's an error
oidc_scopes = BASE_SCOPES
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
OpenID(
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
OPENID_CONFIG_URL,
# Use the configured scopes
base_scopes=oidc_scopes,
),
auth_backend,
USER_AUTH_SECRET,
associate_by_email=True,
is_verified_by_default=True,
redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback",
),
prefix="/auth/oidc",
)
# need basic auth router for `logout` endpoint
include_auth_router_with_prefix(
application,
fastapi_users.get_auth_router(auth_backend),
prefix="/auth",
)
elif AUTH_TYPE == AuthType.SAML:
include_auth_router_with_prefix(
application,
saml_router,
)
# RBAC / group access control
include_router_with_global_prefix_prepended(application, user_group_router)
# Analytics endpoints

View File

@@ -1,45 +0,0 @@
import json
import os
from typing import cast
from typing import List
from cohere import Client
from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
Embedding = List[float]
def load_processed_docs(cohere_enabled: bool) -> list[dict]:
base_path = os.path.join(os.getcwd(), "onyx", "seeding")
if cohere_enabled and COHERE_DEFAULT_API_KEY:
initial_docs_path = os.path.join(base_path, "initial_docs_cohere.json")
processed_docs = json.load(open(initial_docs_path))
cohere_client = Client(api_key=COHERE_DEFAULT_API_KEY)
embed_model = "embed-english-v3.0"
for doc in processed_docs:
title_embed_response = cohere_client.embed(
texts=[doc["title"]],
model=embed_model,
input_type="search_document",
)
content_embed_response = cohere_client.embed(
texts=[doc["content"]],
model=embed_model,
input_type="search_document",
)
doc["title_embedding"] = cast(
List[Embedding], title_embed_response.embeddings
)[0]
doc["content_embedding"] = cast(
List[Embedding], content_embed_response.embeddings
)[0]
else:
initial_docs_path = os.path.join(base_path, "initial_docs.json")
processed_docs = json.load(open(initial_docs_path))
return processed_docs

View File

@@ -10,14 +10,6 @@ EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
("/enterprise-settings/logo", {"GET"}),
("/enterprise-settings/logotype", {"GET"}),
("/enterprise-settings/custom-analytics-script", {"GET"}),
# oidc
("/auth/oidc/authorize", {"GET"}),
("/auth/oidc/callback", {"GET"}),
# saml
("/auth/saml/authorize", {"GET"}),
("/auth/saml/callback", {"POST"}),
("/auth/saml/callback", {"GET"}),
("/auth/saml/logout", {"POST"}),
]

View File

@@ -48,7 +48,17 @@ def generate_chat_messages_report(
max_size=MAX_IN_MEMORY_SIZE, mode="w+"
) as temp_file:
csvwriter = csv.writer(temp_file, delimiter=",")
csvwriter.writerow(["session_id", "user_id", "flow_type", "time_sent"])
csvwriter.writerow(
[
"session_id",
"user_id",
"flow_type",
"time_sent",
"assistant_name",
"user_email",
"number_of_tokens",
]
)
for chat_message_skeleton_batch in get_all_empty_chat_message_entries(
db_session, period
):
@@ -59,6 +69,9 @@ def generate_chat_messages_report(
chat_message_skeleton.user_id,
chat_message_skeleton.flow_type,
chat_message_skeleton.time_sent.isoformat(),
chat_message_skeleton.assistant_name,
chat_message_skeleton.user_email,
chat_message_skeleton.number_of_tokens,
]
)

View File

@@ -16,6 +16,9 @@ class ChatMessageSkeleton(BaseModel):
user_id: str | None
flow_type: FlowType
time_sent: datetime
assistant_name: str | None
user_email: str | None
number_of_tokens: int
class UserSkeleton(BaseModel):

View File

@@ -37,9 +37,9 @@ from onyx.db.models import AvailableTenant
from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.db.models import UserTenantMapping
from onyx.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
from onyx.llm.llm_provider_options import ANTHROPIC_VISIBLE_MODEL_NAMES
from onyx.llm.llm_provider_options import get_anthropic_model_names
from onyx.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
from onyx.llm.llm_provider_options import OPEN_AI_VISIBLE_MODEL_NAMES
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
@@ -278,7 +278,7 @@ def configure_default_api_keys(db_session: Session) -> None:
is_visible=name in ANTHROPIC_VISIBLE_MODEL_NAMES,
max_input_tokens=None,
)
for name in ANTHROPIC_MODEL_NAMES
for name in get_anthropic_model_names()
],
api_key_changed=True,
)

View File

@@ -0,0 +1,22 @@
from typing import Any
from posthog import Posthog
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
from ee.onyx.configs.app_configs import POSTHOG_HOST
from onyx.utils.logger import setup_logger
logger = setup_logger()
def posthog_on_error(error: Any, items: Any) -> None:
"""Log any PostHog delivery errors."""
logger.error(f"PostHog error: {error}, items: {items}")
posthog = Posthog(
project_api_key=POSTHOG_API_KEY,
host=POSTHOG_HOST,
debug=True,
on_error=posthog_on_error,
)

View File

@@ -1,27 +1,9 @@
from typing import Any
from posthog import Posthog
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
from ee.onyx.configs.app_configs import POSTHOG_HOST
from ee.onyx.utils.posthog_client import posthog
from onyx.utils.logger import setup_logger
logger = setup_logger()
def posthog_on_error(error: Any, items: Any) -> None:
"""Log any PostHog delivery errors."""
logger.error(f"PostHog error: {error}, items: {items}")
posthog = Posthog(
project_api_key=POSTHOG_API_KEY,
host=POSTHOG_HOST,
debug=True,
on_error=posthog_on_error,
)
def event_telemetry(
distinct_id: str, event: str, properties: dict | None = None
) -> None:

View File

@@ -1,14 +1,12 @@
from typing import cast
from typing import Optional
from typing import TYPE_CHECKING
import numpy as np
import torch
import torch.nn.functional as F
from fastapi import APIRouter
from huggingface_hub import snapshot_download # type: ignore
from setfit import SetFitModel # type: ignore[import]
from transformers import AutoTokenizer # type: ignore
from transformers import BatchEncoding # type: ignore
from transformers import PreTrainedTokenizer # type: ignore
from model_server.constants import INFORMATION_CONTENT_MODEL_WARM_UP_STRING
from model_server.constants import MODEL_WARM_UP_STRING
@@ -37,23 +35,30 @@ from shared_configs.model_server_models import ContentClassificationPrediction
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
if TYPE_CHECKING:
from setfit import SetFitModel # type: ignore
from transformers import PreTrainedTokenizer, BatchEncoding # type: ignore
logger = setup_logger()
router = APIRouter(prefix="/custom")
_CONNECTOR_CLASSIFIER_TOKENIZER: PreTrainedTokenizer | None = None
_CONNECTOR_CLASSIFIER_TOKENIZER: Optional["PreTrainedTokenizer"] = None
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
_INTENT_TOKENIZER: PreTrainedTokenizer | None = None
_INTENT_TOKENIZER: Optional["PreTrainedTokenizer"] = None
_INTENT_MODEL: HybridClassifier | None = None
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
_INFORMATION_CONTENT_MODEL: Optional["SetFitModel"] = None
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
def get_connector_classifier_tokenizer() -> PreTrainedTokenizer:
def get_connector_classifier_tokenizer() -> "PreTrainedTokenizer":
global _CONNECTOR_CLASSIFIER_TOKENIZER
from transformers import AutoTokenizer, PreTrainedTokenizer
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
# The tokenizer details are not uploaded to the HF hub since it's just the
# unmodified distilbert tokenizer.
@@ -95,7 +100,9 @@ def get_local_connector_classifier(
return _CONNECTOR_CLASSIFIER_MODEL
def get_intent_model_tokenizer() -> PreTrainedTokenizer:
def get_intent_model_tokenizer() -> "PreTrainedTokenizer":
from transformers import AutoTokenizer, PreTrainedTokenizer
global _INTENT_TOKENIZER
if _INTENT_TOKENIZER is None:
# The tokenizer details are not uploaded to the HF hub since it's just the
@@ -141,7 +148,9 @@ def get_local_intent_model(
def get_local_information_content_model(
model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION,
tag: str | None = INFORMATION_CONTENT_MODEL_TAG,
) -> SetFitModel:
) -> "SetFitModel":
from setfit import SetFitModel
global _INFORMATION_CONTENT_MODEL
if _INFORMATION_CONTENT_MODEL is None:
try:
@@ -179,7 +188,7 @@ def get_local_information_content_model(
def tokenize_connector_classification_query(
connectors: list[str],
query: str,
tokenizer: PreTrainedTokenizer,
tokenizer: "PreTrainedTokenizer",
connector_token_end_id: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
@@ -267,7 +276,7 @@ def warm_up_information_content_model() -> None:
@simple_log_function_time()
def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
def run_inference(tokens: "BatchEncoding") -> tuple[list[float], list[float]]:
intent_model = get_local_intent_model()
device = intent_model.device
@@ -401,7 +410,7 @@ def run_content_classification_inference(
def map_keywords(
input_ids: torch.Tensor, tokenizer: PreTrainedTokenizer, is_keyword: list[bool]
input_ids: torch.Tensor, tokenizer: "PreTrainedTokenizer", is_keyword: list[bool]
) -> list[str]:
tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore

View File

@@ -2,13 +2,11 @@ import asyncio
import time
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
from fastapi import APIRouter
from fastapi import HTTPException
from fastapi import Request
from litellm.exceptions import RateLimitError
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from model_server.utils import simple_log_function_time
from onyx.utils.logger import setup_logger
@@ -20,6 +18,9 @@ from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
if TYPE_CHECKING:
from sentence_transformers import CrossEncoder, SentenceTransformer
logger = setup_logger()
router = APIRouter(prefix="/encoder")
@@ -88,8 +89,10 @@ def get_embedding_model(
def get_local_reranking_model(
model_name: str,
) -> CrossEncoder:
) -> "CrossEncoder":
global _RERANK_MODEL
from sentence_transformers import CrossEncoder # type: ignore
if _RERANK_MODEL is None:
logger.notice(f"Loading {model_name}")
model = CrossEncoder(model_name)
@@ -207,6 +210,8 @@ async def route_bi_encoder_embed(
async def process_embed_request(
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
) -> EmbedResponse:
from litellm.exceptions import RateLimitError
# Only local models should use this endpoint - API providers should make direct API calls
if embed_request.provider_type is not None:
raise ValueError(

View File

@@ -30,6 +30,7 @@ from shared_configs.configs import MIN_THREADS_ML_MODELS
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import SENTRY_DSN
from shared_configs.configs import SKIP_WARM_UP
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
@@ -91,16 +92,17 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
if not INDEXING_ONLY:
logger.notice(
"The intent model should run on the model server. The information content model should not run here."
)
warm_up_intent_model()
if not SKIP_WARM_UP:
if not INDEXING_ONLY:
logger.notice("Warming up intent model for inference model server")
warm_up_intent_model()
else:
logger.notice(
"Warming up content information model for indexing model server"
)
warm_up_information_content_model()
else:
logger.notice(
"The content information model should run on the indexing model server. The intent model should not run here."
)
warm_up_information_content_model()
logger.notice("Skipping model warmup due to SKIP_WARM_UP=true")
yield

View File

@@ -1,16 +1,20 @@
import json
import os
from typing import cast
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from transformers import DistilBertConfig # type: ignore
from transformers import DistilBertModel # type: ignore
from transformers import DistilBertTokenizer # type: ignore
if TYPE_CHECKING:
from transformers import DistilBertConfig # type: ignore
class HybridClassifier(nn.Module):
def __init__(self) -> None:
from transformers import DistilBertConfig, DistilBertModel
super().__init__()
config = DistilBertConfig()
self.distilbert = DistilBertModel(config)
@@ -74,7 +78,9 @@ class HybridClassifier(nn.Module):
class ConnectorClassifier(nn.Module):
def __init__(self, config: DistilBertConfig) -> None:
def __init__(self, config: "DistilBertConfig") -> None:
from transformers import DistilBertTokenizer, DistilBertModel
super().__init__()
self.config = config
@@ -115,6 +121,8 @@ class ConnectorClassifier(nn.Module):
@classmethod
def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
from transformers import DistilBertConfig
config = cast(
DistilBertConfig,
DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")),

View File

@@ -100,9 +100,14 @@ class IterationAnswer(BaseModel):
response_type: str | None = None
data: dict | list | str | int | float | bool | None = None
file_ids: list[str] | None = None
# TODO: This is not ideal, but we'll can rework the schema
# for deep research later
is_web_fetch: bool = False
# for image generation step-types
generated_images: list[GeneratedImage] | None = None
# for multi-query search tools (v2 web search and internal search)
# TODO: Clean this up to be more flexible to tools
queries: list[str] | None = None
class AggregatedDRContext(BaseModel):

View File

@@ -3,6 +3,7 @@ from datetime import datetime
from typing import Any
from typing import cast
from braintrust import traced
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
@@ -22,6 +23,9 @@ from onyx.agents.agent_search.dr.models import DecisionResponse
from onyx.agents.agent_search.dr.models import DRPromptPurpose
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.agents.agent_search.dr.process_llm_stream import (
BasicSearchProcessedStreamResults,
)
from onyx.agents.agent_search.dr.process_llm_stream import process_llm_stream
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.states import OrchestrationSetup
@@ -37,6 +41,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.chat.chat_utils import build_citation_map_from_numbers
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
from onyx.chat.memories import make_memories_callback
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.citations_prompt import build_citations_system_message
from onyx.chat.prompt_builder.citations_prompt import build_citations_user_message
@@ -70,6 +75,8 @@ from onyx.prompts.dr_prompts import DEFAULT_DR_SYSTEM_PROMPT
from onyx.prompts.dr_prompts import REPEAT_PROMPT
from onyx.prompts.dr_prompts import TOOL_DESCRIPTION
from onyx.prompts.prompt_template import PromptTemplate
from onyx.prompts.prompt_utils import handle_company_awareness
from onyx.prompts.prompt_utils import handle_memories
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import SectionEnd
@@ -116,7 +123,9 @@ def _get_available_tools(
else:
include_kg = False
tool_dict: dict[int, Tool] = {tool.id: tool for tool in get_tools(db_session)}
tool_dict: dict[int, Tool] = {
tool.id: tool for tool in get_tools(db_session, only_enabled=True)
}
for tool in graph_config.tooling.tools:
@@ -484,6 +493,16 @@ def clarifier(
+ PROJECT_INSTRUCTIONS_SEPARATOR
+ graph_config.inputs.project_instructions
)
user = (
graph_config.tooling.search_tool.user
if graph_config.tooling.search_tool
else None
)
memories_callback = make_memories_callback(user, db_session)
assistant_system_prompt = handle_company_awareness(assistant_system_prompt)
assistant_system_prompt = handle_memories(
assistant_system_prompt, memories_callback
)
chat_history_string = (
get_chat_history_string(
@@ -666,28 +685,30 @@ def clarifier(
system_prompt_to_use = assistant_system_prompt
user_prompt_to_use = decision_prompt + assistant_task_prompt
stream = graph_config.tooling.primary_llm.stream(
prompt=create_question_prompt(
cast(str, system_prompt_to_use),
cast(str, user_prompt_to_use),
uploaded_image_context=uploaded_image_context,
),
tools=([_ARTIFICIAL_ALL_ENCOMPASSING_TOOL]),
tool_choice=(None),
structured_response_format=graph_config.inputs.structured_response_format,
)
full_response = process_llm_stream(
messages=stream,
should_stream_answer=True,
writer=writer,
ind=0,
final_search_results=context_llm_docs,
displayed_search_results=context_llm_docs,
generate_final_answer=True,
chat_message_id=str(graph_config.persistence.chat_session_id),
)
@traced(name="clarifier stream and process", type="llm")
def stream_and_process() -> BasicSearchProcessedStreamResults:
stream = graph_config.tooling.primary_llm.stream(
prompt=create_question_prompt(
cast(str, system_prompt_to_use),
cast(str, user_prompt_to_use),
uploaded_image_context=uploaded_image_context,
),
tools=([_ARTIFICIAL_ALL_ENCOMPASSING_TOOL]),
tool_choice=(None),
structured_response_format=graph_config.inputs.structured_response_format,
)
return process_llm_stream(
messages=stream,
should_stream_answer=True,
writer=writer,
ind=0,
final_search_results=context_llm_docs,
displayed_search_results=context_llm_docs,
generate_final_answer=True,
chat_message_id=str(graph_config.persistence.chat_session_id),
)
full_response = stream_and_process()
if len(full_response.ai_message_chunk.tool_calls) == 0:
if isinstance(full_response.full_answer, str):

View File

@@ -199,6 +199,7 @@ def save_iteration(
else None
),
additional_data=iteration_answer.additional_data,
queries=iteration_answer.queries,
)
db_session.add(research_agent_iteration_sub_step)

View File

@@ -180,6 +180,7 @@ def save_iteration(
else None
),
additional_data=iteration_answer.additional_data,
queries=iteration_answer.queries,
)
db_session.add(research_agent_iteration_sub_step)

View File

@@ -1,3 +1,4 @@
import json
from datetime import datetime
from typing import cast
@@ -28,6 +29,7 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -62,6 +64,29 @@ def image_generation(
image_tool_info = state.available_tools[state.tools_used[-1]]
image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
image_prompt = branch_query
requested_shape: ImageShape | None = None
try:
parsed_query = json.loads(branch_query)
except json.JSONDecodeError:
parsed_query = None
if isinstance(parsed_query, dict):
prompt_from_llm = parsed_query.get("prompt")
if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
image_prompt = prompt_from_llm.strip()
raw_shape = parsed_query.get("shape")
if isinstance(raw_shape, str):
try:
requested_shape = ImageShape(raw_shape)
except ValueError:
logger.warning(
"Received unsupported image shape '%s' from LLM. Falling back to square.",
raw_shape,
)
logger.debug(
f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
@@ -69,7 +94,15 @@ def image_generation(
# Generate images using the image generation tool
image_generation_responses: list[ImageGenerationResponse] = []
for tool_response in image_tool.run(prompt=branch_query):
if requested_shape is not None:
tool_iterator = image_tool.run(
prompt=image_prompt,
shape=requested_shape.value,
)
else:
tool_iterator = image_tool.run(prompt=image_prompt)
for tool_response in tool_iterator:
if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
# Stream heartbeat to frontend
write_custom_event(
@@ -95,6 +128,7 @@ def image_generation(
file_id=file_id,
url=build_frontend_file_url(file_id),
revised_prompt=img.revised_prompt,
shape=(requested_shape or ImageShape.SQUARE).value,
)
for file_id, img in zip(file_ids, image_generation_responses)
]
@@ -107,15 +141,29 @@ def image_generation(
if final_generated_images:
image_descriptions = []
for i, img in enumerate(final_generated_images, 1):
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
if img.shape and img.shape != ImageShape.SQUARE.value:
image_descriptions.append(
f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
)
else:
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
answer_string = (
f"Generated {len(final_generated_images)} image(s) based on the request: {branch_query}\n\n"
f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
+ "\n".join(image_descriptions)
)
reasoning = f"Used image generation tool to create {len(final_generated_images)} image(s) based on the user's request."
if requested_shape:
reasoning = (
"Used image generation tool to create "
f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
)
else:
reasoning = (
"Used image generation tool to create "
f"{len(final_generated_images)} image(s) based on the user's request."
)
else:
answer_string = f"Failed to generate images for request: {branch_query}"
answer_string = f"Failed to generate images for request: {image_prompt}"
reasoning = "Image generation tool did not return any results."
return BranchUpdate(

View File

@@ -5,6 +5,7 @@ class GeneratedImage(BaseModel):
file_id: str
url: str
revised_prompt: str
shape: str | None = None
# Needed for PydanticType

View File

@@ -2,30 +2,28 @@ from exa_py import Exa
from exa_py.api import HighlightsContentsOptions
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetContent,
WebContent,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchProvider,
WebSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchResult,
WebSearchResult,
)
from onyx.configs.chat_configs import EXA_API_KEY
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.utils.retry_wrapper import retry_builder
# TODO Dependency inject for testing
class ExaClient(InternetSearchProvider):
class ExaClient(WebSearchProvider):
def __init__(self, api_key: str | None = EXA_API_KEY) -> None:
self.exa = Exa(api_key=api_key)
@retry_builder(tries=3, delay=1, backoff=2)
def search(self, query: str) -> list[InternetSearchResult]:
def search(self, query: str) -> list[WebSearchResult]:
response = self.exa.search_and_contents(
query,
type="fast",
livecrawl="never",
type="auto",
highlights=HighlightsContentsOptions(
num_sentences=2,
highlights_per_url=1,
@@ -34,7 +32,7 @@ class ExaClient(InternetSearchProvider):
)
return [
InternetSearchResult(
WebSearchResult(
title=result.title or "",
link=result.url,
snippet=result.highlights[0] if result.highlights else "",
@@ -49,7 +47,7 @@ class ExaClient(InternetSearchProvider):
]
@retry_builder(tries=3, delay=1, backoff=2)
def contents(self, urls: list[str]) -> list[InternetContent]:
def contents(self, urls: list[str]) -> list[WebContent]:
response = self.exa.get_contents(
urls=urls,
text=True,
@@ -57,7 +55,7 @@ class ExaClient(InternetSearchProvider):
)
return [
InternetContent(
WebContent(
title=result.title or "",
link=result.url,
full_content=result.text or "",

View File

@@ -0,0 +1,147 @@
import json
from concurrent.futures import ThreadPoolExecutor
import requests
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContent,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.configs.chat_configs import SERPER_API_KEY
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.utils.retry_wrapper import retry_builder
SERPER_SEARCH_URL = "https://google.serper.dev/search"
SERPER_CONTENTS_URL = "https://scrape.serper.dev"
class SerperClient(WebSearchProvider):
def __init__(self, api_key: str | None = SERPER_API_KEY) -> None:
self.headers = {
"X-API-KEY": api_key,
"Content-Type": "application/json",
}
@retry_builder(tries=3, delay=1, backoff=2)
def search(self, query: str) -> list[WebSearchResult]:
payload = {
"q": query,
}
response = requests.post(
SERPER_SEARCH_URL,
headers=self.headers,
data=json.dumps(payload),
)
response.raise_for_status()
results = response.json()
organic_results = results["organic"]
return [
WebSearchResult(
title=result["title"],
link=result["link"],
snippet=result["snippet"],
author=None,
published_date=None,
)
for result in organic_results
]
def contents(self, urls: list[str]) -> list[WebContent]:
if not urls:
return []
# Serper can responds with 500s regularly. We want to retry,
# but in the event of failure, return an unsuccesful scrape.
def safe_get_webpage_content(url: str) -> WebContent:
try:
return self._get_webpage_content(url)
except Exception:
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
with ThreadPoolExecutor(max_workers=min(8, len(urls))) as e:
return list(e.map(safe_get_webpage_content, urls))
@retry_builder(tries=3, delay=1, backoff=2)
def _get_webpage_content(self, url: str) -> WebContent:
payload = {
"url": url,
}
response = requests.post(
SERPER_CONTENTS_URL,
headers=self.headers,
data=json.dumps(payload),
)
# 400 returned when serper cannot scrape
if response.status_code == 400:
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
response.raise_for_status()
response_json = response.json()
# Response only guarantees text
text = response_json["text"]
# metadata & jsonld is not guaranteed to be present
metadata = response_json.get("metadata", {})
jsonld = response_json.get("jsonld", {})
title = extract_title_from_metadata(metadata)
# Serper does not provide a reliable mechanism to extract the url
response_url = url
published_date_str = extract_published_date_from_jsonld(jsonld)
published_date = None
if published_date_str:
try:
published_date = time_str_to_utc(published_date_str)
except Exception:
published_date = None
return WebContent(
title=title or "",
link=response_url,
full_content=text or "",
published_date=published_date,
)
def extract_title_from_metadata(metadata: dict[str, str]) -> str | None:
keys = ["title", "og:title"]
return extract_value_from_dict(metadata, keys)
def extract_published_date_from_jsonld(jsonld: dict[str, str]) -> str | None:
keys = ["dateModified"]
return extract_value_from_dict(jsonld, keys)
def extract_value_from_dict(data: dict[str, str], keys: list[str]) -> str | None:
for key in keys:
if key in data:
return data[key]
return None

View File

@@ -7,7 +7,7 @@ from langsmith import traceable
from onyx.agents.agent_search.dr.models import WebSearchAnswer
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchResult,
WebSearchResult,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
get_default_provider,
@@ -75,15 +75,15 @@ def web_search(
raise ValueError("No internet search provider found")
@traceable(name="Search Provider API Call")
def _search(search_query: str) -> list[InternetSearchResult]:
search_results: list[InternetSearchResult] = []
def _search(search_query: str) -> list[WebSearchResult]:
search_results: list[WebSearchResult] = []
try:
search_results = provider.search(search_query)
except Exception as e:
logger.error(f"Error performing search: {e}")
return search_results
search_results: list[InternetSearchResult] = _search(search_query)
search_results: list[WebSearchResult] = _search(search_query)
search_results_text = "\n\n".join(
[
f"{i}. {result.title}\n URL: {result.link}\n"

View File

@@ -4,7 +4,7 @@ from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchResult,
WebSearchResult,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
InternetSearchInput,
@@ -23,7 +23,7 @@ def dedup_urls(
writer: StreamWriter = lambda _: None,
) -> InternetSearchInput:
branch_questions_to_urls: dict[str, list[str]] = defaultdict(list)
unique_results_by_link: dict[str, InternetSearchResult] = {}
unique_results_by_link: dict[str, WebSearchResult] = {}
for query, result in state.results_to_open:
branch_questions_to_urls[query].append(result.link)
if result.link not in unique_results_by_link:

View File

@@ -13,7 +13,7 @@ class ProviderType(Enum):
EXA = "exa"
class InternetSearchResult(BaseModel):
class WebSearchResult(BaseModel):
title: str
link: str
author: str | None = None
@@ -21,18 +21,19 @@ class InternetSearchResult(BaseModel):
snippet: str | None = None
class InternetContent(BaseModel):
class WebContent(BaseModel):
title: str
link: str
full_content: str
published_date: datetime | None = None
scrape_successful: bool = True
class InternetSearchProvider(ABC):
class WebSearchProvider(ABC):
@abstractmethod
def search(self, query: str) -> list[InternetSearchResult]:
def search(self, query: str) -> list[WebSearchResult]:
pass
@abstractmethod
def contents(self, urls: list[str]) -> list[InternetContent]:
def contents(self, urls: list[str]) -> list[WebContent]:
pass

View File

@@ -1,13 +1,19 @@
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.exa_client import (
ExaClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.serper_client import (
SerperClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchProvider,
WebSearchProvider,
)
from onyx.configs.chat_configs import EXA_API_KEY
from onyx.configs.chat_configs import SERPER_API_KEY
def get_default_provider() -> InternetSearchProvider | None:
def get_default_provider() -> WebSearchProvider | None:
if EXA_API_KEY:
return ExaClient()
if SERPER_API_KEY:
return SerperClient()
return None

View File

@@ -4,13 +4,13 @@ from typing import Annotated
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchResult,
WebSearchResult,
)
from onyx.context.search.models import InferenceSection
class InternetSearchInput(SubAgentInput):
results_to_open: Annotated[list[tuple[str, InternetSearchResult]], add] = []
results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
parallelization_nr: int = 0
branch_question: Annotated[str, lambda x, y: y] = ""
branch_questions_to_urls: Annotated[dict[str, list[str]], lambda x, y: y] = {}
@@ -18,7 +18,7 @@ class InternetSearchInput(SubAgentInput):
class InternetSearchUpdate(LoggerUpdate):
results_to_open: Annotated[list[tuple[str, InternetSearchResult]], add] = []
results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
class FetchInput(SubAgentInput):

View File

@@ -1,8 +1,8 @@
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetContent,
WebContent,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchResult,
WebSearchResult,
)
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceChunk
@@ -17,7 +17,7 @@ def truncate_search_result_content(content: str, max_chars: int = 10000) -> str:
def dummy_inference_section_from_internet_content(
result: InternetContent,
result: WebContent,
) -> InferenceSection:
truncated_content = truncate_search_result_content(result.full_content)
return InferenceSection(
@@ -34,7 +34,7 @@ def dummy_inference_section_from_internet_content(
boost=1,
recency_bias=1.0,
score=1.0,
hidden=False,
hidden=(not result.scrape_successful),
metadata={},
match_highlights=[],
doc_summary=truncated_content,
@@ -48,7 +48,7 @@ def dummy_inference_section_from_internet_content(
def dummy_inference_section_from_internet_search_result(
result: InternetSearchResult,
result: WebSearchResult,
) -> InferenceSection:
return InferenceSection(
center_chunk=InferenceChunk(

View File

@@ -1,7 +1,6 @@
from uuid import UUID
from pydantic import BaseModel
from pydantic import model_validator
from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.enums import ResearchType
@@ -88,11 +87,5 @@ class GraphConfig(BaseModel):
# Only needed for agentic search
persistence: GraphPersistence
@model_validator(mode="after")
def validate_search_tool(self) -> "GraphConfig":
if self.behavior.use_agentic_search and self.tooling.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
return self
class Config:
arbitrary_types_allowed = True

View File

@@ -1,71 +0,0 @@
from typing import cast
from langchain_core.messages import AIMessageChunk
from langchain_core.messages.tool import ToolCall
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.orchestration.states import ToolCallOutput
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.tools.message import build_tool_message
from onyx.tools.message import ToolCallSummary
from onyx.tools.tool_runner import ToolRunner
from onyx.utils.logger import setup_logger
logger = setup_logger()
class ToolCallException(Exception):
"""Exception raised for errors during tool calls."""
def call_tool(
state: ToolChoiceUpdate,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ToolCallUpdate:
"""Calls the tool specified in the state and updates the state with the result"""
cast(GraphConfig, config["metadata"]["config"])
tool_choice = state.tool_choice
if tool_choice is None:
raise ValueError("Cannot invoke tool call node without a tool choice")
tool = tool_choice.tool
tool_args = tool_choice.tool_args
tool_id = tool_choice.id
tool_runner = ToolRunner(
tool, tool_args, override_kwargs=tool_choice.search_tool_override_kwargs
)
tool_kickoff = tool_runner.kickoff()
try:
tool_responses = []
for response in tool_runner.tool_responses():
tool_responses.append(response)
tool_final_result = tool_runner.tool_final_result()
except Exception as e:
raise ToolCallException(
f"Error during tool call for {tool.display_name}: {e}"
) from e
tool_call = ToolCall(name=tool.name, args=tool_args, id=tool_id)
tool_call_summary = ToolCallSummary(
tool_call_request=AIMessageChunk(content="", tool_calls=[tool_call]),
tool_call_result=build_tool_message(
tool_call, tool_runner.tool_message_content()
),
)
tool_call_output = ToolCallOutput(
tool_call_summary=tool_call_summary,
tool_call_kickoff=tool_kickoff,
tool_call_responses=tool_responses,
tool_call_final_result=tool_final_result,
)
return ToolCallUpdate(tool_call_output=tool_call_output)

View File

@@ -1,354 +0,0 @@
from typing import cast
from uuid import uuid4
from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import ToolCall
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.process_llm_stream import process_llm_stream
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.orchestration.states import ToolChoice
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import (
get_tool_call_for_non_tool_calling_llm_impl,
)
from onyx.configs.chat_configs import USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
from onyx.context.search.preprocessing.preprocessing import query_analysis
from onyx.context.search.retrieval.search_runner import get_query_embedding
from onyx.llm.factory import get_default_llms
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
from onyx.prompts.chat_prompts import QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
from onyx.prompts.chat_prompts import QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT
from onyx.tools.models import QueryExpansions
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import TimeoutThread
from onyx.utils.threadpool_concurrency import wait_on_background
from onyx.utils.timing import log_function_time
from shared_configs.model_server_models import Embedding
logger = setup_logger()
def _create_history_str(prompt_builder: AnswerPromptBuilder) -> str:
# TODO: Add trimming logic
history_segments = []
for msg in prompt_builder.message_history:
if isinstance(msg, HumanMessage):
role = "User"
elif isinstance(msg, AIMessage):
role = "Assistant"
else:
continue
history_segments.append(f"{role}:\n {msg.content}\n\n")
return "\n".join(history_segments)
def _expand_query(
query: str,
expansion_type: QueryExpansionType,
prompt_builder: AnswerPromptBuilder,
) -> str:
history_str = _create_history_str(prompt_builder)
if history_str:
if expansion_type == QueryExpansionType.KEYWORD:
base_prompt = QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
else:
base_prompt = QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
expansion_prompt = base_prompt.format(question=query, history=history_str)
else:
if expansion_type == QueryExpansionType.KEYWORD:
base_prompt = QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
else:
base_prompt = QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT
expansion_prompt = base_prompt.format(question=query)
msg = HumanMessage(content=expansion_prompt)
primary_llm, _ = get_default_llms()
response = primary_llm.invoke([msg])
rephrased_query: str = cast(str, response.content)
return rephrased_query
def _expand_query_non_tool_calling_llm(
expanded_keyword_thread: TimeoutThread[str],
expanded_semantic_thread: TimeoutThread[str],
) -> QueryExpansions | None:
keyword_expansion: str | None = wait_on_background(expanded_keyword_thread)
semantic_expansion: str | None = wait_on_background(expanded_semantic_thread)
if keyword_expansion is None or semantic_expansion is None:
return None
return QueryExpansions(
keywords_expansions=[keyword_expansion],
semantic_expansions=[semantic_expansion],
)
# TODO: break this out into an implementation function
# and a function that handles extracting the necessary fields
# from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable?
@log_function_time(print_only=True)
def choose_tool(
state: ToolChoiceState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ToolChoiceUpdate:
"""
This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
"""
should_stream_answer = state.should_stream_answer
agent_config = cast(GraphConfig, config["metadata"]["config"])
force_use_tool = agent_config.tooling.force_use_tool
embedding_thread: TimeoutThread[Embedding] | None = None
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
expanded_keyword_thread: TimeoutThread[str] | None = None
expanded_semantic_thread: TimeoutThread[str] | None = None
# If we have override_kwargs, add them to the tool_args
override_kwargs: SearchToolOverrideKwargs = (
force_use_tool.override_kwargs or SearchToolOverrideKwargs()
)
override_kwargs.original_query = agent_config.inputs.prompt_builder.raw_user_query
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
llm = agent_config.tooling.primary_llm
skip_gen_ai_answer_generation = agent_config.behavior.skip_gen_ai_answer_generation
if (
not agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
and (
not force_use_tool.force_use or force_use_tool.tool_name == SearchTool._NAME
)
):
# Run in a background thread to avoid blocking the main thread
embedding_thread = run_in_background(
get_query_embedding,
agent_config.inputs.prompt_builder.raw_user_query,
agent_config.persistence.db_session,
)
keyword_thread = run_in_background(
query_analysis,
agent_config.inputs.prompt_builder.raw_user_query,
)
if USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH:
expanded_keyword_thread = run_in_background(
_expand_query,
agent_config.inputs.prompt_builder.raw_user_query,
QueryExpansionType.KEYWORD,
prompt_builder,
)
expanded_semantic_thread = run_in_background(
_expand_query,
agent_config.inputs.prompt_builder.raw_user_query,
QueryExpansionType.SEMANTIC,
prompt_builder,
)
structured_response_format = agent_config.inputs.structured_response_format
tools = [
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
]
tool, tool_args = None, None
if force_use_tool.force_use and force_use_tool.args is not None:
tool_name, tool_args = (
force_use_tool.tool_name,
force_use_tool.args,
)
tool = get_tool_by_name(tools, tool_name)
# special pre-logic for non-tool calling LLM case
elif not using_tool_calling_llm and tools:
chosen_tool_and_args = get_tool_call_for_non_tool_calling_llm_impl(
force_use_tool=force_use_tool,
tools=tools,
prompt_builder=prompt_builder,
llm=llm,
)
if chosen_tool_and_args:
tool, tool_args = chosen_tool_and_args
# If we have a tool and tool args, we are ready to request a tool call.
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
if tool and tool_args:
if embedding_thread and tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
# dual keyword expansion needs to be added here for non-tool calling LLM case
if (
USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
and expanded_keyword_thread
and expanded_semantic_thread
and tool.name == SearchTool._NAME
):
override_kwargs.expanded_queries = _expand_query_non_tool_calling_llm(
expanded_keyword_thread=expanded_keyword_thread,
expanded_semantic_thread=expanded_semantic_thread,
)
if (
USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
and tool.name == SearchTool._NAME
and override_kwargs.expanded_queries
):
if (
override_kwargs.expanded_queries.keywords_expansions is None
or override_kwargs.expanded_queries.semantic_expansions is None
):
raise ValueError("No expanded keyword or semantic threads found.")
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=tool,
tool_args=tool_args,
id=str(uuid4()),
search_tool_override_kwargs=override_kwargs,
),
)
# if we're skipping gen ai answer generation, we should only
# continue if we're forcing a tool call (which will be emitted by
# the tool calling llm in the stream() below)
if skip_gen_ai_answer_generation and not force_use_tool.force_use:
return ToolChoiceUpdate(
tool_choice=None,
)
built_prompt = (
prompt_builder.build()
if isinstance(prompt_builder, AnswerPromptBuilder)
else prompt_builder.built_prompt
)
# At this point, we are either using a tool calling LLM or we are skipping the tool call.
# DEBUG: good breakpoint
stream = llm.stream(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=built_prompt,
tools=(
[tool.tool_definition() for tool in tools] or None
if using_tool_calling_llm
else None
),
tool_choice=(
"required"
if tools and force_use_tool.force_use and using_tool_calling_llm
else None
),
structured_response_format=structured_response_format,
)
tool_message = process_llm_stream(
stream,
should_stream_answer
and not agent_config.behavior.skip_gen_ai_answer_generation,
writer,
ind=0,
).ai_message_chunk
if tool_message is None:
raise ValueError("No tool message emitted by LLM")
# If no tool calls are emitted by the LLM, we should not choose a tool
if len(tool_message.tool_calls) == 0:
logger.debug("No tool calls emitted by LLM")
return ToolChoiceUpdate(
tool_choice=None,
)
# TODO: here we could handle parallel tool calls. Right now
# we just pick the first one that matches.
selected_tool: Tool | None = None
selected_tool_call_request: ToolCall | None = None
for tool_call_request in tool_message.tool_calls:
known_tools_by_name = [
tool for tool in tools if tool.name == tool_call_request["name"]
]
if known_tools_by_name:
selected_tool = known_tools_by_name[0]
selected_tool_call_request = tool_call_request
break
logger.error(
"Tool call requested with unknown name field. \n"
f"tools: {tools}"
f"tool_call_request: {tool_call_request}"
)
if not selected_tool or not selected_tool_call_request:
raise ValueError(
f"Tool call attempted with tool {selected_tool}, request {selected_tool_call_request}"
)
logger.debug(f"Selected tool: {selected_tool.name}")
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
if embedding_thread and selected_tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and selected_tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
if (
selected_tool.name == SearchTool._NAME
and expanded_keyword_thread
and expanded_semantic_thread
):
override_kwargs.expanded_queries = _expand_query_non_tool_calling_llm(
expanded_keyword_thread=expanded_keyword_thread,
expanded_semantic_thread=expanded_semantic_thread,
)
if (
USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
and selected_tool.name == SearchTool._NAME
and override_kwargs.expanded_queries
):
# TODO: this is a hack to handle the case where the expanded queries are not found.
# We should refactor this to be more robust.
if (
override_kwargs.expanded_queries.keywords_expansions is None
or override_kwargs.expanded_queries.semantic_expansions is None
):
raise ValueError("No expanded keyword or semantic threads found.")
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=selected_tool,
tool_args=selected_tool_call_request["args"],
id=selected_tool_call_request["id"],
search_tool_override_kwargs=override_kwargs,
),
)

View File

@@ -1,17 +0,0 @@
from typing import Any
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
def prepare_tool_input(state: Any, config: RunnableConfig) -> ToolChoiceInput:
agent_config = cast(GraphConfig, config["metadata"]["config"])
return ToolChoiceInput(
# NOTE: this node is used at the top level of the agent, so we always stream
should_stream_answer=True,
prompt_snapshot=None, # uses default prompt builder
tools=[tool.name for tool in (agent_config.tooling.tools or [])],
)

View File

@@ -3,6 +3,7 @@ from typing import cast
from langchain_core.runnables.schema import CustomStreamEvent
from langchain_core.runnables.schema import StreamEvent
from langfuse.langchain import CallbackHandler
from langgraph.graph.state import CompiledStateGraph
from onyx.agents.agent_search.dc_search_analysis.graph_builder import (
@@ -15,12 +16,13 @@ from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
from onyx.agents.agent_search.kb_search.states import MainInput as KBMainInput
from onyx.agents.agent_search.models import GraphConfig
from onyx.chat.models import AnswerStream
from onyx.configs.app_configs import LANGFUSE_PUBLIC_KEY
from onyx.configs.app_configs import LANGFUSE_SECRET_KEY
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.utils.logger import setup_logger
logger = setup_logger()
GraphInput = DCMainInput | KBMainInput | DRMainInput
@@ -30,10 +32,16 @@ def manage_sync_streaming(
graph_input: GraphInput,
) -> Iterable[StreamEvent]:
message_id = config.persistence.message_id if config.persistence else None
callbacks: list[CallbackHandler] = []
if LANGFUSE_SECRET_KEY and LANGFUSE_PUBLIC_KEY:
callbacks.append(CallbackHandler())
for event in compiled_graph.stream(
stream_mode="custom",
input=graph_input,
config={"metadata": {"config": config, "thread_id": str(message_id)}},
config={
"metadata": {"config": config, "thread_id": str(message_id)},
"callbacks": callbacks, # type: ignore
},
):
yield cast(CustomStreamEvent, event)

View File

@@ -5,11 +5,10 @@ from typing import Literal
from typing import Type
from typing import TypeVar
from braintrust import traced
from langchain.schema.language_model import LanguageModelInput
from langchain_core.messages import HumanMessage
from langgraph.types import StreamWriter
from litellm import get_supported_openai_params
from litellm import supports_response_schema
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
@@ -29,6 +28,7 @@ SchemaType = TypeVar("SchemaType", bound=BaseModel)
JSON_PATTERN = re.compile(r"```(?:json)?\s*(\{.*?\})\s*```", re.DOTALL)
@traced(name="stream llm", type="llm")
def stream_llm_answer(
llm: LLM,
prompt: LanguageModelInput,
@@ -147,6 +147,7 @@ def invoke_llm_json(
Invoke an LLM, forcing it to respond in a specified JSON format if possible,
and return an object of that schema.
"""
from litellm.utils import get_supported_openai_params, supports_response_schema
# check if the model supports response_format: json_schema
supports_json = "response_format" in (

View File

@@ -29,7 +29,7 @@ from onyx.configs.app_configs import SMTP_USER
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
from onyx.configs.constants import ONYX_SLACK_URL
from onyx.configs.constants import ONYX_DISCORD_URL
from onyx.db.models import User
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.logger import setup_logger
@@ -145,7 +145,7 @@ HTML_EMAIL_TEMPLATE = """\
<tr>
<td class="footer">
© {year} {application_name}. All rights reserved.
{slack_fragment}
{community_link_fragment}
</td>
</tr>
</table>
@@ -161,9 +161,9 @@ def build_html_email(
cta_text: str | None = None,
cta_link: str | None = None,
) -> str:
slack_fragment = ""
community_link_fragment = ""
if application_name == ONYX_DEFAULT_APPLICATION_NAME:
slack_fragment = f'<br>Have questions? Join our Slack community <a href="{ONYX_SLACK_URL}">here</a>.'
community_link_fragment = f'<br>Have questions? Join our Discord community <a href="{ONYX_DISCORD_URL}">here</a>.'
if cta_text and cta_link:
cta_block = f'<a class="cta-button" href="{cta_link}">{cta_text}</a>'
@@ -175,7 +175,7 @@ def build_html_email(
heading=heading,
message=message,
cta_block=cta_block,
slack_fragment=slack_fragment,
community_link_fragment=community_link_fragment,
year=datetime.now().year,
)

177
backend/onyx/auth/jwt.py Normal file
View File

@@ -0,0 +1,177 @@
import json
from enum import Enum
from functools import lru_cache
from typing import Any
from typing import cast
import jwt
import requests
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
from jwt import decode as jwt_decode
from jwt import InvalidTokenError
from jwt import PyJWTError
from jwt.algorithms import RSAAlgorithm
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from onyx.configs.app_configs import JWT_PUBLIC_KEY_URL
from onyx.db.models import User
from onyx.utils.logger import setup_logger
logger = setup_logger()
_PUBLIC_KEY_FETCH_ATTEMPTS = 2
class PublicKeyFormat(Enum):
JWKS = "jwks"
PEM = "pem"
@lru_cache()
def _fetch_public_key_payload() -> tuple[str | dict[str, Any], PublicKeyFormat] | None:
"""Fetch and cache the raw JWT verification material."""
if JWT_PUBLIC_KEY_URL is None:
logger.error("JWT_PUBLIC_KEY_URL is not set")
return None
try:
response = requests.get(JWT_PUBLIC_KEY_URL)
response.raise_for_status()
except requests.RequestException as exc:
logger.error(f"Failed to fetch JWT public key: {str(exc)}")
return None
content_type = response.headers.get("Content-Type", "").lower()
raw_body = response.text
body_lstripped = raw_body.lstrip()
if "application/json" in content_type or body_lstripped.startswith("{"):
try:
data = response.json()
except ValueError:
logger.error("JWT public key URL returned invalid JSON")
return None
if isinstance(data, dict) and "keys" in data:
return data, PublicKeyFormat.JWKS
logger.error(
"JWT public key URL returned JSON but no JWKS 'keys' field was found"
)
return None
body = raw_body.strip()
if not body:
logger.error("JWT public key URL returned an empty response")
return None
return body, PublicKeyFormat.PEM
def get_public_key(token: str) -> RSAPublicKey | str | None:
"""Return the concrete public key used to verify the provided JWT token."""
payload = _fetch_public_key_payload()
if payload is None:
logger.error("Failed to retrieve public key payload")
return None
key_material, key_format = payload
if key_format is PublicKeyFormat.JWKS:
jwks_data = cast(dict[str, Any], key_material)
return _resolve_public_key_from_jwks(token, jwks_data)
return cast(str, key_material)
def _resolve_public_key_from_jwks(
token: str, jwks_payload: dict[str, Any]
) -> RSAPublicKey | None:
try:
header = jwt.get_unverified_header(token)
except PyJWTError as e:
logger.error(f"Unable to parse JWT header: {str(e)}")
return None
keys = jwks_payload.get("keys", []) if isinstance(jwks_payload, dict) else []
if not keys:
logger.error("JWKS payload did not contain any keys")
return None
kid = header.get("kid")
thumbprint = header.get("x5t")
candidates = []
if kid:
candidates = [k for k in keys if k.get("kid") == kid]
if not candidates and thumbprint:
candidates = [k for k in keys if k.get("x5t") == thumbprint]
if not candidates and len(keys) == 1:
candidates = keys
if not candidates:
logger.warning(
"No matching JWK found for token header (kid=%s, x5t=%s)", kid, thumbprint
)
return None
if len(candidates) > 1:
logger.warning(
"Multiple JWKs matched token header kid=%s; selecting the first occurrence",
kid,
)
jwk = candidates[0]
try:
return cast(RSAPublicKey, RSAAlgorithm.from_jwk(json.dumps(jwk)))
except ValueError as e:
logger.error(f"Failed to construct RSA key from JWK: {str(e)}")
return None
async def verify_jwt_token(token: str, async_db_session: AsyncSession) -> User | None:
for attempt in range(_PUBLIC_KEY_FETCH_ATTEMPTS):
public_key = get_public_key(token)
if public_key is None:
logger.error("Unable to resolve a public key for JWT verification")
if attempt < _PUBLIC_KEY_FETCH_ATTEMPTS - 1:
_fetch_public_key_payload.cache_clear()
continue
return None
try:
from sqlalchemy import func
payload = jwt_decode(
token,
public_key,
algorithms=["RS256"],
options={"verify_aud": False},
)
except InvalidTokenError as e:
logger.error(f"Invalid JWT token: {str(e)}")
if attempt < _PUBLIC_KEY_FETCH_ATTEMPTS - 1:
_fetch_public_key_payload.cache_clear()
continue
return None
except PyJWTError as e:
logger.error(f"JWT decoding error: {str(e)}")
if attempt < _PUBLIC_KEY_FETCH_ATTEMPTS - 1:
_fetch_public_key_payload.cache_clear()
continue
return None
email = payload.get("email")
if email:
result = await async_db_session.execute(
select(User).where(func.lower(User.email) == func.lower(email))
)
return result.scalars().first()
logger.warning(
"JWT token decoded successfully but no email claim found; skipping auth"
)
break
return None

View File

@@ -3,12 +3,14 @@ from typing import Any
from typing import cast
from onyx.auth.schemas import UserRole
from onyx.configs.constants import KV_NO_AUTH_USER_PERSONALIZATION_KEY
from onyx.configs.constants import KV_NO_AUTH_USER_PREFERENCES_KEY
from onyx.configs.constants import NO_AUTH_USER_EMAIL
from onyx.configs.constants import NO_AUTH_USER_ID
from onyx.key_value_store.store import KeyValueStore
from onyx.key_value_store.store import KvKeyNotFoundError
from onyx.server.manage.models import UserInfo
from onyx.server.manage.models import UserPersonalization
from onyx.server.manage.models import UserPreferences
@@ -18,6 +20,12 @@ def set_no_auth_user_preferences(
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.model_dump())
def set_no_auth_user_personalization(
store: KeyValueStore, personalization: UserPersonalization
) -> None:
store.store(KV_NO_AUTH_USER_PERSONALIZATION_KEY, personalization.model_dump())
def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
try:
preferences_data = cast(
@@ -33,6 +41,15 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
def fetch_no_auth_user(
store: KeyValueStore, *, anonymous_user_enabled: bool | None = None
) -> UserInfo:
personalization = UserPersonalization()
try:
personalization_data = cast(
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PERSONALIZATION_KEY)
)
personalization = UserPersonalization(**personalization_data)
except KvKeyNotFoundError:
pass
return UserInfo(
id=NO_AUTH_USER_ID,
email=NO_AUTH_USER_EMAIL,
@@ -41,6 +58,7 @@ def fetch_no_auth_user(
is_verified=True,
role=UserRole.BASIC if anonymous_user_enabled else UserRole.ADMIN,
preferences=load_no_auth_user_preferences(store),
personalization=personalization,
is_anonymous_user=anonymous_user_enabled,
password_configured=False,
)

View File

@@ -54,6 +54,8 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import nulls_last
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from onyx.auth.api_key import get_hashed_api_key_from_request
@@ -61,6 +63,7 @@ from onyx.auth.email_utils import send_forgot_password_email
from onyx.auth.email_utils import send_user_verification_email
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import remove_user_from_invited_users
from onyx.auth.jwt import verify_jwt_token
from onyx.auth.schemas import AuthBackend
from onyx.auth.schemas import UserCreate
from onyx.auth.schemas import UserRole
@@ -70,6 +73,7 @@ from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import JWT_PUBLIC_KEY_URL
from onyx.configs.app_configs import PASSWORD_MAX_LENGTH
from onyx.configs.app_configs import PASSWORD_MIN_LENGTH
from onyx.configs.app_configs import PASSWORD_REQUIRE_DIGIT
@@ -103,19 +107,21 @@ from onyx.db.engine.async_sql_engine import get_async_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.saml import get_saml_account
from onyx.db.users import get_user_by_email
from onyx.redis.redis_pool import get_async_redis_connection
from onyx.redis.redis_pool import get_redis_client
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.secrets import extract_hashed_cookie
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.timing import log_function_time
from onyx.utils.url import add_url_params
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import async_return_default_schema
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
@@ -134,10 +140,9 @@ def is_user_admin(user: User | None) -> bool:
def verify_auth_setting() -> None:
if AUTH_TYPE not in [AuthType.DISABLED, AuthType.BASIC, AuthType.GOOGLE_OAUTH]:
if AUTH_TYPE == AuthType.CLOUD:
raise ValueError(
"User must choose a valid user authentication method: "
"disabled, basic, or google_oauth"
f"{AUTH_TYPE.value} is not a valid auth type for self-hosted deployments."
)
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
@@ -324,8 +329,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
):
user_create.role = UserRole.ADMIN
user_created = False
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
user = await super().create(
user_create, safe=safe, request=request
) # type: ignore
user_created = True
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
@@ -351,11 +360,42 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
role=user_create.role,
)
user = await self.update(user_update, user)
if user_created:
await self._assign_default_pinned_assistants(user, db_session)
remove_user_from_invited_users(user_create.email)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return user
async def _assign_default_pinned_assistants(
self, user: User, db_session: AsyncSession
) -> None:
if user.pinned_assistants is not None:
return
result = await db_session.execute(
select(Persona.id)
.where(
Persona.is_default_persona.is_(True),
Persona.is_public.is_(True),
Persona.is_visible.is_(True),
Persona.deleted.is_(False),
)
.order_by(
nulls_last(Persona.display_priority.asc()),
Persona.id.asc(),
)
)
default_persona_ids = list(result.scalars().all())
if not default_persona_ids:
return
await self.user_db.update(
user,
{"pinned_assistants": default_persona_ids},
)
user.pinned_assistants = default_persona_ids
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
# Validate password according to configurable security policy (defined via environment variables)
if len(password) < PASSWORD_MIN_LENGTH:
@@ -476,6 +516,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user = await self.user_db.create(user_dict)
await self.user_db.add_oauth_account(user, oauth_account_dict)
await self._assign_default_pinned_assistants(user, db_session)
await self.on_after_register(user, request)
else:
@@ -1018,13 +1059,28 @@ fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
optional_fastapi_current_user = fastapi_users.current_user(active=True, optional=True)
async def optional_user_(
async def _check_for_saml_and_jwt(
request: Request,
user: User | None,
async_db_session: AsyncSession,
) -> User | None:
"""NOTE: `request` and `db_session` are not used here, but are included
for the EE version of this function."""
# Check if the user has a session cookie from SAML
if AUTH_TYPE == AuthType.SAML:
saved_cookie = extract_hashed_cookie(request)
if saved_cookie:
saml_account = await get_saml_account(
cookie=saved_cookie, async_db_session=async_db_session
)
user = saml_account.user if saml_account else None
# If user is still None, check for JWT in Authorization header
if user is None and JWT_PUBLIC_KEY_URL is not None:
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header[len("Bearer ") :].strip()
user = await verify_jwt_token(token, async_db_session)
return user
@@ -1033,14 +1089,14 @@ async def optional_user(
async_db_session: AsyncSession = Depends(get_async_session),
user: User | None = Depends(optional_fastapi_current_user),
) -> User | None:
versioned_fetch_user = fetch_versioned_implementation(
"onyx.auth.users", "optional_user_"
)
user = await versioned_fetch_user(request, user, async_db_session)
user = await _check_for_saml_and_jwt(request, user, async_db_session)
# check if an API key is present
if user is None:
hashed_api_key = get_hashed_api_key_from_request(request)
try:
hashed_api_key = get_hashed_api_key_from_request(request)
except ValueError:
hashed_api_key = None
if hashed_api_key:
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)

View File

@@ -0,0 +1,137 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.background")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
logger.info("worker_init signal received for consolidated background worker.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME)
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
# Initialize Vespa httpx pool (needed for light worker tasks)
if MANAGED_VESPA:
httpx_init_vespa_pool(
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
ssl_cert=VESPA_CLOUD_CERT_PATH,
ssl_key=VESPA_CLOUD_KEY_PATH,
)
else:
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None:
SqlEngine.reset_engine()
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
base_bootsteps = app_base.get_bootsteps()
for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
# Original background worker tasks
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.kg_processing",
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.user_file_processing",
# Light worker tasks
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
# Docprocessing worker tasks
"onyx.background.celery.tasks.docprocessing",
# Docfetching worker tasks
"onyx.background.celery.tasks.docfetching",
]
)

View File

@@ -98,5 +98,8 @@ for bootstep in base_bootsteps:
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.docfetching",
# Ensure the user files indexing worker registers the doc_id migration task
# TODO(subash): remove this once the doc_id migration is complete
"onyx.background.celery.tasks.user_file_processing",
]
)

View File

@@ -324,5 +324,6 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.llm_model_update",
"onyx.background.celery.tasks.kg_processing",
"onyx.background.celery.tasks.user_file_processing",
]
)

View File

@@ -19,7 +19,9 @@ from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -30,7 +32,7 @@ PRUNING_CHECKPOINTED_BATCH_SIZE = 32
def document_batch_to_ids(
doc_batch: Iterator[list[Document]],
doc_batch: Iterator[list[Document]] | Iterator[list[SlimDocument]],
) -> Generator[set[str], None, None]:
for doc_list in doc_batch:
yield {doc.id for doc in doc_list}
@@ -41,20 +43,24 @@ def extract_ids_from_runnable_connector(
callback: IndexingHeartbeatInterface | None = None,
) -> set[str]:
"""
If the SlimConnector hasnt been implemented for the given connector, just pull
If the given connector is neither a SlimConnector nor a SlimConnectorWithPermSync, just pull
all docs using the load_from_state and grab out the IDs.
Optionally, a callback can be passed to handle the length of each document batch.
"""
all_connector_doc_ids: set[str] = set()
if isinstance(runnable_connector, SlimConnector):
for metadata_batch in runnable_connector.retrieve_all_slim_documents():
all_connector_doc_ids.update({doc.id for doc in metadata_batch})
doc_batch_id_generator = None
if isinstance(runnable_connector, LoadConnector):
if isinstance(runnable_connector, SlimConnector):
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.retrieve_all_slim_docs()
)
elif isinstance(runnable_connector, SlimConnectorWithPermSync):
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.retrieve_all_slim_docs_perm_sync()
)
# If the connector isn't slim, fall back to running it normally to get ids
elif isinstance(runnable_connector, LoadConnector):
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.load_from_state()
)
@@ -78,13 +84,14 @@ def extract_ids_from_runnable_connector(
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
# this function is called per batch for rate limiting
def doc_batch_processing_func(doc_batch_ids: set[str]) -> set[str]:
return doc_batch_ids
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
doc_batch_processing_func = rate_limit_builder(
doc_batch_processing_func = (
rate_limit_builder(
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(lambda x: x)
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
else lambda x: x
)
for doc_batch_ids in doc_batch_id_generator:
if callback:
if callback.should_stop():

View File

@@ -0,0 +1,23 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_BACKGROUND_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
worker_concurrency = CELERY_WORKER_BACKGROUND_CONCURRENCY
worker_pool = "threads"
# Increased from 1 to 4 to handle fast light worker tasks more efficiently
# This allows the worker to prefetch multiple tasks per thread
worker_prefetch_multiplier = 4

View File

@@ -1,4 +1,5 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_HEAVY_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
@@ -15,6 +16,6 @@ result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
worker_concurrency = 4
worker_concurrency = CELERY_WORKER_HEAVY_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -1,4 +1,5 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_MONITORING_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
@@ -16,6 +17,6 @@ task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
# Monitoring worker specific settings
worker_concurrency = 1 # Single worker is sufficient for monitoring
worker_concurrency = CELERY_WORKER_MONITORING_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -33,17 +33,34 @@ beat_task_templates: list[dict] = [
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.USER_FILE_PROCESSING,
},
},
{
"name": "check-for-user-file-project-sync",
"task": OnyxCeleryTask.CHECK_FOR_USER_FILE_PROJECT_SYNC,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-user-file-delete",
"task": OnyxCeleryTask.CHECK_FOR_USER_FILE_DELETE,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "user-file-docid-migration",
"task": OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
"schedule": timedelta(minutes=1),
"schedule": timedelta(minutes=10),
"options": {
"priority": OnyxCeleryPriority.LOW,
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.USER_FILE_PROCESSING,
"queue": OnyxCeleryQueues.USER_FILES_INDEXING,
},
},
{
@@ -85,9 +102,9 @@ beat_task_templates: list[dict] = [
{
"name": "check-for-index-attempt-cleanup",
"task": OnyxCeleryTask.CHECK_FOR_INDEX_ATTEMPT_CLEANUP,
"schedule": timedelta(hours=1),
"schedule": timedelta(minutes=30),
"options": {
"priority": OnyxCeleryPriority.LOW,
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},

View File

@@ -42,6 +42,12 @@ from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.index_attempt import delete_index_attempts
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
from onyx.db.permission_sync_attempt import (
delete_doc_permission_sync_attempts__no_commit,
)
from onyx.db.permission_sync_attempt import (
delete_external_group_permission_sync_attempts__no_commit,
)
from onyx.db.search_settings import get_all_search_settings
from onyx.db.sync_record import cleanup_sync_records
from onyx.db.sync_record import insert_sync_record
@@ -441,6 +447,16 @@ def monitor_connector_deletion_taskset(
cc_pair_id=cc_pair_id,
)
# permission sync attempts
delete_doc_permission_sync_attempts__no_commit(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
delete_external_group_permission_sync_attempts__no_commit(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
# document sets
delete_document_set_cc_pair_relationship__no_commit(
db_session=db_session,

View File

@@ -89,6 +89,7 @@ from onyx.indexing.adapters.document_indexing_adapter import (
DocumentIndexingBatchAdapter,
)
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
@@ -1270,8 +1271,6 @@ def _docprocessing_task(
tenant_id: str,
batch_num: int,
) -> None:
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
start_time = time.monotonic()
if tenant_id:

View File

@@ -895,6 +895,9 @@ def monitor_celery_queues_helper(
n_user_file_project_sync = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, r_celery
)
n_user_file_delete = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_DELETE, r_celery
)
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
n_deletion = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
n_pruning = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery)
@@ -924,6 +927,7 @@ def monitor_celery_queues_helper(
f"user_files_indexing={n_user_files_indexing} "
f"user_file_processing={n_user_file_processing} "
f"user_file_project_sync={n_user_file_project_sync} "
f"user_file_delete={n_user_file_delete} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning} "

View File

@@ -9,17 +9,19 @@ import sqlalchemy as sa
from celery import shared_task
from celery import Task
from redis.lock import Lock as RedisLock
from retry import retry
from sqlalchemy import select
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import OnyxCeleryPriority
@@ -36,15 +38,15 @@ from onyx.db.models import UserFile
from onyx.db.search_settings import get_active_search_settings
from onyx.db.search_settings import get_active_search_settings_list
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.document_index.interfaces import VespaDocumentUserFields
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
)
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa_constants import USER_PROJECT
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.file_store import S3BackedFileStore
from onyx.file_store.utils import user_file_id_to_plaintext_file_name
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAdapter
from onyx.indexing.embedder import DefaultIndexingEmbedder
@@ -68,6 +70,56 @@ def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
def _user_file_delete_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_DELETE_LOCK_PREFIX}:{user_file_id}"
@retry(tries=3, delay=1, backoff=2, jitter=(0.0, 1.0))
def _visit_chunks(
*,
http_client: httpx.Client,
index_name: str,
selection: str,
continuation: str | None = None,
) -> tuple[list[dict[str, Any]], str | None]:
task_logger.info(
f"Visiting chunks for index={index_name} with selection={selection}"
)
base_url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
params: dict[str, str] = {
"selection": selection,
"wantedDocumentCount": "100", # Use smaller batch size to avoid timeouts
}
if continuation:
params["continuation"] = continuation
resp = http_client.get(base_url, params=params, timeout=None)
resp.raise_for_status()
payload = resp.json()
return payload.get("documents", []), payload.get("continuation")
def _get_document_chunk_count(
*,
index_name: str,
selection: str,
) -> int:
chunk_count = 0
continuation = None
while True:
docs, continuation = _visit_chunks(
http_client=HttpxPool.get("vespa"),
index_name=index_name,
selection=selection,
continuation=continuation,
)
if not docs:
break
chunk_count += len(docs)
if not continuation:
break
return chunk_count
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_PROCESSING,
soft_time_limit=300,
@@ -134,7 +186,8 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
redis_client = get_redis_client(tenant_id=tenant_id)
file_lock: RedisLock = redis_client.lock(
_user_file_lock_key(user_file_id), timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
)
if not file_lock.acquire(blocking=False):
@@ -244,18 +297,31 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
task_logger.error(
f"process_single_user_file - Indexing pipeline failed id={user_file_id}"
)
uf.status = UserFileStatus.FAILED
db_session.add(uf)
db_session.commit()
# don't update the status if the user file is being deleted
# Re-fetch to avoid mypy error
current_user_file = db_session.get(UserFile, _as_uuid(user_file_id))
if (
current_user_file
and current_user_file.status != UserFileStatus.DELETING
):
uf.status = UserFileStatus.FAILED
db_session.add(uf)
db_session.commit()
return None
except Exception as e:
task_logger.exception(
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
)
uf.status = UserFileStatus.FAILED
db_session.add(uf)
db_session.commit()
# don't update the status if the user file is being deleted
current_user_file = db_session.get(UserFile, _as_uuid(user_file_id))
if (
current_user_file
and current_user_file.status != UserFileStatus.DELETING
):
uf.status = UserFileStatus.FAILED
db_session.add(uf)
db_session.commit()
return None
elapsed = time.monotonic() - start
@@ -268,7 +334,9 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
with get_session_with_current_tenant() as db_session:
uf = db_session.get(UserFile, _as_uuid(user_file_id))
if uf:
uf.status = UserFileStatus.FAILED
# don't update the status if the user file is being deleted
if uf.status != UserFileStatus.DELETING:
uf.status = UserFileStatus.FAILED
db_session.add(uf)
db_session.commit()
@@ -281,6 +349,149 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
file_lock.release()
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_DELETE,
soft_time_limit=300,
bind=True,
ignore_result=True,
)
def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with DELETING status and enqueue per-file tasks."""
task_logger.info("check_for_user_file_delete - Starting")
redis_client = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = redis_client.lock(
OnyxRedisLocks.USER_FILE_DELETE_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
if not lock.acquire(blocking=False):
return None
enqueued = 0
try:
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
select(UserFile.id).where(
UserFile.status == UserFileStatus.DELETING
)
)
.scalars()
.all()
)
for user_file_id in user_file_ids:
self.app.send_task(
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_DELETE,
priority=OnyxCeleryPriority.HIGH,
)
enqueued += 1
except Exception as e:
task_logger.exception(
f"check_for_user_file_delete - Error enqueuing deletes - {e.__class__.__name__}"
)
return None
finally:
if lock.owned():
lock.release()
task_logger.info(
f"check_for_user_file_delete - Enqueued {enqueued} tasks for tenant={tenant_id}"
)
return None
@shared_task(
name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
bind=True,
ignore_result=True,
)
def process_single_user_file_delete(
self: Task, *, user_file_id: str, tenant_id: str
) -> None:
"""Process a single user file delete."""
task_logger.info(f"process_single_user_file_delete - Starting id={user_file_id}")
redis_client = get_redis_client(tenant_id=tenant_id)
file_lock: RedisLock = redis_client.lock(
_user_file_delete_lock_key(user_file_id),
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
if not file_lock.acquire(blocking=False):
task_logger.info(
f"process_single_user_file_delete - Lock held, skipping user_file_id={user_file_id}"
)
return None
try:
with get_session_with_current_tenant() as db_session:
# 20 is the documented default for httpx max_keepalive_connections
if MANAGED_VESPA:
httpx_init_vespa_pool(
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
)
else:
httpx_init_vespa_pool(20)
active_search_settings = get_active_search_settings(db_session)
document_index = get_default_document_index(
search_settings=active_search_settings.primary,
secondary_search_settings=active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
)
retry_index = RetryDocumentIndex(document_index)
index_name = active_search_settings.primary.index_name
selection = f"{index_name}.document_id=='{user_file_id}'"
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
if not user_file:
task_logger.info(
f"process_single_user_file_delete - User file not found id={user_file_id}"
)
return None
# 1) Delete Vespa chunks for the document
chunk_count = 0
if user_file.chunk_count is None or user_file.chunk_count == 0:
chunk_count = _get_document_chunk_count(
index_name=index_name,
selection=selection,
)
else:
chunk_count = user_file.chunk_count
retry_index.delete_single(
doc_id=user_file_id,
tenant_id=tenant_id,
chunk_count=chunk_count,
)
# 2) Delete the user-uploaded file content from filestore (blob + metadata)
file_store = get_default_file_store()
try:
file_store.delete_file(user_file.file_id)
file_store.delete_file(
user_file_id_to_plaintext_file_name(user_file.id)
)
except Exception as e:
# This block executed only if the file is not found in the filestore
task_logger.exception(
f"process_single_user_file_delete - Error deleting file id={user_file.id} - {e.__class__.__name__}"
)
# 3) Finally, delete the UserFile row
db_session.delete(user_file)
db_session.commit()
task_logger.info(
f"process_single_user_file_delete - Completed id={user_file_id}"
)
except Exception as e:
task_logger.exception(
f"process_single_user_file_delete - Error processing file id={user_file_id} - {e.__class__.__name__}"
)
return None
finally:
if file_lock.owned():
file_lock.release()
return None
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_PROJECT_SYNC,
soft_time_limit=300,
@@ -306,8 +517,10 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
user_file_ids = (
db_session.execute(
select(UserFile.id).where(
UserFile.needs_project_sync.is_(True)
and UserFile.status == UserFileStatus.COMPLETED
sa.and_(
UserFile.needs_project_sync.is_(True),
UserFile.status == UserFileStatus.COMPLETED,
)
)
)
.scalars()
@@ -348,7 +561,7 @@ def process_single_user_file_project_sync(
redis_client = get_redis_client(tenant_id=tenant_id)
file_lock: RedisLock = redis_client.lock(
_user_file_project_sync_lock_key(user_file_id),
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
)
if not file_lock.acquire(blocking=False):
@@ -359,6 +572,15 @@ def process_single_user_file_project_sync(
try:
with get_session_with_current_tenant() as db_session:
# 20 is the documented default for httpx max_keepalive_connections
if MANAGED_VESPA:
httpx_init_vespa_pool(
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
)
else:
httpx_init_vespa_pool(20)
active_search_settings = get_active_search_settings(db_session)
doc_index = get_default_document_index(
search_settings=active_search_settings.primary,
@@ -416,147 +638,155 @@ def _normalize_legacy_user_file_doc_id(old_id: str) -> str:
return old_id
def _visit_chunks(
*,
http_client: httpx.Client,
index_name: str,
selection: str,
continuation: str | None = None,
) -> tuple[list[dict[str, Any]], str | None]:
base_url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
params: dict[str, str] = {
"selection": selection,
"wantedDocumentCount": "1000",
}
if continuation:
params["continuation"] = continuation
resp = http_client.get(base_url, params=params, timeout=None)
resp.raise_for_status()
payload = resp.json()
return payload.get("documents", []), payload.get("continuation")
def update_legacy_plaintext_file_records() -> None:
"""Migrate legacy plaintext cache objects from int-based keys to UUID-based
keys. Copies each S3 object to its expected UUID key and updates DB.
Examples:
- Old key: bucket/schema/plaintext_<int>
- New key: bucket/schema/plaintext_<uuid>
"""
def _update_document_id_in_vespa(
*,
index_name: str,
old_doc_id: str,
new_doc_id: str,
user_project_ids: list[int] | None = None,
) -> None:
clean_new_doc_id = replace_invalid_doc_id_characters(new_doc_id)
normalized_old = _normalize_legacy_user_file_doc_id(old_doc_id)
clean_old_doc_id = replace_invalid_doc_id_characters(normalized_old)
task_logger.info("update_legacy_plaintext_file_records - Starting")
selection = f"{index_name}.document_id=='{clean_old_doc_id}'"
task_logger.debug(f"Vespa selection: {selection}")
with get_session_with_current_tenant() as db_session:
store = get_default_file_store()
with get_vespa_http_client() as http_client:
continuation: str | None = None
while True:
docs, continuation = _visit_chunks(
http_client=http_client,
index_name=index_name,
selection=selection,
continuation=continuation,
if not isinstance(store, S3BackedFileStore):
task_logger.info(
"update_legacy_plaintext_file_records - Skipping non-S3 store"
)
if not docs:
break
for doc in docs:
vespa_full_id = doc.get("id")
if not vespa_full_id:
return
s3_client = store._get_s3_client()
bucket_name = store._get_bucket_name()
# Select PLAINTEXT_CACHE records whose object_key ends with 'plaintext_' + non-hyphen chars
# Example: 'some/path/plaintext_abc123' matches; '.../plaintext_foo-bar' does not
plaintext_records: Sequence[FileRecord] = (
db_session.execute(
sa.select(FileRecord).where(
FileRecord.file_origin == FileOrigin.PLAINTEXT_CACHE,
FileRecord.object_key.op("~")(r"plaintext_[^-]+$"),
)
)
.scalars()
.all()
)
task_logger.info(
f"update_legacy_plaintext_file_records - Found {len(plaintext_records)} plaintext records to update"
)
normalized = 0
for fr in plaintext_records:
try:
expected_key = store._get_s3_key(fr.file_id)
if fr.object_key == expected_key:
continue
vespa_doc_uuid = vespa_full_id.split("::")[-1]
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_doc_uuid}"
update_request: dict[str, Any] = {
"fields": {"document_id": {"assign": clean_new_doc_id}}
}
if user_project_ids is not None:
update_request["fields"][USER_PROJECT] = {
"assign": user_project_ids
}
r = http_client.put(vespa_url, json=update_request)
r.raise_for_status()
if not continuation:
break
if fr.bucket_name is None:
task_logger.warning(f"id={fr.file_id} - Bucket name is None")
continue
if fr.object_key is None:
task_logger.warning(f"id={fr.file_id} - Object key is None")
continue
# Copy old object to new key
copy_source = f"{fr.bucket_name}/{fr.object_key}"
s3_client.copy_object(
CopySource=copy_source,
Bucket=bucket_name,
Key=expected_key,
MetadataDirective="COPY",
)
# Delete old object (best-effort)
try:
s3_client.delete_object(Bucket=fr.bucket_name, Key=fr.object_key)
except Exception:
pass
# Update DB record with new key
fr.object_key = expected_key
db_session.add(fr)
normalized += 1
except Exception as e:
task_logger.warning(f"id={fr.file_id} - {e.__class__.__name__}")
if normalized:
db_session.commit()
task_logger.info(
f"user_file_docid_migration_task normalized {normalized} plaintext objects"
)
@shared_task(
name=OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
ignore_result=True,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
bind=True,
)
def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
"""Per-tenant job to update Vespa and search_doc document_id values for user files.
- For each user_file with a legacy document_id, set Vespa `document_id` to the UUID `user_file.id`.
- Update `search_doc.document_id` to the same UUID string.
"""
task_logger.info(
f"user_file_docid_migration_task - Starting for tenant={tenant_id}"
)
redis_client = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = redis_client.lock(
OnyxRedisLocks.USER_FILE_DOCID_MIGRATION_LOCK,
timeout=CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT,
)
if not lock.acquire(blocking=False):
task_logger.info(
f"user_file_docid_migration_task - Lock held, skipping tenant={tenant_id}"
)
return False
updated_count = 0
try:
update_legacy_plaintext_file_records()
# Track lock renewal
last_lock_time = time.monotonic()
with get_session_with_current_tenant() as db_session:
# 20 is the documented default for httpx max_keepalive_connections
if MANAGED_VESPA:
httpx_init_vespa_pool(
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
)
else:
httpx_init_vespa_pool(20)
active_settings = get_active_search_settings(db_session)
document_index = get_default_document_index(
active_settings.primary,
active_settings.secondary,
search_settings=active_settings.primary,
secondary_search_settings=active_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
)
if hasattr(document_index, "index_name"):
index_name = document_index.index_name
else:
index_name = "danswer_index"
# Fetch mappings of legacy -> new ids
rows = db_session.execute(
sa.select(
UserFile.document_id.label("document_id"),
UserFile.id.label("id"),
).where(
UserFile.document_id.is_not(None),
UserFile.document_id_migrated.is_(False),
)
).all()
retry_index = RetryDocumentIndex(document_index)
# dedupe by old document_id
seen: set[str] = set()
for row in rows:
old_doc_id = str(row.document_id)
new_uuid = str(row.id)
if not old_doc_id or not new_uuid or old_doc_id in seen:
continue
seen.add(old_doc_id)
# collect user project ids for a combined Vespa update
user_project_ids: list[int] | None = None
try:
uf = db_session.get(UserFile, UUID(new_uuid))
if uf is not None:
user_project_ids = [project.id for project in uf.projects]
except Exception as e:
task_logger.warning(
f"Tenant={tenant_id} failed fetching projects for doc_id={new_uuid} - {e.__class__.__name__}"
)
try:
_update_document_id_in_vespa(
index_name=index_name,
old_doc_id=old_doc_id,
new_doc_id=new_uuid,
user_project_ids=user_project_ids,
)
except Exception as e:
task_logger.warning(
f"Tenant={tenant_id} failed Vespa update for doc_id={new_uuid} - {e.__class__.__name__}"
)
# Update search_doc records to refer to the UUID string
# we are not using document_id_migrated = false because if the migration already completed,
# it will not run again and we will not update the search_doc records because of the issue currently fixed
# Select user files with a legacy doc id that have not been migrated
user_files = (
db_session.execute(
sa.select(UserFile).where(UserFile.document_id.is_not(None))
sa.select(UserFile).where(
sa.and_(
UserFile.document_id.is_not(None),
UserFile.document_id_migrated.is_(False),
)
)
)
.scalars()
.all()
)
task_logger.info(
f"user_file_docid_migration_task - Found {len(user_files)} user files to migrate"
)
# Query all SearchDocs that need updating
search_docs = (
db_session.execute(
@@ -567,9 +797,9 @@ def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
.scalars()
.all()
)
task_logger.info(f"Found {len(user_files)} user files to update")
task_logger.info(f"Found {len(search_docs)} search docs to update")
task_logger.info(
f"user_file_docid_migration_task - Found {len(search_docs)} search docs to update"
)
# Build a map of normalized doc IDs to SearchDocs
search_doc_map: dict[str, list[SearchDoc]] = {}
@@ -579,102 +809,129 @@ def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
search_doc_map[doc_id] = []
search_doc_map[doc_id].append(sd)
# Process each UserFile and update matching SearchDocs
updated_count = 0
for uf in user_files:
doc_id = uf.document_id
if doc_id.startswith("USER_FILE_CONNECTOR__"):
doc_id = "FILE_CONNECTOR__" + doc_id[len("USER_FILE_CONNECTOR__") :]
task_logger.debug(
f"user_file_docid_migration_task - Built search doc map with {len(search_doc_map)} entries"
)
if doc_id in search_doc_map:
# Update the SearchDoc to use the UserFile's UUID
for search_doc in search_doc_map[doc_id]:
search_doc.document_id = str(uf.id)
db_session.add(search_doc)
ids_preview = list(search_doc_map.keys())[:5]
task_logger.debug(
f"user_file_docid_migration_task - First few search_doc_map ids: {ids_preview if ids_preview else 'No ids found'}"
)
task_logger.debug(
f"user_file_docid_migration_task - search_doc_map total items: "
f"{sum(len(docs) for docs in search_doc_map.values())}"
)
for user_file in user_files:
# Periodically renew the Redis lock to prevent expiry mid-run
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT / 4
):
renewed = False
try:
# extend lock ttl to full timeout window
lock.extend(CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT)
renewed = True
except Exception:
# if extend fails, best-effort reacquire as a fallback
try:
lock.reacquire()
renewed = True
except Exception:
renewed = False
last_lock_time = current_time
if not renewed or not lock.owned():
task_logger.error(
"user_file_docid_migration_task - Lost lock ownership or failed to renew; aborting for safety"
)
return False
# Mark UserFile as migrated
uf.document_id_migrated = True
db_session.add(uf)
try:
clean_old_doc_id = replace_invalid_doc_id_characters(
user_file.document_id
)
normalized_doc_id = _normalize_legacy_user_file_doc_id(
clean_old_doc_id
)
user_project_ids = [project.id for project in user_file.projects]
task_logger.info(
f"user_file_docid_migration_task - Migrating user file {user_file.id} with doc_id {normalized_doc_id}"
)
index_name = active_settings.primary.index_name
# First find the chunks count using direct Vespa query
selection = f"{index_name}.document_id=='{normalized_doc_id}'"
# Count all chunks for this document
chunk_count = _get_document_chunk_count(
index_name=index_name,
selection=selection,
)
task_logger.info(
f"Found {chunk_count} chunks for document {normalized_doc_id}"
)
# Now update Vespa chunks with the found chunk count using retry_index
updated_chunks = retry_index.update_single(
doc_id=str(normalized_doc_id),
tenant_id=tenant_id,
chunk_count=chunk_count,
fields=VespaDocumentFields(document_id=str(user_file.id)),
user_fields=VespaDocumentUserFields(
user_projects=user_project_ids
),
)
user_file.chunk_count = updated_chunks
# Update the SearchDocs
actual_doc_id = str(user_file.document_id)
normalized_actual_doc_id = _normalize_legacy_user_file_doc_id(
actual_doc_id
)
if (
normalized_doc_id in search_doc_map
or normalized_actual_doc_id in search_doc_map
):
to_update = (
search_doc_map[normalized_doc_id]
if normalized_doc_id in search_doc_map
else search_doc_map[normalized_actual_doc_id]
)
task_logger.debug(
f"user_file_docid_migration_task - Updating {len(to_update)} search docs for user file {user_file.id}"
)
for search_doc in to_update:
search_doc.document_id = str(user_file.id)
db_session.add(search_doc)
user_file.document_id_migrated = True
db_session.add(user_file)
db_session.commit()
updated_count += 1
except Exception as per_file_exc:
# Rollback the current transaction and continue with the next file
db_session.rollback()
task_logger.exception(
f"user_file_docid_migration_task - Error migrating user file {user_file.id} - "
f"{per_file_exc.__class__.__name__}"
)
task_logger.info(
f"Updated {updated_count} SearchDoc records with new UUIDs"
f"user_file_docid_migration_task - Updated {updated_count} user files"
)
db_session.commit()
# Normalize plaintext FileRecord blobs: ensure S3 object key aligns with current file_id
try:
store = get_default_file_store()
# Only supported for S3-backed stores where we can manipulate object keys
if isinstance(store, S3BackedFileStore):
s3_client = store._get_s3_client()
bucket_name = store._get_bucket_name()
plaintext_records: Sequence[FileRecord] = (
db_session.execute(
sa.select(FileRecord).where(
FileRecord.file_origin == FileOrigin.PLAINTEXT_CACHE,
FileRecord.file_id.like("plaintext_%"),
)
)
.scalars()
.all()
)
normalized = 0
for fr in plaintext_records:
try:
expected_key = store._get_s3_key(fr.file_id)
if fr.object_key == expected_key:
continue
# Copy old object to new key
copy_source = f"{fr.bucket_name}/{fr.object_key}"
s3_client.copy_object(
CopySource=copy_source,
Bucket=bucket_name,
Key=expected_key,
MetadataDirective="COPY",
)
# Delete old object (best-effort)
try:
s3_client.delete_object(
Bucket=fr.bucket_name, Key=fr.object_key
)
except Exception:
pass
# Update DB record with new key
fr.object_key = expected_key
db_session.add(fr)
normalized += 1
except Exception as e:
task_logger.warning(
f"Tenant={tenant_id} failed plaintext object normalize for "
f"id={fr.file_id} - {e.__class__.__name__}"
)
if normalized:
db_session.commit()
task_logger.info(
f"user_file_docid_migration_task normalized {normalized} plaintext objects for tenant={tenant_id}"
)
else:
task_logger.info(
"user_file_docid_migration_task skipping plaintext object normalization (non-S3 store)"
)
except Exception:
task_logger.exception(
f"user_file_docid_migration_task - Error during plaintext normalization for tenant={tenant_id}"
)
task_logger.info(
f"user_file_docid_migration_task completed for tenant={tenant_id} (rows={len(rows)})"
f"user_file_docid_migration_task - Completed for tenant={tenant_id} (updated={updated_count})"
)
return True
except Exception:
except Exception as e:
task_logger.exception(
f"user_file_docid_migration_task - Error during execution for tenant={tenant_id}"
f"user_file_docid_migration_task - Error during execution for tenant={tenant_id} "
f"(updated={updated_count}) exception={e.__class__.__name__}"
)
return False
finally:
if lock.owned():
lock.release()

View File

@@ -0,0 +1,10 @@
from celery import Celery
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app: Celery = fetch_versioned_implementation(
"onyx.background.celery.apps.background",
"celery_app",
)

View File

@@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import NUM_DAYS_TO_KEEP_INDEX_ATTEMPTS
from onyx.db.engine.time_utils import get_db_current_time
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
def get_old_index_attempts(
@@ -21,6 +22,10 @@ def get_old_index_attempts(
def cleanup_index_attempts(db_session: Session, index_attempt_ids: list[int]) -> None:
"""Clean up multiple index attempts"""
db_session.query(IndexAttemptError).filter(
IndexAttemptError.index_attempt_id.in_(index_attempt_ids)
).delete(synchronize_session=False)
db_session.query(IndexAttempt).filter(
IndexAttempt.id.in_(index_attempt_ids)
).delete(synchronize_session=False)

View File

@@ -28,6 +28,7 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorStopSignal
@@ -101,7 +102,6 @@ def _get_connector_runner(
are the complete list of existing documents of the connector. If the task
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
"""
from onyx.connectors.factory import instantiate_connector
task = attempt.connector_credential_pair.connector.input_type

View File

@@ -0,0 +1,29 @@
from collections.abc import Callable
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.models import Memory
from onyx.db.models import User
def make_memories_callback(
user: User | None, db_session: Session
) -> Callable[[], list[str]]:
def memories_callback() -> list[str]:
if user is None:
return []
user_info = [
f"User's name: {user.personal_name}" if user.personal_name else "",
f"User's role: {user.personal_role}" if user.personal_role else "",
f"User's email: {user.email}" if user.email else "",
]
memory_rows = db_session.scalars(
select(Memory).where(Memory.user_id == user.id)
).all()
memories = [memory.memory_text for memory in memory_rows if memory.memory_text]
return user_info + memories
return memories_callback

View File

@@ -2,18 +2,20 @@ import re
import time
import traceback
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from typing import cast
from typing import Protocol
from uuid import UUID
from redis.client import Redis
from sqlalchemy.orm import Session
from onyx.agents.agent_search.orchestration.nodes.call_tool import ToolCallException
from onyx.chat.answer import Answer
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import create_temporary_persona
from onyx.chat.chat_utils import process_kg_commands
from onyx.chat.memories import make_memories_callback
from onyx.chat.models import AnswerStream
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import ChatBasicResponse
@@ -26,12 +28,16 @@ from onyx.chat.models import PromptConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import StreamingError
from onyx.chat.models import UserKnowledgeFilePacket
from onyx.chat.packet_proccessing.process_streamed_packets import (
process_streamed_packets,
)
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import (
default_build_system_message_v2,
)
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message_v2
from onyx.chat.turn import fast_chat_turn
from onyx.chat.turn.infra.emitter import get_default_emitter
from onyx.chat.turn.models import ChatTurnDependencies
from onyx.chat.user_files.parse_user_files import parse_user_files
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
@@ -70,6 +76,8 @@ from onyx.db.projects import get_project_instructions
from onyx.db.projects import get_user_files_from_project
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.feature_flags.factory import get_default_feature_flag_provider
from onyx.feature_flags.feature_flags_keys import SIMPLE_AGENT_FRAMEWORK
from onyx.file_store.models import FileDescriptor
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import build_frontend_file_url
@@ -82,6 +90,7 @@ from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.redis.redis_pool import get_redis_client
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationInfo
@@ -89,6 +98,7 @@ from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.utils import get_json_line
from onyx.tools.adapter_v1_to_v2 import tools_to_function_tools
from onyx.tools.force import ForceUseTool
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool import Tool
@@ -108,11 +118,14 @@ from onyx.utils.timing import log_function_time
from onyx.utils.timing import log_generator_function_time
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
ERROR_TYPE_CANCELLED = "cancelled"
class ToolCallException(Exception):
"""Exception raised for errors during tool calls."""
class PartialResponse(Protocol):
def __call__(
self,
@@ -137,24 +150,35 @@ def _build_project_llm_docs(
return project_llm_docs
project_file_id_set = set(project_file_ids)
for f in in_memory_user_files:
# Only include files that belong to the project (not ad-hoc uploads)
if project_file_id_set and (f.file_id in project_file_id_set):
try:
text_content = f.content.decode("utf-8", errors="ignore")
except Exception:
text_content = ""
# Build a short blurb from the file content for better UI display
blurb = (
(text_content[:200] + "...")
if len(text_content) > 200
else text_content
)
def _strip_nuls(s: str) -> str:
return s.replace("\x00", "") if s else s
for f in in_memory_user_files:
if project_file_id_set and (f.file_id in project_file_id_set):
cleaned_filename = _strip_nuls(f.filename or str(f.file_id))
if f.file_type.is_text_file():
try:
text_content = f.content.decode("utf-8", errors="ignore")
text_content = _strip_nuls(text_content)
except Exception:
text_content = ""
# Build a short blurb from the file content for better UI display
blurb = (
(text_content[:200] + "...")
if len(text_content) > 200
else text_content
)
else:
# Non-text (e.g., images): do not decode bytes; keep empty content but allow citation
text_content = ""
blurb = f"[{f.file_type.value}] {cleaned_filename}"
# Provide basic metadata to improve SavedSearchDoc display
file_metadata: dict[str, str | list[str]] = {
"filename": f.filename or str(f.file_id),
"filename": cleaned_filename,
"file_type": f.file_type.value,
}
@@ -163,7 +187,7 @@ def _build_project_llm_docs(
document_id=str(f.file_id),
content=text_content,
blurb=blurb,
semantic_identifier=f.filename or str(f.file_id),
semantic_identifier=cleaned_filename,
source_type=DocumentSource.USER_FILE,
metadata=file_metadata,
updated_at=None,
@@ -352,14 +376,12 @@ def stream_chat_message_objects(
long_term_logger = LongTermLogger(
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
)
persona = _get_persona_for_chat_session(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
default_persona=chat_session.persona,
)
# TODO: remove once we have an endpoint for this stuff
process_kg_commands(new_msg_req.message, persona.name, tenant_id, db_session)
@@ -729,15 +751,37 @@ def stream_chat_message_objects(
and (file.file_id not in project_file_ids)
]
)
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
feature_flag_provider = get_default_feature_flag_provider()
simple_agent_framework_enabled = (
feature_flag_provider.feature_enabled_for_user_tenant(
flag_key=SIMPLE_AGENT_FRAMEWORK,
user=user,
tenant_id=tenant_id,
)
and not new_msg_req.use_agentic_search
)
prompt_user_message = (
default_build_user_message_v2(
user_query=final_msg.message,
prompt_config=prompt_config,
files=latest_query_files,
single_message_history=single_message_history,
),
system_message=default_build_system_message(prompt_config, llm.config),
)
if simple_agent_framework_enabled
else default_build_user_message(
user_query=final_msg.message,
prompt_config=prompt_config,
files=latest_query_files,
)
)
mem_callback = make_memories_callback(user, db_session)
system_message = (
default_build_system_message_v2(prompt_config, llm.config, mem_callback)
if simple_agent_framework_enabled
else default_build_system_message(prompt_config, llm.config, mem_callback)
)
prompt_builder = AnswerPromptBuilder(
user_message=prompt_user_message,
system_message=system_message,
message_history=message_history,
llm_config=llm.config,
raw_user_query=final_msg.message,
@@ -779,11 +823,21 @@ def stream_chat_message_objects(
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
project_instructions=project_instructions,
)
if simple_agent_framework_enabled:
yield from _fast_message_stream(
answer,
tools,
db_session,
get_redis_client(),
chat_session_id,
reserved_message_id,
)
else:
from onyx.chat.packet_proccessing import process_streamed_packets
# Process streamed packets using the new packet processing module
yield from process_streamed_packets(
answer_processed_output=answer.processed_streamed_output,
)
yield from process_streamed_packets.process_streamed_packets(
answer_processed_output=answer.processed_streamed_output,
)
except ValueError as e:
logger.exception("Failed to process chat message.")
@@ -820,6 +874,59 @@ def stream_chat_message_objects(
return
# TODO: Refactor this to live somewhere else
def _fast_message_stream(
answer: Answer,
tools: list[Tool],
db_session: Session,
redis_client: Redis,
chat_session_id: UUID,
reserved_message_id: int,
) -> Generator[Packet, None, None]:
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
OktaProfileTool,
)
from onyx.llm.litellm_singleton import LitellmModel
image_generation_tool_instance = None
okta_profile_tool_instance = None
for tool in tools:
if isinstance(tool, ImageGenerationTool):
image_generation_tool_instance = tool
elif isinstance(tool, OktaProfileTool):
okta_profile_tool_instance = tool
converted_message_history = [
PreviousMessage.from_langchain_msg(message, 0).to_agent_sdk_msg()
for message in answer.graph_inputs.prompt_builder.build()
]
emitter = get_default_emitter()
return fast_chat_turn.fast_chat_turn(
messages=converted_message_history,
# TODO: Maybe we can use some DI framework here?
dependencies=ChatTurnDependencies(
llm_model=LitellmModel(
model=answer.graph_tooling.primary_llm.config.model_name,
base_url=answer.graph_tooling.primary_llm.config.api_base,
api_key=answer.graph_tooling.primary_llm.config.api_key,
),
llm=answer.graph_tooling.primary_llm,
tools=tools_to_function_tools(tools),
search_pipeline=answer.graph_tooling.search_tool,
image_generation_tool=image_generation_tool_instance,
okta_profile_tool=okta_profile_tool_instance,
db_session=db_session,
redis_client=redis_client,
emitter=emitter,
),
chat_session_id=chat_session_id,
message_id=reserved_message_id,
research_type=answer.graph_config.behavior.research_type,
)
@log_generator_function_time()
def stream_chat_message(
new_msg_req: CreateChatMessageRequest,

View File

@@ -21,8 +21,11 @@ from onyx.llm.utils import model_supports_image_input
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
from onyx.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT_V2
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
from onyx.prompts.prompt_utils import drop_messages_history_overflow
from onyx.prompts.prompt_utils import handle_company_awareness
from onyx.prompts.prompt_utils import handle_memories
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.tools.force import ForceUseTool
from onyx.tools.models import ToolCallFinalResult
@@ -31,9 +34,41 @@ from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
def default_build_system_message_v2(
prompt_config: PromptConfig,
llm_config: LLMConfig,
memories_callback: Callable[[], list[str]] | None = None,
) -> SystemMessage | None:
system_prompt = prompt_config.system_prompt.strip()
system_prompt += REQUIRE_CITATION_STATEMENT_V2
# See https://simonwillison.net/tags/markdown/ for context on this temporary fix
# for o-series markdown generation
if (
llm_config.model_provider == OPENAI_PROVIDER_NAME
and llm_config.model_name.startswith("o")
):
system_prompt = CODE_BLOCK_MARKDOWN + system_prompt
tag_handled_prompt = handle_onyx_date_awareness(
system_prompt,
prompt_config,
add_additional_info_if_no_tag=prompt_config.datetime_aware,
)
if not tag_handled_prompt:
return None
tag_handled_prompt = handle_company_awareness(tag_handled_prompt)
if memories_callback:
tag_handled_prompt = handle_memories(tag_handled_prompt, memories_callback)
return SystemMessage(content=tag_handled_prompt)
def default_build_system_message(
prompt_config: PromptConfig,
llm_config: LLMConfig,
memories_callback: Callable[[], list[str]] | None = None,
) -> SystemMessage | None:
system_prompt = prompt_config.system_prompt.strip()
# See https://simonwillison.net/tags/markdown/ for context on this temporary fix
@@ -52,9 +87,32 @@ def default_build_system_message(
if not tag_handled_prompt:
return None
tag_handled_prompt = handle_company_awareness(tag_handled_prompt)
if memories_callback:
tag_handled_prompt = handle_memories(tag_handled_prompt, memories_callback)
return SystemMessage(content=tag_handled_prompt)
def default_build_user_message_v2(
user_query: str,
prompt_config: PromptConfig,
files: list[InMemoryChatFile] = [],
) -> HumanMessage:
user_prompt = user_query
user_prompt = user_prompt.strip()
tag_handled_prompt = handle_onyx_date_awareness(user_prompt, prompt_config)
user_msg = HumanMessage(
content=(
build_content_with_imgs(tag_handled_prompt, files)
if files
else tag_handled_prompt
)
)
return user_msg
def default_build_user_message(
user_query: str,
prompt_config: PromptConfig,

Some files were not shown because too many files have changed in this diff Show More