mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-25 03:35:48 +00:00
Compare commits
157 Commits
projects-r
...
fix/projec
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c54752e3ba | ||
|
|
bd32795804 | ||
|
|
7aa6b01ac0 | ||
|
|
93084a3a39 | ||
|
|
e46f632570 | ||
|
|
bbb4b9eda3 | ||
|
|
12b7c7d4dd | ||
|
|
464967340b | ||
|
|
a2308c2f45 | ||
|
|
2ee9f79f71 | ||
|
|
c3904b7c96 | ||
|
|
5009dcf911 | ||
|
|
c7b4a0fad9 | ||
|
|
60a402fcab | ||
|
|
c9bb078a37 | ||
|
|
c36c2a6c8d | ||
|
|
f9e2f9cbb4 | ||
|
|
0b7c808480 | ||
|
|
0a6ff30ee4 | ||
|
|
dc036eb452 | ||
|
|
ee950b9cbd | ||
|
|
dd71765849 | ||
|
|
dc6b97f1b1 | ||
|
|
d960c23b6a | ||
|
|
d9c753ba92 | ||
|
|
60234dd6da | ||
|
|
f88ef2e9ff | ||
|
|
6b479a01ea | ||
|
|
248fe416e1 | ||
|
|
cbea4bb75c | ||
|
|
4a147a48dc | ||
|
|
a77025cd46 | ||
|
|
d10914ccc6 | ||
|
|
7d44d48f87 | ||
|
|
82fd0e0316 | ||
|
|
d7e4c47ef1 | ||
|
|
799b0df1cb | ||
|
|
b31d36564a | ||
|
|
84df0a1bf9 | ||
|
|
dbc53fe176 | ||
|
|
1e4ba93daa | ||
|
|
d872715620 | ||
|
|
46ad541ebc | ||
|
|
613907a06f | ||
|
|
ff723992d1 | ||
|
|
bda3c6b189 | ||
|
|
264d1de994 | ||
|
|
335571ce79 | ||
|
|
4d3fac2574 | ||
|
|
7c229dd103 | ||
|
|
b5df182a36 | ||
|
|
7e7cfa4187 | ||
|
|
69d8430288 | ||
|
|
467d294b30 | ||
|
|
ba2dd18233 | ||
|
|
891eeb0212 | ||
|
|
9085731ff0 | ||
|
|
f5d88c47f4 | ||
|
|
807e5c21b0 | ||
|
|
1bcd795011 | ||
|
|
aae357df40 | ||
|
|
4f03e85c57 | ||
|
|
c3411fb28d | ||
|
|
b3d1b1f4aa | ||
|
|
cbb86c12aa | ||
|
|
8fd606b713 | ||
|
|
d69170ee13 | ||
|
|
e356c5308c | ||
|
|
3026ac8912 | ||
|
|
0cee7c849f | ||
|
|
14bfb7fd0c | ||
|
|
804e48a3da | ||
|
|
907271656e | ||
|
|
1f11dd3e46 | ||
|
|
048561ce0b | ||
|
|
8718f10c38 | ||
|
|
ab4d820089 | ||
|
|
77ae4f1a45 | ||
|
|
8fd1f42a1c | ||
|
|
b94c7e581b | ||
|
|
c90ff701dc | ||
|
|
b1ad58c5af | ||
|
|
345f9b3497 | ||
|
|
4671d18d4f | ||
|
|
f0598be875 | ||
|
|
eb361c6434 | ||
|
|
e39b0a921c | ||
|
|
2dd8a8c788 | ||
|
|
8b79e2e90b | ||
|
|
d05941d1bd | ||
|
|
50070fb264 | ||
|
|
5792d8d5ed | ||
|
|
e1c4b33cf7 | ||
|
|
2c2f6e7c23 | ||
|
|
3d30233d46 | ||
|
|
875f8cff5c | ||
|
|
6e4686a09f | ||
|
|
237c18e15e | ||
|
|
a71d80329d | ||
|
|
91c392b4fc | ||
|
|
a25df4002d | ||
|
|
436a5add88 | ||
|
|
3a4bb239b1 | ||
|
|
2acb4cfdb6 | ||
|
|
f1d626adb0 | ||
|
|
5ca604f186 | ||
|
|
c19c76c3ad | ||
|
|
4555f6badc | ||
|
|
71bd643537 | ||
|
|
23f70f0a96 | ||
|
|
c97672559a | ||
|
|
243f0bbdbd | ||
|
|
0a5ca7f1cf | ||
|
|
8d56d213ec | ||
|
|
cea2ea924b | ||
|
|
569d205e31 | ||
|
|
9feff5002f | ||
|
|
a1314e49a3 | ||
|
|
463f839154 | ||
|
|
5a0fe3c1d1 | ||
|
|
8ac5c86c1e | ||
|
|
d803b48edd | ||
|
|
bc3adcdc89 | ||
|
|
95e27f1c30 | ||
|
|
d0724312db | ||
|
|
5b1021f20b | ||
|
|
55cdbe396f | ||
|
|
e8fe0fecd2 | ||
|
|
5b4fc91a3e | ||
|
|
afd2d8c362 | ||
|
|
8a8cf13089 | ||
|
|
c7e872d4e3 | ||
|
|
1dbe926518 | ||
|
|
d095bec6df | ||
|
|
58e8d501a1 | ||
|
|
a39782468b | ||
|
|
d747b48d22 | ||
|
|
817de23854 | ||
|
|
6474d30ba0 | ||
|
|
6c9635373a | ||
|
|
1a945b6f94 | ||
|
|
526c76fa08 | ||
|
|
932e62531f | ||
|
|
83768e2ff1 | ||
|
|
f23b6506f4 | ||
|
|
5f09318302 | ||
|
|
674e789036 | ||
|
|
cb514e6e34 | ||
|
|
965dad785c | ||
|
|
c9558224d2 | ||
|
|
c2dbd3fd1e | ||
|
|
d27c2b1b4e | ||
|
|
8c52444bda | ||
|
|
b4caa85cd4 | ||
|
|
57163dd936 | ||
|
|
15f2a0bf60 | ||
|
|
aeae7ebdef |
50
.github/actions/prepare-build/action.yml
vendored
Normal file
50
.github/actions/prepare-build/action.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
name: "Prepare Build (OpenAPI generation)"
|
||||
description: "Sets up Python with uv, installs deps, generates OpenAPI schema and Python client, uploads artifact"
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@v3
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Install Python dependencies with uv
|
||||
shell: bash
|
||||
run: |
|
||||
uv pip install --system \
|
||||
-r backend/requirements/default.txt \
|
||||
-r backend/requirements/dev.txt
|
||||
|
||||
- name: Generate OpenAPI schema
|
||||
shell: bash
|
||||
working-directory: backend
|
||||
env:
|
||||
PYTHONPATH: "."
|
||||
run: |
|
||||
python scripts/onyx_openapi_schema.py --filename generated/openapi.json
|
||||
|
||||
- name: Generate OpenAPI Python client
|
||||
shell: bash
|
||||
run: |
|
||||
docker run --rm \
|
||||
-v "${{ github.workspace }}/backend/generated:/local" \
|
||||
openapitools/openapi-generator-cli generate \
|
||||
-i /local/openapi.json \
|
||||
-g python \
|
||||
-o /local/onyx_openapi_client \
|
||||
--package-name onyx_openapi_client \
|
||||
--skip-validate-spec \
|
||||
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
|
||||
|
||||
- name: Upload OpenAPI artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: openapi-artifacts
|
||||
path: backend/generated/
|
||||
|
||||
@@ -109,6 +109,15 @@ jobs:
|
||||
# Needed for trivyignore
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Check if stable release version
|
||||
id: check_version
|
||||
run: |
|
||||
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
|
||||
echo "is_stable=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "is_stable=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v4
|
||||
|
||||
@@ -88,7 +88,7 @@ jobs:
|
||||
push: true
|
||||
tags: ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
@@ -134,7 +134,7 @@ jobs:
|
||||
push: true
|
||||
tags: ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
|
||||
@@ -120,6 +120,15 @@ jobs:
|
||||
if: needs.precheck.outputs.should-run == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check if stable release version
|
||||
id: check_version
|
||||
run: |
|
||||
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
|
||||
echo "is_stable=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "is_stable=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
|
||||
4
.github/workflows/docker-tag-latest.yml
vendored
4
.github/workflows/docker-tag-latest.yml
vendored
@@ -35,3 +35,7 @@ jobs:
|
||||
- name: Pull, Tag and Push API Server Image
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-backend:latest onyxdotapp/onyx-backend:${{ github.event.inputs.version }}
|
||||
|
||||
- name: Pull, Tag and Push Model Server Image
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:${{ github.event.inputs.version }}
|
||||
|
||||
171
.github/workflows/hotfix-release-branches.yml
vendored
171
.github/workflows/hotfix-release-branches.yml
vendored
@@ -1,171 +0,0 @@
|
||||
# This workflow is intended to be manually triggered via the GitHub Action tab.
|
||||
# Given a hotfix branch, it will attempt to open a PR to all release branches and
|
||||
# by default auto merge them
|
||||
|
||||
name: Hotfix release branches
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
hotfix_commit:
|
||||
description: "Hotfix commit hash"
|
||||
required: true
|
||||
hotfix_suffix:
|
||||
description: "Hotfix branch suffix (e.g. hotfix/v0.8-{suffix})"
|
||||
required: true
|
||||
release_branch_pattern:
|
||||
description: "Release branch pattern (regex)"
|
||||
required: true
|
||||
default: "release/.*"
|
||||
auto_merge:
|
||||
description: "Automatically merge the hotfix PRs"
|
||||
required: true
|
||||
type: choice
|
||||
default: "true"
|
||||
options:
|
||||
- true
|
||||
- false
|
||||
|
||||
jobs:
|
||||
hotfix_release_branches:
|
||||
permissions: write-all
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# use a lower powered instance since this just does i/o to docker hub
|
||||
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
# needs RKUO_DEPLOY_KEY for write access to merge PR's
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Git user
|
||||
run: |
|
||||
git config user.name "Richard Kuo [bot]"
|
||||
git config user.email "rkuo[bot]@onyx.app"
|
||||
|
||||
- name: Fetch All Branches
|
||||
run: |
|
||||
git fetch --all --prune
|
||||
|
||||
- name: Verify Hotfix Commit Exists
|
||||
run: |
|
||||
git rev-parse --verify "${{ github.event.inputs.hotfix_commit }}" || { echo "Commit not found: ${{ github.event.inputs.hotfix_commit }}"; exit 1; }
|
||||
|
||||
- name: Get Release Branches
|
||||
id: get_release_branches
|
||||
run: |
|
||||
BRANCHES=$(git branch -r | grep -E "${{ github.event.inputs.release_branch_pattern }}" | sed 's|origin/||' | tr -d ' ')
|
||||
if [ -z "$BRANCHES" ]; then
|
||||
echo "No release branches found matching pattern '${{ github.event.inputs.release_branch_pattern }}'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Found release branches:"
|
||||
echo "$BRANCHES"
|
||||
|
||||
# Join the branches into a single line separated by commas
|
||||
BRANCHES_JOINED=$(echo "$BRANCHES" | tr '\n' ',' | sed 's/,$//')
|
||||
|
||||
# Set the branches as an output
|
||||
echo "branches=$BRANCHES_JOINED" >> $GITHUB_OUTPUT
|
||||
|
||||
# notes on all the vagaries of wiring up automated PR's
|
||||
# https://github.com/peter-evans/create-pull-request/blob/main/docs/concepts-guidelines.md#triggering-further-workflow-runs
|
||||
# we must use a custom token for GH_TOKEN to trigger the subsequent PR checks
|
||||
- name: Create and Merge Pull Requests to Matching Release Branches
|
||||
env:
|
||||
HOTFIX_COMMIT: ${{ github.event.inputs.hotfix_commit }}
|
||||
HOTFIX_SUFFIX: ${{ github.event.inputs.hotfix_suffix }}
|
||||
AUTO_MERGE: ${{ github.event.inputs.auto_merge }}
|
||||
GH_TOKEN: ${{ secrets.RKUO_PERSONAL_ACCESS_TOKEN }}
|
||||
run: |
|
||||
# Get the branches from the previous step
|
||||
BRANCHES="${{ steps.get_release_branches.outputs.branches }}"
|
||||
|
||||
# Convert BRANCHES to an array
|
||||
IFS=$',' read -ra BRANCH_ARRAY <<< "$BRANCHES"
|
||||
|
||||
# Loop through each release branch and create and merge a PR
|
||||
for RELEASE_BRANCH in "${BRANCH_ARRAY[@]}"; do
|
||||
echo "Processing $RELEASE_BRANCH..."
|
||||
|
||||
# Parse out the release version by removing "release/" from the branch name
|
||||
RELEASE_VERSION=${RELEASE_BRANCH#release/}
|
||||
echo "Release version parsed: $RELEASE_VERSION"
|
||||
|
||||
HOTFIX_BRANCH="hotfix/${RELEASE_VERSION}-${HOTFIX_SUFFIX}"
|
||||
echo "Creating PR from $HOTFIX_BRANCH to $RELEASE_BRANCH"
|
||||
|
||||
# Checkout the release branch
|
||||
echo "Checking out $RELEASE_BRANCH"
|
||||
git checkout "$RELEASE_BRANCH"
|
||||
|
||||
# Create the new hotfix branch
|
||||
if git rev-parse --verify "$HOTFIX_BRANCH" >/dev/null 2>&1; then
|
||||
echo "Hotfix branch $HOTFIX_BRANCH already exists. Skipping branch creation."
|
||||
else
|
||||
echo "Branching $RELEASE_BRANCH to $HOTFIX_BRANCH"
|
||||
git checkout -b "$HOTFIX_BRANCH"
|
||||
fi
|
||||
|
||||
# Check if the hotfix commit is a merge commit
|
||||
if git rev-list --merges -n 1 "$HOTFIX_COMMIT" >/dev/null 2>&1; then
|
||||
# -m 1 uses the target branch as the base (which is what we want)
|
||||
echo "Hotfix commit $HOTFIX_COMMIT is a merge commit, using -m 1 for cherry-pick"
|
||||
CHERRY_PICK_CMD="git cherry-pick -m 1 $HOTFIX_COMMIT"
|
||||
else
|
||||
CHERRY_PICK_CMD="git cherry-pick $HOTFIX_COMMIT"
|
||||
fi
|
||||
|
||||
# Perform the cherry-pick
|
||||
echo "Executing: $CHERRY_PICK_CMD"
|
||||
eval "$CHERRY_PICK_CMD"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Cherry-pick failed for $HOTFIX_COMMIT on $HOTFIX_BRANCH. Aborting..."
|
||||
git cherry-pick --abort
|
||||
continue
|
||||
fi
|
||||
|
||||
# Push the hotfix branch to the remote
|
||||
echo "Pushing $HOTFIX_BRANCH..."
|
||||
git push origin "$HOTFIX_BRANCH"
|
||||
echo "Hotfix branch $HOTFIX_BRANCH created and pushed."
|
||||
|
||||
# Check if PR already exists
|
||||
EXISTING_PR=$(gh pr list --head "$HOTFIX_BRANCH" --base "$RELEASE_BRANCH" --state open --json number --jq '.[0].number')
|
||||
|
||||
if [ -n "$EXISTING_PR" ]; then
|
||||
echo "An open PR already exists: #$EXISTING_PR. Skipping..."
|
||||
continue
|
||||
fi
|
||||
|
||||
# Create a new PR and capture the output
|
||||
PR_OUTPUT=$(gh pr create --title "Merge $HOTFIX_BRANCH into $RELEASE_BRANCH" \
|
||||
--body "Automated PR to merge \`$HOTFIX_BRANCH\` into \`$RELEASE_BRANCH\`." \
|
||||
--head "$HOTFIX_BRANCH" --base "$RELEASE_BRANCH")
|
||||
|
||||
# Extract the URL from the output
|
||||
PR_URL=$(echo "$PR_OUTPUT" | grep -Eo 'https://github.com/[^ ]+')
|
||||
echo "Pull request created: $PR_URL"
|
||||
|
||||
# Extract PR number from URL
|
||||
PR_NUMBER=$(basename "$PR_URL")
|
||||
echo "Pull request created: $PR_NUMBER"
|
||||
|
||||
if [ "$AUTO_MERGE" == "true" ]; then
|
||||
echo "Attempting to merge pull request #$PR_NUMBER"
|
||||
|
||||
# Attempt to merge the PR
|
||||
gh pr merge "$PR_NUMBER" --merge --auto --delete-branch
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "Pull request #$PR_NUMBER merged successfully."
|
||||
else
|
||||
# Optionally, handle the error or continue
|
||||
echo "Failed to merge pull request #$PR_NUMBER."
|
||||
fi
|
||||
fi
|
||||
done
|
||||
@@ -14,11 +14,10 @@ env:
|
||||
S3_ENDPOINT_URL: "http://localhost:9004"
|
||||
|
||||
# Confluence
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_TEST_SPACE: ${{ vars.CONFLUENCE_TEST_SPACE }}
|
||||
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
|
||||
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
|
||||
|
||||
54
.github/workflows/pr-integration-tests.yml
vendored
54
.github/workflows/pr-integration-tests.yml
vendored
@@ -19,8 +19,8 @@ env:
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
@@ -67,46 +67,8 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
- name: Generate OpenAPI schema
|
||||
working-directory: ./backend
|
||||
env:
|
||||
PYTHONPATH: "."
|
||||
run: |
|
||||
python scripts/onyx_openapi_schema.py --filename generated/openapi.json
|
||||
|
||||
- name: Generate OpenAPI Python client
|
||||
working-directory: ./backend
|
||||
run: |
|
||||
docker run --rm \
|
||||
-v "${{ github.workspace }}/backend/generated:/local" \
|
||||
openapitools/openapi-generator-cli generate \
|
||||
-i /local/openapi.json \
|
||||
-g python \
|
||||
-o /local/onyx_openapi_client \
|
||||
--package-name onyx_openapi_client \
|
||||
--skip-validate-spec \
|
||||
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
|
||||
|
||||
- name: Upload OpenAPI artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: openapi-artifacts
|
||||
path: backend/generated/
|
||||
- name: Prepare build
|
||||
uses: ./.github/actions/prepare-build
|
||||
|
||||
build-backend-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
@@ -133,7 +95,8 @@ jobs:
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
@@ -161,7 +124,8 @@ jobs:
|
||||
push: true
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
no-cache: true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
|
||||
build-integration-image:
|
||||
needs: prepare-build
|
||||
@@ -195,7 +159,7 @@ jobs:
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
integration-tests:
|
||||
needs:
|
||||
|
||||
35
.github/workflows/pr-jest-tests.yml
vendored
Normal file
35
.github/workflows/pr-jest-tests.yml
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
name: Run Jest Tests
|
||||
concurrency:
|
||||
group: Run-Jest-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on: push
|
||||
|
||||
jobs:
|
||||
jest-tests:
|
||||
name: Jest Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Run Jest tests
|
||||
working-directory: ./web
|
||||
run: npm test -- --ci --coverage --maxWorkers=50%
|
||||
|
||||
- name: Upload coverage reports
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: jest-coverage-${{ github.run_id }}
|
||||
path: ./web/coverage
|
||||
retention-days: 7
|
||||
55
.github/workflows/pr-mit-integration-tests.yml
vendored
55
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -16,8 +16,8 @@ env:
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
@@ -64,46 +64,8 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
- name: Generate OpenAPI schema
|
||||
working-directory: ./backend
|
||||
env:
|
||||
PYTHONPATH: "."
|
||||
run: |
|
||||
python scripts/onyx_openapi_schema.py --filename generated/openapi.json
|
||||
|
||||
- name: Generate OpenAPI Python client
|
||||
working-directory: ./backend
|
||||
run: |
|
||||
docker run --rm \
|
||||
-v "${{ github.workspace }}/backend/generated:/local" \
|
||||
openapitools/openapi-generator-cli generate \
|
||||
-i /local/openapi.json \
|
||||
-g python \
|
||||
-o /local/onyx_openapi_client \
|
||||
--package-name onyx_openapi_client \
|
||||
--skip-validate-spec \
|
||||
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
|
||||
|
||||
- name: Upload OpenAPI artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: openapi-artifacts
|
||||
path: backend/generated/
|
||||
- name: Prepare build
|
||||
uses: ./.github/actions/prepare-build
|
||||
|
||||
build-backend-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
@@ -130,7 +92,8 @@ jobs:
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
@@ -158,7 +121,8 @@ jobs:
|
||||
push: true
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
no-cache: true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
|
||||
build-integration-image:
|
||||
needs: prepare-build
|
||||
@@ -192,7 +156,8 @@ jobs:
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
|
||||
integration-tests-mit:
|
||||
needs:
|
||||
|
||||
6
.github/workflows/pr-playwright-tests.yml
vendored
6
.github/workflows/pr-playwright-tests.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
||||
sbom: false
|
||||
push: true
|
||||
outputs: type=registry
|
||||
# no-cache: true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-backend-image:
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
@@ -90,7 +90,7 @@ jobs:
|
||||
sbom: false
|
||||
push: true
|
||||
outputs: type=registry
|
||||
# no-cache: true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
@@ -123,7 +123,7 @@ jobs:
|
||||
sbom: false
|
||||
push: true
|
||||
outputs: type=registry
|
||||
# no-cache: true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
playwright-tests:
|
||||
needs: [build-web-image, build-backend-image, build-model-server-image]
|
||||
|
||||
102
.github/workflows/pr-python-connector-tests.yml
vendored
102
.github/workflows/pr-python-connector-tests.yml
vendored
@@ -13,12 +13,20 @@ env:
|
||||
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
|
||||
AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
|
||||
|
||||
# Cloudflare R2
|
||||
R2_ACCOUNT_ID_DAILY_CONNECTOR_TESTS: ${{ vars.R2_ACCOUNT_ID_DAILY_CONNECTOR_TESTS }}
|
||||
R2_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.R2_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
|
||||
R2_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.R2_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
|
||||
|
||||
# Google Cloud Storage
|
||||
GCS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.GCS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
|
||||
GCS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.GCS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
|
||||
|
||||
# Confluence
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_TEST_SPACE: ${{ vars.CONFLUENCE_TEST_SPACE }}
|
||||
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
|
||||
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
|
||||
@@ -56,22 +64,22 @@ env:
|
||||
HUBSPOT_ACCESS_TOKEN: ${{ secrets.HUBSPOT_ACCESS_TOKEN }}
|
||||
|
||||
# IMAP
|
||||
IMAP_HOST: ${{ secrets.IMAP_HOST }}
|
||||
IMAP_USERNAME: ${{ secrets.IMAP_USERNAME }}
|
||||
IMAP_HOST: ${{ vars.IMAP_HOST }}
|
||||
IMAP_USERNAME: ${{ vars.IMAP_USERNAME }}
|
||||
IMAP_PASSWORD: ${{ secrets.IMAP_PASSWORD }}
|
||||
IMAP_MAILBOXES: ${{ secrets.IMAP_MAILBOXES }}
|
||||
IMAP_MAILBOXES: ${{ vars.IMAP_MAILBOXES }}
|
||||
|
||||
# Airtable
|
||||
AIRTABLE_TEST_BASE_ID: ${{ secrets.AIRTABLE_TEST_BASE_ID }}
|
||||
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
||||
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
|
||||
AIRTABLE_TEST_BASE_ID: ${{ vars.AIRTABLE_TEST_BASE_ID }}
|
||||
AIRTABLE_TEST_TABLE_ID: ${{ vars.AIRTABLE_TEST_TABLE_ID }}
|
||||
AIRTABLE_TEST_TABLE_NAME: ${{ vars.AIRTABLE_TEST_TABLE_NAME }}
|
||||
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
|
||||
|
||||
# Sharepoint
|
||||
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
|
||||
SHAREPOINT_CLIENT_ID: ${{ vars.SHAREPOINT_CLIENT_ID }}
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ vars.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ vars.SHAREPOINT_SITE }}
|
||||
|
||||
# Github
|
||||
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
|
||||
@@ -102,9 +110,12 @@ env:
|
||||
BITBUCKET_WORKSPACE: ${{ secrets.BITBUCKET_WORKSPACE }}
|
||||
BITBUCKET_REPOSITORIES: ${{ secrets.BITBUCKET_REPOSITORIES }}
|
||||
BITBUCKET_PROJECTS: ${{ secrets.BITBUCKET_PROJECTS }}
|
||||
BITBUCKET_EMAIL: ${{ secrets.BITBUCKET_EMAIL }}
|
||||
BITBUCKET_EMAIL: ${{ vars.BITBUCKET_EMAIL }}
|
||||
BITBUCKET_API_TOKEN: ${{ secrets.BITBUCKET_API_TOKEN }}
|
||||
|
||||
# Fireflies
|
||||
FIREFLIES_API_KEY: ${{ secrets.FIREFLIES_API_KEY }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
@@ -134,7 +145,24 @@ jobs:
|
||||
playwright install chromium
|
||||
playwright install-deps chromium
|
||||
|
||||
- name: Run Tests
|
||||
- name: Detect Connector changes
|
||||
id: changes
|
||||
uses: dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
hubspot:
|
||||
- 'backend/onyx/connectors/hubspot/**'
|
||||
- 'backend/tests/daily/connectors/hubspot/**'
|
||||
salesforce:
|
||||
- 'backend/onyx/connectors/salesforce/**'
|
||||
- 'backend/tests/daily/connectors/salesforce/**'
|
||||
github:
|
||||
- 'backend/onyx/connectors/github/**'
|
||||
- 'backend/tests/daily/connectors/github/**'
|
||||
file_processing:
|
||||
- 'backend/onyx/file_processing/**'
|
||||
|
||||
- name: Run Tests (excluding HubSpot, Salesforce, and GitHub)
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test \
|
||||
@@ -144,7 +172,49 @@ jobs:
|
||||
-o junit_family=xunit2 \
|
||||
-xv \
|
||||
--ff \
|
||||
backend/tests/daily/connectors
|
||||
backend/tests/daily/connectors \
|
||||
--ignore backend/tests/daily/connectors/hubspot \
|
||||
--ignore backend/tests/daily/connectors/salesforce \
|
||||
--ignore backend/tests/daily/connectors/github
|
||||
|
||||
- name: Run HubSpot Connector Tests
|
||||
if: ${{ github.event_name == 'schedule' || steps.changes.outputs.hubspot == 'true' || steps.changes.outputs.file_processing == 'true' }}
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test \
|
||||
-n 8 \
|
||||
--dist loadfile \
|
||||
--durations=8 \
|
||||
-o junit_family=xunit2 \
|
||||
-xv \
|
||||
--ff \
|
||||
backend/tests/daily/connectors/hubspot
|
||||
|
||||
- name: Run Salesforce Connector Tests
|
||||
if: ${{ github.event_name == 'schedule' || steps.changes.outputs.salesforce == 'true' || steps.changes.outputs.file_processing == 'true' }}
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test \
|
||||
-n 8 \
|
||||
--dist loadfile \
|
||||
--durations=8 \
|
||||
-o junit_family=xunit2 \
|
||||
-xv \
|
||||
--ff \
|
||||
backend/tests/daily/connectors/salesforce
|
||||
|
||||
- name: Run GitHub Connector Tests
|
||||
if: ${{ github.event_name == 'schedule' || steps.changes.outputs.github == 'true' || steps.changes.outputs.file_processing == 'true' }}
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test \
|
||||
-n 8 \
|
||||
--dist loadfile \
|
||||
--durations=8 \
|
||||
-o junit_family=xunit2 \
|
||||
-xv \
|
||||
--ff \
|
||||
backend/tests/daily/connectors/github
|
||||
|
||||
- name: Alert on Failure
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
|
||||
4
.github/workflows/pr-python-model-tests.yml
vendored
4
.github/workflows/pr-python-model-tests.yml
vendored
@@ -15,7 +15,7 @@ env:
|
||||
# Bedrock
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
|
||||
AWS_REGION_NAME: ${{ vars.AWS_REGION_NAME }}
|
||||
|
||||
# API keys for testing
|
||||
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
|
||||
@@ -23,7 +23,7 @@ env:
|
||||
LITELLM_API_URL: ${{ secrets.LITELLM_API_URL }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
|
||||
AZURE_API_URL: ${{ secrets.AZURE_API_URL }}
|
||||
AZURE_API_URL: ${{ vars.AZURE_API_URL }}
|
||||
|
||||
jobs:
|
||||
model-check:
|
||||
|
||||
3
.github/workflows/tag-nightly.yml
vendored
3
.github/workflows/tag-nightly.yml
vendored
@@ -15,6 +15,9 @@ jobs:
|
||||
# actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes
|
||||
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
|
||||
# implement here which needs an actual user's deploy key
|
||||
|
||||
# Additional NOTE: even though this is named "rkuo", the actual key is tied to the onyx repo
|
||||
# and not rkuo's personal account. It is fine to leave this key as is!
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
|
||||
@@ -34,8 +34,16 @@ repos:
|
||||
hooks:
|
||||
- id: prettier
|
||||
types_or: [html, css, javascript, ts, tsx]
|
||||
additional_dependencies:
|
||||
- prettier
|
||||
language_version: system
|
||||
|
||||
- repo: https://github.com/sirwart/ripsecrets
|
||||
rev: v0.1.11
|
||||
hooks:
|
||||
- id: ripsecrets
|
||||
args:
|
||||
- --additional-pattern
|
||||
- ^sk-[A-Za-z0-9_\-]{20,}$
|
||||
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
|
||||
197
.vscode/launch.template.jsonc
vendored
197
.vscode/launch.template.jsonc
vendored
@@ -23,12 +23,10 @@
|
||||
"Slack Bot",
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery background",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
"Celery user file processing"
|
||||
"Celery beat"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
@@ -42,16 +40,29 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery (all)",
|
||||
"name": "Celery (lightweight mode)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery background",
|
||||
"Celery beat"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
},
|
||||
"stopAll": true
|
||||
},
|
||||
{
|
||||
"name": "Celery (standard mode)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery kg_processing",
|
||||
"Celery monitoring",
|
||||
"Celery user_file_processing",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
"Celery user file processing"
|
||||
"Celery beat"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
@@ -199,6 +210,35 @@
|
||||
},
|
||||
"consoleTitle": "Celery light Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery background",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.background",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=20",
|
||||
"--prefetch-multiplier=4",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=background@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,user_files_indexing,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery background Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery heavy",
|
||||
"type": "debugpy",
|
||||
@@ -221,13 +261,100 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync"
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery kg_processing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.kg_processing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=2",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=kg_processing@%n",
|
||||
"-Q",
|
||||
"kg_processing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery kg_processing Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery monitoring Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery user_file_processing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.user_file_processing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=2",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_file_processing@%n",
|
||||
"-Q",
|
||||
"user_file_processing,user_file_project_sync,user_file_delete"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery user_file_processing Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery docfetching",
|
||||
"type": "debugpy",
|
||||
@@ -311,58 +438,6 @@
|
||||
},
|
||||
"consoleTitle": "Celery beat Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=solo",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery monitoring Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery user file processing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.user_file_processing",
|
||||
"worker",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_file_processing@%n",
|
||||
"--pool=threads",
|
||||
"-Q",
|
||||
"user_file_processing,user_file_project_sync"
|
||||
],
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery user file processing Console"
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"consoleName": "Pytest",
|
||||
|
||||
32
AGENTS.md
32
AGENTS.md
@@ -70,7 +70,12 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
- Single thread (monitoring doesn't need parallelism)
|
||||
- Cloud-specific monitoring tasks
|
||||
|
||||
8. **Beat Worker** (`beat`)
|
||||
8. **User File Processing Worker** (`user_file_processing`)
|
||||
- Processes user-uploaded files
|
||||
- Handles user file indexing and project synchronization
|
||||
- Configurable concurrency
|
||||
|
||||
9. **Beat Worker** (`beat`)
|
||||
- Celery's scheduler for periodic tasks
|
||||
- Uses DynamicTenantScheduler for multi-tenant support
|
||||
- Schedules tasks like:
|
||||
@@ -82,6 +87,31 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
- Monitoring tasks (every 5 minutes)
|
||||
- Cleanup tasks (hourly)
|
||||
|
||||
#### Worker Deployment Modes
|
||||
|
||||
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
|
||||
|
||||
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
|
||||
- Runs a single consolidated `background` worker that handles all background tasks:
|
||||
- Pruning operations (from `heavy` worker)
|
||||
- Knowledge graph processing (from `kg_processing` worker)
|
||||
- Monitoring tasks (from `monitoring` worker)
|
||||
- User file processing (from `user_file_processing` worker)
|
||||
- Lower resource footprint (single worker process)
|
||||
- Suitable for smaller deployments or development environments
|
||||
- Default concurrency: 6 threads
|
||||
|
||||
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
|
||||
- Runs separate specialized workers as documented above (heavy, kg_processing, monitoring, user_file_processing)
|
||||
- Better isolation and scalability
|
||||
- Can scale individual workers independently based on workload
|
||||
- Suitable for production deployments with higher load
|
||||
|
||||
The deployment mode affects:
|
||||
- **Backend**: Worker processes spawned by supervisord or dev scripts
|
||||
- **Helm**: Which Kubernetes deployments are created
|
||||
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
|
||||
|
||||
#### Key Features
|
||||
|
||||
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
|
||||
|
||||
39
CLAUDE.md
39
CLAUDE.md
@@ -70,7 +70,12 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
- Single thread (monitoring doesn't need parallelism)
|
||||
- Cloud-specific monitoring tasks
|
||||
|
||||
8. **Beat Worker** (`beat`)
|
||||
8. **User File Processing Worker** (`user_file_processing`)
|
||||
- Processes user-uploaded files
|
||||
- Handles user file indexing and project synchronization
|
||||
- Configurable concurrency
|
||||
|
||||
9. **Beat Worker** (`beat`)
|
||||
- Celery's scheduler for periodic tasks
|
||||
- Uses DynamicTenantScheduler for multi-tenant support
|
||||
- Schedules tasks like:
|
||||
@@ -82,11 +87,39 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
- Monitoring tasks (every 5 minutes)
|
||||
- Cleanup tasks (hourly)
|
||||
|
||||
#### Worker Deployment Modes
|
||||
|
||||
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
|
||||
|
||||
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
|
||||
- Runs a single consolidated `background` worker that handles all background tasks:
|
||||
- Light worker tasks (Vespa operations, permissions sync, deletion)
|
||||
- Document processing (indexing pipeline)
|
||||
- Document fetching (connector data retrieval)
|
||||
- Pruning operations (from `heavy` worker)
|
||||
- Knowledge graph processing (from `kg_processing` worker)
|
||||
- Monitoring tasks (from `monitoring` worker)
|
||||
- User file processing (from `user_file_processing` worker)
|
||||
- Lower resource footprint (fewer worker processes)
|
||||
- Suitable for smaller deployments or development environments
|
||||
- Default concurrency: 20 threads (increased to handle combined workload)
|
||||
|
||||
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
|
||||
- Runs separate specialized workers as documented above (light, docprocessing, docfetching, heavy, kg_processing, monitoring, user_file_processing)
|
||||
- Better isolation and scalability
|
||||
- Can scale individual workers independently based on workload
|
||||
- Suitable for production deployments with higher load
|
||||
|
||||
The deployment mode affects:
|
||||
- **Backend**: Worker processes spawned by supervisord or dev scripts
|
||||
- **Helm**: Which Kubernetes deployments are created
|
||||
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
|
||||
|
||||
#### Key Features
|
||||
|
||||
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
|
||||
- **Tenant Awareness**: Multi-tenant support with per-tenant task isolation. There is a
|
||||
middleware layer that automatically finds the appropriate tenant ID when sending tasks
|
||||
- **Tenant Awareness**: Multi-tenant support with per-tenant task isolation. There is a
|
||||
middleware layer that automatically finds the appropriate tenant ID when sending tasks
|
||||
via Celery Beat.
|
||||
- **Task Prioritization**: High, Medium, Low priority queues
|
||||
- **Monitoring**: Built-in heartbeat and liveness checking
|
||||
|
||||
@@ -13,8 +13,7 @@ As an open source project in a rapidly changing space, we welcome all contributi
|
||||
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to start for contribution ideas.
|
||||
|
||||
To ensure that your contribution is aligned with the project's direction, please reach out to any maintainer on the Onyx team
|
||||
via [Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-34lu4m7xg-TsKGO6h8PDvR5W27zTdyhA) /
|
||||
[Discord](https://discord.gg/TDJ59cGV2X) or [email](mailto:founders@onyx.app).
|
||||
via [Discord](https://discord.gg/4NA5SbzrWb) or [email](mailto:hello@onyx.app).
|
||||
|
||||
Issues that have been explicitly approved by the maintainers (aligned with the direction of the project)
|
||||
will be marked with the `approved by maintainers` label.
|
||||
@@ -28,8 +27,7 @@ Your input is vital to making sure that Onyx moves in the right direction.
|
||||
Before starting on implementation, please raise a GitHub issue.
|
||||
|
||||
Also, always feel free to message the founders (Chris Weaver / Yuhong Sun) on
|
||||
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-34lu4m7xg-TsKGO6h8PDvR5W27zTdyhA) /
|
||||
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.
|
||||
[Discord](https://discord.gg/4NA5SbzrWb) directly about anything at all.
|
||||
|
||||
### Contributing Code
|
||||
|
||||
@@ -46,9 +44,7 @@ Our goal is to make contributing as easy as possible. If you run into any issues
|
||||
That way we can help future contributors and users can avoid the same issue.
|
||||
|
||||
We also have support channels and generally interesting discussions on our
|
||||
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA)
|
||||
and
|
||||
[Discord](https://discord.gg/TDJ59cGV2X).
|
||||
[Discord](https://discord.gg/4NA5SbzrWb).
|
||||
|
||||
We would love to see you there!
|
||||
|
||||
@@ -122,8 +118,15 @@ You may have to deactivate and reactivate your virtualenv for `playwright` to ap
|
||||
|
||||
#### Frontend: Node dependencies
|
||||
|
||||
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
|
||||
Once the above is done, navigate to `onyx/web` run:
|
||||
Onyx uses Node v22.20.0. We highly recommend you use [Node Version Manager (nvm)](https://github.com/nvm-sh/nvm)
|
||||
to manage your Node installations. Once installed, you can run
|
||||
|
||||
```bash
|
||||
nvm install 22 && nvm use 22`
|
||||
node -v # verify your active version
|
||||
```
|
||||
|
||||
Navigate to `onyx/web` and run:
|
||||
|
||||
```bash
|
||||
npm i
|
||||
@@ -134,8 +137,6 @@ npm i
|
||||
### Backend
|
||||
|
||||
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
|
||||
First, install pre-commit (if you don't have it already) following the instructions
|
||||
[here](https://pre-commit.com/#installation).
|
||||
|
||||
With the virtual environment active, install the pre-commit library with:
|
||||
|
||||
@@ -155,15 +156,17 @@ To run the mypy checks manually, run `python -m mypy .` from the `onyx/backend`
|
||||
|
||||
### Web
|
||||
|
||||
We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `onyx/web` directory.
|
||||
We use `prettier` for formatting. The desired version will be installed via a `npm i` from the `onyx/web` directory.
|
||||
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
|
||||
Please double check that prettier passes before creating a pull request.
|
||||
|
||||
Pre-commit will also run prettier automatically on files you've recently touched. If re-formatted, your commit will fail.
|
||||
Re-stage your changes and commit again.
|
||||
|
||||
# Running the application for development
|
||||
|
||||
## Developing using VSCode Debugger (recommended)
|
||||
|
||||
We highly recommend using VSCode debugger for development.
|
||||
**We highly recommend using VSCode debugger for development.**
|
||||
See [CONTRIBUTING_VSCODE.md](./CONTRIBUTING_VSCODE.md) for more details.
|
||||
|
||||
Otherwise, you can follow the instructions below to run the application for development.
|
||||
|
||||
@@ -21,6 +21,9 @@ Before starting, make sure the Docker Daemon is running.
|
||||
5. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
|
||||
6. Use the debug toolbar to step through code, inspect variables, etc.
|
||||
|
||||
Note: Clear and Restart External Volumes and Containers will reset your postgres and Vespa (relational-db and index).
|
||||
Only run this if you are okay with wiping your data.
|
||||
|
||||
## Features
|
||||
|
||||
- Hot reload is enabled for the web server and API servers
|
||||
|
||||
@@ -15,8 +15,8 @@ ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DO_NOT_TRACK="true" \
|
||||
PLAYWRIGHT_BROWSERS_PATH="/app/.cache/ms-playwright"
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||
|
||||
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
|
||||
# Install system dependencies
|
||||
# cmake needed for psycopg (postgres)
|
||||
# libpq-dev needed for psycopg (postgres)
|
||||
@@ -48,22 +48,19 @@ RUN apt-get update && \
|
||||
# Remove py which is pulled in by retry, py is not needed and is a CVE
|
||||
COPY ./requirements/default.txt /tmp/requirements.txt
|
||||
COPY ./requirements/ee.txt /tmp/ee-requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
--retries 5 \
|
||||
--timeout 30 \
|
||||
RUN uv pip install --system --no-cache-dir --upgrade \
|
||||
-r /tmp/requirements.txt \
|
||||
-r /tmp/ee-requirements.txt && \
|
||||
pip uninstall -y py && \
|
||||
playwright install chromium && \
|
||||
playwright install-deps chromium && \
|
||||
ln -s /usr/local/bin/supervisord /usr/bin/supervisord
|
||||
|
||||
# Cleanup for CVEs and size reduction
|
||||
# https://github.com/tornadoweb/tornado/issues/3107
|
||||
# xserver-common and xvfb included by playwright installation but not needed after
|
||||
# perl-base is part of the base Python Debian image but not needed for Onyx functionality
|
||||
# perl-base could only be removed with --allow-remove-essential
|
||||
RUN apt-get update && \
|
||||
ln -s /usr/local/bin/supervisord /usr/bin/supervisord && \
|
||||
# Cleanup for CVEs and size reduction
|
||||
# https://github.com/tornadoweb/tornado/issues/3107
|
||||
# xserver-common and xvfb included by playwright installation but not needed after
|
||||
# perl-base is part of the base Python Debian image but not needed for Onyx functionality
|
||||
# perl-base could only be removed with --allow-remove-essential
|
||||
apt-get update && \
|
||||
apt-get remove -y --allow-remove-essential \
|
||||
perl-base \
|
||||
xserver-common \
|
||||
@@ -73,15 +70,16 @@ RUN apt-get update && \
|
||||
libxmlsec1-dev \
|
||||
pkg-config \
|
||||
gcc && \
|
||||
apt-get install -y libxmlsec1-openssl && \
|
||||
# Install here to avoid some packages being cleaned up above
|
||||
apt-get install -y \
|
||||
libxmlsec1-openssl \
|
||||
# Install postgresql-client for easy manual tests
|
||||
postgresql-client && \
|
||||
apt-get autoremove -y && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm -rf ~/.cache/uv /tmp/*.txt && \
|
||||
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
|
||||
# Install postgresql-client for easy manual tests
|
||||
# Install it here to avoid it being cleaned up above
|
||||
RUN apt-get update && apt-get install -y postgresql-client
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
@@ -95,36 +93,37 @@ nltk.download('punkt_tab', quiet=True);"
|
||||
# Set up application files
|
||||
WORKDIR /app
|
||||
|
||||
# Enterprise Version Files
|
||||
COPY ./ee /app/ee
|
||||
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
|
||||
|
||||
# Set up application files
|
||||
COPY ./onyx /app/onyx
|
||||
COPY ./shared_configs /app/shared_configs
|
||||
COPY ./alembic /app/alembic
|
||||
COPY ./alembic_tenants /app/alembic_tenants
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
COPY ./static /app/static
|
||||
|
||||
# Escape hatch scripts
|
||||
COPY ./scripts/debugging /app/scripts/debugging
|
||||
COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
|
||||
# Put logo in assets
|
||||
COPY ./assets /app/assets
|
||||
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
# Create non-root user for security best practices
|
||||
RUN groupadd -g 1001 onyx && \
|
||||
useradd -u 1001 -g onyx -m -s /bin/bash onyx && \
|
||||
chown -R onyx:onyx /app && \
|
||||
mkdir -p /var/log/onyx && \
|
||||
chmod 755 /var/log/onyx && \
|
||||
chown onyx:onyx /var/log/onyx
|
||||
|
||||
# Enterprise Version Files
|
||||
COPY --chown=onyx:onyx ./ee /app/ee
|
||||
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
|
||||
|
||||
# Set up application files
|
||||
COPY --chown=onyx:onyx ./onyx /app/onyx
|
||||
COPY --chown=onyx:onyx ./shared_configs /app/shared_configs
|
||||
COPY --chown=onyx:onyx ./alembic /app/alembic
|
||||
COPY --chown=onyx:onyx ./alembic_tenants /app/alembic_tenants
|
||||
COPY --chown=onyx:onyx ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
COPY --chown=onyx:onyx ./static /app/static
|
||||
|
||||
# Escape hatch scripts
|
||||
COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging
|
||||
COPY --chown=onyx:onyx ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
COPY --chown=onyx:onyx ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh
|
||||
|
||||
# Put logo in assets
|
||||
COPY --chown=onyx:onyx ./assets /app/assets
|
||||
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
# Default command which does nothing
|
||||
# This container is used by api server and background which specify their own CMD
|
||||
CMD ["tail", "-f", "/dev/null"]
|
||||
|
||||
@@ -12,7 +12,7 @@ ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
HF_HOME=/app/.cache/huggingface
|
||||
|
||||
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||
|
||||
# Create non-root user for security best practices
|
||||
RUN mkdir -p /app && \
|
||||
@@ -34,19 +34,17 @@ RUN set -eux; \
|
||||
pkg-config \
|
||||
curl \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
# Install latest stable Rust (supports Cargo.lock v4)
|
||||
&& curl -sSf https://sh.rustup.rs | sh -s -- -y --profile minimal --default-toolchain stable \
|
||||
&& rustc --version && cargo --version
|
||||
&& rustc --version && cargo --version \
|
||||
&& apt-get remove -y --allow-remove-essential perl-base \
|
||||
&& apt-get autoremove -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY ./requirements/model_server.txt /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
--retries 5 \
|
||||
--timeout 30 \
|
||||
-r /tmp/requirements.txt
|
||||
|
||||
RUN apt-get remove -y --allow-remove-essential perl-base && \
|
||||
apt-get autoremove -y
|
||||
RUN uv pip install --system --no-cache-dir --upgrade \
|
||||
-r /tmp/requirements.txt && \
|
||||
rm -rf ~/.cache/uv /tmp/*.txt
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
# Download tokenizers, distilbert for the Onyx model
|
||||
@@ -61,12 +59,11 @@ snapshot_download(repo_id='onyx-dot-app/information-content-model'); \
|
||||
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
from sentence_transformers import SentenceTransformer; \
|
||||
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);"
|
||||
|
||||
# In case the user has volumes mounted to /app/.cache/huggingface that they've downloaded while
|
||||
# running Onyx, move the current contents of the cache folder to a temporary location to ensure
|
||||
# it's preserved in order to combine with the user's cache contents
|
||||
RUN mv /app/.cache/huggingface /app/.cache/temp_huggingface && \
|
||||
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);" && \
|
||||
# In case the user has volumes mounted to /app/.cache/huggingface that they've downloaded while
|
||||
# running Onyx, move the current contents of the cache folder to a temporary location to ensure
|
||||
# it's preserved in order to combine with the user's cache contents
|
||||
mv /app/.cache/huggingface /app/.cache/temp_huggingface && \
|
||||
chown -R onyx:onyx /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
"""add permission sync attempt tables
|
||||
|
||||
Revision ID: 03d710ccf29c
|
||||
Revises: 96a5702df6aa
|
||||
Create Date: 2025-09-11 13:30:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "03d710ccf29c" # Generate a new unique ID
|
||||
down_revision = "96a5702df6aa"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create the permission sync status enum
|
||||
permission_sync_status_enum = sa.Enum(
|
||||
"not_started",
|
||||
"in_progress",
|
||||
"success",
|
||||
"canceled",
|
||||
"failed",
|
||||
"completed_with_errors",
|
||||
name="permissionsyncstatus",
|
||||
native_enum=False,
|
||||
)
|
||||
|
||||
# Create doc_permission_sync_attempt table
|
||||
op.create_table(
|
||||
"doc_permission_sync_attempt",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False),
|
||||
sa.Column("status", permission_sync_status_enum, nullable=False),
|
||||
sa.Column("total_docs_synced", sa.Integer(), nullable=True),
|
||||
sa.Column("docs_with_permission_errors", sa.Integer(), nullable=True),
|
||||
sa.Column("error_message", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("time_started", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("time_finished", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["connector_credential_pair_id"],
|
||||
["connector_credential_pair.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Create indexes for doc_permission_sync_attempt
|
||||
op.create_index(
|
||||
"ix_doc_permission_sync_attempt_time_created",
|
||||
"doc_permission_sync_attempt",
|
||||
["time_created"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_permission_sync_attempt_latest_for_cc_pair",
|
||||
"doc_permission_sync_attempt",
|
||||
["connector_credential_pair_id", "time_created"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_permission_sync_attempt_status_time",
|
||||
"doc_permission_sync_attempt",
|
||||
["status", sa.text("time_finished DESC")],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# Create external_group_permission_sync_attempt table
|
||||
# connector_credential_pair_id is nullable - group syncs can be global (e.g., Confluence)
|
||||
op.create_table(
|
||||
"external_group_permission_sync_attempt",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=True),
|
||||
sa.Column("status", permission_sync_status_enum, nullable=False),
|
||||
sa.Column("total_users_processed", sa.Integer(), nullable=True),
|
||||
sa.Column("total_groups_processed", sa.Integer(), nullable=True),
|
||||
sa.Column("total_group_memberships_synced", sa.Integer(), nullable=True),
|
||||
sa.Column("error_message", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("time_started", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("time_finished", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["connector_credential_pair_id"],
|
||||
["connector_credential_pair.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Create indexes for external_group_permission_sync_attempt
|
||||
op.create_index(
|
||||
"ix_external_group_permission_sync_attempt_time_created",
|
||||
"external_group_permission_sync_attempt",
|
||||
["time_created"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_group_sync_attempt_cc_pair_time",
|
||||
"external_group_permission_sync_attempt",
|
||||
["connector_credential_pair_id", "time_created"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_group_sync_attempt_status_time",
|
||||
"external_group_permission_sync_attempt",
|
||||
["status", sa.text("time_finished DESC")],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes
|
||||
op.drop_index(
|
||||
"ix_group_sync_attempt_status_time",
|
||||
table_name="external_group_permission_sync_attempt",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_group_sync_attempt_cc_pair_time",
|
||||
table_name="external_group_permission_sync_attempt",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_external_group_permission_sync_attempt_time_created",
|
||||
table_name="external_group_permission_sync_attempt",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_permission_sync_attempt_status_time",
|
||||
table_name="doc_permission_sync_attempt",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_permission_sync_attempt_latest_for_cc_pair",
|
||||
table_name="doc_permission_sync_attempt",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_doc_permission_sync_attempt_time_created",
|
||||
table_name="doc_permission_sync_attempt",
|
||||
)
|
||||
|
||||
# Drop tables
|
||||
op.drop_table("external_group_permission_sync_attempt")
|
||||
op.drop_table("doc_permission_sync_attempt")
|
||||
@@ -0,0 +1,28 @@
|
||||
"""reset userfile document_id_migrated field
|
||||
|
||||
Revision ID: 40926a4dab77
|
||||
Revises: 64bd5677aeb6
|
||||
Create Date: 2025-10-06 16:10:32.898668
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "40926a4dab77"
|
||||
down_revision = "64bd5677aeb6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Set all existing records to not migrated
|
||||
op.execute(
|
||||
"UPDATE user_file SET document_id_migrated = FALSE "
|
||||
"WHERE document_id_migrated IS DISTINCT FROM FALSE;"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# No-op
|
||||
pass
|
||||
@@ -0,0 +1,37 @@
|
||||
"""add queries and is web fetch to iteration answer
|
||||
|
||||
Revision ID: 6f4f86aef280
|
||||
Revises: 03d710ccf29c
|
||||
Create Date: 2025-10-14 18:08:30.920123
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6f4f86aef280"
|
||||
down_revision = "03d710ccf29c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add is_web_fetch column
|
||||
op.add_column(
|
||||
"research_agent_iteration_sub_step",
|
||||
sa.Column("is_web_fetch", sa.Boolean(), nullable=True),
|
||||
)
|
||||
|
||||
# Add queries column
|
||||
op.add_column(
|
||||
"research_agent_iteration_sub_step",
|
||||
sa.Column("queries", postgresql.JSONB(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("research_agent_iteration_sub_step", "queries")
|
||||
op.drop_column("research_agent_iteration_sub_step", "is_web_fetch")
|
||||
45
backend/alembic/versions/96a5702df6aa_mcp_tool_enabled.py
Normal file
45
backend/alembic/versions/96a5702df6aa_mcp_tool_enabled.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""mcp_tool_enabled
|
||||
|
||||
Revision ID: 96a5702df6aa
|
||||
Revises: 40926a4dab77
|
||||
Create Date: 2025-10-09 12:10:21.733097
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "96a5702df6aa"
|
||||
down_revision = "40926a4dab77"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
DELETE_DISABLED_TOOLS_SQL = "DELETE FROM tool WHERE enabled = false"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"tool",
|
||||
sa.Column(
|
||||
"enabled",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.true(),
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_tool_mcp_server_enabled",
|
||||
"tool",
|
||||
["mcp_server_id", "enabled"],
|
||||
)
|
||||
# Remove the server default so application controls defaulting
|
||||
op.alter_column("tool", "enabled", server_default=None)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(DELETE_DISABLED_TOOLS_SQL)
|
||||
op.drop_index("ix_tool_mcp_server_enabled", table_name="tool")
|
||||
op.drop_column("tool", "enabled")
|
||||
@@ -0,0 +1,72 @@
|
||||
"""personalization_user_info
|
||||
|
||||
Revision ID: c8a93a2af083
|
||||
Revises: 6f4f86aef280
|
||||
Create Date: 2025-10-14 15:59:03.577343
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c8a93a2af083"
|
||||
down_revision = "6f4f86aef280"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("personal_name", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("personal_role", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"use_memories",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.true(),
|
||||
),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"memory",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("memory_text", sa.Text(), nullable=False),
|
||||
sa.Column("conversation_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("message_id", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_index("ix_memory_user_id", "memory", ["user_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_memory_user_id", table_name="memory")
|
||||
op.drop_table("memory")
|
||||
|
||||
op.drop_column("user", "use_memories")
|
||||
op.drop_column("user", "personal_role")
|
||||
op.drop_column("user", "personal_name")
|
||||
@@ -1,29 +1,17 @@
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
|
||||
import jwt
|
||||
import requests
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import status
|
||||
from jwt import decode as jwt_decode
|
||||
from jwt import InvalidTokenError
|
||||
from jwt import PyJWTError
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ee.onyx.configs.app_configs import JWT_PUBLIC_KEY_URL
|
||||
from ee.onyx.configs.app_configs import SUPER_CLOUD_API_KEY
|
||||
from ee.onyx.configs.app_configs import SUPER_USERS
|
||||
from ee.onyx.db.saml import get_saml_account
|
||||
from ee.onyx.server.seeding import get_seed_config
|
||||
from ee.onyx.utils.secrets import extract_hashed_cookie
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -31,75 +19,11 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_public_key() -> str | None:
|
||||
if JWT_PUBLIC_KEY_URL is None:
|
||||
logger.error("JWT_PUBLIC_KEY_URL is not set")
|
||||
return None
|
||||
|
||||
response = requests.get(JWT_PUBLIC_KEY_URL)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
|
||||
|
||||
async def verify_jwt_token(token: str, async_db_session: AsyncSession) -> User | None:
|
||||
try:
|
||||
public_key_pem = get_public_key()
|
||||
if public_key_pem is None:
|
||||
logger.error("Failed to retrieve public key")
|
||||
return None
|
||||
|
||||
payload = jwt_decode(
|
||||
token,
|
||||
public_key_pem,
|
||||
algorithms=["RS256"],
|
||||
audience=None,
|
||||
)
|
||||
email = payload.get("email")
|
||||
if email:
|
||||
result = await async_db_session.execute(
|
||||
select(User).where(func.lower(User.email) == func.lower(email))
|
||||
)
|
||||
return result.scalars().first()
|
||||
except InvalidTokenError:
|
||||
logger.error("Invalid JWT token")
|
||||
get_public_key.cache_clear()
|
||||
except PyJWTError as e:
|
||||
logger.error(f"JWT decoding error: {str(e)}")
|
||||
get_public_key.cache_clear()
|
||||
return None
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
# All the Auth flows are valid for EE version
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
async def optional_user_(
|
||||
request: Request,
|
||||
user: User | None,
|
||||
async_db_session: AsyncSession,
|
||||
) -> User | None:
|
||||
# Check if the user has a session cookie from SAML
|
||||
if AUTH_TYPE == AuthType.SAML:
|
||||
saved_cookie = extract_hashed_cookie(request)
|
||||
|
||||
if saved_cookie:
|
||||
saml_account = await get_saml_account(
|
||||
cookie=saved_cookie, async_db_session=async_db_session
|
||||
)
|
||||
user = saml_account.user if saml_account else None
|
||||
|
||||
# If user is still None, check for JWT in Authorization header
|
||||
if user is None and JWT_PUBLIC_KEY_URL is not None:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
token = auth_header[len("Bearer ") :].strip()
|
||||
user = await verify_jwt_token(token, async_db_session)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def get_default_admin_user_emails_() -> list[str]:
|
||||
seed_config = get_seed_config()
|
||||
if seed_config and seed_config.admin_user_emails:
|
||||
|
||||
12
backend/ee/onyx/background/celery/apps/background.py
Normal file
12
backend/ee/onyx/background/celery/apps/background.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from onyx.background.celery.apps.background import celery_app
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
)
|
||||
@@ -1,123 +1,4 @@
|
||||
import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.server.query_history.api import fetch_and_process_chat_session_history
|
||||
from ee.onyx.server.query_history.api import ONYX_ANONYMIZED_EMAIL
|
||||
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
|
||||
from onyx.background.celery.apps.heavy import celery_app
|
||||
from onyx.background.task_utils import construct_query_history_report_name
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import FileType
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.tasks import delete_task_with_id
|
||||
from onyx.db.tasks import mark_task_as_finished_with_id
|
||||
from onyx.db.tasks import mark_task_as_started_with_id
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def export_query_history_task(
|
||||
self: Task,
|
||||
*,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
start_time: datetime,
|
||||
# Need to include the tenant_id since the TenantAwareTask needs this
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
if not self.request.id:
|
||||
raise RuntimeError("No task id defined for this task; cannot identify it")
|
||||
|
||||
task_id = self.request.id
|
||||
stream = io.StringIO()
|
||||
writer = csv.DictWriter(
|
||||
stream,
|
||||
fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()),
|
||||
)
|
||||
writer.writeheader()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
mark_task_as_started_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
snapshot_generator = fetch_and_process_chat_session_history(
|
||||
db_session=db_session,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
for snapshot in snapshot_generator:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
writer.writerows(
|
||||
qa_pair.to_json()
|
||||
for qa_pair in QuestionAnswerPairSnapshot.from_chat_session_snapshot(
|
||||
snapshot
|
||||
)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Failed to export query history with {task_id=}")
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
|
||||
report_name = construct_query_history_report_name(task_id)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
stream.seek(0)
|
||||
get_default_file_store().save_file(
|
||||
content=stream,
|
||||
display_name=report_name,
|
||||
file_origin=FileOrigin.QUERY_HISTORY_CSV,
|
||||
file_type=FileType.CSV,
|
||||
file_metadata={
|
||||
"start": start.isoformat(),
|
||||
"end": end.isoformat(),
|
||||
"start_time": start_time.isoformat(),
|
||||
},
|
||||
file_id=report_name,
|
||||
)
|
||||
|
||||
delete_task_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to save query history export file; {report_name=}"
|
||||
)
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
@@ -125,5 +6,6 @@ celery_app.autodiscover_tasks(
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -5,7 +5,6 @@ from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from ee.onyx.server.tenants.product_gating import get_gated_tenants
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
@@ -52,10 +51,18 @@ def cloud_beat_task_generator(
|
||||
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
gated_tenants = get_gated_tenants()
|
||||
|
||||
# NOTE: for now, we are running tasks for gated tenants, since we want to allow
|
||||
# connector deletion to run successfully. The new plan is to continously prune
|
||||
# the gated tenants set, so we won't have a build up of old, unused gated tenants.
|
||||
# Keeping this around in case we want to revert to the previous behavior.
|
||||
# gated_tenants = get_gated_tenants()
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id in gated_tenants:
|
||||
continue
|
||||
|
||||
# Same comment here as the above NOTE
|
||||
# if tenant_id in gated_tenants:
|
||||
# continue
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
|
||||
|
||||
@@ -56,6 +56,12 @@ from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.permission_sync_attempt import complete_doc_permission_sync_attempt
|
||||
from onyx.db.permission_sync_attempt import create_doc_permission_sync_attempt
|
||||
from onyx.db.permission_sync_attempt import mark_doc_permission_sync_attempt_failed
|
||||
from onyx.db.permission_sync_attempt import (
|
||||
mark_doc_permission_sync_attempt_in_progress,
|
||||
)
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
@@ -113,6 +119,14 @@ def _get_fence_validation_block_expiration() -> int:
|
||||
"""Jobs / utils for kicking off doc permissions sync tasks."""
|
||||
|
||||
|
||||
def _fail_doc_permission_sync_attempt(attempt_id: int, error_msg: str) -> None:
|
||||
"""Helper to mark a doc permission sync attempt as failed with an error message."""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_doc_permission_sync_attempt_failed(
|
||||
attempt_id, db_session, error_message=error_msg
|
||||
)
|
||||
|
||||
|
||||
def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if external doc permissions sync is due."""
|
||||
|
||||
@@ -379,6 +393,15 @@ def connector_permission_sync_generator_task(
|
||||
doc_permission_sync_ctx_dict["request_id"] = self.request.id
|
||||
doc_permission_sync_ctx.set(doc_permission_sync_ctx_dict)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
attempt_id = create_doc_permission_sync_attempt(
|
||||
connector_credential_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
task_logger.info(
|
||||
f"Created doc permission sync attempt: {attempt_id} for cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
r = get_redis_client()
|
||||
@@ -389,22 +412,28 @@ def connector_permission_sync_generator_task(
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
error_msg = (
|
||||
f"connector_permission_sync_generator_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if not redis_connector.permissions.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
error_msg = (
|
||||
f"connector_permission_sync_generator_task - fence not found: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
payload = redis_connector.permissions.payload # The payload must exist
|
||||
if not payload:
|
||||
raise ValueError(
|
||||
error_msg = (
|
||||
"connector_permission_sync_generator_task: payload invalid or not found"
|
||||
)
|
||||
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
logger.info(
|
||||
@@ -432,9 +461,11 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
error_msg = (
|
||||
f"Permission sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
task_logger.warning(error_msg)
|
||||
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -470,11 +501,15 @@ def connector_permission_sync_generator_task(
|
||||
source_type = cc_pair.connector.source
|
||||
sync_config = get_source_perm_sync_config(source_type)
|
||||
if sync_config is None:
|
||||
logger.error(f"No sync config found for {source_type}")
|
||||
error_msg = f"No sync config found for {source_type}"
|
||||
logger.error(error_msg)
|
||||
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
|
||||
return None
|
||||
|
||||
if sync_config.doc_sync_config is None:
|
||||
if sync_config.censoring_config:
|
||||
error_msg = f"Doc sync config is None but censoring config exists for {source_type}"
|
||||
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
|
||||
return None
|
||||
|
||||
raise ValueError(
|
||||
@@ -483,6 +518,8 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
|
||||
|
||||
mark_doc_permission_sync_attempt_in_progress(attempt_id, db_session)
|
||||
|
||||
payload = redis_connector.permissions.payload
|
||||
if not payload:
|
||||
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
|
||||
@@ -533,8 +570,9 @@ def connector_permission_sync_generator_task(
|
||||
)
|
||||
|
||||
tasks_generated = 0
|
||||
docs_with_errors = 0
|
||||
for doc_external_access in document_external_accesses:
|
||||
redis_connector.permissions.update_db(
|
||||
result = redis_connector.permissions.update_db(
|
||||
lock=lock,
|
||||
new_permissions=[doc_external_access],
|
||||
source_string=source_type,
|
||||
@@ -542,11 +580,23 @@ def connector_permission_sync_generator_task(
|
||||
credential_id=cc_pair.credential.id,
|
||||
task_logger=task_logger,
|
||||
)
|
||||
tasks_generated += 1
|
||||
tasks_generated += result.num_updated
|
||||
docs_with_errors += result.num_errors
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.permissions.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated} docs_with_errors={docs_with_errors}"
|
||||
)
|
||||
|
||||
complete_doc_permission_sync_attempt(
|
||||
db_session=db_session,
|
||||
attempt_id=attempt_id,
|
||||
total_docs_synced=tasks_generated,
|
||||
docs_with_permission_errors=docs_with_errors,
|
||||
)
|
||||
task_logger.info(
|
||||
f"Completed doc permission sync attempt {attempt_id}: "
|
||||
f"{tasks_generated} docs, {docs_with_errors} errors"
|
||||
)
|
||||
|
||||
redis_connector.permissions.generator_complete = tasks_generated
|
||||
@@ -561,6 +611,11 @@ def connector_permission_sync_generator_task(
|
||||
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id}"
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_doc_permission_sync_attempt_failed(
|
||||
attempt_id, db_session, error_message=error_msg
|
||||
)
|
||||
|
||||
redis_connector.permissions.generator_clear()
|
||||
redis_connector.permissions.taskset_clear()
|
||||
redis_connector.permissions.set_fence(None)
|
||||
|
||||
@@ -49,6 +49,16 @@ from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.permission_sync_attempt import complete_external_group_sync_attempt
|
||||
from onyx.db.permission_sync_attempt import (
|
||||
create_external_group_sync_attempt,
|
||||
)
|
||||
from onyx.db.permission_sync_attempt import (
|
||||
mark_external_group_sync_attempt_failed,
|
||||
)
|
||||
from onyx.db.permission_sync_attempt import (
|
||||
mark_external_group_sync_attempt_in_progress,
|
||||
)
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
@@ -70,6 +80,14 @@ logger = setup_logger()
|
||||
_EXTERNAL_GROUP_BATCH_SIZE = 100
|
||||
|
||||
|
||||
def _fail_external_group_sync_attempt(attempt_id: int, error_msg: str) -> None:
|
||||
"""Helper to mark an external group sync attempt as failed with an error message."""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_external_group_sync_attempt_failed(
|
||||
attempt_id, db_session, error_message=error_msg
|
||||
)
|
||||
|
||||
|
||||
def _get_fence_validation_block_expiration() -> int:
|
||||
"""
|
||||
Compute the expiration time for the fence validation block signal.
|
||||
@@ -449,6 +467,16 @@ def _perform_external_group_sync(
|
||||
cc_pair_id: int,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
# Create attempt record at the start
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
attempt_id = create_external_group_sync_attempt(
|
||||
connector_credential_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
logger.info(
|
||||
f"Created external group sync attempt: {attempt_id} for cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
@@ -463,11 +491,13 @@ def _perform_external_group_sync(
|
||||
if sync_config is None:
|
||||
msg = f"No sync config found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
_fail_external_group_sync_attempt(attempt_id, msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
if sync_config.group_sync_config is None:
|
||||
msg = f"No group sync config found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
_fail_external_group_sync_attempt(attempt_id, msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
ext_group_sync_func = sync_config.group_sync_config.group_sync_func
|
||||
@@ -477,14 +507,27 @@ def _perform_external_group_sync(
|
||||
)
|
||||
mark_old_external_groups_as_stale(db_session, cc_pair_id)
|
||||
|
||||
# Mark attempt as in progress
|
||||
mark_external_group_sync_attempt_in_progress(attempt_id, db_session)
|
||||
logger.info(f"Marked external group sync attempt {attempt_id} as in progress")
|
||||
|
||||
logger.info(
|
||||
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
external_user_group_batch: list[ExternalUserGroup] = []
|
||||
seen_users: set[str] = set() # Track unique users across all groups
|
||||
total_groups_processed = 0
|
||||
total_group_memberships_synced = 0
|
||||
try:
|
||||
external_user_group_generator = ext_group_sync_func(tenant_id, cc_pair)
|
||||
for external_user_group in external_user_group_generator:
|
||||
external_user_group_batch.append(external_user_group)
|
||||
|
||||
# Track progress
|
||||
total_groups_processed += 1
|
||||
total_group_memberships_synced += len(external_user_group.user_emails)
|
||||
seen_users = seen_users.union(external_user_group.user_emails)
|
||||
|
||||
if len(external_user_group_batch) >= _EXTERNAL_GROUP_BATCH_SIZE:
|
||||
logger.debug(
|
||||
f"New external user groups: {external_user_group_batch}"
|
||||
@@ -506,6 +549,13 @@ def _perform_external_group_sync(
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
except Exception as e:
|
||||
format_error_for_logging(e)
|
||||
|
||||
# Mark as failed (this also updates progress to show partial progress)
|
||||
mark_external_group_sync_attempt_failed(
|
||||
attempt_id, db_session, error_message=str(e)
|
||||
)
|
||||
|
||||
# TODO: add some notification to the admins here
|
||||
logger.exception(
|
||||
f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
|
||||
@@ -517,6 +567,24 @@ def _perform_external_group_sync(
|
||||
)
|
||||
remove_stale_external_groups(db_session, cc_pair_id)
|
||||
|
||||
# Calculate total unique users processed
|
||||
total_users_processed = len(seen_users)
|
||||
|
||||
# Complete the sync attempt with final progress
|
||||
complete_external_group_sync_attempt(
|
||||
db_session=db_session,
|
||||
attempt_id=attempt_id,
|
||||
total_users_processed=total_users_processed,
|
||||
total_groups_processed=total_groups_processed,
|
||||
total_group_memberships_synced=total_group_memberships_synced,
|
||||
errors_encountered=0,
|
||||
)
|
||||
logger.info(
|
||||
f"Completed external group sync attempt {attempt_id}: "
|
||||
f"{total_groups_processed} groups, {total_users_processed} users, "
|
||||
f"{total_group_memberships_synced} memberships"
|
||||
)
|
||||
|
||||
mark_all_relevant_cc_pairs_as_external_group_synced(db_session, cc_pair)
|
||||
|
||||
|
||||
|
||||
119
backend/ee/onyx/background/celery/tasks/query_history/tasks.py
Normal file
119
backend/ee/onyx/background/celery/tasks/query_history/tasks.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.server.query_history.api import fetch_and_process_chat_session_history
|
||||
from ee.onyx.server.query_history.api import ONYX_ANONYMIZED_EMAIL
|
||||
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
|
||||
from onyx.background.task_utils import construct_query_history_report_name
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import FileType
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.tasks import delete_task_with_id
|
||||
from onyx.db.tasks import mark_task_as_finished_with_id
|
||||
from onyx.db.tasks import mark_task_as_started_with_id
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def export_query_history_task(
|
||||
self: Task,
|
||||
*,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
start_time: datetime,
|
||||
# Need to include the tenant_id since the TenantAwareTask needs this
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
if not self.request.id:
|
||||
raise RuntimeError("No task id defined for this task; cannot identify it")
|
||||
|
||||
task_id = self.request.id
|
||||
stream = io.StringIO()
|
||||
writer = csv.DictWriter(
|
||||
stream,
|
||||
fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()),
|
||||
)
|
||||
writer.writeheader()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
mark_task_as_started_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
snapshot_generator = fetch_and_process_chat_session_history(
|
||||
db_session=db_session,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
for snapshot in snapshot_generator:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
writer.writerows(
|
||||
qa_pair.to_json()
|
||||
for qa_pair in QuestionAnswerPairSnapshot.from_chat_session_snapshot(
|
||||
snapshot
|
||||
)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Failed to export query history with {task_id=}")
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
|
||||
report_name = construct_query_history_report_name(task_id)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
stream.seek(0)
|
||||
get_default_file_store().save_file(
|
||||
content=stream,
|
||||
display_name=report_name,
|
||||
file_origin=FileOrigin.QUERY_HISTORY_CSV,
|
||||
file_type=FileType.CSV,
|
||||
file_metadata={
|
||||
"start": start.isoformat(),
|
||||
"end": end.isoformat(),
|
||||
"start_time": start_time.isoformat(),
|
||||
},
|
||||
file_id=report_name,
|
||||
)
|
||||
|
||||
delete_task_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to save query history export file; {report_name=}"
|
||||
)
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
@@ -1,26 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
# Applicable for OIDC Auth
|
||||
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
|
||||
|
||||
# Applicable for OIDC Auth, allows you to override the scopes that
|
||||
# are requested from the OIDC provider. Currently used when passing
|
||||
# over access tokens to tool calls and the tool needs more scopes
|
||||
OIDC_SCOPE_OVERRIDE: list[str] | None = None
|
||||
_OIDC_SCOPE_OVERRIDE = os.environ.get("OIDC_SCOPE_OVERRIDE")
|
||||
|
||||
if _OIDC_SCOPE_OVERRIDE:
|
||||
try:
|
||||
OIDC_SCOPE_OVERRIDE = [
|
||||
scope.strip() for scope in _OIDC_SCOPE_OVERRIDE.split(",")
|
||||
]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Applicable for SAML Auth
|
||||
SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_config"
|
||||
|
||||
|
||||
#####
|
||||
# Auto Permission Sync
|
||||
|
||||
@@ -73,6 +73,12 @@ def fetch_per_user_query_analytics(
|
||||
ChatSession.user_id,
|
||||
)
|
||||
.join(ChatSession, ChatSession.id == ChatMessage.chat_session_id)
|
||||
# Include chats that have no explicit feedback instead of dropping them
|
||||
.join(
|
||||
ChatMessageFeedback,
|
||||
ChatMessageFeedback.chat_message_id == ChatMessage.id,
|
||||
isouter=True,
|
||||
)
|
||||
.where(
|
||||
ChatMessage.time_sent >= start,
|
||||
)
|
||||
|
||||
@@ -50,6 +50,25 @@ def get_empty_chat_messages_entries__paginated(
|
||||
if message.message_type != MessageType.USER:
|
||||
continue
|
||||
|
||||
# Get user email
|
||||
user_email = chat_session.user.email if chat_session.user else None
|
||||
|
||||
# Get assistant name (from session persona, or alternate if specified)
|
||||
assistant_name = None
|
||||
if message.alternate_assistant_id:
|
||||
# If there's an alternate assistant, we need to fetch it
|
||||
from onyx.db.models import Persona
|
||||
|
||||
alternate_persona = (
|
||||
db_session.query(Persona)
|
||||
.filter(Persona.id == message.alternate_assistant_id)
|
||||
.first()
|
||||
)
|
||||
if alternate_persona:
|
||||
assistant_name = alternate_persona.name
|
||||
elif chat_session.persona:
|
||||
assistant_name = chat_session.persona.name
|
||||
|
||||
message_skeletons.append(
|
||||
ChatMessageSkeleton(
|
||||
message_id=message.id,
|
||||
@@ -57,6 +76,9 @@ def get_empty_chat_messages_entries__paginated(
|
||||
user_id=str(chat_session.user_id) if chat_session.user_id else None,
|
||||
flow_type=flow_type,
|
||||
time_sent=message.time_sent,
|
||||
assistant_name=assistant_name,
|
||||
user_email=user_email,
|
||||
number_of_tokens=message.token_count,
|
||||
)
|
||||
)
|
||||
if len(chat_sessions) == 0:
|
||||
|
||||
0
backend/ee/onyx/feature_flags/__init__.py
Normal file
0
backend/ee/onyx/feature_flags/__init__.py
Normal file
15
backend/ee/onyx/feature_flags/factory.py
Normal file
15
backend/ee/onyx/feature_flags/factory.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from ee.onyx.feature_flags.posthog_provider import PostHogFeatureFlagProvider
|
||||
from onyx.feature_flags.interface import FeatureFlagProvider
|
||||
|
||||
|
||||
def get_posthog_feature_flag_provider() -> FeatureFlagProvider:
|
||||
"""
|
||||
Get the PostHog feature flag provider instance.
|
||||
|
||||
This is the EE implementation that gets loaded by the versioned
|
||||
implementation loader.
|
||||
|
||||
Returns:
|
||||
PostHogFeatureFlagProvider: The PostHog-based feature flag provider
|
||||
"""
|
||||
return PostHogFeatureFlagProvider()
|
||||
54
backend/ee/onyx/feature_flags/posthog_provider.py
Normal file
54
backend/ee/onyx/feature_flags/posthog_provider.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from ee.onyx.utils.posthog_client import posthog
|
||||
from onyx.feature_flags.interface import FeatureFlagProvider
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class PostHogFeatureFlagProvider(FeatureFlagProvider):
|
||||
"""
|
||||
PostHog-based feature flag provider.
|
||||
|
||||
Uses PostHog's feature flag API to determine if features are enabled
|
||||
for specific users. Only active in multi-tenant mode.
|
||||
"""
|
||||
|
||||
def feature_enabled(
|
||||
self,
|
||||
flag_key: str,
|
||||
user_id: UUID,
|
||||
user_properties: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a feature flag is enabled for a user via PostHog.
|
||||
|
||||
Args:
|
||||
flag_key: The identifier for the feature flag to check
|
||||
user_id: The unique identifier for the user
|
||||
user_properties: Optional dictionary of user properties/attributes
|
||||
that may influence flag evaluation
|
||||
|
||||
Returns:
|
||||
True if the feature is enabled for the user, False otherwise.
|
||||
"""
|
||||
try:
|
||||
posthog.set(
|
||||
distinct_id=user_id,
|
||||
properties=user_properties,
|
||||
)
|
||||
is_enabled = posthog.feature_enabled(
|
||||
flag_key,
|
||||
str(user_id),
|
||||
person_properties=user_properties,
|
||||
)
|
||||
|
||||
return bool(is_enabled) if is_enabled is not None else False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error checking feature flag {flag_key} for user {user_id}: {e}"
|
||||
)
|
||||
return False
|
||||
@@ -3,11 +3,7 @@ from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from httpx_oauth.clients.openid import BASE_SCOPES
|
||||
from httpx_oauth.clients.openid import OpenID
|
||||
|
||||
from ee.onyx.configs.app_configs import OIDC_SCOPE_OVERRIDE
|
||||
from ee.onyx.configs.app_configs import OPENID_CONFIG_URL
|
||||
from ee.onyx.server.analytics.api import router as analytics_router
|
||||
from ee.onyx.server.auth_check import check_ee_router_auth
|
||||
from ee.onyx.server.documents.cc_pair import router as ee_document_cc_pair_router
|
||||
@@ -31,7 +27,6 @@ from ee.onyx.server.query_and_chat.query_backend import (
|
||||
)
|
||||
from ee.onyx.server.query_history.api import router as query_history_router
|
||||
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
|
||||
from ee.onyx.server.saml import router as saml_router
|
||||
from ee.onyx.server.seeding import seed_db
|
||||
from ee.onyx.server.tenants.api import router as tenants_router
|
||||
from ee.onyx.server.token_rate_limits.api import (
|
||||
@@ -117,49 +112,6 @@ def get_application() -> FastAPI:
|
||||
prefix="/auth",
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.OIDC:
|
||||
# Ensure we request offline_access for refresh tokens
|
||||
try:
|
||||
oidc_scopes = list(OIDC_SCOPE_OVERRIDE or BASE_SCOPES)
|
||||
if "offline_access" not in oidc_scopes:
|
||||
oidc_scopes.append("offline_access")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error configuring OIDC scopes: {e}")
|
||||
# Fall back to default scopes if there's an error
|
||||
oidc_scopes = BASE_SCOPES
|
||||
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
OpenID(
|
||||
OAUTH_CLIENT_ID,
|
||||
OAUTH_CLIENT_SECRET,
|
||||
OPENID_CONFIG_URL,
|
||||
# Use the configured scopes
|
||||
base_scopes=oidc_scopes,
|
||||
),
|
||||
auth_backend,
|
||||
USER_AUTH_SECRET,
|
||||
associate_by_email=True,
|
||||
is_verified_by_default=True,
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback",
|
||||
),
|
||||
prefix="/auth/oidc",
|
||||
)
|
||||
|
||||
# need basic auth router for `logout` endpoint
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_auth_router(auth_backend),
|
||||
prefix="/auth",
|
||||
)
|
||||
|
||||
elif AUTH_TYPE == AuthType.SAML:
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
saml_router,
|
||||
)
|
||||
|
||||
# RBAC / group access control
|
||||
include_router_with_global_prefix_prepended(application, user_group_router)
|
||||
# Analytics endpoints
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from typing import cast
|
||||
from typing import List
|
||||
|
||||
from cohere import Client
|
||||
|
||||
from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
|
||||
Embedding = List[float]
|
||||
|
||||
|
||||
def load_processed_docs(cohere_enabled: bool) -> list[dict]:
|
||||
base_path = os.path.join(os.getcwd(), "onyx", "seeding")
|
||||
|
||||
if cohere_enabled and COHERE_DEFAULT_API_KEY:
|
||||
initial_docs_path = os.path.join(base_path, "initial_docs_cohere.json")
|
||||
processed_docs = json.load(open(initial_docs_path))
|
||||
|
||||
cohere_client = Client(api_key=COHERE_DEFAULT_API_KEY)
|
||||
embed_model = "embed-english-v3.0"
|
||||
|
||||
for doc in processed_docs:
|
||||
title_embed_response = cohere_client.embed(
|
||||
texts=[doc["title"]],
|
||||
model=embed_model,
|
||||
input_type="search_document",
|
||||
)
|
||||
content_embed_response = cohere_client.embed(
|
||||
texts=[doc["content"]],
|
||||
model=embed_model,
|
||||
input_type="search_document",
|
||||
)
|
||||
|
||||
doc["title_embedding"] = cast(
|
||||
List[Embedding], title_embed_response.embeddings
|
||||
)[0]
|
||||
doc["content_embedding"] = cast(
|
||||
List[Embedding], content_embed_response.embeddings
|
||||
)[0]
|
||||
else:
|
||||
initial_docs_path = os.path.join(base_path, "initial_docs.json")
|
||||
processed_docs = json.load(open(initial_docs_path))
|
||||
|
||||
return processed_docs
|
||||
@@ -10,14 +10,6 @@ EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
|
||||
("/enterprise-settings/logo", {"GET"}),
|
||||
("/enterprise-settings/logotype", {"GET"}),
|
||||
("/enterprise-settings/custom-analytics-script", {"GET"}),
|
||||
# oidc
|
||||
("/auth/oidc/authorize", {"GET"}),
|
||||
("/auth/oidc/callback", {"GET"}),
|
||||
# saml
|
||||
("/auth/saml/authorize", {"GET"}),
|
||||
("/auth/saml/callback", {"POST"}),
|
||||
("/auth/saml/callback", {"GET"}),
|
||||
("/auth/saml/logout", {"POST"}),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -48,7 +48,17 @@ def generate_chat_messages_report(
|
||||
max_size=MAX_IN_MEMORY_SIZE, mode="w+"
|
||||
) as temp_file:
|
||||
csvwriter = csv.writer(temp_file, delimiter=",")
|
||||
csvwriter.writerow(["session_id", "user_id", "flow_type", "time_sent"])
|
||||
csvwriter.writerow(
|
||||
[
|
||||
"session_id",
|
||||
"user_id",
|
||||
"flow_type",
|
||||
"time_sent",
|
||||
"assistant_name",
|
||||
"user_email",
|
||||
"number_of_tokens",
|
||||
]
|
||||
)
|
||||
for chat_message_skeleton_batch in get_all_empty_chat_message_entries(
|
||||
db_session, period
|
||||
):
|
||||
@@ -59,6 +69,9 @@ def generate_chat_messages_report(
|
||||
chat_message_skeleton.user_id,
|
||||
chat_message_skeleton.flow_type,
|
||||
chat_message_skeleton.time_sent.isoformat(),
|
||||
chat_message_skeleton.assistant_name,
|
||||
chat_message_skeleton.user_email,
|
||||
chat_message_skeleton.number_of_tokens,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -16,6 +16,9 @@ class ChatMessageSkeleton(BaseModel):
|
||||
user_id: str | None
|
||||
flow_type: FlowType
|
||||
time_sent: datetime
|
||||
assistant_name: str | None
|
||||
user_email: str | None
|
||||
number_of_tokens: int
|
||||
|
||||
|
||||
class UserSkeleton(BaseModel):
|
||||
|
||||
@@ -37,9 +37,9 @@ from onyx.db.models import AvailableTenant
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import UserTenantMapping
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_VISIBLE_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import get_anthropic_model_names
|
||||
from onyx.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import OPEN_AI_VISIBLE_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
|
||||
@@ -278,7 +278,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
is_visible=name in ANTHROPIC_VISIBLE_MODEL_NAMES,
|
||||
max_input_tokens=None,
|
||||
)
|
||||
for name in ANTHROPIC_MODEL_NAMES
|
||||
for name in get_anthropic_model_names()
|
||||
],
|
||||
api_key_changed=True,
|
||||
)
|
||||
|
||||
22
backend/ee/onyx/utils/posthog_client.py
Normal file
22
backend/ee/onyx/utils/posthog_client.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Any
|
||||
|
||||
from posthog import Posthog
|
||||
|
||||
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
|
||||
from ee.onyx.configs.app_configs import POSTHOG_HOST
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def posthog_on_error(error: Any, items: Any) -> None:
|
||||
"""Log any PostHog delivery errors."""
|
||||
logger.error(f"PostHog error: {error}, items: {items}")
|
||||
|
||||
|
||||
posthog = Posthog(
|
||||
project_api_key=POSTHOG_API_KEY,
|
||||
host=POSTHOG_HOST,
|
||||
debug=True,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
@@ -1,27 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
from posthog import Posthog
|
||||
|
||||
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
|
||||
from ee.onyx.configs.app_configs import POSTHOG_HOST
|
||||
from ee.onyx.utils.posthog_client import posthog
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def posthog_on_error(error: Any, items: Any) -> None:
|
||||
"""Log any PostHog delivery errors."""
|
||||
logger.error(f"PostHog error: {error}, items: {items}")
|
||||
|
||||
|
||||
posthog = Posthog(
|
||||
project_api_key=POSTHOG_API_KEY,
|
||||
host=POSTHOG_HOST,
|
||||
debug=True,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
|
||||
|
||||
def event_telemetry(
|
||||
distinct_id: str, event: str, properties: dict | None = None
|
||||
) -> None:
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fastapi import APIRouter
|
||||
from huggingface_hub import snapshot_download # type: ignore
|
||||
from setfit import SetFitModel # type: ignore[import]
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
from transformers import BatchEncoding # type: ignore
|
||||
from transformers import PreTrainedTokenizer # type: ignore
|
||||
|
||||
from model_server.constants import INFORMATION_CONTENT_MODEL_WARM_UP_STRING
|
||||
from model_server.constants import MODEL_WARM_UP_STRING
|
||||
@@ -37,23 +35,30 @@ from shared_configs.model_server_models import ContentClassificationPrediction
|
||||
from shared_configs.model_server_models import IntentRequest
|
||||
from shared_configs.model_server_models import IntentResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from setfit import SetFitModel # type: ignore
|
||||
from transformers import PreTrainedTokenizer, BatchEncoding # type: ignore
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/custom")
|
||||
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER: PreTrainedTokenizer | None = None
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER: Optional["PreTrainedTokenizer"] = None
|
||||
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
|
||||
|
||||
_INTENT_TOKENIZER: PreTrainedTokenizer | None = None
|
||||
_INTENT_TOKENIZER: Optional["PreTrainedTokenizer"] = None
|
||||
_INTENT_MODEL: HybridClassifier | None = None
|
||||
|
||||
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
|
||||
_INFORMATION_CONTENT_MODEL: Optional["SetFitModel"] = None
|
||||
|
||||
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
|
||||
|
||||
|
||||
def get_connector_classifier_tokenizer() -> PreTrainedTokenizer:
|
||||
def get_connector_classifier_tokenizer() -> "PreTrainedTokenizer":
|
||||
global _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
|
||||
# The tokenizer details are not uploaded to the HF hub since it's just the
|
||||
# unmodified distilbert tokenizer.
|
||||
@@ -95,7 +100,9 @@ def get_local_connector_classifier(
|
||||
return _CONNECTOR_CLASSIFIER_MODEL
|
||||
|
||||
|
||||
def get_intent_model_tokenizer() -> PreTrainedTokenizer:
|
||||
def get_intent_model_tokenizer() -> "PreTrainedTokenizer":
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
global _INTENT_TOKENIZER
|
||||
if _INTENT_TOKENIZER is None:
|
||||
# The tokenizer details are not uploaded to the HF hub since it's just the
|
||||
@@ -141,7 +148,9 @@ def get_local_intent_model(
|
||||
def get_local_information_content_model(
|
||||
model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION,
|
||||
tag: str | None = INFORMATION_CONTENT_MODEL_TAG,
|
||||
) -> SetFitModel:
|
||||
) -> "SetFitModel":
|
||||
from setfit import SetFitModel
|
||||
|
||||
global _INFORMATION_CONTENT_MODEL
|
||||
if _INFORMATION_CONTENT_MODEL is None:
|
||||
try:
|
||||
@@ -179,7 +188,7 @@ def get_local_information_content_model(
|
||||
def tokenize_connector_classification_query(
|
||||
connectors: list[str],
|
||||
query: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
connector_token_end_id: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@@ -267,7 +276,7 @@ def warm_up_information_content_model() -> None:
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
|
||||
def run_inference(tokens: "BatchEncoding") -> tuple[list[float], list[float]]:
|
||||
intent_model = get_local_intent_model()
|
||||
device = intent_model.device
|
||||
|
||||
@@ -401,7 +410,7 @@ def run_content_classification_inference(
|
||||
|
||||
|
||||
def map_keywords(
|
||||
input_ids: torch.Tensor, tokenizer: PreTrainedTokenizer, is_keyword: list[bool]
|
||||
input_ids: torch.Tensor, tokenizer: "PreTrainedTokenizer", is_keyword: list[bool]
|
||||
) -> list[str]:
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore
|
||||
|
||||
|
||||
@@ -2,12 +2,11 @@ import asyncio
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
from model_server.utils import simple_log_function_time
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -19,6 +18,9 @@ from shared_configs.model_server_models import EmbedResponse
|
||||
from shared_configs.model_server_models import RerankRequest
|
||||
from shared_configs.model_server_models import RerankResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentence_transformers import CrossEncoder, SentenceTransformer
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/encoder")
|
||||
@@ -87,8 +89,10 @@ def get_embedding_model(
|
||||
|
||||
def get_local_reranking_model(
|
||||
model_name: str,
|
||||
) -> CrossEncoder:
|
||||
) -> "CrossEncoder":
|
||||
global _RERANK_MODEL
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
|
||||
if _RERANK_MODEL is None:
|
||||
logger.notice(f"Loading {model_name}")
|
||||
model = CrossEncoder(model_name)
|
||||
|
||||
@@ -30,6 +30,7 @@ from shared_configs.configs import MIN_THREADS_ML_MODELS
|
||||
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
from shared_configs.configs import SKIP_WARM_UP
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
@@ -91,16 +92,17 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
|
||||
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
if not INDEXING_ONLY:
|
||||
logger.notice(
|
||||
"The intent model should run on the model server. The information content model should not run here."
|
||||
)
|
||||
warm_up_intent_model()
|
||||
if not SKIP_WARM_UP:
|
||||
if not INDEXING_ONLY:
|
||||
logger.notice("Warming up intent model for inference model server")
|
||||
warm_up_intent_model()
|
||||
else:
|
||||
logger.notice(
|
||||
"Warming up content information model for indexing model server"
|
||||
)
|
||||
warm_up_information_content_model()
|
||||
else:
|
||||
logger.notice(
|
||||
"The content information model should run on the indexing model server. The intent model should not run here."
|
||||
)
|
||||
warm_up_information_content_model()
|
||||
logger.notice("Skipping model warmup due to SKIP_WARM_UP=true")
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
import json
|
||||
import os
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import DistilBertConfig # type: ignore
|
||||
from transformers import DistilBertModel # type: ignore
|
||||
from transformers import DistilBertTokenizer # type: ignore
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import DistilBertConfig # type: ignore
|
||||
|
||||
|
||||
class HybridClassifier(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
from transformers import DistilBertConfig, DistilBertModel
|
||||
|
||||
super().__init__()
|
||||
config = DistilBertConfig()
|
||||
self.distilbert = DistilBertModel(config)
|
||||
@@ -74,7 +78,9 @@ class HybridClassifier(nn.Module):
|
||||
|
||||
|
||||
class ConnectorClassifier(nn.Module):
|
||||
def __init__(self, config: DistilBertConfig) -> None:
|
||||
def __init__(self, config: "DistilBertConfig") -> None:
|
||||
from transformers import DistilBertTokenizer, DistilBertModel
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
@@ -115,6 +121,8 @@ class ConnectorClassifier(nn.Module):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
|
||||
from transformers import DistilBertConfig
|
||||
|
||||
config = cast(
|
||||
DistilBertConfig,
|
||||
DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")),
|
||||
|
||||
78
backend/onyx/agents/agent_sdk/message_format.py
Normal file
78
backend/onyx/agents/agent_sdk/message_format.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
|
||||
|
||||
# TODO: Currently, we only support native API input for images. For other
|
||||
# files, we process the content and share it as text in the message. In
|
||||
# the future, we might support native file uploads for other types of files.
|
||||
def base_messages_to_agent_sdk_msgs(msgs: Sequence[BaseMessage]) -> list[dict]:
|
||||
return [_base_message_to_agent_sdk_msg(msg) for msg in msgs]
|
||||
|
||||
|
||||
def _base_message_to_agent_sdk_msg(msg: BaseMessage) -> dict:
|
||||
message_type_to_agent_sdk_role = {
|
||||
"human": "user",
|
||||
"system": "system",
|
||||
"ai": "assistant",
|
||||
}
|
||||
role = message_type_to_agent_sdk_role[msg.type]
|
||||
|
||||
# Convert content to Agent SDK format
|
||||
content = msg.content
|
||||
if isinstance(content, str):
|
||||
# Convert string to structured text format
|
||||
structured_content = [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": content,
|
||||
}
|
||||
]
|
||||
elif isinstance(content, list):
|
||||
# Content is already a list, process each item
|
||||
structured_content = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
structured_content.append(
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": item,
|
||||
}
|
||||
)
|
||||
elif isinstance(item, dict):
|
||||
# Handle different item types
|
||||
item_type = item.get("type")
|
||||
|
||||
if item_type == "text":
|
||||
# Convert text type to input_text
|
||||
structured_content.append(
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": item.get("text", ""),
|
||||
}
|
||||
)
|
||||
elif item_type == "image_url":
|
||||
# Convert image_url to input_image format
|
||||
image_url = item.get("image_url", {})
|
||||
if isinstance(image_url, dict):
|
||||
url = image_url.get("url", "")
|
||||
else:
|
||||
url = image_url
|
||||
structured_content.append(
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": url,
|
||||
"detail": "auto",
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected item type: {type(item)}. Item: {item}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected content type: {type(content)}. Content: {content}"
|
||||
)
|
||||
|
||||
return {
|
||||
"role": role,
|
||||
"content": structured_content,
|
||||
}
|
||||
176
backend/onyx/agents/agent_sdk/sync_agent_stream_adapter.py
Normal file
176
backend/onyx/agents/agent_sdk/sync_agent_stream_adapter.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import asyncio
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Iterator
|
||||
from typing import Generic
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
from agents import Agent
|
||||
from agents import RunResultStreaming
|
||||
from agents import TContext
|
||||
from agents.run import Runner
|
||||
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class SyncAgentStream(Generic[T]):
|
||||
"""
|
||||
Convert an async streamed run into a sync iterator with cooperative cancellation.
|
||||
Runs the Agent in a background thread.
|
||||
|
||||
Usage:
|
||||
adapter = SyncStreamAdapter(
|
||||
agent=agent,
|
||||
input=input,
|
||||
context=context,
|
||||
max_turns=100,
|
||||
queue_maxsize=0, # optional backpressure
|
||||
)
|
||||
for ev in adapter: # sync iteration
|
||||
...
|
||||
# or cancel from elsewhere:
|
||||
adapter.cancel()
|
||||
"""
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
agent: Agent,
|
||||
input: list[dict],
|
||||
context: TContext | None = None,
|
||||
max_turns: int = 100,
|
||||
queue_maxsize: int = 0,
|
||||
) -> None:
|
||||
self._agent = agent
|
||||
self._input = input
|
||||
self._context = context
|
||||
self._max_turns = max_turns
|
||||
|
||||
self._q: "queue.Queue[object]" = queue.Queue(maxsize=queue_maxsize)
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self.streamed: RunResultStreaming | None = None
|
||||
self._exc: Optional[BaseException] = None
|
||||
self._cancel_requested = threading.Event()
|
||||
self._started = threading.Event()
|
||||
self._done = threading.Event()
|
||||
|
||||
self._start_thread()
|
||||
|
||||
# ---------- public sync API ----------
|
||||
|
||||
def __iter__(self) -> Iterator[T]:
|
||||
try:
|
||||
while True:
|
||||
item = self._q.get()
|
||||
if item is self._SENTINEL:
|
||||
# If the consumer thread raised, surface it now
|
||||
if self._exc is not None:
|
||||
raise self._exc
|
||||
# Normal completion
|
||||
return
|
||||
yield item # type: ignore[misc,return-value]
|
||||
finally:
|
||||
# Ensure we fully clean up whether we exited due to exception,
|
||||
# StopIteration, or external cancel.
|
||||
self.close()
|
||||
|
||||
def cancel(self) -> bool:
|
||||
"""
|
||||
Cooperatively cancel the underlying streamed run and shut down.
|
||||
Safe to call multiple times and from any thread.
|
||||
"""
|
||||
self._cancel_requested.set()
|
||||
loop = self._loop
|
||||
streamed = self.streamed
|
||||
if loop is not None and streamed is not None and not self._done.is_set():
|
||||
loop.call_soon_threadsafe(streamed.cancel)
|
||||
return True
|
||||
return False
|
||||
|
||||
def close(self, *, wait: bool = True) -> None:
|
||||
"""Idempotent shutdown."""
|
||||
self.cancel()
|
||||
# ask the loop to stop if it's still running
|
||||
loop = self._loop
|
||||
if loop is not None and loop.is_running():
|
||||
try:
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
except Exception:
|
||||
pass
|
||||
# join the thread
|
||||
if wait and self._thread is not None and self._thread.is_alive():
|
||||
self._thread.join(timeout=5.0)
|
||||
|
||||
# ---------- internals ----------
|
||||
|
||||
def _start_thread(self) -> None:
|
||||
t = run_in_background(self._thread_main)
|
||||
self._thread = t
|
||||
# Optionally wait until the loop/worker is started so .cancel() is safe soon after init
|
||||
self._started.wait(timeout=1.0)
|
||||
|
||||
def _thread_main(self) -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
self._loop = loop
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
async def worker() -> None:
|
||||
try:
|
||||
# Start the streamed run inside the loop thread
|
||||
self.streamed = Runner.run_streamed(
|
||||
self._agent,
|
||||
self._input, # type: ignore[arg-type]
|
||||
context=self._context,
|
||||
max_turns=self._max_turns,
|
||||
)
|
||||
|
||||
# If cancel was requested before we created _streamed, honor it now
|
||||
if self._cancel_requested.is_set():
|
||||
await self.streamed.cancel() # type: ignore[func-returns-value]
|
||||
|
||||
# Consume async events and forward into the thread-safe queue
|
||||
async for ev in self.streamed.stream_events():
|
||||
# Early exit if a late cancel arrives
|
||||
if self._cancel_requested.is_set():
|
||||
# Try to cancel gracefully; don't break until cancel takes effect
|
||||
try:
|
||||
await self.streamed.cancel() # type: ignore[func-returns-value]
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
# This put() may block if queue_maxsize > 0 (backpressure)
|
||||
self._q.put(ev)
|
||||
|
||||
except BaseException as e:
|
||||
# Save exception to surface on the sync iterator side
|
||||
self._exc = e
|
||||
finally:
|
||||
# Signal end-of-stream
|
||||
self._q.put(self._SENTINEL)
|
||||
self._done.set()
|
||||
|
||||
# Mark started and run the worker to completion
|
||||
self._started.set()
|
||||
try:
|
||||
loop.run_until_complete(worker())
|
||||
finally:
|
||||
try:
|
||||
# Drain pending tasks/callbacks safely
|
||||
pending = asyncio.all_tasks(loop=loop)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
if pending:
|
||||
loop.run_until_complete(
|
||||
asyncio.gather(*pending, return_exceptions=True)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
self._loop = None
|
||||
@@ -100,9 +100,14 @@ class IterationAnswer(BaseModel):
|
||||
response_type: str | None = None
|
||||
data: dict | list | str | int | float | bool | None = None
|
||||
file_ids: list[str] | None = None
|
||||
|
||||
# TODO: This is not ideal, but we'll can rework the schema
|
||||
# for deep research later
|
||||
is_web_fetch: bool = False
|
||||
# for image generation step-types
|
||||
generated_images: list[GeneratedImage] | None = None
|
||||
# for multi-query search tools (v2 web search and internal search)
|
||||
# TODO: Clean this up to be more flexible to tools
|
||||
queries: list[str] | None = None
|
||||
|
||||
|
||||
class AggregatedDRContext(BaseModel):
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from braintrust import traced
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
@@ -22,6 +23,9 @@ from onyx.agents.agent_search.dr.models import DecisionResponse
|
||||
from onyx.agents.agent_search.dr.models import DRPromptPurpose
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
|
||||
from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
from onyx.agents.agent_search.dr.process_llm_stream import (
|
||||
BasicSearchProcessedStreamResults,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.process_llm_stream import process_llm_stream
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.states import OrchestrationSetup
|
||||
@@ -37,6 +41,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.chat.chat_utils import build_citation_map_from_numbers
|
||||
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
|
||||
from onyx.chat.memories import make_memories_callback
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.citations_prompt import build_citations_system_message
|
||||
from onyx.chat.prompt_builder.citations_prompt import build_citations_user_message
|
||||
@@ -70,6 +75,8 @@ from onyx.prompts.dr_prompts import DEFAULT_DR_SYSTEM_PROMPT
|
||||
from onyx.prompts.dr_prompts import REPEAT_PROMPT
|
||||
from onyx.prompts.dr_prompts import TOOL_DESCRIPTION
|
||||
from onyx.prompts.prompt_template import PromptTemplate
|
||||
from onyx.prompts.prompt_utils import handle_company_awareness
|
||||
from onyx.prompts.prompt_utils import handle_memories
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
@@ -116,7 +123,9 @@ def _get_available_tools(
|
||||
else:
|
||||
include_kg = False
|
||||
|
||||
tool_dict: dict[int, Tool] = {tool.id: tool for tool in get_tools(db_session)}
|
||||
tool_dict: dict[int, Tool] = {
|
||||
tool.id: tool for tool in get_tools(db_session, only_enabled=True)
|
||||
}
|
||||
|
||||
for tool in graph_config.tooling.tools:
|
||||
|
||||
@@ -484,6 +493,16 @@ def clarifier(
|
||||
+ PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
+ graph_config.inputs.project_instructions
|
||||
)
|
||||
user = (
|
||||
graph_config.tooling.search_tool.user
|
||||
if graph_config.tooling.search_tool
|
||||
else None
|
||||
)
|
||||
memories_callback = make_memories_callback(user, db_session)
|
||||
assistant_system_prompt = handle_company_awareness(assistant_system_prompt)
|
||||
assistant_system_prompt = handle_memories(
|
||||
assistant_system_prompt, memories_callback
|
||||
)
|
||||
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
@@ -666,28 +685,30 @@ def clarifier(
|
||||
system_prompt_to_use = assistant_system_prompt
|
||||
user_prompt_to_use = decision_prompt + assistant_task_prompt
|
||||
|
||||
stream = graph_config.tooling.primary_llm.stream(
|
||||
prompt=create_question_prompt(
|
||||
cast(str, system_prompt_to_use),
|
||||
cast(str, user_prompt_to_use),
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
tools=([_ARTIFICIAL_ALL_ENCOMPASSING_TOOL]),
|
||||
tool_choice=(None),
|
||||
structured_response_format=graph_config.inputs.structured_response_format,
|
||||
)
|
||||
|
||||
full_response = process_llm_stream(
|
||||
messages=stream,
|
||||
should_stream_answer=True,
|
||||
writer=writer,
|
||||
ind=0,
|
||||
final_search_results=context_llm_docs,
|
||||
displayed_search_results=context_llm_docs,
|
||||
generate_final_answer=True,
|
||||
chat_message_id=str(graph_config.persistence.chat_session_id),
|
||||
)
|
||||
@traced(name="clarifier stream and process", type="llm")
|
||||
def stream_and_process() -> BasicSearchProcessedStreamResults:
|
||||
stream = graph_config.tooling.primary_llm.stream(
|
||||
prompt=create_question_prompt(
|
||||
cast(str, system_prompt_to_use),
|
||||
cast(str, user_prompt_to_use),
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
tools=([_ARTIFICIAL_ALL_ENCOMPASSING_TOOL]),
|
||||
tool_choice=(None),
|
||||
structured_response_format=graph_config.inputs.structured_response_format,
|
||||
)
|
||||
return process_llm_stream(
|
||||
messages=stream,
|
||||
should_stream_answer=True,
|
||||
writer=writer,
|
||||
ind=0,
|
||||
final_search_results=context_llm_docs,
|
||||
displayed_search_results=context_llm_docs,
|
||||
generate_final_answer=True,
|
||||
chat_message_id=str(graph_config.persistence.chat_session_id),
|
||||
)
|
||||
|
||||
full_response = stream_and_process()
|
||||
if len(full_response.ai_message_chunk.tool_calls) == 0:
|
||||
|
||||
if isinstance(full_response.full_answer, str):
|
||||
|
||||
@@ -199,6 +199,7 @@ def save_iteration(
|
||||
else None
|
||||
),
|
||||
additional_data=iteration_answer.additional_data,
|
||||
queries=iteration_answer.queries,
|
||||
)
|
||||
db_session.add(research_agent_iteration_sub_step)
|
||||
|
||||
|
||||
@@ -180,6 +180,7 @@ def save_iteration(
|
||||
else None
|
||||
),
|
||||
additional_data=iteration_answer.additional_data,
|
||||
queries=iteration_answer.queries,
|
||||
)
|
||||
db_session.add(research_agent_iteration_sub_step)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
@@ -28,6 +29,7 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -62,6 +64,29 @@ def image_generation(
|
||||
image_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
|
||||
|
||||
image_prompt = branch_query
|
||||
requested_shape: ImageShape | None = None
|
||||
|
||||
try:
|
||||
parsed_query = json.loads(branch_query)
|
||||
except json.JSONDecodeError:
|
||||
parsed_query = None
|
||||
|
||||
if isinstance(parsed_query, dict):
|
||||
prompt_from_llm = parsed_query.get("prompt")
|
||||
if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
|
||||
image_prompt = prompt_from_llm.strip()
|
||||
|
||||
raw_shape = parsed_query.get("shape")
|
||||
if isinstance(raw_shape, str):
|
||||
try:
|
||||
requested_shape = ImageShape(raw_shape)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"Received unsupported image shape '%s' from LLM. Falling back to square.",
|
||||
raw_shape,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
@@ -69,7 +94,15 @@ def image_generation(
|
||||
# Generate images using the image generation tool
|
||||
image_generation_responses: list[ImageGenerationResponse] = []
|
||||
|
||||
for tool_response in image_tool.run(prompt=branch_query):
|
||||
if requested_shape is not None:
|
||||
tool_iterator = image_tool.run(
|
||||
prompt=image_prompt,
|
||||
shape=requested_shape.value,
|
||||
)
|
||||
else:
|
||||
tool_iterator = image_tool.run(prompt=image_prompt)
|
||||
|
||||
for tool_response in tool_iterator:
|
||||
if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
|
||||
# Stream heartbeat to frontend
|
||||
write_custom_event(
|
||||
@@ -95,6 +128,7 @@ def image_generation(
|
||||
file_id=file_id,
|
||||
url=build_frontend_file_url(file_id),
|
||||
revised_prompt=img.revised_prompt,
|
||||
shape=(requested_shape or ImageShape.SQUARE).value,
|
||||
)
|
||||
for file_id, img in zip(file_ids, image_generation_responses)
|
||||
]
|
||||
@@ -107,15 +141,29 @@ def image_generation(
|
||||
if final_generated_images:
|
||||
image_descriptions = []
|
||||
for i, img in enumerate(final_generated_images, 1):
|
||||
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
|
||||
if img.shape and img.shape != ImageShape.SQUARE.value:
|
||||
image_descriptions.append(
|
||||
f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
|
||||
)
|
||||
else:
|
||||
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
|
||||
|
||||
answer_string = (
|
||||
f"Generated {len(final_generated_images)} image(s) based on the request: {branch_query}\n\n"
|
||||
f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
|
||||
+ "\n".join(image_descriptions)
|
||||
)
|
||||
reasoning = f"Used image generation tool to create {len(final_generated_images)} image(s) based on the user's request."
|
||||
if requested_shape:
|
||||
reasoning = (
|
||||
"Used image generation tool to create "
|
||||
f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
|
||||
)
|
||||
else:
|
||||
reasoning = (
|
||||
"Used image generation tool to create "
|
||||
f"{len(final_generated_images)} image(s) based on the user's request."
|
||||
)
|
||||
else:
|
||||
answer_string = f"Failed to generate images for request: {branch_query}"
|
||||
answer_string = f"Failed to generate images for request: {image_prompt}"
|
||||
reasoning = "Image generation tool did not return any results."
|
||||
|
||||
return BranchUpdate(
|
||||
|
||||
@@ -5,6 +5,7 @@ class GeneratedImage(BaseModel):
|
||||
file_id: str
|
||||
url: str
|
||||
revised_prompt: str
|
||||
shape: str | None = None
|
||||
|
||||
|
||||
# Needed for PydanticType
|
||||
|
||||
@@ -2,30 +2,28 @@ from exa_py import Exa
|
||||
from exa_py.api import HighlightsContentsOptions
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetContent,
|
||||
WebContent,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchProvider,
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.configs.chat_configs import EXA_API_KEY
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
# TODO Dependency inject for testing
|
||||
class ExaClient(InternetSearchProvider):
|
||||
class ExaClient(WebSearchProvider):
|
||||
def __init__(self, api_key: str | None = EXA_API_KEY) -> None:
|
||||
self.exa = Exa(api_key=api_key)
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def search(self, query: str) -> list[InternetSearchResult]:
|
||||
def search(self, query: str) -> list[WebSearchResult]:
|
||||
response = self.exa.search_and_contents(
|
||||
query,
|
||||
type="fast",
|
||||
livecrawl="never",
|
||||
type="auto",
|
||||
highlights=HighlightsContentsOptions(
|
||||
num_sentences=2,
|
||||
highlights_per_url=1,
|
||||
@@ -34,7 +32,7 @@ class ExaClient(InternetSearchProvider):
|
||||
)
|
||||
|
||||
return [
|
||||
InternetSearchResult(
|
||||
WebSearchResult(
|
||||
title=result.title or "",
|
||||
link=result.url,
|
||||
snippet=result.highlights[0] if result.highlights else "",
|
||||
@@ -49,7 +47,7 @@ class ExaClient(InternetSearchProvider):
|
||||
]
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def contents(self, urls: list[str]) -> list[InternetContent]:
|
||||
def contents(self, urls: list[str]) -> list[WebContent]:
|
||||
response = self.exa.get_contents(
|
||||
urls=urls,
|
||||
text=True,
|
||||
@@ -57,7 +55,7 @@ class ExaClient(InternetSearchProvider):
|
||||
)
|
||||
|
||||
return [
|
||||
InternetContent(
|
||||
WebContent(
|
||||
title=result.title or "",
|
||||
link=result.url,
|
||||
full_content=result.text or "",
|
||||
|
||||
@@ -4,13 +4,13 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
import requests
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetContent,
|
||||
WebContent,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchProvider,
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.configs.chat_configs import SERPER_API_KEY
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
@@ -20,7 +20,7 @@ SERPER_SEARCH_URL = "https://google.serper.dev/search"
|
||||
SERPER_CONTENTS_URL = "https://scrape.serper.dev"
|
||||
|
||||
|
||||
class SerperClient(InternetSearchProvider):
|
||||
class SerperClient(WebSearchProvider):
|
||||
def __init__(self, api_key: str | None = SERPER_API_KEY) -> None:
|
||||
self.headers = {
|
||||
"X-API-KEY": api_key,
|
||||
@@ -28,7 +28,7 @@ class SerperClient(InternetSearchProvider):
|
||||
}
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def search(self, query: str) -> list[InternetSearchResult]:
|
||||
def search(self, query: str) -> list[WebSearchResult]:
|
||||
payload = {
|
||||
"q": query,
|
||||
}
|
||||
@@ -45,7 +45,7 @@ class SerperClient(InternetSearchProvider):
|
||||
organic_results = results["organic"]
|
||||
|
||||
return [
|
||||
InternetSearchResult(
|
||||
WebSearchResult(
|
||||
title=result["title"],
|
||||
link=result["link"],
|
||||
snippet=result["snippet"],
|
||||
@@ -55,17 +55,17 @@ class SerperClient(InternetSearchProvider):
|
||||
for result in organic_results
|
||||
]
|
||||
|
||||
def contents(self, urls: list[str]) -> list[InternetContent]:
|
||||
def contents(self, urls: list[str]) -> list[WebContent]:
|
||||
if not urls:
|
||||
return []
|
||||
|
||||
# Serper can responds with 500s regularly. We want to retry,
|
||||
# but in the event of failure, return an unsuccesful scrape.
|
||||
def safe_get_webpage_content(url: str) -> InternetContent:
|
||||
def safe_get_webpage_content(url: str) -> WebContent:
|
||||
try:
|
||||
return self._get_webpage_content(url)
|
||||
except Exception:
|
||||
return InternetContent(
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
@@ -77,7 +77,7 @@ class SerperClient(InternetSearchProvider):
|
||||
return list(e.map(safe_get_webpage_content, urls))
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def _get_webpage_content(self, url: str) -> InternetContent:
|
||||
def _get_webpage_content(self, url: str) -> WebContent:
|
||||
payload = {
|
||||
"url": url,
|
||||
}
|
||||
@@ -90,7 +90,7 @@ class SerperClient(InternetSearchProvider):
|
||||
|
||||
# 400 returned when serper cannot scrape
|
||||
if response.status_code == 400:
|
||||
return InternetContent(
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
@@ -122,7 +122,7 @@ class SerperClient(InternetSearchProvider):
|
||||
except Exception:
|
||||
published_date = None
|
||||
|
||||
return InternetContent(
|
||||
return WebContent(
|
||||
title=title or "",
|
||||
link=response_url,
|
||||
full_content=text or "",
|
||||
|
||||
@@ -7,7 +7,7 @@ from langsmith import traceable
|
||||
|
||||
from onyx.agents.agent_search.dr.models import WebSearchAnswer
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
|
||||
get_default_provider,
|
||||
@@ -75,15 +75,15 @@ def web_search(
|
||||
raise ValueError("No internet search provider found")
|
||||
|
||||
@traceable(name="Search Provider API Call")
|
||||
def _search(search_query: str) -> list[InternetSearchResult]:
|
||||
search_results: list[InternetSearchResult] = []
|
||||
def _search(search_query: str) -> list[WebSearchResult]:
|
||||
search_results: list[WebSearchResult] = []
|
||||
try:
|
||||
search_results = provider.search(search_query)
|
||||
except Exception as e:
|
||||
logger.error(f"Error performing search: {e}")
|
||||
return search_results
|
||||
|
||||
search_results: list[InternetSearchResult] = _search(search_query)
|
||||
search_results: list[WebSearchResult] = _search(search_query)
|
||||
search_results_text = "\n\n".join(
|
||||
[
|
||||
f"{i}. {result.title}\n URL: {result.link}\n"
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
|
||||
InternetSearchInput,
|
||||
@@ -23,7 +23,7 @@ def dedup_urls(
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> InternetSearchInput:
|
||||
branch_questions_to_urls: dict[str, list[str]] = defaultdict(list)
|
||||
unique_results_by_link: dict[str, InternetSearchResult] = {}
|
||||
unique_results_by_link: dict[str, WebSearchResult] = {}
|
||||
for query, result in state.results_to_open:
|
||||
branch_questions_to_urls[query].append(result.link)
|
||||
if result.link not in unique_results_by_link:
|
||||
|
||||
@@ -13,7 +13,7 @@ class ProviderType(Enum):
|
||||
EXA = "exa"
|
||||
|
||||
|
||||
class InternetSearchResult(BaseModel):
|
||||
class WebSearchResult(BaseModel):
|
||||
title: str
|
||||
link: str
|
||||
author: str | None = None
|
||||
@@ -21,7 +21,7 @@ class InternetSearchResult(BaseModel):
|
||||
snippet: str | None = None
|
||||
|
||||
|
||||
class InternetContent(BaseModel):
|
||||
class WebContent(BaseModel):
|
||||
title: str
|
||||
link: str
|
||||
full_content: str
|
||||
@@ -29,11 +29,11 @@ class InternetContent(BaseModel):
|
||||
scrape_successful: bool = True
|
||||
|
||||
|
||||
class InternetSearchProvider(ABC):
|
||||
class WebSearchProvider(ABC):
|
||||
@abstractmethod
|
||||
def search(self, query: str) -> list[InternetSearchResult]:
|
||||
def search(self, query: str) -> list[WebSearchResult]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def contents(self, urls: list[str]) -> list[InternetContent]:
|
||||
def contents(self, urls: list[str]) -> list[WebContent]:
|
||||
pass
|
||||
|
||||
@@ -5,13 +5,13 @@ from onyx.agents.agent_search.dr.sub_agents.web_search.clients.serper_client imp
|
||||
SerperClient,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchProvider,
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.configs.chat_configs import EXA_API_KEY
|
||||
from onyx.configs.chat_configs import SERPER_API_KEY
|
||||
|
||||
|
||||
def get_default_provider() -> InternetSearchProvider | None:
|
||||
def get_default_provider() -> WebSearchProvider | None:
|
||||
if EXA_API_KEY:
|
||||
return ExaClient()
|
||||
if SERPER_API_KEY:
|
||||
|
||||
@@ -4,13 +4,13 @@ from typing import Annotated
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class InternetSearchInput(SubAgentInput):
|
||||
results_to_open: Annotated[list[tuple[str, InternetSearchResult]], add] = []
|
||||
results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
|
||||
parallelization_nr: int = 0
|
||||
branch_question: Annotated[str, lambda x, y: y] = ""
|
||||
branch_questions_to_urls: Annotated[dict[str, list[str]], lambda x, y: y] = {}
|
||||
@@ -18,7 +18,7 @@ class InternetSearchInput(SubAgentInput):
|
||||
|
||||
|
||||
class InternetSearchUpdate(LoggerUpdate):
|
||||
results_to_open: Annotated[list[tuple[str, InternetSearchResult]], add] = []
|
||||
results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
|
||||
|
||||
|
||||
class FetchInput(SubAgentInput):
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetContent,
|
||||
WebContent,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
@@ -17,7 +17,7 @@ def truncate_search_result_content(content: str, max_chars: int = 10000) -> str:
|
||||
|
||||
|
||||
def dummy_inference_section_from_internet_content(
|
||||
result: InternetContent,
|
||||
result: WebContent,
|
||||
) -> InferenceSection:
|
||||
truncated_content = truncate_search_result_content(result.full_content)
|
||||
return InferenceSection(
|
||||
@@ -48,7 +48,7 @@ def dummy_inference_section_from_internet_content(
|
||||
|
||||
|
||||
def dummy_inference_section_from_internet_search_result(
|
||||
result: InternetSearchResult,
|
||||
result: WebSearchResult,
|
||||
) -> InferenceSection:
|
||||
return InferenceSection(
|
||||
center_chunk=InferenceChunk(
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
@@ -88,11 +87,5 @@ class GraphConfig(BaseModel):
|
||||
# Only needed for agentic search
|
||||
persistence: GraphPersistence
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_search_tool(self) -> "GraphConfig":
|
||||
if self.behavior.use_agentic_search and self.tooling.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
return self
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallOutput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.tools.message import build_tool_message
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.tool_runner import ToolRunner
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ToolCallException(Exception):
|
||||
"""Exception raised for errors during tool calls."""
|
||||
|
||||
|
||||
def call_tool(
|
||||
state: ToolChoiceUpdate,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ToolCallUpdate:
|
||||
"""Calls the tool specified in the state and updates the state with the result"""
|
||||
|
||||
cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
tool_choice = state.tool_choice
|
||||
if tool_choice is None:
|
||||
raise ValueError("Cannot invoke tool call node without a tool choice")
|
||||
|
||||
tool = tool_choice.tool
|
||||
tool_args = tool_choice.tool_args
|
||||
tool_id = tool_choice.id
|
||||
tool_runner = ToolRunner(
|
||||
tool, tool_args, override_kwargs=tool_choice.search_tool_override_kwargs
|
||||
)
|
||||
tool_kickoff = tool_runner.kickoff()
|
||||
|
||||
try:
|
||||
tool_responses = []
|
||||
for response in tool_runner.tool_responses():
|
||||
tool_responses.append(response)
|
||||
|
||||
tool_final_result = tool_runner.tool_final_result()
|
||||
except Exception as e:
|
||||
raise ToolCallException(
|
||||
f"Error during tool call for {tool.display_name}: {e}"
|
||||
) from e
|
||||
|
||||
tool_call = ToolCall(name=tool.name, args=tool_args, id=tool_id)
|
||||
tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=AIMessageChunk(content="", tool_calls=[tool_call]),
|
||||
tool_call_result=build_tool_message(
|
||||
tool_call, tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
tool_call_output = ToolCallOutput(
|
||||
tool_call_summary=tool_call_summary,
|
||||
tool_call_kickoff=tool_kickoff,
|
||||
tool_call_responses=tool_responses,
|
||||
tool_call_final_result=tool_final_result,
|
||||
)
|
||||
return ToolCallUpdate(tool_call_output=tool_call_output)
|
||||
@@ -1,354 +0,0 @@
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.process_llm_stream import process_llm_stream
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoice
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.chat.tool_handling.tool_response_handler import (
|
||||
get_tool_call_for_non_tool_calling_llm_impl,
|
||||
)
|
||||
from onyx.configs.chat_configs import USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
|
||||
from onyx.context.search.preprocessing.preprocessing import query_analysis
|
||||
from onyx.context.search.retrieval.search_runner import get_query_embedding
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
|
||||
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||
from onyx.prompts.chat_prompts import QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
|
||||
from onyx.prompts.chat_prompts import QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||
from onyx.tools.models import QueryExpansions
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import TimeoutThread
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _create_history_str(prompt_builder: AnswerPromptBuilder) -> str:
|
||||
# TODO: Add trimming logic
|
||||
history_segments = []
|
||||
for msg in prompt_builder.message_history:
|
||||
if isinstance(msg, HumanMessage):
|
||||
role = "User"
|
||||
elif isinstance(msg, AIMessage):
|
||||
role = "Assistant"
|
||||
else:
|
||||
continue
|
||||
history_segments.append(f"{role}:\n {msg.content}\n\n")
|
||||
return "\n".join(history_segments)
|
||||
|
||||
|
||||
def _expand_query(
|
||||
query: str,
|
||||
expansion_type: QueryExpansionType,
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
) -> str:
|
||||
|
||||
history_str = _create_history_str(prompt_builder)
|
||||
|
||||
if history_str:
|
||||
if expansion_type == QueryExpansionType.KEYWORD:
|
||||
base_prompt = QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
|
||||
else:
|
||||
base_prompt = QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
|
||||
expansion_prompt = base_prompt.format(question=query, history=history_str)
|
||||
else:
|
||||
if expansion_type == QueryExpansionType.KEYWORD:
|
||||
base_prompt = QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||
else:
|
||||
base_prompt = QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||
expansion_prompt = base_prompt.format(question=query)
|
||||
|
||||
msg = HumanMessage(content=expansion_prompt)
|
||||
primary_llm, _ = get_default_llms()
|
||||
response = primary_llm.invoke([msg])
|
||||
rephrased_query: str = cast(str, response.content)
|
||||
|
||||
return rephrased_query
|
||||
|
||||
|
||||
def _expand_query_non_tool_calling_llm(
|
||||
expanded_keyword_thread: TimeoutThread[str],
|
||||
expanded_semantic_thread: TimeoutThread[str],
|
||||
) -> QueryExpansions | None:
|
||||
keyword_expansion: str | None = wait_on_background(expanded_keyword_thread)
|
||||
semantic_expansion: str | None = wait_on_background(expanded_semantic_thread)
|
||||
|
||||
if keyword_expansion is None or semantic_expansion is None:
|
||||
return None
|
||||
|
||||
return QueryExpansions(
|
||||
keywords_expansions=[keyword_expansion],
|
||||
semantic_expansions=[semantic_expansion],
|
||||
)
|
||||
|
||||
|
||||
# TODO: break this out into an implementation function
|
||||
# and a function that handles extracting the necessary fields
|
||||
# from the state and config
|
||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||
@log_function_time(print_only=True)
|
||||
def choose_tool(
|
||||
state: ToolChoiceState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ToolChoiceUpdate:
|
||||
"""
|
||||
This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
|
||||
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
|
||||
"""
|
||||
should_stream_answer = state.should_stream_answer
|
||||
|
||||
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
force_use_tool = agent_config.tooling.force_use_tool
|
||||
|
||||
embedding_thread: TimeoutThread[Embedding] | None = None
|
||||
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
|
||||
expanded_keyword_thread: TimeoutThread[str] | None = None
|
||||
expanded_semantic_thread: TimeoutThread[str] | None = None
|
||||
# If we have override_kwargs, add them to the tool_args
|
||||
override_kwargs: SearchToolOverrideKwargs = (
|
||||
force_use_tool.override_kwargs or SearchToolOverrideKwargs()
|
||||
)
|
||||
override_kwargs.original_query = agent_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
|
||||
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
|
||||
|
||||
llm = agent_config.tooling.primary_llm
|
||||
skip_gen_ai_answer_generation = agent_config.behavior.skip_gen_ai_answer_generation
|
||||
|
||||
if (
|
||||
not agent_config.behavior.use_agentic_search
|
||||
and agent_config.tooling.search_tool is not None
|
||||
and (
|
||||
not force_use_tool.force_use or force_use_tool.tool_name == SearchTool._NAME
|
||||
)
|
||||
):
|
||||
# Run in a background thread to avoid blocking the main thread
|
||||
embedding_thread = run_in_background(
|
||||
get_query_embedding,
|
||||
agent_config.inputs.prompt_builder.raw_user_query,
|
||||
agent_config.persistence.db_session,
|
||||
)
|
||||
keyword_thread = run_in_background(
|
||||
query_analysis,
|
||||
agent_config.inputs.prompt_builder.raw_user_query,
|
||||
)
|
||||
|
||||
if USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH:
|
||||
|
||||
expanded_keyword_thread = run_in_background(
|
||||
_expand_query,
|
||||
agent_config.inputs.prompt_builder.raw_user_query,
|
||||
QueryExpansionType.KEYWORD,
|
||||
prompt_builder,
|
||||
)
|
||||
expanded_semantic_thread = run_in_background(
|
||||
_expand_query,
|
||||
agent_config.inputs.prompt_builder.raw_user_query,
|
||||
QueryExpansionType.SEMANTIC,
|
||||
prompt_builder,
|
||||
)
|
||||
|
||||
structured_response_format = agent_config.inputs.structured_response_format
|
||||
tools = [
|
||||
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
|
||||
]
|
||||
|
||||
tool, tool_args = None, None
|
||||
if force_use_tool.force_use and force_use_tool.args is not None:
|
||||
tool_name, tool_args = (
|
||||
force_use_tool.tool_name,
|
||||
force_use_tool.args,
|
||||
)
|
||||
tool = get_tool_by_name(tools, tool_name)
|
||||
|
||||
# special pre-logic for non-tool calling LLM case
|
||||
elif not using_tool_calling_llm and tools:
|
||||
chosen_tool_and_args = get_tool_call_for_non_tool_calling_llm_impl(
|
||||
force_use_tool=force_use_tool,
|
||||
tools=tools,
|
||||
prompt_builder=prompt_builder,
|
||||
llm=llm,
|
||||
)
|
||||
if chosen_tool_and_args:
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
|
||||
# If we have a tool and tool args, we are ready to request a tool call.
|
||||
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
|
||||
if tool and tool_args:
|
||||
if embedding_thread and tool.name == SearchTool._NAME:
|
||||
# Wait for the embedding thread to finish
|
||||
embedding = wait_on_background(embedding_thread)
|
||||
override_kwargs.precomputed_query_embedding = embedding
|
||||
if keyword_thread and tool.name == SearchTool._NAME:
|
||||
is_keyword, keywords = wait_on_background(keyword_thread)
|
||||
override_kwargs.precomputed_is_keyword = is_keyword
|
||||
override_kwargs.precomputed_keywords = keywords
|
||||
# dual keyword expansion needs to be added here for non-tool calling LLM case
|
||||
if (
|
||||
USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
|
||||
and expanded_keyword_thread
|
||||
and expanded_semantic_thread
|
||||
and tool.name == SearchTool._NAME
|
||||
):
|
||||
override_kwargs.expanded_queries = _expand_query_non_tool_calling_llm(
|
||||
expanded_keyword_thread=expanded_keyword_thread,
|
||||
expanded_semantic_thread=expanded_semantic_thread,
|
||||
)
|
||||
if (
|
||||
USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
|
||||
and tool.name == SearchTool._NAME
|
||||
and override_kwargs.expanded_queries
|
||||
):
|
||||
if (
|
||||
override_kwargs.expanded_queries.keywords_expansions is None
|
||||
or override_kwargs.expanded_queries.semantic_expansions is None
|
||||
):
|
||||
raise ValueError("No expanded keyword or semantic threads found.")
|
||||
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=tool,
|
||||
tool_args=tool_args,
|
||||
id=str(uuid4()),
|
||||
search_tool_override_kwargs=override_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
# if we're skipping gen ai answer generation, we should only
|
||||
# continue if we're forcing a tool call (which will be emitted by
|
||||
# the tool calling llm in the stream() below)
|
||||
if skip_gen_ai_answer_generation and not force_use_tool.force_use:
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
built_prompt = (
|
||||
prompt_builder.build()
|
||||
if isinstance(prompt_builder, AnswerPromptBuilder)
|
||||
else prompt_builder.built_prompt
|
||||
)
|
||||
# At this point, we are either using a tool calling LLM or we are skipping the tool call.
|
||||
# DEBUG: good breakpoint
|
||||
stream = llm.stream(
|
||||
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
prompt=built_prompt,
|
||||
tools=(
|
||||
[tool.tool_definition() for tool in tools] or None
|
||||
if using_tool_calling_llm
|
||||
else None
|
||||
),
|
||||
tool_choice=(
|
||||
"required"
|
||||
if tools and force_use_tool.force_use and using_tool_calling_llm
|
||||
else None
|
||||
),
|
||||
structured_response_format=structured_response_format,
|
||||
)
|
||||
|
||||
tool_message = process_llm_stream(
|
||||
stream,
|
||||
should_stream_answer
|
||||
and not agent_config.behavior.skip_gen_ai_answer_generation,
|
||||
writer,
|
||||
ind=0,
|
||||
).ai_message_chunk
|
||||
|
||||
if tool_message is None:
|
||||
raise ValueError("No tool message emitted by LLM")
|
||||
|
||||
# If no tool calls are emitted by the LLM, we should not choose a tool
|
||||
if len(tool_message.tool_calls) == 0:
|
||||
logger.debug("No tool calls emitted by LLM")
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
# TODO: here we could handle parallel tool calls. Right now
|
||||
# we just pick the first one that matches.
|
||||
selected_tool: Tool | None = None
|
||||
selected_tool_call_request: ToolCall | None = None
|
||||
for tool_call_request in tool_message.tool_calls:
|
||||
known_tools_by_name = [
|
||||
tool for tool in tools if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if known_tools_by_name:
|
||||
selected_tool = known_tools_by_name[0]
|
||||
selected_tool_call_request = tool_call_request
|
||||
break
|
||||
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"tools: {tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
|
||||
if not selected_tool or not selected_tool_call_request:
|
||||
raise ValueError(
|
||||
f"Tool call attempted with tool {selected_tool}, request {selected_tool_call_request}"
|
||||
)
|
||||
|
||||
logger.debug(f"Selected tool: {selected_tool.name}")
|
||||
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
|
||||
|
||||
if embedding_thread and selected_tool.name == SearchTool._NAME:
|
||||
# Wait for the embedding thread to finish
|
||||
embedding = wait_on_background(embedding_thread)
|
||||
override_kwargs.precomputed_query_embedding = embedding
|
||||
if keyword_thread and selected_tool.name == SearchTool._NAME:
|
||||
is_keyword, keywords = wait_on_background(keyword_thread)
|
||||
override_kwargs.precomputed_is_keyword = is_keyword
|
||||
override_kwargs.precomputed_keywords = keywords
|
||||
|
||||
if (
|
||||
selected_tool.name == SearchTool._NAME
|
||||
and expanded_keyword_thread
|
||||
and expanded_semantic_thread
|
||||
):
|
||||
|
||||
override_kwargs.expanded_queries = _expand_query_non_tool_calling_llm(
|
||||
expanded_keyword_thread=expanded_keyword_thread,
|
||||
expanded_semantic_thread=expanded_semantic_thread,
|
||||
)
|
||||
if (
|
||||
USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
|
||||
and selected_tool.name == SearchTool._NAME
|
||||
and override_kwargs.expanded_queries
|
||||
):
|
||||
# TODO: this is a hack to handle the case where the expanded queries are not found.
|
||||
# We should refactor this to be more robust.
|
||||
if (
|
||||
override_kwargs.expanded_queries.keywords_expansions is None
|
||||
or override_kwargs.expanded_queries.semantic_expansions is None
|
||||
):
|
||||
raise ValueError("No expanded keyword or semantic threads found.")
|
||||
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=selected_tool,
|
||||
tool_args=selected_tool_call_request["args"],
|
||||
id=selected_tool_call_request["id"],
|
||||
search_tool_override_kwargs=override_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -1,17 +0,0 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
|
||||
|
||||
def prepare_tool_input(state: Any, config: RunnableConfig) -> ToolChoiceInput:
|
||||
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
return ToolChoiceInput(
|
||||
# NOTE: this node is used at the top level of the agent, so we always stream
|
||||
should_stream_answer=True,
|
||||
prompt_snapshot=None, # uses default prompt builder
|
||||
tools=[tool.name for tool in (agent_config.tooling.tools or [])],
|
||||
)
|
||||
@@ -3,6 +3,7 @@ from typing import cast
|
||||
|
||||
from langchain_core.runnables.schema import CustomStreamEvent
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langfuse.langchain import CallbackHandler
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.graph_builder import (
|
||||
@@ -15,12 +16,13 @@ from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
|
||||
from onyx.agents.agent_search.kb_search.states import MainInput as KBMainInput
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.configs.app_configs import LANGFUSE_PUBLIC_KEY
|
||||
from onyx.configs.app_configs import LANGFUSE_SECRET_KEY
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
GraphInput = DCMainInput | KBMainInput | DRMainInput
|
||||
|
||||
|
||||
@@ -30,10 +32,16 @@ def manage_sync_streaming(
|
||||
graph_input: GraphInput,
|
||||
) -> Iterable[StreamEvent]:
|
||||
message_id = config.persistence.message_id if config.persistence else None
|
||||
callbacks: list[CallbackHandler] = []
|
||||
if LANGFUSE_SECRET_KEY and LANGFUSE_PUBLIC_KEY:
|
||||
callbacks.append(CallbackHandler())
|
||||
for event in compiled_graph.stream(
|
||||
stream_mode="custom",
|
||||
input=graph_input,
|
||||
config={"metadata": {"config": config, "thread_id": str(message_id)}},
|
||||
config={
|
||||
"metadata": {"config": config, "thread_id": str(message_id)},
|
||||
"callbacks": callbacks, # type: ignore
|
||||
},
|
||||
):
|
||||
yield cast(CustomStreamEvent, event)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Literal
|
||||
from typing import Type
|
||||
from typing import TypeVar
|
||||
|
||||
from braintrust import traced
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import StreamWriter
|
||||
@@ -27,6 +28,7 @@ SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
||||
JSON_PATTERN = re.compile(r"```(?:json)?\s*(\{.*?\})\s*```", re.DOTALL)
|
||||
|
||||
|
||||
@traced(name="stream llm", type="llm")
|
||||
def stream_llm_answer(
|
||||
llm: LLM,
|
||||
prompt: LanguageModelInput,
|
||||
|
||||
@@ -29,7 +29,7 @@ from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
|
||||
from onyx.configs.constants import ONYX_SLACK_URL
|
||||
from onyx.configs.constants import ONYX_DISCORD_URL
|
||||
from onyx.db.models import User
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -145,7 +145,7 @@ HTML_EMAIL_TEMPLATE = """\
|
||||
<tr>
|
||||
<td class="footer">
|
||||
© {year} {application_name}. All rights reserved.
|
||||
{slack_fragment}
|
||||
{community_link_fragment}
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
@@ -161,9 +161,9 @@ def build_html_email(
|
||||
cta_text: str | None = None,
|
||||
cta_link: str | None = None,
|
||||
) -> str:
|
||||
slack_fragment = ""
|
||||
community_link_fragment = ""
|
||||
if application_name == ONYX_DEFAULT_APPLICATION_NAME:
|
||||
slack_fragment = f'<br>Have questions? Join our Slack community <a href="{ONYX_SLACK_URL}">here</a>.'
|
||||
community_link_fragment = f'<br>Have questions? Join our Discord community <a href="{ONYX_DISCORD_URL}">here</a>.'
|
||||
|
||||
if cta_text and cta_link:
|
||||
cta_block = f'<a class="cta-button" href="{cta_link}">{cta_text}</a>'
|
||||
@@ -175,7 +175,7 @@ def build_html_email(
|
||||
heading=heading,
|
||||
message=message,
|
||||
cta_block=cta_block,
|
||||
slack_fragment=slack_fragment,
|
||||
community_link_fragment=community_link_fragment,
|
||||
year=datetime.now().year,
|
||||
)
|
||||
|
||||
|
||||
177
backend/onyx/auth/jwt.py
Normal file
177
backend/onyx/auth/jwt.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import jwt
|
||||
import requests
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
|
||||
from jwt import decode as jwt_decode
|
||||
from jwt import InvalidTokenError
|
||||
from jwt import PyJWTError
|
||||
from jwt.algorithms import RSAAlgorithm
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.configs.app_configs import JWT_PUBLIC_KEY_URL
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_PUBLIC_KEY_FETCH_ATTEMPTS = 2
|
||||
|
||||
|
||||
class PublicKeyFormat(Enum):
|
||||
JWKS = "jwks"
|
||||
PEM = "pem"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def _fetch_public_key_payload() -> tuple[str | dict[str, Any], PublicKeyFormat] | None:
|
||||
"""Fetch and cache the raw JWT verification material."""
|
||||
if JWT_PUBLIC_KEY_URL is None:
|
||||
logger.error("JWT_PUBLIC_KEY_URL is not set")
|
||||
return None
|
||||
|
||||
try:
|
||||
response = requests.get(JWT_PUBLIC_KEY_URL)
|
||||
response.raise_for_status()
|
||||
except requests.RequestException as exc:
|
||||
logger.error(f"Failed to fetch JWT public key: {str(exc)}")
|
||||
return None
|
||||
content_type = response.headers.get("Content-Type", "").lower()
|
||||
raw_body = response.text
|
||||
body_lstripped = raw_body.lstrip()
|
||||
|
||||
if "application/json" in content_type or body_lstripped.startswith("{"):
|
||||
try:
|
||||
data = response.json()
|
||||
except ValueError:
|
||||
logger.error("JWT public key URL returned invalid JSON")
|
||||
return None
|
||||
|
||||
if isinstance(data, dict) and "keys" in data:
|
||||
return data, PublicKeyFormat.JWKS
|
||||
|
||||
logger.error(
|
||||
"JWT public key URL returned JSON but no JWKS 'keys' field was found"
|
||||
)
|
||||
return None
|
||||
|
||||
body = raw_body.strip()
|
||||
if not body:
|
||||
logger.error("JWT public key URL returned an empty response")
|
||||
return None
|
||||
|
||||
return body, PublicKeyFormat.PEM
|
||||
|
||||
|
||||
def get_public_key(token: str) -> RSAPublicKey | str | None:
|
||||
"""Return the concrete public key used to verify the provided JWT token."""
|
||||
payload = _fetch_public_key_payload()
|
||||
if payload is None:
|
||||
logger.error("Failed to retrieve public key payload")
|
||||
return None
|
||||
|
||||
key_material, key_format = payload
|
||||
|
||||
if key_format is PublicKeyFormat.JWKS:
|
||||
jwks_data = cast(dict[str, Any], key_material)
|
||||
return _resolve_public_key_from_jwks(token, jwks_data)
|
||||
|
||||
return cast(str, key_material)
|
||||
|
||||
|
||||
def _resolve_public_key_from_jwks(
|
||||
token: str, jwks_payload: dict[str, Any]
|
||||
) -> RSAPublicKey | None:
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
except PyJWTError as e:
|
||||
logger.error(f"Unable to parse JWT header: {str(e)}")
|
||||
return None
|
||||
|
||||
keys = jwks_payload.get("keys", []) if isinstance(jwks_payload, dict) else []
|
||||
if not keys:
|
||||
logger.error("JWKS payload did not contain any keys")
|
||||
return None
|
||||
|
||||
kid = header.get("kid")
|
||||
thumbprint = header.get("x5t")
|
||||
|
||||
candidates = []
|
||||
if kid:
|
||||
candidates = [k for k in keys if k.get("kid") == kid]
|
||||
if not candidates and thumbprint:
|
||||
candidates = [k for k in keys if k.get("x5t") == thumbprint]
|
||||
if not candidates and len(keys) == 1:
|
||||
candidates = keys
|
||||
|
||||
if not candidates:
|
||||
logger.warning(
|
||||
"No matching JWK found for token header (kid=%s, x5t=%s)", kid, thumbprint
|
||||
)
|
||||
return None
|
||||
|
||||
if len(candidates) > 1:
|
||||
logger.warning(
|
||||
"Multiple JWKs matched token header kid=%s; selecting the first occurrence",
|
||||
kid,
|
||||
)
|
||||
|
||||
jwk = candidates[0]
|
||||
try:
|
||||
return cast(RSAPublicKey, RSAAlgorithm.from_jwk(json.dumps(jwk)))
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to construct RSA key from JWK: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def verify_jwt_token(token: str, async_db_session: AsyncSession) -> User | None:
|
||||
for attempt in range(_PUBLIC_KEY_FETCH_ATTEMPTS):
|
||||
public_key = get_public_key(token)
|
||||
if public_key is None:
|
||||
logger.error("Unable to resolve a public key for JWT verification")
|
||||
if attempt < _PUBLIC_KEY_FETCH_ATTEMPTS - 1:
|
||||
_fetch_public_key_payload.cache_clear()
|
||||
continue
|
||||
return None
|
||||
|
||||
try:
|
||||
from sqlalchemy import func
|
||||
|
||||
payload = jwt_decode(
|
||||
token,
|
||||
public_key,
|
||||
algorithms=["RS256"],
|
||||
options={"verify_aud": False},
|
||||
)
|
||||
except InvalidTokenError as e:
|
||||
logger.error(f"Invalid JWT token: {str(e)}")
|
||||
if attempt < _PUBLIC_KEY_FETCH_ATTEMPTS - 1:
|
||||
_fetch_public_key_payload.cache_clear()
|
||||
continue
|
||||
return None
|
||||
except PyJWTError as e:
|
||||
logger.error(f"JWT decoding error: {str(e)}")
|
||||
if attempt < _PUBLIC_KEY_FETCH_ATTEMPTS - 1:
|
||||
_fetch_public_key_payload.cache_clear()
|
||||
continue
|
||||
return None
|
||||
|
||||
email = payload.get("email")
|
||||
if email:
|
||||
result = await async_db_session.execute(
|
||||
select(User).where(func.lower(User.email) == func.lower(email))
|
||||
)
|
||||
return result.scalars().first()
|
||||
logger.warning(
|
||||
"JWT token decoded successfully but no email claim found; skipping auth"
|
||||
)
|
||||
break
|
||||
|
||||
return None
|
||||
@@ -3,12 +3,14 @@ from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import KV_NO_AUTH_USER_PERSONALIZATION_KEY
|
||||
from onyx.configs.constants import KV_NO_AUTH_USER_PREFERENCES_KEY
|
||||
from onyx.configs.constants import NO_AUTH_USER_EMAIL
|
||||
from onyx.configs.constants import NO_AUTH_USER_ID
|
||||
from onyx.key_value_store.store import KeyValueStore
|
||||
from onyx.key_value_store.store import KvKeyNotFoundError
|
||||
from onyx.server.manage.models import UserInfo
|
||||
from onyx.server.manage.models import UserPersonalization
|
||||
from onyx.server.manage.models import UserPreferences
|
||||
|
||||
|
||||
@@ -18,6 +20,12 @@ def set_no_auth_user_preferences(
|
||||
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.model_dump())
|
||||
|
||||
|
||||
def set_no_auth_user_personalization(
|
||||
store: KeyValueStore, personalization: UserPersonalization
|
||||
) -> None:
|
||||
store.store(KV_NO_AUTH_USER_PERSONALIZATION_KEY, personalization.model_dump())
|
||||
|
||||
|
||||
def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
|
||||
try:
|
||||
preferences_data = cast(
|
||||
@@ -33,6 +41,15 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
|
||||
def fetch_no_auth_user(
|
||||
store: KeyValueStore, *, anonymous_user_enabled: bool | None = None
|
||||
) -> UserInfo:
|
||||
personalization = UserPersonalization()
|
||||
try:
|
||||
personalization_data = cast(
|
||||
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PERSONALIZATION_KEY)
|
||||
)
|
||||
personalization = UserPersonalization(**personalization_data)
|
||||
except KvKeyNotFoundError:
|
||||
pass
|
||||
|
||||
return UserInfo(
|
||||
id=NO_AUTH_USER_ID,
|
||||
email=NO_AUTH_USER_EMAIL,
|
||||
@@ -41,6 +58,7 @@ def fetch_no_auth_user(
|
||||
is_verified=True,
|
||||
role=UserRole.BASIC if anonymous_user_enabled else UserRole.ADMIN,
|
||||
preferences=load_no_auth_user_preferences(store),
|
||||
personalization=personalization,
|
||||
is_anonymous_user=anonymous_user_enabled,
|
||||
password_configured=False,
|
||||
)
|
||||
|
||||
@@ -54,6 +54,8 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
|
||||
from httpx_oauth.oauth2 import BaseOAuth2
|
||||
from httpx_oauth.oauth2 import OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import nulls_last
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
@@ -61,6 +63,7 @@ from onyx.auth.email_utils import send_forgot_password_email
|
||||
from onyx.auth.email_utils import send_user_verification_email
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.jwt import verify_jwt_token
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.schemas import UserRole
|
||||
@@ -70,6 +73,7 @@ from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import JWT_PUBLIC_KEY_URL
|
||||
from onyx.configs.app_configs import PASSWORD_MAX_LENGTH
|
||||
from onyx.configs.app_configs import PASSWORD_MIN_LENGTH
|
||||
from onyx.configs.app_configs import PASSWORD_REQUIRE_DIGIT
|
||||
@@ -103,19 +107,21 @@ from onyx.db.engine.async_sql_engine import get_async_session_context_manager
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.saml import get_saml_account
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.secrets import extract_hashed_cookie
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.url import add_url_params
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import async_return_default_schema
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -134,10 +140,9 @@ def is_user_admin(user: User | None) -> bool:
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
if AUTH_TYPE not in [AuthType.DISABLED, AuthType.BASIC, AuthType.GOOGLE_OAUTH]:
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
raise ValueError(
|
||||
"User must choose a valid user authentication method: "
|
||||
"disabled, basic, or google_oauth"
|
||||
f"{AUTH_TYPE.value} is not a valid auth type for self-hosted deployments."
|
||||
)
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
@@ -324,8 +329,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
):
|
||||
user_create.role = UserRole.ADMIN
|
||||
|
||||
user_created = False
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
user = await super().create(
|
||||
user_create, safe=safe, request=request
|
||||
) # type: ignore
|
||||
user_created = True
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
|
||||
@@ -351,11 +360,42 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
role=user_create.role,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
if user_created:
|
||||
await self._assign_default_pinned_assistants(user, db_session)
|
||||
remove_user_from_invited_users(user_create.email)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
return user
|
||||
|
||||
async def _assign_default_pinned_assistants(
|
||||
self, user: User, db_session: AsyncSession
|
||||
) -> None:
|
||||
if user.pinned_assistants is not None:
|
||||
return
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Persona.id)
|
||||
.where(
|
||||
Persona.is_default_persona.is_(True),
|
||||
Persona.is_public.is_(True),
|
||||
Persona.is_visible.is_(True),
|
||||
Persona.deleted.is_(False),
|
||||
)
|
||||
.order_by(
|
||||
nulls_last(Persona.display_priority.asc()),
|
||||
Persona.id.asc(),
|
||||
)
|
||||
)
|
||||
default_persona_ids = list(result.scalars().all())
|
||||
if not default_persona_ids:
|
||||
return
|
||||
|
||||
await self.user_db.update(
|
||||
user,
|
||||
{"pinned_assistants": default_persona_ids},
|
||||
)
|
||||
user.pinned_assistants = default_persona_ids
|
||||
|
||||
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
|
||||
# Validate password according to configurable security policy (defined via environment variables)
|
||||
if len(password) < PASSWORD_MIN_LENGTH:
|
||||
@@ -476,6 +516,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
user = await self.user_db.create(user_dict)
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
await self._assign_default_pinned_assistants(user, db_session)
|
||||
await self.on_after_register(user, request)
|
||||
|
||||
else:
|
||||
@@ -1018,13 +1059,28 @@ fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
|
||||
optional_fastapi_current_user = fastapi_users.current_user(active=True, optional=True)
|
||||
|
||||
|
||||
async def optional_user_(
|
||||
async def _check_for_saml_and_jwt(
|
||||
request: Request,
|
||||
user: User | None,
|
||||
async_db_session: AsyncSession,
|
||||
) -> User | None:
|
||||
"""NOTE: `request` and `db_session` are not used here, but are included
|
||||
for the EE version of this function."""
|
||||
# Check if the user has a session cookie from SAML
|
||||
if AUTH_TYPE == AuthType.SAML:
|
||||
saved_cookie = extract_hashed_cookie(request)
|
||||
|
||||
if saved_cookie:
|
||||
saml_account = await get_saml_account(
|
||||
cookie=saved_cookie, async_db_session=async_db_session
|
||||
)
|
||||
user = saml_account.user if saml_account else None
|
||||
|
||||
# If user is still None, check for JWT in Authorization header
|
||||
if user is None and JWT_PUBLIC_KEY_URL is not None:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
token = auth_header[len("Bearer ") :].strip()
|
||||
user = await verify_jwt_token(token, async_db_session)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@@ -1033,14 +1089,14 @@ async def optional_user(
|
||||
async_db_session: AsyncSession = Depends(get_async_session),
|
||||
user: User | None = Depends(optional_fastapi_current_user),
|
||||
) -> User | None:
|
||||
versioned_fetch_user = fetch_versioned_implementation(
|
||||
"onyx.auth.users", "optional_user_"
|
||||
)
|
||||
user = await versioned_fetch_user(request, user, async_db_session)
|
||||
user = await _check_for_saml_and_jwt(request, user, async_db_session)
|
||||
|
||||
# check if an API key is present
|
||||
if user is None:
|
||||
hashed_api_key = get_hashed_api_key_from_request(request)
|
||||
try:
|
||||
hashed_api_key = get_hashed_api_key_from_request(request)
|
||||
except ValueError:
|
||||
hashed_api_key = None
|
||||
if hashed_api_key:
|
||||
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
|
||||
|
||||
|
||||
137
backend/onyx/background/celery/apps/background.py
Normal file
137
backend/onyx/background/celery/apps/background.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.background")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
|
||||
|
||||
logger.info("worker_init signal received for consolidated background worker.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME)
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
|
||||
|
||||
# Initialize Vespa httpx pool (needed for light worker tasks)
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
|
||||
ssl_cert=VESPA_CLOUD_CERT_PATH,
|
||||
ssl_key=VESPA_CLOUD_KEY_PATH,
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None:
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
# Original background worker tasks
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.kg_processing",
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
# Light worker tasks
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
# Docprocessing worker tasks
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Docfetching worker tasks
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
]
|
||||
)
|
||||
@@ -98,5 +98,8 @@ for bootstep in base_bootsteps:
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
# Ensure the user files indexing worker registers the doc_id migration task
|
||||
# TODO(subash): remove this once the doc_id migration is complete
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -324,5 +324,6 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
"onyx.background.celery.tasks.kg_processing",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
)
|
||||
|
||||
23
backend/onyx/background/celery/configs/background.py
Normal file
23
backend/onyx/background/celery/configs/background.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_BACKGROUND_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = CELERY_WORKER_BACKGROUND_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
# Increased from 1 to 4 to handle fast light worker tasks more efficiently
|
||||
# This allows the worker to prefetch multiple tasks per thread
|
||||
worker_prefetch_multiplier = 4
|
||||
@@ -1,4 +1,5 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_HEAVY_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
@@ -15,6 +16,6 @@ result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = 4
|
||||
worker_concurrency = CELERY_WORKER_HEAVY_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_MONITORING_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
@@ -16,6 +17,6 @@ task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
# Monitoring worker specific settings
|
||||
worker_concurrency = 1 # Single worker is sufficient for monitoring
|
||||
worker_concurrency = CELERY_WORKER_MONITORING_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
|
||||
@@ -33,17 +33,34 @@ beat_task_templates: list[dict] = [
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-user-file-project-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_USER_FILE_PROJECT_SYNC,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-user-file-delete",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_USER_FILE_DELETE,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "user-file-docid-migration",
|
||||
"task": OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
|
||||
"schedule": timedelta(minutes=1),
|
||||
"schedule": timedelta(minutes=10),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
"queue": OnyxCeleryQueues.USER_FILES_INDEXING,
|
||||
},
|
||||
},
|
||||
{
|
||||
|
||||
@@ -42,6 +42,12 @@ from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.index_attempt import delete_index_attempts
|
||||
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
|
||||
from onyx.db.permission_sync_attempt import (
|
||||
delete_doc_permission_sync_attempts__no_commit,
|
||||
)
|
||||
from onyx.db.permission_sync_attempt import (
|
||||
delete_external_group_permission_sync_attempts__no_commit,
|
||||
)
|
||||
from onyx.db.search_settings import get_all_search_settings
|
||||
from onyx.db.sync_record import cleanup_sync_records
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
@@ -441,6 +447,16 @@ def monitor_connector_deletion_taskset(
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
# permission sync attempts
|
||||
delete_doc_permission_sync_attempts__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
delete_external_group_permission_sync_attempts__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
# document sets
|
||||
delete_document_set_cc_pair_relationship__no_commit(
|
||||
db_session=db_session,
|
||||
|
||||
@@ -895,6 +895,9 @@ def monitor_celery_queues_helper(
|
||||
n_user_file_project_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, r_celery
|
||||
)
|
||||
n_user_file_delete = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_DELETE, r_celery
|
||||
)
|
||||
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
n_deletion = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
|
||||
n_pruning = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery)
|
||||
@@ -924,6 +927,7 @@ def monitor_celery_queues_helper(
|
||||
f"user_files_indexing={n_user_files_indexing} "
|
||||
f"user_file_processing={n_user_file_processing} "
|
||||
f"user_file_project_sync={n_user_file_project_sync} "
|
||||
f"user_file_delete={n_user_file_delete} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
|
||||
@@ -9,17 +9,19 @@ import sqlalchemy as sa
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from redis.lock import Lock as RedisLock
|
||||
from retry import retry
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
@@ -36,15 +38,15 @@ from onyx.db.models import UserFile
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
)
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.document_index.vespa_constants import USER_PROJECT
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.file_store import S3BackedFileStore
|
||||
from onyx.file_store.utils import user_file_id_to_plaintext_file_name
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAdapter
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
@@ -68,6 +70,56 @@ def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_delete_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_DELETE_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2, jitter=(0.0, 1.0))
|
||||
def _visit_chunks(
|
||||
*,
|
||||
http_client: httpx.Client,
|
||||
index_name: str,
|
||||
selection: str,
|
||||
continuation: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
task_logger.info(
|
||||
f"Visiting chunks for index={index_name} with selection={selection}"
|
||||
)
|
||||
base_url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
|
||||
params: dict[str, str] = {
|
||||
"selection": selection,
|
||||
"wantedDocumentCount": "100", # Use smaller batch size to avoid timeouts
|
||||
}
|
||||
if continuation:
|
||||
params["continuation"] = continuation
|
||||
resp = http_client.get(base_url, params=params, timeout=None)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
return payload.get("documents", []), payload.get("continuation")
|
||||
|
||||
|
||||
def _get_document_chunk_count(
|
||||
*,
|
||||
index_name: str,
|
||||
selection: str,
|
||||
) -> int:
|
||||
chunk_count = 0
|
||||
continuation = None
|
||||
while True:
|
||||
docs, continuation = _visit_chunks(
|
||||
http_client=HttpxPool.get("vespa"),
|
||||
index_name=index_name,
|
||||
selection=selection,
|
||||
continuation=continuation,
|
||||
)
|
||||
if not docs:
|
||||
break
|
||||
chunk_count += len(docs)
|
||||
if not continuation:
|
||||
break
|
||||
return chunk_count
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_PROCESSING,
|
||||
soft_time_limit=300,
|
||||
@@ -134,7 +186,8 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id), timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not file_lock.acquire(blocking=False):
|
||||
@@ -244,18 +297,31 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
task_logger.error(
|
||||
f"process_single_user_file - Indexing pipeline failed id={user_file_id}"
|
||||
)
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
# don't update the status if the user file is being deleted
|
||||
# Re-fetch to avoid mypy error
|
||||
current_user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if (
|
||||
current_user_file
|
||||
and current_user_file.status != UserFileStatus.DELETING
|
||||
):
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
# don't update the status if the user file is being deleted
|
||||
current_user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if (
|
||||
current_user_file
|
||||
and current_user_file.status != UserFileStatus.DELETING
|
||||
):
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
return None
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
@@ -268,7 +334,9 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
uf = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if uf:
|
||||
uf.status = UserFileStatus.FAILED
|
||||
# don't update the status if the user file is being deleted
|
||||
if uf.status != UserFileStatus.DELETING:
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
|
||||
@@ -281,6 +349,149 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
file_lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_DELETE,
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with DELETING status and enqueue per-file tasks."""
|
||||
task_logger.info("check_for_user_file_delete - Starting")
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
lock: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.USER_FILE_DELETE_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
if not lock.acquire(blocking=False):
|
||||
return None
|
||||
enqueued = 0
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
select(UserFile.id).where(
|
||||
UserFile.status == UserFileStatus.DELETING
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
for user_file_id in user_file_ids:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
enqueued += 1
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"check_for_user_file_delete - Error enqueuing deletes - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
task_logger.info(
|
||||
f"check_for_user_file_delete - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_delete(
|
||||
self: Task, *, user_file_id: str, tenant_id: str
|
||||
) -> None:
|
||||
"""Process a single user file delete."""
|
||||
task_logger.info(f"process_single_user_file_delete - Starting id={user_file_id}")
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_delete_lock_key(user_file_id),
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file_delete - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
retry_index = RetryDocumentIndex(document_index)
|
||||
index_name = active_search_settings.primary.index_name
|
||||
selection = f"{index_name}.document_id=='{user_file_id}'"
|
||||
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"process_single_user_file_delete - User file not found id={user_file_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
# 1) Delete Vespa chunks for the document
|
||||
chunk_count = 0
|
||||
if user_file.chunk_count is None or user_file.chunk_count == 0:
|
||||
chunk_count = _get_document_chunk_count(
|
||||
index_name=index_name,
|
||||
selection=selection,
|
||||
)
|
||||
else:
|
||||
chunk_count = user_file.chunk_count
|
||||
|
||||
retry_index.delete_single(
|
||||
doc_id=user_file_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
# 2) Delete the user-uploaded file content from filestore (blob + metadata)
|
||||
file_store = get_default_file_store()
|
||||
try:
|
||||
file_store.delete_file(user_file.file_id)
|
||||
file_store.delete_file(
|
||||
user_file_id_to_plaintext_file_name(user_file.id)
|
||||
)
|
||||
except Exception as e:
|
||||
# This block executed only if the file is not found in the filestore
|
||||
task_logger.exception(
|
||||
f"process_single_user_file_delete - Error deleting file id={user_file.id} - {e.__class__.__name__}"
|
||||
)
|
||||
|
||||
# 3) Finally, delete the UserFile row
|
||||
db_session.delete(user_file)
|
||||
db_session.commit()
|
||||
task_logger.info(
|
||||
f"process_single_user_file_delete - Completed id={user_file_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"process_single_user_file_delete - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if file_lock.owned():
|
||||
file_lock.release()
|
||||
return None
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_PROJECT_SYNC,
|
||||
soft_time_limit=300,
|
||||
@@ -306,8 +517,10 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
select(UserFile.id).where(
|
||||
UserFile.needs_project_sync.is_(True)
|
||||
and UserFile.status == UserFileStatus.COMPLETED
|
||||
sa.and_(
|
||||
UserFile.needs_project_sync.is_(True),
|
||||
UserFile.status == UserFileStatus.COMPLETED,
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
@@ -348,7 +561,7 @@ def process_single_user_file_project_sync(
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not file_lock.acquire(blocking=False):
|
||||
@@ -359,6 +572,15 @@ def process_single_user_file_project_sync(
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
search_settings=active_search_settings.primary,
|
||||
@@ -416,147 +638,155 @@ def _normalize_legacy_user_file_doc_id(old_id: str) -> str:
|
||||
return old_id
|
||||
|
||||
|
||||
def _visit_chunks(
|
||||
*,
|
||||
http_client: httpx.Client,
|
||||
index_name: str,
|
||||
selection: str,
|
||||
continuation: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
base_url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
|
||||
params: dict[str, str] = {
|
||||
"selection": selection,
|
||||
"wantedDocumentCount": "1000",
|
||||
}
|
||||
if continuation:
|
||||
params["continuation"] = continuation
|
||||
resp = http_client.get(base_url, params=params, timeout=None)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
return payload.get("documents", []), payload.get("continuation")
|
||||
def update_legacy_plaintext_file_records() -> None:
|
||||
"""Migrate legacy plaintext cache objects from int-based keys to UUID-based
|
||||
keys. Copies each S3 object to its expected UUID key and updates DB.
|
||||
|
||||
Examples:
|
||||
- Old key: bucket/schema/plaintext_<int>
|
||||
- New key: bucket/schema/plaintext_<uuid>
|
||||
"""
|
||||
|
||||
def _update_document_id_in_vespa(
|
||||
*,
|
||||
index_name: str,
|
||||
old_doc_id: str,
|
||||
new_doc_id: str,
|
||||
user_project_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
clean_new_doc_id = replace_invalid_doc_id_characters(new_doc_id)
|
||||
normalized_old = _normalize_legacy_user_file_doc_id(old_doc_id)
|
||||
clean_old_doc_id = replace_invalid_doc_id_characters(normalized_old)
|
||||
task_logger.info("update_legacy_plaintext_file_records - Starting")
|
||||
|
||||
selection = f"{index_name}.document_id=='{clean_old_doc_id}'"
|
||||
task_logger.debug(f"Vespa selection: {selection}")
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
store = get_default_file_store()
|
||||
|
||||
with get_vespa_http_client() as http_client:
|
||||
continuation: str | None = None
|
||||
while True:
|
||||
docs, continuation = _visit_chunks(
|
||||
http_client=http_client,
|
||||
index_name=index_name,
|
||||
selection=selection,
|
||||
continuation=continuation,
|
||||
if not isinstance(store, S3BackedFileStore):
|
||||
task_logger.info(
|
||||
"update_legacy_plaintext_file_records - Skipping non-S3 store"
|
||||
)
|
||||
if not docs:
|
||||
break
|
||||
for doc in docs:
|
||||
vespa_full_id = doc.get("id")
|
||||
if not vespa_full_id:
|
||||
return
|
||||
|
||||
s3_client = store._get_s3_client()
|
||||
bucket_name = store._get_bucket_name()
|
||||
|
||||
# Select PLAINTEXT_CACHE records whose object_key ends with 'plaintext_' + non-hyphen chars
|
||||
# Example: 'some/path/plaintext_abc123' matches; '.../plaintext_foo-bar' does not
|
||||
plaintext_records: Sequence[FileRecord] = (
|
||||
db_session.execute(
|
||||
sa.select(FileRecord).where(
|
||||
FileRecord.file_origin == FileOrigin.PLAINTEXT_CACHE,
|
||||
FileRecord.object_key.op("~")(r"plaintext_[^-]+$"),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"update_legacy_plaintext_file_records - Found {len(plaintext_records)} plaintext records to update"
|
||||
)
|
||||
|
||||
normalized = 0
|
||||
for fr in plaintext_records:
|
||||
try:
|
||||
expected_key = store._get_s3_key(fr.file_id)
|
||||
if fr.object_key == expected_key:
|
||||
continue
|
||||
vespa_doc_uuid = vespa_full_id.split("::")[-1]
|
||||
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_doc_uuid}"
|
||||
update_request: dict[str, Any] = {
|
||||
"fields": {"document_id": {"assign": clean_new_doc_id}}
|
||||
}
|
||||
if user_project_ids is not None:
|
||||
update_request["fields"][USER_PROJECT] = {
|
||||
"assign": user_project_ids
|
||||
}
|
||||
r = http_client.put(vespa_url, json=update_request)
|
||||
r.raise_for_status()
|
||||
if not continuation:
|
||||
break
|
||||
|
||||
if fr.bucket_name is None:
|
||||
task_logger.warning(f"id={fr.file_id} - Bucket name is None")
|
||||
continue
|
||||
|
||||
if fr.object_key is None:
|
||||
task_logger.warning(f"id={fr.file_id} - Object key is None")
|
||||
continue
|
||||
|
||||
# Copy old object to new key
|
||||
copy_source = f"{fr.bucket_name}/{fr.object_key}"
|
||||
s3_client.copy_object(
|
||||
CopySource=copy_source,
|
||||
Bucket=bucket_name,
|
||||
Key=expected_key,
|
||||
MetadataDirective="COPY",
|
||||
)
|
||||
|
||||
# Delete old object (best-effort)
|
||||
try:
|
||||
s3_client.delete_object(Bucket=fr.bucket_name, Key=fr.object_key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Update DB record with new key
|
||||
fr.object_key = expected_key
|
||||
db_session.add(fr)
|
||||
normalized += 1
|
||||
except Exception as e:
|
||||
task_logger.warning(f"id={fr.file_id} - {e.__class__.__name__}")
|
||||
|
||||
if normalized:
|
||||
db_session.commit()
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task normalized {normalized} plaintext objects"
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
|
||||
ignore_result=True,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
bind=True,
|
||||
)
|
||||
def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
|
||||
"""Per-tenant job to update Vespa and search_doc document_id values for user files.
|
||||
|
||||
- For each user_file with a legacy document_id, set Vespa `document_id` to the UUID `user_file.id`.
|
||||
- Update `search_doc.document_id` to the same UUID string.
|
||||
"""
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Starting for tenant={tenant_id}"
|
||||
)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
lock: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.USER_FILE_DOCID_MIGRATION_LOCK,
|
||||
timeout=CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Lock held, skipping tenant={tenant_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
updated_count = 0
|
||||
try:
|
||||
update_legacy_plaintext_file_records()
|
||||
# Track lock renewal
|
||||
last_lock_time = time.monotonic()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_settings = get_active_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
active_settings.primary,
|
||||
active_settings.secondary,
|
||||
search_settings=active_settings.primary,
|
||||
secondary_search_settings=active_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
if hasattr(document_index, "index_name"):
|
||||
index_name = document_index.index_name
|
||||
else:
|
||||
index_name = "danswer_index"
|
||||
|
||||
# Fetch mappings of legacy -> new ids
|
||||
rows = db_session.execute(
|
||||
sa.select(
|
||||
UserFile.document_id.label("document_id"),
|
||||
UserFile.id.label("id"),
|
||||
).where(
|
||||
UserFile.document_id.is_not(None),
|
||||
UserFile.document_id_migrated.is_(False),
|
||||
)
|
||||
).all()
|
||||
retry_index = RetryDocumentIndex(document_index)
|
||||
|
||||
# dedupe by old document_id
|
||||
seen: set[str] = set()
|
||||
for row in rows:
|
||||
old_doc_id = str(row.document_id)
|
||||
new_uuid = str(row.id)
|
||||
if not old_doc_id or not new_uuid or old_doc_id in seen:
|
||||
continue
|
||||
seen.add(old_doc_id)
|
||||
# collect user project ids for a combined Vespa update
|
||||
user_project_ids: list[int] | None = None
|
||||
try:
|
||||
uf = db_session.get(UserFile, UUID(new_uuid))
|
||||
if uf is not None:
|
||||
user_project_ids = [project.id for project in uf.projects]
|
||||
except Exception as e:
|
||||
task_logger.warning(
|
||||
f"Tenant={tenant_id} failed fetching projects for doc_id={new_uuid} - {e.__class__.__name__}"
|
||||
)
|
||||
try:
|
||||
_update_document_id_in_vespa(
|
||||
index_name=index_name,
|
||||
old_doc_id=old_doc_id,
|
||||
new_doc_id=new_uuid,
|
||||
user_project_ids=user_project_ids,
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.warning(
|
||||
f"Tenant={tenant_id} failed Vespa update for doc_id={new_uuid} - {e.__class__.__name__}"
|
||||
)
|
||||
# Update search_doc records to refer to the UUID string
|
||||
# we are not using document_id_migrated = false because if the migration already completed,
|
||||
# it will not run again and we will not update the search_doc records because of the issue currently fixed
|
||||
# Select user files with a legacy doc id that have not been migrated
|
||||
user_files = (
|
||||
db_session.execute(
|
||||
sa.select(UserFile).where(UserFile.document_id.is_not(None))
|
||||
sa.select(UserFile).where(
|
||||
sa.and_(
|
||||
UserFile.document_id.is_not(None),
|
||||
UserFile.document_id_migrated.is_(False),
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Found {len(user_files)} user files to migrate"
|
||||
)
|
||||
|
||||
# Query all SearchDocs that need updating
|
||||
search_docs = (
|
||||
db_session.execute(
|
||||
@@ -567,9 +797,9 @@ def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
task_logger.info(f"Found {len(user_files)} user files to update")
|
||||
task_logger.info(f"Found {len(search_docs)} search docs to update")
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Found {len(search_docs)} search docs to update"
|
||||
)
|
||||
|
||||
# Build a map of normalized doc IDs to SearchDocs
|
||||
search_doc_map: dict[str, list[SearchDoc]] = {}
|
||||
@@ -580,120 +810,128 @@ def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
|
||||
search_doc_map[doc_id].append(sd)
|
||||
|
||||
task_logger.debug(
|
||||
f"Built search doc map with {len(search_doc_map)} entries"
|
||||
f"user_file_docid_migration_task - Built search doc map with {len(search_doc_map)} entries"
|
||||
)
|
||||
|
||||
ids_preview = list(search_doc_map.keys())[:5]
|
||||
task_logger.debug(
|
||||
f"First few search_doc_map ids: {ids_preview if ids_preview else 'No ids found'}"
|
||||
f"user_file_docid_migration_task - First few search_doc_map ids: {ids_preview if ids_preview else 'No ids found'}"
|
||||
)
|
||||
task_logger.debug(
|
||||
f"search_doc_map total items: {sum(len(docs) for docs in search_doc_map.values())}"
|
||||
f"user_file_docid_migration_task - search_doc_map total items: "
|
||||
f"{sum(len(docs) for docs in search_doc_map.values())}"
|
||||
)
|
||||
# Process each UserFile and update matching SearchDocs
|
||||
updated_count = 0
|
||||
for uf in user_files:
|
||||
doc_id = uf.document_id
|
||||
if doc_id.startswith("USER_FILE_CONNECTOR__"):
|
||||
doc_id = "FILE_CONNECTOR__" + doc_id[len("USER_FILE_CONNECTOR__") :]
|
||||
for user_file in user_files:
|
||||
# Periodically renew the Redis lock to prevent expiry mid-run
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT / 4
|
||||
):
|
||||
renewed = False
|
||||
try:
|
||||
# extend lock ttl to full timeout window
|
||||
lock.extend(CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT)
|
||||
renewed = True
|
||||
except Exception:
|
||||
# if extend fails, best-effort reacquire as a fallback
|
||||
try:
|
||||
lock.reacquire()
|
||||
renewed = True
|
||||
except Exception:
|
||||
renewed = False
|
||||
last_lock_time = current_time
|
||||
if not renewed or not lock.owned():
|
||||
task_logger.error(
|
||||
"user_file_docid_migration_task - Lost lock ownership or failed to renew; aborting for safety"
|
||||
)
|
||||
return False
|
||||
|
||||
task_logger.debug(f"Processing user file {uf.id} with doc_id {doc_id}")
|
||||
task_logger.debug(
|
||||
f"doc_id in search_doc_map: {doc_id in search_doc_map}"
|
||||
)
|
||||
|
||||
if doc_id in search_doc_map:
|
||||
search_docs = search_doc_map[doc_id]
|
||||
task_logger.debug(
|
||||
f"Found {len(search_docs)} search docs to update for user file {uf.id}"
|
||||
try:
|
||||
clean_old_doc_id = replace_invalid_doc_id_characters(
|
||||
user_file.document_id
|
||||
)
|
||||
normalized_doc_id = _normalize_legacy_user_file_doc_id(
|
||||
clean_old_doc_id
|
||||
)
|
||||
user_project_ids = [project.id for project in user_file.projects]
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Migrating user file {user_file.id} with doc_id {normalized_doc_id}"
|
||||
)
|
||||
# Update the SearchDoc to use the UserFile's UUID
|
||||
for search_doc in search_docs:
|
||||
search_doc.document_id = str(uf.id)
|
||||
db_session.add(search_doc)
|
||||
|
||||
# Mark UserFile as migrated
|
||||
uf.document_id_migrated = True
|
||||
db_session.add(uf)
|
||||
index_name = active_settings.primary.index_name
|
||||
|
||||
# First find the chunks count using direct Vespa query
|
||||
selection = f"{index_name}.document_id=='{normalized_doc_id}'"
|
||||
|
||||
# Count all chunks for this document
|
||||
chunk_count = _get_document_chunk_count(
|
||||
index_name=index_name,
|
||||
selection=selection,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Found {chunk_count} chunks for document {normalized_doc_id}"
|
||||
)
|
||||
|
||||
# Now update Vespa chunks with the found chunk count using retry_index
|
||||
updated_chunks = retry_index.update_single(
|
||||
doc_id=str(normalized_doc_id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
fields=VespaDocumentFields(document_id=str(user_file.id)),
|
||||
user_fields=VespaDocumentUserFields(
|
||||
user_projects=user_project_ids
|
||||
),
|
||||
)
|
||||
user_file.chunk_count = updated_chunks
|
||||
|
||||
# Update the SearchDocs
|
||||
actual_doc_id = str(user_file.document_id)
|
||||
normalized_actual_doc_id = _normalize_legacy_user_file_doc_id(
|
||||
actual_doc_id
|
||||
)
|
||||
if (
|
||||
normalized_doc_id in search_doc_map
|
||||
or normalized_actual_doc_id in search_doc_map
|
||||
):
|
||||
to_update = (
|
||||
search_doc_map[normalized_doc_id]
|
||||
if normalized_doc_id in search_doc_map
|
||||
else search_doc_map[normalized_actual_doc_id]
|
||||
)
|
||||
task_logger.debug(
|
||||
f"user_file_docid_migration_task - Updating {len(to_update)} search docs for user file {user_file.id}"
|
||||
)
|
||||
for search_doc in to_update:
|
||||
search_doc.document_id = str(user_file.id)
|
||||
db_session.add(search_doc)
|
||||
|
||||
user_file.document_id_migrated = True
|
||||
db_session.add(user_file)
|
||||
db_session.commit()
|
||||
updated_count += 1
|
||||
except Exception as per_file_exc:
|
||||
# Rollback the current transaction and continue with the next file
|
||||
db_session.rollback()
|
||||
task_logger.exception(
|
||||
f"user_file_docid_migration_task - Error migrating user file {user_file.id} - "
|
||||
f"{per_file_exc.__class__.__name__}"
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Updated {updated_count} SearchDoc records with new UUIDs"
|
||||
f"user_file_docid_migration_task - Updated {updated_count} user files"
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
# Normalize plaintext FileRecord blobs: ensure S3 object key aligns with current file_id
|
||||
try:
|
||||
store = get_default_file_store()
|
||||
# Only supported for S3-backed stores where we can manipulate object keys
|
||||
if isinstance(store, S3BackedFileStore):
|
||||
s3_client = store._get_s3_client()
|
||||
bucket_name = store._get_bucket_name()
|
||||
|
||||
plaintext_records: Sequence[FileRecord] = (
|
||||
db_session.execute(
|
||||
sa.select(FileRecord).where(
|
||||
FileRecord.file_origin == FileOrigin.PLAINTEXT_CACHE,
|
||||
FileRecord.file_id.like("plaintext_%"),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
normalized = 0
|
||||
for fr in plaintext_records:
|
||||
try:
|
||||
expected_key = store._get_s3_key(fr.file_id)
|
||||
if fr.object_key == expected_key:
|
||||
continue
|
||||
|
||||
# Copy old object to new key
|
||||
copy_source = f"{fr.bucket_name}/{fr.object_key}"
|
||||
s3_client.copy_object(
|
||||
CopySource=copy_source,
|
||||
Bucket=bucket_name,
|
||||
Key=expected_key,
|
||||
MetadataDirective="COPY",
|
||||
)
|
||||
|
||||
# Delete old object (best-effort)
|
||||
try:
|
||||
s3_client.delete_object(
|
||||
Bucket=fr.bucket_name, Key=fr.object_key
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Update DB record with new key
|
||||
fr.object_key = expected_key
|
||||
db_session.add(fr)
|
||||
normalized += 1
|
||||
except Exception as e:
|
||||
task_logger.warning(
|
||||
f"Tenant={tenant_id} failed plaintext object normalize for "
|
||||
f"id={fr.file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
|
||||
if normalized:
|
||||
db_session.commit()
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task normalized {normalized} plaintext objects for tenant={tenant_id}"
|
||||
)
|
||||
else:
|
||||
task_logger.info(
|
||||
"user_file_docid_migration_task skipping plaintext object normalization (non-S3 store)"
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"user_file_docid_migration_task - Error during plaintext normalization for tenant={tenant_id}"
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task completed for tenant={tenant_id} (rows={len(rows)})"
|
||||
f"user_file_docid_migration_task - Completed for tenant={tenant_id} (updated={updated_count})"
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"user_file_docid_migration_task - Error during execution for tenant={tenant_id}"
|
||||
f"user_file_docid_migration_task - Error during execution for tenant={tenant_id} "
|
||||
f"(updated={updated_count}) exception={e.__class__.__name__}"
|
||||
)
|
||||
return False
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
10
backend/onyx/background/celery/versioned_apps/background.py
Normal file
10
backend/onyx/background/celery/versioned_apps/background.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from celery import Celery
|
||||
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app: Celery = fetch_versioned_implementation(
|
||||
"onyx.background.celery.apps.background",
|
||||
"celery_app",
|
||||
)
|
||||
29
backend/onyx/chat/memories.py
Normal file
29
backend/onyx/chat/memories.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
def make_memories_callback(
|
||||
user: User | None, db_session: Session
|
||||
) -> Callable[[], list[str]]:
|
||||
def memories_callback() -> list[str]:
|
||||
if user is None:
|
||||
return []
|
||||
|
||||
user_info = [
|
||||
f"User's name: {user.personal_name}" if user.personal_name else "",
|
||||
f"User's role: {user.personal_role}" if user.personal_role else "",
|
||||
f"User's email: {user.email}" if user.email else "",
|
||||
]
|
||||
|
||||
memory_rows = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user.id)
|
||||
).all()
|
||||
memories = [memory.memory_text for memory in memory_rows if memory.memory_text]
|
||||
return user_info + memories
|
||||
|
||||
return memories_callback
|
||||
@@ -2,18 +2,23 @@ import re
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
from typing import Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from agents import Model
|
||||
from agents import ModelSettings
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import ToolCallException
|
||||
from onyx.agents.agent_sdk.message_format import base_messages_to_agent_sdk_msgs
|
||||
from onyx.chat.answer import Answer
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_temporary_persona
|
||||
from onyx.chat.chat_utils import process_kg_commands
|
||||
from onyx.chat.memories import make_memories_callback
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import ChatBasicResponse
|
||||
@@ -26,12 +31,15 @@ from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import UserKnowledgeFilePacket
|
||||
from onyx.chat.packet_proccessing.process_streamed_packets import (
|
||||
process_streamed_packets,
|
||||
)
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import (
|
||||
default_build_system_message_v2,
|
||||
)
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.chat.turn import fast_chat_turn
|
||||
from onyx.chat.turn.infra.emitter import get_default_emitter
|
||||
from onyx.chat.turn.models import ChatTurnDependencies
|
||||
from onyx.chat.user_files.parse_user_files import parse_user_files
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
@@ -70,18 +78,22 @@ from onyx.db.projects import get_project_instructions
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.feature_flags.factory import get_default_feature_flag_provider
|
||||
from onyx.feature_flags.feature_flags_keys import SIMPLE_AGENT_FRAMEWORK
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import build_frontend_file_url
|
||||
from onyx.file_store.utils import load_all_chat_files
|
||||
from onyx.kg.models import KGException
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.factory import get_llm_model_and_settings_for_persona
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.factory import get_main_llm_from_tuple
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
@@ -89,6 +101,7 @@ from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.tools.adapter_v1_to_v2 import tools_to_function_tools
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool import Tool
|
||||
@@ -108,11 +121,14 @@ from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.timing import log_generator_function_time
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
ERROR_TYPE_CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class ToolCallException(Exception):
|
||||
"""Exception raised for errors during tool calls."""
|
||||
|
||||
|
||||
class PartialResponse(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
@@ -137,12 +153,12 @@ def _build_project_llm_docs(
|
||||
return project_llm_docs
|
||||
|
||||
project_file_id_set = set(project_file_ids)
|
||||
|
||||
def _strip_nuls(s: str) -> str:
|
||||
return s.replace("\x00", "") if s else s
|
||||
|
||||
for f in in_memory_user_files:
|
||||
if project_file_id_set and (f.file_id in project_file_id_set):
|
||||
|
||||
def _strip_nuls(s: str) -> str:
|
||||
return s.replace("\x00", "") if s else s
|
||||
|
||||
cleaned_filename = _strip_nuls(f.filename or str(f.file_id))
|
||||
|
||||
if f.file_type.is_text_file():
|
||||
@@ -363,14 +379,12 @@ def stream_chat_message_objects(
|
||||
long_term_logger = LongTermLogger(
|
||||
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
|
||||
)
|
||||
|
||||
persona = _get_persona_for_chat_session(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
default_persona=chat_session.persona,
|
||||
)
|
||||
|
||||
# TODO: remove once we have an endpoint for this stuff
|
||||
process_kg_commands(new_msg_req.message, persona.name, tenant_id, db_session)
|
||||
|
||||
@@ -740,15 +754,29 @@ def stream_chat_message_objects(
|
||||
and (file.file_id not in project_file_ids)
|
||||
]
|
||||
)
|
||||
|
||||
feature_flag_provider = get_default_feature_flag_provider()
|
||||
simple_agent_framework_enabled = (
|
||||
feature_flag_provider.feature_enabled_for_user_tenant(
|
||||
flag_key=SIMPLE_AGENT_FRAMEWORK,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
and not new_msg_req.use_agentic_search
|
||||
)
|
||||
prompt_user_message = default_build_user_message(
|
||||
user_query=final_msg.message,
|
||||
prompt_config=prompt_config,
|
||||
files=latest_query_files,
|
||||
)
|
||||
mem_callback = make_memories_callback(user, db_session)
|
||||
system_message = (
|
||||
default_build_system_message_v2(prompt_config, llm.config, mem_callback)
|
||||
if simple_agent_framework_enabled
|
||||
else default_build_system_message(prompt_config, llm.config, mem_callback)
|
||||
)
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
user_query=final_msg.message,
|
||||
prompt_config=prompt_config,
|
||||
files=latest_query_files,
|
||||
single_message_history=single_message_history,
|
||||
),
|
||||
system_message=default_build_system_message(prompt_config, llm.config),
|
||||
user_message=prompt_user_message,
|
||||
system_message=system_message,
|
||||
message_history=message_history,
|
||||
llm_config=llm.config,
|
||||
raw_user_query=final_msg.message,
|
||||
@@ -790,11 +818,29 @@ def stream_chat_message_objects(
|
||||
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
|
||||
project_instructions=project_instructions,
|
||||
)
|
||||
if simple_agent_framework_enabled:
|
||||
llm_model, model_settings = get_llm_model_and_settings_for_persona(
|
||||
persona=persona,
|
||||
llm_override=(new_msg_req.llm_override or chat_session.llm_override),
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
yield from _fast_message_stream(
|
||||
answer,
|
||||
tools,
|
||||
db_session,
|
||||
get_redis_client(),
|
||||
chat_session_id,
|
||||
reserved_message_id,
|
||||
prompt_config,
|
||||
llm_model,
|
||||
model_settings,
|
||||
)
|
||||
else:
|
||||
from onyx.chat.packet_proccessing import process_streamed_packets
|
||||
|
||||
# Process streamed packets using the new packet processing module
|
||||
yield from process_streamed_packets(
|
||||
answer_processed_output=answer.processed_streamed_output,
|
||||
)
|
||||
yield from process_streamed_packets.process_streamed_packets(
|
||||
answer_processed_output=answer.processed_streamed_output,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
@@ -831,6 +877,58 @@ def stream_chat_message_objects(
|
||||
return
|
||||
|
||||
|
||||
# TODO: Refactor this to live somewhere else
|
||||
def _fast_message_stream(
|
||||
answer: Answer,
|
||||
tools: list[Tool],
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
chat_session_id: UUID,
|
||||
reserved_message_id: int,
|
||||
prompt_config: PromptConfig,
|
||||
llm_model: Model,
|
||||
model_settings: ModelSettings,
|
||||
) -> Generator[Packet, None, None]:
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
|
||||
OktaProfileTool,
|
||||
)
|
||||
|
||||
image_generation_tool_instance = None
|
||||
okta_profile_tool_instance = None
|
||||
for tool in tools:
|
||||
if isinstance(tool, ImageGenerationTool):
|
||||
image_generation_tool_instance = tool
|
||||
elif isinstance(tool, OktaProfileTool):
|
||||
okta_profile_tool_instance = tool
|
||||
messages = base_messages_to_agent_sdk_msgs(
|
||||
answer.graph_inputs.prompt_builder.build()
|
||||
)
|
||||
emitter = get_default_emitter()
|
||||
return fast_chat_turn.fast_chat_turn(
|
||||
messages=messages,
|
||||
# TODO: Maybe we can use some DI framework here?
|
||||
dependencies=ChatTurnDependencies(
|
||||
llm_model=llm_model,
|
||||
model_settings=model_settings,
|
||||
llm=answer.graph_tooling.primary_llm,
|
||||
tools=tools_to_function_tools(tools),
|
||||
search_pipeline=answer.graph_tooling.search_tool,
|
||||
image_generation_tool=image_generation_tool_instance,
|
||||
okta_profile_tool=okta_profile_tool_instance,
|
||||
db_session=db_session,
|
||||
redis_client=redis_client,
|
||||
emitter=emitter,
|
||||
),
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=reserved_message_id,
|
||||
research_type=answer.graph_config.behavior.research_type,
|
||||
prompt_config=prompt_config,
|
||||
)
|
||||
|
||||
|
||||
@log_generator_function_time()
|
||||
def stream_chat_message(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
|
||||
@@ -23,6 +23,8 @@ from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
|
||||
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from onyx.prompts.prompt_utils import drop_messages_history_overflow
|
||||
from onyx.prompts.prompt_utils import handle_company_awareness
|
||||
from onyx.prompts.prompt_utils import handle_memories
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
@@ -31,9 +33,10 @@ from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
|
||||
def default_build_system_message(
|
||||
def default_build_system_message_v2(
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
memories_callback: Callable[[], list[str]] | None = None,
|
||||
) -> SystemMessage | None:
|
||||
system_prompt = prompt_config.system_prompt.strip()
|
||||
# See https://simonwillison.net/tags/markdown/ for context on this temporary fix
|
||||
@@ -52,6 +55,41 @@ def default_build_system_message(
|
||||
if not tag_handled_prompt:
|
||||
return None
|
||||
|
||||
tag_handled_prompt = handle_company_awareness(tag_handled_prompt)
|
||||
|
||||
if memories_callback:
|
||||
tag_handled_prompt = handle_memories(tag_handled_prompt, memories_callback)
|
||||
|
||||
return SystemMessage(content=tag_handled_prompt)
|
||||
|
||||
|
||||
def default_build_system_message(
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
memories_callback: Callable[[], list[str]] | None = None,
|
||||
) -> SystemMessage | None:
|
||||
system_prompt = prompt_config.system_prompt.strip()
|
||||
# See https://simonwillison.net/tags/markdown/ for context on this temporary fix
|
||||
# for o-series markdown generation
|
||||
if (
|
||||
llm_config.model_provider == OPENAI_PROVIDER_NAME
|
||||
and llm_config.model_name.startswith("o")
|
||||
):
|
||||
system_prompt = CODE_BLOCK_MARKDOWN + system_prompt
|
||||
tag_handled_prompt = handle_onyx_date_awareness(
|
||||
system_prompt,
|
||||
prompt_config,
|
||||
add_additional_info_if_no_tag=prompt_config.datetime_aware,
|
||||
)
|
||||
|
||||
if not tag_handled_prompt:
|
||||
return None
|
||||
|
||||
tag_handled_prompt = handle_company_awareness(tag_handled_prompt)
|
||||
|
||||
if memories_callback:
|
||||
tag_handled_prompt = handle_memories(tag_handled_prompt, memories_callback)
|
||||
|
||||
return SystemMessage(content=tag_handled_prompt)
|
||||
|
||||
|
||||
|
||||
56
backend/onyx/chat/stop_signal_checker.py
Normal file
56
backend/onyx/chat/stop_signal_checker.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from uuid import UUID
|
||||
|
||||
from redis.client import Redis
|
||||
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
# Redis key prefixes for chat session stop signals
|
||||
PREFIX = "chatsessionstop"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
|
||||
|
||||
def set_fence(chat_session_id: UUID, redis_client: Redis, value: bool) -> None:
|
||||
"""
|
||||
Set or clear the stop signal fence for a chat session.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: Redis client to use
|
||||
value: True to set the fence (stop signal), False to clear it
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
fence_key = f"{FENCE_PREFIX}_{tenant_id}_{chat_session_id}"
|
||||
if not value:
|
||||
redis_client.delete(fence_key)
|
||||
return
|
||||
|
||||
redis_client.set(fence_key, 0)
|
||||
|
||||
|
||||
def is_connected(chat_session_id: UUID, redis_client: Redis) -> bool:
|
||||
"""
|
||||
Check if the chat session should continue (not stopped).
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session to check
|
||||
redis_client: Redis client to use for checking the stop signal
|
||||
|
||||
Returns:
|
||||
True if the session should continue, False if it should stop
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
fence_key = f"{FENCE_PREFIX}_{tenant_id}_{chat_session_id}"
|
||||
return not bool(redis_client.exists(fence_key))
|
||||
|
||||
|
||||
def reset_cancel_status(chat_session_id: UUID, redis_client: Redis) -> None:
|
||||
"""
|
||||
Clear the stop signal for a chat session.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: Redis client to use
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
fence_key = f"{FENCE_PREFIX}_{tenant_id}_{chat_session_id}"
|
||||
redis_client.delete(fence_key)
|
||||
1
backend/onyx/chat/turn/__init__.py
Normal file
1
backend/onyx/chat/turn/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Turn module for chat functionality
|
||||
333
backend/onyx/chat/turn/fast_chat_turn.py
Normal file
333
backend/onyx/chat/turn/fast_chat_turn.py
Normal file
@@ -0,0 +1,333 @@
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from agents import Agent
|
||||
from agents import RawResponsesStreamEvent
|
||||
from agents import RunResultStreaming
|
||||
from agents import ToolCallItem
|
||||
from agents.tracing import trace
|
||||
|
||||
from onyx.agents.agent_sdk.sync_agent_stream_adapter import SyncAgentStream
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.chat.chat_utils import llm_doc_from_inference_section
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.stop_signal_checker import is_connected
|
||||
from onyx.chat.stop_signal_checker import reset_cancel_status
|
||||
from onyx.chat.stream_processing.citation_processing import CitationProcessor
|
||||
from onyx.chat.turn.infra.chat_turn_event_stream import unified_event_stream
|
||||
from onyx.chat.turn.infra.session_sink import extract_final_answer_from_packets
|
||||
from onyx.chat.turn.infra.session_sink import save_iteration
|
||||
from onyx.chat.turn.models import AgentToolType
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.chat.turn.models import ChatTurnDependencies
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.prompts.prompt_utils import build_task_prompt_reminders_v2
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketObj
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import ResponseFunctionToolCall
|
||||
|
||||
|
||||
def _remove_last_task_prompt_and_insert_new_one(
|
||||
chat_turn_user_message: str,
|
||||
current_messages: list[dict],
|
||||
prompt_config: PromptConfig,
|
||||
ctx: ChatTurnContext,
|
||||
) -> list[dict]:
|
||||
new_task_prompt = build_task_prompt_reminders_v2(
|
||||
chat_turn_user_message,
|
||||
prompt_config,
|
||||
use_language_hint=False,
|
||||
should_cite=ctx.should_cite_documents,
|
||||
)
|
||||
for i in range(len(current_messages) - 1, -1, -1):
|
||||
if current_messages[i].get("role") == "user":
|
||||
current_messages.pop(i)
|
||||
break
|
||||
current_messages = current_messages + [{"role": "user", "content": new_task_prompt}]
|
||||
return current_messages
|
||||
|
||||
|
||||
# TODO -- this can be refactored out and played with in evals + normal demo
|
||||
def _run_agent_loop(
|
||||
messages: list[dict],
|
||||
dependencies: ChatTurnDependencies,
|
||||
chat_session_id: UUID,
|
||||
ctx: ChatTurnContext,
|
||||
prompt_config: PromptConfig,
|
||||
) -> None:
|
||||
current_messages: list[dict] = messages
|
||||
last_call_is_final = False
|
||||
agent = Agent(
|
||||
name="Assistant",
|
||||
model=dependencies.llm_model,
|
||||
tools=cast(list[AgentToolType], dependencies.tools),
|
||||
model_settings=dependencies.model_settings,
|
||||
tool_use_behavior="stop_on_first_tool",
|
||||
)
|
||||
while not last_call_is_final:
|
||||
agent_stream: SyncAgentStream = SyncAgentStream(
|
||||
agent=agent,
|
||||
input=current_messages,
|
||||
context=ctx,
|
||||
)
|
||||
streamed, tool_call_events = _process_stream(
|
||||
agent_stream, chat_session_id, dependencies, ctx
|
||||
)
|
||||
current_messages = cast(list[dict], streamed.to_input_list())
|
||||
current_messages = _remove_last_task_prompt_and_insert_new_one(
|
||||
messages[-1]["content"], current_messages, prompt_config, ctx
|
||||
)
|
||||
# TODO: Make this configurable on OnyxAgent level
|
||||
stopping_tools = ["image_generation"]
|
||||
if len(tool_call_events) == 0 or any(
|
||||
tool.name in stopping_tools for tool in tool_call_events
|
||||
):
|
||||
last_call_is_final = True
|
||||
|
||||
|
||||
def _fast_chat_turn_core(
|
||||
messages: list[dict],
|
||||
dependencies: ChatTurnDependencies,
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
research_type: ResearchType,
|
||||
prompt_config: PromptConfig,
|
||||
# Dependency injectable arguments for testing
|
||||
starter_global_iteration_responses: list[IterationAnswer] | None = None,
|
||||
starter_cited_documents: list[InferenceSection] | None = None,
|
||||
) -> None:
|
||||
"""Core fast chat turn logic that allows overriding global_iteration_responses for testing.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages
|
||||
dependencies: Chat turn dependencies
|
||||
chat_session_id: Chat session ID
|
||||
message_id: Message ID
|
||||
research_type: Research type
|
||||
global_iteration_responses: Optional list of iteration answers to inject for testing
|
||||
cited_documents: Optional list of cited documents to inject for testing
|
||||
"""
|
||||
reset_cancel_status(
|
||||
chat_session_id,
|
||||
dependencies.redis_client,
|
||||
)
|
||||
ctx = ChatTurnContext(
|
||||
run_dependencies=dependencies,
|
||||
aggregated_context=AggregatedDRContext(
|
||||
context="context",
|
||||
cited_documents=starter_cited_documents or [],
|
||||
is_internet_marker_dict={},
|
||||
global_iteration_responses=starter_global_iteration_responses or [],
|
||||
),
|
||||
iteration_instructions=[],
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=message_id,
|
||||
research_type=research_type,
|
||||
)
|
||||
with trace("fast_chat_turn"):
|
||||
_run_agent_loop(
|
||||
messages=messages,
|
||||
dependencies=dependencies,
|
||||
chat_session_id=chat_session_id,
|
||||
ctx=ctx,
|
||||
prompt_config=prompt_config,
|
||||
)
|
||||
final_answer = extract_final_answer_from_packets(
|
||||
dependencies.emitter.packet_history
|
||||
)
|
||||
|
||||
all_cited_documents = []
|
||||
if ctx.aggregated_context.global_iteration_responses:
|
||||
context_docs = _gather_context_docs_from_iteration_answers(
|
||||
ctx.aggregated_context.global_iteration_responses
|
||||
)
|
||||
all_cited_documents = context_docs
|
||||
if context_docs and final_answer:
|
||||
_process_citations_for_final_answer(
|
||||
final_answer=final_answer,
|
||||
context_docs=context_docs,
|
||||
dependencies=dependencies,
|
||||
ctx=ctx,
|
||||
)
|
||||
|
||||
save_iteration(
|
||||
db_session=dependencies.db_session,
|
||||
message_id=message_id,
|
||||
chat_session_id=chat_session_id,
|
||||
research_type=research_type,
|
||||
ctx=ctx,
|
||||
final_answer=final_answer,
|
||||
all_cited_documents=all_cited_documents,
|
||||
)
|
||||
dependencies.emitter.emit(
|
||||
Packet(ind=ctx.current_run_step, obj=OverallStop(type="stop"))
|
||||
)
|
||||
|
||||
|
||||
@unified_event_stream
|
||||
def fast_chat_turn(
|
||||
messages: list[dict],
|
||||
dependencies: ChatTurnDependencies,
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
research_type: ResearchType,
|
||||
prompt_config: PromptConfig,
|
||||
) -> None:
|
||||
"""Main fast chat turn function that calls the core logic with default parameters."""
|
||||
_fast_chat_turn_core(
|
||||
messages,
|
||||
dependencies,
|
||||
chat_session_id,
|
||||
message_id,
|
||||
research_type,
|
||||
prompt_config,
|
||||
starter_global_iteration_responses=None,
|
||||
)
|
||||
|
||||
|
||||
def _process_stream(
|
||||
agent_stream: SyncAgentStream,
|
||||
chat_session_id: UUID,
|
||||
dependencies: ChatTurnDependencies,
|
||||
ctx: ChatTurnContext,
|
||||
emit_message_to_user: bool = True,
|
||||
) -> tuple[RunResultStreaming, list["ResponseFunctionToolCall"]]:
|
||||
from litellm import ResponseFunctionToolCall
|
||||
|
||||
tool_call_events: list[ResponseFunctionToolCall] = []
|
||||
for ev in agent_stream:
|
||||
connected = is_connected(
|
||||
chat_session_id,
|
||||
dependencies.redis_client,
|
||||
)
|
||||
if not connected:
|
||||
_emit_clean_up_packets(dependencies, ctx)
|
||||
agent_stream.cancel()
|
||||
break
|
||||
if emit_message_to_user:
|
||||
obj = _default_packet_translation(ev, ctx)
|
||||
if obj:
|
||||
dependencies.emitter.emit(Packet(ind=ctx.current_run_step, obj=obj))
|
||||
if isinstance(getattr(ev, "item", None), ToolCallItem):
|
||||
tool_call_events.append(cast(ResponseFunctionToolCall, ev.item.raw_item))
|
||||
if agent_stream.streamed is None:
|
||||
raise ValueError("agent_stream.streamed is None")
|
||||
return agent_stream.streamed, tool_call_events
|
||||
|
||||
|
||||
# TODO: Maybe in general there's a cleaner way to handle cancellation in the middle of a tool call?
|
||||
def _emit_clean_up_packets(
|
||||
dependencies: ChatTurnDependencies, ctx: ChatTurnContext
|
||||
) -> None:
|
||||
if not (
|
||||
dependencies.emitter.packet_history
|
||||
and dependencies.emitter.packet_history[-1].obj.type == "message_delta"
|
||||
):
|
||||
dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=ctx.current_run_step,
|
||||
obj=MessageStart(
|
||||
type="message_start", content="Cancelled", final_documents=None
|
||||
),
|
||||
)
|
||||
)
|
||||
dependencies.emitter.emit(
|
||||
Packet(ind=ctx.current_run_step, obj=SectionEnd(type="section_end"))
|
||||
)
|
||||
|
||||
|
||||
def _gather_context_docs_from_iteration_answers(
|
||||
iteration_answers: list[IterationAnswer],
|
||||
) -> list[InferenceSection]:
|
||||
"""Gather cited documents from iteration answers for citation processing."""
|
||||
context_docs: list[InferenceSection] = []
|
||||
|
||||
for iteration_answer in iteration_answers:
|
||||
# Extract cited documents from this iteration
|
||||
for inference_section in iteration_answer.cited_documents.values():
|
||||
# Avoid duplicates by checking document_id
|
||||
if not any(
|
||||
doc.center_chunk.document_id
|
||||
== inference_section.center_chunk.document_id
|
||||
for doc in context_docs
|
||||
):
|
||||
context_docs.append(inference_section)
|
||||
|
||||
return context_docs
|
||||
|
||||
|
||||
def _process_citations_for_final_answer(
|
||||
final_answer: str,
|
||||
context_docs: list[InferenceSection],
|
||||
dependencies: ChatTurnDependencies,
|
||||
ctx: ChatTurnContext,
|
||||
) -> None:
|
||||
index = ctx.current_run_step + 1
|
||||
"""Process citations in the final answer and emit citation events."""
|
||||
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
|
||||
# Convert InferenceSection objects to LlmDoc objects for citation processing
|
||||
llm_docs = [llm_doc_from_inference_section(section) for section in context_docs]
|
||||
|
||||
# Create document ID to rank mappings (simple 1-based indexing)
|
||||
final_doc_id_to_rank_map = DocumentIdOrderMapping(
|
||||
order_mapping={doc.document_id: i + 1 for i, doc in enumerate(llm_docs)}
|
||||
)
|
||||
display_doc_id_to_rank_map = final_doc_id_to_rank_map # Same mapping for display
|
||||
|
||||
# Initialize citation processor
|
||||
citation_processor = CitationProcessor(
|
||||
context_docs=llm_docs,
|
||||
final_doc_id_to_rank_map=final_doc_id_to_rank_map,
|
||||
display_doc_id_to_rank_map=display_doc_id_to_rank_map,
|
||||
)
|
||||
|
||||
# Process the final answer through citation processor
|
||||
collected_citations: list = []
|
||||
for response_part in citation_processor.process_token(final_answer):
|
||||
if hasattr(response_part, "citation_num"): # It's a CitationInfo
|
||||
collected_citations.append(response_part)
|
||||
|
||||
# Emit citation events if we found any citations
|
||||
if collected_citations:
|
||||
dependencies.emitter.emit(Packet(ind=index, obj=CitationStart()))
|
||||
dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=CitationDelta(citations=collected_citations), # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
dependencies.emitter.emit(Packet(ind=index, obj=SectionEnd(type="section_end")))
|
||||
ctx.current_run_step = index
|
||||
|
||||
|
||||
def _default_packet_translation(ev: object, ctx: ChatTurnContext) -> PacketObj | None:
|
||||
if isinstance(ev, RawResponsesStreamEvent):
|
||||
# TODO: might need some variation here for different types of models
|
||||
# OpenAI packet translator
|
||||
obj: PacketObj | None = None
|
||||
if ev.data.type == "response.content_part.added":
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
ctx.aggregated_context.cited_documents
|
||||
)
|
||||
obj = MessageStart(
|
||||
type="message_start", content="", final_documents=retrieved_search_docs
|
||||
)
|
||||
elif ev.data.type == "response.output_text.delta":
|
||||
obj = MessageDelta(type="message_delta", content=ev.data.delta)
|
||||
elif ev.data.type == "response.content_part.done":
|
||||
obj = SectionEnd(type="section_end")
|
||||
return obj
|
||||
return None
|
||||
1
backend/onyx/chat/turn/infra/__init__.py
Normal file
1
backend/onyx/chat/turn/infra/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Infrastructure module for chat turn orchestration
|
||||
57
backend/onyx/chat/turn/infra/chat_turn_event_stream.py
Normal file
57
backend/onyx/chat/turn/infra/chat_turn_event_stream.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
from onyx.chat.turn.models import ChatTurnDependencies
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketException
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
|
||||
|
||||
def unified_event_stream(
|
||||
turn_func: Callable[..., None],
|
||||
) -> Callable[..., Generator[Packet, None]]:
|
||||
"""
|
||||
Decorator that wraps a turn_func to provide event streaming capabilities.
|
||||
|
||||
Usage:
|
||||
@unified_event_stream
|
||||
def my_turn_func(messages, dependencies, *args, **kwargs):
|
||||
# Your turn logic here
|
||||
pass
|
||||
|
||||
Then call it like:
|
||||
generator = my_turn_func(messages, dependencies, *args, **kwargs)
|
||||
"""
|
||||
|
||||
def wrapper(
|
||||
messages: List[Dict[str, Any]],
|
||||
dependencies: ChatTurnDependencies,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> Generator[Packet, None]:
|
||||
def run_with_exception_capture() -> None:
|
||||
try:
|
||||
turn_func(messages, dependencies, *args, **kwargs)
|
||||
except Exception as e:
|
||||
dependencies.emitter.emit(
|
||||
Packet(ind=0, obj=PacketException(type="error", exception=e))
|
||||
)
|
||||
|
||||
thread = run_in_background(run_with_exception_capture)
|
||||
while True:
|
||||
pkt: Packet = dependencies.emitter.bus.get()
|
||||
if pkt.obj == OverallStop(type="stop"):
|
||||
yield pkt
|
||||
break
|
||||
elif isinstance(pkt.obj, PacketException):
|
||||
raise pkt.obj.exception
|
||||
else:
|
||||
yield pkt
|
||||
wait_on_background(thread)
|
||||
|
||||
return wrapper
|
||||
21
backend/onyx/chat/turn/infra/emitter.py
Normal file
21
backend/onyx/chat/turn/infra/emitter.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from queue import Queue
|
||||
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
|
||||
|
||||
class Emitter:
|
||||
"""Use this inside tools to emit arbitrary UI progress."""
|
||||
|
||||
def __init__(self, bus: Queue):
|
||||
self.bus = bus
|
||||
self.packet_history: list[Packet] = []
|
||||
|
||||
def emit(self, packet: Packet) -> None:
|
||||
self.bus.put(packet)
|
||||
self.packet_history.append(packet)
|
||||
|
||||
|
||||
def get_default_emitter() -> Emitter:
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
return emitter
|
||||
170
backend/onyx/chat/turn/infra/session_sink.py
Normal file
170
backend/onyx/chat/turn/infra/session_sink.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# TODO: Figure out a way to persist information is robust to cancellation,
|
||||
# modular so easily testable in unit tests and evals [likely injecting some higher
|
||||
# level session manager and span sink], potentially has some robustness off the critical path,
|
||||
# and promotes clean separation of concerns.
|
||||
import re
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
|
||||
GeneratedImageFullResult,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.chat import create_search_doc_from_inference_section
|
||||
from onyx.db.chat import update_db_session_with_messages
|
||||
from onyx.db.models import ChatMessage__SearchDoc
|
||||
from onyx.db.models import ResearchAgentIteration
|
||||
from onyx.db.models import ResearchAgentIterationSubStep
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
|
||||
|
||||
def save_iteration(
|
||||
db_session: Session,
|
||||
message_id: int,
|
||||
chat_session_id: UUID,
|
||||
research_type: ResearchType,
|
||||
ctx: ChatTurnContext,
|
||||
final_answer: str,
|
||||
all_cited_documents: list[InferenceSection],
|
||||
) -> None:
|
||||
# first, insert the search_docs
|
||||
is_internet_marker_dict: dict[str, bool] = {}
|
||||
search_docs = [
|
||||
create_search_doc_from_inference_section(
|
||||
inference_section=inference_section,
|
||||
is_internet=is_internet_marker_dict.get(
|
||||
inference_section.center_chunk.document_id, False
|
||||
), # TODO: revisit
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
for inference_section in all_cited_documents
|
||||
]
|
||||
|
||||
# then, map_search_docs to message
|
||||
_insert_chat_message_search_doc_pair(
|
||||
message_id, [search_doc.id for search_doc in search_docs], db_session
|
||||
)
|
||||
|
||||
# lastly, insert the citations
|
||||
citation_dict: dict[int, int] = {}
|
||||
cited_doc_nrs = _extract_citation_numbers(final_answer)
|
||||
if search_docs:
|
||||
for cited_doc_nr in cited_doc_nrs:
|
||||
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=ctx.run_dependencies.llm.config.model_name,
|
||||
provider_type=ctx.run_dependencies.llm.config.model_provider,
|
||||
)
|
||||
num_tokens = len(llm_tokenizer.encode(final_answer or ""))
|
||||
# Update the chat message and its parent message in database
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=chat_session_id,
|
||||
is_agentic=research_type == ResearchType.DEEP,
|
||||
message=final_answer,
|
||||
citations=citation_dict,
|
||||
research_type=research_type,
|
||||
research_plan={},
|
||||
final_documents=search_docs,
|
||||
update_parent_message=True,
|
||||
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
token_count=num_tokens,
|
||||
)
|
||||
|
||||
# TODO: I don't think this is the ideal schema for all use cases
|
||||
# find a better schema to store tool and reasoning calls
|
||||
for iteration_preparation in ctx.iteration_instructions:
|
||||
research_agent_iteration_step = ResearchAgentIteration(
|
||||
primary_question_id=message_id,
|
||||
reasoning=iteration_preparation.reasoning,
|
||||
purpose=iteration_preparation.purpose,
|
||||
iteration_nr=iteration_preparation.iteration_nr,
|
||||
)
|
||||
db_session.add(research_agent_iteration_step)
|
||||
|
||||
for iteration_answer in ctx.aggregated_context.global_iteration_responses:
|
||||
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
list(iteration_answer.cited_documents.values())
|
||||
)
|
||||
|
||||
# Convert SavedSearchDoc objects to JSON-serializable format
|
||||
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
|
||||
|
||||
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
|
||||
primary_question_id=message_id,
|
||||
iteration_nr=iteration_answer.iteration_nr,
|
||||
iteration_sub_step_nr=iteration_answer.parallelization_nr,
|
||||
sub_step_instructions=iteration_answer.question,
|
||||
sub_step_tool_id=iteration_answer.tool_id,
|
||||
sub_answer=iteration_answer.answer,
|
||||
reasoning=iteration_answer.reasoning,
|
||||
claims=iteration_answer.claims,
|
||||
cited_doc_results=serialized_search_docs,
|
||||
generated_images=(
|
||||
GeneratedImageFullResult(images=iteration_answer.generated_images)
|
||||
if iteration_answer.generated_images
|
||||
else None
|
||||
),
|
||||
additional_data=iteration_answer.additional_data,
|
||||
is_web_fetch=iteration_answer.is_web_fetch,
|
||||
queries=iteration_answer.queries,
|
||||
)
|
||||
db_session.add(research_agent_iteration_sub_step)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _insert_chat_message_search_doc_pair(
|
||||
message_id: int, search_doc_ids: list[int], db_session: Session
|
||||
) -> None:
|
||||
"""
|
||||
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
|
||||
|
||||
Args:
|
||||
message_id: The ID of the chat message
|
||||
search_doc_id: The ID of the search document
|
||||
db_session: The database session
|
||||
"""
|
||||
for search_doc_id in search_doc_ids:
|
||||
chat_message_search_doc = ChatMessage__SearchDoc(
|
||||
chat_message_id=message_id, search_doc_id=search_doc_id
|
||||
)
|
||||
db_session.add(chat_message_search_doc)
|
||||
|
||||
|
||||
def _extract_citation_numbers(text: str) -> list[int]:
|
||||
"""
|
||||
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
|
||||
Returns a list of all unique citation numbers found.
|
||||
"""
|
||||
# Pattern to match [[number]] or [[number1, number2, ...]]
|
||||
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
|
||||
matches = re.findall(pattern, text)
|
||||
|
||||
cited_numbers = []
|
||||
for match in matches:
|
||||
# Split by comma and extract all numbers
|
||||
numbers = [int(num.strip()) for num in match.split(",")]
|
||||
cited_numbers.extend(numbers)
|
||||
|
||||
return list(set(cited_numbers)) # Return unique numbers
|
||||
|
||||
|
||||
def extract_final_answer_from_packets(packet_history: list[Packet]) -> str:
|
||||
"""Extract the final answer by concatenating all MessageDelta content."""
|
||||
final_answer = ""
|
||||
for packet in packet_history:
|
||||
if isinstance(packet.obj, MessageDelta) or isinstance(packet.obj, MessageStart):
|
||||
final_answer += packet.obj.content
|
||||
return final_answer
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user