mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-22 02:05:46 +00:00
Compare commits
5 Commits
thread_sen
...
feat/resol
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d81337e345 | ||
|
|
df3c8982a1 | ||
|
|
dbb720e7f9 | ||
|
|
f68b9526fb | ||
|
|
6460b5df4b |
@@ -1,8 +0,0 @@
|
||||
# Exclude these commits from git blame (e.g. mass reformatting).
|
||||
# These are ignored by GitHub automatically.
|
||||
# To enable this locally, run:
|
||||
#
|
||||
# git config blame.ignoreRevsFile .git-blame-ignore-revs
|
||||
|
||||
3134e5f840c12c8f32613ce520101a047c89dcc2 # refactor(whitespace): rm temporary react fragments (#7161)
|
||||
ed3f72bc75f3e3a9ae9e4d8cd38278f9c97e78b4 # refactor(whitespace): rm react fragment #7190
|
||||
7
.github/CODEOWNERS
vendored
7
.github/CODEOWNERS
vendored
@@ -1,10 +1,3 @@
|
||||
* @onyx-dot-app/onyx-core-team
|
||||
# Helm charts Owners
|
||||
/helm/ @justin-tahara
|
||||
|
||||
# Web standards updates
|
||||
/web/STANDARDS.md @raunakab @Weves
|
||||
|
||||
# Agent context files
|
||||
/CLAUDE.md.template @Weves
|
||||
/AGENTS.md.template @Weves
|
||||
|
||||
@@ -7,6 +7,14 @@ inputs:
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
# TODO: Enable caching once there is a uv.lock file checked in.
|
||||
# with:
|
||||
# enable-cache: true
|
||||
|
||||
- name: Compute requirements hash
|
||||
id: req-hash
|
||||
shell: bash
|
||||
@@ -22,8 +30,6 @@ runs:
|
||||
done <<< "$REQUIREMENTS"
|
||||
echo "hash=$(echo "$hash" | sha256sum | cut -d' ' -f1)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
# NOTE: This comes before Setup uv since clean-ups run in reverse chronological order
|
||||
# such that Setup uv's prune-cache is able to prune the cache before we upload.
|
||||
- name: Cache uv cache directory
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
@@ -32,14 +38,6 @@ runs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-uv-
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
# TODO: Enable caching once there is a uv.lock file checked in.
|
||||
# with:
|
||||
# enable-cache: true
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # ratchet:actions/setup-python@v5
|
||||
with:
|
||||
|
||||
153
.github/workflows/deployment.yml
vendored
153
.github/workflows/deployment.yml
vendored
@@ -11,6 +11,7 @@ on:
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
IS_DRY_RUN: ${{ github.event_name == 'workflow_dispatch' }}
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
|
||||
jobs:
|
||||
@@ -30,46 +31,27 @@ jobs:
|
||||
is-beta: ${{ steps.check.outputs.is-beta }}
|
||||
is-stable-standalone: ${{ steps.check.outputs.is-stable-standalone }}
|
||||
is-beta-standalone: ${{ steps.check.outputs.is-beta-standalone }}
|
||||
is-test-run: ${{ steps.check.outputs.is-test-run }}
|
||||
sanitized-tag: ${{ steps.check.outputs.sanitized-tag }}
|
||||
short-sha: ${{ steps.check.outputs.short-sha }}
|
||||
steps:
|
||||
- name: Check which components to build and version info
|
||||
id: check
|
||||
env:
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
run: |
|
||||
TAG="${GITHUB_REF_NAME}"
|
||||
# Sanitize tag name by replacing slashes with hyphens (for Docker tag compatibility)
|
||||
SANITIZED_TAG=$(echo "$TAG" | tr '/' '-')
|
||||
SHORT_SHA="${GITHUB_SHA::7}"
|
||||
|
||||
# Initialize all flags to false
|
||||
IS_CLOUD=false
|
||||
IS_NIGHTLY=false
|
||||
IS_VERSION_TAG=false
|
||||
IS_STABLE=false
|
||||
IS_BETA=false
|
||||
IS_STABLE_STANDALONE=false
|
||||
IS_BETA_STANDALONE=false
|
||||
IS_PROD_TAG=false
|
||||
IS_TEST_RUN=false
|
||||
BUILD_DESKTOP=false
|
||||
BUILD_WEB=false
|
||||
BUILD_WEB_CLOUD=false
|
||||
BUILD_BACKEND=true
|
||||
BUILD_MODEL_SERVER=true
|
||||
IS_STABLE=false
|
||||
IS_BETA=false
|
||||
IS_STABLE_STANDALONE=false
|
||||
IS_BETA_STANDALONE=false
|
||||
|
||||
# Determine tag type based on pattern matching (do regex checks once)
|
||||
if [[ "$TAG" == *cloud* ]]; then
|
||||
IS_CLOUD=true
|
||||
fi
|
||||
if [[ "$TAG" == nightly* ]]; then
|
||||
IS_NIGHTLY=true
|
||||
fi
|
||||
if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+ ]]; then
|
||||
IS_VERSION_TAG=true
|
||||
fi
|
||||
# Version checks (for web - any stable version)
|
||||
if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
IS_STABLE=true
|
||||
fi
|
||||
@@ -77,35 +59,26 @@ jobs:
|
||||
IS_BETA=true
|
||||
fi
|
||||
|
||||
# Determine what to build based on tag type
|
||||
if [[ "$IS_CLOUD" == "true" ]]; then
|
||||
if [[ "$TAG" == *cloud* ]]; then
|
||||
IS_CLOUD=true
|
||||
BUILD_WEB_CLOUD=true
|
||||
else
|
||||
BUILD_WEB=true
|
||||
# Skip desktop builds on beta tags and nightly runs
|
||||
if [[ "$IS_BETA" != "true" ]] && [[ "$IS_NIGHTLY" != "true" ]]; then
|
||||
# Skip desktop builds on beta tags
|
||||
if [[ "$IS_BETA" != "true" ]]; then
|
||||
BUILD_DESKTOP=true
|
||||
fi
|
||||
fi
|
||||
|
||||
# Standalone version checks (for backend/model-server - version excluding cloud tags)
|
||||
if [[ "$IS_STABLE" == "true" ]] && [[ "$IS_CLOUD" != "true" ]]; then
|
||||
# Version checks (for backend/model-server - stable version excluding cloud tags)
|
||||
if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "$TAG" != *cloud* ]]; then
|
||||
IS_STABLE_STANDALONE=true
|
||||
fi
|
||||
if [[ "$IS_BETA" == "true" ]] && [[ "$IS_CLOUD" != "true" ]]; then
|
||||
if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+-beta\.[0-9]+$ ]] && [[ "$TAG" != *cloud* ]]; then
|
||||
IS_BETA_STANDALONE=true
|
||||
fi
|
||||
|
||||
# Determine if this is a production tag
|
||||
# Production tags are: version tags (v1.2.3*) or nightly tags
|
||||
if [[ "$IS_VERSION_TAG" == "true" ]] || [[ "$IS_NIGHTLY" == "true" ]]; then
|
||||
IS_PROD_TAG=true
|
||||
fi
|
||||
|
||||
# Determine if this is a test run (workflow_dispatch on non-production ref)
|
||||
if [[ "$EVENT_NAME" == "workflow_dispatch" ]] && [[ "$IS_PROD_TAG" != "true" ]]; then
|
||||
IS_TEST_RUN=true
|
||||
fi
|
||||
SHORT_SHA="${GITHUB_SHA::7}"
|
||||
{
|
||||
echo "build-desktop=$BUILD_DESKTOP"
|
||||
echo "build-web=$BUILD_WEB"
|
||||
@@ -117,7 +90,6 @@ jobs:
|
||||
echo "is-beta=$IS_BETA"
|
||||
echo "is-stable-standalone=$IS_STABLE_STANDALONE"
|
||||
echo "is-beta-standalone=$IS_BETA_STANDALONE"
|
||||
echo "is-test-run=$IS_TEST_RUN"
|
||||
echo "sanitized-tag=$SANITIZED_TAG"
|
||||
echo "short-sha=$SHORT_SHA"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
@@ -229,9 +201,9 @@ jobs:
|
||||
working-directory: ./desktop
|
||||
env:
|
||||
SHORT_SHA: ${{ needs.determine-builds.outputs.short-sha }}
|
||||
IS_TEST_RUN: ${{ needs.determine-builds.outputs.is-test-run }}
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
run: |
|
||||
if [ "${IS_TEST_RUN}" == "true" ]; then
|
||||
if [ "${EVENT_NAME}" == "workflow_dispatch" ]; then
|
||||
VERSION="0.0.0-dev+${SHORT_SHA}"
|
||||
else
|
||||
VERSION="${GITHUB_REF_NAME#v}"
|
||||
@@ -256,11 +228,9 @@ jobs:
|
||||
if: runner.os == 'Windows'
|
||||
working-directory: ./desktop
|
||||
shell: pwsh
|
||||
env:
|
||||
IS_TEST_RUN: ${{ needs.determine-builds.outputs.is-test-run }}
|
||||
run: |
|
||||
# Windows MSI requires numeric-only build metadata, so we skip the SHA suffix
|
||||
if ($env:IS_TEST_RUN -eq "true") {
|
||||
if ("${{ github.event_name }}" -eq "workflow_dispatch") {
|
||||
$VERSION = "0.0.0"
|
||||
} else {
|
||||
# Strip 'v' prefix and any pre-release suffix (e.g., -beta.13) for MSI compatibility
|
||||
@@ -289,8 +259,8 @@ jobs:
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
tagName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
tagName: ${{ github.event_name != 'workflow_dispatch' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseName: ${{ github.event_name != 'workflow_dispatch' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseBody: "See the assets to download this version and install."
|
||||
releaseDraft: true
|
||||
prerelease: false
|
||||
@@ -321,7 +291,7 @@ jobs:
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -351,7 +321,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64,mode=max
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-web-arm64:
|
||||
@@ -379,7 +349,7 @@ jobs:
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -409,7 +379,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64,mode=max
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
merge-web:
|
||||
@@ -441,18 +411,18 @@ jobs:
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('web-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta == 'true' && 'beta' || '' }}
|
||||
type=raw,value=${{ github.event_name == 'workflow_dispatch' && format('web-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-beta == 'true' && 'beta' || '' }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-web-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-web-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
@@ -487,7 +457,7 @@ jobs:
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -525,7 +495,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64,mode=max
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-web-cloud-arm64:
|
||||
@@ -553,7 +523,7 @@ jobs:
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -591,7 +561,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64,mode=max
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
merge-web-cloud:
|
||||
@@ -623,15 +593,15 @@ jobs:
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('web-cloud-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ github.event_name == 'workflow_dispatch' && format('web-cloud-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-web-cloud-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-web-cloud-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
@@ -666,7 +636,7 @@ jobs:
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -695,7 +665,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64,mode=max
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-backend-arm64:
|
||||
@@ -723,7 +693,7 @@ jobs:
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -752,7 +722,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64,mode=max
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
merge-backend:
|
||||
@@ -784,18 +754,18 @@ jobs:
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('backend-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-stable-standalone == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
type=raw,value=${{ github.event_name == 'workflow_dispatch' && format('backend-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-stable-standalone == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-backend-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-backend-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
@@ -831,7 +801,7 @@ jobs:
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -864,7 +834,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64,mode=max
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ env.EDGE_TAG != 'true' && vars.MODEL_SERVER_NO_CACHE == 'true' }}
|
||||
provenance: false
|
||||
sbom: false
|
||||
@@ -895,7 +865,7 @@ jobs:
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -928,7 +898,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64,mode=max
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ env.EDGE_TAG != 'true' && vars.MODEL_SERVER_NO_CACHE == 'true' }}
|
||||
provenance: false
|
||||
sbom: false
|
||||
@@ -962,18 +932,18 @@ jobs:
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('model-server-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-stable-standalone == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
type=raw,value=${{ github.event_name == 'workflow_dispatch' && format('model-server-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-stable-standalone == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-model-server-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-model-server-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
@@ -1006,7 +976,7 @@ jobs:
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ "${{ needs.determine-builds.outputs.is-test-run }}" == "true" ]; then
|
||||
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
SCAN_IMAGE="${{ env.RUNS_ON_ECR_CACHE }}:web-${{ needs.determine-builds.outputs.sanitized-tag }}"
|
||||
else
|
||||
SCAN_IMAGE="docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}"
|
||||
@@ -1046,7 +1016,7 @@ jobs:
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ "${{ needs.determine-builds.outputs.is-test-run }}" == "true" ]; then
|
||||
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
SCAN_IMAGE="${{ env.RUNS_ON_ECR_CACHE }}:web-cloud-${{ needs.determine-builds.outputs.sanitized-tag }}"
|
||||
else
|
||||
SCAN_IMAGE="docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}"
|
||||
@@ -1091,7 +1061,7 @@ jobs:
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ "${{ needs.determine-builds.outputs.is-test-run }}" == "true" ]; then
|
||||
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
SCAN_IMAGE="${{ env.RUNS_ON_ECR_CACHE }}:backend-${{ needs.determine-builds.outputs.sanitized-tag }}"
|
||||
else
|
||||
SCAN_IMAGE="docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}"
|
||||
@@ -1133,7 +1103,7 @@ jobs:
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ "${{ needs.determine-builds.outputs.is-test-run }}" == "true" ]; then
|
||||
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
SCAN_IMAGE="${{ env.RUNS_ON_ECR_CACHE }}:model-server-${{ needs.determine-builds.outputs.sanitized-tag }}"
|
||||
else
|
||||
SCAN_IMAGE="docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}"
|
||||
@@ -1152,7 +1122,6 @@ jobs:
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs:
|
||||
- determine-builds
|
||||
- build-desktop
|
||||
- build-web-amd64
|
||||
- build-web-arm64
|
||||
@@ -1166,7 +1135,7 @@ jobs:
|
||||
- build-model-server-amd64
|
||||
- build-model-server-arm64
|
||||
- merge-model-server
|
||||
if: always() && (needs.build-desktop.result == 'failure' || needs.build-web-amd64.result == 'failure' || needs.build-web-arm64.result == 'failure' || needs.merge-web.result == 'failure' || needs.build-web-cloud-amd64.result == 'failure' || needs.build-web-cloud-arm64.result == 'failure' || needs.merge-web-cloud.result == 'failure' || needs.build-backend-amd64.result == 'failure' || needs.build-backend-arm64.result == 'failure' || needs.merge-backend.result == 'failure' || needs.build-model-server-amd64.result == 'failure' || needs.build-model-server-arm64.result == 'failure' || needs.merge-model-server.result == 'failure') && needs.determine-builds.outputs.is-test-run != 'true'
|
||||
if: always() && (needs.build-desktop.result == 'failure' || needs.build-web-amd64.result == 'failure' || needs.build-web-arm64.result == 'failure' || needs.merge-web.result == 'failure' || needs.build-web-cloud-amd64.result == 'failure' || needs.build-web-cloud-arm64.result == 'failure' || needs.merge-web-cloud.result == 'failure' || needs.build-backend-amd64.result == 'failure' || needs.build-backend-arm64.result == 'failure' || needs.merge-backend.result == 'failure' || needs.build-model-server-amd64.result == 'failure' || needs.build-model-server-arm64.result == 'failure' || needs.merge-model-server.result == 'failure') && github.event_name != 'workflow_dispatch'
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 90
|
||||
|
||||
@@ -38,8 +38,6 @@ env:
|
||||
# LLMs
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
VERTEX_CREDENTIALS: ${{ secrets.VERTEX_CREDENTIALS }}
|
||||
VERTEX_LOCATION: ${{ vars.VERTEX_LOCATION }}
|
||||
|
||||
# Code Interpreter
|
||||
# TODO: debug why this is failing and enable
|
||||
|
||||
404
.github/workflows/pr-helm-chart-testing.yml
vendored
404
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -6,11 +6,11 @@ concurrency:
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
branches: [ main ]
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -18,233 +18,225 @@ permissions:
|
||||
jobs:
|
||||
helm-chart-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=8cpu-linux-x64,
|
||||
hdd=256,
|
||||
"run-id=${{ github.run_id }}-helm-chart-check",
|
||||
]
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}-helm-chart-check"]
|
||||
timeout-minutes: 45
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4.3.1
|
||||
with:
|
||||
version: v3.19.0
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4.3.1
|
||||
with:
|
||||
version: v3.19.0
|
||||
|
||||
- name: Set up chart-testing
|
||||
# NOTE: This is Jamison's patch from https://github.com/helm/chart-testing-action/pull/194
|
||||
uses: helm/chart-testing-action@8958a6ac472cbd8ee9a8fbb6f1acbc1b0e966e44 # zizmor: ignore[impostor-commit]
|
||||
with:
|
||||
uv_version: "0.9.9"
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@6ec842c01de15ebb84c8627d2744a0c2f2755c9f # ratchet:helm/chart-testing-action@v2.8.0
|
||||
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
env:
|
||||
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
|
||||
run: |
|
||||
echo "default_branch: ${DEFAULT_BRANCH}"
|
||||
changed=$(ct list-changed --remote origin --target-branch ${DEFAULT_BRANCH} --chart-dirs deployment/helm/charts)
|
||||
echo "list-changed output: $changed"
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
env:
|
||||
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
|
||||
run: |
|
||||
echo "default_branch: ${DEFAULT_BRANCH}"
|
||||
changed=$(ct list-changed --remote origin --target-branch ${DEFAULT_BRANCH} --chart-dirs deployment/helm/charts)
|
||||
echo "list-changed output: $changed"
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# uncomment to force run chart-testing
|
||||
# - name: Force run chart-testing (list-changed)
|
||||
# id: list-changed
|
||||
# run: echo "changed=true" >> $GITHUB_OUTPUT
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --config ct.yaml --all
|
||||
# the following would lint only changed charts, but linting isn't expensive
|
||||
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
# uncomment to force run chart-testing
|
||||
# - name: Force run chart-testing (list-changed)
|
||||
# id: list-changed
|
||||
# run: echo "changed=true" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@92086f6be054225fa813e0a4b13787fc9088faab # ratchet:helm/kind-action@v1.13.0
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --config ct.yaml --all
|
||||
# the following would lint only changed charts, but linting isn't expensive
|
||||
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Pre-install cluster status check
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-install Cluster Status ==="
|
||||
kubectl get nodes -o wide
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get storageclass
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@92086f6be054225fa813e0a4b13787fc9088faab # ratchet:helm/kind-action@v1.13.0
|
||||
|
||||
- name: Add Helm repositories and update
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Adding Helm repositories ==="
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
|
||||
helm repo update
|
||||
- name: Pre-install cluster status check
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-install Cluster Status ==="
|
||||
kubectl get nodes -o wide
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get storageclass
|
||||
|
||||
- name: Install Redis operator
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
echo "=== Installing redis-operator CRDs ==="
|
||||
helm upgrade --install redis-operator ot-container-kit/redis-operator \
|
||||
--namespace redis-operator --create-namespace --wait --timeout 300s
|
||||
- name: Add Helm repositories and update
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Adding Helm repositories ==="
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
|
||||
helm repo update
|
||||
|
||||
- name: Pre-pull required images
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-pulling required images to avoid timeout ==="
|
||||
KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//')
|
||||
echo "Kind cluster: $KIND_CLUSTER"
|
||||
- name: Install Redis operator
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
echo "=== Installing redis-operator CRDs ==="
|
||||
helm upgrade --install redis-operator ot-container-kit/redis-operator \
|
||||
--namespace redis-operator --create-namespace --wait --timeout 300s
|
||||
|
||||
IMAGES=(
|
||||
"ghcr.io/cloudnative-pg/cloudnative-pg:1.27.0"
|
||||
"quay.io/opstree/redis:v7.0.15"
|
||||
"docker.io/onyxdotapp/onyx-web-server:latest"
|
||||
)
|
||||
- name: Pre-pull required images
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-pulling required images to avoid timeout ==="
|
||||
KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//')
|
||||
echo "Kind cluster: $KIND_CLUSTER"
|
||||
|
||||
for image in "${IMAGES[@]}"; do
|
||||
echo "Pre-pulling $image"
|
||||
if docker pull "$image"; then
|
||||
kind load docker-image "$image" --name "$KIND_CLUSTER" || echo "Failed to load $image into kind"
|
||||
else
|
||||
echo "Failed to pull $image"
|
||||
fi
|
||||
done
|
||||
IMAGES=(
|
||||
"ghcr.io/cloudnative-pg/cloudnative-pg:1.27.0"
|
||||
"quay.io/opstree/redis:v7.0.15"
|
||||
"docker.io/onyxdotapp/onyx-web-server:latest"
|
||||
)
|
||||
|
||||
echo "=== Images loaded into Kind cluster ==="
|
||||
docker exec "$KIND_CLUSTER"-control-plane crictl images | grep -E "(cloudnative-pg|redis|onyx)" || echo "Some images may still be loading..."
|
||||
|
||||
- name: Validate chart dependencies
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Validating chart dependencies ==="
|
||||
cd deployment/helm/charts/onyx
|
||||
helm dependency update
|
||||
helm lint .
|
||||
|
||||
- name: Run chart-testing (install) with enhanced monitoring
|
||||
timeout-minutes: 25
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Starting chart installation with monitoring ==="
|
||||
|
||||
# Function to monitor cluster state
|
||||
monitor_cluster() {
|
||||
while true; do
|
||||
echo "=== Cluster Status Check at $(date) ==="
|
||||
# Only show non-running pods to reduce noise
|
||||
NON_RUNNING_PODS=$(kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded --no-headers 2>/dev/null | wc -l)
|
||||
if [ "$NON_RUNNING_PODS" -gt 0 ]; then
|
||||
echo "Non-running pods:"
|
||||
kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded
|
||||
else
|
||||
echo "All pods running successfully"
|
||||
fi
|
||||
# Only show recent events if there are issues
|
||||
RECENT_EVENTS=$(kubectl get events --sort-by=.lastTimestamp --all-namespaces --field-selector=type!=Normal 2>/dev/null | tail -5)
|
||||
if [ -n "$RECENT_EVENTS" ]; then
|
||||
echo "Recent warnings/errors:"
|
||||
echo "$RECENT_EVENTS"
|
||||
fi
|
||||
sleep 60
|
||||
done
|
||||
}
|
||||
|
||||
# Start monitoring in background
|
||||
monitor_cluster &
|
||||
MONITOR_PID=$!
|
||||
|
||||
# Set up cleanup
|
||||
cleanup() {
|
||||
echo "=== Cleaning up monitoring process ==="
|
||||
kill $MONITOR_PID 2>/dev/null || true
|
||||
echo "=== Final cluster state ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -20
|
||||
}
|
||||
|
||||
# Trap cleanup on exit
|
||||
trap cleanup EXIT
|
||||
|
||||
# Run the actual installation with detailed logging
|
||||
echo "=== Starting ct install ==="
|
||||
set +e
|
||||
ct install --all \
|
||||
--helm-extra-set-args="\
|
||||
--set=nginx.enabled=false \
|
||||
--set=minio.enabled=false \
|
||||
--set=vespa.enabled=false \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=postgresql.enabled=true \
|
||||
--set=postgresql.nameOverride=cloudnative-pg \
|
||||
--set=postgresql.cluster.storage.storageClass=standard \
|
||||
--set=redis.enabled=true \
|
||||
--set=redis.storageSpec.volumeClaimTemplate.spec.storageClassName=standard \
|
||||
--set=webserver.replicaCount=1 \
|
||||
--set=api.replicaCount=0 \
|
||||
--set=inferenceCapability.replicaCount=0 \
|
||||
--set=indexCapability.replicaCount=0 \
|
||||
--set=celery_beat.replicaCount=0 \
|
||||
--set=celery_worker_heavy.replicaCount=0 \
|
||||
--set=celery_worker_docfetching.replicaCount=0 \
|
||||
--set=celery_worker_docprocessing.replicaCount=0 \
|
||||
--set=celery_worker_light.replicaCount=0 \
|
||||
--set=celery_worker_monitoring.replicaCount=0 \
|
||||
--set=celery_worker_primary.replicaCount=0 \
|
||||
--set=celery_worker_user_file_processing.replicaCount=0 \
|
||||
--set=celery_worker_user_files_indexing.replicaCount=0" \
|
||||
--helm-extra-args="--timeout 900s --debug" \
|
||||
--debug --config ct.yaml
|
||||
CT_EXIT=$?
|
||||
set -e
|
||||
|
||||
if [[ $CT_EXIT -ne 0 ]]; then
|
||||
echo "ct install failed with exit code $CT_EXIT"
|
||||
exit $CT_EXIT
|
||||
for image in "${IMAGES[@]}"; do
|
||||
echo "Pre-pulling $image"
|
||||
if docker pull "$image"; then
|
||||
kind load docker-image "$image" --name "$KIND_CLUSTER" || echo "Failed to load $image into kind"
|
||||
else
|
||||
echo "=== Installation completed successfully ==="
|
||||
echo "Failed to pull $image"
|
||||
fi
|
||||
done
|
||||
|
||||
kubectl get pods --all-namespaces
|
||||
echo "=== Images loaded into Kind cluster ==="
|
||||
docker exec "$KIND_CLUSTER"-control-plane crictl images | grep -E "(cloudnative-pg|redis|onyx)" || echo "Some images may still be loading..."
|
||||
|
||||
- name: Post-install verification
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Post-install verification ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get services --all-namespaces
|
||||
# Only show issues if they exist
|
||||
kubectl describe pods --all-namespaces | grep -A 5 -B 2 "Failed\|Error\|Warning" || echo "No pod issues found"
|
||||
- name: Validate chart dependencies
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Validating chart dependencies ==="
|
||||
cd deployment/helm/charts/onyx
|
||||
helm dependency update
|
||||
helm lint .
|
||||
|
||||
- name: Cleanup on failure
|
||||
if: failure() && steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Cleanup on failure ==="
|
||||
- name: Run chart-testing (install) with enhanced monitoring
|
||||
timeout-minutes: 25
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Starting chart installation with monitoring ==="
|
||||
|
||||
# Function to monitor cluster state
|
||||
monitor_cluster() {
|
||||
while true; do
|
||||
echo "=== Cluster Status Check at $(date) ==="
|
||||
# Only show non-running pods to reduce noise
|
||||
NON_RUNNING_PODS=$(kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded --no-headers 2>/dev/null | wc -l)
|
||||
if [ "$NON_RUNNING_PODS" -gt 0 ]; then
|
||||
echo "Non-running pods:"
|
||||
kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded
|
||||
else
|
||||
echo "All pods running successfully"
|
||||
fi
|
||||
# Only show recent events if there are issues
|
||||
RECENT_EVENTS=$(kubectl get events --sort-by=.lastTimestamp --all-namespaces --field-selector=type!=Normal 2>/dev/null | tail -5)
|
||||
if [ -n "$RECENT_EVENTS" ]; then
|
||||
echo "Recent warnings/errors:"
|
||||
echo "$RECENT_EVENTS"
|
||||
fi
|
||||
sleep 60
|
||||
done
|
||||
}
|
||||
|
||||
# Start monitoring in background
|
||||
monitor_cluster &
|
||||
MONITOR_PID=$!
|
||||
|
||||
# Set up cleanup
|
||||
cleanup() {
|
||||
echo "=== Cleaning up monitoring process ==="
|
||||
kill $MONITOR_PID 2>/dev/null || true
|
||||
echo "=== Final cluster state ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -10
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -20
|
||||
}
|
||||
|
||||
echo "=== Pod descriptions for debugging ==="
|
||||
kubectl describe pods --all-namespaces | grep -A 10 -B 3 "Failed\|Error\|Warning\|Pending" || echo "No problematic pods found"
|
||||
# Trap cleanup on exit
|
||||
trap cleanup EXIT
|
||||
|
||||
echo "=== Recent logs for debugging ==="
|
||||
kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found"
|
||||
# Run the actual installation with detailed logging
|
||||
echo "=== Starting ct install ==="
|
||||
set +e
|
||||
ct install --all \
|
||||
--helm-extra-set-args="\
|
||||
--set=nginx.enabled=false \
|
||||
--set=minio.enabled=false \
|
||||
--set=vespa.enabled=false \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=postgresql.enabled=true \
|
||||
--set=postgresql.nameOverride=cloudnative-pg \
|
||||
--set=postgresql.cluster.storage.storageClass=standard \
|
||||
--set=redis.enabled=true \
|
||||
--set=redis.storageSpec.volumeClaimTemplate.spec.storageClassName=standard \
|
||||
--set=webserver.replicaCount=1 \
|
||||
--set=api.replicaCount=0 \
|
||||
--set=inferenceCapability.replicaCount=0 \
|
||||
--set=indexCapability.replicaCount=0 \
|
||||
--set=celery_beat.replicaCount=0 \
|
||||
--set=celery_worker_heavy.replicaCount=0 \
|
||||
--set=celery_worker_docfetching.replicaCount=0 \
|
||||
--set=celery_worker_docprocessing.replicaCount=0 \
|
||||
--set=celery_worker_light.replicaCount=0 \
|
||||
--set=celery_worker_monitoring.replicaCount=0 \
|
||||
--set=celery_worker_primary.replicaCount=0 \
|
||||
--set=celery_worker_user_file_processing.replicaCount=0 \
|
||||
--set=celery_worker_user_files_indexing.replicaCount=0" \
|
||||
--helm-extra-args="--timeout 900s --debug" \
|
||||
--debug --config ct.yaml
|
||||
CT_EXIT=$?
|
||||
set -e
|
||||
|
||||
echo "=== Helm releases ==="
|
||||
helm list --all-namespaces
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
# don't worry about that for now
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
if [[ $CT_EXIT -ne 0 ]]; then
|
||||
echo "ct install failed with exit code $CT_EXIT"
|
||||
exit $CT_EXIT
|
||||
else
|
||||
echo "=== Installation completed successfully ==="
|
||||
fi
|
||||
|
||||
kubectl get pods --all-namespaces
|
||||
|
||||
- name: Post-install verification
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Post-install verification ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get services --all-namespaces
|
||||
# Only show issues if they exist
|
||||
kubectl describe pods --all-namespaces | grep -A 5 -B 2 "Failed\|Error\|Warning" || echo "No pod issues found"
|
||||
|
||||
- name: Cleanup on failure
|
||||
if: failure() && steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Cleanup on failure ==="
|
||||
echo "=== Final cluster state ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -10
|
||||
|
||||
echo "=== Pod descriptions for debugging ==="
|
||||
kubectl describe pods --all-namespaces | grep -A 10 -B 3 "Failed\|Error\|Warning\|Pending" || echo "No problematic pods found"
|
||||
|
||||
echo "=== Recent logs for debugging ==="
|
||||
kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found"
|
||||
|
||||
echo "=== Helm releases ==="
|
||||
helm list --all-namespaces
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
# don't worry about that for now
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
17
.github/workflows/pr-integration-tests.yml
vendored
17
.github/workflows/pr-integration-tests.yml
vendored
@@ -56,7 +56,7 @@ jobs:
|
||||
id: set-matrix
|
||||
run: |
|
||||
# Find all leaf-level directories in both test directories
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" -exec basename {} \; | sort)
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
|
||||
# Create JSON array with directory info
|
||||
@@ -310,7 +310,6 @@ jobs:
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=1
|
||||
MCP_SERVER_ENABLED=true
|
||||
EOF
|
||||
|
||||
@@ -325,6 +324,7 @@ jobs:
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
mcp_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker
|
||||
@@ -367,6 +367,12 @@ jobs:
|
||||
}
|
||||
|
||||
wait_for_service "http://localhost:8080/health" "API server"
|
||||
test_dir="${{ matrix.test-dir.path }}"
|
||||
if [ "$test_dir" = "tests/mcp" ]; then
|
||||
wait_for_service "http://localhost:8090/health" "MCP server"
|
||||
else
|
||||
echo "Skipping MCP server wait for non-MCP suite: $test_dir"
|
||||
fi
|
||||
echo "Finished waiting for services."
|
||||
|
||||
- name: Start Mock Services
|
||||
@@ -396,6 +402,8 @@ jobs:
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e MCP_SERVER_HOST=mcp_server \
|
||||
-e MCP_SERVER_PORT=8090 \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
@@ -480,10 +488,10 @@ jobs:
|
||||
AUTH_TYPE=cloud \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
OPENAI_DEFAULT_API_KEY=${OPENAI_API_KEY} \
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
DEV_MODE=true \
|
||||
MCP_SERVER_ENABLED=true \
|
||||
docker compose -f docker-compose.multitenant-dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
@@ -492,6 +500,7 @@ jobs:
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
mcp_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker_multi_tenant
|
||||
@@ -540,6 +549,8 @@ jobs:
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e MCP_SERVER_HOST=mcp_server \
|
||||
-e MCP_SERVER_PORT=8090 \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
|
||||
12
.github/workflows/pr-mit-integration-tests.yml
vendored
12
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -48,7 +48,7 @@ jobs:
|
||||
id: set-matrix
|
||||
run: |
|
||||
# Find all leaf-level directories in both test directories
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" -exec basename {} \; | sort)
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
|
||||
# Create JSON array with directory info
|
||||
@@ -301,7 +301,6 @@ jobs:
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
MCP_SERVER_ENABLED=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=1
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
@@ -315,6 +314,7 @@ jobs:
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
mcp_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker
|
||||
@@ -357,6 +357,12 @@ jobs:
|
||||
}
|
||||
|
||||
wait_for_service "http://localhost:8080/health" "API server"
|
||||
test_dir="${{ matrix.test-dir.path }}"
|
||||
if [ "$test_dir" = "tests/mcp" ]; then
|
||||
wait_for_service "http://localhost:8090/health" "MCP server"
|
||||
else
|
||||
echo "Skipping MCP server wait for non-MCP suite: $test_dir"
|
||||
fi
|
||||
echo "Finished waiting for services."
|
||||
|
||||
- name: Start Mock Services
|
||||
@@ -387,6 +393,8 @@ jobs:
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e MCP_SERVER_HOST=mcp_server \
|
||||
-e MCP_SERVER_PORT=8090 \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
|
||||
11
.github/workflows/zizmor.yml
vendored
11
.github/workflows/zizmor.yml
vendored
@@ -21,29 +21,18 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Detect changes
|
||||
id: filter
|
||||
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # ratchet:dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
zizmor:
|
||||
- '.github/**'
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
- name: Run zizmor
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
run: uv run --no-sync --with zizmor zizmor --format=sarif . > results.sarif
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Upload SARIF file
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab # ratchet:github/codeql-action/upload-sarif@codeql-bundle-v2.23.5
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
|
||||
@@ -9,7 +9,7 @@ repos:
|
||||
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
hooks:
|
||||
- id: uv-sync
|
||||
args: ["--locked", "--all-extras"]
|
||||
args: ["--active", "--locked", "--all-extras"]
|
||||
- id: uv-lock
|
||||
files: ^pyproject\.toml$
|
||||
- id: uv-export
|
||||
|
||||
51
.vscode/env_template.txt
vendored
51
.vscode/env_template.txt
vendored
@@ -1,45 +1,36 @@
|
||||
# Copy this file to .env in the .vscode folder.
|
||||
# Fill in the <REPLACE THIS> values as needed; it is recommended to set the
|
||||
# GEN_AI_API_KEY value to avoid having to set up an LLM in the UI.
|
||||
# Also check out onyx/backend/scripts/restart_containers.sh for a script to
|
||||
# restart the containers which Onyx relies on outside of VSCode/Cursor
|
||||
# processes.
|
||||
# Copy this file to .env in the .vscode folder
|
||||
# Fill in the <REPLACE THIS> values as needed, it is recommended to set the GEN_AI_API_KEY value to avoid having to set up an LLM in the UI
|
||||
# Also check out onyx/backend/scripts/restart_containers.sh for a script to restart the containers which Onyx relies on outside of VSCode/Cursor processes
|
||||
|
||||
|
||||
# For local dev, often user Authentication is not needed.
|
||||
# For local dev, often user Authentication is not needed
|
||||
AUTH_TYPE=disabled
|
||||
|
||||
|
||||
# Always keep these on for Dev.
|
||||
# Logs model prompts, reasoning, and answer to stdout.
|
||||
# Always keep these on for Dev
|
||||
# Logs model prompts, reasoning, and answer to stdout
|
||||
LOG_ONYX_MODEL_INTERACTIONS=True
|
||||
# More verbose logging
|
||||
LOG_LEVEL=debug
|
||||
|
||||
|
||||
# This passes top N results to LLM an additional time for reranking prior to
|
||||
# answer generation.
|
||||
# This step is quite heavy on token usage so we disable it for dev generally.
|
||||
# This passes top N results to LLM an additional time for reranking prior to answer generation
|
||||
# This step is quite heavy on token usage so we disable it for dev generally
|
||||
DISABLE_LLM_DOC_RELEVANCE=False
|
||||
|
||||
|
||||
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically).
|
||||
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
|
||||
OAUTH_CLIENT_ID=<REPLACE THIS>
|
||||
OAUTH_CLIENT_SECRET=<REPLACE THIS>
|
||||
OPENID_CONFIG_URL=<REPLACE THIS>
|
||||
SAML_CONF_DIR=/<ABSOLUTE PATH TO ONYX>/onyx/backend/ee/onyx/configs/saml_config
|
||||
|
||||
|
||||
# Generally not useful for dev, we don't generally want to set up an SMTP server
|
||||
# for dev.
|
||||
# Generally not useful for dev, we don't generally want to set up an SMTP server for dev
|
||||
REQUIRE_EMAIL_VERIFICATION=False
|
||||
|
||||
|
||||
# Set these so if you wipe the DB, you don't end up having to go through the UI
|
||||
# every time.
|
||||
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
|
||||
GEN_AI_API_KEY=<REPLACE THIS>
|
||||
OPENAI_API_KEY=<REPLACE THIS>
|
||||
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper.
|
||||
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper
|
||||
GEN_AI_MODEL_VERSION=gpt-4o
|
||||
FAST_GEN_AI_MODEL_VERSION=gpt-4o
|
||||
|
||||
@@ -49,36 +40,26 @@ PYTHONPATH=../backend
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
|
||||
# Enable the full set of Danswer Enterprise Edition features.
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you
|
||||
# are using this for local testing/development).
|
||||
# Enable the full set of Danswer Enterprise Edition features
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
|
||||
|
||||
# S3 File Store Configuration (MinIO for local development)
|
||||
S3_ENDPOINT_URL=http://localhost:9004
|
||||
S3_FILE_STORE_BUCKET_NAME=onyx-file-store-bucket
|
||||
S3_AWS_ACCESS_KEY_ID=minioadmin
|
||||
S3_AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
|
||||
|
||||
# Show extra/uncommon connectors.
|
||||
# Show extra/uncommon connectors
|
||||
SHOW_EXTRA_CONNECTORS=True
|
||||
|
||||
|
||||
# Local langsmith tracing
|
||||
LANGSMITH_TRACING="true"
|
||||
LANGSMITH_ENDPOINT="https://api.smith.langchain.com"
|
||||
LANGSMITH_API_KEY=<REPLACE_THIS>
|
||||
LANGSMITH_PROJECT=<REPLACE_THIS>
|
||||
|
||||
|
||||
# Local Confluence OAuth testing
|
||||
# OAUTH_CONFLUENCE_CLOUD_CLIENT_ID=<REPLACE_THIS>
|
||||
# OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET=<REPLACE_THIS>
|
||||
# NEXT_PUBLIC_TEST_ENV=True
|
||||
|
||||
|
||||
# OpenSearch
|
||||
# Arbitrary password is fine for local development.
|
||||
OPENSEARCH_INITIAL_ADMIN_PASSWORD=<REPLACE THIS>
|
||||
# NEXT_PUBLIC_TEST_ENV=True
|
||||
15
.vscode/launch.template.jsonc
vendored
15
.vscode/launch.template.jsonc
vendored
@@ -512,21 +512,6 @@
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart OpenSearch Container",
|
||||
// Generic debugger type, required arg but has no bearing on bash.
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"${workspaceFolder}/backend/scripts/restart_opensearch_container.sh"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Eval CLI",
|
||||
"type": "debugpy",
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
# AGENTS.md
|
||||
|
||||
This file provides guidance to AI agents when working with code in this repository.
|
||||
This file provides guidance to Codex when working with code in this repository.
|
||||
|
||||
## KEY NOTES
|
||||
|
||||
- If you run into any missing python dependency errors, try running your command with `source .venv/bin/activate` \
|
||||
- If you run into any missing python dependency errors, try running your command with `source backend/.venv/bin/activate` \
|
||||
to assume the python venv.
|
||||
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@example.com` and password
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
|
||||
`a`. The app can be accessed at `http://localhost:3000`.
|
||||
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
|
||||
make sure we see logs coming out from the relevant service.
|
||||
@@ -181,286 +181,6 @@ web/
|
||||
└── src/lib/ # Utilities & business logic
|
||||
```
|
||||
|
||||
## Frontend Standards
|
||||
|
||||
### 1. Import Standards
|
||||
|
||||
**Always use absolute imports with the `@` prefix.**
|
||||
|
||||
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useAuth } from "@/hooks/useAuth";
|
||||
import { Text } from "@/refresh-components/texts/Text";
|
||||
|
||||
// ❌ Bad
|
||||
import { Button } from "../../../components/ui/button";
|
||||
import { useAuth } from "./hooks/useAuth";
|
||||
```
|
||||
|
||||
### 2. React Component Functions
|
||||
|
||||
**Prefer regular functions over arrow functions for React components.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
function UserProfile({ userId }: UserProfileProps) {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
const UserProfile = ({ userId }: UserProfileProps) => {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Props Interface Extraction
|
||||
|
||||
**Extract prop types into their own interface definitions.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
interface UserCardProps {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}
|
||||
|
||||
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({
|
||||
user,
|
||||
showActions = false,
|
||||
onEdit
|
||||
}: {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Spacing Guidelines
|
||||
|
||||
**Prefer padding over margins for spacing.**
|
||||
|
||||
**Reason:** We want to consolidate usage to paddings instead of margins.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
<div className="p-4 space-y-2">
|
||||
<div className="p-2">Content</div>
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className="m-4 space-y-2">
|
||||
<div className="m-2">Content</div>
|
||||
</div>
|
||||
```
|
||||
|
||||
### 5. Tailwind Dark Mode
|
||||
|
||||
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
|
||||
|
||||
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
|
||||
|
||||
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Standard components use `web/tailwind-themes/tailwind.config.js` / `web/src/app/css/colors.css`
|
||||
<div className="bg-background-neutral-03 text-text-02">
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
|
||||
export const GithubIcon = createLogoIcon(githubLightIcon, {
|
||||
monochromatic: true, // Will apply dark:invert internally
|
||||
});
|
||||
|
||||
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
|
||||
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
|
||||
});
|
||||
|
||||
// ❌ Bad - Manual dark mode overrides
|
||||
<div className="bg-white dark:bg-black text-black dark:text-white">
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
### 6. Class Name Utilities
|
||||
|
||||
**Use the `cn` utility instead of raw string formatting for classNames.**
|
||||
|
||||
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
|
||||
|
||||
```typescript
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
// ✅ Good
|
||||
<div className={cn(
|
||||
'base-class',
|
||||
isActive && 'active-class',
|
||||
className
|
||||
)}>
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
### 7. Custom Hooks Organization
|
||||
|
||||
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
|
||||
|
||||
**Reason:** This is just a layout preference. Keeps code clean.
|
||||
|
||||
```typescript
|
||||
// web/src/hooks/useUserData.ts
|
||||
export function useUserData(userId: string) {
|
||||
// hook implementation
|
||||
}
|
||||
|
||||
// web/src/hooks/useLocalStorage.ts
|
||||
export function useLocalStorage<T>(key: string, initialValue: T) {
|
||||
// hook implementation
|
||||
}
|
||||
```
|
||||
|
||||
### 8. Icon Usage
|
||||
|
||||
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
|
||||
|
||||
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import SvgX from "@/icons/x";
|
||||
import SvgMoreHorizontal from "@/icons/more-horizontal";
|
||||
|
||||
// ❌ Bad
|
||||
import { User } from "lucide-react";
|
||||
import { FiSearch } from "react-icons/fi";
|
||||
```
|
||||
|
||||
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
|
||||
If you need help with this step, reach out to `raunak@onyx.app`.
|
||||
|
||||
### 9. Text Rendering
|
||||
|
||||
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
|
||||
|
||||
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Text } from '@/refresh-components/texts/Text'
|
||||
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<Text
|
||||
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
|
||||
text03
|
||||
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
|
||||
mainAction
|
||||
>
|
||||
{name}
|
||||
</Text>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<div>
|
||||
<h2>{name}</h2>
|
||||
<p>User details</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 10. Component Usage
|
||||
|
||||
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
|
||||
|
||||
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import Button from '@/refresh-components/buttons/Button'
|
||||
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
|
||||
import SvgPlusCircle from '@/icons/plus-circle'
|
||||
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<InputTypeIn placeholder="Search..." />
|
||||
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<input placeholder="Name" />
|
||||
<textarea placeholder="Message" />
|
||||
<button type="submit">Submit</button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 11. Colors
|
||||
|
||||
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
|
||||
|
||||
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
|
||||
|
||||
**Available color categories:**
|
||||
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
|
||||
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
|
||||
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
|
||||
- **Actions:** `action-link-XX`, `action-danger-XX`
|
||||
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
|
||||
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Use custom Onyx color classes
|
||||
<div className="bg-background-neutral-01 border border-border-02" />
|
||||
<div className="bg-background-tint-02 border border-border-01" />
|
||||
<div className="bg-status-success-01" />
|
||||
<div className="bg-action-link-01" />
|
||||
<div className="bg-theme-primary-05" />
|
||||
|
||||
// ❌ Bad - Do NOT use standard Tailwind colors
|
||||
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
|
||||
<div className="bg-white border border-slate-200" />
|
||||
<div className="bg-green-100 text-green-700" />
|
||||
<div className="bg-blue-100 text-blue-600" />
|
||||
<div className="bg-indigo-500" />
|
||||
```
|
||||
|
||||
### 12. Data Fetching
|
||||
|
||||
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
|
||||
|
||||
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
### Running Migrations
|
||||
@@ -575,6 +295,14 @@ will be tailing their logs to this file.
|
||||
- Token management and rate limiting
|
||||
- Custom prompts and agent actions
|
||||
|
||||
## UI/UX Patterns
|
||||
|
||||
- Tailwind CSS with design system in `web/src/components/ui/`
|
||||
- Radix UI and Headless UI for accessible components
|
||||
- SWR for data fetching and caching
|
||||
- Form validation with react-hook-form
|
||||
- Error handling with popup notifications
|
||||
|
||||
## Creating a Plan
|
||||
When creating a plan in the `plans` directory, make sure to include at least these elements:
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
||||
- If you run into any missing python dependency errors, try running your command with `source .venv/bin/activate` \
|
||||
to assume the python venv.
|
||||
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@example.com` and password
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
|
||||
`a`. The app can be accessed at `http://localhost:3000`.
|
||||
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
|
||||
make sure we see logs coming out from the relevant service.
|
||||
@@ -184,286 +184,6 @@ web/
|
||||
└── src/lib/ # Utilities & business logic
|
||||
```
|
||||
|
||||
## Frontend Standards
|
||||
|
||||
### 1. Import Standards
|
||||
|
||||
**Always use absolute imports with the `@` prefix.**
|
||||
|
||||
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useAuth } from "@/hooks/useAuth";
|
||||
import { Text } from "@/refresh-components/texts/Text";
|
||||
|
||||
// ❌ Bad
|
||||
import { Button } from "../../../components/ui/button";
|
||||
import { useAuth } from "./hooks/useAuth";
|
||||
```
|
||||
|
||||
### 2. React Component Functions
|
||||
|
||||
**Prefer regular functions over arrow functions for React components.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
function UserProfile({ userId }: UserProfileProps) {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
const UserProfile = ({ userId }: UserProfileProps) => {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Props Interface Extraction
|
||||
|
||||
**Extract prop types into their own interface definitions.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
interface UserCardProps {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}
|
||||
|
||||
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({
|
||||
user,
|
||||
showActions = false,
|
||||
onEdit
|
||||
}: {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Spacing Guidelines
|
||||
|
||||
**Prefer padding over margins for spacing.**
|
||||
|
||||
**Reason:** We want to consolidate usage to paddings instead of margins.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
<div className="p-4 space-y-2">
|
||||
<div className="p-2">Content</div>
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className="m-4 space-y-2">
|
||||
<div className="m-2">Content</div>
|
||||
</div>
|
||||
```
|
||||
|
||||
### 5. Tailwind Dark Mode
|
||||
|
||||
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
|
||||
|
||||
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
|
||||
|
||||
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Standard components use `tailwind-themes/tailwind.config.js` / `src/app/css/colors.css`
|
||||
<div className="bg-background-neutral-03 text-text-02">
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
|
||||
export const GithubIcon = createLogoIcon(githubLightIcon, {
|
||||
monochromatic: true, // Will apply dark:invert internally
|
||||
});
|
||||
|
||||
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
|
||||
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
|
||||
});
|
||||
|
||||
// ❌ Bad - Manual dark mode overrides
|
||||
<div className="bg-white dark:bg-black text-black dark:text-white">
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
### 6. Class Name Utilities
|
||||
|
||||
**Use the `cn` utility instead of raw string formatting for classNames.**
|
||||
|
||||
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
|
||||
|
||||
```typescript
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
// ✅ Good
|
||||
<div className={cn(
|
||||
'base-class',
|
||||
isActive && 'active-class',
|
||||
className
|
||||
)}>
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
### 7. Custom Hooks Organization
|
||||
|
||||
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
|
||||
|
||||
**Reason:** This is just a layout preference. Keeps code clean.
|
||||
|
||||
```typescript
|
||||
// web/src/hooks/useUserData.ts
|
||||
export function useUserData(userId: string) {
|
||||
// hook implementation
|
||||
}
|
||||
|
||||
// web/src/hooks/useLocalStorage.ts
|
||||
export function useLocalStorage<T>(key: string, initialValue: T) {
|
||||
// hook implementation
|
||||
}
|
||||
```
|
||||
|
||||
### 8. Icon Usage
|
||||
|
||||
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
|
||||
|
||||
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import SvgX from "@/icons/x";
|
||||
import SvgMoreHorizontal from "@/icons/more-horizontal";
|
||||
|
||||
// ❌ Bad
|
||||
import { User } from "lucide-react";
|
||||
import { FiSearch } from "react-icons/fi";
|
||||
```
|
||||
|
||||
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
|
||||
If you need help with this step, reach out to `raunak@onyx.app`.
|
||||
|
||||
### 9. Text Rendering
|
||||
|
||||
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
|
||||
|
||||
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Text } from '@/refresh-components/texts/Text'
|
||||
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<Text
|
||||
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
|
||||
text03
|
||||
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
|
||||
mainAction
|
||||
>
|
||||
{name}
|
||||
</Text>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<div>
|
||||
<h2>{name}</h2>
|
||||
<p>User details</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 10. Component Usage
|
||||
|
||||
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
|
||||
|
||||
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import Button from '@/refresh-components/buttons/Button'
|
||||
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
|
||||
import SvgPlusCircle from '@/icons/plus-circle'
|
||||
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<InputTypeIn placeholder="Search..." />
|
||||
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<input placeholder="Name" />
|
||||
<textarea placeholder="Message" />
|
||||
<button type="submit">Submit</button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 11. Colors
|
||||
|
||||
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
|
||||
|
||||
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
|
||||
|
||||
**Available color categories:**
|
||||
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
|
||||
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
|
||||
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
|
||||
- **Actions:** `action-link-XX`, `action-danger-XX`
|
||||
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
|
||||
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Use custom Onyx color classes
|
||||
<div className="bg-background-neutral-01 border border-border-02" />
|
||||
<div className="bg-background-tint-02 border border-border-01" />
|
||||
<div className="bg-status-success-01" />
|
||||
<div className="bg-action-link-01" />
|
||||
<div className="bg-theme-primary-05" />
|
||||
|
||||
// ❌ Bad - Do NOT use standard Tailwind colors
|
||||
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
|
||||
<div className="bg-white border border-slate-200" />
|
||||
<div className="bg-green-100 text-green-700" />
|
||||
<div className="bg-blue-100 text-blue-600" />
|
||||
<div className="bg-indigo-500" />
|
||||
```
|
||||
|
||||
### 12. Data Fetching
|
||||
|
||||
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
|
||||
|
||||
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
### Running Migrations
|
||||
@@ -580,6 +300,14 @@ will be tailing their logs to this file.
|
||||
- Token management and rate limiting
|
||||
- Custom prompts and agent actions
|
||||
|
||||
## UI/UX Patterns
|
||||
|
||||
- Tailwind CSS with design system in `web/src/components/ui/`
|
||||
- Radix UI and Headless UI for accessible components
|
||||
- SWR for data fetching and caching
|
||||
- Form validation with react-hook-form
|
||||
- Error handling with popup notifications
|
||||
|
||||
## Creating a Plan
|
||||
When creating a plan in the `plans` directory, make sure to include at least these elements:
|
||||
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
"""usage_limits
|
||||
|
||||
Revision ID: 2b90f3af54b8
|
||||
Revises: 9a0296d7421e
|
||||
Create Date: 2026-01-03 16:55:30.449692
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2b90f3af54b8"
|
||||
down_revision = "9a0296d7421e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"tenant_usage",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"window_start", sa.DateTime(timezone=True), nullable=False, index=True
|
||||
),
|
||||
sa.Column("llm_cost_cents", sa.Float(), nullable=False, server_default="0.0"),
|
||||
sa.Column("chunks_indexed", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("api_calls", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column(
|
||||
"non_streaming_api_calls", sa.Integer(), nullable=False, server_default="0"
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("window_start", name="uq_tenant_usage_window"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_tenant_usage_window_start", table_name="tenant_usage")
|
||||
op.drop_table("tenant_usage")
|
||||
@@ -11,7 +11,7 @@ from pydantic import BaseModel, ConfigDict
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
from onyx.llm.llm_provider_options import (
|
||||
fetch_model_names_for_provider_as_set,
|
||||
fetch_visible_model_names_for_provider_as_set,
|
||||
)
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
"""backend driven notification details
|
||||
|
||||
Revision ID: 5c3dca366b35
|
||||
Revises: 9087b548dd69
|
||||
Create Date: 2026-01-06 16:03:11.413724
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5c3dca366b35"
|
||||
down_revision = "9087b548dd69"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"notification",
|
||||
sa.Column(
|
||||
"title", sa.String(), nullable=False, server_default="New Notification"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"notification",
|
||||
sa.Column("description", sa.String(), nullable=True, server_default=""),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("notification", "title")
|
||||
op.drop_column("notification", "description")
|
||||
@@ -1,75 +0,0 @@
|
||||
"""nullify_default_task_prompt
|
||||
|
||||
Revision ID: 699221885109
|
||||
Revises: 7e490836d179
|
||||
Create Date: 2025-12-30 10:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "699221885109"
|
||||
down_revision = "7e490836d179"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
DEFAULT_PERSONA_ID = 0
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Make task_prompt column nullable
|
||||
# Note: The model had nullable=True but the DB column was NOT NULL until this point
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"task_prompt",
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Set task_prompt to NULL for the default persona
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET task_prompt = NULL
|
||||
WHERE id = :persona_id
|
||||
"""
|
||||
),
|
||||
{"persona_id": DEFAULT_PERSONA_ID},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore task_prompt to empty string for the default persona
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET task_prompt = ''
|
||||
WHERE id = :persona_id AND task_prompt IS NULL
|
||||
"""
|
||||
),
|
||||
{"persona_id": DEFAULT_PERSONA_ID},
|
||||
)
|
||||
|
||||
# Set any remaining NULL task_prompts to empty string before making non-nullable
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET task_prompt = ''
|
||||
WHERE task_prompt IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Revert task_prompt column to not nullable
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"task_prompt",
|
||||
nullable=False,
|
||||
)
|
||||
@@ -1,54 +0,0 @@
|
||||
"""add image generation config table
|
||||
|
||||
Revision ID: 7206234e012a
|
||||
Revises: 699221885109
|
||||
Create Date: 2025-12-21 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7206234e012a"
|
||||
down_revision = "699221885109"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"image_generation_config",
|
||||
sa.Column("image_provider_id", sa.String(), primary_key=True),
|
||||
sa.Column("model_configuration_id", sa.Integer(), nullable=False),
|
||||
sa.Column("is_default", sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["model_configuration_id"],
|
||||
["model_configuration.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_image_generation_config_is_default",
|
||||
"image_generation_config",
|
||||
["is_default"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_image_generation_config_model_configuration_id",
|
||||
"image_generation_config",
|
||||
["model_configuration_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
"ix_image_generation_config_model_configuration_id",
|
||||
table_name="image_generation_config",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_image_generation_config_is_default", table_name="image_generation_config"
|
||||
)
|
||||
op.drop_table("image_generation_config")
|
||||
@@ -10,7 +10,7 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
from onyx.llm.llm_provider_options import (
|
||||
fetch_model_names_for_provider_as_set,
|
||||
fetch_visible_model_names_for_provider_as_set,
|
||||
)
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
"""nullify_default_system_prompt
|
||||
|
||||
Revision ID: 7e490836d179
|
||||
Revises: c1d2e3f4a5b6
|
||||
Create Date: 2025-12-29 16:54:36.635574
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7e490836d179"
|
||||
down_revision = "c1d2e3f4a5b6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# This is the default system prompt from the previous migration (87c52ec39f84)
|
||||
# ruff: noqa: E501, W605 start
|
||||
PREVIOUS_DEFAULT_SYSTEM_PROMPT = """
|
||||
You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the user's intent, ask clarifying questions when needed, think step-by-step through complex problems, provide clear and accurate answers, and proactively anticipate helpful follow-up information. Always prioritize being truthful, nuanced, insightful, and efficient.
|
||||
|
||||
The current date is [[CURRENT_DATETIME]].[[CITATION_GUIDANCE]]
|
||||
|
||||
# Response Style
|
||||
You use different text styles, bolding, emojis (sparingly), block quotes, and other formatting to make your responses more readable and engaging.
|
||||
You use proper Markdown and LaTeX to format your responses for math, scientific, and chemical formulas, symbols, etc.: '$$\\n[expression]\\n$$' for standalone cases and '\\( [expression] \\)' when inline.
|
||||
For code you prefer to use Markdown and specify the language.
|
||||
You can use horizontal rules (---) to separate sections of your responses.
|
||||
You can use Markdown tables to format your responses for data, lists, and other structured information.
|
||||
""".lstrip()
|
||||
# ruff: noqa: E501, W605 end
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Make system_prompt column nullable (model already has nullable=True but DB doesn't)
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"system_prompt",
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Set system_prompt to NULL where it matches the previous default
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET system_prompt = NULL
|
||||
WHERE system_prompt = :previous_default
|
||||
"""
|
||||
),
|
||||
{"previous_default": PREVIOUS_DEFAULT_SYSTEM_PROMPT},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore the default system prompt for personas that have NULL
|
||||
# Note: This may restore the prompt to personas that originally had NULL
|
||||
# before this migration, but there's no way to distinguish them
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET system_prompt = :previous_default
|
||||
WHERE system_prompt IS NULL
|
||||
"""
|
||||
),
|
||||
{"previous_default": PREVIOUS_DEFAULT_SYSTEM_PROMPT},
|
||||
)
|
||||
|
||||
# Revert system_prompt column to not nullable
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"system_prompt",
|
||||
nullable=False,
|
||||
)
|
||||
@@ -1,136 +0,0 @@
|
||||
"""seed_default_image_gen_config
|
||||
|
||||
Revision ID: 9087b548dd69
|
||||
Revises: 2b90f3af54b8
|
||||
Create Date: 2026-01-05 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9087b548dd69"
|
||||
down_revision = "2b90f3af54b8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# Constants for default image generation config
|
||||
# Source: web/src/app/admin/configuration/image-generation/constants.ts
|
||||
IMAGE_PROVIDER_ID = "openai_gpt_image_1"
|
||||
MODEL_NAME = "gpt-image-1"
|
||||
PROVIDER_NAME = "openai"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Check if image_generation_config table already has records
|
||||
existing_configs = (
|
||||
conn.execute(sa.text("SELECT COUNT(*) FROM image_generation_config")).scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
if existing_configs > 0:
|
||||
# Skip if configs already exist - user may have configured manually
|
||||
return
|
||||
|
||||
# Find the first OpenAI LLM provider
|
||||
openai_provider = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, api_key
|
||||
FROM llm_provider
|
||||
WHERE provider = :provider
|
||||
ORDER BY id
|
||||
LIMIT 1
|
||||
"""
|
||||
),
|
||||
{"provider": PROVIDER_NAME},
|
||||
).fetchone()
|
||||
|
||||
if not openai_provider:
|
||||
# No OpenAI provider found - nothing to do
|
||||
return
|
||||
|
||||
source_provider_id, api_key = openai_provider
|
||||
|
||||
# Create new LLM provider for image generation (clone only api_key)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO llm_provider (
|
||||
name, provider, api_key, api_base, api_version,
|
||||
deployment_name, default_model_name, is_public,
|
||||
is_default_provider, is_default_vision_provider, is_auto_mode
|
||||
)
|
||||
VALUES (
|
||||
:name, :provider, :api_key, NULL, NULL,
|
||||
NULL, :default_model_name, :is_public,
|
||||
NULL, NULL, :is_auto_mode
|
||||
)
|
||||
RETURNING id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"name": f"Image Gen - {IMAGE_PROVIDER_ID}",
|
||||
"provider": PROVIDER_NAME,
|
||||
"api_key": api_key,
|
||||
"default_model_name": MODEL_NAME,
|
||||
"is_public": True,
|
||||
"is_auto_mode": False,
|
||||
},
|
||||
)
|
||||
new_provider_id = result.scalar()
|
||||
|
||||
# Create model configuration
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO model_configuration (
|
||||
llm_provider_id, name, is_visible, max_input_tokens,
|
||||
supports_image_input, display_name
|
||||
)
|
||||
VALUES (
|
||||
:llm_provider_id, :name, :is_visible, :max_input_tokens,
|
||||
:supports_image_input, :display_name
|
||||
)
|
||||
RETURNING id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"llm_provider_id": new_provider_id,
|
||||
"name": MODEL_NAME,
|
||||
"is_visible": True,
|
||||
"max_input_tokens": None,
|
||||
"supports_image_input": False,
|
||||
"display_name": None,
|
||||
},
|
||||
)
|
||||
model_config_id = result.scalar()
|
||||
|
||||
# Create image generation config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO image_generation_config (
|
||||
image_provider_id, model_configuration_id, is_default
|
||||
)
|
||||
VALUES (
|
||||
:image_provider_id, :model_configuration_id, :is_default
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"image_provider_id": IMAGE_PROVIDER_ID,
|
||||
"model_configuration_id": model_config_id,
|
||||
"is_default": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We don't remove the config on downgrade since it's safe to keep around
|
||||
# If we upgrade again, it will be a no-op due to the existing records check
|
||||
pass
|
||||
@@ -1,33 +0,0 @@
|
||||
"""add_is_auto_mode_to_llm_provider
|
||||
|
||||
Revision ID: 9a0296d7421e
|
||||
Revises: 7206234e012a
|
||||
Create Date: 2025-12-17 18:14:29.620981
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9a0296d7421e"
|
||||
down_revision = "7206234e012a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column(
|
||||
"is_auto_mode",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("llm_provider", "is_auto_mode")
|
||||
@@ -234,8 +234,6 @@ def downgrade() -> None:
|
||||
if "instructions" in columns:
|
||||
op.drop_column("user_project", "instructions")
|
||||
op.execute("ALTER TABLE user_project RENAME TO user_folder")
|
||||
# Update NULL descriptions to empty string before setting NOT NULL constraint
|
||||
op.execute("UPDATE user_folder SET description = '' WHERE description IS NULL")
|
||||
op.alter_column("user_folder", "description", nullable=False)
|
||||
logger.info("Renamed user_project back to user_folder")
|
||||
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
"""remove userfile related deprecated fields
|
||||
|
||||
Revision ID: a3c1a7904cd0
|
||||
Revises: 5c3dca366b35
|
||||
Create Date: 2026-01-06 13:00:30.634396
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a3c1a7904cd0"
|
||||
down_revision = "5c3dca366b35"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("user_file", "document_id")
|
||||
op.drop_column("user_file", "document_id_migrated")
|
||||
op.drop_column("connector_credential_pair", "is_user_file")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("is_user_file", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column("document_id", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"document_id_migrated", sa.Boolean(), nullable=False, server_default="true"
|
||||
),
|
||||
)
|
||||
@@ -111,6 +111,10 @@ CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS = float(
|
||||
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
|
||||
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
|
||||
|
||||
OPENAI_DEFAULT_API_KEY = os.environ.get("OPENAI_DEFAULT_API_KEY")
|
||||
ANTHROPIC_DEFAULT_API_KEY = os.environ.get("ANTHROPIC_DEFAULT_API_KEY")
|
||||
COHERE_DEFAULT_API_KEY = os.environ.get("COHERE_DEFAULT_API_KEY")
|
||||
|
||||
# JWT Public Key URL
|
||||
JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
|
||||
|
||||
|
||||
@@ -34,7 +34,6 @@ def make_persona_private(
|
||||
create_notification(
|
||||
user_id=user_id,
|
||||
notif_type=NotificationType.PERSONA_SHARED,
|
||||
title="A new agent was shared with you!",
|
||||
db_session=db_session,
|
||||
additional_data=PersonaSharedNotificationData(
|
||||
persona_id=persona_id,
|
||||
|
||||
@@ -21,9 +21,8 @@ from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
|
||||
router = APIRouter(prefix="/analytics", tags=PUBLIC_API_TAGS)
|
||||
router = APIRouter(prefix="/analytics")
|
||||
|
||||
|
||||
_DEFAULT_LOOKBACK_DAYS = 30
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import List
|
||||
|
||||
@@ -24,12 +23,6 @@ class NavigationItem(BaseModel):
|
||||
return instance
|
||||
|
||||
|
||||
class LogoDisplayStyle(str, Enum):
|
||||
LOGO_AND_NAME = "logo_and_name"
|
||||
LOGO_ONLY = "logo_only"
|
||||
NAME_ONLY = "name_only"
|
||||
|
||||
|
||||
class EnterpriseSettings(BaseModel):
|
||||
"""General settings that only apply to the Enterprise Edition of Onyx
|
||||
|
||||
@@ -38,7 +31,6 @@ class EnterpriseSettings(BaseModel):
|
||||
application_name: str | None = None
|
||||
use_custom_logo: bool = False
|
||||
use_custom_logotype: bool = False
|
||||
logo_display_style: LogoDisplayStyle | None = None
|
||||
|
||||
# custom navigation
|
||||
custom_nav_items: List[NavigationItem] = Field(default_factory=list)
|
||||
@@ -50,9 +42,6 @@ class EnterpriseSettings(BaseModel):
|
||||
custom_popup_header: str | None = None
|
||||
custom_popup_content: str | None = None
|
||||
enable_consent_screen: bool | None = None
|
||||
consent_screen_prompt: str | None = None
|
||||
show_first_visit_notice: bool | None = None
|
||||
custom_greeting_message: str | None = None
|
||||
|
||||
def check_validity(self) -> None:
|
||||
return
|
||||
|
||||
@@ -106,6 +106,7 @@ def handle_simplified_chat_message(
|
||||
new_msg_req=full_chat_msg_info,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
enforce_chat_session_id_for_search_docs=False,
|
||||
)
|
||||
|
||||
return gather_stream(packets)
|
||||
@@ -209,6 +210,7 @@ def handle_send_message_simple_with_history(
|
||||
new_msg_req=full_chat_msg_info,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
enforce_chat_session_id_for_search_docs=False,
|
||||
)
|
||||
|
||||
return gather_stream(packets)
|
||||
|
||||
@@ -48,7 +48,6 @@ from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.query_and_chat.models import ChatSessionDetails
|
||||
from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
from onyx.utils.threadpool_concurrency import parallel_yield
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -295,7 +294,7 @@ def list_all_query_history_exports(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/query-history/start-export", tags=PUBLIC_API_TAGS)
|
||||
@router.post("/admin/query-history/start-export")
|
||||
def start_query_history_export(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
@@ -341,7 +340,7 @@ def start_query_history_export(
|
||||
return {"request_id": task_id}
|
||||
|
||||
|
||||
@router.get("/admin/query-history/export-status", tags=PUBLIC_API_TAGS)
|
||||
@router.get("/admin/query-history/export-status")
|
||||
def get_query_history_export_status(
|
||||
request_id: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
@@ -375,7 +374,7 @@ def get_query_history_export_status(
|
||||
return {"status": TaskStatus.SUCCESS}
|
||||
|
||||
|
||||
@router.get("/admin/query-history/download", tags=PUBLIC_API_TAGS)
|
||||
@router.get("/admin/query-history/download")
|
||||
def download_query_history_csv(
|
||||
request_id: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
"""Tenant-specific usage limit overrides from the control plane (EE version)."""
|
||||
|
||||
import requests
|
||||
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.server.tenant_usage_limits import TenantUsageLimitOverrides
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# In-memory storage for tenant overrides (populated at startup)
|
||||
_tenant_usage_limit_overrides: dict[str, TenantUsageLimitOverrides] | None = None
|
||||
|
||||
|
||||
def fetch_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides]:
|
||||
"""
|
||||
Fetch tenant-specific usage limit overrides from the control plane.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tenant_id to their specific limit overrides.
|
||||
Returns empty dict on any error (falls back to defaults).
|
||||
"""
|
||||
try:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/usage-limit-overrides"
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
tenant_overrides = response.json()
|
||||
|
||||
# Parse each tenant's overrides
|
||||
result: dict[str, TenantUsageLimitOverrides] = {}
|
||||
for override_data in tenant_overrides:
|
||||
tenant_id = override_data["tenant_id"]
|
||||
try:
|
||||
result[tenant_id] = TenantUsageLimitOverrides(**override_data)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to parse usage limit overrides for tenant {tenant_id}: {e}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warning(f"Failed to fetch usage limit overrides from control plane: {e}")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing usage limit overrides: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def load_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides]:
|
||||
"""
|
||||
Load tenant usage limit overrides from the control plane.
|
||||
|
||||
Called at server startup to populate the in-memory cache.
|
||||
"""
|
||||
global _tenant_usage_limit_overrides
|
||||
|
||||
logger.info("Loading tenant usage limit overrides from control plane...")
|
||||
overrides = fetch_usage_limit_overrides()
|
||||
_tenant_usage_limit_overrides = overrides
|
||||
|
||||
if overrides:
|
||||
logger.info(f"Loaded usage limit overrides for {len(overrides)} tenants")
|
||||
else:
|
||||
logger.info("No tenant-specific usage limit overrides found")
|
||||
return overrides
|
||||
|
||||
|
||||
def get_tenant_usage_limit_overrides(
|
||||
tenant_id: str,
|
||||
) -> TenantUsageLimitOverrides | None:
|
||||
"""
|
||||
Get the usage limit overrides for a specific tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID to look up
|
||||
|
||||
Returns:
|
||||
TenantUsageLimitOverrides if the tenant has overrides, None otherwise.
|
||||
"""
|
||||
global _tenant_usage_limit_overrides
|
||||
if _tenant_usage_limit_overrides is None:
|
||||
_tenant_usage_limit_overrides = load_usage_limit_overrides()
|
||||
return _tenant_usage_limit_overrides.get(tenant_id)
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import aiohttp # Async HTTP client
|
||||
@@ -9,7 +10,10 @@ from fastapi import Request
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
|
||||
from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
|
||||
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import TenantByDomainResponse
|
||||
from ee.onyx.server.tenants.models import TenantCreationPayload
|
||||
@@ -21,18 +25,11 @@ from ee.onyx.server.tenants.user_mapping import add_users_to_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import user_owns_a_tenant
|
||||
from onyx.auth.users import exceptions
|
||||
from onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
|
||||
from onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from onyx.configs.app_configs import OPENROUTER_DEFAULT_API_KEY
|
||||
from onyx.configs.app_configs import VERTEXAI_DEFAULT_CREDENTIALS
|
||||
from onyx.configs.app_configs import VERTEXAI_DEFAULT_LOCATION
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.image_generation import create_default_image_gen_config_from_api_key
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
@@ -40,24 +37,13 @@ from onyx.db.models import AvailableTenant
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import UserTenantMapping
|
||||
from onyx.llm.well_known_providers.auto_update_models import LLMRecommendations
|
||||
from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import VERTEX_CREDENTIALS_FILE_KWARG
|
||||
from onyx.llm.well_known_providers.constants import VERTEX_LOCATION_KWARG
|
||||
from onyx.llm.well_known_providers.constants import VERTEXAI_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
get_recommendations,
|
||||
)
|
||||
from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
model_configurations_for_provider,
|
||||
)
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.llm_provider_options import get_anthropic_model_names
|
||||
from onyx.llm.llm_provider_options import get_openai_model_names
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.setup import setup_onyx
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -66,7 +52,7 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_or_provision_tenant(
|
||||
@@ -275,173 +261,59 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
logger.info(f"Tenant rollback completed successfully for tenant {tenant_id}")
|
||||
|
||||
|
||||
def _build_model_configuration_upsert_requests(
|
||||
provider_name: str,
|
||||
recommendations: LLMRecommendations,
|
||||
) -> list[ModelConfigurationUpsertRequest]:
|
||||
model_configurations = model_configurations_for_provider(
|
||||
provider_name, recommendations
|
||||
)
|
||||
return [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_configuration.name,
|
||||
is_visible=model_configuration.is_visible,
|
||||
max_input_tokens=model_configuration.max_input_tokens,
|
||||
supports_image_input=model_configuration.supports_image_input,
|
||||
)
|
||||
for model_configuration in model_configurations
|
||||
]
|
||||
|
||||
|
||||
def configure_default_api_keys(db_session: Session) -> None:
|
||||
"""Configure default LLM providers using recommended-models.json for model selection."""
|
||||
# Load recommendations from JSON config
|
||||
recommendations = get_recommendations()
|
||||
|
||||
has_set_default_provider = False
|
||||
|
||||
def _upsert(request: LLMProviderUpsertRequest) -> None:
|
||||
nonlocal has_set_default_provider
|
||||
try:
|
||||
provider = upsert_llm_provider(request, db_session)
|
||||
if not has_set_default_provider:
|
||||
update_default_provider(provider.id, db_session)
|
||||
has_set_default_provider = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure {request.provider} provider: {e}")
|
||||
|
||||
# Configure OpenAI provider
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
default_model = recommendations.get_default_model(OPENAI_PROVIDER_NAME)
|
||||
if default_model is None:
|
||||
logger.error(
|
||||
f"No default model found for {OPENAI_PROVIDER_NAME} in recommendations"
|
||||
)
|
||||
default_model_name = default_model.name if default_model else "gpt-5.2"
|
||||
|
||||
openai_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
OPENAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openai_provider)
|
||||
|
||||
# Create default image generation config using the OpenAI API key
|
||||
try:
|
||||
create_default_image_gen_config_from_api_key(
|
||||
db_session, OPENAI_DEFAULT_API_KEY
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create default image gen config: {e}")
|
||||
else:
|
||||
logger.info(
|
||||
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
|
||||
)
|
||||
|
||||
# Configure Anthropic provider
|
||||
if ANTHROPIC_DEFAULT_API_KEY:
|
||||
default_model = recommendations.get_default_model(ANTHROPIC_PROVIDER_NAME)
|
||||
if default_model is None:
|
||||
logger.error(
|
||||
f"No default model found for {ANTHROPIC_PROVIDER_NAME} in recommendations"
|
||||
)
|
||||
default_model_name = (
|
||||
default_model.name if default_model else "claude-sonnet-4-5"
|
||||
)
|
||||
|
||||
anthropic_provider = LLMProviderUpsertRequest(
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
provider=LlmProviderNames.ANTHROPIC,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
ANTHROPIC_PROVIDER_NAME, recommendations
|
||||
),
|
||||
default_model_name="claude-3-7-sonnet-20250219",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=name,
|
||||
is_visible=False,
|
||||
max_input_tokens=None,
|
||||
)
|
||||
for name in get_anthropic_model_names()
|
||||
],
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(anthropic_provider)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure Anthropic provider: {e}")
|
||||
else:
|
||||
logger.info(
|
||||
logger.error(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
)
|
||||
|
||||
# Configure Vertex AI provider
|
||||
if VERTEXAI_DEFAULT_CREDENTIALS:
|
||||
default_model = recommendations.get_default_model(VERTEXAI_PROVIDER_NAME)
|
||||
if default_model is None:
|
||||
logger.error(
|
||||
f"No default model found for {VERTEXAI_PROVIDER_NAME} in recommendations"
|
||||
)
|
||||
default_model_name = default_model.name if default_model else "gemini-2.5-pro"
|
||||
|
||||
# Vertex AI uses custom_config for credentials and location
|
||||
custom_config = {
|
||||
VERTEX_CREDENTIALS_FILE_KWARG: VERTEXAI_DEFAULT_CREDENTIALS,
|
||||
VERTEX_LOCATION_KWARG: VERTEXAI_DEFAULT_LOCATION,
|
||||
}
|
||||
|
||||
vertexai_provider = LLMProviderUpsertRequest(
|
||||
name="Google Vertex AI",
|
||||
provider=VERTEXAI_PROVIDER_NAME,
|
||||
custom_config=custom_config,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
VERTEXAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
openai_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_name,
|
||||
is_visible=False,
|
||||
max_input_tokens=None,
|
||||
)
|
||||
for model_name in get_openai_model_names()
|
||||
],
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(vertexai_provider)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(openai_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure OpenAI provider: {e}")
|
||||
else:
|
||||
logger.info(
|
||||
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
|
||||
logger.error(
|
||||
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
|
||||
)
|
||||
|
||||
# Configure OpenRouter provider
|
||||
if OPENROUTER_DEFAULT_API_KEY:
|
||||
default_model = recommendations.get_default_model(OPENROUTER_PROVIDER_NAME)
|
||||
if default_model is None:
|
||||
logger.error(
|
||||
f"No default model found for {OPENROUTER_PROVIDER_NAME} in recommendations"
|
||||
)
|
||||
default_model_name = default_model.name if default_model else "z-ai/glm-4.7"
|
||||
|
||||
# For OpenRouter, we use the visible models from recommendations as model_configurations
|
||||
# since OpenRouter models are dynamic (fetched from their API)
|
||||
visible_models = recommendations.get_visible_models(OPENROUTER_PROVIDER_NAME)
|
||||
model_configurations = [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model.name,
|
||||
is_visible=True,
|
||||
max_input_tokens=None,
|
||||
display_name=model.display_name,
|
||||
)
|
||||
for model in visible_models
|
||||
]
|
||||
|
||||
openrouter_provider = LLMProviderUpsertRequest(
|
||||
name="OpenRouter",
|
||||
provider=OPENROUTER_PROVIDER_NAME,
|
||||
api_key=OPENROUTER_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=model_configurations,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openrouter_provider)
|
||||
else:
|
||||
logger.info(
|
||||
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"
|
||||
)
|
||||
|
||||
# Configure Cohere embedding provider
|
||||
if COHERE_DEFAULT_API_KEY:
|
||||
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
|
||||
provider_type=EmbeddingProvider.COHERE,
|
||||
|
||||
@@ -16,9 +16,8 @@ from onyx.db.token_limit import insert_user_token_rate_limit
|
||||
from onyx.server.query_and_chat.token_limit import any_rate_limit_exists
|
||||
from onyx.server.token_rate_limits.models import TokenRateLimitArgs
|
||||
from onyx.server.token_rate_limits.models import TokenRateLimitDisplay
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
|
||||
router = APIRouter(prefix="/admin/token-rate-limits", tags=PUBLIC_API_TAGS)
|
||||
router = APIRouter(prefix="/admin/token-rate-limits")
|
||||
|
||||
|
||||
"""
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
"""EE Usage limits - trial detection via billing information."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_tenant_on_trial(tenant_id: str) -> bool:
|
||||
"""
|
||||
Determine if a tenant is currently on a trial subscription.
|
||||
|
||||
In multi-tenant mode, we fetch billing information from the control plane
|
||||
to determine if the tenant has an active trial.
|
||||
"""
|
||||
if not MULTI_TENANT:
|
||||
return False
|
||||
|
||||
try:
|
||||
billing_info = fetch_billing_information(tenant_id)
|
||||
|
||||
# If not subscribed at all, check if we have trial information
|
||||
if isinstance(billing_info, SubscriptionStatusResponse):
|
||||
# No subscription means they're likely on trial (new tenant)
|
||||
return True
|
||||
|
||||
if isinstance(billing_info, BillingInformation):
|
||||
# Check if trial is active
|
||||
if billing_info.trial_end is not None:
|
||||
now = datetime.now(timezone.utc)
|
||||
# Trial active if trial_end is in the future
|
||||
# and subscription status indicates trialing
|
||||
if billing_info.trial_end > now and billing_info.status == "trialing":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch billing info for trial check: {e}")
|
||||
# Default to trial limits on error (more restrictive = safer)
|
||||
return True
|
||||
@@ -21,12 +21,11 @@ from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS)
|
||||
router = APIRouter(prefix="/manage")
|
||||
|
||||
|
||||
@router.get("/admin/user-group")
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
"""Captcha verification for user registration."""
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.app_configs import CAPTCHA_ENABLED
|
||||
from onyx.configs.app_configs import RECAPTCHA_SCORE_THRESHOLD
|
||||
from onyx.configs.app_configs import RECAPTCHA_SECRET_KEY
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
RECAPTCHA_VERIFY_URL = "https://www.google.com/recaptcha/api/siteverify"
|
||||
|
||||
|
||||
class CaptchaVerificationError(Exception):
|
||||
"""Raised when captcha verification fails."""
|
||||
|
||||
|
||||
class RecaptchaResponse(BaseModel):
|
||||
"""Response from Google reCAPTCHA verification API."""
|
||||
|
||||
success: bool
|
||||
score: float | None = None # Only present for reCAPTCHA v3
|
||||
action: str | None = None
|
||||
challenge_ts: str | None = None
|
||||
hostname: str | None = None
|
||||
error_codes: list[str] | None = Field(default=None, alias="error-codes")
|
||||
|
||||
|
||||
def is_captcha_enabled() -> bool:
|
||||
"""Check if captcha verification is enabled."""
|
||||
return CAPTCHA_ENABLED and bool(RECAPTCHA_SECRET_KEY)
|
||||
|
||||
|
||||
async def verify_captcha_token(
|
||||
token: str,
|
||||
expected_action: str = "signup",
|
||||
) -> None:
|
||||
"""
|
||||
Verify a reCAPTCHA token with Google's API.
|
||||
|
||||
Args:
|
||||
token: The reCAPTCHA response token from the client
|
||||
expected_action: Expected action name for v3 verification
|
||||
|
||||
Raises:
|
||||
CaptchaVerificationError: If verification fails
|
||||
"""
|
||||
if not is_captcha_enabled():
|
||||
return
|
||||
|
||||
if not token:
|
||||
raise CaptchaVerificationError("Captcha token is required")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
RECAPTCHA_VERIFY_URL,
|
||||
data={
|
||||
"secret": RECAPTCHA_SECRET_KEY,
|
||||
"response": token,
|
||||
},
|
||||
timeout=10.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
result = RecaptchaResponse(**data)
|
||||
|
||||
if not result.success:
|
||||
error_codes = result.error_codes or ["unknown-error"]
|
||||
logger.warning(f"Captcha verification failed: {error_codes}")
|
||||
raise CaptchaVerificationError(
|
||||
f"Captcha verification failed: {', '.join(error_codes)}"
|
||||
)
|
||||
|
||||
# For reCAPTCHA v3, also check the score
|
||||
if result.score is not None:
|
||||
if result.score < RECAPTCHA_SCORE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Captcha score too low: {result.score} < {RECAPTCHA_SCORE_THRESHOLD}"
|
||||
)
|
||||
raise CaptchaVerificationError(
|
||||
"Captcha verification failed: suspicious activity detected"
|
||||
)
|
||||
|
||||
# Optionally verify the action matches
|
||||
if result.action and result.action != expected_action:
|
||||
logger.warning(
|
||||
f"Captcha action mismatch: {result.action} != {expected_action}"
|
||||
)
|
||||
raise CaptchaVerificationError(
|
||||
"Captcha verification failed: action mismatch"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Captcha verification passed: score={result.score}, "
|
||||
f"action={result.action}"
|
||||
)
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Captcha API request failed: {e}")
|
||||
# In case of API errors, we might want to allow registration
|
||||
# to prevent blocking legitimate users. This is a policy decision.
|
||||
raise CaptchaVerificationError("Captcha verification service unavailable")
|
||||
@@ -1,192 +0,0 @@
|
||||
"""
|
||||
Utility to validate and block disposable/temporary email addresses.
|
||||
|
||||
This module fetches a list of known disposable email domains from a remote source
|
||||
and caches them for performance. It's used during user registration to prevent
|
||||
abuse from temporary email services.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from typing import Set
|
||||
|
||||
import httpx
|
||||
|
||||
from onyx.configs.app_configs import DISPOSABLE_EMAIL_DOMAINS_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DisposableEmailValidator:
|
||||
"""
|
||||
Thread-safe singleton validator for disposable email domains.
|
||||
|
||||
Fetches and caches the list of disposable domains, with periodic refresh.
|
||||
"""
|
||||
|
||||
_instance: "DisposableEmailValidator | None" = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls) -> "DisposableEmailValidator":
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Check if already initialized using a try/except to avoid type issues
|
||||
try:
|
||||
if self._initialized:
|
||||
return
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
self._domains: Set[str] = set()
|
||||
self._last_fetch_time: float = 0
|
||||
self._fetch_lock = threading.Lock()
|
||||
# Cache for 1 hour
|
||||
self._cache_duration = 3600
|
||||
# Hardcoded fallback list of common disposable domains
|
||||
# This ensures we block at least these even if the remote fetch fails
|
||||
self._fallback_domains = {
|
||||
"trashlify.com",
|
||||
"10minutemail.com",
|
||||
"guerrillamail.com",
|
||||
"mailinator.com",
|
||||
"tempmail.com",
|
||||
"throwaway.email",
|
||||
"yopmail.com",
|
||||
"temp-mail.org",
|
||||
"getnada.com",
|
||||
"maildrop.cc",
|
||||
}
|
||||
# Set initialized flag last to prevent race conditions
|
||||
self._initialized: bool = True
|
||||
|
||||
def _should_refresh(self) -> bool:
|
||||
"""Check if the cached domains should be refreshed."""
|
||||
return (time.time() - self._last_fetch_time) > self._cache_duration
|
||||
|
||||
def _fetch_domains(self) -> Set[str]:
|
||||
"""
|
||||
Fetch disposable email domains from the configured URL.
|
||||
|
||||
Returns:
|
||||
Set of domain strings (lowercased)
|
||||
"""
|
||||
if not DISPOSABLE_EMAIL_DOMAINS_URL:
|
||||
logger.debug("DISPOSABLE_EMAIL_DOMAINS_URL not configured")
|
||||
return self._fallback_domains.copy()
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"Fetching disposable email domains from {DISPOSABLE_EMAIL_DOMAINS_URL}"
|
||||
)
|
||||
with httpx.Client(timeout=10.0) as client:
|
||||
response = client.get(DISPOSABLE_EMAIL_DOMAINS_URL)
|
||||
response.raise_for_status()
|
||||
|
||||
domains_list = response.json()
|
||||
|
||||
if not isinstance(domains_list, list):
|
||||
logger.error(
|
||||
f"Expected list from disposable domains URL, got {type(domains_list)}"
|
||||
)
|
||||
return self._fallback_domains.copy()
|
||||
|
||||
# Convert all to lowercase and create set
|
||||
domains = {domain.lower().strip() for domain in domains_list if domain}
|
||||
|
||||
# Always include fallback domains
|
||||
domains.update(self._fallback_domains)
|
||||
|
||||
logger.info(
|
||||
f"Successfully fetched {len(domains)} disposable email domains"
|
||||
)
|
||||
return domains
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.warning(f"Failed to fetch disposable domains (HTTP error): {e}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch disposable domains: {e}")
|
||||
|
||||
# On error, return fallback domains
|
||||
return self._fallback_domains.copy()
|
||||
|
||||
def get_domains(self) -> Set[str]:
|
||||
"""
|
||||
Get the cached set of disposable email domains.
|
||||
Refreshes the cache if needed.
|
||||
|
||||
Returns:
|
||||
Set of disposable domain strings (lowercased)
|
||||
"""
|
||||
# Fast path: return cached domains if still fresh
|
||||
if self._domains and not self._should_refresh():
|
||||
return self._domains.copy()
|
||||
|
||||
# Slow path: need to refresh
|
||||
with self._fetch_lock:
|
||||
# Double-check after acquiring lock
|
||||
if self._domains and not self._should_refresh():
|
||||
return self._domains.copy()
|
||||
|
||||
self._domains = self._fetch_domains()
|
||||
self._last_fetch_time = time.time()
|
||||
return self._domains.copy()
|
||||
|
||||
def is_disposable(self, email: str) -> bool:
|
||||
"""
|
||||
Check if an email address uses a disposable domain.
|
||||
|
||||
Args:
|
||||
email: The email address to check
|
||||
|
||||
Returns:
|
||||
True if the email domain is disposable, False otherwise
|
||||
"""
|
||||
if not email or "@" not in email:
|
||||
return False
|
||||
|
||||
parts = email.split("@")
|
||||
if len(parts) != 2 or not parts[0]: # Must have user@domain with non-empty user
|
||||
return False
|
||||
|
||||
domain = parts[1].lower().strip()
|
||||
if not domain: # Domain part must not be empty
|
||||
return False
|
||||
|
||||
disposable_domains = self.get_domains()
|
||||
return domain in disposable_domains
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_validator = DisposableEmailValidator()
|
||||
|
||||
|
||||
def is_disposable_email(email: str) -> bool:
|
||||
"""
|
||||
Check if an email address uses a disposable/temporary domain.
|
||||
|
||||
This is a convenience function that uses the global validator instance.
|
||||
|
||||
Args:
|
||||
email: The email address to check
|
||||
|
||||
Returns:
|
||||
True if the email uses a disposable domain, False otherwise
|
||||
"""
|
||||
return _validator.is_disposable(email)
|
||||
|
||||
|
||||
def refresh_disposable_domains() -> None:
|
||||
"""
|
||||
Force a refresh of the disposable domains list.
|
||||
|
||||
This can be called manually if you want to update the list
|
||||
without waiting for the cache to expire.
|
||||
"""
|
||||
_validator._last_fetch_time = 0
|
||||
_validator.get_domains()
|
||||
@@ -40,8 +40,6 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
tenant_id: str | None = None
|
||||
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
|
||||
captcha_token: str | None = None
|
||||
|
||||
|
||||
class UserUpdateWithRole(schemas.BaseUserUpdate):
|
||||
|
||||
@@ -60,7 +60,6 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.disposable_email_validator import is_disposable_email
|
||||
from onyx.auth.email_utils import send_forgot_password_email
|
||||
from onyx.auth.email_utils import send_user_verification_email
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
@@ -249,23 +248,13 @@ def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
|
||||
|
||||
|
||||
def verify_email_domain(email: str) -> None:
|
||||
if email.count("@") != 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email is not valid",
|
||||
)
|
||||
|
||||
domain = email.split("@")[-1].lower()
|
||||
|
||||
# Check if email uses a disposable/temporary domain
|
||||
if is_disposable_email(email):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Disposable email addresses are not allowed. Please use a permanent email address.",
|
||||
)
|
||||
|
||||
# Check domain whitelist if configured
|
||||
if VALID_EMAIL_DOMAINS:
|
||||
if email.count("@") != 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email is not valid",
|
||||
)
|
||||
domain = email.split("@")[-1].lower()
|
||||
if domain not in VALID_EMAIL_DOMAINS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -303,57 +292,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> User:
|
||||
# Verify captcha if enabled (for cloud signup protection)
|
||||
from onyx.auth.captcha import CaptchaVerificationError
|
||||
from onyx.auth.captcha import is_captcha_enabled
|
||||
from onyx.auth.captcha import verify_captcha_token
|
||||
|
||||
if is_captcha_enabled() and request is not None:
|
||||
# Get captcha token from request body or headers
|
||||
captcha_token = None
|
||||
if hasattr(user_create, "captcha_token"):
|
||||
captcha_token = getattr(user_create, "captcha_token", None)
|
||||
|
||||
# Also check headers as a fallback
|
||||
if not captcha_token:
|
||||
captcha_token = request.headers.get("X-Captcha-Token")
|
||||
|
||||
try:
|
||||
await verify_captcha_token(
|
||||
captcha_token or "", expected_action="signup"
|
||||
)
|
||||
except CaptchaVerificationError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": str(e)},
|
||||
)
|
||||
|
||||
# We verify the password here to make sure it's valid before we proceed
|
||||
await self.validate_password(
|
||||
user_create.password, cast(schemas.UC, user_create)
|
||||
)
|
||||
|
||||
# Check for disposable emails BEFORE provisioning tenant
|
||||
# This prevents creating tenants for throwaway email addresses
|
||||
try:
|
||||
verify_email_domain(user_create.email)
|
||||
except HTTPException as e:
|
||||
# Log blocked disposable email attempts
|
||||
if (
|
||||
e.status_code == status.HTTP_400_BAD_REQUEST
|
||||
and "Disposable email" in str(e.detail)
|
||||
):
|
||||
domain = (
|
||||
user_create.email.split("@")[-1]
|
||||
if "@" in user_create.email
|
||||
else "unknown"
|
||||
)
|
||||
logger.warning(
|
||||
f"Blocked disposable email registration attempt: {domain}",
|
||||
extra={"email_domain": domain},
|
||||
)
|
||||
raise
|
||||
|
||||
user_count: int | None = None
|
||||
referral_source = (
|
||||
request.cookies.get("referral_source", None)
|
||||
@@ -375,17 +318,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
async with get_async_session_context_manager(tenant_id) as db_session:
|
||||
# Check invite list based on deployment mode
|
||||
if MULTI_TENANT:
|
||||
# Multi-tenant: Only require invite for existing tenants
|
||||
# New tenant creation (first user) doesn't require an invite
|
||||
user_count = await get_user_count()
|
||||
if user_count > 0:
|
||||
# Tenant already has users - require invite for new users
|
||||
verify_email_is_invited(user_create.email)
|
||||
else:
|
||||
# Single-tenant: Check invite list (skips if SAML/OIDC or no list configured)
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
|
||||
db_session, User, OAuthAccount
|
||||
|
||||
@@ -26,7 +26,6 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.background.celery.celery_utils import make_probe_path
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFIX
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_FOR_ONYX
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
@@ -516,9 +515,6 @@ def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
|
||||
if ENABLE_OPENSEARCH_FOR_ONYX:
|
||||
return
|
||||
|
||||
if not wait_for_vespa_with_timeout():
|
||||
msg = "Vespa: Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
|
||||
@@ -98,5 +98,8 @@ for bootstep in base_bootsteps:
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
# Ensure the user files indexing worker registers the doc_id migration task
|
||||
# TODO(subash): remove this once the doc_id migration is complete
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -2,12 +2,8 @@ import copy
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from celery.schedules import crontab
|
||||
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -57,6 +53,16 @@ beat_task_templates: list[dict] = [
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "user-file-docid-migration",
|
||||
"task": OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
|
||||
"schedule": timedelta(minutes=10),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.USER_FILES_INDEXING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-kg-processing",
|
||||
"task": OnyxCeleryTask.CHECK_KG_PROCESSING,
|
||||
@@ -165,32 +171,13 @@ if ENTERPRISE_EDITION_ENABLED:
|
||||
]
|
||||
)
|
||||
|
||||
# Add the Auto LLM update task if the config URL is set (has a default)
|
||||
if AUTO_LLM_CONFIG_URL:
|
||||
# Only add the LLM model update task if the API URL is configured
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
beat_task_templates.append(
|
||||
{
|
||||
"name": "check-for-auto-llm-update",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_AUTO_LLM_UPDATE,
|
||||
"schedule": timedelta(seconds=AUTO_LLM_UPDATE_INTERVAL_SECONDS),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": AUTO_LLM_UPDATE_INTERVAL_SECONDS,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Add scheduled eval task if datasets are configured
|
||||
if SCHEDULED_EVAL_DATASET_NAMES:
|
||||
beat_task_templates.append(
|
||||
{
|
||||
"name": "scheduled-eval-pipeline",
|
||||
"task": OnyxCeleryTask.SCHEDULED_EVAL_TASK,
|
||||
# run every Sunday at midnight UTC
|
||||
"schedule": crontab(
|
||||
hour=0,
|
||||
minute=0,
|
||||
day_of_week=0,
|
||||
),
|
||||
"name": "check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"schedule": timedelta(hours=1), # Check every hour
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
|
||||
@@ -72,6 +72,15 @@ def try_creating_docfetching_task(
|
||||
# Another indexing attempt is already running
|
||||
return None
|
||||
|
||||
# Determine which queue to use based on whether this is a user file
|
||||
# TODO: at the moment the indexing pipeline is
|
||||
# shared between user files and connectors
|
||||
queue = (
|
||||
OnyxCeleryQueues.USER_FILES_INDEXING
|
||||
if cc_pair.is_user_file
|
||||
else OnyxCeleryQueues.CONNECTOR_DOC_FETCHING
|
||||
)
|
||||
|
||||
# Use higher priority for first-time indexing to ensure new connectors
|
||||
# get processed before re-indexing of existing connectors
|
||||
has_successful_attempt = cc_pair.last_successful_index_time is not None
|
||||
@@ -90,7 +99,7 @@ def try_creating_docfetching_task(
|
||||
search_settings_id=search_settings.id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
|
||||
queue=queue,
|
||||
task_id=custom_task_id,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
@@ -12,7 +12,6 @@ from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
@@ -41,11 +40,9 @@ from onyx.background.indexing.checkpointing_utils import (
|
||||
)
|
||||
from onyx.background.indexing.index_attempt_utils import cleanup_index_attempts
|
||||
from onyx.background.indexing.index_attempt_utils import get_old_index_attempts
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
@@ -62,9 +59,11 @@ from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import (
|
||||
fetch_indexable_standard_connector_credential_pair_ids,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import (
|
||||
fetch_indexable_user_file_connector_credential_pair_ids,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import set_cc_pair_repeated_error_state
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair_from_id
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.time_utils import get_db_current_time
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -113,7 +112,6 @@ from onyx.utils.telemetry import RecordType
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import USAGE_LIMITS_ENABLED
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
|
||||
|
||||
@@ -540,7 +538,12 @@ def check_indexing_completion(
|
||||
]:
|
||||
# User file connectors must be paused on success
|
||||
# NOTE: _run_indexing doesn't update connectors if the index attempt is the future embedding model
|
||||
cc_pair.status = ConnectorCredentialPairStatus.ACTIVE
|
||||
# TODO: figure out why this doesn't pause connectors during swap
|
||||
cc_pair.status = (
|
||||
ConnectorCredentialPairStatus.PAUSED
|
||||
if cc_pair.is_user_file
|
||||
else ConnectorCredentialPairStatus.ACTIVE
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
mt_cloud_telemetry(
|
||||
@@ -806,8 +809,13 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
db_session, active_cc_pairs_only=True
|
||||
)
|
||||
)
|
||||
user_file_cc_pair_ids = (
|
||||
fetch_indexable_user_file_connector_credential_pair_ids(
|
||||
db_session, search_settings_id=current_search_settings.id
|
||||
)
|
||||
)
|
||||
|
||||
primary_cc_pair_ids = standard_cc_pair_ids
|
||||
primary_cc_pair_ids = standard_cc_pair_ids + user_file_cc_pair_ids
|
||||
|
||||
# Get CC pairs for secondary search settings
|
||||
secondary_cc_pair_ids: list[int] = []
|
||||
@@ -823,47 +831,30 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
db_session, active_cc_pairs_only=not include_paused
|
||||
)
|
||||
)
|
||||
user_file_cc_pair_ids = (
|
||||
fetch_indexable_user_file_connector_credential_pair_ids(
|
||||
db_session, search_settings_id=secondary_search_settings.id
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
secondary_cc_pair_ids = standard_cc_pair_ids
|
||||
secondary_cc_pair_ids = standard_cc_pair_ids + user_file_cc_pair_ids
|
||||
|
||||
# Flag CC pairs in repeated error state for primary/current search settings
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for cc_pair_id in primary_cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
if is_in_repeated_error_state(
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
# if already in repeated error state, don't do anything
|
||||
# this is important so that we don't keep pausing the connector
|
||||
# immediately upon a user un-pausing it to manually re-trigger and
|
||||
# recover.
|
||||
if (
|
||||
cc_pair
|
||||
and not cc_pair.in_repeated_error_state
|
||||
and is_in_repeated_error_state(
|
||||
cc_pair=cc_pair,
|
||||
search_settings_id=current_search_settings.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
search_settings_id=current_search_settings.id,
|
||||
db_session=db_session,
|
||||
):
|
||||
set_cc_pair_repeated_error_state(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
in_repeated_error_state=True,
|
||||
)
|
||||
# When entering repeated error state, also pause the connector
|
||||
# to prevent continued indexing retry attempts burning through embedding credits.
|
||||
# NOTE: only for Cloud, since most self-hosted users use self-hosted embedding
|
||||
# models. Also, they are more prone to repeated failures -> eventual success.
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
update_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=ConnectorCredentialPairStatus.PAUSED,
|
||||
)
|
||||
|
||||
# NOTE: At this point, we haven't done heavy checks on whether or not the CC pairs should actually be indexed
|
||||
# Heavy check, should_index(), is called in _kickoff_indexing_tasks
|
||||
@@ -1288,26 +1279,6 @@ def docprocessing_task(
|
||||
INDEX_ATTEMPT_INFO_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def _check_chunk_usage_limit(tenant_id: str) -> None:
|
||||
"""Check if chunk indexing usage limit has been exceeded.
|
||||
|
||||
Raises UsageLimitExceededError if the limit is exceeded.
|
||||
"""
|
||||
if not USAGE_LIMITS_ENABLED:
|
||||
return
|
||||
|
||||
from onyx.db.usage import UsageType
|
||||
from onyx.server.usage_limits import check_usage_and_raise
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
check_usage_and_raise(
|
||||
db_session=db_session,
|
||||
usage_type=UsageType.CHUNKS_INDEXED,
|
||||
tenant_id=tenant_id,
|
||||
pending_amount=0, # Just check current usage
|
||||
)
|
||||
|
||||
|
||||
def _docprocessing_task(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
@@ -1319,25 +1290,6 @@ def _docprocessing_task(
|
||||
if tenant_id:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Check if chunk indexing usage limit has been exceeded before processing
|
||||
if USAGE_LIMITS_ENABLED:
|
||||
try:
|
||||
_check_chunk_usage_limit(tenant_id)
|
||||
except HTTPException as e:
|
||||
# Log the error and fail the indexing attempt
|
||||
task_logger.error(
|
||||
f"Chunk indexing usage limit exceeded for tenant {tenant_id}: {e}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason=str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
task_logger.info(
|
||||
f"Processing document batch: "
|
||||
f"attempt={index_attempt_id} "
|
||||
@@ -1482,23 +1434,6 @@ def _docprocessing_task(
|
||||
adapter=adapter,
|
||||
)
|
||||
|
||||
# Track chunk indexing usage for cloud usage limits
|
||||
if USAGE_LIMITS_ENABLED and index_pipeline_result.total_chunks > 0:
|
||||
try:
|
||||
from onyx.db.usage import increment_usage
|
||||
from onyx.db.usage import UsageType
|
||||
|
||||
with get_session_with_current_tenant() as usage_db_session:
|
||||
increment_usage(
|
||||
db_session=usage_db_session,
|
||||
usage_type=UsageType.CHUNKS_INDEXED,
|
||||
amount=index_pipeline_result.total_chunks,
|
||||
)
|
||||
usage_db_session.commit()
|
||||
except Exception as e:
|
||||
# Log but don't fail indexing if usage tracking fails
|
||||
task_logger.warning(f"Failed to track chunk indexing usage: {e}")
|
||||
|
||||
# Update batch completion and document counts atomically using database coordination
|
||||
|
||||
with get_session_with_current_tenant() as db_session, cross_batch_db_lock:
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine.time_utils import get_db_current_time
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
@@ -125,9 +126,18 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
|
||||
|
||||
def is_in_repeated_error_state(
|
||||
cc_pair: ConnectorCredentialPair, search_settings_id: int, db_session: Session
|
||||
cc_pair_id: int, search_settings_id: int, db_session: Session
|
||||
) -> bool:
|
||||
"""Checks if the cc pair / search setting combination is in a repeated error state."""
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise RuntimeError(
|
||||
f"is_in_repeated_error_state - could not find cc_pair with id={cc_pair_id}"
|
||||
)
|
||||
|
||||
# if the connector doesn't have a refresh_freq, a single failed attempt is enough
|
||||
number_of_failed_attempts_in_a_row_needed = (
|
||||
NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE
|
||||
@@ -136,7 +146,7 @@ def is_in_repeated_error_state(
|
||||
)
|
||||
|
||||
most_recent_index_attempts = get_recent_attempts_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=search_settings_id,
|
||||
limit=number_of_failed_attempts_in_a_row_needed,
|
||||
db_session=db_session,
|
||||
@@ -170,7 +180,7 @@ def should_index(
|
||||
db_session=db_session,
|
||||
)
|
||||
all_recent_errored = is_in_repeated_error_state(
|
||||
cc_pair=cc_pair,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings_instance.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
@@ -1,15 +1,9 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from onyx.configs.app_configs import BRAINTRUST_API_KEY
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_PERMISSIONS_EMAIL
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_PROJECT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.evals.eval import run_eval
|
||||
from onyx.evals.models import EvalConfigurationOptions
|
||||
@@ -39,109 +33,3 @@ def eval_run_task(
|
||||
except Exception:
|
||||
logger.error("Failed to run eval task")
|
||||
raise
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.SCHEDULED_EVAL_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT * 5, # Allow more time for multiple datasets
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def scheduled_eval_task(self: Task, **kwargs: Any) -> None:
|
||||
"""
|
||||
Scheduled task to run evaluations on configured datasets.
|
||||
Runs weekly on Sunday at midnight UTC.
|
||||
|
||||
Configure via environment variables (with defaults):
|
||||
- SCHEDULED_EVAL_DATASET_NAMES: Comma-separated list of Braintrust dataset names
|
||||
- SCHEDULED_EVAL_PERMISSIONS_EMAIL: Email for search permissions (default: roshan@onyx.app)
|
||||
- SCHEDULED_EVAL_PROJECT: Braintrust project name
|
||||
"""
|
||||
if not BRAINTRUST_API_KEY:
|
||||
logger.error("BRAINTRUST_API_KEY is not configured, cannot run scheduled evals")
|
||||
return
|
||||
|
||||
if not SCHEDULED_EVAL_PROJECT:
|
||||
logger.error(
|
||||
"SCHEDULED_EVAL_PROJECT is not configured, cannot run scheduled evals"
|
||||
)
|
||||
return
|
||||
|
||||
if not SCHEDULED_EVAL_DATASET_NAMES:
|
||||
logger.info("No scheduled eval datasets configured, skipping")
|
||||
return
|
||||
|
||||
if not SCHEDULED_EVAL_PERMISSIONS_EMAIL:
|
||||
logger.error("SCHEDULED_EVAL_PERMISSIONS_EMAIL not configured")
|
||||
return
|
||||
|
||||
project_name = SCHEDULED_EVAL_PROJECT
|
||||
dataset_names = SCHEDULED_EVAL_DATASET_NAMES
|
||||
permissions_email = SCHEDULED_EVAL_PERMISSIONS_EMAIL
|
||||
|
||||
# Create a timestamp for the scheduled run
|
||||
run_timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
|
||||
logger.info(
|
||||
f"Starting scheduled eval pipeline for project '{project_name}' "
|
||||
f"with {len(dataset_names)} dataset(s): {dataset_names}"
|
||||
)
|
||||
|
||||
pipeline_start = datetime.now(timezone.utc)
|
||||
results: list[dict[str, Any]] = []
|
||||
|
||||
for dataset_name in dataset_names:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
error_message: str | None = None
|
||||
success = False
|
||||
|
||||
# Create informative experiment name for scheduled runs
|
||||
experiment_name = f"{dataset_name} - {run_timestamp}"
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"Running scheduled eval for dataset: {dataset_name} "
|
||||
f"(project: {project_name})"
|
||||
)
|
||||
|
||||
configuration = EvalConfigurationOptions(
|
||||
search_permissions_email=permissions_email,
|
||||
dataset_name=dataset_name,
|
||||
no_send_logs=False,
|
||||
braintrust_project=project_name,
|
||||
experiment_name=experiment_name,
|
||||
)
|
||||
|
||||
result = run_eval(
|
||||
configuration=configuration,
|
||||
remote_dataset_name=dataset_name,
|
||||
)
|
||||
success = result.success
|
||||
logger.info(f"Completed eval for {dataset_name}: success={success}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run scheduled eval for {dataset_name}")
|
||||
error_message = str(e)
|
||||
success = False
|
||||
|
||||
end_time = datetime.now(timezone.utc)
|
||||
|
||||
results.append(
|
||||
{
|
||||
"dataset_name": dataset_name,
|
||||
"success": success,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"error_message": error_message,
|
||||
}
|
||||
)
|
||||
|
||||
pipeline_end = datetime.now(timezone.utc)
|
||||
total_duration = (pipeline_end - pipeline_start).total_seconds()
|
||||
|
||||
passed_count = sum(1 for r in results if r["success"])
|
||||
logger.info(
|
||||
f"Scheduled eval pipeline completed: {passed_count}/{len(results)} passed "
|
||||
f"in {total_duration:.1f}s"
|
||||
)
|
||||
|
||||
@@ -1,57 +1,135 @@
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import LLMProvider
|
||||
from onyx.db.models import ModelConfiguration
|
||||
|
||||
|
||||
def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
# Handle case where response is wrapped in a "data" field
|
||||
if isinstance(model_list_json, dict):
|
||||
if "data" in model_list_json:
|
||||
model_list_json = model_list_json["data"]
|
||||
elif "models" in model_list_json:
|
||||
model_list_json = model_list_json["models"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid response from API - expected dict with 'data' or "
|
||||
f"'models' field, got {type(model_list_json)}"
|
||||
)
|
||||
|
||||
if not isinstance(model_list_json, list):
|
||||
raise ValueError(
|
||||
f"Invalid response from API - expected list, got {type(model_list_json)}"
|
||||
)
|
||||
|
||||
# Handle both string list and object list cases
|
||||
model_names: list[str] = []
|
||||
for item in model_list_json:
|
||||
if isinstance(item, str):
|
||||
model_names.append(item)
|
||||
elif isinstance(item, dict):
|
||||
if "model_name" in item:
|
||||
model_names.append(item["model_name"])
|
||||
elif "id" in item:
|
||||
model_names.append(item["id"])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid item in model list - expected dict with model_name or id, got {type(item)}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid item in model list - expected string or dict, got {type(item)}"
|
||||
)
|
||||
|
||||
return model_names
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_AUTO_LLM_UPDATE,
|
||||
name=OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
ignore_result=True,
|
||||
soft_time_limit=300, # 5 minute timeout
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_auto_llm_updates(self: Task, *, tenant_id: str) -> bool | None:
|
||||
"""Periodic task to fetch LLM model updates from GitHub
|
||||
and sync them to providers in Auto mode.
|
||||
|
||||
This task checks the GitHub-hosted config file and updates all
|
||||
providers that have is_auto_mode=True.
|
||||
"""
|
||||
if not AUTO_LLM_CONFIG_URL:
|
||||
task_logger.debug("AUTO_LLM_CONFIG_URL not configured, skipping")
|
||||
return None
|
||||
def check_for_llm_model_update(self: Task, *, tenant_id: str) -> bool | None:
|
||||
if not LLM_MODEL_UPDATE_API_URL:
|
||||
raise ValueError("LLM model update API URL not configured")
|
||||
|
||||
# First fetch the models from the API
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
fetch_llm_recommendations_from_github,
|
||||
)
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
sync_llm_models_from_github,
|
||||
)
|
||||
|
||||
# Fetch config from GitHub
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
|
||||
if not config:
|
||||
task_logger.warning("Failed to fetch GitHub config")
|
||||
return None
|
||||
|
||||
# Sync to database
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
results = sync_llm_models_from_github(db_session, config)
|
||||
|
||||
if results:
|
||||
task_logger.info(f"Auto mode sync results: {results}")
|
||||
else:
|
||||
task_logger.debug("No model updates applied")
|
||||
response = requests.get(LLM_MODEL_UPDATE_API_URL)
|
||||
response.raise_for_status()
|
||||
available_models = _process_model_list_response(response.json())
|
||||
task_logger.info(f"Found available models: {available_models}")
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Error in auto LLM update task")
|
||||
raise
|
||||
task_logger.exception("Failed to fetch models from API.")
|
||||
return None
|
||||
|
||||
# Then update the database with the fetched models
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Get the default LLM provider
|
||||
default_provider = (
|
||||
db_session.query(LLMProvider)
|
||||
.filter(LLMProvider.is_default_provider.is_(True))
|
||||
.first()
|
||||
)
|
||||
|
||||
if not default_provider:
|
||||
task_logger.warning("No default LLM provider found")
|
||||
return None
|
||||
|
||||
# log change if any
|
||||
old_models = set(
|
||||
model_configuration.name
|
||||
for model_configuration in default_provider.model_configurations
|
||||
)
|
||||
new_models = set(available_models)
|
||||
added_models = new_models - old_models
|
||||
removed_models = old_models - new_models
|
||||
|
||||
if added_models:
|
||||
task_logger.info(f"Adding models: {sorted(added_models)}")
|
||||
if removed_models:
|
||||
task_logger.info(f"Removing models: {sorted(removed_models)}")
|
||||
|
||||
# Update the provider's model list
|
||||
# Remove models that are no longer available
|
||||
db_session.query(ModelConfiguration).filter(
|
||||
ModelConfiguration.llm_provider_id == default_provider.id,
|
||||
ModelConfiguration.name.notin_(available_models),
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
# Add new models
|
||||
for available_model_name in available_models:
|
||||
db_session.merge(
|
||||
ModelConfiguration(
|
||||
llm_provider_id=default_provider.id,
|
||||
name=available_model_name,
|
||||
is_visible=False,
|
||||
max_input_tokens=None,
|
||||
)
|
||||
)
|
||||
|
||||
# if the default model is no longer available, set it to the first model in the list
|
||||
if default_provider.default_model_name not in available_models:
|
||||
task_logger.info(
|
||||
f"Default model {default_provider.default_model_name} not "
|
||||
f"available, setting to first model in list: {available_models[0]}"
|
||||
)
|
||||
default_provider.default_model_name = available_models[0]
|
||||
db_session.commit()
|
||||
|
||||
if added_models or removed_models:
|
||||
task_logger.info("Updated model list for default provider.")
|
||||
|
||||
return True
|
||||
|
||||
@@ -886,7 +886,9 @@ def monitor_celery_queues_helper(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
)
|
||||
n_docprocessing = celery_get_queue_length(OnyxCeleryQueues.DOCPROCESSING, r_celery)
|
||||
|
||||
n_user_files_indexing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILES_INDEXING, r_celery
|
||||
)
|
||||
n_user_file_processing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
@@ -922,6 +924,7 @@ def monitor_celery_queues_helper(
|
||||
f"docfetching_prefetched={len(n_docfetching_prefetched)} "
|
||||
f"docprocessing={n_docprocessing} "
|
||||
f"docprocessing_prefetched={len(n_docprocessing_prefetched)} "
|
||||
f"user_files_indexing={n_user_files_indexing} "
|
||||
f"user_file_processing={n_user_file_processing} "
|
||||
f"user_file_project_sync={n_user_file_project_sync} "
|
||||
f"user_file_delete={n_user_file_delete} "
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import datetime
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
@@ -18,9 +19,11 @@ from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -29,13 +32,20 @@ from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import FileRecord
|
||||
from onyx.db.models import SearchDoc
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
)
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.file_store import S3BackedFileStore
|
||||
from onyx.file_store.utils import user_file_id_to_plaintext_file_name
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAdapter
|
||||
@@ -608,3 +618,315 @@ def process_single_user_file_project_sync(
|
||||
file_lock.release()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_legacy_user_file_doc_id(old_id: str) -> str:
|
||||
# Convert USER_FILE_CONNECTOR__<uuid> -> FILE_CONNECTOR__<uuid> for legacy values
|
||||
user_prefix = "USER_FILE_CONNECTOR__"
|
||||
file_prefix = "FILE_CONNECTOR__"
|
||||
if old_id.startswith(user_prefix):
|
||||
remainder = old_id[len(user_prefix) :]
|
||||
return file_prefix + remainder
|
||||
return old_id
|
||||
|
||||
|
||||
def update_legacy_plaintext_file_records() -> None:
|
||||
"""Migrate legacy plaintext cache objects from int-based keys to UUID-based
|
||||
keys. Copies each S3 object to its expected UUID key and updates DB.
|
||||
|
||||
Examples:
|
||||
- Old key: bucket/schema/plaintext_<int>
|
||||
- New key: bucket/schema/plaintext_<uuid>
|
||||
"""
|
||||
|
||||
task_logger.info("update_legacy_plaintext_file_records - Starting")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
store = get_default_file_store()
|
||||
|
||||
if not isinstance(store, S3BackedFileStore):
|
||||
task_logger.info(
|
||||
"update_legacy_plaintext_file_records - Skipping non-S3 store"
|
||||
)
|
||||
return
|
||||
|
||||
s3_client = store._get_s3_client()
|
||||
bucket_name = store._get_bucket_name()
|
||||
|
||||
# Select PLAINTEXT_CACHE records whose object_key ends with 'plaintext_' + non-hyphen chars
|
||||
# Example: 'some/path/plaintext_abc123' matches; '.../plaintext_foo-bar' does not
|
||||
plaintext_records: Sequence[FileRecord] = (
|
||||
db_session.execute(
|
||||
sa.select(FileRecord).where(
|
||||
FileRecord.file_origin == FileOrigin.PLAINTEXT_CACHE,
|
||||
FileRecord.object_key.op("~")(r"plaintext_[^-]+$"),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"update_legacy_plaintext_file_records - Found {len(plaintext_records)} plaintext records to update"
|
||||
)
|
||||
|
||||
normalized = 0
|
||||
for fr in plaintext_records:
|
||||
try:
|
||||
expected_key = store._get_s3_key(fr.file_id)
|
||||
if fr.object_key == expected_key:
|
||||
continue
|
||||
|
||||
if fr.bucket_name is None:
|
||||
task_logger.warning(f"id={fr.file_id} - Bucket name is None")
|
||||
continue
|
||||
|
||||
if fr.object_key is None:
|
||||
task_logger.warning(f"id={fr.file_id} - Object key is None")
|
||||
continue
|
||||
|
||||
# Copy old object to new key
|
||||
copy_source = f"{fr.bucket_name}/{fr.object_key}"
|
||||
s3_client.copy_object(
|
||||
CopySource=copy_source,
|
||||
Bucket=bucket_name,
|
||||
Key=expected_key,
|
||||
MetadataDirective="COPY",
|
||||
)
|
||||
|
||||
# Delete old object (best-effort)
|
||||
try:
|
||||
s3_client.delete_object(Bucket=fr.bucket_name, Key=fr.object_key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Update DB record with new key
|
||||
fr.object_key = expected_key
|
||||
db_session.add(fr)
|
||||
normalized += 1
|
||||
except Exception as e:
|
||||
task_logger.warning(f"id={fr.file_id} - {e.__class__.__name__}")
|
||||
|
||||
if normalized:
|
||||
db_session.commit()
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task normalized {normalized} plaintext objects"
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
|
||||
ignore_result=True,
|
||||
bind=True,
|
||||
)
|
||||
def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
|
||||
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Starting for tenant={tenant_id}"
|
||||
)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
lock: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.USER_FILE_DOCID_MIGRATION_LOCK,
|
||||
timeout=CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Lock held, skipping tenant={tenant_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
updated_count = 0
|
||||
try:
|
||||
update_legacy_plaintext_file_records()
|
||||
# Track lock renewal
|
||||
last_lock_time = time.monotonic()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_settings = get_active_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
search_settings=active_settings.primary,
|
||||
secondary_search_settings=active_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(document_index)
|
||||
|
||||
# Select user files with a legacy doc id that have not been migrated
|
||||
user_files = (
|
||||
db_session.execute(
|
||||
sa.select(UserFile).where(
|
||||
sa.and_(
|
||||
UserFile.document_id.is_not(None),
|
||||
UserFile.document_id_migrated.is_(False),
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Found {len(user_files)} user files to migrate"
|
||||
)
|
||||
|
||||
# Query all SearchDocs that need updating
|
||||
search_docs = (
|
||||
db_session.execute(
|
||||
sa.select(SearchDoc).where(
|
||||
SearchDoc.document_id.like("%FILE_CONNECTOR__%")
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Found {len(search_docs)} search docs to update"
|
||||
)
|
||||
|
||||
# Build a map of normalized doc IDs to SearchDocs
|
||||
search_doc_map: dict[str, list[SearchDoc]] = {}
|
||||
for sd in search_docs:
|
||||
doc_id = sd.document_id
|
||||
if search_doc_map.get(doc_id) is None:
|
||||
search_doc_map[doc_id] = []
|
||||
search_doc_map[doc_id].append(sd)
|
||||
|
||||
task_logger.debug(
|
||||
f"user_file_docid_migration_task - Built search doc map with {len(search_doc_map)} entries"
|
||||
)
|
||||
|
||||
ids_preview = list(search_doc_map.keys())[:5]
|
||||
task_logger.debug(
|
||||
f"user_file_docid_migration_task - First few search_doc_map ids: {ids_preview if ids_preview else 'No ids found'}"
|
||||
)
|
||||
task_logger.debug(
|
||||
f"user_file_docid_migration_task - search_doc_map total items: "
|
||||
f"{sum(len(docs) for docs in search_doc_map.values())}"
|
||||
)
|
||||
for user_file in user_files:
|
||||
# Periodically renew the Redis lock to prevent expiry mid-run
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT / 4
|
||||
):
|
||||
renewed = False
|
||||
try:
|
||||
# extend lock ttl to full timeout window
|
||||
lock.extend(CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT)
|
||||
renewed = True
|
||||
except Exception:
|
||||
# if extend fails, best-effort reacquire as a fallback
|
||||
try:
|
||||
lock.reacquire()
|
||||
renewed = True
|
||||
except Exception:
|
||||
renewed = False
|
||||
last_lock_time = current_time
|
||||
if not renewed or not lock.owned():
|
||||
task_logger.error(
|
||||
"user_file_docid_migration_task - Lost lock ownership or failed to renew; aborting for safety"
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
clean_old_doc_id = replace_invalid_doc_id_characters(
|
||||
user_file.document_id
|
||||
)
|
||||
normalized_doc_id = _normalize_legacy_user_file_doc_id(
|
||||
clean_old_doc_id
|
||||
)
|
||||
user_project_ids = [project.id for project in user_file.projects]
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Migrating user file {user_file.id} with doc_id {normalized_doc_id}"
|
||||
)
|
||||
|
||||
index_name = active_settings.primary.index_name
|
||||
|
||||
# First find the chunks count using direct Vespa query
|
||||
selection = f"{index_name}.document_id=='{normalized_doc_id}'"
|
||||
|
||||
# Count all chunks for this document
|
||||
chunk_count = _get_document_chunk_count(
|
||||
index_name=index_name,
|
||||
selection=selection,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Found {chunk_count} chunks for document {normalized_doc_id}"
|
||||
)
|
||||
|
||||
# Now update Vespa chunks with the found chunk count using retry_index
|
||||
# WARNING: In the future this will error; we no longer want
|
||||
# to support changing document ID.
|
||||
# TODO(andrei): Delete soon.
|
||||
retry_index.update_single(
|
||||
doc_id=str(normalized_doc_id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
fields=VespaDocumentFields(document_id=str(user_file.id)),
|
||||
user_fields=VespaDocumentUserFields(
|
||||
user_projects=user_project_ids
|
||||
),
|
||||
)
|
||||
user_file.chunk_count = chunk_count
|
||||
|
||||
# Update the SearchDocs
|
||||
actual_doc_id = str(user_file.document_id)
|
||||
normalized_actual_doc_id = _normalize_legacy_user_file_doc_id(
|
||||
actual_doc_id
|
||||
)
|
||||
if (
|
||||
normalized_doc_id in search_doc_map
|
||||
or normalized_actual_doc_id in search_doc_map
|
||||
):
|
||||
to_update = (
|
||||
search_doc_map[normalized_doc_id]
|
||||
if normalized_doc_id in search_doc_map
|
||||
else search_doc_map[normalized_actual_doc_id]
|
||||
)
|
||||
task_logger.debug(
|
||||
f"user_file_docid_migration_task - Updating {len(to_update)} search docs for user file {user_file.id}"
|
||||
)
|
||||
for search_doc in to_update:
|
||||
search_doc.document_id = str(user_file.id)
|
||||
db_session.add(search_doc)
|
||||
|
||||
user_file.document_id_migrated = True
|
||||
db_session.add(user_file)
|
||||
db_session.commit()
|
||||
updated_count += 1
|
||||
except Exception as per_file_exc:
|
||||
# Rollback the current transaction and continue with the next file
|
||||
db_session.rollback()
|
||||
task_logger.exception(
|
||||
f"user_file_docid_migration_task - Error migrating user file {user_file.id} - "
|
||||
f"{per_file_exc.__class__.__name__}"
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Updated {updated_count} user files"
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Completed for tenant={tenant_id} (updated={updated_count})"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"user_file_docid_migration_task - Error during execution for tenant={tenant_id} "
|
||||
f"(updated={updated_count}) exception={e.__class__.__name__}"
|
||||
)
|
||||
return False
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
@@ -63,7 +63,7 @@ To ensure the LLM follows certain specific instructions, instructions are added
|
||||
tool is used, a citation reminder is always added. Otherwise, by default there is no reminder. If the user configures reminders, those are added to the
|
||||
final message. If a search related tool just ran and the user has reminders, both appear in a single message.
|
||||
|
||||
If a search related tool is called at any point during the turn, the reminder will remain at the end until the turn is over and the agent has responded.
|
||||
If a search related tool is called at any point during the turn, the reminder will remain at the end until the turn is over and the agent as responded.
|
||||
|
||||
|
||||
## Tool Calls
|
||||
@@ -145,83 +145,9 @@ attention despite having global access.
|
||||
In a similar concept, LLM instructions in the system prompt are structured specifically so that there are coherent sections for the LLM to attend to. This is
|
||||
fairly surprising actually but if there is a line of instructions effectively saying "If you try to use some tools and find that you need more information or
|
||||
need to call additional tools, you are encouraged to do this", having this in the Tool section of the System prompt makes all the LLMs follow it well but if it's
|
||||
even just a paragraph away like near the beginning of the prompt, it is often ignored. The difference is as drastic as a 30% follow rate to a 90% follow
|
||||
even just a paragraph away like near the beginning of the prompt, it is often often ignored. The difference is as drastic as a 30% follow rate to a 90% follow
|
||||
rate even just moving the same statement a few sentences.
|
||||
|
||||
|
||||
## Other related pointers
|
||||
- How messages, files, images are stored can be found in backend/onyx/db/models.py, there is also a README.md under that directory that may be helpful.
|
||||
|
||||
---
|
||||
|
||||
# Overview of LLM flow architecture
|
||||
|
||||
**Concepts:**
|
||||
Turn: User sends a message and AI does some set of things and responds
|
||||
Step/Cycle: 1 single LLM inference given some context and some tools
|
||||
|
||||
|
||||
## 1. Top Level (process_message function):
|
||||
This function can be thought of as the set-up and validation layer. It ensures that the database is in a valid state, reads the
|
||||
messages in the session and sets up all the necessary items to run the chat loop and state containers. The major things it does
|
||||
are:
|
||||
- Validates the request
|
||||
- Builds the chat history for the session
|
||||
- Fetches any additional context such as files and images
|
||||
- Prepares all of the tools for the LLM
|
||||
- Creates the state container objects for use in the loop
|
||||
|
||||
### Wrapper (run_chat_loop_with_state_containers function):
|
||||
This wrapper is used to run the LLM flow in a background thread and monitor the emitter for stop signals. This means the top
|
||||
level is as isolated from the LLM flow as possible and can continue to yield packets as soon as they are available from the lower
|
||||
levels. This also means that if the lower levels fail, the top level will still guarantee a reasonable response to the user.
|
||||
All of the saving and database operations are abstracted away from the lower levels.
|
||||
|
||||
### Emitter
|
||||
The emitter is designed to be an object queue so that lower levels do not need to yield objects all the way back to the top.
|
||||
This way the functions can be better designed (not everything as a generator) and more easily tested. The wrapper around the
|
||||
LLM flow (run_chat_loop_with_state_containers) is used to monitor the emitter and handle packets as soon as they are available
|
||||
from the lower levels. Both the emitter and the state container are mutating state objects and only used to accumulate state.
|
||||
There should be no logic dependent on the states of these objects, especially in the lower levels. The emitter should only take
|
||||
packets and should not be used for other things.
|
||||
|
||||
### State Container
|
||||
The state container is used to accumulate state during the LLM flow. Similar to the emitter, it should not be used for logic,
|
||||
only for accumulating state. It is used to gather all of the necessary information for saving the chat turn into the database.
|
||||
So it will accumulate answer tokens, reasoning tokens, tool calls, citation info, etc. This is used at the end of the flow once
|
||||
the lower level is completed whether on its own or stopped by the user. At that point, all of the state is read and stored into
|
||||
the database. The state container can be added to by any of the underlying layers, this is fine.
|
||||
|
||||
### Stopping Generation
|
||||
A stop signal is checked every 300ms by the wrapper around the LLM flow. The signal itself
|
||||
is stored in Redis and is set by the user calling the stop endpoint. The wrapper ensures that no matter what the lower level is
|
||||
doing at the time, the thread can be killed by the top level. It does not require a cooperative cancellation from the lower level
|
||||
and in fact the lower level does not know about the stop signal at all.
|
||||
|
||||
|
||||
## 2. LLM Loop (run_llm_loop function)
|
||||
This function handles the logic of the Turn. It's essentially a while loop where context is added and modified (according what
|
||||
is outlined in the first half of this doc). Its main functionality is:
|
||||
- Translate and truncate the context for the LLM inference
|
||||
- Add context modifiers like reminders, updates to the system prompts, etc.
|
||||
- Run tool calls and gather results
|
||||
- Build some of the objects stored in the state container.
|
||||
|
||||
|
||||
## 3. LLM Step (run_llm_step function)
|
||||
This function is a single inference of the LLM. It's a wrapper around the LLM stream function which handles packet translations
|
||||
so that the Emitter can emit individual tokens as soon as they arrive. It also keeps track of the different sections since they
|
||||
do not all come at once (reasoning, answers, tool calls are all built up token by token). This layer also tracks the different
|
||||
tool calls and returns that to the LLM Loop to execute.
|
||||
|
||||
|
||||
## Things to know
|
||||
- Packets are labeled with a "turn_index" field as part of the Placement of the packet. This is not the same as the backend
|
||||
concept of a turn. The turn_index for the frontend is which block does this packet belong to. So while a reasoning + tool call
|
||||
comes from the same LLM inference (same backend LLM step), they are 2 turns to the frontend because that's how it's rendered.
|
||||
|
||||
- There are 3 representations of "message". The first is the database model ChatMessage, this one should be translated away and
|
||||
not used deep into the flow. The second is ChatMessageSimple which is the data model which should be used throughout the code
|
||||
as much as possible. If modifications/additions are needed, it should be to this object. This is the rich representation of a
|
||||
message for the code. Finally there is the LanguageModelInput representation of a message. This one is for the LLM interface
|
||||
layer and is as stripped down as possible so that the LLM interface can be clean and easy to maintain/extend.
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from queue import Empty
|
||||
@@ -39,7 +38,6 @@ class ChatStateContainer:
|
||||
self.citation_to_doc: CitationMapping = {}
|
||||
# True if this turn is a clarification question (deep research flow)
|
||||
self.is_clarification: bool = False
|
||||
# Note: LLM cost tracking is now handled in multi_llm.py
|
||||
|
||||
def add_tool_call(self, tool_call: ToolCallInfo) -> None:
|
||||
"""Add a tool call to the accumulated state."""
|
||||
@@ -146,9 +144,6 @@ def run_chat_loop_with_state_containers(
|
||||
thread = run_in_background(run_with_exception_capture)
|
||||
|
||||
pkt: Packet | None = None
|
||||
last_turn_index = 0 # Track the highest turn_index seen for stop packet
|
||||
last_cancel_check = time.monotonic()
|
||||
cancel_check_interval = 0.3 # Check for cancellation every 300ms
|
||||
try:
|
||||
while True:
|
||||
# Poll queue with 300ms timeout for natural stop signal checking
|
||||
@@ -157,40 +152,18 @@ def run_chat_loop_with_state_containers(
|
||||
pkt = emitter.bus.get(timeout=0.3)
|
||||
except Empty:
|
||||
if not is_connected():
|
||||
# Stop signal detected
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
# Stop signal detected, kill the thread
|
||||
break
|
||||
last_cancel_check = time.monotonic()
|
||||
continue
|
||||
|
||||
if pkt is not None:
|
||||
# Track the highest turn_index for the stop packet
|
||||
if pkt.placement and pkt.placement.turn_index > last_turn_index:
|
||||
last_turn_index = pkt.placement.turn_index
|
||||
|
||||
if isinstance(pkt.obj, OverallStop):
|
||||
if pkt.obj == OverallStop(type="stop"):
|
||||
yield pkt
|
||||
break
|
||||
elif isinstance(pkt.obj, PacketException):
|
||||
raise pkt.obj.exception
|
||||
else:
|
||||
yield pkt
|
||||
|
||||
# Check for cancellation periodically even when packets are flowing
|
||||
# This ensures stop signal is checked during active streaming
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_cancel_check >= cancel_check_interval:
|
||||
if not is_connected():
|
||||
# Stop signal detected during streaming
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
break
|
||||
last_cancel_check = current_time
|
||||
finally:
|
||||
# Wait for thread to complete on normal exit to propagate exceptions and ensure cleanup.
|
||||
# Skip waiting if user disconnected to exit quickly.
|
||||
|
||||
@@ -26,7 +26,6 @@ from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.kg_config import is_kg_config_settings_enabled_valid
|
||||
from onyx.db.llm import fetch_existing_doc_sets
|
||||
@@ -38,7 +37,6 @@ from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.projects import check_project_ownership
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -53,7 +51,6 @@ from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.prompts.chat_prompts import ADDITIONAL_CONTEXT_PROMPT
|
||||
from onyx.prompts.chat_prompts import TOOL_CALL_RESPONSE_CROSS_MESSAGE
|
||||
from onyx.prompts.tool_prompts import TOOL_CALL_FAILURE_PROMPT
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
@@ -64,45 +61,9 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def create_chat_session_from_request(
|
||||
chat_session_request: ChatSessionCreationRequest,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> ChatSession:
|
||||
"""Create a chat session from a ChatSessionCreationRequest.
|
||||
|
||||
Includes project ownership validation when project_id is provided.
|
||||
|
||||
Args:
|
||||
chat_session_request: The request containing persona_id, description, and project_id
|
||||
user_id: The ID of the user creating the session (can be None for anonymous)
|
||||
db_session: The database session
|
||||
|
||||
Returns:
|
||||
The newly created ChatSession
|
||||
|
||||
Raises:
|
||||
ValueError: If user lacks access to the specified project
|
||||
Exception: If the persona is invalid
|
||||
"""
|
||||
project_id = chat_session_request.project_id
|
||||
if project_id:
|
||||
if not check_project_ownership(project_id, user_id, db_session):
|
||||
raise ValueError("User does not have access to project")
|
||||
|
||||
return create_chat_session(
|
||||
db_session=db_session,
|
||||
description=chat_session_request.description or "",
|
||||
user_id=user_id,
|
||||
persona_id=chat_session_request.persona_id,
|
||||
project_id=chat_session_request.project_id,
|
||||
)
|
||||
|
||||
|
||||
def prepare_chat_message_request(
|
||||
message_text: str,
|
||||
user: User | None,
|
||||
@@ -116,7 +77,6 @@ def prepare_chat_message_request(
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
llm_override: LLMOverride | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
forced_tool_ids: list[int] | None = None,
|
||||
) -> CreateChatMessageRequest:
|
||||
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
|
||||
new_chat_session = create_chat_session(
|
||||
@@ -143,7 +103,6 @@ def prepare_chat_message_request(
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
llm_override=llm_override,
|
||||
allowed_tool_ids=allowed_tool_ids,
|
||||
forced_tool_ids=forced_tool_ids,
|
||||
)
|
||||
|
||||
|
||||
@@ -205,15 +164,13 @@ def create_chat_history_chain(
|
||||
)
|
||||
|
||||
if not all_chat_messages:
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
raise RuntimeError("No messages in Chat Session")
|
||||
|
||||
root_message = all_chat_messages[0]
|
||||
if root_message.parent_message is not None:
|
||||
raise RuntimeError(
|
||||
"Invalid root message, unable to fetch valid chat message sequence"
|
||||
)
|
||||
else:
|
||||
root_message = all_chat_messages[0]
|
||||
if root_message.parent_message is not None:
|
||||
raise RuntimeError(
|
||||
"Invalid root message, unable to fetch valid chat message sequence"
|
||||
)
|
||||
|
||||
current_message: ChatMessage | None = root_message
|
||||
previous_message: ChatMessage | None = None
|
||||
@@ -244,6 +201,9 @@ def create_chat_history_chain(
|
||||
|
||||
previous_message = current_message
|
||||
|
||||
if not mainline_messages:
|
||||
raise RuntimeError("Could not trace chat message history")
|
||||
|
||||
return mainline_messages
|
||||
|
||||
|
||||
|
||||
@@ -99,7 +99,7 @@ def _build_project_file_citation_mapping(
|
||||
|
||||
|
||||
def construct_message_history(
|
||||
system_prompt: ChatMessageSimple | None,
|
||||
system_prompt: ChatMessageSimple,
|
||||
custom_agent_prompt: ChatMessageSimple | None,
|
||||
simple_chat_history: list[ChatMessageSimple],
|
||||
reminder_message: ChatMessageSimple | None,
|
||||
@@ -114,7 +114,7 @@ def construct_message_history(
|
||||
)
|
||||
|
||||
history_token_budget = available_tokens
|
||||
history_token_budget -= system_prompt.token_count if system_prompt else 0
|
||||
history_token_budget -= system_prompt.token_count
|
||||
history_token_budget -= (
|
||||
custom_agent_prompt.token_count if custom_agent_prompt else 0
|
||||
)
|
||||
@@ -125,12 +125,9 @@ def construct_message_history(
|
||||
if history_token_budget < 0:
|
||||
raise ValueError("Not enough tokens available to construct message history")
|
||||
|
||||
if system_prompt:
|
||||
system_prompt.should_cache = True
|
||||
|
||||
# If no history, build minimal context
|
||||
if not simple_chat_history:
|
||||
result = [system_prompt] if system_prompt else []
|
||||
result = [system_prompt]
|
||||
if custom_agent_prompt:
|
||||
result.append(custom_agent_prompt)
|
||||
if project_files and project_files.project_file_texts:
|
||||
@@ -202,7 +199,6 @@ def construct_message_history(
|
||||
|
||||
for msg in reversed(history_before_last_user):
|
||||
if current_token_count + msg.token_count <= remaining_budget:
|
||||
msg.should_cache = True
|
||||
truncated_history_before.insert(0, msg)
|
||||
current_token_count += msg.token_count
|
||||
else:
|
||||
@@ -222,7 +218,7 @@ def construct_message_history(
|
||||
# Build the final message list according to README ordering:
|
||||
# [system], [history_before_last_user], [custom_agent], [project_files],
|
||||
# [last_user_message], [messages_after_last_user], [reminder]
|
||||
result = [system_prompt] if system_prompt else []
|
||||
result = [system_prompt]
|
||||
|
||||
# 1. Add truncated history before last user message
|
||||
result.extend(truncated_history_before)
|
||||
@@ -346,13 +342,8 @@ def run_llm_loop(
|
||||
has_called_search_tool: bool = False
|
||||
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
|
||||
|
||||
default_base_system_prompt: str = get_default_base_system_prompt(db_session)
|
||||
system_prompt = None
|
||||
custom_agent_prompt_msg = None
|
||||
|
||||
reasoning_cycles = 0
|
||||
for llm_cycle_count in range(MAX_LLM_CYCLES):
|
||||
out_of_cycles = llm_cycle_count == MAX_LLM_CYCLES - 1
|
||||
if forced_tool_id:
|
||||
# Needs to be just the single one because the "required" currently doesn't have a specified tool, just a binary
|
||||
final_tools = [tool for tool in tools if tool.id == forced_tool_id]
|
||||
@@ -360,7 +351,7 @@ def run_llm_loop(
|
||||
raise ValueError(f"Tool {forced_tool_id} not found in tools")
|
||||
tool_choice = ToolChoiceOptions.REQUIRED
|
||||
forced_tool_id = None
|
||||
elif out_of_cycles or ran_image_gen:
|
||||
elif llm_cycle_count == MAX_LLM_CYCLES - 1 or ran_image_gen:
|
||||
# Last cycle, no tools allowed, just answer!
|
||||
tool_choice = ToolChoiceOptions.NONE
|
||||
final_tools = []
|
||||
@@ -379,47 +370,35 @@ def run_llm_loop(
|
||||
)
|
||||
custom_agent_prompt_msg = None
|
||||
else:
|
||||
# If it's an empty string, we assume the user does not want to include it as an empty System message
|
||||
if default_base_system_prompt:
|
||||
open_ai_formatting_enabled = model_needs_formatting_reenabled(
|
||||
llm.config.model_name
|
||||
)
|
||||
# System message and custom agent message are both included.
|
||||
open_ai_formatting_enabled = model_needs_formatting_reenabled(
|
||||
llm.config.model_name
|
||||
)
|
||||
|
||||
system_prompt_str = build_system_prompt(
|
||||
base_system_prompt=default_base_system_prompt,
|
||||
datetime_aware=persona.datetime_aware if persona else True,
|
||||
memories=memories,
|
||||
tools=tools,
|
||||
should_cite_documents=should_cite_documents
|
||||
or always_cite_documents,
|
||||
open_ai_formatting_enabled=open_ai_formatting_enabled,
|
||||
system_prompt_str = build_system_prompt(
|
||||
base_system_prompt=get_default_base_system_prompt(db_session),
|
||||
datetime_aware=persona.datetime_aware if persona else True,
|
||||
memories=memories,
|
||||
tools=tools,
|
||||
should_cite_documents=should_cite_documents
|
||||
or always_cite_documents,
|
||||
open_ai_formatting_enabled=open_ai_formatting_enabled,
|
||||
)
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=system_prompt_str,
|
||||
token_count=token_counter(system_prompt_str),
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
|
||||
custom_agent_prompt_msg = (
|
||||
ChatMessageSimple(
|
||||
message=custom_agent_prompt,
|
||||
token_count=token_counter(custom_agent_prompt),
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=system_prompt_str,
|
||||
token_count=token_counter(system_prompt_str),
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
custom_agent_prompt_msg = (
|
||||
ChatMessageSimple(
|
||||
message=custom_agent_prompt,
|
||||
token_count=token_counter(custom_agent_prompt),
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
if custom_agent_prompt
|
||||
else None
|
||||
)
|
||||
else:
|
||||
# If there is a custom agent prompt, it replaces the system prompt when the default system prompt is empty
|
||||
system_prompt = (
|
||||
ChatMessageSimple(
|
||||
message=custom_agent_prompt,
|
||||
token_count=token_counter(custom_agent_prompt),
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
if custom_agent_prompt
|
||||
else None
|
||||
)
|
||||
custom_agent_prompt_msg = None
|
||||
if custom_agent_prompt
|
||||
else None
|
||||
)
|
||||
|
||||
reminder_message_text: str | None
|
||||
if ran_image_gen:
|
||||
@@ -427,7 +406,7 @@ def run_llm_loop(
|
||||
# This is to prevent it generating things like:
|
||||
# [Cute Cat](attachment://a_cute_cat_sitting_playfully.png)
|
||||
reminder_message_text = IMAGE_GEN_REMINDER
|
||||
elif just_ran_web_search and not out_of_cycles:
|
||||
elif just_ran_web_search:
|
||||
reminder_message_text = OPEN_URL_REMINDER
|
||||
else:
|
||||
# This is the default case, the LLM at this point may answer so it is important
|
||||
@@ -438,7 +417,6 @@ def run_llm_loop(
|
||||
),
|
||||
include_citation_reminder=should_cite_documents
|
||||
or always_cite_documents,
|
||||
is_last_cycle=out_of_cycles,
|
||||
)
|
||||
|
||||
reminder_msg = (
|
||||
@@ -513,7 +491,6 @@ def run_llm_loop(
|
||||
user_info=None, # TODO, this is part of memories right now, might want to separate it out
|
||||
citation_mapping=citation_mapping,
|
||||
next_citation_num=citation_processor.get_next_citation_number(),
|
||||
max_concurrent_tools=None,
|
||||
skip_search_query_expansion=has_called_search_tool,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Mapping
|
||||
@@ -18,7 +17,6 @@ from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.model_response import Delta
|
||||
@@ -33,7 +31,6 @@ from onyx.llm.models import TextContentPart
|
||||
from onyx.llm.models import ToolCall
|
||||
from onyx.llm.models import ToolMessage
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.prompt_cache.processor import process_with_prompt_cache
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
@@ -49,6 +46,7 @@ from onyx.tracing.framework.create import generation_span
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -279,7 +277,6 @@ def _extract_tool_call_kickoffs(
|
||||
|
||||
def translate_history_to_llm_format(
|
||||
history: list[ChatMessageSimple],
|
||||
llm_config: LLMConfig,
|
||||
) -> LanguageModelInput:
|
||||
"""Convert a list of ChatMessageSimple to LanguageModelInput format.
|
||||
|
||||
@@ -287,23 +284,8 @@ def translate_history_to_llm_format(
|
||||
handling different message types and image files for multimodal support.
|
||||
"""
|
||||
messages: list[ChatCompletionMessage] = []
|
||||
last_cacheable_msg_idx = -1
|
||||
all_previous_msgs_cacheable = True
|
||||
|
||||
for idx, msg in enumerate(history):
|
||||
# if the message is being added to the history
|
||||
if msg.message_type in [
|
||||
MessageType.SYSTEM,
|
||||
MessageType.USER,
|
||||
MessageType.ASSISTANT,
|
||||
MessageType.TOOL_CALL_RESPONSE,
|
||||
]:
|
||||
all_previous_msgs_cacheable = (
|
||||
all_previous_msgs_cacheable and msg.should_cache
|
||||
)
|
||||
if all_previous_msgs_cacheable:
|
||||
last_cacheable_msg_idx = idx
|
||||
|
||||
for msg in history:
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
system_msg = SystemMessage(
|
||||
role="system",
|
||||
@@ -413,7 +395,7 @@ def translate_history_to_llm_format(
|
||||
assistant_msg_with_tool = AssistantMessage(
|
||||
role="assistant",
|
||||
content=None, # The tool call is parsed, doesn't need to be duplicated in the content
|
||||
tool_calls=tool_calls or None,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
)
|
||||
messages.append(assistant_msg_with_tool)
|
||||
|
||||
@@ -435,18 +417,6 @@ def translate_history_to_llm_format(
|
||||
f"Unknown message type {msg.message_type} in history. Skipping message."
|
||||
)
|
||||
|
||||
# prompt caching: rely on should_cache in ChatMessageSimple to
|
||||
# pick the split point for the cacheable prefix and suffix
|
||||
if last_cacheable_msg_idx != -1:
|
||||
processed_messages, _ = process_with_prompt_cache(
|
||||
llm_config=llm_config,
|
||||
cacheable_prefix=messages[: last_cacheable_msg_idx + 1],
|
||||
suffix=messages[last_cacheable_msg_idx + 1 :],
|
||||
continuation=False,
|
||||
)
|
||||
assert isinstance(processed_messages, list) # for mypy
|
||||
messages = processed_messages
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@@ -459,10 +429,6 @@ def _increment_turns(
|
||||
return turn_index, sub_turn_index + 1
|
||||
|
||||
|
||||
def _delta_has_action(delta: Delta) -> bool:
|
||||
return bool(delta.content or delta.reasoning_content or delta.tool_calls)
|
||||
|
||||
|
||||
def run_llm_step_pkt_generator(
|
||||
history: list[ChatMessageSimple],
|
||||
tool_definitions: list[dict],
|
||||
@@ -533,7 +499,7 @@ def run_llm_step_pkt_generator(
|
||||
tab_index = placement.tab_index
|
||||
sub_turn_index = placement.sub_turn_index
|
||||
|
||||
llm_msg_history = translate_history_to_llm_format(history, llm.config)
|
||||
llm_msg_history = translate_history_to_llm_format(history)
|
||||
has_reasoned = 0
|
||||
|
||||
# Uncomment the line below to log the entire message history to the console
|
||||
@@ -560,8 +526,6 @@ def run_llm_step_pkt_generator(
|
||||
span_generation.span_data.input = cast(
|
||||
Sequence[Mapping[str, Any]], llm_msg_history
|
||||
)
|
||||
stream_start_time = time.monotonic()
|
||||
first_action_recorded = False
|
||||
for packet in llm.stream(
|
||||
prompt=llm_msg_history,
|
||||
tools=tool_definitions,
|
||||
@@ -579,13 +543,7 @@ def run_llm_step_pkt_generator(
|
||||
"cache_read_input_tokens": usage.cache_read_input_tokens,
|
||||
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
|
||||
}
|
||||
# Note: LLM cost tracking is now handled in multi_llm.py
|
||||
delta = packet.choice.delta
|
||||
if not first_action_recorded and _delta_has_action(delta):
|
||||
span_generation.span_data.time_to_first_action_seconds = (
|
||||
time.monotonic() - stream_start_time
|
||||
)
|
||||
first_action_recorded = True
|
||||
|
||||
if custom_token_processor:
|
||||
# The custom token processor can modify the deltas for specific custom logic
|
||||
@@ -744,15 +702,6 @@ def run_llm_step_pkt_generator(
|
||||
# Flush custom token processor to get any final tool calls
|
||||
if custom_token_processor:
|
||||
flush_delta, processor_state = custom_token_processor(None, processor_state)
|
||||
if (
|
||||
not first_action_recorded
|
||||
and flush_delta is not None
|
||||
and _delta_has_action(flush_delta)
|
||||
):
|
||||
span_generation.span_data.time_to_first_action_seconds = (
|
||||
time.monotonic() - stream_start_time
|
||||
)
|
||||
first_action_recorded = True
|
||||
if flush_delta and flush_delta.tool_calls:
|
||||
for tool_call_delta in flush_delta.tool_calls:
|
||||
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
|
||||
|
||||
@@ -3,7 +3,6 @@ from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
@@ -17,9 +16,7 @@ from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.models import SearchToolUsage
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
|
||||
@@ -135,13 +132,6 @@ class ToolConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class ProjectSearchConfig(BaseModel):
|
||||
"""Configuration for search tool availability in project context."""
|
||||
|
||||
search_usage: SearchToolUsage
|
||||
disable_forced_tool: bool
|
||||
|
||||
|
||||
class PromptOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
@@ -181,10 +171,6 @@ AnswerQuestionPossibleReturn = (
|
||||
)
|
||||
|
||||
|
||||
class CreateChatSessionID(BaseModel):
|
||||
chat_session_id: UUID
|
||||
|
||||
|
||||
AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]
|
||||
|
||||
|
||||
@@ -195,14 +181,12 @@ class LLMMetricsContainer(BaseModel):
|
||||
|
||||
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
|
||||
|
||||
|
||||
AnswerStreamPart = (
|
||||
Packet
|
||||
| StreamStopInfo
|
||||
| MessageResponseIDInfo
|
||||
| StreamingError
|
||||
| UserKnowledgeFilePacket
|
||||
| CreateChatSessionID
|
||||
)
|
||||
|
||||
AnswerStream = Iterator[AnswerStreamPart]
|
||||
@@ -220,37 +204,6 @@ class ChatBasicResponse(BaseModel):
|
||||
citation_info: list[CitationInfo]
|
||||
|
||||
|
||||
class ToolCallResponse(BaseModel):
|
||||
"""Tool call with full details for non-streaming response."""
|
||||
|
||||
tool_name: str
|
||||
tool_arguments: dict[str, Any]
|
||||
tool_result: str
|
||||
search_docs: list[SearchDoc] | None = None
|
||||
generated_images: list[GeneratedImage] | None = None
|
||||
# Reasoning that led to the tool call
|
||||
pre_reasoning: str | None = None
|
||||
|
||||
|
||||
class ChatFullResponse(BaseModel):
|
||||
"""Complete non-streaming response with all available data."""
|
||||
|
||||
# Core response fields
|
||||
answer: str
|
||||
answer_citationless: str
|
||||
pre_answer_reasoning: str | None = None
|
||||
tool_calls: list[ToolCallResponse] = []
|
||||
|
||||
# Documents & citations
|
||||
top_documents: list[SearchDoc]
|
||||
citation_info: list[CitationInfo]
|
||||
|
||||
# Metadata
|
||||
message_id: int
|
||||
chat_session_id: UUID | None = None
|
||||
error_msg: str | None = None
|
||||
|
||||
|
||||
class ChatLoadedFile(InMemoryChatFile):
|
||||
content_text: str | None
|
||||
token_count: int
|
||||
@@ -264,12 +217,6 @@ class ChatMessageSimple(BaseModel):
|
||||
image_files: list[ChatLoadedFile] | None = None
|
||||
# Only for TOOL_CALL_RESPONSE type messages
|
||||
tool_call_id: str | None = None
|
||||
# The last message for which this is true
|
||||
# AND is true for all previous messages
|
||||
# (counting from the start of the history)
|
||||
# represents the end of the cacheable prefix
|
||||
# used for prompt caching
|
||||
should_cache: bool = False
|
||||
|
||||
|
||||
class ProjectFileMetadata(BaseModel):
|
||||
@@ -287,8 +234,6 @@ class ExtractedProjectFiles(BaseModel):
|
||||
total_token_count: int
|
||||
# Metadata for project files to enable citations
|
||||
project_file_metadata: list[ProjectFileMetadata]
|
||||
# None if not a project
|
||||
project_uncapped_token_count: int | None
|
||||
|
||||
|
||||
class LlmStepResult(BaseModel):
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
"""
|
||||
IMPORTANT: familiarize yourself with the design concepts prior to contributing to this file.
|
||||
An overview can be found in the README.md file in this directory.
|
||||
"""
|
||||
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -13,7 +10,6 @@ from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_state import run_chat_loop_with_state_containers
|
||||
from onyx.chat.chat_utils import convert_chat_history
|
||||
from onyx.chat.chat_utils import create_chat_history_chain
|
||||
from onyx.chat.chat_utils import create_chat_session_from_request
|
||||
from onyx.chat.chat_utils import get_custom_agent_prompt
|
||||
from onyx.chat.chat_utils import is_last_assistant_message_clarification
|
||||
from onyx.chat.chat_utils import load_all_chat_files
|
||||
@@ -21,36 +17,37 @@ from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.chat.llm_loop import run_llm_loop
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import ChatBasicResponse
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ProjectSearchConfig
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import ToolCallResponse
|
||||
from onyx.chat.prompt_utils import calculate_reserved_tokens
|
||||
from onyx.chat.save_chat import save_chat_turn
|
||||
from onyx.chat.stop_signal_checker import is_connected as check_stop_signal
|
||||
from onyx.chat.stop_signal_checker import reset_cancel_status
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.models import CitationDocInfo
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_chat_message
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.memory import get_memories
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import User
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
from onyx.db.tools import get_tools
|
||||
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.utils import load_in_memory_chat_files
|
||||
from onyx.file_store.utils import verify_user_files
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
@@ -60,24 +57,23 @@ from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.tools.constants import SEARCH_TOOL_ID
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import SearchToolUsage
|
||||
from onyx.tools.tool_constructor import construct_tools
|
||||
from onyx.tools.tool_constructor import CustomToolConfig
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.tools.tool_constructor import SearchToolUsage
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.timing import log_generator_function_time
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -130,7 +126,6 @@ def _extract_project_file_texts_and_images(
|
||||
project_as_filter=False,
|
||||
total_token_count=0,
|
||||
project_file_metadata=[],
|
||||
project_uncapped_token_count=None,
|
||||
)
|
||||
|
||||
max_actual_tokens = (
|
||||
@@ -216,120 +211,159 @@ def _extract_project_file_texts_and_images(
|
||||
project_as_filter=project_as_filter,
|
||||
total_token_count=total_token_count,
|
||||
project_file_metadata=project_file_metadata,
|
||||
project_uncapped_token_count=project_tokens,
|
||||
)
|
||||
|
||||
|
||||
def _get_project_search_availability(
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
loaded_project_files: bool,
|
||||
project_has_files: bool,
|
||||
forced_tool_id: int | None,
|
||||
has_project_file_texts: bool,
|
||||
forced_tool_ids: list[int] | None,
|
||||
search_tool_id: int | None,
|
||||
) -> ProjectSearchConfig:
|
||||
) -> SearchToolUsage:
|
||||
"""Determine search tool availability based on project context.
|
||||
|
||||
Search is disabled when ALL of the following are true:
|
||||
- User is in a project
|
||||
- Using the default persona (not a custom agent)
|
||||
- Project files are already loaded in context
|
||||
Args:
|
||||
project_id: The project ID if the user is in a project
|
||||
persona_id: The persona ID to check if it's the default persona
|
||||
has_project_file_texts: Whether project files are loaded in context
|
||||
forced_tool_ids: List of forced tool IDs (may be mutated to remove search tool)
|
||||
search_tool_id: The search tool ID to check against
|
||||
|
||||
When search is disabled and the user tried to force the search tool,
|
||||
that forcing is also disabled.
|
||||
|
||||
Returns AUTO (follow persona config) in all other cases.
|
||||
Returns:
|
||||
SearchToolUsage setting indicating how search should be used
|
||||
"""
|
||||
# Not in a project, this should have no impact on search tool availability
|
||||
if not project_id:
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.AUTO, disable_forced_tool=False
|
||||
)
|
||||
# There are cases where the internal search tool should be disabled
|
||||
# If the user is in a project, it should not use other sources / generic search
|
||||
# If they are in a project but using a custom agent, it should use the agent setup
|
||||
# (which means it can use search)
|
||||
# However if in a project and there are more files than can fit in the context,
|
||||
# it should use the search tool with the project filter on
|
||||
# If no files are uploaded, search should remain enabled
|
||||
search_usage_forcing_setting = SearchToolUsage.AUTO
|
||||
if project_id:
|
||||
if bool(persona_id is DEFAULT_PERSONA_ID and has_project_file_texts):
|
||||
search_usage_forcing_setting = SearchToolUsage.DISABLED
|
||||
# Remove search tool from forced_tool_ids if it's present
|
||||
if forced_tool_ids and search_tool_id and search_tool_id in forced_tool_ids:
|
||||
forced_tool_ids[:] = [
|
||||
tool_id for tool_id in forced_tool_ids if tool_id != search_tool_id
|
||||
]
|
||||
elif forced_tool_ids and search_tool_id and search_tool_id in forced_tool_ids:
|
||||
search_usage_forcing_setting = SearchToolUsage.ENABLED
|
||||
return search_usage_forcing_setting
|
||||
|
||||
# Custom persona in project - let persona config decide
|
||||
# Even if there are no files in the project, it's still guided by the persona config.
|
||||
if persona_id != DEFAULT_PERSONA_ID:
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.AUTO, disable_forced_tool=False
|
||||
)
|
||||
|
||||
# If in a project with the default persona and the files have been already loaded into the context or
|
||||
# there are no files in the project, disable search as there is nothing to search for.
|
||||
if loaded_project_files or not project_has_files:
|
||||
user_forced_search = (
|
||||
forced_tool_id is not None
|
||||
and search_tool_id is not None
|
||||
and forced_tool_id == search_tool_id
|
||||
)
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.DISABLED,
|
||||
disable_forced_tool=user_forced_search,
|
||||
)
|
||||
|
||||
# Default persona in a project with files, but also the files have not been loaded into the context already.
|
||||
return ProjectSearchConfig(
|
||||
search_usage=SearchToolUsage.ENABLED, disable_forced_tool=False
|
||||
def _initialize_chat_session(
|
||||
message_text: str,
|
||||
files: list[FileDescriptor],
|
||||
token_counter: Callable[[str], int],
|
||||
parent_id: int | None,
|
||||
user_id: UUID | None,
|
||||
chat_session_id: UUID,
|
||||
db_session: Session,
|
||||
use_existing_user_message: bool = False,
|
||||
) -> ChatMessage:
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
|
||||
if parent_id is None:
|
||||
parent_message = root_message
|
||||
else:
|
||||
parent_message = get_chat_message(
|
||||
chat_message_id=parent_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
def handle_stream_message_objects(
|
||||
new_msg_req: SendMessageRequest,
|
||||
# For seeding, the parent message points to the message that is supposed to be the last
|
||||
# user message.
|
||||
if use_existing_user_message:
|
||||
if parent_message.parent_message is None:
|
||||
raise RuntimeError("No parent message found for seeding")
|
||||
if parent_message.message_type != MessageType.USER:
|
||||
raise RuntimeError(
|
||||
"Parent message is not a user message, needed for seeded flow."
|
||||
)
|
||||
message_text = parent_message.message
|
||||
token_count = parent_message.token_count
|
||||
parent_message = parent_message.parent_message
|
||||
else:
|
||||
token_count = token_counter(message_text)
|
||||
|
||||
# Flushed for ID but not committed yet
|
||||
user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=parent_message,
|
||||
message=message_text,
|
||||
token_count=token_count,
|
||||
message_type=MessageType.USER,
|
||||
files=files,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
return user_message
|
||||
|
||||
|
||||
def stream_chat_message_objects(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
# Needed to translate persona num_chunks to tokens to the LLM
|
||||
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
# For flow with search, don't include as many chunks as possible since we need to leave space
|
||||
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
|
||||
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
# if specified, uses the last user message and does not create a new user message based
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
enforce_chat_session_id_for_search_docs: bool = True,
|
||||
bypass_acl: bool = False,
|
||||
# Additional context that should be included in the chat history, for example:
|
||||
# Slack threads where the conversation cannot be represented by a chain of User/Assistant
|
||||
# messages. Both of the below are used for Slack
|
||||
# messages.
|
||||
# NOTE: is not stored in the database, only passed in to the LLM as context
|
||||
additional_context: str | None = None,
|
||||
# Slack context for federated Slack search
|
||||
slack_context: SlackContext | None = None,
|
||||
# Optional external state container for non-streaming access to accumulated state
|
||||
external_state_container: ChatStateContainer | None = None,
|
||||
) -> AnswerStream:
|
||||
tenant_id = get_current_tenant_id()
|
||||
use_existing_user_message = new_msg_req.use_existing_user_message
|
||||
|
||||
llm: LLM | None = None
|
||||
|
||||
user_id = user.id if user is not None else None
|
||||
llm_user_identifier = (
|
||||
user.email
|
||||
if user is not None and getattr(user, "email", None)
|
||||
else (str(user_id) if user_id else "anonymous_user")
|
||||
)
|
||||
try:
|
||||
if not new_msg_req.chat_session_id:
|
||||
if not new_msg_req.chat_session_info:
|
||||
raise RuntimeError(
|
||||
"Must specify a chat session id or chat session info"
|
||||
)
|
||||
chat_session = create_chat_session_from_request(
|
||||
chat_session_request=new_msg_req.chat_session_info,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield CreateChatSessionID(chat_session_id=chat_session.id)
|
||||
else:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
user_id = user.id if user is not None else None
|
||||
llm_user_identifier = (
|
||||
user.email
|
||||
if user is not None and getattr(user, "email", None)
|
||||
else (str(user_id) if user_id else "anonymous_user")
|
||||
)
|
||||
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
persona = chat_session.persona
|
||||
|
||||
message_text = new_msg_req.message
|
||||
chat_session_id = new_msg_req.chat_session_id
|
||||
user_identity = LLMUserIdentity(
|
||||
user_id=llm_user_identifier, session_id=str(chat_session.id)
|
||||
user_id=llm_user_identifier, session_id=str(chat_session_id)
|
||||
)
|
||||
parent_id = new_msg_req.parent_message_id
|
||||
reference_doc_ids = new_msg_req.search_doc_ids
|
||||
retrieval_options = new_msg_req.retrieval_options
|
||||
new_msg_req.alternate_assistant_id
|
||||
user_selected_filters = retrieval_options.filters if retrieval_options else None
|
||||
|
||||
# permanent "log" store, used primarily for debugging
|
||||
long_term_logger = LongTermLogger(
|
||||
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session.id)}
|
||||
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
|
||||
)
|
||||
|
||||
# Milestone tracking, most devs using the API don't need to understand this
|
||||
@@ -339,6 +373,11 @@ def handle_stream_message_objects(
|
||||
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
)
|
||||
|
||||
if reference_doc_ids is None and retrieval_options is None:
|
||||
raise RuntimeError(
|
||||
"Must specify a set of documents for chat or specify search options"
|
||||
)
|
||||
|
||||
llm = get_llm_for_persona(
|
||||
persona=persona,
|
||||
user=user,
|
||||
@@ -348,14 +387,6 @@ def handle_stream_message_objects(
|
||||
)
|
||||
token_counter = get_llm_token_counter(llm)
|
||||
|
||||
# Check LLM cost limits before using the LLM (only for Onyx-managed keys)
|
||||
|
||||
check_llm_cost_limit_for_provider(
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
llm_provider_api_key=llm.config.api_key,
|
||||
)
|
||||
|
||||
# Verify that the user specified files actually belong to the user
|
||||
verify_user_files(
|
||||
user_files=new_msg_req.file_descriptors,
|
||||
@@ -364,58 +395,35 @@ def handle_stream_message_objects(
|
||||
project_id=chat_session.project_id,
|
||||
)
|
||||
|
||||
# Makes sure that the chat session has the right message nodes
|
||||
# and that the latest user message is created (not yet committed)
|
||||
user_message = _initialize_chat_session(
|
||||
message_text=message_text,
|
||||
files=new_msg_req.file_descriptors,
|
||||
token_counter=token_counter,
|
||||
parent_id=parent_id,
|
||||
user_id=user_id,
|
||||
chat_session_id=chat_session_id,
|
||||
db_session=db_session,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
)
|
||||
|
||||
# re-create linear history of messages
|
||||
chat_history = create_chat_history_chain(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
|
||||
# Determine the parent message based on the request:
|
||||
# - -1: auto-place after latest message in chain
|
||||
# - None: regeneration from root (first message)
|
||||
# - positive int: place after that specific parent message
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
last_chat_message = chat_history[-1]
|
||||
|
||||
if new_msg_req.parent_message_id == AUTO_PLACE_AFTER_LATEST_MESSAGE:
|
||||
# Auto-place after the latest message in the chain
|
||||
parent_message = chat_history[-1] if chat_history else root_message
|
||||
elif new_msg_req.parent_message_id is None:
|
||||
# None = regeneration from root
|
||||
parent_message = root_message
|
||||
# Truncate history since we're starting from root
|
||||
chat_history = []
|
||||
else:
|
||||
# Specific parent message ID provided, find parent in chat_history
|
||||
parent_message = None
|
||||
for i in range(len(chat_history) - 1, -1, -1):
|
||||
if chat_history[i].id == new_msg_req.parent_message_id:
|
||||
parent_message = chat_history[i]
|
||||
# Truncate history to only include messages up to and including parent
|
||||
chat_history = chat_history[: i + 1]
|
||||
break
|
||||
|
||||
if parent_message is None:
|
||||
raise ValueError(
|
||||
"The new message sent is not on the latest mainline of messages"
|
||||
if last_chat_message.id != user_message.id:
|
||||
db_session.rollback()
|
||||
raise RuntimeError(
|
||||
"The new message was not on the mainline. "
|
||||
"Chat message history tree is not correctly built."
|
||||
)
|
||||
|
||||
# If the parent message is a user message, it's a regeneration and we use the existing user message.
|
||||
if parent_message.message_type == MessageType.USER:
|
||||
user_message = parent_message
|
||||
else:
|
||||
user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=parent_message,
|
||||
message=message_text,
|
||||
token_count=token_counter(message_text),
|
||||
message_type=MessageType.USER,
|
||||
files=new_msg_req.file_descriptors,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
chat_history.append(user_message)
|
||||
# At this point we can save the user message as it's validated and final
|
||||
db_session.commit()
|
||||
|
||||
memories = get_memories(user, db_session)
|
||||
|
||||
@@ -425,7 +433,7 @@ def handle_stream_message_objects(
|
||||
db_session=db_session,
|
||||
persona_system_prompt=custom_agent_prompt or "",
|
||||
token_counter=token_counter,
|
||||
files=new_msg_req.file_descriptors,
|
||||
files=last_chat_message.files,
|
||||
memories=memories,
|
||||
)
|
||||
|
||||
@@ -447,20 +455,15 @@ def handle_stream_message_objects(
|
||||
None,
|
||||
)
|
||||
|
||||
# Determine if search should be disabled for this project context
|
||||
forced_tool_id = new_msg_req.forced_tool_id
|
||||
project_search_config = _get_project_search_availability(
|
||||
# This may also mutate the new_msg_req.forced_tool_ids
|
||||
# This logic is specifically for projects
|
||||
search_usage_forcing_setting = _get_project_search_availability(
|
||||
project_id=chat_session.project_id,
|
||||
persona_id=persona.id,
|
||||
loaded_project_files=bool(extracted_project_files.project_file_texts),
|
||||
project_has_files=bool(
|
||||
extracted_project_files.project_uncapped_token_count
|
||||
),
|
||||
forced_tool_id=new_msg_req.forced_tool_id,
|
||||
has_project_file_texts=bool(extracted_project_files.project_file_texts),
|
||||
forced_tool_ids=new_msg_req.forced_tool_ids,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
if project_search_config.disable_forced_tool:
|
||||
forced_tool_id = None
|
||||
|
||||
emitter = get_default_emitter()
|
||||
|
||||
@@ -472,7 +475,7 @@ def handle_stream_message_objects(
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=new_msg_req.internal_search_filters,
|
||||
user_selected_filters=user_selected_filters,
|
||||
project_id=(
|
||||
chat_session.project_id
|
||||
if extracted_project_files.project_as_filter
|
||||
@@ -482,20 +485,17 @@ def handle_stream_message_objects(
|
||||
slack_context=slack_context,
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=chat_session.id,
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=project_search_config.search_usage,
|
||||
search_usage_forcing_setting=search_usage_forcing_setting,
|
||||
)
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
if forced_tool_id and forced_tool_id not in [tool.id for tool in tools]:
|
||||
raise ValueError(f"Forced tool {forced_tool_id} not found in tools")
|
||||
|
||||
# TODO Once summarization is done, we don't need to load all the files from the beginning anymore.
|
||||
# load all files needed for this chat chain in memory
|
||||
files = load_all_chat_files(chat_history, db_session)
|
||||
@@ -505,7 +505,7 @@ def handle_stream_message_objects(
|
||||
# Reserve a message id for the assistant response for frontend to track packets
|
||||
assistant_response = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session.id,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=user_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
@@ -529,16 +529,15 @@ def handle_stream_message_objects(
|
||||
redis_client = get_redis_client()
|
||||
|
||||
reset_cancel_status(
|
||||
chat_session.id,
|
||||
chat_session_id,
|
||||
redis_client,
|
||||
)
|
||||
|
||||
def check_is_connected() -> bool:
|
||||
return check_stop_signal(chat_session.id, redis_client)
|
||||
return check_stop_signal(chat_session_id, redis_client)
|
||||
|
||||
# Use external state container if provided, otherwise create internal one
|
||||
# External container allows non-streaming callers to access accumulated state
|
||||
state_container = external_state_container or ChatStateContainer()
|
||||
# Create state container for accumulating partial results
|
||||
state_container = ChatStateContainer()
|
||||
|
||||
# Run the LLM loop with explicit wrapper for stop signal handling
|
||||
# The wrapper runs run_llm_loop in a background thread and polls every 300ms
|
||||
@@ -566,7 +565,7 @@ def handle_stream_message_objects(
|
||||
db_session=db_session,
|
||||
skip_clarification=skip_clarification,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
chat_session_id=str(chat_session_id),
|
||||
)
|
||||
else:
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
@@ -583,15 +582,19 @@ def handle_stream_message_objects(
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
forced_tool_id=forced_tool_id,
|
||||
forced_tool_id=(
|
||||
new_msg_req.forced_tool_ids[0]
|
||||
if new_msg_req.forced_tool_ids
|
||||
else None
|
||||
),
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
chat_session_id=str(chat_session_id),
|
||||
)
|
||||
|
||||
# Determine if stopped by user
|
||||
completed_normally = check_is_connected()
|
||||
if not completed_normally:
|
||||
logger.debug(f"Chat session {chat_session.id} stopped by user")
|
||||
logger.debug(f"Chat session {chat_session_id} stopped by user")
|
||||
|
||||
# Build final answer based on completion status
|
||||
if completed_normally:
|
||||
@@ -693,63 +696,23 @@ def handle_stream_message_objects(
|
||||
return
|
||||
|
||||
|
||||
def stream_chat_message_objects(
|
||||
@log_generator_function_time()
|
||||
def stream_chat_message(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
# if specified, uses the last user message and does not create a new user message based
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
bypass_acl: bool = False,
|
||||
# Additional context that should be included in the chat history, for example:
|
||||
# Slack threads where the conversation cannot be represented by a chain of User/Assistant
|
||||
# messages. Both of the below are used for Slack
|
||||
# NOTE: is not stored in the database, only passed in to the LLM as context
|
||||
additional_context: str | None = None,
|
||||
# Slack context for federated Slack search
|
||||
slack_context: SlackContext | None = None,
|
||||
) -> AnswerStream:
|
||||
forced_tool_id = (
|
||||
new_msg_req.forced_tool_ids[0] if new_msg_req.forced_tool_ids else None
|
||||
)
|
||||
if (
|
||||
new_msg_req.retrieval_options
|
||||
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS
|
||||
):
|
||||
all_tools = get_tools(db_session)
|
||||
|
||||
search_tool_id = next(
|
||||
(tool.id for tool in all_tools if tool.in_code_tool_id == SEARCH_TOOL_ID),
|
||||
None,
|
||||
) -> Iterator[str]:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
objects = stream_chat_message_objects(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
custom_tool_additional_headers=custom_tool_additional_headers,
|
||||
)
|
||||
forced_tool_id = search_tool_id
|
||||
|
||||
translated_new_msg_req = SendMessageRequest(
|
||||
message=new_msg_req.message,
|
||||
llm_override=new_msg_req.llm_override,
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
forced_tool_id=forced_tool_id,
|
||||
file_descriptors=new_msg_req.file_descriptors,
|
||||
internal_search_filters=(
|
||||
new_msg_req.retrieval_options.filters
|
||||
if new_msg_req.retrieval_options
|
||||
else None
|
||||
),
|
||||
deep_research=new_msg_req.deep_research,
|
||||
parent_message_id=new_msg_req.parent_message_id,
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
)
|
||||
return handle_stream_message_objects(
|
||||
new_msg_req=translated_new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
custom_tool_additional_headers=custom_tool_additional_headers,
|
||||
bypass_acl=bypass_acl,
|
||||
additional_context=additional_context,
|
||||
slack_context=slack_context,
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.model_dump())
|
||||
|
||||
|
||||
def remove_answer_citations(answer: str) -> str:
|
||||
@@ -804,83 +767,3 @@ def gather_stream(
|
||||
error_msg=error_msg,
|
||||
top_documents=top_documents,
|
||||
)
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def gather_stream_full(
|
||||
packets: AnswerStream,
|
||||
state_container: ChatStateContainer,
|
||||
) -> ChatFullResponse:
|
||||
"""
|
||||
Aggregate streaming packets and state container into a complete ChatFullResponse.
|
||||
|
||||
This function consumes all packets from the stream and combines them with
|
||||
the accumulated state from the ChatStateContainer to build a complete response
|
||||
including answer, reasoning, citations, and tool calls.
|
||||
|
||||
Args:
|
||||
packets: The stream of packets from handle_stream_message_objects
|
||||
state_container: The state container that accumulates tool calls, reasoning, etc.
|
||||
|
||||
Returns:
|
||||
ChatFullResponse with all available data
|
||||
"""
|
||||
answer: str | None = None
|
||||
citations: list[CitationInfo] = []
|
||||
error_msg: str | None = None
|
||||
message_id: int | None = None
|
||||
top_documents: list[SearchDoc] = []
|
||||
chat_session_id: UUID | None = None
|
||||
|
||||
for packet in packets:
|
||||
if isinstance(packet, Packet):
|
||||
if isinstance(packet.obj, AgentResponseStart):
|
||||
if packet.obj.final_documents:
|
||||
top_documents = packet.obj.final_documents
|
||||
elif isinstance(packet.obj, AgentResponseDelta):
|
||||
if answer is None:
|
||||
answer = ""
|
||||
if packet.obj.content:
|
||||
answer += packet.obj.content
|
||||
elif isinstance(packet.obj, CitationInfo):
|
||||
citations.append(packet.obj)
|
||||
elif isinstance(packet, StreamingError):
|
||||
error_msg = packet.error
|
||||
elif isinstance(packet, MessageResponseIDInfo):
|
||||
message_id = packet.reserved_assistant_message_id
|
||||
elif isinstance(packet, CreateChatSessionID):
|
||||
chat_session_id = packet.chat_session_id
|
||||
|
||||
if message_id is None:
|
||||
raise ValueError("Message ID is required")
|
||||
|
||||
# Use state_container for complete answer (handles edge cases gracefully)
|
||||
final_answer = state_container.get_answer_tokens() or answer or ""
|
||||
|
||||
# Get reasoning from state container (None when model doesn't produce reasoning)
|
||||
reasoning = state_container.get_reasoning_tokens()
|
||||
|
||||
# Convert ToolCallInfo list to ToolCallResponse list
|
||||
tool_call_responses = [
|
||||
ToolCallResponse(
|
||||
tool_name=tc.tool_name,
|
||||
tool_arguments=tc.tool_call_arguments,
|
||||
tool_result=tc.tool_call_response,
|
||||
search_docs=tc.search_docs,
|
||||
generated_images=tc.generated_images,
|
||||
pre_reasoning=tc.reasoning_tokens,
|
||||
)
|
||||
for tc in state_container.get_tool_calls()
|
||||
]
|
||||
|
||||
return ChatFullResponse(
|
||||
answer=final_answer,
|
||||
answer_citationless=remove_answer_citations(final_answer),
|
||||
pre_answer_reasoning=reasoning,
|
||||
tool_calls=tool_call_responses,
|
||||
top_documents=top_documents,
|
||||
citation_info=citations,
|
||||
message_id=message_id,
|
||||
chat_session_id=chat_session_id,
|
||||
error_msg=error_msg,
|
||||
)
|
||||
|
||||
@@ -10,7 +10,6 @@ from onyx.file_store.models import FileDescriptor
|
||||
from onyx.prompts.chat_prompts import CITATION_REMINDER
|
||||
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
|
||||
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
|
||||
from onyx.prompts.chat_prompts import LAST_CYCLE_CITATION_REMINDER
|
||||
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
|
||||
from onyx.prompts.chat_prompts import USER_INFO_HEADER
|
||||
from onyx.prompts.prompt_utils import get_company_context
|
||||
@@ -23,7 +22,6 @@ from onyx.prompts.tool_prompts import PYTHON_TOOL_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import TOOL_SECTION_HEADER
|
||||
from onyx.prompts.tool_prompts import WEB_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
@@ -39,7 +37,7 @@ def get_default_base_system_prompt(db_session: Session) -> str:
|
||||
default_persona = get_default_behavior_persona(db_session)
|
||||
return (
|
||||
default_persona.system_prompt
|
||||
if default_persona and default_persona.system_prompt is not None
|
||||
if default_persona and default_persona.system_prompt
|
||||
else DEFAULT_SYSTEM_PROMPT
|
||||
)
|
||||
|
||||
@@ -117,11 +115,8 @@ def calculate_reserved_tokens(
|
||||
def build_reminder_message(
|
||||
reminder_text: str | None,
|
||||
include_citation_reminder: bool,
|
||||
is_last_cycle: bool,
|
||||
) -> str | None:
|
||||
reminder = reminder_text.strip() if reminder_text else ""
|
||||
if is_last_cycle:
|
||||
reminder += "\n\n" + LAST_CYCLE_CITATION_REMINDER
|
||||
if include_citation_reminder:
|
||||
reminder += "\n\n" + CITATION_REMINDER
|
||||
reminder = reminder.strip()
|
||||
@@ -174,9 +169,7 @@ def build_system_prompt(
|
||||
TOOL_SECTION_HEADER
|
||||
+ TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
+ INTERNAL_SEARCH_GUIDANCE
|
||||
+ WEB_SEARCH_GUIDANCE.format(
|
||||
site_colon_disabled=WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
)
|
||||
+ WEB_SEARCH_GUIDANCE
|
||||
+ OPEN_URLS_GUIDANCE
|
||||
+ GENERATE_IMAGE_GUIDANCE
|
||||
+ PYTHON_TOOL_GUIDANCE
|
||||
@@ -202,16 +195,7 @@ def build_system_prompt(
|
||||
system_prompt += INTERNAL_SEARCH_GUIDANCE
|
||||
|
||||
if has_web_search or include_all_guidance:
|
||||
site_disabled_guidance = ""
|
||||
if has_web_search:
|
||||
web_search_tool = next(
|
||||
(t for t in tools if isinstance(t, WebSearchTool)), None
|
||||
)
|
||||
if web_search_tool and not web_search_tool.supports_site_filter:
|
||||
site_disabled_guidance = WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
system_prompt += WEB_SEARCH_GUIDANCE.format(
|
||||
site_colon_disabled=site_disabled_guidance
|
||||
)
|
||||
system_prompt += WEB_SEARCH_GUIDANCE
|
||||
|
||||
if has_open_urls or include_all_guidance:
|
||||
system_prompt += OPEN_URLS_GUIDANCE
|
||||
|
||||
@@ -117,30 +117,22 @@ def _create_and_link_tool_calls(
|
||||
tool_call_map[tool_call_obj.tool_call_id] = tool_call_obj.id
|
||||
|
||||
# Update parent_tool_call_id for all tool calls
|
||||
# Filter out orphaned children (whose parents don't exist) - this can happen
|
||||
# when generation is stopped mid-execution and parent tool calls were cancelled
|
||||
valid_tool_calls: list[ToolCall] = []
|
||||
for tool_call_obj in tool_call_objects:
|
||||
tool_call_info = tool_call_info_map[tool_call_obj.tool_call_id]
|
||||
if tool_call_info.parent_tool_call_id is not None:
|
||||
parent_id = tool_call_map.get(tool_call_info.parent_tool_call_id)
|
||||
if parent_id is not None:
|
||||
tool_call_obj.parent_tool_call_id = parent_id
|
||||
valid_tool_calls.append(tool_call_obj)
|
||||
else:
|
||||
# Parent doesn't exist (likely cancelled) - skip this orphaned child
|
||||
logger.warning(
|
||||
f"Skipping tool call '{tool_call_obj.tool_call_id}' with missing parent "
|
||||
f"'{tool_call_info.parent_tool_call_id}' (likely cancelled during execution)"
|
||||
# This would cause chat sessions to fail if this function is miscalled with
|
||||
# tool calls that have bad parent pointers but this falls under "fail loudly"
|
||||
raise ValueError(
|
||||
f"Parent tool call with tool_call_id '{tool_call_info.parent_tool_call_id}' "
|
||||
f"not found for tool call '{tool_call_obj.tool_call_id}'"
|
||||
)
|
||||
# Remove from DB session to prevent saving
|
||||
db_session.delete(tool_call_obj)
|
||||
else:
|
||||
# Top-level tool call (no parent)
|
||||
valid_tool_calls.append(tool_call_obj)
|
||||
|
||||
# Link SearchDocs only to valid ToolCalls
|
||||
for tool_call_obj in valid_tool_calls:
|
||||
# Link SearchDocs to ToolCalls
|
||||
for tool_call_obj in tool_call_objects:
|
||||
search_doc_ids = tool_call_to_search_doc_ids.get(tool_call_obj.tool_call_id, [])
|
||||
if search_doc_ids:
|
||||
add_search_docs_to_tool_call(
|
||||
|
||||
@@ -2,23 +2,12 @@ from uuid import UUID
|
||||
|
||||
from redis.client import Redis
|
||||
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
# Redis key prefixes for chat session stop signals
|
||||
PREFIX = "chatsessionstop"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
FENCE_TTL = 10 * 60 # 10 minutes - defensive TTL to prevent memory leaks
|
||||
|
||||
|
||||
def _get_fence_key(chat_session_id: UUID) -> str:
|
||||
"""
|
||||
Generate the Redis key for a chat session stop signal fence.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
|
||||
Returns:
|
||||
The fence key string (tenant_id is automatically added by the Redis client)
|
||||
"""
|
||||
return f"{FENCE_PREFIX}_{chat_session_id}"
|
||||
FENCE_TTL = 24 * 60 * 60 # 24 hours - defensive TTL to prevent memory leaks
|
||||
|
||||
|
||||
def set_fence(chat_session_id: UUID, redis_client: Redis, value: bool) -> None:
|
||||
@@ -27,10 +16,11 @@ def set_fence(chat_session_id: UUID, redis_client: Redis, value: bool) -> None:
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: Redis client to use (tenant-aware client that auto-prefixes keys)
|
||||
redis_client: Redis client to use
|
||||
value: True to set the fence (stop signal), False to clear it
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
tenant_id = get_current_tenant_id()
|
||||
fence_key = f"{FENCE_PREFIX}_{tenant_id}_{chat_session_id}"
|
||||
if not value:
|
||||
redis_client.delete(fence_key)
|
||||
return
|
||||
@@ -44,12 +34,13 @@ def is_connected(chat_session_id: UUID, redis_client: Redis) -> bool:
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session to check
|
||||
redis_client: Redis client to use for checking the stop signal (tenant-aware client that auto-prefixes keys)
|
||||
redis_client: Redis client to use for checking the stop signal
|
||||
|
||||
Returns:
|
||||
True if the session should continue, False if it should stop
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
tenant_id = get_current_tenant_id()
|
||||
fence_key = f"{FENCE_PREFIX}_{tenant_id}_{chat_session_id}"
|
||||
return not bool(redis_client.exists(fence_key))
|
||||
|
||||
|
||||
@@ -59,7 +50,8 @@ def reset_cancel_status(chat_session_id: UUID, redis_client: Redis) -> None:
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: Redis client to use (tenant-aware client that auto-prefixes keys)
|
||||
redis_client: Redis client to use
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
tenant_id = get_current_tenant_id()
|
||||
fence_key = f"{FENCE_PREFIX}_{tenant_id}_{chat_session_id}"
|
||||
redis_client.delete(fence_key)
|
||||
|
||||
@@ -120,14 +120,6 @@ VALID_EMAIL_DOMAINS = (
|
||||
if _VALID_EMAIL_DOMAINS_STR
|
||||
else []
|
||||
)
|
||||
|
||||
# Disposable email blocking - blocks temporary/throwaway email addresses
|
||||
# Set to empty string to disable disposable email blocking
|
||||
DISPOSABLE_EMAIL_DOMAINS_URL = os.environ.get(
|
||||
"DISPOSABLE_EMAIL_DOMAINS_URL",
|
||||
"https://disposable.github.io/disposable-email-domains/domains.json",
|
||||
)
|
||||
|
||||
# OAuth Login Flow
|
||||
# Used for both Google OAuth2 and OIDC flows
|
||||
OAUTH_CLIENT_ID = (
|
||||
@@ -194,16 +186,6 @@ TRACK_EXTERNAL_IDP_EXPIRY = (
|
||||
# DB Configs
|
||||
#####
|
||||
DOCUMENT_INDEX_NAME = "danswer_index"
|
||||
|
||||
OPENSEARCH_HOST = os.environ.get("OPENSEARCH_HOST") or "localhost"
|
||||
OPENSEARCH_REST_API_PORT = int(os.environ.get("OPENSEARCH_REST_API_PORT") or 9200)
|
||||
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
|
||||
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
|
||||
|
||||
ENABLE_OPENSEARCH_FOR_ONYX = (
|
||||
os.environ.get("ENABLE_OPENSEARCH_FOR_ONYX", "").lower() == "true"
|
||||
)
|
||||
|
||||
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
# NOTE: this is used if and only if the vespa config server is accessible via a
|
||||
# different host than the main vespa application
|
||||
@@ -679,6 +661,10 @@ INDEXING_EMBEDDING_MODEL_NUM_THREADS = int(
|
||||
os.environ.get("INDEXING_EMBEDDING_MODEL_NUM_THREADS") or 8
|
||||
)
|
||||
|
||||
# Maximum number of user file connector credential pairs to index in a single batch
|
||||
# Setting this number too high may overload the indexing process
|
||||
USER_FILE_INDEXING_LIMIT = int(os.environ.get("USER_FILE_INDEXING_LIMIT") or 100)
|
||||
|
||||
# Maximum file size in a document to be indexed
|
||||
MAX_DOCUMENT_CHARS = int(os.environ.get("MAX_DOCUMENT_CHARS") or 5_000_000)
|
||||
MAX_FILE_SIZE_BYTES = int(
|
||||
@@ -750,27 +736,7 @@ BRAINTRUST_PROJECT = os.environ.get("BRAINTRUST_PROJECT", "Onyx")
|
||||
# Braintrust API key - if provided, Braintrust tracing will be enabled
|
||||
BRAINTRUST_API_KEY = os.environ.get("BRAINTRUST_API_KEY") or ""
|
||||
# Maximum concurrency for Braintrust evaluations
|
||||
# None means unlimited concurrency, otherwise specify a number
|
||||
_braintrust_concurrency = os.environ.get("BRAINTRUST_MAX_CONCURRENCY")
|
||||
BRAINTRUST_MAX_CONCURRENCY = (
|
||||
int(_braintrust_concurrency) if _braintrust_concurrency else None
|
||||
)
|
||||
|
||||
#####
|
||||
# Scheduled Evals Configuration
|
||||
#####
|
||||
# Comma-separated list of Braintrust dataset names to run on schedule
|
||||
SCHEDULED_EVAL_DATASET_NAMES = [
|
||||
name.strip()
|
||||
for name in os.environ.get("SCHEDULED_EVAL_DATASET_NAMES", "").split(",")
|
||||
if name.strip()
|
||||
]
|
||||
# Email address to use for search permissions during scheduled evals
|
||||
SCHEDULED_EVAL_PERMISSIONS_EMAIL = os.environ.get(
|
||||
"SCHEDULED_EVAL_PERMISSIONS_EMAIL", "roshan@onyx.app"
|
||||
)
|
||||
# Braintrust project name to use for scheduled evals
|
||||
SCHEDULED_EVAL_PROJECT = os.environ.get("SCHEDULED_EVAL_PROJECT", "st-dev")
|
||||
BRAINTRUST_MAX_CONCURRENCY = int(os.environ.get("BRAINTRUST_MAX_CONCURRENCY") or 5)
|
||||
|
||||
#####
|
||||
# Langfuse Configuration
|
||||
@@ -806,16 +772,8 @@ try:
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Auto LLM Configuration - fetches model configs from GitHub for providers in Auto mode
|
||||
AUTO_LLM_CONFIG_URL = os.environ.get(
|
||||
"AUTO_LLM_CONFIG_URL",
|
||||
"https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/onyx/llm/well_known_providers/recommended-models.json",
|
||||
)
|
||||
|
||||
# How often to check for auto LLM model updates (in seconds)
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS = int(
|
||||
os.environ.get("AUTO_LLM_UPDATE_INTERVAL_SECONDS", 1800) # 30 minutes
|
||||
)
|
||||
# LLM Model Update API endpoint
|
||||
LLM_MODEL_UPDATE_API_URL = os.environ.get("LLM_MODEL_UPDATE_API_URL")
|
||||
|
||||
#####
|
||||
# Enterprise Edition Configs
|
||||
@@ -828,11 +786,6 @@ ENTERPRISE_EDITION_ENABLED = (
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
|
||||
)
|
||||
|
||||
#####
|
||||
# Image Generation Configuration (DEPRECATED)
|
||||
# These environment variables will be deprecated soon.
|
||||
# To configure image generation, please visit the Image Generation page in the Admin Panel.
|
||||
#####
|
||||
# Azure Image Configurations
|
||||
AZURE_IMAGE_API_VERSION = os.environ.get("AZURE_IMAGE_API_VERSION") or os.environ.get(
|
||||
"AZURE_DALLE_API_VERSION"
|
||||
@@ -917,19 +870,6 @@ DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
|
||||
|
||||
INTEGRATION_TESTS_MODE = os.environ.get("INTEGRATION_TESTS_MODE", "").lower() == "true"
|
||||
|
||||
#####
|
||||
# Captcha Configuration (for cloud signup protection)
|
||||
#####
|
||||
# Enable captcha verification for new user registration
|
||||
CAPTCHA_ENABLED = os.environ.get("CAPTCHA_ENABLED", "").lower() == "true"
|
||||
|
||||
# Google reCAPTCHA secret key (server-side validation)
|
||||
RECAPTCHA_SECRET_KEY = os.environ.get("RECAPTCHA_SECRET_KEY", "")
|
||||
|
||||
# Minimum score threshold for reCAPTCHA v3 (0.0-1.0, higher = more likely human)
|
||||
# 0.5 is the recommended default
|
||||
RECAPTCHA_SCORE_THRESHOLD = float(os.environ.get("RECAPTCHA_SCORE_THRESHOLD", "0.5"))
|
||||
|
||||
MOCK_CONNECTOR_FILE_PATH = os.environ.get("MOCK_CONNECTOR_FILE_PATH")
|
||||
|
||||
# Set to true to mock LLM responses for testing purposes
|
||||
@@ -983,15 +923,3 @@ S3_GENERATE_LOCAL_CHECKSUM = (
|
||||
# Forcing Vespa Language
|
||||
# English: en, German:de, etc. See: https://docs.vespa.ai/en/linguistics.html
|
||||
VESPA_LANGUAGE_OVERRIDE = os.environ.get("VESPA_LANGUAGE_OVERRIDE")
|
||||
|
||||
|
||||
#####
|
||||
# Default LLM API Keys (for cloud deployments)
|
||||
# These are Onyx-managed API keys provided to tenants by default
|
||||
#####
|
||||
OPENAI_DEFAULT_API_KEY = os.environ.get("OPENAI_DEFAULT_API_KEY")
|
||||
ANTHROPIC_DEFAULT_API_KEY = os.environ.get("ANTHROPIC_DEFAULT_API_KEY")
|
||||
COHERE_DEFAULT_API_KEY = os.environ.get("COHERE_DEFAULT_API_KEY")
|
||||
VERTEXAI_DEFAULT_CREDENTIALS = os.environ.get("VERTEXAI_DEFAULT_CREDENTIALS")
|
||||
VERTEXAI_DEFAULT_LOCATION = os.environ.get("VERTEXAI_DEFAULT_LOCATION", "global")
|
||||
OPENROUTER_DEFAULT_API_KEY = os.environ.get("OPENROUTER_DEFAULT_API_KEY")
|
||||
|
||||
@@ -11,6 +11,9 @@ NUM_POSTPROCESSED_RESULTS = 20
|
||||
|
||||
# May be less depending on model
|
||||
MAX_CHUNKS_FED_TO_CHAT = int(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 25)
|
||||
# For Chat, need to keep enough space for history and other prompt pieces
|
||||
# ~3k input, half for docs, half for chat history + prompts
|
||||
CHAT_TARGET_CHUNK_PERCENTAGE = 512 * 3 / 3072
|
||||
|
||||
# Maximum percentage of the context window to fill with selected sections
|
||||
SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE = 0.8
|
||||
|
||||
@@ -146,6 +146,9 @@ CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 3600 # 1 hour (in seconds)
|
||||
|
||||
CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
# Doc ID migration can be long-running; use a longer TTL and renew periodically
|
||||
CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT = 10 * 60 # 10 minutes (in seconds)
|
||||
|
||||
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
@@ -234,7 +237,6 @@ class NotificationType(str, Enum):
|
||||
REINDEX = "reindex"
|
||||
PERSONA_SHARED = "persona_shared"
|
||||
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending" # 2 days left in trial
|
||||
RELEASE_NOTES = "release_notes"
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
@@ -363,6 +365,9 @@ class OnyxCeleryQueues:
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC = "connector_external_group_sync"
|
||||
CSV_GENERATION = "csv_generation"
|
||||
|
||||
# Indexing queue
|
||||
USER_FILES_INDEXING = "user_files_indexing"
|
||||
|
||||
# User file processing queue
|
||||
USER_FILE_PROCESSING = "user_file_processing"
|
||||
USER_FILE_PROJECT_SYNC = "user_file_project_sync"
|
||||
@@ -421,6 +426,7 @@ class OnyxRedisLocks:
|
||||
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
USER_FILE_DELETE_LOCK_PREFIX = "da_lock:user_file_delete"
|
||||
USER_FILE_DOCID_MIGRATION_LOCK = "da_lock:user_file_docid_migration"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
@@ -486,7 +492,7 @@ class OnyxCeleryTask:
|
||||
CHECK_FOR_PRUNING = "check_for_pruning"
|
||||
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
CHECK_FOR_AUTO_LLM_UPDATE = "check_for_auto_llm_update"
|
||||
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
|
||||
|
||||
# User file processing
|
||||
CHECK_FOR_USER_FILE_PROCESSING = "check_for_user_file_processing"
|
||||
@@ -527,6 +533,7 @@ class OnyxCeleryTask:
|
||||
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
|
||||
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
|
||||
USER_FILE_DOCID_MIGRATION = "user_file_docid_migration"
|
||||
|
||||
# chat retention
|
||||
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
|
||||
@@ -535,7 +542,6 @@ class OnyxCeleryTask:
|
||||
GENERATE_USAGE_REPORT_TASK = "generate_usage_report_task"
|
||||
|
||||
EVAL_RUN_TASK = "eval_run_task"
|
||||
SCHEDULED_EVAL_TASK = "scheduled_eval_task"
|
||||
|
||||
EXPORT_QUERY_HISTORY_TASK = "export_query_history_task"
|
||||
EXPORT_QUERY_HISTORY_CLEANUP_TASK = "export_query_history_cleanup_task"
|
||||
@@ -556,9 +562,9 @@ REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
|
||||
|
||||
if platform.system() == "Darwin":
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore[attr-defined,unused-ignore]
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore
|
||||
else:
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore[attr-defined,unused-ignore]
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class OnyxCallTypes(str, Enum):
|
||||
|
||||
@@ -128,17 +128,3 @@ if _LITELLM_EXTRA_BODY_RAW:
|
||||
LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
#####
|
||||
# Prompt Caching Configs
|
||||
#####
|
||||
# Enable prompt caching framework
|
||||
ENABLE_PROMPT_CACHING = (
|
||||
os.environ.get("ENABLE_PROMPT_CACHING", "true").lower() != "false"
|
||||
)
|
||||
|
||||
# Cache TTL multiplier - store caches slightly longer than provider TTL
|
||||
# This allows for some clock skew and ensures we don't lose cache metadata prematurely
|
||||
PROMPT_CACHE_REDIS_TTL_MULTIPLIER = float(
|
||||
os.environ.get("PROMPT_CACHE_REDIS_TTL_MULTIPLIER") or 1.2
|
||||
)
|
||||
|
||||
@@ -961,20 +961,14 @@ def get_user_email_from_username__server(
|
||||
try:
|
||||
response = confluence_client.get_mobile_parameters(user_name)
|
||||
email = response.get("email")
|
||||
except HTTPError as e:
|
||||
status_code = e.response.status_code if e.response is not None else "N/A"
|
||||
logger.warning(
|
||||
f"Failed to get confluence email for {user_name}: "
|
||||
f"HTTP {status_code} - {e}"
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(f"failed to get confluence email for {user_name}")
|
||||
# For now, we'll just return None and log a warning. This means
|
||||
# we will keep retrying to get the email every group sync.
|
||||
email = None
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to get confluence email for {user_name}: {type(e).__name__} - {e}"
|
||||
)
|
||||
email = None
|
||||
# We may want to just return a string that indicates failure so we dont
|
||||
# keep retrying
|
||||
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
||||
_USER_EMAIL_CACHE[user_name] = email
|
||||
return _USER_EMAIL_CACHE[user_name]
|
||||
|
||||
|
||||
@@ -58,59 +58,51 @@ class DrupalWikiConnector(
|
||||
CheckpointedConnector[DrupalWikiCheckpoint],
|
||||
SlimConnector,
|
||||
):
|
||||
# Deprecated parameters that may exist in old connector configurations
|
||||
_DEPRECATED_PARAMS = {"drupal_wiki_scope", "include_all_spaces"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
spaces: list[str] | None = None,
|
||||
pages: list[str] | None = None,
|
||||
include_all_spaces: bool = False,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
drupal_wiki_scope: str | None = None,
|
||||
include_attachments: bool = False,
|
||||
allow_images: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the Drupal Wiki connector.
|
||||
|
||||
Args:
|
||||
base_url: The base URL of the Drupal Wiki instance (e.g., https://help.drupal-wiki.com)
|
||||
spaces: List of space IDs to index. If None and pages is also None, all spaces will be indexed.
|
||||
pages: List of page IDs to index. If provided, these specific pages will be indexed.
|
||||
spaces: List of space IDs to index. If None and include_all_spaces is False, no spaces will be indexed.
|
||||
pages: List of page IDs to index. If provided, only these specific pages will be indexed.
|
||||
include_all_spaces: If True, all spaces will be indexed regardless of the spaces parameter.
|
||||
batch_size: Number of documents to process in a batch.
|
||||
continue_on_failure: If True, continue indexing even if some documents fail.
|
||||
drupal_wiki_scope: The selected tab value from the frontend. If "all_spaces", all spaces will be indexed.
|
||||
include_attachments: If True, enable processing of page attachments including images and documents.
|
||||
allow_images: If True, enable processing of image attachments.
|
||||
"""
|
||||
|
||||
#########################################################
|
||||
# TODO: Remove this after 02/01/2026 and remove **kwargs from the function signature
|
||||
# Check for deprecated parameters from old connector configurations
|
||||
# If attempting to update without deleting the connector:
|
||||
# Remove the deprecated parameters from the custom_connector_config in the relevant connector table rows
|
||||
deprecated_found = set(kwargs.keys()) & self._DEPRECATED_PARAMS
|
||||
if deprecated_found:
|
||||
raise ConnectorValidationError(
|
||||
f"Outdated Drupal Wiki connector configuration detected "
|
||||
f"(found deprecated parameters: {', '.join(deprecated_found)}). "
|
||||
f"Please delete and recreate this connector, or contact Onyx support "
|
||||
f"for assistance with updating the configuration without deleting the connector."
|
||||
)
|
||||
# Reject any other unexpected parameters
|
||||
if kwargs:
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected parameters for Drupal Wiki connector: {', '.join(kwargs.keys())}"
|
||||
)
|
||||
#########################################################
|
||||
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.spaces = spaces or []
|
||||
self.pages = pages or []
|
||||
|
||||
# If no specific spaces or pages are provided, index all spaces
|
||||
self.include_all_spaces = not self.spaces and not self.pages
|
||||
# Determine whether to include all spaces based on the selected tab
|
||||
# If drupal_wiki_scope is "all_spaces", we should index all spaces
|
||||
# If it's "specific_spaces", we should only index the specified spaces
|
||||
# If it's None, we use the include_all_spaces parameter
|
||||
|
||||
if drupal_wiki_scope is not None:
|
||||
logger.debug(f"drupal_wiki_scope is set to {drupal_wiki_scope}")
|
||||
|
||||
self.include_all_spaces = drupal_wiki_scope == "all_spaces"
|
||||
# If scope is specific_spaces, include_all_spaces correctly defaults to False
|
||||
else:
|
||||
logger.debug(
|
||||
f"drupal_wiki_scope is not set, using include_all_spaces={include_all_spaces}"
|
||||
)
|
||||
self.include_all_spaces = include_all_spaces
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
|
||||
@@ -8,13 +8,10 @@ from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Protocol
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import urlunparse
|
||||
|
||||
from google.auth.exceptions import RefreshError
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
@@ -41,6 +38,9 @@ from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
|
||||
from onyx.connectors.google_drive.file_retrieval import (
|
||||
get_all_files_in_my_drive_and_shared,
|
||||
)
|
||||
from onyx.connectors.google_drive.file_retrieval import (
|
||||
get_files_by_web_view_links_batch,
|
||||
)
|
||||
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
|
||||
from onyx.connectors.google_drive.file_retrieval import has_link_only_permission
|
||||
@@ -66,7 +66,7 @@ from onyx.connectors.google_utils.shared_constants import USER_FIELDS
|
||||
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import NormalizationResult
|
||||
from onyx.connectors.interfaces import Resolver
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
@@ -157,7 +157,9 @@ class DriveIdStatus(Enum):
|
||||
|
||||
|
||||
class GoogleDriveConnector(
|
||||
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
|
||||
SlimConnectorWithPermSync,
|
||||
CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint],
|
||||
Resolver,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -284,54 +286,6 @@ class GoogleDriveConnector(
|
||||
)
|
||||
return self._creds
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def normalize_url(cls, url: str) -> NormalizationResult:
|
||||
"""Normalize a Google Drive URL to match the canonical Document.id format.
|
||||
|
||||
Reuses the connector's existing document ID creation logic from
|
||||
onyx_document_id_from_drive_file.
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
netloc = parsed.netloc.lower()
|
||||
|
||||
if not (
|
||||
netloc.startswith("docs.google.com")
|
||||
or netloc.startswith("drive.google.com")
|
||||
):
|
||||
return NormalizationResult(normalized_url=None, use_default=False)
|
||||
|
||||
# Handle ?id= query parameter case
|
||||
query_params = parse_qs(parsed.query)
|
||||
doc_id = query_params.get("id", [None])[0]
|
||||
if doc_id:
|
||||
scheme = parsed.scheme or "https"
|
||||
netloc = "drive.google.com"
|
||||
path = f"/file/d/{doc_id}"
|
||||
params = ""
|
||||
query = ""
|
||||
fragment = ""
|
||||
normalized = urlunparse(
|
||||
(scheme, netloc, path, params, query, fragment)
|
||||
).rstrip("/")
|
||||
return NormalizationResult(normalized_url=normalized, use_default=False)
|
||||
|
||||
# Extract file ID and use connector's function
|
||||
path_parts = parsed.path.split("/")
|
||||
file_id = None
|
||||
for i, part in enumerate(path_parts):
|
||||
if part == "d" and i + 1 < len(path_parts):
|
||||
file_id = path_parts[i + 1]
|
||||
break
|
||||
|
||||
if not file_id:
|
||||
return NormalizationResult(normalized_url=None, use_default=False)
|
||||
|
||||
# Create minimal file object for connector function
|
||||
file_obj = {"webViewLink": url, "id": file_id}
|
||||
normalized = onyx_document_id_from_drive_file(file_obj).rstrip("/")
|
||||
return NormalizationResult(normalized_url=normalized, use_default=False)
|
||||
|
||||
# TODO: ensure returned new_creds_dict is actually persisted when this is called?
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
|
||||
try:
|
||||
@@ -1181,26 +1135,67 @@ class GoogleDriveConnector(
|
||||
end=end,
|
||||
)
|
||||
|
||||
def _extract_docs_from_google_drive(
|
||||
def _convert_retrieved_files_to_documents(
|
||||
self,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
drive_files_iter: Iterator[RetrievedDriveFile],
|
||||
include_permissions: bool,
|
||||
) -> Iterator[Document | ConnectorFailure]:
|
||||
"""
|
||||
Retrieves and converts Google Drive files to documents.
|
||||
Converts retrieved files to documents.
|
||||
"""
|
||||
field_type = (
|
||||
DriveFileFieldType.WITH_PERMISSIONS
|
||||
if include_permissions or self.exclude_domain_link_only
|
||||
else DriveFileFieldType.STANDARD
|
||||
files_batch: list[RetrievedDriveFile] = []
|
||||
for retrieved_file in drive_files_iter:
|
||||
if self.exclude_domain_link_only and has_link_only_permission(
|
||||
retrieved_file.drive_file
|
||||
):
|
||||
continue
|
||||
if retrieved_file.error is None:
|
||||
files_batch.append(retrieved_file)
|
||||
continue
|
||||
# handle retrieval errors
|
||||
failure_stage = retrieved_file.completion_stage.value
|
||||
failure_message = f"retrieval failure during stage: {failure_stage},"
|
||||
failure_message += f"user: {retrieved_file.user_email},"
|
||||
failure_message += f"parent drive/folder: {retrieved_file.parent_id},"
|
||||
failure_message += f"error: {retrieved_file.error}"
|
||||
logger.error(failure_message)
|
||||
yield ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=failure_stage,
|
||||
),
|
||||
failure_message=failure_message,
|
||||
exception=retrieved_file.error,
|
||||
)
|
||||
|
||||
# Process the batch using run_functions_tuples_in_parallel
|
||||
func_with_args = [
|
||||
(
|
||||
self._convert_retrieved_file_to_document,
|
||||
(retrieved_file, include_permissions),
|
||||
)
|
||||
for retrieved_file in files_batch
|
||||
]
|
||||
results = cast(
|
||||
list[Document | ConnectorFailure | None],
|
||||
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
|
||||
)
|
||||
|
||||
results_cleaned = [result for result in results if result is not None]
|
||||
logger.debug(f"batch has {len(results_cleaned)} docs or failures")
|
||||
|
||||
yield from results_cleaned
|
||||
|
||||
def _convert_retrieved_file_to_document(
|
||||
self,
|
||||
retrieved_file: RetrievedDriveFile,
|
||||
include_permissions: bool,
|
||||
) -> Document | ConnectorFailure | None:
|
||||
"""
|
||||
Converts a retrieved file to a document.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Prepare a partial function with the credentials and admin email
|
||||
convert_func = partial(
|
||||
convert_drive_item_to_document,
|
||||
return convert_drive_item_to_document(
|
||||
self.creds,
|
||||
self.allow_images,
|
||||
self.size_threshold,
|
||||
@@ -1212,83 +1207,15 @@ class GoogleDriveConnector(
|
||||
if include_permissions
|
||||
else None
|
||||
),
|
||||
)
|
||||
# Fetch files in batches
|
||||
batches_complete = 0
|
||||
files_batch: list[RetrievedDriveFile] = []
|
||||
|
||||
def _yield_batch(
|
||||
files_batch: list[RetrievedDriveFile],
|
||||
) -> Iterator[Document | ConnectorFailure]:
|
||||
nonlocal batches_complete
|
||||
# Process the batch using run_functions_tuples_in_parallel
|
||||
func_with_args = [
|
||||
(
|
||||
convert_func,
|
||||
(
|
||||
[file.user_email, self.primary_admin_email]
|
||||
+ get_file_owners(
|
||||
file.drive_file, self.primary_admin_email
|
||||
),
|
||||
file.drive_file,
|
||||
),
|
||||
)
|
||||
for file in files_batch
|
||||
]
|
||||
results = cast(
|
||||
list[Document | ConnectorFailure | None],
|
||||
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
|
||||
)
|
||||
logger.debug(
|
||||
f"finished processing batch {batches_complete} with {len(results)} results"
|
||||
)
|
||||
|
||||
docs_and_failures = [result for result in results if result is not None]
|
||||
logger.debug(
|
||||
f"batch {batches_complete} has {len(docs_and_failures)} docs or failures"
|
||||
)
|
||||
|
||||
if docs_and_failures:
|
||||
yield from docs_and_failures
|
||||
batches_complete += 1
|
||||
logger.debug(f"finished yielding batch {batches_complete}")
|
||||
|
||||
for retrieved_file in self._fetch_drive_items(
|
||||
field_type=field_type,
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
if self.exclude_domain_link_only and has_link_only_permission(
|
||||
retrieved_file.drive_file
|
||||
):
|
||||
continue
|
||||
if retrieved_file.error is None:
|
||||
files_batch.append(retrieved_file)
|
||||
continue
|
||||
|
||||
# handle retrieval errors
|
||||
failure_stage = retrieved_file.completion_stage.value
|
||||
failure_message = f"retrieval failure during stage: {failure_stage},"
|
||||
failure_message += f"user: {retrieved_file.user_email},"
|
||||
failure_message += f"parent drive/folder: {retrieved_file.parent_id},"
|
||||
failure_message += f"error: {retrieved_file.error}"
|
||||
logger.error(failure_message)
|
||||
yield ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=failure_stage,
|
||||
),
|
||||
failure_message=failure_message,
|
||||
exception=retrieved_file.error,
|
||||
)
|
||||
|
||||
yield from _yield_batch(files_batch)
|
||||
checkpoint.retrieved_folder_and_drive_ids = (
|
||||
self._retrieved_folder_and_drive_ids
|
||||
[retrieved_file.user_email, self.primary_admin_email]
|
||||
+ get_file_owners(retrieved_file.drive_file, self.primary_admin_email),
|
||||
retrieved_file.drive_file,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error extracting documents from Google Drive: {e}")
|
||||
logger.exception(
|
||||
f"Error extracting document: {retrieved_file.drive_file.get('name')} from Google Drive"
|
||||
)
|
||||
raise e
|
||||
|
||||
def _load_from_checkpoint(
|
||||
@@ -1313,8 +1240,19 @@ class GoogleDriveConnector(
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
self._retrieved_folder_and_drive_ids = checkpoint.retrieved_folder_and_drive_ids
|
||||
try:
|
||||
yield from self._extract_docs_from_google_drive(
|
||||
checkpoint, start, end, include_permissions
|
||||
field_type = (
|
||||
DriveFileFieldType.WITH_PERMISSIONS
|
||||
if include_permissions or self.exclude_domain_link_only
|
||||
else DriveFileFieldType.STANDARD
|
||||
)
|
||||
drive_files_iter = self._fetch_drive_items(
|
||||
field_type=field_type,
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
yield from self._convert_retrieved_files_to_documents(
|
||||
drive_files_iter, include_permissions
|
||||
)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
@@ -1351,6 +1289,43 @@ class GoogleDriveConnector(
|
||||
start, end, checkpoint, include_permissions=True
|
||||
)
|
||||
|
||||
@override
|
||||
def resolve_errors(
|
||||
self, errors: list[ConnectorFailure], include_permissions: bool = False
|
||||
) -> Generator[Document | ConnectorFailure, None, None]:
|
||||
"""Attempts to yield back ALL the documents described by the error, no checkpointing.
|
||||
caller's responsibility is to delete the old connectorfailures and replace with the new ones.
|
||||
"""
|
||||
if self._creds is None or self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Credentials missing, should not call this method before calling load_credentials"
|
||||
)
|
||||
|
||||
logger.info(f"Resolving {len(errors)} errors")
|
||||
doc_ids = set(
|
||||
failure.failed_document.document_id
|
||||
for failure in errors
|
||||
if failure.failed_document
|
||||
)
|
||||
service = get_drive_service(self.creds, self.primary_admin_email)
|
||||
field_type = (
|
||||
DriveFileFieldType.WITH_PERMISSIONS
|
||||
if include_permissions or self.exclude_domain_link_only
|
||||
else DriveFileFieldType.STANDARD
|
||||
)
|
||||
files = get_files_by_web_view_links_batch(service, list(doc_ids), field_type)
|
||||
retrieved_iter = (
|
||||
RetrievedDriveFile(
|
||||
drive_file=file,
|
||||
user_email=self.primary_admin_email,
|
||||
completion_stage=DriveRetrievalStage.DONE,
|
||||
)
|
||||
for file in files.values()
|
||||
)
|
||||
yield from self._convert_retrieved_files_to_documents(
|
||||
retrieved_iter, include_permissions
|
||||
)
|
||||
|
||||
def _extract_slim_docs_from_google_drive(
|
||||
self,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
|
||||
@@ -3,9 +3,14 @@ from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from enum import Enum
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from googleapiclient.http import BatchHttpRequest # type: ignore
|
||||
|
||||
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
|
||||
from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
|
||||
@@ -52,6 +57,8 @@ SLIM_FILE_FIELDS = (
|
||||
)
|
||||
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
|
||||
|
||||
MAX_BATCH_SIZE = 100
|
||||
|
||||
|
||||
def generate_time_range_filter(
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
@@ -393,3 +400,98 @@ def get_root_folder_id(service: Resource) -> str:
|
||||
.get(fileId="root", fields=GoogleFields.ID.value)
|
||||
.execute()[GoogleFields.ID.value]
|
||||
)
|
||||
|
||||
|
||||
def _extract_file_id_from_web_view_link(web_view_link: str) -> str:
|
||||
parsed = urlparse(web_view_link)
|
||||
path_parts = [part for part in parsed.path.split("/") if part]
|
||||
|
||||
if "d" in path_parts:
|
||||
idx = path_parts.index("d")
|
||||
if idx + 1 < len(path_parts):
|
||||
return path_parts[idx + 1]
|
||||
|
||||
query_params = parse_qs(parsed.query)
|
||||
for key in ("id", "fileId"):
|
||||
value = query_params.get(key)
|
||||
if value and value[0]:
|
||||
return value[0]
|
||||
|
||||
raise ValueError(
|
||||
f"Unable to extract Drive file id from webViewLink: {web_view_link}"
|
||||
)
|
||||
|
||||
|
||||
def get_file_by_web_view_link(
|
||||
service: GoogleDriveService,
|
||||
web_view_link: str,
|
||||
fields: str,
|
||||
) -> GoogleDriveFileType:
|
||||
"""
|
||||
Retrieve a Google Drive file using its webViewLink.
|
||||
"""
|
||||
file_id = _extract_file_id_from_web_view_link(web_view_link)
|
||||
return (
|
||||
service.files()
|
||||
.get(
|
||||
fileId=file_id,
|
||||
supportsAllDrives=True,
|
||||
fields=fields,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
|
||||
def get_files_by_web_view_links_batch(
|
||||
service: GoogleDriveService,
|
||||
web_view_links: list[str],
|
||||
field_type: DriveFileFieldType,
|
||||
) -> dict[str, GoogleDriveFileType]:
|
||||
fields = _get_fields_for_file_type(field_type)
|
||||
if len(web_view_links) <= MAX_BATCH_SIZE:
|
||||
return _get_files_by_web_view_links_batch(service, web_view_links, fields)
|
||||
|
||||
ret = {}
|
||||
for i in range(0, len(web_view_links), MAX_BATCH_SIZE):
|
||||
batch = web_view_links[i : i + MAX_BATCH_SIZE]
|
||||
ret.update(_get_files_by_web_view_links_batch(service, batch, fields))
|
||||
return ret
|
||||
|
||||
|
||||
def _get_files_by_web_view_links_batch(
|
||||
service: GoogleDriveService,
|
||||
web_view_links: list[str],
|
||||
fields: str,
|
||||
) -> dict[str, GoogleDriveFileType]:
|
||||
"""
|
||||
Retrieve multiple Google Drive files using their webViewLinks in a single batch request.
|
||||
|
||||
Returns a dict mapping web_view_link to file metadata.
|
||||
Failed requests (due to invalid links or API errors) are omitted from the result.
|
||||
"""
|
||||
|
||||
def callback(
|
||||
request_id: str, response: GoogleDriveFileType, exception: Exception | None
|
||||
) -> None:
|
||||
if exception:
|
||||
logger.warning(f"Error retrieving file {request_id}: {exception}")
|
||||
else:
|
||||
results[request_id] = response
|
||||
|
||||
results: Dict[str, GoogleDriveFileType] = {}
|
||||
batch = cast(BatchHttpRequest, service.new_batch_http_request(callback=callback))
|
||||
|
||||
for web_view_link in web_view_links:
|
||||
try:
|
||||
file_id = _extract_file_id_from_web_view_link(web_view_link)
|
||||
request = service.files().get(
|
||||
fileId=file_id,
|
||||
supportsAllDrives=True,
|
||||
fields=fields,
|
||||
)
|
||||
batch.add(request, request_id=web_view_link)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to extract file ID from {web_view_link}: {e}")
|
||||
|
||||
batch.execute()
|
||||
return results
|
||||
|
||||
@@ -25,18 +25,6 @@ GenerateSlimDocumentOutput = Iterator[list[SlimDocument]]
|
||||
CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
|
||||
|
||||
class NormalizationResult(BaseModel):
|
||||
"""Result of URL normalization attempt.
|
||||
|
||||
Attributes:
|
||||
normalized_url: The normalized URL string, or None if normalization failed
|
||||
use_default: If True, fall back to default normalizer. If False, return None.
|
||||
"""
|
||||
|
||||
normalized_url: str | None
|
||||
use_default: bool = False
|
||||
|
||||
|
||||
class BaseConnector(abc.ABC, Generic[CT]):
|
||||
REDIS_KEY_PREFIX = "da_connector_data:"
|
||||
|
||||
@@ -86,15 +74,6 @@ class BaseConnector(abc.ABC, Generic[CT]):
|
||||
"""Implement if the underlying connector wants to skip/allow image downloading
|
||||
based on the application level image analysis setting."""
|
||||
|
||||
@classmethod
|
||||
def normalize_url(cls, url: str) -> "NormalizationResult":
|
||||
"""Normalize a URL to match the canonical Document.id format used during ingestion.
|
||||
|
||||
Connectors that use URLs as document IDs should override this method.
|
||||
Returns NormalizationResult with use_default=True if not implemented.
|
||||
"""
|
||||
return NormalizationResult(normalized_url=None, use_default=True)
|
||||
|
||||
def build_dummy_checkpoint(self) -> CT:
|
||||
# TODO: find a way to make this work without type: ignore
|
||||
return ConnectorCheckpoint(has_more=True) # type: ignore
|
||||
@@ -290,3 +269,15 @@ class CheckpointedConnectorWithPermSync(CheckpointedConnector[CT]):
|
||||
checkpoint: CT,
|
||||
) -> CheckpointOutput[CT]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Resolver:
|
||||
@abc.abstractmethod
|
||||
def resolve_errors(
|
||||
self, errors: list[ConnectorFailure], include_permissions: bool = False
|
||||
) -> Generator[Document | ConnectorFailure, None, None]:
|
||||
"""Attempts to yield back ALL the documents described by the error, no checkpointing.
|
||||
caller's responsibility is to delete the old connectorfailures and replace with the new ones.
|
||||
If include_permissions is True, the documents will have permissions synced.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import LINEAR_CLIENT_ID
|
||||
@@ -19,7 +16,6 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import NormalizationResult
|
||||
from onyx.connectors.interfaces import OAuthConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
@@ -316,31 +312,6 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
|
||||
yield from self._process_issues(start_str=start_time, end_str=end_time)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def normalize_url(cls, url: str) -> NormalizationResult:
|
||||
"""Extract Linear issue identifier from URL.
|
||||
|
||||
Linear URLs are like: https://linear.app/team/issue/IDENTIFIER/...
|
||||
Returns the identifier (e.g., "DAN-2327") which can be used to match Document.link.
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
netloc = parsed.netloc.lower()
|
||||
|
||||
if "linear.app" not in netloc:
|
||||
return NormalizationResult(normalized_url=None, use_default=False)
|
||||
|
||||
# Extract identifier from path: /team/issue/IDENTIFIER/...
|
||||
# Pattern: /{team}/issue/{identifier}/...
|
||||
path_parts = [p for p in parsed.path.split("/") if p]
|
||||
if len(path_parts) >= 3 and path_parts[1] == "issue":
|
||||
identifier = path_parts[2]
|
||||
# Validate identifier format (e.g., "DAN-2327")
|
||||
if re.match(r"^[A-Z]+-\d+$", identifier):
|
||||
return NormalizationResult(normalized_url=identifier, use_default=False)
|
||||
|
||||
return NormalizationResult(normalized_url=None, use_default=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = LinearConnector()
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP
|
||||
@@ -25,7 +21,6 @@ from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import NormalizationResult
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
@@ -106,49 +101,6 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
# very large, this may not be practical.
|
||||
self.recursive_index_enabled = recursive_index_enabled or self.root_page_id
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def normalize_url(cls, url: str) -> NormalizationResult:
|
||||
"""Normalize a Notion URL to extract the page ID (UUID format)."""
|
||||
parsed = urlparse(url)
|
||||
netloc = parsed.netloc.lower()
|
||||
|
||||
if not ("notion.so" in netloc or "notion.site" in netloc):
|
||||
return NormalizationResult(normalized_url=None, use_default=False)
|
||||
|
||||
# Extract page ID from path (format: "Title-PageID")
|
||||
path_last = parsed.path.split("/")[-1]
|
||||
candidate = path_last.split("-")[-1] if "-" in path_last else path_last
|
||||
|
||||
# Clean and format as UUID
|
||||
candidate = re.sub(r"[^0-9a-fA-F-]", "", candidate)
|
||||
cleaned = candidate.replace("-", "")
|
||||
|
||||
if len(cleaned) == 32 and re.fullmatch(r"[0-9a-fA-F]{32}", cleaned):
|
||||
normalized_uuid = (
|
||||
f"{cleaned[0:8]}-{cleaned[8:12]}-{cleaned[12:16]}-"
|
||||
f"{cleaned[16:20]}-{cleaned[20:]}"
|
||||
).lower()
|
||||
return NormalizationResult(
|
||||
normalized_url=normalized_uuid, use_default=False
|
||||
)
|
||||
|
||||
# Try query params
|
||||
params = parse_qs(parsed.query)
|
||||
for key in ("p", "page_id"):
|
||||
if key in params and params[key]:
|
||||
candidate = params[key][0].replace("-", "")
|
||||
if len(candidate) == 32 and re.fullmatch(r"[0-9a-fA-F]{32}", candidate):
|
||||
normalized_uuid = (
|
||||
f"{candidate[0:8]}-{candidate[8:12]}-{candidate[12:16]}-"
|
||||
f"{candidate[16:20]}-{candidate[20:]}"
|
||||
).lower()
|
||||
return NormalizationResult(
|
||||
normalized_url=normalized_uuid, use_default=False
|
||||
)
|
||||
|
||||
return NormalizationResult(normalized_url=None, use_default=False)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_child_blocks(
|
||||
self, block_id: str, cursor: str | None = None
|
||||
|
||||
@@ -15,7 +15,6 @@ from http.client import RemoteDisconnected
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from urllib.error import URLError
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
@@ -42,7 +41,6 @@ from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import CredentialsConnector
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import NormalizationResult
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
@@ -628,43 +626,6 @@ class SlackConnector(
|
||||
# self.delay_lock: str | None = None # the redis key for the shared lock
|
||||
# self.delay_key: str | None = None # the redis key for the shared delay
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def normalize_url(cls, url: str) -> NormalizationResult:
|
||||
"""Normalize a Slack URL to extract channel_id__thread_ts format."""
|
||||
parsed = urlparse(url)
|
||||
if "slack.com" not in parsed.netloc.lower():
|
||||
return NormalizationResult(normalized_url=None, use_default=False)
|
||||
|
||||
# Slack document IDs are format: channel_id__thread_ts
|
||||
# Extract from URL pattern: .../archives/{channel_id}/p{timestamp}
|
||||
path_parts = parsed.path.split("/")
|
||||
if "archives" not in path_parts:
|
||||
return NormalizationResult(normalized_url=None, use_default=False)
|
||||
|
||||
archives_idx = path_parts.index("archives")
|
||||
if archives_idx + 1 >= len(path_parts):
|
||||
return NormalizationResult(normalized_url=None, use_default=False)
|
||||
|
||||
channel_id = path_parts[archives_idx + 1]
|
||||
if archives_idx + 2 >= len(path_parts):
|
||||
return NormalizationResult(normalized_url=None, use_default=False)
|
||||
|
||||
thread_part = path_parts[archives_idx + 2]
|
||||
if not thread_part.startswith("p"):
|
||||
return NormalizationResult(normalized_url=None, use_default=False)
|
||||
|
||||
# Convert p1234567890123456 to 1234567890.123456 format
|
||||
timestamp_str = thread_part[1:] # Remove 'p' prefix
|
||||
if len(timestamp_str) == 16:
|
||||
# Insert dot at position 10 to match canonical format
|
||||
thread_ts = f"{timestamp_str[:10]}.{timestamp_str[10:]}"
|
||||
else:
|
||||
thread_ts = timestamp_str
|
||||
|
||||
normalized = f"{channel_id}__{thread_ts}"
|
||||
return NormalizationResult(normalized_url=normalized, use_default=False)
|
||||
|
||||
@staticmethod
|
||||
def make_credential_prefix(key: str) -> str:
|
||||
return f"connector:slack:credential_{key}"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import io
|
||||
import ipaddress
|
||||
import random
|
||||
import socket
|
||||
@@ -35,11 +36,10 @@ from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.html_utils import web_html_cleanup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.sitemap import list_pages_for_site
|
||||
from onyx.utils.web_content import extract_pdf_text
|
||||
from onyx.utils.web_content import is_pdf_resource
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -116,6 +116,16 @@ DEFAULT_HEADERS = {
|
||||
"Sec-CH-UA-Platform": '"macOS"',
|
||||
}
|
||||
|
||||
# Common PDF MIME types
|
||||
PDF_MIME_TYPES = [
|
||||
"application/pdf",
|
||||
"application/x-pdf",
|
||||
"application/acrobat",
|
||||
"application/vnd.pdf",
|
||||
"text/pdf",
|
||||
"text/x-pdf",
|
||||
]
|
||||
|
||||
|
||||
class WEB_CONNECTOR_VALID_SETTINGS(str, Enum):
|
||||
# Given a base site, index everything under that path
|
||||
@@ -258,6 +268,12 @@ def get_internal_links(
|
||||
return internal_links
|
||||
|
||||
|
||||
def is_pdf_content(response: requests.Response) -> bool:
|
||||
"""Check if the response contains PDF content based on content-type header"""
|
||||
content_type = response.headers.get("content-type", "").lower()
|
||||
return any(pdf_type in content_type for pdf_type in PDF_MIME_TYPES)
|
||||
|
||||
|
||||
def start_playwright() -> Tuple[Playwright, BrowserContext]:
|
||||
playwright = sync_playwright().start()
|
||||
|
||||
@@ -513,13 +529,14 @@ class WebConnector(LoadConnector):
|
||||
head_response = requests.head(
|
||||
initial_url, headers=DEFAULT_HEADERS, allow_redirects=True
|
||||
)
|
||||
content_type = head_response.headers.get("content-type")
|
||||
is_pdf = is_pdf_resource(initial_url, content_type)
|
||||
is_pdf = is_pdf_content(head_response)
|
||||
|
||||
if is_pdf:
|
||||
if is_pdf or initial_url.lower().endswith(".pdf"):
|
||||
# PDF files are not checked for links
|
||||
response = requests.get(initial_url, headers=DEFAULT_HEADERS)
|
||||
page_text, metadata = extract_pdf_text(response.content)
|
||||
page_text, metadata, images = read_pdf_file(
|
||||
file=io.BytesIO(response.content)
|
||||
)
|
||||
last_modified = response.headers.get("Last-Modified")
|
||||
|
||||
result.doc = Document(
|
||||
|
||||
@@ -1016,15 +1016,9 @@ def slack_retrieval(
|
||||
for query_string in query_strings
|
||||
]
|
||||
|
||||
# If include_dm is True AND we're not already searching all channels,
|
||||
# add additional searches without channel filters.
|
||||
# This allows searching DMs/group DMs while still searching the specified channels.
|
||||
# Skip this if search_all_channels is already True (would be duplicate queries).
|
||||
if (
|
||||
entities
|
||||
and entities.get("include_dm")
|
||||
and not entities.get("search_all_channels")
|
||||
):
|
||||
# If include_dm is True, add additional searches without channel filters
|
||||
# This allows searching DMs/group DMs while still searching the specified channels
|
||||
if entities and entities.get("include_dm"):
|
||||
# Create a minimal entities dict that won't add channel filters
|
||||
# This ensures we search ALL conversations (DMs, group DMs, private channels)
|
||||
# BUT we still want to exclude channels specified in exclude_channels
|
||||
|
||||
@@ -398,8 +398,8 @@ def extract_channel_references_from_query(query_text: str) -> set[str]:
|
||||
channel_patterns = [
|
||||
r"\bin\s+(?:the\s+)?([a-z0-9_-]+)\s+(?:slack\s+)?channels?\b", # "in the office channel"
|
||||
r"\bfrom\s+(?:the\s+)?([a-z0-9_-]+)\s+(?:slack\s+)?channels?\b", # "from the office channel"
|
||||
r"\bin[:\s]*#([a-z0-9_-]+)\b", # "in #office" or "in:#office"
|
||||
r"\bfrom[:\s]*#([a-z0-9_-]+)\b", # "from #office" or "from:#office"
|
||||
r"\bin\s+#([a-z0-9_-]+)\b", # "in #office"
|
||||
r"\bfrom\s+#([a-z0-9_-]+)\b", # "from #office"
|
||||
]
|
||||
|
||||
for pattern in channel_patterns:
|
||||
|
||||
@@ -253,7 +253,7 @@ class RetrievalDetails(ChunkContext):
|
||||
# Use LLM to determine whether to do a retrieval or only rely on existing history
|
||||
# If the Persona is configured to not run search (0 chunks), this is bypassed
|
||||
# If no Prompt is configured, the only search results are shown, this is bypassed
|
||||
run_search: OptionalSearchSetting = OptionalSearchSetting.AUTO
|
||||
run_search: OptionalSearchSetting = OptionalSearchSetting.ALWAYS
|
||||
# Is this a real-time/streaming call or a question where Onyx can take more time?
|
||||
# Used to determine reranking flow
|
||||
real_time: bool = True
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.secondary_llm_flows.source_filter import extract_source_filter
|
||||
from onyx.secondary_llm_flows.time_filter import extract_time_filter
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -248,6 +249,8 @@ def search_pipeline(
|
||||
db_session: Session,
|
||||
auto_detect_filters: bool = False,
|
||||
llm: LLM | None = None,
|
||||
# Needed for federated Slack search
|
||||
slack_context: SlackContext | None = None,
|
||||
# If a project ID is provided, it will be exclusively scoped to that project
|
||||
project_id: int | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
@@ -288,9 +291,11 @@ def search_pipeline(
|
||||
|
||||
retrieved_chunks = search_chunks(
|
||||
query_request=query_request,
|
||||
# Needed for federated Slack search
|
||||
user_id=user.id if user else None,
|
||||
document_index=document_index,
|
||||
db_session=db_session,
|
||||
slack_context=slack_context,
|
||||
)
|
||||
|
||||
# For some specific connectors like Salesforce, a user that has access to an object doesn't mean
|
||||
|
||||
@@ -29,6 +29,7 @@ from onyx.document_index.vespa.shared_utils.utils import (
|
||||
from onyx.federated_connectors.federated_retrieval import (
|
||||
get_federated_retrieval_functions,
|
||||
)
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.secondary_llm_flows.query_expansion import multilingual_query_expansion
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
@@ -330,6 +331,7 @@ def retrieve_chunks(
|
||||
retrieval_metrics_callback: (
|
||||
Callable[[RetrievalMetricsContainer], None] | None
|
||||
) = None,
|
||||
slack_context: SlackContext | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
|
||||
|
||||
@@ -346,7 +348,8 @@ def retrieve_chunks(
|
||||
user_id,
|
||||
list(query.filters.source_type) if query.filters.source_type else None,
|
||||
query.filters.document_set,
|
||||
user_file_ids=query.filters.user_file_ids,
|
||||
slack_context,
|
||||
query.filters.user_file_ids,
|
||||
)
|
||||
federated_sources = set(
|
||||
federated_retrieval_info.source.to_non_federated_source()
|
||||
@@ -457,6 +460,7 @@ def search_chunks(
|
||||
user_id: UUID | None,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
slack_context: SlackContext | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
run_queries: list[tuple[Callable, tuple]] = []
|
||||
|
||||
@@ -472,6 +476,7 @@ def search_chunks(
|
||||
user_id=user_id,
|
||||
source_types=list(source_filters) if source_filters else None,
|
||||
document_set_names=query_request.filters.document_set,
|
||||
slack_context=slack_context,
|
||||
user_file_ids=query_request.filters.user_file_ids,
|
||||
)
|
||||
|
||||
|
||||
@@ -27,10 +27,9 @@ Tool calls are stored in the ToolCall table and can represent all of the followi
|
||||
the agent call as a parent are the tool calls that happen as part of the agent.
|
||||
|
||||
The different branches are generated by sending a new search query to an existing parent.
|
||||
```
|
||||
|
||||
[Empty Root Message] (This allows the first message to be branched/edited as well)
|
||||
/ | \
|
||||
[First Message] [First Message Edit 1] [First Message Edit 2]
|
||||
| |
|
||||
[Second Message] [Second Message of Edit 1 Branch]
|
||||
```
|
||||
|
||||
@@ -7,7 +7,6 @@ from uuid import UUID
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import nullsfirst
|
||||
from sqlalchemy import or_
|
||||
@@ -153,7 +152,6 @@ def get_chat_sessions_by_user(
|
||||
limit: int = 50,
|
||||
project_id: int | None = None,
|
||||
only_non_project_chats: bool = False,
|
||||
include_failed_chats: bool = False,
|
||||
) -> list[ChatSession]:
|
||||
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
|
||||
|
||||
@@ -173,16 +171,6 @@ def get_chat_sessions_by_user(
|
||||
elif only_non_project_chats:
|
||||
stmt = stmt.where(ChatSession.project_id.is_(None))
|
||||
|
||||
if not include_failed_chats:
|
||||
non_system_message_exists_subq = (
|
||||
exists()
|
||||
.where(ChatMessage.chat_session_id == ChatSession.id)
|
||||
.where(ChatMessage.message_type != MessageType.SYSTEM)
|
||||
.correlate(ChatSession)
|
||||
)
|
||||
|
||||
stmt = stmt.where(non_system_message_exists_subq)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
chat_sessions = result.scalars().all()
|
||||
|
||||
|
||||
@@ -6,15 +6,21 @@ from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import lateral
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import true
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.app_configs import USER_FILE_INDEXING_LIMIT
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.credentials import fetch_credential_by_id
|
||||
@@ -114,6 +120,7 @@ def get_connector_credential_pairs_for_user(
|
||||
eager_load_connector: bool = False,
|
||||
eager_load_credential: bool = False,
|
||||
eager_load_user: bool = False,
|
||||
include_user_files: bool = False,
|
||||
order_by_desc: bool = False,
|
||||
source: DocumentSource | None = None,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
@@ -142,6 +149,9 @@ def get_connector_credential_pairs_for_user(
|
||||
if ids:
|
||||
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
|
||||
|
||||
if not include_user_files:
|
||||
stmt = stmt.where(ConnectorCredentialPair.is_user_file.is_(False))
|
||||
|
||||
if order_by_desc:
|
||||
stmt = stmt.order_by(desc(ConnectorCredentialPair.id))
|
||||
|
||||
@@ -176,13 +186,16 @@ def get_connector_credential_pairs_for_user_parallel(
|
||||
|
||||
|
||||
def get_connector_credential_pairs(
|
||||
db_session: Session, ids: list[int] | None = None
|
||||
db_session: Session, ids: list[int] | None = None, include_user_files: bool = False
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
|
||||
if ids:
|
||||
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
|
||||
|
||||
if not include_user_files:
|
||||
stmt = stmt.where(ConnectorCredentialPair.is_user_file != True) # noqa: E712
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
@@ -229,12 +242,15 @@ def get_connector_credential_pair_for_user(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
user: User | None,
|
||||
include_user_files: bool = False,
|
||||
get_editable: bool = True,
|
||||
) -> ConnectorCredentialPair | None:
|
||||
stmt = select(ConnectorCredentialPair)
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
stmt = stmt.where(ConnectorCredentialPair.connector_id == connector_id)
|
||||
stmt = stmt.where(ConnectorCredentialPair.credential_id == credential_id)
|
||||
if not include_user_files:
|
||||
stmt = stmt.where(ConnectorCredentialPair.is_user_file != True) # noqa: E712
|
||||
result = db_session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@@ -361,6 +377,8 @@ def _update_connector_credential_pair(
|
||||
cc_pair.total_docs_indexed += net_docs
|
||||
if status is not None:
|
||||
cc_pair.status = status
|
||||
if cc_pair.is_user_file:
|
||||
cc_pair.status = ConnectorCredentialPairStatus.PAUSED
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -426,10 +444,27 @@ def set_cc_pair_repeated_error_state(
|
||||
cc_pair_id: int,
|
||||
in_repeated_error_state: bool,
|
||||
) -> None:
|
||||
values: dict = {"in_repeated_error_state": in_repeated_error_state}
|
||||
|
||||
# When entering repeated error state, also pause the connector
|
||||
# to prevent continued indexing retry attempts burning through embedding credits.
|
||||
# However, don't pause if there's an active manual indexing trigger,
|
||||
# which indicates the user wants to retry immediately.
|
||||
# NOTE: only for Cloud, since most self-hosted users use self-hosted embedding
|
||||
# models. Also, they are more prone to repeated failures -> eventual success.
|
||||
if in_repeated_error_state and AUTH_TYPE == AuthType.CLOUD:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
# Only pause if there's no manual indexing trigger active
|
||||
if cc_pair and cc_pair.indexing_trigger is None:
|
||||
values["status"] = ConnectorCredentialPairStatus.PAUSED
|
||||
|
||||
stmt = (
|
||||
update(ConnectorCredentialPair)
|
||||
.where(ConnectorCredentialPair.id == cc_pair_id)
|
||||
.values(in_repeated_error_state=in_repeated_error_state)
|
||||
.values(**values)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
db_session.commit()
|
||||
@@ -501,6 +536,7 @@ def add_credential_to_connector(
|
||||
initial_status: ConnectorCredentialPairStatus = ConnectorCredentialPairStatus.SCHEDULED,
|
||||
last_successful_index_time: datetime | None = None,
|
||||
seeding_flow: bool = False,
|
||||
is_user_file: bool = False,
|
||||
) -> StatusResponse:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
|
||||
@@ -566,6 +602,7 @@ def add_credential_to_connector(
|
||||
access_type=access_type,
|
||||
auto_sync_options=auto_sync_options,
|
||||
last_successful_index_time=last_successful_index_time,
|
||||
is_user_file=is_user_file,
|
||||
)
|
||||
db_session.add(association)
|
||||
db_session.flush() # make sure the association has an id
|
||||
@@ -662,12 +699,67 @@ def fetch_indexable_standard_connector_credential_pair_ids(
|
||||
)
|
||||
)
|
||||
|
||||
# Exclude user files. NOTE: some cc pairs have null for is_user_file instead of False
|
||||
stmt = stmt.where(ConnectorCredentialPair.is_user_file.is_not(True))
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
return list(db_session.scalars(stmt))
|
||||
|
||||
|
||||
def fetch_indexable_user_file_connector_credential_pair_ids(
|
||||
db_session: Session,
|
||||
search_settings_id: int,
|
||||
limit: int | None = USER_FILE_INDEXING_LIMIT,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Return up to `limit` user file connector_credential_pair IDs that still
|
||||
need indexing for the given `search_settings_id`
|
||||
|
||||
A cc_pair is considered "needs indexing" if its most recent IndexAttempt
|
||||
for this search_settings_id is either:
|
||||
- Missing entirely (no attempts yet)
|
||||
- Present but not SUCCESS status
|
||||
|
||||
Implementation details:
|
||||
- Uses a LEFT JOIN LATERAL subquery to fetch only the single newest attempt
|
||||
per cc_pair (`ORDER BY time_updated DESC LIMIT 1`), instead of joining all
|
||||
attempts. This avoids scanning thousands of historical attempts and
|
||||
keeps memory/CPU usage low
|
||||
- `ON TRUE` is required in the lateral join because the correlation to
|
||||
ConnectorCredentialPair.id happens inside the subquery itself
|
||||
- NOTE: Shares some redundant logic with should_index() (TODO: combine)
|
||||
|
||||
Returns:
|
||||
list[int]: connector_credential_pair IDs that should be indexed next
|
||||
"""
|
||||
latest_attempt = lateral(
|
||||
select(IndexAttempt.status)
|
||||
.where(
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
IndexAttempt.search_settings_id == search_settings_id,
|
||||
)
|
||||
.order_by(IndexAttempt.time_updated.desc())
|
||||
.limit(1)
|
||||
).alias("latest_attempt")
|
||||
|
||||
stmt = (
|
||||
select(ConnectorCredentialPair.id)
|
||||
.outerjoin(latest_attempt, true()) # ON TRUE, Postgres-style lateral join
|
||||
.where(
|
||||
ConnectorCredentialPair.is_user_file.is_(True),
|
||||
or_(
|
||||
latest_attempt.c.status.is_(None), # no attempts at all
|
||||
latest_attempt.c.status != IndexingStatus.SUCCESS, # latest != SUCCESS
|
||||
),
|
||||
)
|
||||
.limit(limit) # Always apply a limit when fetching user file cc pairs
|
||||
)
|
||||
|
||||
return list(db_session.scalars(stmt))
|
||||
|
||||
|
||||
def fetch_connector_credential_pair_for_connector(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
|
||||
@@ -225,38 +225,6 @@ def get_documents_by_ids(
|
||||
return list(documents)
|
||||
|
||||
|
||||
def filter_existing_document_ids(
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
) -> set[str]:
|
||||
"""Filter a list of document IDs to only those that exist in the database.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
document_ids: List of document IDs to check for existence
|
||||
|
||||
Returns:
|
||||
Set of document IDs from the input list that exist in the database
|
||||
"""
|
||||
if not document_ids:
|
||||
return set()
|
||||
stmt = select(DbDocument.id).where(DbDocument.id.in_(document_ids))
|
||||
return set(db_session.execute(stmt).scalars().all())
|
||||
|
||||
|
||||
def fetch_document_ids_by_links(
|
||||
db_session: Session,
|
||||
links: list[str],
|
||||
) -> dict[str, str]:
|
||||
"""Fetch document IDs for documents whose link matches any of the provided values."""
|
||||
if not links:
|
||||
return {}
|
||||
|
||||
stmt = select(DbDocument.link, DbDocument.id).where(DbDocument.link.in_(links))
|
||||
rows = db_session.execute(stmt).all()
|
||||
return {link: doc_id for link, doc_id in rows if link}
|
||||
|
||||
|
||||
def get_document_connector_count(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
|
||||
@@ -1,246 +0,0 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import ImageGenerationConfig
|
||||
from onyx.db.models import LLMProvider
|
||||
from onyx.db.models import ModelConfiguration
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Default image generation config constants
|
||||
DEFAULT_IMAGE_PROVIDER_ID = "openai_gpt_image_1"
|
||||
DEFAULT_IMAGE_MODEL_NAME = "gpt-image-1"
|
||||
DEFAULT_IMAGE_PROVIDER = "openai"
|
||||
|
||||
|
||||
def create_image_generation_config__no_commit(
|
||||
db_session: Session,
|
||||
image_provider_id: str,
|
||||
model_configuration_id: int,
|
||||
is_default: bool = False,
|
||||
) -> ImageGenerationConfig:
|
||||
"""Create a new image generation config."""
|
||||
# If setting as default, clear ALL existing defaults in a single atomic update
|
||||
# This is more atomic than select-then-update pattern
|
||||
if is_default:
|
||||
db_session.execute(
|
||||
update(ImageGenerationConfig)
|
||||
.where(ImageGenerationConfig.is_default.is_(True))
|
||||
.values(is_default=False)
|
||||
)
|
||||
|
||||
new_config = ImageGenerationConfig(
|
||||
image_provider_id=image_provider_id,
|
||||
model_configuration_id=model_configuration_id,
|
||||
is_default=is_default,
|
||||
)
|
||||
db_session.add(new_config)
|
||||
db_session.flush()
|
||||
return new_config
|
||||
|
||||
|
||||
def get_all_image_generation_configs(
|
||||
db_session: Session,
|
||||
) -> list[ImageGenerationConfig]:
|
||||
"""Get all image generation configs.
|
||||
|
||||
Returns:
|
||||
List of all ImageGenerationConfig objects
|
||||
"""
|
||||
stmt = select(ImageGenerationConfig)
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def get_image_generation_config(
|
||||
db_session: Session,
|
||||
image_provider_id: str,
|
||||
) -> ImageGenerationConfig | None:
|
||||
"""Get a single image generation config by image_provider_id with relationships loaded.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
image_provider_id: The image provider ID (primary key)
|
||||
|
||||
Returns:
|
||||
The ImageGenerationConfig or None if not found
|
||||
"""
|
||||
stmt = (
|
||||
select(ImageGenerationConfig)
|
||||
.where(ImageGenerationConfig.image_provider_id == image_provider_id)
|
||||
.options(
|
||||
selectinload(ImageGenerationConfig.model_configuration).selectinload(
|
||||
ModelConfiguration.llm_provider
|
||||
)
|
||||
)
|
||||
)
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def get_default_image_generation_config(
|
||||
db_session: Session,
|
||||
) -> ImageGenerationConfig | None:
|
||||
"""Get the default image generation config.
|
||||
|
||||
Returns:
|
||||
The default ImageGenerationConfig or None if not set
|
||||
"""
|
||||
stmt = (
|
||||
select(ImageGenerationConfig)
|
||||
.where(ImageGenerationConfig.is_default.is_(True))
|
||||
.options(
|
||||
selectinload(ImageGenerationConfig.model_configuration).selectinload(
|
||||
ModelConfiguration.llm_provider
|
||||
)
|
||||
)
|
||||
)
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def set_default_image_generation_config(
|
||||
db_session: Session,
|
||||
image_provider_id: str,
|
||||
) -> None:
|
||||
"""Set a config as the default (clears previous default).
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
image_provider_id: The image provider ID to set as default
|
||||
|
||||
Raises:
|
||||
ValueError: If config not found
|
||||
"""
|
||||
# Get the config to set as default
|
||||
new_default = db_session.get(ImageGenerationConfig, image_provider_id)
|
||||
if not new_default:
|
||||
raise ValueError(
|
||||
f"ImageGenerationConfig with image_provider_id {image_provider_id} not found"
|
||||
)
|
||||
|
||||
# Clear ALL existing defaults in a single atomic update
|
||||
# This is more atomic than select-then-update pattern
|
||||
db_session.execute(
|
||||
update(ImageGenerationConfig)
|
||||
.where(
|
||||
ImageGenerationConfig.is_default.is_(True),
|
||||
ImageGenerationConfig.image_provider_id != image_provider_id,
|
||||
)
|
||||
.values(is_default=False)
|
||||
)
|
||||
|
||||
# Set new default
|
||||
new_default.is_default = True
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def unset_default_image_generation_config(
|
||||
db_session: Session,
|
||||
image_provider_id: str,
|
||||
) -> None:
|
||||
"""Unset a config as the default."""
|
||||
config = db_session.get(ImageGenerationConfig, image_provider_id)
|
||||
if not config:
|
||||
raise ValueError(
|
||||
f"ImageGenerationConfig with image_provider_id {image_provider_id} not found"
|
||||
)
|
||||
config.is_default = False
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_image_generation_config__no_commit(
|
||||
db_session: Session,
|
||||
image_provider_id: str,
|
||||
) -> None:
|
||||
"""Delete an image generation config by image_provider_id."""
|
||||
config = db_session.get(ImageGenerationConfig, image_provider_id)
|
||||
if not config:
|
||||
raise ValueError(
|
||||
f"ImageGenerationConfig with image_provider_id {image_provider_id} not found"
|
||||
)
|
||||
|
||||
db_session.delete(config)
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def create_default_image_gen_config_from_api_key(
|
||||
db_session: Session,
|
||||
api_key: str,
|
||||
provider: str = DEFAULT_IMAGE_PROVIDER,
|
||||
image_provider_id: str = DEFAULT_IMAGE_PROVIDER_ID,
|
||||
model_name: str = DEFAULT_IMAGE_MODEL_NAME,
|
||||
) -> ImageGenerationConfig | None:
|
||||
"""Create default image gen config using an API key directly.
|
||||
|
||||
This function is used during tenant provisioning to automatically create
|
||||
a default image generation config when an OpenAI provider is configured.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
api_key: API key for the LLM provider
|
||||
provider: Provider name (default: openai)
|
||||
image_provider_id: Static unique key for the config (default: openai_gpt_image_1)
|
||||
model_name: Model name for image generation (default: gpt-image-1)
|
||||
|
||||
Returns:
|
||||
The created ImageGenerationConfig, or None if:
|
||||
- image_generation_config table already has records
|
||||
"""
|
||||
# Check if any image generation configs already exist (optimization to avoid work)
|
||||
existing_configs = get_all_image_generation_configs(db_session)
|
||||
if existing_configs:
|
||||
logger.info("Image generation config already exists, skipping default creation")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Create new LLM provider for image generation
|
||||
new_provider = LLMProvider(
|
||||
name=f"Image Gen - {image_provider_id}",
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
default_model_name=model_name,
|
||||
deployment_name=None,
|
||||
is_public=True,
|
||||
)
|
||||
db_session.add(new_provider)
|
||||
db_session.flush()
|
||||
|
||||
# Create model configuration
|
||||
max_input_tokens = get_max_input_tokens(
|
||||
model_name=model_name,
|
||||
model_provider=provider,
|
||||
)
|
||||
|
||||
model_config = ModelConfiguration(
|
||||
llm_provider_id=new_provider.id,
|
||||
name=model_name,
|
||||
is_visible=True,
|
||||
max_input_tokens=max_input_tokens,
|
||||
)
|
||||
db_session.add(model_config)
|
||||
db_session.flush()
|
||||
|
||||
# Create image generation config
|
||||
config = create_image_generation_config__no_commit(
|
||||
db_session=db_session,
|
||||
image_provider_id=image_provider_id,
|
||||
model_configuration_id=model_config.id,
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
logger.info(f"Created default image generation config: {image_provider_id}")
|
||||
|
||||
return config
|
||||
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
logger.exception(
|
||||
f"Failed to create default image generation config {image_provider_id}"
|
||||
)
|
||||
return None
|
||||
@@ -6,7 +6,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import ImageGenerationConfig
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import LLMProvider__Persona
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
@@ -17,7 +16,6 @@ from onyx.db.models import Tool as ToolModel
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.llm.well_known_providers.auto_update_models import LLMRecommendations
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
@@ -237,7 +235,6 @@ def upsert_llm_provider(
|
||||
llm_provider_upsert_request.default_model_name
|
||||
)
|
||||
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
|
||||
existing_llm_provider.is_auto_mode = llm_provider_upsert_request.is_auto_mode
|
||||
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name
|
||||
|
||||
if not existing_llm_provider.id:
|
||||
@@ -374,29 +371,12 @@ def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolM
|
||||
def fetch_existing_llm_providers(
|
||||
db_session: Session,
|
||||
only_public: bool = False,
|
||||
exclude_image_generation_providers: bool = False,
|
||||
) -> list[LLMProviderModel]:
|
||||
"""Fetch all LLM providers with optional filtering.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
only_public: If True, only return public providers
|
||||
exclude_image_generation_providers: If True, exclude providers that are
|
||||
used for image generation configs
|
||||
"""
|
||||
stmt = select(LLMProviderModel).options(
|
||||
selectinload(LLMProviderModel.model_configurations),
|
||||
selectinload(LLMProviderModel.groups),
|
||||
selectinload(LLMProviderModel.personas),
|
||||
)
|
||||
|
||||
if exclude_image_generation_providers:
|
||||
# Get LLM provider IDs used by ImageGenerationConfig
|
||||
image_gen_provider_ids = select(ModelConfiguration.llm_provider_id).join(
|
||||
ImageGenerationConfig
|
||||
)
|
||||
stmt = stmt.where(LLMProviderModel.id.not_in(image_gen_provider_ids))
|
||||
|
||||
providers = list(db_session.scalars(stmt).all())
|
||||
if only_public:
|
||||
return [provider for provider in providers if provider.is_public]
|
||||
@@ -502,30 +482,6 @@ def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_llm_provider__no_commit(db_session: Session, provider_id: int) -> None:
|
||||
"""Remove LLM provider."""
|
||||
provider = db_session.get(LLMProviderModel, provider_id)
|
||||
if not provider:
|
||||
raise ValueError("LLM Provider not found")
|
||||
|
||||
# Clear the provider override from any personas using it
|
||||
# This causes them to fall back to the default provider
|
||||
personas_using_provider = get_personas_using_provider(db_session, provider.name)
|
||||
for persona in personas_using_provider:
|
||||
persona.llm_model_provider_override = None
|
||||
|
||||
db_session.execute(
|
||||
delete(LLMProvider__UserGroup).where(
|
||||
LLMProvider__UserGroup.llm_provider_id == provider_id
|
||||
)
|
||||
)
|
||||
# Remove LLMProvider
|
||||
db_session.execute(
|
||||
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def update_default_provider(provider_id: int, db_session: Session) -> None:
|
||||
new_default = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
@@ -581,98 +537,3 @@ def update_default_vision_provider(
|
||||
new_default.is_default_vision_provider = True
|
||||
new_default.default_vision_model = vision_model
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def fetch_auto_mode_providers(db_session: Session) -> list[LLMProviderModel]:
|
||||
"""Fetch all LLM providers that are in Auto mode."""
|
||||
return list(
|
||||
db_session.scalars(
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.is_auto_mode == True) # noqa: E712
|
||||
.options(selectinload(LLMProviderModel.model_configurations))
|
||||
).all()
|
||||
)
|
||||
|
||||
|
||||
def sync_auto_mode_models(
|
||||
db_session: Session,
|
||||
provider: LLMProviderModel,
|
||||
llm_recommendations: LLMRecommendations,
|
||||
) -> int:
|
||||
"""Sync models from GitHub config to a provider in Auto mode.
|
||||
|
||||
In Auto mode, the model list and default are controlled by GitHub config.
|
||||
The schema has:
|
||||
- default_model: The default model config (always visible)
|
||||
- additional_visible_models: List of additional visible models
|
||||
|
||||
Admin only provides API credentials.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
provider: LLM provider in Auto mode
|
||||
github_config: Configuration from GitHub
|
||||
|
||||
Returns:
|
||||
The number of changes made.
|
||||
"""
|
||||
changes = 0
|
||||
|
||||
# Build the list of all visible models from the config
|
||||
# All models in the config are visible (default + additional_visible_models)
|
||||
recommended_visible_models = llm_recommendations.get_visible_models(provider.name)
|
||||
recommended_visible_model_names = [
|
||||
model.name for model in recommended_visible_models
|
||||
]
|
||||
|
||||
# Get existing models
|
||||
existing_models: dict[str, ModelConfiguration] = {
|
||||
mc.name: mc
|
||||
for mc in db_session.scalars(
|
||||
select(ModelConfiguration).where(
|
||||
ModelConfiguration.llm_provider_id == provider.id
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
# Remove models that are no longer in GitHub config
|
||||
for model_name, model in existing_models.items():
|
||||
if model_name not in recommended_visible_model_names:
|
||||
db_session.delete(model)
|
||||
changes += 1
|
||||
|
||||
# Add or update models from GitHub config
|
||||
for model_config in recommended_visible_models:
|
||||
if model_config.name in existing_models:
|
||||
# Update existing model
|
||||
existing = existing_models[model_config.name]
|
||||
# Check each field for changes
|
||||
updated = False
|
||||
if existing.display_name != model_config.display_name:
|
||||
existing.display_name = model_config.display_name
|
||||
updated = True
|
||||
# All models in the config are visible
|
||||
if not existing.is_visible:
|
||||
existing.is_visible = True
|
||||
updated = True
|
||||
if updated:
|
||||
changes += 1
|
||||
else:
|
||||
# Add new model - all models from GitHub config are visible
|
||||
new_model = ModelConfiguration(
|
||||
llm_provider_id=provider.id,
|
||||
name=model_config.name,
|
||||
display_name=model_config.display_name,
|
||||
is_visible=True,
|
||||
)
|
||||
db_session.add(new_model)
|
||||
changes += 1
|
||||
|
||||
# In Auto mode, default model is always set from GitHub config
|
||||
default_model = llm_recommendations.get_default_model(provider.name)
|
||||
if default_model and provider.default_model_name != default_model.name:
|
||||
provider.default_model_name = default_model.name
|
||||
changes += 1
|
||||
|
||||
db_session.commit()
|
||||
return changes
|
||||
|
||||
@@ -369,8 +369,6 @@ class Notification(Base):
|
||||
dismissed: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
last_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
|
||||
first_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
|
||||
title: Mapped[str] = mapped_column(String)
|
||||
description: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
user: Mapped[User] = relationship("User", back_populates="notifications")
|
||||
additional_data: Mapped[dict | None] = mapped_column(
|
||||
@@ -534,6 +532,7 @@ class ConnectorCredentialPair(Base):
|
||||
"""
|
||||
|
||||
__tablename__ = "connector_credential_pair"
|
||||
is_user_file: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# NOTE: this `id` column has to use `Sequence` instead of `autoincrement=True`
|
||||
# due to some SQLAlchemy quirks + this not being a primary key column
|
||||
id: Mapped[int] = mapped_column(
|
||||
@@ -2394,8 +2393,6 @@ class LLMProvider(Base):
|
||||
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
# Auto mode: models, visibility, and defaults are managed by GitHub config
|
||||
is_auto_mode: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
groups: Mapped[list["UserGroup"]] = relationship(
|
||||
"UserGroup",
|
||||
secondary="llm_provider__user_group",
|
||||
@@ -2454,29 +2451,6 @@ class ModelConfiguration(Base):
|
||||
)
|
||||
|
||||
|
||||
class ImageGenerationConfig(Base):
|
||||
__tablename__ = "image_generation_config"
|
||||
|
||||
image_provider_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
model_configuration_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("model_configuration.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
model_configuration: Mapped["ModelConfiguration"] = relationship(
|
||||
"ModelConfiguration"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_image_generation_config_is_default", "is_default"),
|
||||
Index(
|
||||
"ix_image_generation_config_model_configuration_id",
|
||||
"model_configuration_id",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(Base):
|
||||
__tablename__ = "embedding_provider"
|
||||
|
||||
@@ -3570,6 +3544,9 @@ class UserFile(Base):
|
||||
back_populates="user_files",
|
||||
)
|
||||
file_id: Mapped[str] = mapped_column(nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(
|
||||
nullable=False
|
||||
) # TODO(subash): legacy document_id, will be removed in a future migration
|
||||
name: Mapped[str] = mapped_column(nullable=False)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
default=datetime.datetime.utcnow
|
||||
@@ -3597,6 +3574,9 @@ class UserFile(Base):
|
||||
|
||||
link_url: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
content_type: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
document_id_migrated: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=True
|
||||
)
|
||||
|
||||
projects: Mapped[list["UserProject"]] = relationship(
|
||||
"UserProject",
|
||||
@@ -3957,45 +3937,3 @@ class License(Base):
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class TenantUsage(Base):
|
||||
"""
|
||||
Tracks per-tenant usage statistics within a time window for cloud usage limits.
|
||||
|
||||
Each row represents usage for a specific tenant during a specific time window.
|
||||
A new row is created when the window rolls over (typically weekly).
|
||||
"""
|
||||
|
||||
__tablename__ = "tenant_usage"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
||||
# The start of the usage tracking window (e.g., start of the week in UTC)
|
||||
window_start: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Cumulative LLM usage cost in cents for the window
|
||||
llm_cost_cents: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||
|
||||
# Number of chunks indexed during the window
|
||||
chunks_indexed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
# Number of API calls using API keys or Personal Access Tokens
|
||||
api_calls: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
# Number of non-streaming API calls (more expensive operations)
|
||||
non_streaming_api_calls: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False, default=0
|
||||
)
|
||||
|
||||
# Last updated timestamp for tracking freshness
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
# Ensure only one row per window start (tenant_id is in the schema name)
|
||||
UniqueConstraint("window_start", name="uq_tenant_usage_window"),
|
||||
)
|
||||
|
||||
@@ -14,8 +14,6 @@ def create_notification(
|
||||
user_id: UUID | None,
|
||||
notif_type: NotificationType,
|
||||
db_session: Session,
|
||||
title: str,
|
||||
description: str | None = None,
|
||||
additional_data: dict | None = None,
|
||||
) -> Notification:
|
||||
# Check if an undismissed notification of the same type and data exists
|
||||
@@ -40,8 +38,6 @@ def create_notification(
|
||||
notification = Notification(
|
||||
user_id=user_id,
|
||||
notif_type=notif_type,
|
||||
title=title,
|
||||
description=description,
|
||||
dismissed=False,
|
||||
last_shown=func.now(),
|
||||
first_shown=func.now(),
|
||||
|
||||
@@ -205,7 +205,6 @@ def make_persona_private(
|
||||
create_notification(
|
||||
user_id=user_uuid,
|
||||
notif_type=NotificationType.PERSONA_SHARED,
|
||||
title="A new agent was shared with you!",
|
||||
db_session=db_session,
|
||||
additional_data=PersonaSharedNotificationData(
|
||||
persona_id=persona_id,
|
||||
@@ -270,7 +269,6 @@ def create_update_persona(
|
||||
system_prompt=create_persona_request.system_prompt,
|
||||
task_prompt=create_persona_request.task_prompt,
|
||||
datetime_aware=create_persona_request.datetime_aware,
|
||||
replace_base_system_prompt=create_persona_request.replace_base_system_prompt,
|
||||
uploaded_image_id=create_persona_request.uploaded_image_id,
|
||||
icon_name=create_persona_request.icon_name,
|
||||
display_priority=create_persona_request.display_priority,
|
||||
@@ -799,7 +797,6 @@ def upsert_persona(
|
||||
user_file_ids: list[UUID] | None = None,
|
||||
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
|
||||
chunks_below: int = CONTEXT_CHUNKS_BELOW,
|
||||
replace_base_system_prompt: bool = False,
|
||||
) -> Persona:
|
||||
"""
|
||||
NOTE: This operation cannot update persona configuration options that
|
||||
@@ -908,7 +905,6 @@ def upsert_persona(
|
||||
existing_persona.task_prompt = task_prompt
|
||||
if datetime_aware is not None:
|
||||
existing_persona.datetime_aware = datetime_aware
|
||||
existing_persona.replace_base_system_prompt = replace_base_system_prompt
|
||||
|
||||
# Do not delete any associations manually added unless
|
||||
# a new updated list is provided
|
||||
@@ -949,7 +945,6 @@ def upsert_persona(
|
||||
system_prompt=system_prompt or "",
|
||||
task_prompt=task_prompt or "",
|
||||
datetime_aware=(datetime_aware if datetime_aware is not None else True),
|
||||
replace_base_system_prompt=replace_base_system_prompt,
|
||||
document_sets=document_sets or [],
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
@@ -1188,15 +1183,13 @@ def update_default_assistant_configuration(
|
||||
db_session: Session,
|
||||
tool_ids: list[int] | None = None,
|
||||
system_prompt: str | None = None,
|
||||
update_system_prompt: bool = False,
|
||||
) -> Persona:
|
||||
"""Update only tools and system_prompt for the default assistant.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
tool_ids: List of tool IDs to enable (if None, tools are not updated)
|
||||
system_prompt: New system prompt value (None means use default)
|
||||
update_system_prompt: If True, update the system_prompt field (allows setting to None)
|
||||
system_prompt: New system prompt (if None, system prompt is not updated)
|
||||
|
||||
Returns:
|
||||
Updated Persona object
|
||||
@@ -1209,8 +1202,8 @@ def update_default_assistant_configuration(
|
||||
if not persona:
|
||||
raise ValueError("Default assistant not found")
|
||||
|
||||
# Update system prompt if explicitly requested
|
||||
if update_system_prompt:
|
||||
# Update system prompt if provided
|
||||
if system_prompt is not None:
|
||||
persona.system_prompt = system_prompt
|
||||
|
||||
# Update tools if provided
|
||||
|
||||
@@ -21,7 +21,6 @@ from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserProject
|
||||
from onyx.server.documents.connector import upload_files
|
||||
from onyx.server.features.projects.projects_file_utils import categorize_uploaded_files
|
||||
from onyx.server.features.projects.projects_file_utils import RejectedFile
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -30,7 +29,8 @@ logger = setup_logger()
|
||||
|
||||
class CategorizedFilesResult(BaseModel):
|
||||
user_files: list[UserFile]
|
||||
rejected_files: list[RejectedFile]
|
||||
non_accepted_files: list[str]
|
||||
unsupported_files: list[str]
|
||||
id_to_temp_id: dict[str, str]
|
||||
# Allow SQLAlchemy ORM models inside this result container
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
@@ -56,7 +56,8 @@ def create_user_files(
|
||||
# Should revisit to decide whether this should be a feature.
|
||||
upload_response = upload_files(categorized_files.acceptable, FileOrigin.USER_FILE)
|
||||
user_files = []
|
||||
rejected_files = categorized_files.rejected
|
||||
non_accepted_files = categorized_files.non_accepted
|
||||
unsupported_files = categorized_files.unsupported
|
||||
id_to_temp_id: dict[str, str] = {}
|
||||
# Pair returned storage paths with the same set of acceptable files we uploaded
|
||||
for file_path, file in zip(
|
||||
@@ -72,6 +73,7 @@ def create_user_files(
|
||||
id=new_id,
|
||||
user_id=user.id if user else None,
|
||||
file_id=file_path,
|
||||
document_id=str(new_id),
|
||||
name=file.filename,
|
||||
token_count=categorized_files.acceptable_file_to_token_count[
|
||||
file.filename or ""
|
||||
@@ -94,7 +96,8 @@ def create_user_files(
|
||||
db_session.commit()
|
||||
return CategorizedFilesResult(
|
||||
user_files=user_files,
|
||||
rejected_files=rejected_files,
|
||||
non_accepted_files=non_accepted_files,
|
||||
unsupported_files=unsupported_files,
|
||||
id_to_temp_id=id_to_temp_id,
|
||||
)
|
||||
|
||||
@@ -119,14 +122,17 @@ def upload_files_to_user_files_with_indexing(
|
||||
temp_id_map=temp_id_map,
|
||||
)
|
||||
user_files = categorized_files_result.user_files
|
||||
rejected_files = categorized_files_result.rejected_files
|
||||
non_accepted_files = categorized_files_result.non_accepted_files
|
||||
unsupported_files = categorized_files_result.unsupported_files
|
||||
id_to_temp_id = categorized_files_result.id_to_temp_id
|
||||
# Trigger per-file processing immediately for the current tenant
|
||||
tenant_id = get_current_tenant_id()
|
||||
for rejected_file in rejected_files:
|
||||
logger.warning(
|
||||
f"File {rejected_file.filename} rejected for {rejected_file.reason}"
|
||||
)
|
||||
if non_accepted_files:
|
||||
for filename in non_accepted_files:
|
||||
logger.warning(f"Non-accepted file: {filename}")
|
||||
if unsupported_files:
|
||||
for filename in unsupported_files:
|
||||
logger.warning(f"Unsupported file: {filename}")
|
||||
for user_file in user_files:
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
@@ -140,7 +146,8 @@ def upload_files_to_user_files_with_indexing(
|
||||
|
||||
return CategorizedFilesResult(
|
||||
user_files=user_files,
|
||||
rejected_files=rejected_files,
|
||||
non_accepted_files=non_accepted_files,
|
||||
unsupported_files=unsupported_files,
|
||||
id_to_temp_id=id_to_temp_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
did change.
|
||||
"""
|
||||
# Default CC-pair created for Ingestion API unused here
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session, include_user_files=True)
|
||||
cc_pair_count = max(len(all_cc_pairs) - 1, 0)
|
||||
new_search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ def create_or_add_document_tag(
|
||||
is_list=False,
|
||||
)
|
||||
insert_stmt = insert_stmt.on_conflict_do_nothing(
|
||||
constraint="_tag_key_value_source_list_uc"
|
||||
index_elements=["tag_key", "tag_value", "source", "is_list"]
|
||||
)
|
||||
db_session.execute(insert_stmt)
|
||||
|
||||
@@ -98,7 +98,7 @@ def create_or_add_document_tag_list(
|
||||
is_list=True,
|
||||
)
|
||||
insert_stmt = insert_stmt.on_conflict_do_nothing(
|
||||
constraint="_tag_key_value_source_list_uc"
|
||||
index_elements=["tag_key", "tag_value", "source", "is_list"]
|
||||
)
|
||||
db_session.execute(insert_stmt)
|
||||
|
||||
|
||||
@@ -1,209 +0,0 @@
|
||||
"""Database interactions for tenant usage tracking (cloud usage limits)."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import TenantUsage
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import USAGE_LIMIT_WINDOW_SECONDS
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class UsageType(str, Enum):
|
||||
"""Types of usage that can be tracked and limited."""
|
||||
|
||||
LLM_COST = "llm_cost_cents"
|
||||
CHUNKS_INDEXED = "chunks_indexed"
|
||||
API_CALLS = "api_calls"
|
||||
NON_STREAMING_API_CALLS = "non_streaming_api_calls"
|
||||
|
||||
|
||||
class TenantUsageStats(BaseModel):
|
||||
"""Current usage statistics for a tenant."""
|
||||
|
||||
window_start: datetime
|
||||
llm_cost_cents: float
|
||||
chunks_indexed: int
|
||||
api_calls: int
|
||||
non_streaming_api_calls: int
|
||||
|
||||
|
||||
class UsageLimitExceededError(Exception):
|
||||
"""Raised when a tenant exceeds their usage limit."""
|
||||
|
||||
def __init__(self, usage_type: UsageType, current: float, limit: float):
|
||||
self.usage_type = usage_type
|
||||
self.current = current
|
||||
self.limit = limit
|
||||
super().__init__(
|
||||
f"Usage limit exceeded for {usage_type.value}: "
|
||||
f"current usage {current}, limit {limit}"
|
||||
)
|
||||
|
||||
|
||||
def get_current_window_start() -> datetime:
|
||||
"""
|
||||
Calculate the start of the current usage window.
|
||||
|
||||
Uses fixed windows aligned to Monday 00:00 UTC for predictability.
|
||||
The window duration is configured via USAGE_LIMIT_WINDOW_SECONDS.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
# For weekly windows (default), align to Monday 00:00 UTC
|
||||
if USAGE_LIMIT_WINDOW_SECONDS == 604800: # 1 week
|
||||
# Get the start of the current week (Monday)
|
||||
days_since_monday = now.weekday()
|
||||
window_start = now.replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
) - __import__("datetime").timedelta(days=days_since_monday)
|
||||
return window_start
|
||||
|
||||
# For other window sizes, use epoch-aligned windows
|
||||
epoch = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
seconds_since_epoch = int((now - epoch).total_seconds())
|
||||
window_number = seconds_since_epoch // USAGE_LIMIT_WINDOW_SECONDS
|
||||
window_start_seconds = window_number * USAGE_LIMIT_WINDOW_SECONDS
|
||||
return epoch + __import__("datetime").timedelta(seconds=window_start_seconds)
|
||||
|
||||
|
||||
def get_or_create_tenant_usage(
|
||||
db_session: Session,
|
||||
window_start: datetime | None = None,
|
||||
) -> TenantUsage:
|
||||
"""
|
||||
Get or create the usage record for the current window.
|
||||
|
||||
Uses INSERT ... ON CONFLICT DO UPDATE to atomically create or get the record,
|
||||
avoiding TOCTOU race conditions where two concurrent requests could both
|
||||
attempt to insert a new record.
|
||||
"""
|
||||
if window_start is None:
|
||||
window_start = get_current_window_start()
|
||||
|
||||
# Atomic upsert: insert if not exists, or update a field to itself if exists
|
||||
# This ensures we always get back a valid row without race conditions
|
||||
stmt = (
|
||||
pg_insert(TenantUsage)
|
||||
.values(
|
||||
window_start=window_start,
|
||||
llm_cost_cents=0.0,
|
||||
chunks_indexed=0,
|
||||
api_calls=0,
|
||||
non_streaming_api_calls=0,
|
||||
)
|
||||
.on_conflict_do_update(
|
||||
index_elements=["window_start"],
|
||||
# No-op update: just set a field to its current value
|
||||
# This ensures the row is returned even on conflict
|
||||
set_={"llm_cost_cents": TenantUsage.llm_cost_cents},
|
||||
)
|
||||
.returning(TenantUsage)
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt).scalar_one()
|
||||
db_session.flush()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_tenant_usage_stats(
|
||||
db_session: Session,
|
||||
window_start: datetime | None = None,
|
||||
) -> TenantUsageStats:
|
||||
"""Get the current usage statistics for the tenant (read-only, no lock)."""
|
||||
if window_start is None:
|
||||
window_start = get_current_window_start()
|
||||
|
||||
usage = db_session.execute(
|
||||
select(TenantUsage).where(TenantUsage.window_start == window_start)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if usage is None:
|
||||
# No usage recorded yet for this window
|
||||
return TenantUsageStats(
|
||||
window_start=window_start,
|
||||
llm_cost_cents=0.0,
|
||||
chunks_indexed=0,
|
||||
api_calls=0,
|
||||
non_streaming_api_calls=0,
|
||||
)
|
||||
|
||||
return TenantUsageStats(
|
||||
window_start=usage.window_start,
|
||||
llm_cost_cents=usage.llm_cost_cents,
|
||||
chunks_indexed=usage.chunks_indexed,
|
||||
api_calls=usage.api_calls,
|
||||
non_streaming_api_calls=usage.non_streaming_api_calls,
|
||||
)
|
||||
|
||||
|
||||
def increment_usage(
|
||||
db_session: Session,
|
||||
usage_type: UsageType,
|
||||
amount: float | int,
|
||||
) -> None:
|
||||
"""
|
||||
Atomically increment a usage counter.
|
||||
|
||||
Uses row-level locking to prevent race conditions.
|
||||
The caller should handle the transaction commit.
|
||||
"""
|
||||
usage = get_or_create_tenant_usage(db_session)
|
||||
|
||||
if usage_type == UsageType.LLM_COST:
|
||||
usage.llm_cost_cents += float(amount)
|
||||
elif usage_type == UsageType.CHUNKS_INDEXED:
|
||||
usage.chunks_indexed += int(amount)
|
||||
elif usage_type == UsageType.API_CALLS:
|
||||
usage.api_calls += int(amount)
|
||||
elif usage_type == UsageType.NON_STREAMING_API_CALLS:
|
||||
usage.non_streaming_api_calls += int(amount)
|
||||
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def check_usage_limit(
|
||||
db_session: Session,
|
||||
usage_type: UsageType,
|
||||
limit: float | int,
|
||||
pending_amount: float | int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Check if the current usage plus pending amount would exceed the limit.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
usage_type: Type of usage to check
|
||||
limit: The maximum allowed usage
|
||||
pending_amount: Amount about to be used (to check before committing)
|
||||
|
||||
Raises:
|
||||
UsageLimitExceededError: If usage would exceed the limit
|
||||
"""
|
||||
stats = get_tenant_usage_stats(db_session)
|
||||
|
||||
current_value: float
|
||||
if usage_type == UsageType.LLM_COST:
|
||||
current_value = stats.llm_cost_cents
|
||||
elif usage_type == UsageType.CHUNKS_INDEXED:
|
||||
current_value = float(stats.chunks_indexed)
|
||||
elif usage_type == UsageType.API_CALLS:
|
||||
current_value = float(stats.api_calls)
|
||||
elif usage_type == UsageType.NON_STREAMING_API_CALLS:
|
||||
current_value = float(stats.non_streaming_api_calls)
|
||||
else:
|
||||
current_value = 0.0
|
||||
|
||||
if current_value + pending_amount > limit:
|
||||
raise UsageLimitExceededError(
|
||||
usage_type=usage_type,
|
||||
current=current_value + pending_amount,
|
||||
limit=float(limit),
|
||||
)
|
||||
@@ -33,16 +33,9 @@ from onyx.llm.models import ToolChoiceOptions
|
||||
from onyx.llm.utils import model_is_reasoning_model
|
||||
from onyx.prompts.deep_research.orchestration_layer import CLARIFICATION_PROMPT
|
||||
from onyx.prompts.deep_research.orchestration_layer import FINAL_REPORT_PROMPT
|
||||
from onyx.prompts.deep_research.orchestration_layer import (
|
||||
INTERNAL_SEARCH_CLARIFICATION_GUIDANCE,
|
||||
)
|
||||
from onyx.prompts.deep_research.orchestration_layer import (
|
||||
INTERNAL_SEARCH_RESEARCH_TASK_GUIDANCE,
|
||||
)
|
||||
from onyx.prompts.deep_research.orchestration_layer import ORCHESTRATOR_PROMPT
|
||||
from onyx.prompts.deep_research.orchestration_layer import ORCHESTRATOR_PROMPT_REASONING
|
||||
from onyx.prompts.deep_research.orchestration_layer import RESEARCH_PLAN_PROMPT
|
||||
from onyx.prompts.deep_research.orchestration_layer import RESEARCH_PLAN_REMINDER
|
||||
from onyx.prompts.deep_research.orchestration_layer import USER_FINAL_REPORT_QUERY
|
||||
from onyx.prompts.prompt_utils import get_current_llm_day_time
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
@@ -205,21 +198,14 @@ def run_deep_research_llm_loop(
|
||||
# Filter tools to only allow web search, internal search, and open URL
|
||||
allowed_tool_names = {SearchTool.NAME, WebSearchTool.NAME, OpenURLTool.NAME}
|
||||
allowed_tools = [tool for tool in tools if tool.name in allowed_tool_names]
|
||||
include_internal_search_tunings = SearchTool.NAME in allowed_tool_names
|
||||
orchestrator_start_turn_index = 1
|
||||
|
||||
#########################################################
|
||||
# CLARIFICATION STEP (optional)
|
||||
#########################################################
|
||||
internal_search_clarification_guidance = (
|
||||
INTERNAL_SEARCH_CLARIFICATION_GUIDANCE
|
||||
if include_internal_search_tunings
|
||||
else ""
|
||||
)
|
||||
if not skip_clarification:
|
||||
clarification_prompt = CLARIFICATION_PROMPT.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
internal_search_clarification_guidance=internal_search_clarification_guidance,
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False)
|
||||
)
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=clarification_prompt,
|
||||
@@ -276,19 +262,15 @@ def run_deep_research_llm_loop(
|
||||
token_count=300,
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
reminder_message = ChatMessageSimple(
|
||||
message=RESEARCH_PLAN_REMINDER,
|
||||
token_count=100,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
|
||||
truncated_message_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history + [reminder_message],
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT + 1,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
)
|
||||
|
||||
research_plan_generator = run_llm_step_pkt_generator(
|
||||
@@ -363,17 +345,11 @@ def run_deep_research_llm_loop(
|
||||
else ORCHESTRATOR_PROMPT_REASONING
|
||||
)
|
||||
|
||||
internal_search_research_task_guidance = (
|
||||
INTERNAL_SEARCH_RESEARCH_TASK_GUIDANCE
|
||||
if include_internal_search_tunings
|
||||
else ""
|
||||
)
|
||||
token_count_prompt = orchestrator_prompt_template.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
current_cycle_count=1,
|
||||
max_cycles=max_orchestrator_cycles,
|
||||
research_plan=research_plan,
|
||||
internal_search_research_task_guidance=internal_search_research_task_guidance,
|
||||
)
|
||||
orchestration_tokens = token_counter(token_count_prompt)
|
||||
|
||||
@@ -410,7 +386,6 @@ def run_deep_research_llm_loop(
|
||||
current_cycle_count=cycle,
|
||||
max_cycles=max_orchestrator_cycles,
|
||||
research_plan=research_plan,
|
||||
internal_search_research_task_guidance=internal_search_research_task_guidance,
|
||||
)
|
||||
|
||||
system_prompt = ChatMessageSimple(
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_FOR_ONYX
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchOldDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -27,24 +23,15 @@ def get_default_document_index(
|
||||
secondary_index_name = secondary_search_settings.index_name
|
||||
secondary_large_chunks_enabled = secondary_search_settings.large_chunks_enabled
|
||||
|
||||
if ENABLE_OPENSEARCH_FOR_ONYX:
|
||||
return OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
secondary_index_name=secondary_index_name,
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
else:
|
||||
return VespaIndex(
|
||||
index_name=search_settings.index_name,
|
||||
secondary_index_name=secondary_index_name,
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
# Currently only supporting Vespa
|
||||
return VespaIndex(
|
||||
index_name=search_settings.index_name,
|
||||
secondary_index_name=secondary_index_name,
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
|
||||
|
||||
def get_current_primary_default_document_index(db_session: Session) -> DocumentIndex:
|
||||
|
||||
@@ -109,6 +109,10 @@ class VespaDocumentFields:
|
||||
hidden: bool | None = None
|
||||
aggregated_chunk_boost_factor: float | None = None
|
||||
|
||||
# document_id is added for migration purposes, ideally we should not be updating this field
|
||||
# TODO(subash): remove this field in a future migration
|
||||
document_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VespaDocumentUserFields:
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import abc
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.constants import PUBLIC_DOC_PAT
|
||||
@@ -10,7 +8,6 @@ from onyx.context.search.enums import QueryType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
@@ -40,25 +37,6 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class TenantState(BaseModel):
|
||||
"""
|
||||
Captures the tenant-related state for an instance of DocumentIndex.
|
||||
|
||||
NOTE: Tenant ID must be set in multitenant mode.
|
||||
"""
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
tenant_id: str
|
||||
multitenant: bool
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_tenant_id_is_set_in_multitenant_mode(self) -> Self:
|
||||
if self.multitenant and not self.tenant_id:
|
||||
raise ValueError("Bug: Tenant ID must be set in multitenant mode.")
|
||||
return self
|
||||
|
||||
|
||||
class DocumentInsertionRecord(BaseModel):
|
||||
"""
|
||||
Result of indexing a document.
|
||||
@@ -83,20 +61,6 @@ class DocumentSectionRequest(BaseModel):
|
||||
document_id: str
|
||||
min_chunk_ind: int | None = None
|
||||
max_chunk_ind: int | None = None
|
||||
# A given document can have multiple chunking strategies.
|
||||
max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_chunk_index_range_is_valid(self) -> Self:
|
||||
if (
|
||||
self.min_chunk_ind is not None
|
||||
and self.max_chunk_ind is not None
|
||||
and self.min_chunk_ind > self.max_chunk_ind
|
||||
):
|
||||
raise ValueError(
|
||||
"Bug: Min chunk index must be less than or equal to max chunk index."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class IndexingMetadata(BaseModel):
|
||||
@@ -222,9 +186,9 @@ class Indexable(abc.ABC):
|
||||
cleaning / updating.
|
||||
|
||||
Returns:
|
||||
List of document IDs which map to unique documents as well as if the
|
||||
document is newly indexed or had already existed and was just
|
||||
updated.
|
||||
List of document IDs which map to unique documents and are used for
|
||||
deduping chunks when updating, as well as if the document is newly
|
||||
indexed or already existed and just updated.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -248,10 +212,6 @@ class Deletable(abc.ABC):
|
||||
Hard deletes all of the chunks for the corresponding document in the
|
||||
document index.
|
||||
|
||||
TODO(andrei): Not a pressing issue now but think about what we want the
|
||||
contract of this method to be in the event the specified document ID
|
||||
does not exist.
|
||||
|
||||
Args:
|
||||
document_id: The unique identifier for the document as represented
|
||||
in Onyx, not necessarily in the document index.
|
||||
@@ -282,6 +242,10 @@ class Updatable(abc.ABC):
|
||||
def update(
|
||||
self,
|
||||
update_requests: list[MetadataUpdateRequest],
|
||||
# TODO(andrei), WARNING: Very temporary, this is not the interface we want
|
||||
# in Updatable, we only have this to continue supporting
|
||||
# user_file_docid_migration_task for Vespa which should be done soon.
|
||||
old_doc_id_to_new_doc_id: dict[str, str],
|
||||
) -> None:
|
||||
"""
|
||||
Updates some set of chunks. The document and fields to update are specified in the update
|
||||
|
||||
@@ -1,521 +0,0 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from opensearchpy import OpenSearch
|
||||
from opensearchpy.exceptions import TransportError
|
||||
|
||||
from onyx.configs.app_configs import OPENSEARCH_ADMIN_PASSWORD
|
||||
from onyx.configs.app_configs import OPENSEARCH_ADMIN_USERNAME
|
||||
from onyx.configs.app_configs import OPENSEARCH_HOST
|
||||
from onyx.configs.app_configs import OPENSEARCH_REST_API_PORT
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
|
||||
from onyx.document_index.opensearch.search import DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
# Set the logging level to WARNING to ignore INFO and DEBUG logs from
|
||||
# opensearch. By default it emits INFO-level logs for every request.
|
||||
opensearch_logger = logging.getLogger("opensearchpy")
|
||||
opensearch_logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class OpenSearchClient:
|
||||
"""Client for interacting with OpenSearch.
|
||||
|
||||
OpenSearch's Python module has pretty bad typing support so this client
|
||||
attempts to protect the rest of the codebase from this. As a consequence,
|
||||
most methods here return the minimum data needed for the rest of Onyx, and
|
||||
tend to rely on Exceptions to handle errors.
|
||||
|
||||
TODO(andrei): This class currently assumes the structure of the database
|
||||
schema when it returns a DocumentChunk. Make the class, or at least the
|
||||
search method, templated on the structure the caller can expect.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
host: str = OPENSEARCH_HOST,
|
||||
port: int = OPENSEARCH_REST_API_PORT,
|
||||
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
|
||||
use_ssl: bool = True,
|
||||
verify_certs: bool = False,
|
||||
ssl_show_warn: bool = False,
|
||||
):
|
||||
self._index_name = index_name
|
||||
self._client = OpenSearch(
|
||||
hosts=[{"host": host, "port": port}],
|
||||
http_auth=auth,
|
||||
use_ssl=use_ssl,
|
||||
verify_certs=verify_certs,
|
||||
ssl_show_warn=ssl_show_warn,
|
||||
)
|
||||
|
||||
def create_index(self, mappings: dict[str, Any], settings: dict[str, Any]) -> None:
|
||||
"""Creates the index.
|
||||
|
||||
See the OpenSearch documentation for more information on mappings and
|
||||
settings.
|
||||
|
||||
Args:
|
||||
mappings: The mappings for the index to create.
|
||||
settings: The settings for the index to create.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error creating the index.
|
||||
"""
|
||||
body: dict[str, Any] = {
|
||||
"mappings": mappings,
|
||||
"settings": settings,
|
||||
}
|
||||
response = self._client.indices.create(index=self._index_name, body=body)
|
||||
if not response.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to create index {self._index_name}.")
|
||||
response_index = response.get("index", "")
|
||||
if response_index != self._index_name:
|
||||
raise RuntimeError(
|
||||
f"OpenSearch responded with index name {response_index} when creating index {self._index_name}."
|
||||
)
|
||||
|
||||
def delete_index(self) -> bool:
|
||||
"""Deletes the index.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error deleting the index.
|
||||
|
||||
Returns:
|
||||
True if the index was deleted, False if it did not exist.
|
||||
"""
|
||||
if not self._client.indices.exists(index=self._index_name):
|
||||
logger.warning(
|
||||
f"Tried to delete index {self._index_name} but it does not exist."
|
||||
)
|
||||
return False
|
||||
|
||||
response = self._client.indices.delete(index=self._index_name)
|
||||
if not response.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to delete index {self._index_name}.")
|
||||
return True
|
||||
|
||||
def index_exists(self) -> bool:
|
||||
"""Checks if the index exists.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error checking if the index exists.
|
||||
|
||||
Returns:
|
||||
True if the index exists, False if it does not.
|
||||
"""
|
||||
return self._client.indices.exists(index=self._index_name)
|
||||
|
||||
def validate_index(self, expected_mappings: dict[str, Any]) -> bool:
|
||||
"""Validates the index.
|
||||
|
||||
Short-circuit returns False on the first mismatch. Logs the mismatch.
|
||||
|
||||
See the OpenSearch documentation for more information on the index
|
||||
mappings.
|
||||
https://docs.opensearch.org/latest/mappings/
|
||||
|
||||
Args:
|
||||
mappings: The expected mappings of the index to validate.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error validating the index.
|
||||
|
||||
Returns:
|
||||
True if the index is valid, False if it is not based on the mappings
|
||||
supplied.
|
||||
"""
|
||||
# OpenSearch's documentation makes no mention of what happens when you
|
||||
# invoke client.indices.get on an index that does not exist, so we check
|
||||
# for existence explicitly just to be sure.
|
||||
exists_response = self.index_exists()
|
||||
if not exists_response:
|
||||
logger.warning(
|
||||
f"Tried to validate index {self._index_name} but it does not exist."
|
||||
)
|
||||
return False
|
||||
|
||||
get_result = self._client.indices.get(index=self._index_name)
|
||||
index_info: dict[str, Any] = get_result.get(self._index_name, {})
|
||||
if not index_info:
|
||||
raise ValueError(
|
||||
f"Bug: OpenSearch did not return any index info for index {self._index_name}, "
|
||||
"even though it confirmed that the index exists."
|
||||
)
|
||||
index_mapping_properties: dict[str, Any] = index_info.get("mappings", {}).get(
|
||||
"properties", {}
|
||||
)
|
||||
expected_mapping_properties: dict[str, Any] = expected_mappings.get(
|
||||
"properties", {}
|
||||
)
|
||||
assert (
|
||||
expected_mapping_properties
|
||||
), "Bug: No properties were found in the provided expected mappings."
|
||||
|
||||
for property in expected_mapping_properties:
|
||||
if property not in index_mapping_properties:
|
||||
logger.warning(
|
||||
f'The field "{property}" was not found in the index {self._index_name}.'
|
||||
)
|
||||
return False
|
||||
|
||||
expected_property_type = expected_mapping_properties[property].get(
|
||||
"type", ""
|
||||
)
|
||||
assert (
|
||||
expected_property_type
|
||||
), f'Bug: The field "{property}" in the supplied expected schema mappings has no type.'
|
||||
|
||||
index_property_type = index_mapping_properties[property].get("type", "")
|
||||
if expected_property_type != index_property_type:
|
||||
logger.warning(
|
||||
f'The field "{property}" in the index {self._index_name} has type {index_property_type} '
|
||||
f"but the expected type is {expected_property_type}."
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def update_settings(self, settings: dict[str, Any]) -> None:
|
||||
"""Updates the settings of the index.
|
||||
|
||||
See the OpenSearch documentation for more information on the index
|
||||
settings.
|
||||
https://docs.opensearch.org/latest/install-and-configure/configuring-opensearch/index-settings/
|
||||
|
||||
Args:
|
||||
settings: The settings to update the index with.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error updating the settings of the index.
|
||||
"""
|
||||
# TODO(andrei): Implement this.
|
||||
raise NotImplementedError
|
||||
|
||||
def index_document(self, document: DocumentChunk) -> None:
|
||||
"""Indexes a document.
|
||||
|
||||
Indexing will fail if a document with the same ID already exists.
|
||||
|
||||
Args:
|
||||
document: The document to index. In Onyx this is a chunk of a
|
||||
document, OpenSearch simply refers to this as a document as
|
||||
well.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error indexing the document. This includes
|
||||
the case where a document with the same ID already exists.
|
||||
"""
|
||||
document_chunk_id: str = get_opensearch_doc_chunk_id(
|
||||
document_id=document.document_id,
|
||||
chunk_index=document.chunk_index,
|
||||
max_chunk_size=document.max_chunk_size,
|
||||
)
|
||||
body: dict[str, Any] = document.model_dump(exclude_none=True)
|
||||
# client.create will raise if a doc with the same ID exists.
|
||||
# client.index does not do this.
|
||||
result = self._client.create(
|
||||
index=self._index_name, id=document_chunk_id, body=body
|
||||
)
|
||||
result_id = result.get("_id", "")
|
||||
# Sanity check.
|
||||
if result_id != document_chunk_id:
|
||||
raise RuntimeError(
|
||||
f'Upon trying to index a document, OpenSearch responded with ID "{result_id}" '
|
||||
f'instead of "{document_chunk_id}" which is the ID it was given.'
|
||||
)
|
||||
result_string: str = result.get("result", "")
|
||||
match result_string:
|
||||
case "created":
|
||||
return
|
||||
# Sanity check.
|
||||
case "updated":
|
||||
raise RuntimeError(
|
||||
f'The OpenSearch client returned result "updated" for indexing document chunk "{document_chunk_id}". '
|
||||
"This indicates that a document chunk with that ID already exists, which is not expected."
|
||||
)
|
||||
case _:
|
||||
raise RuntimeError(
|
||||
f'Unknown OpenSearch indexing result: "{result_string}".'
|
||||
)
|
||||
|
||||
def delete_document(self, document_chunk_id: str) -> bool:
|
||||
"""Deletes a document.
|
||||
|
||||
Args:
|
||||
document_chunk_id: The OpenSearch ID of the document chunk to
|
||||
delete.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error deleting the document.
|
||||
|
||||
Returns:
|
||||
True if the document was deleted, False if it was not found.
|
||||
"""
|
||||
try:
|
||||
result = self._client.delete(index=self._index_name, id=document_chunk_id)
|
||||
except TransportError as e:
|
||||
if e.status_code == 404:
|
||||
return False
|
||||
else:
|
||||
raise e
|
||||
|
||||
result_string: str = result.get("result", "")
|
||||
match result_string:
|
||||
case "deleted":
|
||||
return True
|
||||
case "not_found":
|
||||
return False
|
||||
case _:
|
||||
raise RuntimeError(
|
||||
f'Unknown OpenSearch deletion result: "{result_string}".'
|
||||
)
|
||||
|
||||
def delete_by_query(self, query_body: dict[str, Any]) -> int:
|
||||
"""Deletes documents by a query.
|
||||
|
||||
Args:
|
||||
query_body: The body of the query to delete documents by.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error deleting the documents.
|
||||
|
||||
Returns:
|
||||
The number of documents deleted.
|
||||
"""
|
||||
result = self._client.delete_by_query(index=self._index_name, body=query_body)
|
||||
if result.get("timed_out", False):
|
||||
raise RuntimeError(
|
||||
f"Delete by query timed out for index {self._index_name}."
|
||||
)
|
||||
if len(result.get("failures", [])) > 0:
|
||||
raise RuntimeError(
|
||||
f"Failed to delete some or all of the documents for index {self._index_name}."
|
||||
)
|
||||
|
||||
num_deleted = result.get("deleted", 0)
|
||||
num_processed = result.get("total", 0)
|
||||
if num_deleted != num_processed:
|
||||
raise RuntimeError(
|
||||
f"Failed to delete some or all of the documents for index {self._index_name}. "
|
||||
f"{num_deleted} documents were deleted out of {num_processed} documents that were processed."
|
||||
)
|
||||
|
||||
return num_deleted
|
||||
|
||||
def update_document(self) -> None:
|
||||
# TODO(andrei): Implement this.
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
def get_document(self, document_chunk_id: str) -> DocumentChunk:
|
||||
"""Gets a document.
|
||||
|
||||
Will raise an exception if the document is not found.
|
||||
|
||||
Args:
|
||||
document_chunk_id: The OpenSearch ID of the document chunk to get.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error getting the document. This includes
|
||||
the case where the document is not found.
|
||||
|
||||
Returns:
|
||||
The document chunk.
|
||||
"""
|
||||
result = self._client.get(index=self._index_name, id=document_chunk_id)
|
||||
found_result: bool = result.get("found", False)
|
||||
if not found_result:
|
||||
raise RuntimeError(
|
||||
f'Document chunk with ID "{document_chunk_id}" was not found.'
|
||||
)
|
||||
|
||||
document_chunk_source: dict[str, Any] | None = result.get("_source")
|
||||
if not document_chunk_source:
|
||||
raise RuntimeError(
|
||||
f'Document chunk with ID "{document_chunk_id}" has no data.'
|
||||
)
|
||||
|
||||
return DocumentChunk.model_validate(document_chunk_source)
|
||||
|
||||
def create_search_pipeline(
|
||||
self,
|
||||
pipeline_id: str,
|
||||
pipeline_body: dict[str, Any],
|
||||
) -> None:
|
||||
"""Creates a search pipeline.
|
||||
|
||||
See the OpenSearch documentation for more information on the search
|
||||
pipeline body.
|
||||
https://docs.opensearch.org/latest/search-plugins/search-pipelines/index/
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to create.
|
||||
pipeline_body: The body of the search pipeline to create.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error creating the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
|
||||
|
||||
def delete_search_pipeline(self, pipeline_id: str) -> None:
|
||||
"""Deletes a search pipeline.
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to delete.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error deleting the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
|
||||
|
||||
def search(
|
||||
self, body: dict[str, Any], search_pipeline_id: str | None
|
||||
) -> list[DocumentChunk]:
|
||||
"""Searches the index.
|
||||
|
||||
TODO(andrei): Ideally we could check that every field in the body is
|
||||
present in the index, to avoid a class of runtime bugs that could easily
|
||||
be caught during development.
|
||||
|
||||
Args:
|
||||
body: The body of the search request. See the OpenSearch
|
||||
documentation for more information on search request bodies.
|
||||
search_pipeline_id: The ID of the search pipeline to use. If None,
|
||||
the default search pipeline will be used.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error searching the index.
|
||||
|
||||
Returns:
|
||||
List of document chunks that match the search request.
|
||||
"""
|
||||
result: dict[str, Any]
|
||||
if search_pipeline_id:
|
||||
result = self._client.search(
|
||||
index=self._index_name, search_pipeline=search_pipeline_id, body=body
|
||||
)
|
||||
else:
|
||||
result = self._client.search(index=self._index_name, body=body)
|
||||
|
||||
hits = self._get_hits_from_search_result(result)
|
||||
|
||||
result_chunks: list[DocumentChunk] = []
|
||||
for hit in hits:
|
||||
document_chunk_source: dict[str, Any] | None = hit.get("_source")
|
||||
if not document_chunk_source:
|
||||
raise RuntimeError(
|
||||
f"Document chunk with ID \"{hit.get('_id', '')}\" has no data."
|
||||
)
|
||||
result_chunks.append(DocumentChunk.model_validate(document_chunk_source))
|
||||
return result_chunks
|
||||
|
||||
def search_for_document_ids(self, body: dict[str, Any]) -> list[str]:
|
||||
"""Searches the index and returns only document chunk IDs.
|
||||
|
||||
In order to take advantage of the performance benefits of only returning
|
||||
IDs, the body should have a key, value pair of "_source": False.
|
||||
Otherwise, OpenSearch will return the entire document body and this
|
||||
method's performance will be the same as the search method's.
|
||||
|
||||
TODO(andrei): Ideally we could check that every field in the body is
|
||||
present in the index, to avoid a class of runtime bugs that could easily
|
||||
be caught during development.
|
||||
|
||||
Args:
|
||||
body: The body of the search request. See the OpenSearch
|
||||
documentation for more information on search request bodies.
|
||||
TODO(andrei): Make this a more deep interface; callers shouldn't
|
||||
need to know to set _source: False for example.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error searching the index.
|
||||
|
||||
Returns:
|
||||
List of document chunk IDs that match the search request.
|
||||
"""
|
||||
if "_source" not in body or body["_source"] is not False:
|
||||
logger.warning(
|
||||
"The body of the search request for document chunk IDs is missing the key, value pair of "
|
||||
'"_source": False. This query will therefore be inefficient.'
|
||||
)
|
||||
|
||||
result: dict[str, Any] = self._client.search(index=self._index_name, body=body)
|
||||
|
||||
hits = self._get_hits_from_search_result(result)
|
||||
|
||||
# TODO(andrei): Implement scroll/point in time for results so that we
|
||||
# can return arbitrarily-many IDs.
|
||||
if len(hits) == DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
|
||||
logger.warning(
|
||||
"The search request for document chunk IDs returned the maximum number of results. "
|
||||
"It is extremely likely that there are more hits in OpenSearch than the returned results."
|
||||
)
|
||||
|
||||
# Extract only the _id field from each hit.
|
||||
document_chunk_ids: list[str] = []
|
||||
for hit in hits:
|
||||
document_chunk_id = hit.get("_id")
|
||||
if not document_chunk_id:
|
||||
raise RuntimeError(
|
||||
"Received a hit from OpenSearch but the _id field is missing."
|
||||
)
|
||||
document_chunk_ids.append(document_chunk_id)
|
||||
return document_chunk_ids
|
||||
|
||||
def refresh_index(self) -> None:
|
||||
"""Refreshes the index to make recent changes searchable.
|
||||
|
||||
In OpenSearch, documents are not immediately searchable after indexing.
|
||||
This method forces a refresh to make them available for search.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error refreshing the index.
|
||||
"""
|
||||
self._client.indices.refresh(index=self._index_name)
|
||||
|
||||
def ping(self) -> bool:
|
||||
"""Pings the OpenSearch cluster.
|
||||
|
||||
Returns:
|
||||
True if OpenSearch could be reached, False if it could not.
|
||||
"""
|
||||
return self._client.ping()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the client.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error closing the client.
|
||||
"""
|
||||
self._client.close()
|
||||
|
||||
def _get_hits_from_search_result(self, result: dict[str, Any]) -> list[Any]:
|
||||
"""Extracts the hits from a search result.
|
||||
|
||||
Args:
|
||||
result: The search result to extract the hits from.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error extracting the hits from the search
|
||||
result. This includes the case where the search timed out.
|
||||
|
||||
Returns:
|
||||
The hits from the search result.
|
||||
"""
|
||||
if result.get("timed_out", False):
|
||||
raise RuntimeError(f"Search timed out for index {self._index_name}.")
|
||||
hits_first_layer: dict[str, Any] = result.get("hits", {})
|
||||
if not hits_first_layer:
|
||||
raise RuntimeError(
|
||||
f"Hits field missing from response when trying to search index {self._index_name}."
|
||||
)
|
||||
hits_second_layer: list[Any] = hits_first_layer.get("hits", [])
|
||||
return hits_second_layer
|
||||
@@ -1,23 +0,0 @@
|
||||
# Size of the dynamic list used to consider elements during kNN graph creation.
|
||||
# Higher values improve search quality but increase indexing time. Values
|
||||
# typically range between 100 - 512.
|
||||
EF_CONSTRUCTION = 256
|
||||
# Number of bi-directional links per element. Higher values improve search
|
||||
# quality but increase memory footprint. Values typically range between 12 - 48.
|
||||
M = 32 # Increased for better accuracy.
|
||||
|
||||
# Default value for the maximum number of tokens a chunk can hold, if none is
|
||||
# specified when creating an index.
|
||||
DEFAULT_MAX_CHUNK_SIZE = 512
|
||||
|
||||
# Number of vectors to examine for top k neighbors for the HNSW method. Values
|
||||
# typically range between 100 - 200.
|
||||
EF_SEARCH = 200
|
||||
|
||||
# Default weights to use for hybrid search normalization. These values should
|
||||
# sum to 1.
|
||||
SEARCH_TITLE_VECTOR_WEIGHT = 0.05
|
||||
SEARCH_TITLE_KEYWORD_WEIGHT = 0.05
|
||||
SEARCH_CONTENT_VECTOR_WEIGHT = 0.50 # Increased to favor semantic search.
|
||||
SEARCH_CONTENT_KEYWORD_WEIGHT = 0.35 # Decreased to favor semantic search.
|
||||
SEARCH_CONTENT_PHRASE_WEIGHT = 0.05
|
||||
@@ -1,608 +0,0 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.context.search.enums import QueryType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.context.search.models import QueryExpansionType
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.models import DocumentSource
|
||||
from onyx.document_index.interfaces import DocumentIndex as OldDocumentIndex
|
||||
from onyx.document_index.interfaces import (
|
||||
DocumentInsertionRecord as OldDocumentInsertionRecord,
|
||||
)
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.document_index.interfaces import UpdateRequest
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.interfaces_new import DocumentIndex
|
||||
from onyx.document_index.interfaces_new import DocumentInsertionRecord
|
||||
from onyx.document_index.interfaces_new import DocumentSectionRequest
|
||||
from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.interfaces_new import MetadataUpdateRequest
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
from onyx.document_index.opensearch.schema import DocumentSchema
|
||||
from onyx.document_index.opensearch.search import DocumentQuery
|
||||
from onyx.document_index.opensearch.search import (
|
||||
MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
from onyx.document_index.opensearch.search import (
|
||||
MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
)
|
||||
from onyx.document_index.opensearch.search import (
|
||||
ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
from onyx.document_index.opensearch.search import (
|
||||
ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
)
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import Document
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
def _convert_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
chunk: DocumentChunk,
|
||||
) -> InferenceChunkUncleaned:
|
||||
return InferenceChunkUncleaned(
|
||||
chunk_id=chunk.chunk_index,
|
||||
blurb=chunk.blurb,
|
||||
content=chunk.content,
|
||||
source_links=json.loads(chunk.source_links) if chunk.source_links else None,
|
||||
image_file_id=chunk.image_file_name,
|
||||
# TODO(andrei) Yuhong says he doesn't think we need that anymore. Used
|
||||
# if a section needed to be split into diff chunks. A section is a part
|
||||
# of a doc that a link will take you to. But don't chunks have their own
|
||||
# links? Look at this in a followup.
|
||||
section_continuation=False,
|
||||
document_id=chunk.document_id,
|
||||
source_type=DocumentSource(chunk.source_type),
|
||||
semantic_identifier=chunk.semantic_identifier,
|
||||
title=chunk.title,
|
||||
# TODO(andrei): Same comment as in
|
||||
# _convert_onyx_chunk_to_opensearch_document. Yuhong thinks OpenSearch
|
||||
# has some thing out of the box for this. Just need to look at it in a
|
||||
# followup.
|
||||
boost=1,
|
||||
# TODO(andrei): Do in a followup.
|
||||
recency_bias=1.0,
|
||||
# TODO(andrei): This is how good the match is, we need this, key insight
|
||||
# is we can order chunks by this. Should not be hard to plumb this from
|
||||
# a search result, do that in a followup.
|
||||
score=None,
|
||||
hidden=chunk.hidden,
|
||||
# TODO(andrei): Don't worry about these for now.
|
||||
# is_relevant
|
||||
# relevance_explanation
|
||||
# metadata
|
||||
# TODO(andrei): Same comment as in
|
||||
# _convert_onyx_chunk_to_opensearch_document.
|
||||
metadata={},
|
||||
# TODO(andrei): The vector DB needs to supply this. I vaguely know
|
||||
# OpenSearch can from the documentation I've seen till now, look at this
|
||||
# in a followup.
|
||||
match_highlights=[],
|
||||
# TODO(andrei) This content is not queried on, it is only used to clean
|
||||
# appended content to chunks. Consider storing a chunk content index
|
||||
# instead of a full string when working on chunk content augmentation.
|
||||
doc_summary="",
|
||||
# TODO(andrei) Same thing as contx ret above, LLM gens context for each
|
||||
# chunk.
|
||||
chunk_context="",
|
||||
updated_at=chunk.last_updated,
|
||||
# primary_owners TODO(andrei)
|
||||
# secondary_owners TODO(andrei)
|
||||
# large_chunk_reference_ids TODO(andrei): Don't worry about this one.
|
||||
# TODO(andrei): This is the suffix appended to the end of the chunk
|
||||
# content to assist querying. There are better ways we can do this, for
|
||||
# ex. keeping an index of where to string split from.
|
||||
metadata_suffix=None,
|
||||
)
|
||||
|
||||
|
||||
def _convert_inference_chunk_uncleaned_to_inference_chunk(
|
||||
inference_chunk_uncleaned: InferenceChunkUncleaned,
|
||||
) -> InferenceChunk:
|
||||
# TODO(andrei): Implement this.
|
||||
return inference_chunk_uncleaned.to_inference_chunk()
|
||||
|
||||
|
||||
def _convert_onyx_chunk_to_opensearch_document(
|
||||
chunk: DocMetadataAwareIndexChunk,
|
||||
) -> DocumentChunk:
|
||||
return DocumentChunk(
|
||||
document_id=chunk.source_document.id,
|
||||
chunk_index=chunk.chunk_id,
|
||||
title=chunk.source_document.title,
|
||||
title_vector=chunk.title_embedding,
|
||||
content=chunk.content,
|
||||
content_vector=chunk.embeddings.full_embedding,
|
||||
# TODO(andrei): We should know this. Reason to have this is convenience,
|
||||
# but it could also change when you change your embedding model, maybe
|
||||
# we can remove it, Yuhong to look at this. Hardcoded to some nonsense
|
||||
# value for now.
|
||||
num_tokens=0,
|
||||
source_type=chunk.source_document.source.value,
|
||||
# TODO(andrei): This is just represented a bit differently in
|
||||
# DocumentBase than how we expect it in the schema currently. Look at
|
||||
# this closer in a followup. Always defaults to None for now.
|
||||
# metadata=chunk.source_document.metadata,
|
||||
last_updated=chunk.source_document.doc_updated_at,
|
||||
# TODO(andrei): Don't currently see an easy way of porting this, and
|
||||
# besides some connectors genuinely don't have this data. Look at this
|
||||
# closer in a followup. Always defaults to None for now.
|
||||
# created_at=None,
|
||||
public=chunk.access.is_public,
|
||||
# TODO(andrei): Implement ACL in a followup, currently none of the
|
||||
# methods in OpenSearchDocumentIndex support it anyway. Always defaults
|
||||
# to None for now.
|
||||
# access_control_list=chunk.access.to_acl(),
|
||||
# TODO(andrei): This doesn't work bc global_boost is float, presumably
|
||||
# between 0.0 and inf (check this) and chunk.boost is an int from -inf
|
||||
# to +inf. Look at how the scaling compares between these in a followup.
|
||||
# Always defaults to 1.0 for now.
|
||||
# global_boost=chunk.boost,
|
||||
semantic_identifier=chunk.source_document.semantic_identifier,
|
||||
# TODO(andrei): Ask Chris more about this later. Always defaults to None
|
||||
# for now.
|
||||
# image_file_name=None,
|
||||
source_links=json.dumps(chunk.source_links) if chunk.source_links else None,
|
||||
blurb=chunk.blurb,
|
||||
document_sets=list(chunk.document_sets) if chunk.document_sets else None,
|
||||
project_ids=list(chunk.user_project) if chunk.user_project else None,
|
||||
# TODO(andrei): Consider not even getting this from
|
||||
# DocMetadataAwareIndexChunk and instead using OpenSearchDocumentIndex's
|
||||
# instance variable. One source of truth -> less chance of a very bad
|
||||
# bug in prod.
|
||||
tenant_id=chunk.tenant_id,
|
||||
)
|
||||
|
||||
|
||||
def _enrich_chunk_info() -> None: # pyright: ignore[reportUnusedFunction]
|
||||
# TODO(andrei): Implement this. Until then, we do not enrich chunk content
|
||||
# with title, etc.
|
||||
raise NotImplementedError(
|
||||
"[ANDREI]: Enrich chunk info is not implemented for OpenSearch."
|
||||
)
|
||||
|
||||
|
||||
def _clean_chunk_info() -> None: # pyright: ignore[reportUnusedFunction]
|
||||
# Analogous to _cleanup_chunks in vespa_document_index.py.
|
||||
# TODO(andrei): Implement this. Until then, we do not enrich chunk content
|
||||
# with title, etc.
|
||||
raise NotImplementedError(
|
||||
"[ANDREI]: Clean chunk info is not implemented for OpenSearch."
|
||||
)
|
||||
|
||||
|
||||
class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
"""
|
||||
Wrapper for OpenSearch to adapt the new DocumentIndex interface with
|
||||
invocations to the old DocumentIndex interface in the hotpath.
|
||||
|
||||
The analogous class for Vespa is VespaIndex which calls to
|
||||
VespaDocumentIndex.
|
||||
|
||||
TODO(andrei): This is very dumb and purely temporary until there are no more
|
||||
references to the old interface in the hotpath.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
secondary_index_name: str | None,
|
||||
large_chunks_enabled: bool,
|
||||
secondary_large_chunks_enabled: bool | None,
|
||||
multitenant: bool = False,
|
||||
httpx_client: httpx.Client | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
index_name=index_name,
|
||||
secondary_index_name=secondary_index_name,
|
||||
)
|
||||
self._real_index = OpenSearchDocumentIndex(
|
||||
index_name=index_name,
|
||||
# TODO(andrei): Sus. Do not plug this into production until all
|
||||
# instances where tenant ID is passed into a method call get
|
||||
# refactored to passing this data in on class init.
|
||||
tenant_state=TenantState(tenant_id="", multitenant=multitenant),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def register_multitenant_indices(
|
||||
indices: list[str],
|
||||
embedding_dims: list[int],
|
||||
embedding_precisions: list[EmbeddingPrecision],
|
||||
) -> None:
|
||||
raise NotImplementedError(
|
||||
"[ANDREI]: Multitenant index registration is not implemented for OpenSearch."
|
||||
)
|
||||
|
||||
def ensure_indices_exist(
|
||||
self,
|
||||
primary_embedding_dim: int,
|
||||
primary_embedding_precision: EmbeddingPrecision,
|
||||
secondary_index_embedding_dim: int | None,
|
||||
secondary_index_embedding_precision: EmbeddingPrecision | None,
|
||||
) -> None:
|
||||
# Only handle primary index for now, ignore secondary.
|
||||
return self._real_index.verify_and_create_index_if_necessary(
|
||||
primary_embedding_dim, primary_embedding_precision
|
||||
)
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
# Convert IndexBatchParams to IndexingMetadata.
|
||||
chunk_counts: dict[str, IndexingMetadata.ChunkCounts] = {}
|
||||
for doc_id in index_batch_params.doc_id_to_new_chunk_cnt:
|
||||
old_count = index_batch_params.doc_id_to_previous_chunk_cnt[doc_id]
|
||||
new_count = index_batch_params.doc_id_to_new_chunk_cnt[doc_id]
|
||||
chunk_counts[doc_id] = IndexingMetadata.ChunkCounts(
|
||||
old_chunk_cnt=old_count,
|
||||
new_chunk_cnt=new_count,
|
||||
)
|
||||
|
||||
indexing_metadata = IndexingMetadata(doc_id_to_chunk_cnt_diff=chunk_counts)
|
||||
|
||||
results = self._real_index.index(chunks, indexing_metadata)
|
||||
|
||||
# Convert list[DocumentInsertionRecord] to
|
||||
# set[OldDocumentInsertionRecord].
|
||||
return {
|
||||
OldDocumentInsertionRecord(
|
||||
document_id=record.document_id,
|
||||
already_existed=record.already_existed,
|
||||
)
|
||||
for record in results
|
||||
}
|
||||
|
||||
def delete_single(
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
return self._real_index.delete(doc_id, chunk_count)
|
||||
|
||||
def update_single(
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
fields: VespaDocumentFields | None,
|
||||
user_fields: VespaDocumentUserFields | None,
|
||||
) -> None:
|
||||
if fields is None and user_fields is None:
|
||||
raise ValueError(
|
||||
f"Bug: Tried to update document {doc_id} with no updated fields or user fields."
|
||||
)
|
||||
|
||||
# Convert VespaDocumentFields to MetadataUpdateRequest.
|
||||
update_request = MetadataUpdateRequest(
|
||||
document_ids=[doc_id],
|
||||
doc_id_to_chunk_cnt={
|
||||
doc_id: chunk_count if chunk_count is not None else -1
|
||||
},
|
||||
access=fields.access if fields else None,
|
||||
document_sets=fields.document_sets if fields else None,
|
||||
boost=fields.boost if fields else None,
|
||||
hidden=fields.hidden if fields else None,
|
||||
project_ids=(
|
||||
set(user_fields.user_projects)
|
||||
if user_fields and user_fields.user_projects
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
return self._real_index.update([update_request])
|
||||
|
||||
def update(
|
||||
self,
|
||||
update_requests: list[UpdateRequest],
|
||||
*,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
raise NotImplementedError("[ANDREI]: Update is not implemented for OpenSearch.")
|
||||
|
||||
def id_based_retrieval(
|
||||
self,
|
||||
chunk_requests: list[VespaChunkRequest],
|
||||
filters: IndexFilters,
|
||||
batch_retrieval: bool = False,
|
||||
get_large_chunks: bool = False,
|
||||
) -> list[InferenceChunk]:
|
||||
section_requests = [
|
||||
DocumentSectionRequest(
|
||||
document_id=req.document_id,
|
||||
min_chunk_ind=req.min_chunk_ind,
|
||||
max_chunk_ind=req.max_chunk_ind,
|
||||
)
|
||||
for req in chunk_requests
|
||||
]
|
||||
|
||||
return self._real_index.id_based_retrieval(
|
||||
section_requests, filters, batch_retrieval
|
||||
)
|
||||
|
||||
def hybrid_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
query_embedding: Embedding,
|
||||
final_keywords: list[str] | None,
|
||||
filters: IndexFilters,
|
||||
hybrid_alpha: float,
|
||||
time_decay_multiplier: float,
|
||||
num_to_retrieve: int,
|
||||
ranking_profile_type: QueryExpansionType = QueryExpansionType.SEMANTIC,
|
||||
offset: int = 0,
|
||||
title_content_ratio: float | None = TITLE_CONTENT_RATIO,
|
||||
) -> list[InferenceChunk]:
|
||||
# Determine query type based on hybrid_alpha.
|
||||
if hybrid_alpha >= 0.8:
|
||||
query_type = QueryType.SEMANTIC
|
||||
elif hybrid_alpha <= 0.2:
|
||||
query_type = QueryType.KEYWORD
|
||||
else:
|
||||
query_type = QueryType.SEMANTIC # Default to semantic for hybrid.
|
||||
|
||||
return self._real_index.hybrid_retrieval(
|
||||
query=query,
|
||||
query_embedding=query_embedding,
|
||||
final_keywords=final_keywords,
|
||||
query_type=query_type,
|
||||
filters=filters,
|
||||
num_to_retrieve=num_to_retrieve,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
def admin_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
) -> list[InferenceChunk]:
|
||||
raise NotImplementedError(
|
||||
"[ANDREI]: Admin retrieval is not implemented for OpenSearch."
|
||||
)
|
||||
|
||||
def random_retrieval(
|
||||
self,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int = 100,
|
||||
) -> list[InferenceChunk]:
|
||||
return self._real_index.random_retrieval(
|
||||
filters=filters,
|
||||
num_to_retrieve=num_to_retrieve,
|
||||
dirty=None,
|
||||
)
|
||||
|
||||
|
||||
class OpenSearchDocumentIndex(DocumentIndex):
|
||||
"""OpenSearch-specific implementation of the DocumentIndex interface.
|
||||
|
||||
This class provides document indexing, retrieval, and management operations
|
||||
for an OpenSearch search engine instance. It handles the complete lifecycle
|
||||
of document chunks within a specific OpenSearch index/schema.
|
||||
|
||||
Although not yet used in this way in the codebase, each kind of embedding
|
||||
used should correspond to a different instance of this class, and therefore
|
||||
a different index in OpenSearch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
tenant_state: TenantState,
|
||||
) -> None:
|
||||
self._index_name: str = index_name
|
||||
self._tenant_state: TenantState = tenant_state
|
||||
self._os_client = OpenSearchClient(index_name=self._index_name)
|
||||
|
||||
def verify_and_create_index_if_necessary(
|
||||
self, embedding_dim: int, embedding_precision: EmbeddingPrecision
|
||||
) -> None:
|
||||
expected_mappings = DocumentSchema.get_document_schema(
|
||||
embedding_dim, self._tenant_state.multitenant
|
||||
)
|
||||
if not self._os_client.index_exists():
|
||||
self._os_client.create_index(
|
||||
mappings=expected_mappings,
|
||||
settings=DocumentSchema.get_index_settings(),
|
||||
)
|
||||
if not self._os_client.validate_index(
|
||||
expected_mappings=expected_mappings,
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"The index {self._index_name} is not valid. The expected mappings do not match the actual mappings."
|
||||
)
|
||||
|
||||
self._os_client.create_search_pipeline(
|
||||
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
self._os_client.create_search_pipeline(
|
||||
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
# Set of doc IDs.
|
||||
unique_docs_to_be_indexed: set[str] = set()
|
||||
document_indexing_results: list[DocumentInsertionRecord] = []
|
||||
for chunk in chunks:
|
||||
document_insertion_record: DocumentInsertionRecord | None = None
|
||||
onyx_document: Document = chunk.source_document
|
||||
if onyx_document.id not in unique_docs_to_be_indexed:
|
||||
# If this is the first time we see this doc in this indexing
|
||||
# operation, first delete the doc's chunks from the index. This
|
||||
# is so that there are no dangling chunks in the index, in the
|
||||
# event that the new document's content contains fewer chunks
|
||||
# than the previous content.
|
||||
# TODO(andrei): This can possibly be made more efficient by
|
||||
# checking if the chunk count has actually decreased. This
|
||||
# assumes that overlapping chunks are perfectly overwritten. If
|
||||
# we can't guarantee that then we need the code as-is.
|
||||
unique_docs_to_be_indexed.add(onyx_document.id)
|
||||
num_chunks_deleted = self.delete(
|
||||
onyx_document.id, onyx_document.chunk_count
|
||||
)
|
||||
# If we see that chunks were deleted we assume the doc already
|
||||
# existed.
|
||||
document_insertion_record = DocumentInsertionRecord(
|
||||
document_id=onyx_document.id,
|
||||
already_existed=num_chunks_deleted > 0,
|
||||
)
|
||||
|
||||
opensearch_document_chunk = _convert_onyx_chunk_to_opensearch_document(
|
||||
chunk
|
||||
)
|
||||
# TODO(andrei): Enrich chunk content here.
|
||||
# TODO(andrei): After our client supports batch indexing, use that
|
||||
# here.
|
||||
self._os_client.index_document(opensearch_document_chunk)
|
||||
|
||||
if document_insertion_record is not None:
|
||||
# Only add records once per doc. This object is not None only if
|
||||
# we've seen this doc for the first time in this for-loop.
|
||||
document_indexing_results.append(document_insertion_record)
|
||||
|
||||
return document_indexing_results
|
||||
|
||||
def delete(self, document_id: str, chunk_count: int | None = None) -> int:
|
||||
"""Deletes all chunks for a given document.
|
||||
|
||||
TODO(andrei): Make this method require supplying source type.
|
||||
TODO(andrei): Consider implementing this method to delete on document
|
||||
chunk IDs vs querying for matching document chunks.
|
||||
|
||||
Args:
|
||||
document_id: The ID of the document to delete.
|
||||
chunk_count: The number of chunks in OpenSearch for the document.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Failed to delete some or all of the chunks for the
|
||||
document.
|
||||
|
||||
Returns:
|
||||
The number of chunks successfully deleted.
|
||||
"""
|
||||
query_body = DocumentQuery.delete_from_document_id_query(
|
||||
document_id=document_id,
|
||||
tenant_state=self._tenant_state,
|
||||
)
|
||||
|
||||
return self._os_client.delete_by_query(query_body)
|
||||
|
||||
def update(
|
||||
self,
|
||||
update_requests: list[MetadataUpdateRequest],
|
||||
) -> None:
|
||||
logger.info("[ANDREI]: Updating documents...")
|
||||
# TODO(andrei): This needs to be implemented. I explicitly do not raise
|
||||
# here despite this not being implemented because indexing calls this
|
||||
# method so it is very hard to test other methods of this class if this
|
||||
# raises.
|
||||
|
||||
def id_based_retrieval(
|
||||
self,
|
||||
chunk_requests: list[DocumentSectionRequest],
|
||||
filters: IndexFilters,
|
||||
# TODO(andrei): Remove this from the new interface at some point; we
|
||||
# should not be exposing this.
|
||||
batch_retrieval: bool = False,
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
TODO(andrei): Consider implementing this method to retrieve on document
|
||||
chunk IDs vs querying for matching document chunks.
|
||||
"""
|
||||
results: list[InferenceChunk] = []
|
||||
for chunk_request in chunk_requests:
|
||||
document_chunks: list[DocumentChunk] = []
|
||||
query_body = DocumentQuery.get_from_document_id_query(
|
||||
document_id=chunk_request.document_id,
|
||||
tenant_state=self._tenant_state,
|
||||
max_chunk_size=chunk_request.max_chunk_size,
|
||||
min_chunk_index=chunk_request.min_chunk_ind,
|
||||
max_chunk_index=chunk_request.max_chunk_ind,
|
||||
)
|
||||
document_chunks = self._os_client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
inference_chunks_uncleaned = [
|
||||
_convert_opensearch_chunk_to_inference_chunk_uncleaned(document_chunk)
|
||||
for document_chunk in document_chunks
|
||||
]
|
||||
inference_chunks = [
|
||||
_convert_inference_chunk_uncleaned_to_inference_chunk(
|
||||
inference_chunk_uncleaned
|
||||
)
|
||||
for inference_chunk_uncleaned in inference_chunks_uncleaned
|
||||
]
|
||||
results.extend(inference_chunks)
|
||||
# TODO(andrei): Clean chunk content here.
|
||||
return results
|
||||
|
||||
def hybrid_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
query_embedding: Embedding,
|
||||
final_keywords: list[str] | None,
|
||||
query_type: QueryType,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
) -> list[InferenceChunk]:
|
||||
query_body = DocumentQuery.get_hybrid_search_query(
|
||||
query_text=query,
|
||||
query_vector=query_embedding,
|
||||
num_candidates=1000, # TODO(andrei): Magic number.
|
||||
num_hits=num_to_retrieve,
|
||||
tenant_state=self._tenant_state,
|
||||
)
|
||||
document_chunks = self._os_client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
)
|
||||
# TODO(andrei): Clean chunk content here.
|
||||
inference_chunks_uncleaned = [
|
||||
_convert_opensearch_chunk_to_inference_chunk_uncleaned(document_chunk)
|
||||
for document_chunk in document_chunks
|
||||
]
|
||||
inference_chunks = [
|
||||
_convert_inference_chunk_uncleaned_to_inference_chunk(
|
||||
inference_chunk_uncleaned
|
||||
)
|
||||
for inference_chunk_uncleaned in inference_chunks_uncleaned
|
||||
]
|
||||
return inference_chunks
|
||||
|
||||
def random_retrieval(
|
||||
self,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int = 100,
|
||||
dirty: bool | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
raise NotImplementedError(
|
||||
"[ANDREI]: Random retrieval is not implemented for OpenSearch."
|
||||
)
|
||||
@@ -1,327 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import field_serializer
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
|
||||
from onyx.document_index.opensearch.constants import EF_SEARCH
|
||||
from onyx.document_index.opensearch.constants import M
|
||||
|
||||
|
||||
TITLE_FIELD_NAME = "title"
|
||||
TITLE_VECTOR_FIELD_NAME = "title_vector"
|
||||
CONTENT_FIELD_NAME = "content"
|
||||
CONTENT_VECTOR_FIELD_NAME = "content_vector"
|
||||
NUM_TOKENS_FIELD_NAME = "num_tokens"
|
||||
SOURCE_TYPE_FIELD_NAME = "source_type"
|
||||
METADATA_FIELD_NAME = "metadata"
|
||||
LAST_UPDATED_FIELD_NAME = "last_updated"
|
||||
CREATED_AT_FIELD_NAME = "created_at"
|
||||
PUBLIC_FIELD_NAME = "public"
|
||||
ACCESS_CONTROL_LIST_FIELD_NAME = "access_control_list"
|
||||
HIDDEN_FIELD_NAME = "hidden"
|
||||
GLOBAL_BOOST_FIELD_NAME = "global_boost"
|
||||
SEMANTIC_IDENTIFIER_FIELD_NAME = "semantic_identifier"
|
||||
IMAGE_FILE_NAME_FIELD_NAME = "image_file_name"
|
||||
SOURCE_LINKS_FIELD_NAME = "source_links"
|
||||
DOCUMENT_SETS_FIELD_NAME = "document_sets"
|
||||
PROJECT_IDS_FIELD_NAME = "project_ids"
|
||||
DOCUMENT_ID_FIELD_NAME = "document_id"
|
||||
CHUNK_INDEX_FIELD_NAME = "chunk_index"
|
||||
MAX_CHUNK_SIZE_FIELD_NAME = "max_chunk_size"
|
||||
TENANT_ID_FIELD_NAME = "tenant_id"
|
||||
BLURB_FIELD_NAME = "blurb"
|
||||
|
||||
|
||||
def get_opensearch_doc_chunk_id(
|
||||
document_id: str, chunk_index: int, max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE
|
||||
) -> str:
|
||||
"""
|
||||
Returns a unique identifier for the chunk.
|
||||
|
||||
TODO(andrei): Add source type to this.
|
||||
TODO(andrei): Add tenant ID to this.
|
||||
TODO(andrei): Sanitize document_id in the event it contains characters that
|
||||
are not allowed in OpenSearch IDs.
|
||||
"""
|
||||
return f"{document_id}__{max_chunk_size}__{chunk_index}"
|
||||
|
||||
|
||||
class DocumentChunk(BaseModel):
|
||||
"""
|
||||
Represents a chunk of a document in the OpenSearch index.
|
||||
|
||||
The names of these fields are based on the OpenSearch schema. Changes to the
|
||||
schema require changes here. See get_document_schema.
|
||||
"""
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
document_id: str
|
||||
chunk_index: int
|
||||
# The maximum number of tokens this chunk's content can hold. Previously
|
||||
# there was a concept of large chunks, this is a generic concept of that. We
|
||||
# can choose to have any size of chunks in the index and they should be
|
||||
# distinct from one another.
|
||||
max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE
|
||||
|
||||
# Either both should be None or both should be non-None.
|
||||
title: str | None = None
|
||||
title_vector: list[float] | None = None
|
||||
content: str
|
||||
content_vector: list[float]
|
||||
# The actual number of tokens in the chunk.
|
||||
num_tokens: int
|
||||
|
||||
source_type: str
|
||||
# Application logic should store these strings the format key:::value.
|
||||
metadata: list[str] | None = None
|
||||
last_updated: datetime | None = None
|
||||
created_at: datetime | None = None
|
||||
|
||||
public: bool
|
||||
access_control_list: list[str] | None = None
|
||||
# Defaults to False, currently gets written during update not index.
|
||||
hidden: bool = False
|
||||
|
||||
global_boost: float = 1.0
|
||||
|
||||
semantic_identifier: str
|
||||
image_file_name: str | None = None
|
||||
# Contains a string representation of a dict which maps offset into the raw
|
||||
# chunk text to the link corresponding to that point.
|
||||
source_links: str | None = None
|
||||
blurb: str
|
||||
|
||||
document_sets: list[str] | None = None
|
||||
project_ids: list[int] | None = None
|
||||
|
||||
tenant_id: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_num_tokens_fits_within_max_chunk_size(self) -> Self:
|
||||
if self.num_tokens > self.max_chunk_size:
|
||||
raise ValueError(
|
||||
"Bug: Num tokens must be less than or equal to max chunk size."
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_title_and_title_vector_are_consistent(self) -> Self:
|
||||
# title and title_vector should both either be None or not.
|
||||
if self.title is not None and self.title_vector is None:
|
||||
raise ValueError("Bug: Title vector must not be None if title is not None.")
|
||||
if self.title_vector is not None and self.title is None:
|
||||
raise ValueError("Bug: Title must not be None if title vector is not None.")
|
||||
return self
|
||||
|
||||
@field_serializer("last_updated", "created_at", mode="plain")
|
||||
def serialize_datetime_fields_to_epoch_millis(
|
||||
self, value: datetime | None
|
||||
) -> int | None:
|
||||
"""
|
||||
Serializes datetime fields to milliseconds since the Unix epoch.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if value.tzinfo is None:
|
||||
# astimezone will raise if value does not have a timezone set.
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# Does appropriate time conversion if value was set in a different
|
||||
# timezone.
|
||||
value = value.astimezone(timezone.utc)
|
||||
# timestamp returns a float in seconds so convert to millis.
|
||||
return int(value.timestamp() * 1000)
|
||||
|
||||
|
||||
class DocumentSchema:
|
||||
"""
|
||||
Represents the schema and indexing strategies of the OpenSearch index.
|
||||
|
||||
TODO(andrei): Implement multi-phase indexing strategies.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_document_schema(vector_dimension: int, multitenant: bool) -> dict[str, Any]:
|
||||
"""Returns the document schema for the OpenSearch index.
|
||||
|
||||
WARNING: Changes / additions to field names here require changes to the
|
||||
DocumentChunk class above.
|
||||
|
||||
Notes:
|
||||
- By default all fields have indexing enabled.
|
||||
- By default almost all fields except text fields have doc_values
|
||||
enabled, enabling operations like sorting and aggregations.
|
||||
- By default all fields are nullable.
|
||||
- "type": "keyword" fields are stored as-is, used for exact matches,
|
||||
filtering, etc.
|
||||
- "type": "text" fields are OpenSearch-processed strings, used for
|
||||
full-text searches.
|
||||
- "store": True fields are stored and can be returned on their own,
|
||||
independent of the parent document.
|
||||
|
||||
Args:
|
||||
vector_dimension: The dimension of vector embeddings. Must be a
|
||||
positive integer.
|
||||
multitenant: Whether the index is multitenant.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the document schema, to be supplied to the
|
||||
OpenSearch client. The structure of this dictionary is
|
||||
determined by OpenSearch documentation.
|
||||
"""
|
||||
schema = {
|
||||
"properties": {
|
||||
TITLE_FIELD_NAME: {
|
||||
"type": "text",
|
||||
"fields": {
|
||||
# Subfield accessed as title.keyword. Not indexed for
|
||||
# values longer than 256 chars.
|
||||
"keyword": {"type": "keyword", "ignore_above": 256}
|
||||
},
|
||||
},
|
||||
CONTENT_FIELD_NAME: {
|
||||
"type": "text",
|
||||
"store": True,
|
||||
},
|
||||
TITLE_VECTOR_FIELD_NAME: {
|
||||
"type": "knn_vector",
|
||||
"dimension": vector_dimension,
|
||||
"method": {
|
||||
"name": "hnsw",
|
||||
"space_type": "cosinesimil",
|
||||
"engine": "lucene",
|
||||
"parameters": {"ef_construction": EF_CONSTRUCTION, "m": M},
|
||||
},
|
||||
},
|
||||
CONTENT_VECTOR_FIELD_NAME: {
|
||||
"type": "knn_vector",
|
||||
"dimension": vector_dimension,
|
||||
"method": {
|
||||
"name": "hnsw",
|
||||
"space_type": "cosinesimil",
|
||||
"engine": "lucene",
|
||||
"parameters": {"ef_construction": EF_CONSTRUCTION, "m": M},
|
||||
},
|
||||
},
|
||||
# See TODO in _convert_onyx_chunk_to_opensearch_document. I
|
||||
# don't want to actually add this to the schema until we know
|
||||
# for sure we need it. If we decide we don't I will remove this.
|
||||
# # Number of tokens in the chunk's content.
|
||||
# NUM_TOKENS_FIELD_NAME: {"type": "integer", "store": True},
|
||||
SOURCE_TYPE_FIELD_NAME: {"type": "keyword"},
|
||||
# Application logic should store in the format key:::value.
|
||||
METADATA_FIELD_NAME: {"type": "keyword"},
|
||||
LAST_UPDATED_FIELD_NAME: {
|
||||
"type": "date",
|
||||
"format": "epoch_millis",
|
||||
# For some reason date defaults to False, even though it
|
||||
# would make sense to sort by date.
|
||||
"doc_values": True,
|
||||
},
|
||||
# See TODO in _convert_onyx_chunk_to_opensearch_document. I
|
||||
# don't want to actually add this to the schema until we know
|
||||
# for sure we need it. If we decide we don't I will remove this.
|
||||
# CREATED_AT_FIELD_NAME: {
|
||||
# "type": "date",
|
||||
# "format": "epoch_millis",
|
||||
# # For some reason date defaults to False, even though it
|
||||
# # would make sense to sort by date.
|
||||
# "doc_values": True,
|
||||
# },
|
||||
# Access control fields.
|
||||
# Whether the doc is public. Could have fallen under access
|
||||
# control list but is such a broad and critical filter that it
|
||||
# is its own field.
|
||||
PUBLIC_FIELD_NAME: {"type": "boolean"},
|
||||
# Access control list for the doc, excluding public access,
|
||||
# which is covered above.
|
||||
ACCESS_CONTROL_LIST_FIELD_NAME: {"type": "keyword"},
|
||||
# Whether the doc is hidden from search results. Should clobber
|
||||
# all other search filters; up to search implementations to
|
||||
# guarantee this.
|
||||
HIDDEN_FIELD_NAME: {"type": "boolean"},
|
||||
GLOBAL_BOOST_FIELD_NAME: {"type": "float"},
|
||||
# This field is only used for displaying a useful name for the
|
||||
# doc in the UI and is not used for searching. Disabling these
|
||||
# features to increase perf.
|
||||
SEMANTIC_IDENTIFIER_FIELD_NAME: {
|
||||
"type": "keyword",
|
||||
"index": False,
|
||||
"doc_values": False,
|
||||
"store": False,
|
||||
},
|
||||
# Same as above; used to display an image along with the doc.
|
||||
IMAGE_FILE_NAME_FIELD_NAME: {
|
||||
"type": "keyword",
|
||||
"index": False,
|
||||
"doc_values": False,
|
||||
"store": False,
|
||||
},
|
||||
# Same as above; used to link to the source doc.
|
||||
SOURCE_LINKS_FIELD_NAME: {
|
||||
"type": "keyword",
|
||||
"index": False,
|
||||
"doc_values": False,
|
||||
"store": False,
|
||||
},
|
||||
# Same as above; used to quickly summarize the doc in the UI.
|
||||
BLURB_FIELD_NAME: {
|
||||
"type": "keyword",
|
||||
"index": False,
|
||||
"doc_values": False,
|
||||
"store": False,
|
||||
},
|
||||
# Product-specific fields.
|
||||
DOCUMENT_SETS_FIELD_NAME: {"type": "keyword"},
|
||||
PROJECT_IDS_FIELD_NAME: {"type": "integer"},
|
||||
# OpenSearch metadata fields.
|
||||
DOCUMENT_ID_FIELD_NAME: {"type": "keyword"},
|
||||
CHUNK_INDEX_FIELD_NAME: {"type": "integer"},
|
||||
# The maximum number of tokens this chunk's content can hold.
|
||||
MAX_CHUNK_SIZE_FIELD_NAME: {"type": "integer"},
|
||||
}
|
||||
}
|
||||
|
||||
if multitenant:
|
||||
schema["properties"][TENANT_ID_FIELD_NAME] = {"type": "keyword"}
|
||||
|
||||
return schema
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings() -> dict[str, Any]:
|
||||
"""
|
||||
Standard settings for reasonable local index and search performance.
|
||||
"""
|
||||
return {
|
||||
"index": {
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 1,
|
||||
# Required for vector search.
|
||||
"knn": True,
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_bulk_index_settings() -> dict[str, Any]:
|
||||
"""
|
||||
Optimized settings for bulk indexing: disable refresh and replicas.
|
||||
"""
|
||||
return {
|
||||
"index": {
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0, # No replication during bulk load.
|
||||
# Disables auto-refresh, improves performance in pure indexing (no searching) scenarios.
|
||||
"refresh_interval": "-1",
|
||||
# Required for vector search.
|
||||
"knn": True,
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
}
|
||||
}
|
||||
@@ -1,347 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_KEYWORD_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_PHRASE_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_VECTOR_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_TITLE_KEYWORD_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_TITLE_VECTOR_WEIGHT
|
||||
from onyx.document_index.opensearch.schema import CHUNK_INDEX_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CONTENT_VECTOR_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DOCUMENT_ID_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import MAX_CHUNK_SIZE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import PUBLIC_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import TENANT_ID_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import TITLE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import TITLE_VECTOR_FIELD_NAME
|
||||
|
||||
# Normalization pipelines combine document scores from multiple query clauses.
|
||||
# The number and ordering of weights should match the query clauses. The values
|
||||
# of the weights should sum to 1.
|
||||
|
||||
# TODO(andrei): Turn all magic dictionaries to pydantic models.
|
||||
|
||||
MIN_MAX_NORMALIZATION_PIPELINE_NAME = "normalization_pipeline_min_max"
|
||||
MIN_MAX_NORMALIZATION_PIPELINE_CONFIG = {
|
||||
"description": "Normalization for keyword and vector scores using min-max",
|
||||
"phase_results_processors": [
|
||||
{
|
||||
# https://docs.opensearch.org/latest/search-plugins/search-pipelines/normalization-processor/
|
||||
"normalization-processor": {
|
||||
"normalization": {"technique": "min_max"},
|
||||
"combination": {
|
||||
"technique": "arithmetic_mean",
|
||||
"parameters": {
|
||||
"weights": [
|
||||
SEARCH_TITLE_VECTOR_WEIGHT,
|
||||
SEARCH_CONTENT_VECTOR_WEIGHT,
|
||||
SEARCH_TITLE_KEYWORD_WEIGHT,
|
||||
SEARCH_CONTENT_KEYWORD_WEIGHT,
|
||||
SEARCH_CONTENT_PHRASE_WEIGHT,
|
||||
]
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
ZSCORE_NORMALIZATION_PIPELINE_NAME = "normalization_pipeline_zscore"
|
||||
ZSCORE_NORMALIZATION_PIPELINE_CONFIG = {
|
||||
"description": "Normalization for keyword and vector scores using z-score",
|
||||
"phase_results_processors": [
|
||||
{
|
||||
# https://docs.opensearch.org/latest/search-plugins/search-pipelines/normalization-processor/
|
||||
"normalization-processor": {
|
||||
"normalization": {"technique": "z_score"},
|
||||
"combination": {
|
||||
"technique": "arithmetic_mean",
|
||||
"parameters": {
|
||||
"weights": [
|
||||
SEARCH_TITLE_VECTOR_WEIGHT,
|
||||
SEARCH_CONTENT_VECTOR_WEIGHT,
|
||||
SEARCH_TITLE_KEYWORD_WEIGHT,
|
||||
SEARCH_CONTENT_KEYWORD_WEIGHT,
|
||||
SEARCH_CONTENT_PHRASE_WEIGHT,
|
||||
]
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
assert (
|
||||
sum(
|
||||
[
|
||||
SEARCH_TITLE_VECTOR_WEIGHT,
|
||||
SEARCH_CONTENT_VECTOR_WEIGHT,
|
||||
SEARCH_TITLE_KEYWORD_WEIGHT,
|
||||
SEARCH_CONTENT_KEYWORD_WEIGHT,
|
||||
SEARCH_CONTENT_PHRASE_WEIGHT,
|
||||
]
|
||||
)
|
||||
== 1.0
|
||||
)
|
||||
|
||||
|
||||
# By default OpenSearch will only return a maximum of this many results in a
|
||||
# given search. This value is configurable in the index settings.
|
||||
DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW = 10_000
|
||||
|
||||
|
||||
class DocumentQuery:
|
||||
"""
|
||||
TODO(andrei): Implement multi-phase search strategies.
|
||||
TODO(andrei): Implement document boost.
|
||||
TODO(andrei): Implement document age.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_from_document_id_query(
|
||||
document_id: str,
|
||||
tenant_state: TenantState,
|
||||
max_chunk_size: int,
|
||||
min_chunk_index: int | None,
|
||||
max_chunk_index: int | None,
|
||||
get_full_document: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Returns a final search query which gets chunks from a given document ID.
|
||||
|
||||
This query can be directly supplied to the OpenSearch client.
|
||||
|
||||
TODO(andrei): Currently capped at 10k results. Implement scroll/point in
|
||||
time for results so that we can return arbitrarily-many IDs.
|
||||
|
||||
Args:
|
||||
document_id: Onyx document ID. Notably not an OpenSearch document
|
||||
ID, which points to what Onyx would refer to as a chunk.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
max_chunk_size: Document chunks are categorized by the maximum
|
||||
number of tokens they can hold. This parameter specifies the
|
||||
maximum size category of document chunks to retrieve.
|
||||
min_chunk_index: The minimum chunk index to retrieve, inclusive. If
|
||||
None, no minimum chunk index will be applied.
|
||||
max_chunk_index: The maximum chunk index to retrieve, inclusive. If
|
||||
None, no maximum chunk index will be applied.
|
||||
get_full_document: Whether to get the full document body. If False,
|
||||
OpenSearch will only return the matching document chunk IDs plus
|
||||
metadata; the source data will be omitted from the response. Use
|
||||
this for performance optimization if OpenSearch IDs are
|
||||
sufficient. Defaults to True.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the final ID search query.
|
||||
"""
|
||||
filter_clauses: list[dict[str, Any]] = [
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
]
|
||||
|
||||
if tenant_state.tenant_id is not None:
|
||||
# TODO(andrei): Fix tenant stuff.
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
|
||||
if min_chunk_index is not None or max_chunk_index is not None:
|
||||
range_clause: dict[str, Any] = {"range": {CHUNK_INDEX_FIELD_NAME: {}}}
|
||||
if min_chunk_index is not None:
|
||||
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["gte"] = min_chunk_index
|
||||
if max_chunk_index is not None:
|
||||
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["lte"] = max_chunk_index
|
||||
filter_clauses.append(range_clause)
|
||||
|
||||
filter_clauses.append(
|
||||
{"term": {MAX_CHUNK_SIZE_FIELD_NAME: {"value": max_chunk_size}}}
|
||||
)
|
||||
|
||||
final_get_ids_query: dict[str, Any] = {
|
||||
"query": {"bool": {"filter": filter_clauses}},
|
||||
# We include this to make sure OpenSearch does not revert to
|
||||
# returning some number of results less than the index max allowed
|
||||
# return size.
|
||||
"size": DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW,
|
||||
"_source": get_full_document,
|
||||
}
|
||||
|
||||
return final_get_ids_query
|
||||
|
||||
@staticmethod
|
||||
def delete_from_document_id_query(
|
||||
document_id: str,
|
||||
tenant_state: TenantState,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Returns a final search query which deletes chunks from a given document
|
||||
ID.
|
||||
|
||||
This query can be directly supplied to the OpenSearch client.
|
||||
|
||||
Intended to be supplied to the OpenSearch client's delete_by_query
|
||||
method.
|
||||
|
||||
TODO(andrei): There is no limit to the number of document chunks that
|
||||
can be deleted by this query. This could get expensive. Consider
|
||||
implementing batching.
|
||||
|
||||
Args:
|
||||
document_id: Onyx document ID. Notably not an OpenSearch document
|
||||
ID, which points to what Onyx would refer to as a chunk.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the final delete query.
|
||||
"""
|
||||
filter_clauses: list[dict[str, Any]] = [
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
]
|
||||
|
||||
if tenant_state.tenant_id is not None:
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
|
||||
final_delete_query: dict[str, Any] = {
|
||||
"query": {"bool": {"filter": filter_clauses}},
|
||||
}
|
||||
|
||||
return final_delete_query
|
||||
|
||||
@staticmethod
|
||||
def get_hybrid_search_query(
|
||||
query_text: str,
|
||||
query_vector: list[float],
|
||||
num_candidates: int,
|
||||
num_hits: int,
|
||||
tenant_state: TenantState,
|
||||
) -> dict[str, Any]:
|
||||
"""Returns a final hybrid search query.
|
||||
|
||||
This query can be directly supplied to the OpenSearch client.
|
||||
|
||||
Args:
|
||||
query_text: The text to query for.
|
||||
query_vector: The vector embedding of the text to query for.
|
||||
num_candidates: The number of candidates to consider for vector
|
||||
similarity search. Generally more candidates improves search
|
||||
quality at the cost of performance.
|
||||
num_hits: The final number of hits to return.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the final hybrid search query.
|
||||
"""
|
||||
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
|
||||
raise ValueError(
|
||||
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
|
||||
f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})."
|
||||
)
|
||||
|
||||
hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries(
|
||||
query_text, query_vector, num_candidates
|
||||
)
|
||||
hybrid_search_filters = DocumentQuery._get_hybrid_search_filters(tenant_state)
|
||||
|
||||
hybrid_search_query: dict[str, Any] = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"hybrid": {
|
||||
"queries": hybrid_search_subqueries,
|
||||
}
|
||||
}
|
||||
],
|
||||
"filter": hybrid_search_filters,
|
||||
}
|
||||
}
|
||||
|
||||
final_hybrid_search_body: dict[str, Any] = {
|
||||
"query": hybrid_search_query,
|
||||
"size": num_hits,
|
||||
}
|
||||
return final_hybrid_search_body
|
||||
|
||||
@staticmethod
|
||||
def _get_hybrid_search_subqueries(
|
||||
query_text: str, query_vector: list[float], num_candidates: int
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Returns subqueries for hybrid search.
|
||||
|
||||
Each of these subqueries are the "hybrid" component of this search. We
|
||||
search on various things and combine results.
|
||||
|
||||
The return of this function is not sufficient to be directly supplied to
|
||||
the OpenSearch client. See get_hybrid_search_query.
|
||||
|
||||
Matches:
|
||||
- Title vector
|
||||
- Content vector
|
||||
- Title keyword
|
||||
- Content keyword
|
||||
- Content phrase
|
||||
|
||||
Normalization is not performed here.
|
||||
The weights of each of these subqueries should be configured in a search
|
||||
pipeline.
|
||||
|
||||
NOTE: For OpenSearch, 5 is the maximum number of query clauses allowed
|
||||
in a single hybrid query.
|
||||
|
||||
Args:
|
||||
query_text: The text of the query to search for.
|
||||
query_vector: The vector embedding of the query to search for.
|
||||
num_candidates: The number of candidates to consider for vector
|
||||
similarity search.
|
||||
"""
|
||||
hybrid_search_queries: list[dict[str, Any]] = [
|
||||
{
|
||||
"knn": {
|
||||
TITLE_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
"k": num_candidates,
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"knn": {
|
||||
CONTENT_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
"k": num_candidates,
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"multi_match": {
|
||||
"query": query_text,
|
||||
"fields": [f"{TITLE_FIELD_NAME}^2", f"{TITLE_FIELD_NAME}.keyword"],
|
||||
"type": "best_fields",
|
||||
}
|
||||
},
|
||||
{"match": {CONTENT_FIELD_NAME: {"query": query_text}}},
|
||||
{"match_phrase": {CONTENT_FIELD_NAME: {"query": query_text, "boost": 1.5}}},
|
||||
]
|
||||
return hybrid_search_queries
|
||||
|
||||
@staticmethod
|
||||
def _get_hybrid_search_filters(tenant_state: TenantState) -> list[dict[str, Any]]:
|
||||
"""Returns filters for hybrid search.
|
||||
|
||||
For now only fetches public and not hidden documents.
|
||||
|
||||
The return of this function is not sufficient to be directly supplied to
|
||||
the OpenSearch client. See get_hybrid_search_query.
|
||||
|
||||
TODO(andrei): Add ACL filters and stuff.
|
||||
"""
|
||||
hybrid_search_filters: list[dict[str, Any]] = [
|
||||
{"term": {PUBLIC_FIELD_NAME: {"value": True}}},
|
||||
{"term": {HIDDEN_FIELD_NAME: {"value": False}}},
|
||||
]
|
||||
if tenant_state.tenant_id is not None:
|
||||
hybrid_search_filters.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
return hybrid_search_filters
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user