mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-19 08:45:47 +00:00
Compare commits
185 Commits
dump-scrip
...
overlay-fi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bcc59a476b | ||
|
|
75cfa49504 | ||
|
|
6ec0b09139 | ||
|
|
53691fc95a | ||
|
|
3400e2a14d | ||
|
|
d8cc1f7a2c | ||
|
|
2098e910dd | ||
|
|
e5491d6f79 | ||
|
|
a8934a083a | ||
|
|
80e9507e01 | ||
|
|
60d3be5fe2 | ||
|
|
b481cc36d0 | ||
|
|
65c5da8912 | ||
|
|
0a0366e6ca | ||
|
|
84a623e884 | ||
|
|
6b91607b17 | ||
|
|
82fb737ad9 | ||
|
|
eed49e699e | ||
|
|
3cc7afd334 | ||
|
|
bcbfd28234 | ||
|
|
faa47d9691 | ||
|
|
6649561bf3 | ||
|
|
026cda0468 | ||
|
|
64297e5996 | ||
|
|
c517137c0a | ||
|
|
cbfbe0bbbe | ||
|
|
13ca4c6650 | ||
|
|
e8d9e36d62 | ||
|
|
77e4f3c574 | ||
|
|
2bdc06201a | ||
|
|
077ba9624c | ||
|
|
81eb1a1c7c | ||
|
|
1a16fef783 | ||
|
|
027692d5eb | ||
|
|
3a889f7069 | ||
|
|
20d67bd956 | ||
|
|
8d6b6accaf | ||
|
|
ed76b4eb55 | ||
|
|
7613c100d1 | ||
|
|
c52d3412de | ||
|
|
96b6162b52 | ||
|
|
502ed8909b | ||
|
|
8de75dd033 | ||
|
|
74e3668e38 | ||
|
|
2475a9ef92 | ||
|
|
690f54c441 | ||
|
|
71bb0c029e | ||
|
|
ccf890a129 | ||
|
|
a7bfdebddf | ||
|
|
6fc5ca12a3 | ||
|
|
8298452522 | ||
|
|
2559327636 | ||
|
|
ef185ce2c8 | ||
|
|
a04fee5cbd | ||
|
|
e507378244 | ||
|
|
e6be3f85b2 | ||
|
|
cc96e303ce | ||
|
|
e0fcb1f860 | ||
|
|
f5442c431d | ||
|
|
652e5848e5 | ||
|
|
3fa1896316 | ||
|
|
f855ecab11 | ||
|
|
fd26176e7d | ||
|
|
8986f67779 | ||
|
|
42f2d4aca5 | ||
|
|
7116d24a8c | ||
|
|
7f4593be32 | ||
|
|
f47e25e693 | ||
|
|
877184ae97 | ||
|
|
54961ec8ef | ||
|
|
e797971ce5 | ||
|
|
566cca70d8 | ||
|
|
be2d0e2b5d | ||
|
|
692f937ca4 | ||
|
|
11de1ceb65 | ||
|
|
19993b4679 | ||
|
|
9063827782 | ||
|
|
0cc6fa49d7 | ||
|
|
3f3508b668 | ||
|
|
1c3a88daf8 | ||
|
|
92f30bbad9 | ||
|
|
4abf43d85b | ||
|
|
b08f9adb23 | ||
|
|
7a915833bb | ||
|
|
9698b700e6 | ||
|
|
fd944acc5b | ||
|
|
a1309257f5 | ||
|
|
6266dc816d | ||
|
|
83c011a9e4 | ||
|
|
8d1ac81d09 | ||
|
|
d8cd4c9928 | ||
|
|
5caa4fdaa0 | ||
|
|
f22f33564b | ||
|
|
f86d282a47 | ||
|
|
ece1edb80f | ||
|
|
c9c17e19f3 | ||
|
|
40e834e0b8 | ||
|
|
45bd82d031 | ||
|
|
27c1619c3d | ||
|
|
8cfeb85c43 | ||
|
|
491b550ebc | ||
|
|
1a94dfd113 | ||
|
|
bcd9d7ae41 | ||
|
|
98b4353632 | ||
|
|
f071b280d4 | ||
|
|
f7ebaa42fc | ||
|
|
11737c2069 | ||
|
|
1712253e5f | ||
|
|
de8f292fce | ||
|
|
bbe5058131 | ||
|
|
45fc5e3c97 | ||
|
|
5c976815cc | ||
|
|
3ea4b6e6cc | ||
|
|
7b75c0049b | ||
|
|
04bdce55f4 | ||
|
|
2446b1898e | ||
|
|
6f22a2f656 | ||
|
|
e307a84863 | ||
|
|
2dd27f25cb | ||
|
|
e402c0e3b4 | ||
|
|
2721c8582a | ||
|
|
43c8b7a712 | ||
|
|
f473b85acd | ||
|
|
02cd84c39a | ||
|
|
46d17d6c64 | ||
|
|
10ad536491 | ||
|
|
ccabc1a7a7 | ||
|
|
8e262e4da8 | ||
|
|
79dea9d901 | ||
|
|
2f650bbef8 | ||
|
|
021e67ca71 | ||
|
|
87ae024280 | ||
|
|
5092429557 | ||
|
|
dc691199f5 | ||
|
|
1662c391f0 | ||
|
|
08aefbc115 | ||
|
|
fb6342daa9 | ||
|
|
4e7adcc9ee | ||
|
|
aa4b3d8a24 | ||
|
|
f3bc459b6e | ||
|
|
87cab60b01 | ||
|
|
08ab73caf8 | ||
|
|
675761c81e | ||
|
|
18e15c6da6 | ||
|
|
e1f77e2e17 | ||
|
|
4ef388b2dc | ||
|
|
031485232b | ||
|
|
c0debefaf6 | ||
|
|
bbebe5f201 | ||
|
|
ac9cb22fee | ||
|
|
5e281ce2e6 | ||
|
|
9ea5b7a424 | ||
|
|
e0b83fad4c | ||
|
|
7191b9010d | ||
|
|
fb3428ed37 | ||
|
|
444ad297da | ||
|
|
f46df421a7 | ||
|
|
98a2e12090 | ||
|
|
36bfa8645e | ||
|
|
56e71d7f6c | ||
|
|
e0d172615b | ||
|
|
bde52b13d4 | ||
|
|
b273d91512 | ||
|
|
1fbe76a607 | ||
|
|
6ee7316130 | ||
|
|
51802f46bb | ||
|
|
d430444424 | ||
|
|
17fff6c805 | ||
|
|
a33f6e8416 | ||
|
|
d157649069 | ||
|
|
77bbb9f7a7 | ||
|
|
996b5177d9 | ||
|
|
ab9a3ba970 | ||
|
|
87c1f0ab10 | ||
|
|
dcea1d88e5 | ||
|
|
cc481e20d3 | ||
|
|
4d141a8f68 | ||
|
|
cb32c81d1b | ||
|
|
64f327fdef | ||
|
|
902d6112c3 | ||
|
|
f71e3b9151 | ||
|
|
dd7e1520c5 | ||
|
|
97553de299 | ||
|
|
c80ab8b200 | ||
|
|
85c4ddce39 |
192
.github/workflows/deployment.yml
vendored
192
.github/workflows/deployment.yml
vendored
@@ -6,8 +6,9 @@ on:
|
||||
- "*"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
# Set restrictive default permissions for all jobs. Jobs that need more permissions
|
||||
# should explicitly declare them.
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
IS_DRY_RUN: ${{ github.event_name == 'workflow_dispatch' }}
|
||||
@@ -20,6 +21,7 @@ jobs:
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 90
|
||||
outputs:
|
||||
build-desktop: ${{ steps.check.outputs.build-desktop }}
|
||||
build-web: ${{ steps.check.outputs.build-web }}
|
||||
build-web-cloud: ${{ steps.check.outputs.build-web-cloud }}
|
||||
build-backend: ${{ steps.check.outputs.build-backend }}
|
||||
@@ -30,6 +32,7 @@ jobs:
|
||||
is-stable-standalone: ${{ steps.check.outputs.is-stable-standalone }}
|
||||
is-beta-standalone: ${{ steps.check.outputs.is-beta-standalone }}
|
||||
sanitized-tag: ${{ steps.check.outputs.sanitized-tag }}
|
||||
short-sha: ${{ steps.check.outputs.short-sha }}
|
||||
steps:
|
||||
- name: Check which components to build and version info
|
||||
id: check
|
||||
@@ -38,6 +41,7 @@ jobs:
|
||||
# Sanitize tag name by replacing slashes with hyphens (for Docker tag compatibility)
|
||||
SANITIZED_TAG=$(echo "$TAG" | tr '/' '-')
|
||||
IS_CLOUD=false
|
||||
BUILD_DESKTOP=false
|
||||
BUILD_WEB=false
|
||||
BUILD_WEB_CLOUD=false
|
||||
BUILD_BACKEND=true
|
||||
@@ -47,13 +51,6 @@ jobs:
|
||||
IS_STABLE_STANDALONE=false
|
||||
IS_BETA_STANDALONE=false
|
||||
|
||||
if [[ "$TAG" == *cloud* ]]; then
|
||||
IS_CLOUD=true
|
||||
BUILD_WEB_CLOUD=true
|
||||
else
|
||||
BUILD_WEB=true
|
||||
fi
|
||||
|
||||
# Version checks (for web - any stable version)
|
||||
if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
IS_STABLE=true
|
||||
@@ -62,6 +59,17 @@ jobs:
|
||||
IS_BETA=true
|
||||
fi
|
||||
|
||||
if [[ "$TAG" == *cloud* ]]; then
|
||||
IS_CLOUD=true
|
||||
BUILD_WEB_CLOUD=true
|
||||
else
|
||||
BUILD_WEB=true
|
||||
# Skip desktop builds on beta tags
|
||||
if [[ "$IS_BETA" != "true" ]]; then
|
||||
BUILD_DESKTOP=true
|
||||
fi
|
||||
fi
|
||||
|
||||
# Version checks (for backend/model-server - stable version excluding cloud tags)
|
||||
if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "$TAG" != *cloud* ]]; then
|
||||
IS_STABLE_STANDALONE=true
|
||||
@@ -70,7 +78,9 @@ jobs:
|
||||
IS_BETA_STANDALONE=true
|
||||
fi
|
||||
|
||||
SHORT_SHA="${GITHUB_SHA::7}"
|
||||
{
|
||||
echo "build-desktop=$BUILD_DESKTOP"
|
||||
echo "build-web=$BUILD_WEB"
|
||||
echo "build-web-cloud=$BUILD_WEB_CLOUD"
|
||||
echo "build-backend=$BUILD_BACKEND"
|
||||
@@ -81,6 +91,7 @@ jobs:
|
||||
echo "is-stable-standalone=$IS_STABLE_STANDALONE"
|
||||
echo "is-beta-standalone=$IS_BETA_STANDALONE"
|
||||
echo "sanitized-tag=$SANITIZED_TAG"
|
||||
echo "short-sha=$SHORT_SHA"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
check-version-tag:
|
||||
@@ -95,7 +106,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # ratchet:astral-sh/setup-uv@v7.1.4
|
||||
uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # ratchet:astral-sh/setup-uv@v7.1.5
|
||||
with:
|
||||
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.
|
||||
enable-cache: false
|
||||
@@ -124,6 +135,136 @@ jobs:
|
||||
title: "🚨 Version Tag Check Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
build-desktop:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-desktop == 'true'
|
||||
permissions:
|
||||
contents: write
|
||||
actions: read
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- platform: 'macos-latest' # Build a universal image for macOS.
|
||||
args: '--target universal-apple-darwin'
|
||||
- platform: 'ubuntu-24.04'
|
||||
args: '--bundles deb,rpm'
|
||||
- platform: 'ubuntu-24.04-arm' # Only available in public repos.
|
||||
args: '--bundles deb,rpm'
|
||||
- platform: 'windows-latest'
|
||||
args: ''
|
||||
|
||||
runs-on: ${{ matrix.platform }}
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6.0.1
|
||||
with:
|
||||
# NOTE: persist-credentials is needed for tauri-action to create GitHub releases.
|
||||
persist-credentials: true # zizmor: ignore[artipacked]
|
||||
|
||||
- name: install dependencies (ubuntu only)
|
||||
if: startsWith(matrix.platform, 'ubuntu-')
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y \
|
||||
build-essential \
|
||||
libglib2.0-dev \
|
||||
libgirepository1.0-dev \
|
||||
libgtk-3-dev \
|
||||
libjavascriptcoregtk-4.1-dev \
|
||||
libwebkit2gtk-4.1-dev \
|
||||
libayatana-appindicator3-dev \
|
||||
gobject-introspection \
|
||||
pkg-config \
|
||||
curl \
|
||||
xdg-utils
|
||||
|
||||
- name: setup node
|
||||
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v6.1.0
|
||||
with:
|
||||
node-version: 24
|
||||
package-manager-cache: false
|
||||
|
||||
- name: install Rust stable
|
||||
uses: dtolnay/rust-toolchain@6d9817901c499d6b02debbb57edb38d33daa680b # zizmor: ignore[impostor-commit]
|
||||
with:
|
||||
# Those targets are only used on macos runners so it's in an `if` to slightly speed up windows and linux builds.
|
||||
targets: ${{ matrix.platform == 'macos-latest' && 'aarch64-apple-darwin,x86_64-apple-darwin' || '' }}
|
||||
|
||||
- name: install frontend dependencies
|
||||
working-directory: ./desktop
|
||||
run: npm install
|
||||
|
||||
- name: Inject version (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
working-directory: ./desktop
|
||||
env:
|
||||
SHORT_SHA: ${{ needs.determine-builds.outputs.short-sha }}
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
run: |
|
||||
if [ "${EVENT_NAME}" == "workflow_dispatch" ]; then
|
||||
VERSION="0.0.0-dev+${SHORT_SHA}"
|
||||
else
|
||||
VERSION="${GITHUB_REF_NAME#v}"
|
||||
fi
|
||||
echo "Injecting version: $VERSION"
|
||||
|
||||
# Update Cargo.toml
|
||||
sed "s/^version = .*/version = \"$VERSION\"/" src-tauri/Cargo.toml > src-tauri/Cargo.toml.tmp
|
||||
mv src-tauri/Cargo.toml.tmp src-tauri/Cargo.toml
|
||||
|
||||
# Update tauri.conf.json
|
||||
jq --arg v "$VERSION" '.version = $v' src-tauri/tauri.conf.json > src-tauri/tauri.conf.json.tmp
|
||||
mv src-tauri/tauri.conf.json.tmp src-tauri/tauri.conf.json
|
||||
|
||||
# Update package.json
|
||||
jq --arg v "$VERSION" '.version = $v' package.json > package.json.tmp
|
||||
mv package.json.tmp package.json
|
||||
|
||||
echo "Versions set to: $VERSION"
|
||||
|
||||
- name: Inject version (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
working-directory: ./desktop
|
||||
shell: pwsh
|
||||
run: |
|
||||
# Windows MSI requires numeric-only build metadata, so we skip the SHA suffix
|
||||
if ("${{ github.event_name }}" -eq "workflow_dispatch") {
|
||||
$VERSION = "0.0.0"
|
||||
} else {
|
||||
# Strip 'v' prefix and any pre-release suffix (e.g., -beta.13) for MSI compatibility
|
||||
$VERSION = "$env:GITHUB_REF_NAME" -replace '^v', '' -replace '-.*$', ''
|
||||
}
|
||||
Write-Host "Injecting version: $VERSION"
|
||||
|
||||
# Update Cargo.toml
|
||||
$cargo = Get-Content src-tauri/Cargo.toml -Raw
|
||||
$cargo = $cargo -replace '(?m)^version = .*', "version = `"$VERSION`""
|
||||
Set-Content src-tauri/Cargo.toml $cargo -NoNewline
|
||||
|
||||
# Update tauri.conf.json
|
||||
$json = Get-Content src-tauri/tauri.conf.json | ConvertFrom-Json
|
||||
$json.version = $VERSION
|
||||
$json | ConvertTo-Json -Depth 100 | Set-Content src-tauri/tauri.conf.json
|
||||
|
||||
# Update package.json
|
||||
$pkg = Get-Content package.json | ConvertFrom-Json
|
||||
$pkg.version = $VERSION
|
||||
$pkg | ConvertTo-Json -Depth 100 | Set-Content package.json
|
||||
|
||||
Write-Host "Versions set to: $VERSION"
|
||||
|
||||
- uses: tauri-apps/tauri-action@19b93bb55601e3e373a93cfb6eb4242e45f5af20 # ratchet:tauri-apps/tauri-action@action-v0.6.0
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
tagName: ${{ github.event_name != 'workflow_dispatch' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseName: ${{ github.event_name != 'workflow_dispatch' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseBody: 'See the assets to download this version and install.'
|
||||
releaseDraft: true
|
||||
prerelease: false
|
||||
args: ${{ matrix.args }}
|
||||
|
||||
build-web-amd64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-web == 'true'
|
||||
@@ -147,7 +288,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -205,7 +346,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -267,7 +408,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -313,7 +454,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -379,7 +520,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -449,7 +590,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -492,7 +633,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -549,7 +690,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -610,7 +751,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -657,7 +798,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -721,7 +862,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -788,7 +929,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -980,6 +1121,7 @@ jobs:
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs:
|
||||
- build-desktop
|
||||
- build-web-amd64
|
||||
- build-web-arm64
|
||||
- merge-web
|
||||
@@ -992,7 +1134,7 @@ jobs:
|
||||
- build-model-server-amd64
|
||||
- build-model-server-arm64
|
||||
- merge-model-server
|
||||
if: always() && (needs.build-web-amd64.result == 'failure' || needs.build-web-arm64.result == 'failure' || needs.merge-web.result == 'failure' || needs.build-web-cloud-amd64.result == 'failure' || needs.build-web-cloud-arm64.result == 'failure' || needs.merge-web-cloud.result == 'failure' || needs.build-backend-amd64.result == 'failure' || needs.build-backend-arm64.result == 'failure' || needs.merge-backend.result == 'failure' || needs.build-model-server-amd64.result == 'failure' || needs.build-model-server-arm64.result == 'failure' || needs.merge-model-server.result == 'failure') && github.event_name != 'workflow_dispatch'
|
||||
if: always() && (needs.build-desktop.result == 'failure' || needs.build-web-amd64.result == 'failure' || needs.build-web-arm64.result == 'failure' || needs.merge-web.result == 'failure' || needs.build-web-cloud-amd64.result == 'failure' || needs.build-web-cloud-arm64.result == 'failure' || needs.merge-web-cloud.result == 'failure' || needs.build-backend-amd64.result == 'failure' || needs.build-backend-arm64.result == 'failure' || needs.merge-backend.result == 'failure' || needs.build-model-server-amd64.result == 'failure' || needs.build-model-server-arm64.result == 'failure' || needs.merge-model-server.result == 'failure') && github.event_name != 'workflow_dispatch'
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 90
|
||||
@@ -1007,6 +1149,9 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
FAILED_JOBS=""
|
||||
if [ "${NEEDS_BUILD_DESKTOP_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-desktop\\n"
|
||||
fi
|
||||
if [ "${NEEDS_BUILD_WEB_AMD64_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-web-amd64\\n"
|
||||
fi
|
||||
@@ -1047,6 +1192,7 @@ jobs:
|
||||
FAILED_JOBS=$(printf '%s' "$FAILED_JOBS" | sed 's/\\n$//')
|
||||
echo "jobs=$FAILED_JOBS" >> "$GITHUB_OUTPUT"
|
||||
env:
|
||||
NEEDS_BUILD_DESKTOP_RESULT: ${{ needs.build-desktop.result }}
|
||||
NEEDS_BUILD_WEB_AMD64_RESULT: ${{ needs.build-web-amd64.result }}
|
||||
NEEDS_BUILD_WEB_ARM64_RESULT: ${{ needs.build-web-arm64.result }}
|
||||
NEEDS_MERGE_WEB_RESULT: ${{ needs.merge-web.result }}
|
||||
|
||||
10
.github/workflows/pr-integration-tests.yml
vendored
10
.github/workflows/pr-integration-tests.yml
vendored
@@ -33,6 +33,11 @@ env:
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
|
||||
GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN: ${{ secrets.ONYX_GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN }}
|
||||
GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC: ${{ secrets.ONYX_GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC }}
|
||||
GITHUB_ADMIN_EMAIL: ${{ secrets.ONYX_GITHUB_ADMIN_EMAIL }}
|
||||
GITHUB_TEST_USER_1_EMAIL: ${{ secrets.ONYX_GITHUB_TEST_USER_1_EMAIL }}
|
||||
GITHUB_TEST_USER_2_EMAIL: ${{ secrets.ONYX_GITHUB_TEST_USER_2_EMAIL }}
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
@@ -399,6 +404,11 @@ jobs:
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
|
||||
-e GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN=${GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN} \
|
||||
-e GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC=${GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC} \
|
||||
-e GITHUB_ADMIN_EMAIL=${GITHUB_ADMIN_EMAIL} \
|
||||
-e GITHUB_TEST_USER_1_EMAIL=${GITHUB_TEST_USER_1_EMAIL} \
|
||||
-e GITHUB_TEST_USER_2_EMAIL=${GITHUB_TEST_USER_2_EMAIL} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
|
||||
7
.github/workflows/pr-jest-tests.yml
vendored
7
.github/workflows/pr-jest-tests.yml
vendored
@@ -4,7 +4,14 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
7
.github/workflows/pr-playwright-tests.yml
vendored
7
.github/workflows/pr-playwright-tests.yml
vendored
@@ -4,7 +4,14 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
@@ -8,30 +8,66 @@ repos:
|
||||
# From: https://github.com/astral-sh/uv-pre-commit/pull/53/commits/d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
hooks:
|
||||
- id: uv-run
|
||||
name: Check lazy imports
|
||||
args: ["--with=onyx-devtools", "ods", "check-lazy-imports"]
|
||||
files: ^backend/(?!\.venv/).*\.py$
|
||||
- id: uv-sync
|
||||
args: ["--locked", "--all-extras"]
|
||||
args: ["--active", "--locked", "--all-extras"]
|
||||
- id: uv-lock
|
||||
files: ^pyproject\.toml$
|
||||
- id: uv-export
|
||||
name: uv-export default.txt
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "backend", "-o", "backend/requirements/default.txt"]
|
||||
args:
|
||||
[
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"backend",
|
||||
"-o",
|
||||
"backend/requirements/default.txt",
|
||||
]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-export
|
||||
name: uv-export dev.txt
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "dev", "-o", "backend/requirements/dev.txt"]
|
||||
args:
|
||||
[
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"dev",
|
||||
"-o",
|
||||
"backend/requirements/dev.txt",
|
||||
]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-export
|
||||
name: uv-export ee.txt
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "ee", "-o", "backend/requirements/ee.txt"]
|
||||
args:
|
||||
[
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"ee",
|
||||
"-o",
|
||||
"backend/requirements/ee.txt",
|
||||
]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-export
|
||||
name: uv-export model_server.txt
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "model_server", "-o", "backend/requirements/model_server.txt"]
|
||||
args:
|
||||
[
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"model_server",
|
||||
"-o",
|
||||
"backend/requirements/model_server.txt",
|
||||
]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-run
|
||||
name: Check lazy imports
|
||||
args: ["--active", "--with=onyx-devtools", "ods", "check-lazy-imports"]
|
||||
files: ^backend/(?!\.venv/).*\.py$
|
||||
# NOTE: This takes ~6s on a single, large module which is prohibitively slow.
|
||||
# - id: uv-run
|
||||
# name: mypy
|
||||
@@ -40,68 +76,73 @@ repos:
|
||||
# files: ^backend/.*\.py$
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0
|
||||
rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
files: ^.github/
|
||||
|
||||
- repo: https://github.com/rhysd/actionlint
|
||||
rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9
|
||||
rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9
|
||||
hooks:
|
||||
- id: actionlint
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 8a737e727ac5ab2f1d4cf5876720ed276dc8dc4b # frozen: 25.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3.11
|
||||
- id: black
|
||||
language_version: python3.11
|
||||
|
||||
# this is a fork which keeps compatibility with black
|
||||
- repo: https://github.com/wimglenn/reorder-python-imports-black
|
||||
rev: f55cd27f90f0cf0ee775002c2383ce1c7820013d # frozen: v3.14.0
|
||||
rev: f55cd27f90f0cf0ee775002c2383ce1c7820013d # frozen: v3.14.0
|
||||
hooks:
|
||||
- id: reorder-python-imports
|
||||
args: ['--py311-plus', '--application-directories=backend/']
|
||||
# need to ignore alembic files, since reorder-python-imports gets confused
|
||||
# and thinks that alembic is a local package since there is a folder
|
||||
# in the backend directory called `alembic`
|
||||
exclude: ^backend/alembic/
|
||||
- id: reorder-python-imports
|
||||
args: ["--py311-plus", "--application-directories=backend/"]
|
||||
# need to ignore alembic files, since reorder-python-imports gets confused
|
||||
# and thinks that alembic is a local package since there is a folder
|
||||
# in the backend directory called `alembic`
|
||||
exclude: ^backend/alembic/
|
||||
|
||||
# These settings will remove unused imports with side effects
|
||||
# Note: The repo currently does not and should not have imports with side effects
|
||||
- repo: https://github.com/PyCQA/autoflake
|
||||
rev: 0544741e2b4a22b472d9d93e37d4ea9153820bb1 # frozen: v2.3.1
|
||||
rev: 0544741e2b4a22b472d9d93e37d4ea9153820bb1 # frozen: v2.3.1
|
||||
hooks:
|
||||
- id: autoflake
|
||||
args: [ '--remove-all-unused-imports', '--remove-unused-variables', '--in-place' , '--recursive']
|
||||
args:
|
||||
[
|
||||
"--remove-all-unused-imports",
|
||||
"--remove-unused-variables",
|
||||
"--in-place",
|
||||
"--recursive",
|
||||
]
|
||||
|
||||
- repo: https://github.com/golangci/golangci-lint
|
||||
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
|
||||
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
|
||||
hooks:
|
||||
- id: golangci-lint
|
||||
entry: bash -c "find tools/ -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: 971923581912ef60a6b70dbf0c3e9a39563c9d47 # frozen: v0.11.4
|
||||
rev: 971923581912ef60a6b70dbf0c3e9a39563c9d47 # frozen: v0.11.4
|
||||
hooks:
|
||||
- id: ruff
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||
rev: ffb6a759a979008c0e6dff86e39f4745a2d9eac4 # frozen: v3.1.0
|
||||
rev: ffb6a759a979008c0e6dff86e39f4745a2d9eac4 # frozen: v3.1.0
|
||||
hooks:
|
||||
- id: prettier
|
||||
types_or: [html, css, javascript, ts, tsx]
|
||||
language_version: system
|
||||
- id: prettier
|
||||
types_or: [html, css, javascript, ts, tsx]
|
||||
language_version: system
|
||||
|
||||
- repo: https://github.com/sirwart/ripsecrets
|
||||
rev: 7d94620933e79b8acaa0cd9e60e9864b07673d86 # frozen: v0.1.11
|
||||
rev: 7d94620933e79b8acaa0cd9e60e9864b07673d86 # frozen: v0.1.11
|
||||
hooks:
|
||||
- id: ripsecrets
|
||||
args:
|
||||
- --additional-pattern
|
||||
- ^sk-[A-Za-z0-9_\-]{20,}$
|
||||
|
||||
- --additional-pattern
|
||||
- ^sk-[A-Za-z0-9_\-]{20,}$
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
@@ -112,9 +153,13 @@ repos:
|
||||
pass_filenames: false
|
||||
files: \.tf$
|
||||
|
||||
# Uses tsgo (TypeScript's native Go compiler) for ~10x faster type checking.
|
||||
# This is a preview package - if it breaks:
|
||||
# 1. Try updating: cd web && npm update @typescript/native-preview
|
||||
# 2. Or fallback to tsc: replace 'tsgo' with 'tsc' below
|
||||
- id: typescript-check
|
||||
name: TypeScript type check
|
||||
entry: bash -c 'cd web && npm run types:check'
|
||||
entry: bash -c 'cd web && npx tsgo --noEmit --project tsconfig.types.json'
|
||||
language: system
|
||||
pass_filenames: false
|
||||
files: ^web/.*\.(ts|tsx)$
|
||||
|
||||
@@ -161,7 +161,7 @@ You will need Docker installed to run these containers.
|
||||
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis/MinIO with:
|
||||
|
||||
```bash
|
||||
docker compose up -d index relational_db cache minio
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d index relational_db cache minio
|
||||
```
|
||||
|
||||
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
|
||||
|
||||
@@ -12,8 +12,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "23957775e5f5"
|
||||
down_revision = "bc9771dccadf"
|
||||
branch_labels = None # type: ignore
|
||||
depends_on = None # type: ignore
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
"""add last refreshed at mcp server
|
||||
|
||||
Revision ID: 2a391f840e85
|
||||
Revises: 4cebcbc9b2ae
|
||||
Create Date: 2025-12-06 15:19:59.766066
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembi.
|
||||
revision = "2a391f840e85"
|
||||
down_revision = "4cebcbc9b2ae"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"mcp_server",
|
||||
sa.Column("last_refreshed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("mcp_server", "last_refreshed_at")
|
||||
@@ -0,0 +1,27 @@
|
||||
"""add tab_index to tool_call
|
||||
|
||||
Revision ID: 4cebcbc9b2ae
|
||||
Revises: a1b2c3d4e5f6
|
||||
Create Date: 2025-12-16
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4cebcbc9b2ae"
|
||||
down_revision = "a1b2c3d4e5f6"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("tab_index", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("tool_call", "tab_index")
|
||||
@@ -42,13 +42,13 @@ def upgrade() -> None:
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
@@ -63,13 +63,13 @@ def upgrade() -> None:
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
|
||||
49
backend/alembic/versions/a1b2c3d4e5f6_add_license_table.py
Normal file
49
backend/alembic/versions/a1b2c3d4e5f6_add_license_table.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""add license table
|
||||
|
||||
Revision ID: a1b2c3d4e5f6
|
||||
Revises: a01bf2971c5d
|
||||
Create Date: 2025-12-04 10:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a1b2c3d4e5f6"
|
||||
down_revision = "a01bf2971c5d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"license",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("license_data", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Singleton pattern - only ever one row in this table
|
||||
op.create_index(
|
||||
"idx_license_singleton",
|
||||
"license",
|
||||
[sa.text("(true)")],
|
||||
unique=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("idx_license_singleton", table_name="license")
|
||||
op.drop_table("license")
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Remove fast_default_model_name from llm_provider
|
||||
|
||||
Revision ID: a2b3c4d5e6f7
|
||||
Revises: 2a391f840e85
|
||||
Create Date: 2024-12-17
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a2b3c4d5e6f7"
|
||||
down_revision = "2a391f840e85"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("llm_provider", "fast_default_model_name")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column("fast_default_model_name", sa.String(), nullable=True),
|
||||
)
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Drop milestone table
|
||||
|
||||
Revision ID: b8c9d0e1f2a3
|
||||
Revises: a2b3c4d5e6f7
|
||||
Create Date: 2025-12-18
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import fastapi_users_db_sqlalchemy
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b8c9d0e1f2a3"
|
||||
down_revision = "a2b3c4d5e6f7"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_table("milestone")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_table(
|
||||
"milestone",
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("tenant_id", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("event_type", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("event_tracker", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("event_type", name="uq_milestone_event_type"),
|
||||
)
|
||||
@@ -0,0 +1,52 @@
|
||||
"""add_deep_research_tool
|
||||
|
||||
Revision ID: c1d2e3f4a5b6
|
||||
Revises: b8c9d0e1f2a3
|
||||
Create Date: 2025-12-18 16:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c1d2e3f4a5b6"
|
||||
down_revision = "b8c9d0e1f2a3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
DEEP_RESEARCH_TOOL = {
|
||||
"name": RESEARCH_AGENT_DB_NAME,
|
||||
"display_name": "Research Agent",
|
||||
"description": "The Research Agent is a sub-agent that conducts research on a specific topic.",
|
||||
"in_code_tool_id": "ResearchAgent",
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id, false)
|
||||
"""
|
||||
),
|
||||
DEEP_RESEARCH_TOOL,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM tool
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
{"in_code_tool_id": DEEP_RESEARCH_TOOL["in_code_tool_id"]},
|
||||
)
|
||||
@@ -257,8 +257,8 @@ def _migrate_files_to_external_storage() -> None:
|
||||
print(f"File {file_id} not found in PostgreSQL storage.")
|
||||
continue
|
||||
|
||||
lobj_id = cast(int, file_record.lobj_oid) # type: ignore
|
||||
file_metadata = cast(Any, file_record.file_metadata) # type: ignore
|
||||
lobj_id = cast(int, file_record.lobj_oid)
|
||||
file_metadata = cast(Any, file_record.file_metadata)
|
||||
|
||||
# Read file content from PostgreSQL
|
||||
try:
|
||||
@@ -280,7 +280,7 @@ def _migrate_files_to_external_storage() -> None:
|
||||
else:
|
||||
# Convert other types to dict if possible, otherwise None
|
||||
try:
|
||||
file_metadata = dict(file_record.file_metadata) # type: ignore
|
||||
file_metadata = dict(file_record.file_metadata)
|
||||
except (TypeError, ValueError):
|
||||
file_metadata = None
|
||||
|
||||
|
||||
@@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
|
||||
revision = "e209dc5a8156"
|
||||
down_revision = "48d14957fe80"
|
||||
branch_labels = None # type: ignore
|
||||
depends_on = None # type: ignore
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
@@ -8,7 +8,7 @@ Create Date: 2025-11-28 11:15:37.667340
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from onyx.db.enums import ( # type: ignore[import-untyped]
|
||||
from onyx.db.enums import (
|
||||
MCPTransport,
|
||||
MCPAuthenticationType,
|
||||
MCPAuthenticationPerformer,
|
||||
|
||||
@@ -82,9 +82,9 @@ def run_migrations_offline() -> None:
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata, # type: ignore[arg-type]
|
||||
include_object=include_object,
|
||||
) # type: ignore
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
@@ -118,6 +118,6 @@ def fetch_document_sets(
|
||||
.all()
|
||||
)
|
||||
|
||||
document_set_with_cc_pairs.append((document_set, cc_pairs)) # type: ignore
|
||||
document_set_with_cc_pairs.append((document_set, cc_pairs))
|
||||
|
||||
return document_set_with_cc_pairs
|
||||
|
||||
278
backend/ee/onyx/db/license.py
Normal file
278
backend/ee/onyx/db/license.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Database and cache operations for the license table."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from onyx.db.models import License
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
LICENSE_METADATA_KEY = "license:metadata"
|
||||
LICENSE_CACHE_TTL_SECONDS = 86400 # 24 hours
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Database CRUD Operations
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_license(db_session: Session) -> License | None:
|
||||
"""
|
||||
Get the current license (singleton pattern - only one row).
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
License object if exists, None otherwise
|
||||
"""
|
||||
return db_session.execute(select(License)).scalars().first()
|
||||
|
||||
|
||||
def upsert_license(db_session: Session, license_data: str) -> License:
|
||||
"""
|
||||
Insert or update the license (singleton pattern).
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
license_data: Base64-encoded signed license blob
|
||||
|
||||
Returns:
|
||||
The created or updated License object
|
||||
"""
|
||||
existing = get_license(db_session)
|
||||
|
||||
if existing:
|
||||
existing.license_data = license_data
|
||||
db_session.commit()
|
||||
db_session.refresh(existing)
|
||||
logger.info("License updated")
|
||||
return existing
|
||||
|
||||
new_license = License(license_data=license_data)
|
||||
db_session.add(new_license)
|
||||
db_session.commit()
|
||||
db_session.refresh(new_license)
|
||||
logger.info("License created")
|
||||
return new_license
|
||||
|
||||
|
||||
def delete_license(db_session: Session) -> bool:
|
||||
"""
|
||||
Delete the current license.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
True if deleted, False if no license existed
|
||||
"""
|
||||
existing = get_license(db_session)
|
||||
if existing:
|
||||
db_session.delete(existing)
|
||||
db_session.commit()
|
||||
logger.info("License deleted")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Seat Counting
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
"""
|
||||
Get current seat usage.
|
||||
|
||||
For multi-tenant: counts users in UserTenantMapping for this tenant.
|
||||
For self-hosted: counts all active users (includes both Onyx UI users
|
||||
and Slack users who have been converted to Onyx users).
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_count
|
||||
|
||||
return get_tenant_count(tenant_id or get_current_tenant_id())
|
||||
else:
|
||||
# Self-hosted: count all active users (Onyx + converted Slack users)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
result = db_session.execute(
|
||||
select(func.count()).select_from(User).where(User.is_active) # type: ignore
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Redis Cache Operations
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata | None:
|
||||
"""
|
||||
Get license metadata from Redis cache.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
|
||||
Returns:
|
||||
LicenseMetadata if cached, None otherwise
|
||||
"""
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_replica_client(tenant_id=tenant)
|
||||
|
||||
cached = redis_client.get(LICENSE_METADATA_KEY)
|
||||
if cached:
|
||||
try:
|
||||
cached_str: str
|
||||
if isinstance(cached, bytes):
|
||||
cached_str = cached.decode("utf-8")
|
||||
else:
|
||||
cached_str = str(cached)
|
||||
return LicenseMetadata.model_validate_json(cached_str)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse cached license metadata: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def invalidate_license_cache(tenant_id: str | None = None) -> None:
|
||||
"""
|
||||
Invalidate the license metadata cache (not the license itself).
|
||||
|
||||
This deletes the cached LicenseMetadata from Redis. The actual license
|
||||
in the database is not affected. Redis delete is idempotent - if the
|
||||
key doesn't exist, this is a no-op.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
"""
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_client(tenant_id=tenant)
|
||||
redis_client.delete(LICENSE_METADATA_KEY)
|
||||
logger.info("License cache invalidated")
|
||||
|
||||
|
||||
def update_license_cache(
|
||||
payload: LicensePayload,
|
||||
source: LicenseSource | None = None,
|
||||
grace_period_end: datetime | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LicenseMetadata:
|
||||
"""
|
||||
Update the Redis cache with license metadata.
|
||||
|
||||
We cache all license statuses (ACTIVE, GRACE_PERIOD, GATED_ACCESS) because:
|
||||
1. Frontend needs status to show appropriate UI/banners
|
||||
2. Caching avoids repeated DB + crypto verification on every request
|
||||
3. Status enforcement happens at the feature level, not here
|
||||
|
||||
Args:
|
||||
payload: Verified license payload
|
||||
source: How the license was obtained
|
||||
grace_period_end: Optional grace period end time
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
|
||||
Returns:
|
||||
The cached LicenseMetadata
|
||||
"""
|
||||
from ee.onyx.utils.license import get_license_status
|
||||
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_client(tenant_id=tenant)
|
||||
|
||||
used_seats = get_used_seats(tenant)
|
||||
status = get_license_status(payload, grace_period_end)
|
||||
|
||||
metadata = LicenseMetadata(
|
||||
tenant_id=payload.tenant_id,
|
||||
organization_name=payload.organization_name,
|
||||
seats=payload.seats,
|
||||
used_seats=used_seats,
|
||||
plan_type=payload.plan_type,
|
||||
issued_at=payload.issued_at,
|
||||
expires_at=payload.expires_at,
|
||||
grace_period_end=grace_period_end,
|
||||
status=status,
|
||||
source=source,
|
||||
stripe_subscription_id=payload.stripe_subscription_id,
|
||||
)
|
||||
|
||||
redis_client.setex(
|
||||
LICENSE_METADATA_KEY,
|
||||
LICENSE_CACHE_TTL_SECONDS,
|
||||
metadata.model_dump_json(),
|
||||
)
|
||||
|
||||
logger.info(f"License cache updated: {metadata.seats} seats, status={status.value}")
|
||||
return metadata
|
||||
|
||||
|
||||
def refresh_license_cache(
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
) -> LicenseMetadata | None:
|
||||
"""
|
||||
Refresh the license cache from the database.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
|
||||
Returns:
|
||||
LicenseMetadata if license exists, None otherwise
|
||||
"""
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
|
||||
license_record = get_license(db_session)
|
||||
if not license_record:
|
||||
invalidate_license_cache(tenant_id)
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = verify_license_signature(license_record.license_data)
|
||||
return update_license_cache(
|
||||
payload,
|
||||
source=LicenseSource.AUTO_FETCH,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to verify license during cache refresh: {e}")
|
||||
invalidate_license_cache(tenant_id)
|
||||
return None
|
||||
|
||||
|
||||
def get_license_metadata(
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
) -> LicenseMetadata | None:
|
||||
"""
|
||||
Get license metadata, using cache if available.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
|
||||
Returns:
|
||||
LicenseMetadata if license exists, None otherwise
|
||||
"""
|
||||
# Try cache first
|
||||
cached = get_cached_license_metadata(tenant_id)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
# Refresh from database
|
||||
return refresh_license_cache(db_session, tenant_id)
|
||||
@@ -14,6 +14,7 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
basic_router as enterprise_settings_router,
|
||||
)
|
||||
from ee.onyx.server.evals.api import router as evals_router
|
||||
from ee.onyx.server.license.api import router as license_router
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.tenant_tracking import (
|
||||
add_api_server_tenant_id_middleware,
|
||||
@@ -139,6 +140,8 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, enterprise_settings_router)
|
||||
include_router_with_global_prefix_prepended(application, usage_export_router)
|
||||
# License management
|
||||
include_router_with_global_prefix_prepended(application, license_router)
|
||||
|
||||
if MULTI_TENANT:
|
||||
# Tenant management
|
||||
|
||||
246
backend/ee/onyx/server/license/api.py
Normal file
246
backend/ee/onyx/server/license/api.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""License API endpoints."""
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.db.license import delete_license as db_delete_license
|
||||
from ee.onyx.db.license import get_license_metadata
|
||||
from ee.onyx.db.license import invalidate_license_cache
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from ee.onyx.db.license import update_license_cache
|
||||
from ee.onyx.db.license import upsert_license
|
||||
from ee.onyx.server.license.models import LicenseResponse
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from ee.onyx.server.license.models import LicenseStatusResponse
|
||||
from ee.onyx.server.license.models import LicenseUploadResponse
|
||||
from ee.onyx.server.license.models import SeatUsageResponse
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/license")
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_license_status(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseStatusResponse:
|
||||
"""Get current license status and seat usage."""
|
||||
metadata = get_license_metadata(db_session)
|
||||
|
||||
if not metadata:
|
||||
return LicenseStatusResponse(has_license=False)
|
||||
|
||||
return LicenseStatusResponse(
|
||||
has_license=True,
|
||||
seats=metadata.seats,
|
||||
used_seats=metadata.used_seats,
|
||||
plan_type=metadata.plan_type,
|
||||
issued_at=metadata.issued_at,
|
||||
expires_at=metadata.expires_at,
|
||||
grace_period_end=metadata.grace_period_end,
|
||||
status=metadata.status,
|
||||
source=metadata.source,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/seats")
|
||||
async def get_seat_usage(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SeatUsageResponse:
|
||||
"""Get detailed seat usage information."""
|
||||
metadata = get_license_metadata(db_session)
|
||||
|
||||
if not metadata:
|
||||
return SeatUsageResponse(
|
||||
total_seats=0,
|
||||
used_seats=0,
|
||||
available_seats=0,
|
||||
)
|
||||
|
||||
return SeatUsageResponse(
|
||||
total_seats=metadata.seats,
|
||||
used_seats=metadata.used_seats,
|
||||
available_seats=max(0, metadata.seats - metadata.used_seats),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/fetch")
|
||||
async def fetch_license(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseResponse:
|
||||
"""
|
||||
Fetch license from control plane.
|
||||
Used after Stripe checkout completion to retrieve the new license.
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
try:
|
||||
token = generate_data_plane_token()
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to generate data plane token: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Authentication configuration error"
|
||||
)
|
||||
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/license/{tenant_id}"
|
||||
response = requests.get(url, headers=headers, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
if not isinstance(data, dict) or "license" not in data:
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Invalid response from control plane"
|
||||
)
|
||||
|
||||
license_data = data["license"]
|
||||
if not license_data:
|
||||
raise HTTPException(status_code=404, detail="No license found")
|
||||
|
||||
# Verify signature before persisting
|
||||
payload = verify_license_signature(license_data)
|
||||
|
||||
# Verify the fetched license is for this tenant
|
||||
if payload.tenant_id != tenant_id:
|
||||
logger.error(
|
||||
f"License tenant mismatch: expected {tenant_id}, got {payload.tenant_id}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License tenant ID mismatch - control plane returned wrong license",
|
||||
)
|
||||
|
||||
# Persist to DB and update cache atomically
|
||||
upsert_license(db_session, license_data)
|
||||
try:
|
||||
update_license_cache(payload, source=LicenseSource.AUTO_FETCH)
|
||||
except Exception as cache_error:
|
||||
# Log but don't fail - DB is source of truth, cache will refresh on next read
|
||||
logger.warning(f"Failed to update license cache: {cache_error}")
|
||||
|
||||
return LicenseResponse(success=True, license=payload)
|
||||
|
||||
except requests.HTTPError as e:
|
||||
status_code = e.response.status_code if e.response is not None else 502
|
||||
logger.error(f"Control plane returned error: {status_code}")
|
||||
raise HTTPException(
|
||||
status_code=status_code,
|
||||
detail="Failed to fetch license from control plane",
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"License verification failed: {type(e).__name__}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except requests.RequestException:
|
||||
logger.exception("Failed to fetch license from control plane")
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Failed to connect to control plane"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload_license(
|
||||
license_file: UploadFile = File(...),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseUploadResponse:
|
||||
"""
|
||||
Upload a license file manually.
|
||||
Used for air-gapped deployments where control plane is not accessible.
|
||||
"""
|
||||
try:
|
||||
content = await license_file.read()
|
||||
license_data = content.decode("utf-8").strip()
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Invalid license file format")
|
||||
|
||||
try:
|
||||
payload = verify_license_signature(license_data)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
if payload.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"License tenant ID mismatch. Expected {tenant_id}, got {payload.tenant_id}",
|
||||
)
|
||||
|
||||
# Persist to DB and update cache
|
||||
upsert_license(db_session, license_data)
|
||||
try:
|
||||
update_license_cache(payload, source=LicenseSource.MANUAL_UPLOAD)
|
||||
except Exception as cache_error:
|
||||
# Log but don't fail - DB is source of truth, cache will refresh on next read
|
||||
logger.warning(f"Failed to update license cache: {cache_error}")
|
||||
|
||||
return LicenseUploadResponse(
|
||||
success=True,
|
||||
message=f"License uploaded successfully. {payload.seats} seats, expires {payload.expires_at.date()}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh_license_cache_endpoint(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseStatusResponse:
|
||||
"""
|
||||
Force refresh the license cache from the database.
|
||||
Useful after manual database changes or to verify license validity.
|
||||
"""
|
||||
metadata = refresh_license_cache(db_session)
|
||||
|
||||
if not metadata:
|
||||
return LicenseStatusResponse(has_license=False)
|
||||
|
||||
return LicenseStatusResponse(
|
||||
has_license=True,
|
||||
seats=metadata.seats,
|
||||
used_seats=metadata.used_seats,
|
||||
plan_type=metadata.plan_type,
|
||||
issued_at=metadata.issued_at,
|
||||
expires_at=metadata.expires_at,
|
||||
grace_period_end=metadata.grace_period_end,
|
||||
status=metadata.status,
|
||||
source=metadata.source,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("")
|
||||
async def delete_license(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""
|
||||
Delete the current license.
|
||||
Admin only - removes license and invalidates cache.
|
||||
"""
|
||||
# Invalidate cache first - if DB delete fails, stale cache is worse than no cache
|
||||
try:
|
||||
invalidate_license_cache()
|
||||
except Exception as cache_error:
|
||||
logger.warning(f"Failed to invalidate license cache: {cache_error}")
|
||||
|
||||
deleted = db_delete_license(db_session)
|
||||
|
||||
return {"deleted": deleted}
|
||||
92
backend/ee/onyx/server/license/models.py
Normal file
92
backend/ee/onyx/server/license/models.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
|
||||
|
||||
class PlanType(str, Enum):
|
||||
MONTHLY = "monthly"
|
||||
ANNUAL = "annual"
|
||||
|
||||
|
||||
class LicenseSource(str, Enum):
|
||||
AUTO_FETCH = "auto_fetch"
|
||||
MANUAL_UPLOAD = "manual_upload"
|
||||
|
||||
|
||||
class LicensePayload(BaseModel):
|
||||
"""The payload portion of a signed license."""
|
||||
|
||||
version: str
|
||||
tenant_id: str
|
||||
organization_name: str | None = None
|
||||
issued_at: datetime
|
||||
expires_at: datetime
|
||||
seats: int
|
||||
plan_type: PlanType
|
||||
billing_cycle: str | None = None
|
||||
grace_period_days: int = 30
|
||||
stripe_subscription_id: str | None = None
|
||||
stripe_customer_id: str | None = None
|
||||
|
||||
|
||||
class LicenseData(BaseModel):
|
||||
"""Full signed license structure."""
|
||||
|
||||
payload: LicensePayload
|
||||
signature: str
|
||||
|
||||
|
||||
class LicenseMetadata(BaseModel):
|
||||
"""Cached license metadata stored in Redis."""
|
||||
|
||||
tenant_id: str
|
||||
organization_name: str | None = None
|
||||
seats: int
|
||||
used_seats: int
|
||||
plan_type: PlanType
|
||||
issued_at: datetime
|
||||
expires_at: datetime
|
||||
grace_period_end: datetime | None = None
|
||||
status: ApplicationStatus
|
||||
source: LicenseSource | None = None
|
||||
stripe_subscription_id: str | None = None
|
||||
|
||||
|
||||
class LicenseStatusResponse(BaseModel):
|
||||
"""Response for license status API."""
|
||||
|
||||
has_license: bool
|
||||
seats: int = 0
|
||||
used_seats: int = 0
|
||||
plan_type: PlanType | None = None
|
||||
issued_at: datetime | None = None
|
||||
expires_at: datetime | None = None
|
||||
grace_period_end: datetime | None = None
|
||||
status: ApplicationStatus | None = None
|
||||
source: LicenseSource | None = None
|
||||
|
||||
|
||||
class LicenseResponse(BaseModel):
|
||||
"""Response after license fetch/upload."""
|
||||
|
||||
success: bool
|
||||
message: str | None = None
|
||||
license: LicensePayload | None = None
|
||||
|
||||
|
||||
class LicenseUploadResponse(BaseModel):
|
||||
"""Response after license upload."""
|
||||
|
||||
success: bool
|
||||
message: str | None = None
|
||||
|
||||
|
||||
class SeatUsageResponse(BaseModel):
|
||||
"""Response for seat usage API."""
|
||||
|
||||
total_seats: int
|
||||
used_seats: int
|
||||
available_seats: int
|
||||
@@ -20,7 +20,7 @@ from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -100,7 +100,6 @@ def handle_simplified_chat_message(
|
||||
chunks_below=0,
|
||||
full_doc=chat_message_req.full_doc,
|
||||
structured_response_format=chat_message_req.structured_response_format,
|
||||
use_agentic_search=chat_message_req.use_agentic_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
@@ -158,7 +157,7 @@ def handle_send_message_simple_with_history(
|
||||
persona_id=req.persona_id,
|
||||
)
|
||||
|
||||
llm, _ = get_llms_for_persona(persona=chat_session.persona, user=user)
|
||||
llm = get_llm_for_persona(persona=chat_session.persona, user=user)
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
@@ -205,7 +204,6 @@ def handle_send_message_simple_with_history(
|
||||
chunks_below=0,
|
||||
full_doc=req.full_doc,
|
||||
structured_response_format=req.structured_response_format,
|
||||
use_agentic_search=req.use_agentic_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
|
||||
@@ -54,9 +54,6 @@ class BasicCreateChatMessageRequest(ChunkContext):
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_chat_session_or_persona(self) -> "BasicCreateChatMessageRequest":
|
||||
if self.chat_session_id is None and self.persona_id is None:
|
||||
@@ -76,8 +73,6 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
|
||||
class SimpleDoc(BaseModel):
|
||||
|
||||
@@ -45,7 +45,7 @@ from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRe
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.setup import setup_onyx
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
@@ -269,7 +269,6 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name="claude-3-7-sonnet-20250219",
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=name,
|
||||
@@ -296,7 +295,6 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name="gpt-4o",
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_name,
|
||||
@@ -562,17 +560,11 @@ async def assign_tenant_to_user(
|
||||
try:
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
# Create milestone record in the same transaction context as the tenant assignment
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||
properties={
|
||||
"email": email,
|
||||
},
|
||||
db_session=db_session,
|
||||
)
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=email,
|
||||
event=MilestoneRecordType.TENANT_CREATED,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
|
||||
raise Exception("Failed to assign tenant to user")
|
||||
|
||||
@@ -249,6 +249,17 @@ def accept_user_invite(email: str, tenant_id: str) -> None:
|
||||
)
|
||||
raise
|
||||
|
||||
# Remove from invited users list since they've accepted
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
invited_users = get_invited_users()
|
||||
if email in invited_users:
|
||||
invited_users.remove(email)
|
||||
write_invited_users(invited_users)
|
||||
logger.info(f"Removed {email} from invited users list after acceptance")
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def deny_user_invite(email: str, tenant_id: str) -> None:
|
||||
"""
|
||||
|
||||
126
backend/ee/onyx/utils/license.py
Normal file
126
backend/ee/onyx/utils/license.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""RSA-4096 license signature verification utilities."""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
|
||||
|
||||
from ee.onyx.server.license.models import LicenseData
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# RSA-4096 Public Key for license verification
|
||||
# Load from environment variable - key is generated on the control plane
|
||||
# In production, inject via Kubernetes secrets or secrets manager
|
||||
LICENSE_PUBLIC_KEY_PEM = os.environ.get("LICENSE_PUBLIC_KEY_PEM", "")
|
||||
|
||||
|
||||
def _get_public_key() -> RSAPublicKey:
|
||||
"""Load the public key from environment variable."""
|
||||
if not LICENSE_PUBLIC_KEY_PEM:
|
||||
raise ValueError(
|
||||
"LICENSE_PUBLIC_KEY_PEM environment variable not set. "
|
||||
"License verification requires the control plane public key."
|
||||
)
|
||||
key = serialization.load_pem_public_key(LICENSE_PUBLIC_KEY_PEM.encode())
|
||||
if not isinstance(key, RSAPublicKey):
|
||||
raise ValueError("Expected RSA public key")
|
||||
return key
|
||||
|
||||
|
||||
def verify_license_signature(license_data: str) -> LicensePayload:
|
||||
"""
|
||||
Verify RSA-4096 signature and return payload if valid.
|
||||
|
||||
Args:
|
||||
license_data: Base64-encoded JSON containing payload and signature
|
||||
|
||||
Returns:
|
||||
LicensePayload if signature is valid
|
||||
|
||||
Raises:
|
||||
ValueError: If license data is invalid or signature verification fails
|
||||
"""
|
||||
try:
|
||||
# Decode the license data
|
||||
decoded = json.loads(base64.b64decode(license_data))
|
||||
license_obj = LicenseData(**decoded)
|
||||
|
||||
payload_json = json.dumps(
|
||||
license_obj.payload.model_dump(mode="json"), sort_keys=True
|
||||
)
|
||||
signature_bytes = base64.b64decode(license_obj.signature)
|
||||
|
||||
# Verify signature using PSS padding (modern standard)
|
||||
public_key = _get_public_key()
|
||||
public_key.verify(
|
||||
signature_bytes,
|
||||
payload_json.encode(),
|
||||
padding.PSS(
|
||||
mgf=padding.MGF1(hashes.SHA256()),
|
||||
salt_length=padding.PSS.MAX_LENGTH,
|
||||
),
|
||||
hashes.SHA256(),
|
||||
)
|
||||
|
||||
return license_obj.payload
|
||||
|
||||
except InvalidSignature:
|
||||
logger.error("License signature verification failed")
|
||||
raise ValueError("Invalid license signature")
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Failed to decode license JSON")
|
||||
raise ValueError("Invalid license format: not valid JSON")
|
||||
except (ValueError, KeyError, TypeError) as e:
|
||||
logger.error(f"License data validation error: {type(e).__name__}")
|
||||
raise ValueError(f"Invalid license format: {type(e).__name__}")
|
||||
except Exception:
|
||||
logger.exception("Unexpected error during license verification")
|
||||
raise ValueError("License verification failed: unexpected error")
|
||||
|
||||
|
||||
def get_license_status(
|
||||
payload: LicensePayload,
|
||||
grace_period_end: datetime | None = None,
|
||||
) -> ApplicationStatus:
|
||||
"""
|
||||
Determine current license status based on expiry.
|
||||
|
||||
Args:
|
||||
payload: The verified license payload
|
||||
grace_period_end: Optional grace period end datetime
|
||||
|
||||
Returns:
|
||||
ApplicationStatus indicating current license state
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Check if grace period has expired
|
||||
if grace_period_end and now > grace_period_end:
|
||||
return ApplicationStatus.GATED_ACCESS
|
||||
|
||||
# Check if license has expired
|
||||
if now > payload.expires_at:
|
||||
if grace_period_end and now <= grace_period_end:
|
||||
return ApplicationStatus.GRACE_PERIOD
|
||||
return ApplicationStatus.GATED_ACCESS
|
||||
|
||||
# License is valid
|
||||
return ApplicationStatus.ACTIVE
|
||||
|
||||
|
||||
def is_license_valid(payload: LicensePayload) -> bool:
|
||||
"""Check if a license is currently valid (not expired)."""
|
||||
now = datetime.now(timezone.utc)
|
||||
return now <= payload.expires_at
|
||||
@@ -6,7 +6,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fastapi import APIRouter
|
||||
from huggingface_hub import snapshot_download # type: ignore
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from model_server.constants import INFORMATION_CONTENT_MODEL_WARM_UP_STRING
|
||||
from model_server.constants import MODEL_WARM_UP_STRING
|
||||
@@ -36,8 +36,8 @@ from shared_configs.model_server_models import IntentRequest
|
||||
from shared_configs.model_server_models import IntentResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from setfit import SetFitModel # type: ignore
|
||||
from transformers import PreTrainedTokenizer, BatchEncoding # type: ignore
|
||||
from setfit import SetFitModel # type: ignore[import-untyped]
|
||||
from transformers import PreTrainedTokenizer, BatchEncoding
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -42,7 +42,7 @@ def get_embedding_model(
|
||||
Loads or returns a cached SentenceTransformer, sets max_seq_length, pins device,
|
||||
pre-warms rotary caches once, and wraps encode() with a lock to avoid cache races.
|
||||
"""
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
def _prewarm_rope(st_model: "SentenceTransformer", target_len: int) -> None:
|
||||
"""
|
||||
@@ -91,7 +91,7 @@ def get_local_reranking_model(
|
||||
model_name: str,
|
||||
) -> "CrossEncoder":
|
||||
global _RERANK_MODEL
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
if _RERANK_MODEL is None:
|
||||
logger.notice(f"Loading {model_name}")
|
||||
@@ -195,7 +195,7 @@ async def local_rerank(query: str, docs: list[str], model_name: str) -> list[flo
|
||||
# Run CPU-bound reranking in a thread pool
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(), # type: ignore
|
||||
lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from fastapi import FastAPI
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
||||
from sentry_sdk.integrations.starlette import StarletteIntegration
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
from transformers import logging as transformer_logging
|
||||
|
||||
from model_server.custom_models import router as custom_models_router
|
||||
from model_server.custom_models import warm_up_information_content_model
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import DistilBertConfig # type: ignore
|
||||
from transformers import DistilBertConfig
|
||||
|
||||
|
||||
class HybridClassifier(nn.Module):
|
||||
@@ -34,7 +34,7 @@ class HybridClassifier(nn.Module):
|
||||
query_ids: torch.Tensor,
|
||||
query_mask: torch.Tensor,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask) # type: ignore
|
||||
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
|
||||
# Intent classification on the CLS token
|
||||
@@ -102,7 +102,7 @@ class ConnectorClassifier(nn.Module):
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_states = self.distilbert( # type: ignore
|
||||
hidden_states = self.distilbert(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
).last_hidden_state
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ def get_access_for_document(
|
||||
versioned_get_access_for_document_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "_get_access_for_document"
|
||||
)
|
||||
return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore
|
||||
return versioned_get_access_for_document_fn(document_id, db_session)
|
||||
|
||||
|
||||
def get_null_document_access() -> DocumentAccess:
|
||||
@@ -93,9 +93,7 @@ def get_access_for_documents(
|
||||
versioned_get_access_for_documents_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "_get_access_for_documents"
|
||||
)
|
||||
return versioned_get_access_for_documents_fn(
|
||||
document_ids, db_session
|
||||
) # type: ignore
|
||||
return versioned_get_access_for_documents_fn(document_ids, db_session)
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
@@ -113,7 +111,7 @@ def get_acl_for_user(user: User | None, db_session: Session | None = None) -> se
|
||||
versioned_acl_for_user_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "_get_acl_for_user"
|
||||
)
|
||||
return versioned_acl_for_user_fn(user, db_session) # type: ignore
|
||||
return versioned_acl_for_user_fn(user, db_session)
|
||||
|
||||
|
||||
def source_should_fetch_permissions_during_indexing(source: DocumentSource) -> bool:
|
||||
|
||||
@@ -117,7 +117,7 @@ from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.timing import log_function_time
|
||||
@@ -338,9 +338,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
user_created = False
|
||||
try:
|
||||
user = await super().create(
|
||||
user_create, safe=safe, request=request
|
||||
) # type: ignore
|
||||
user = await super().create(user_create, safe=safe, request=request)
|
||||
user_created = True
|
||||
except IntegrityError as error:
|
||||
# Race condition: another request created the same user after the
|
||||
@@ -604,10 +602,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
|
||||
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
||||
if (
|
||||
user.oidc_expiry is not None # type: ignore
|
||||
and not TRACK_EXTERNAL_IDP_EXPIRY
|
||||
):
|
||||
if user.oidc_expiry is not None and not TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
await self.user_db.update(user, {"oidc_expiry": None})
|
||||
user.oidc_expiry = None # type: ignore
|
||||
remove_user_from_invited_users(user.email)
|
||||
@@ -653,19 +648,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_count = await get_user_count()
|
||||
logger.debug(f"Current tenant user count: {user_count}")
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
event_type = (
|
||||
MilestoneRecordType.USER_SIGNED_UP
|
||||
if user_count == 1
|
||||
else MilestoneRecordType.MULTIPLE_USERS
|
||||
)
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=user.email,
|
||||
event_type=event_type,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=user.email,
|
||||
event=MilestoneRecordType.USER_SIGNED_UP,
|
||||
)
|
||||
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
@@ -1186,7 +1173,7 @@ async def _sync_jwt_oidc_expiry(
|
||||
return
|
||||
|
||||
await user_manager.user_db.update(user, {"oidc_expiry": oidc_expiry})
|
||||
user.oidc_expiry = oidc_expiry # type: ignore
|
||||
user.oidc_expiry = oidc_expiry
|
||||
return
|
||||
|
||||
if user.oidc_expiry is not None:
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import SearchSettings
|
||||
|
||||
|
||||
def try_creating_docfetching_task(
|
||||
celery_app: Celery,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
search_settings: SearchSettings,
|
||||
reindex: bool,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
"""Checks for any conditions that should block the indexing task from being
|
||||
created, then creates the task.
|
||||
|
||||
Does not check for scheduling related conditions as this function
|
||||
is used to trigger indexing immediately.
|
||||
|
||||
Now uses database-based coordination instead of Redis fencing.
|
||||
"""
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
# we need to serialize any attempt to trigger indexing since it can be triggered
|
||||
# either via celery beat or manually (API call)
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
index_attempt_id = None
|
||||
try:
|
||||
# Basic status checks
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# Generate custom task ID for tracking
|
||||
custom_task_id = f"docfetching_{cc_pair.id}_{search_settings.id}_{uuid4()}"
|
||||
|
||||
# Try to create a new index attempt using database coordination
|
||||
# This replaces the Redis fencing mechanism
|
||||
index_attempt_id = IndexingCoordination.try_create_index_attempt(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings.id,
|
||||
celery_task_id=custom_task_id,
|
||||
from_beginning=reindex,
|
||||
)
|
||||
|
||||
if index_attempt_id is None:
|
||||
# Another indexing attempt is already running
|
||||
return None
|
||||
|
||||
# Determine which queue to use based on whether this is a user file
|
||||
# TODO: at the moment the indexing pipeline is
|
||||
# shared between user files and connectors
|
||||
queue = (
|
||||
OnyxCeleryQueues.USER_FILES_INDEXING
|
||||
if cc_pair.is_user_file
|
||||
else OnyxCeleryQueues.CONNECTOR_DOC_FETCHING
|
||||
)
|
||||
|
||||
# Use higher priority for first-time indexing to ensure new connectors
|
||||
# get processed before re-indexing of existing connectors
|
||||
has_successful_attempt = cc_pair.last_successful_index_time is not None
|
||||
priority = (
|
||||
OnyxCeleryPriority.MEDIUM
|
||||
if has_successful_attempt
|
||||
else OnyxCeleryPriority.HIGH
|
||||
)
|
||||
|
||||
# Send the task to Celery
|
||||
result = celery_app.send_task(
|
||||
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
|
||||
kwargs=dict(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings.id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=queue,
|
||||
task_id=custom_task_id,
|
||||
priority=priority,
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("send_task for connector_doc_fetching_task failed.")
|
||||
|
||||
task_logger.info(
|
||||
f"Created docfetching task: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id} "
|
||||
f"attempt_id={index_attempt_id} "
|
||||
f"celery_task_id={custom_task_id}"
|
||||
)
|
||||
|
||||
return index_attempt_id
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"try_creating_indexing_task - Unexpected exception: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
# Clean up on failure
|
||||
if index_attempt_id is not None:
|
||||
mark_attempt_failed(index_attempt_id, db_session)
|
||||
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return index_attempt_id
|
||||
@@ -25,14 +25,14 @@ from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.background.celery.tasks.docfetching.task_creation_utils import (
|
||||
try_creating_docfetching_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
|
||||
from onyx.background.celery.tasks.docprocessing.heartbeat import stop_heartbeat
|
||||
from onyx.background.celery.tasks.docprocessing.utils import IndexingCallback
|
||||
from onyx.background.celery.tasks.docprocessing.utils import is_in_repeated_error_state
|
||||
from onyx.background.celery.tasks.docprocessing.utils import should_index
|
||||
from onyx.background.celery.tasks.docprocessing.utils import (
|
||||
try_creating_docfetching_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.models import DocProcessingContext
|
||||
from onyx.background.indexing.checkpointing_utils import cleanup_checkpoint
|
||||
from onyx.background.indexing.checkpointing_utils import (
|
||||
@@ -45,6 +45,7 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -108,6 +109,7 @@ from onyx.redis.redis_utils import is_fence
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
@@ -547,6 +549,12 @@ def check_indexing_completion(
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=tenant_id,
|
||||
event=MilestoneRecordType.CONNECTOR_SUCCEEDED,
|
||||
)
|
||||
|
||||
# Clear repeated error state on success
|
||||
if cc_pair.in_repeated_error_state:
|
||||
cc_pair.in_repeated_error_state = False
|
||||
@@ -1404,8 +1412,13 @@ def _docprocessing_task(
|
||||
)
|
||||
|
||||
# Process documents through indexing pipeline
|
||||
connector_source = (
|
||||
index_attempt.connector_credential_pair.connector.source.value
|
||||
)
|
||||
task_logger.info(
|
||||
f"Processing {len(documents)} documents through indexing pipeline"
|
||||
f"Processing {len(documents)} documents through indexing pipeline: "
|
||||
f"cc_pair_id={cc_pair_id}, source={connector_source}, "
|
||||
f"batch_num={batch_num}"
|
||||
)
|
||||
|
||||
adapter = DocumentIndexingBatchAdapter(
|
||||
@@ -1495,6 +1508,8 @@ def _docprocessing_task(
|
||||
|
||||
# FIX: Explicitly clear document batch from memory and force garbage collection
|
||||
# This helps prevent memory accumulation across multiple batches
|
||||
# NOTE: Thread-local event loops in embedding threads are cleaned up automatically
|
||||
# via the _cleanup_thread_local decorator in search_nlp_models.py
|
||||
del documents
|
||||
gc.collect()
|
||||
|
||||
|
||||
@@ -1,22 +1,15 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from redis.exceptions import LockError
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine.time_utils import get_db_current_time
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -24,8 +17,6 @@ from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@@ -298,112 +289,3 @@ def should_index(
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_docfetching_task(
|
||||
celery_app: Celery,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
search_settings: SearchSettings,
|
||||
reindex: bool,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
"""Checks for any conditions that should block the indexing task from being
|
||||
created, then creates the task.
|
||||
|
||||
Does not check for scheduling related conditions as this function
|
||||
is used to trigger indexing immediately.
|
||||
|
||||
Now uses database-based coordination instead of Redis fencing.
|
||||
"""
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
# we need to serialize any attempt to trigger indexing since it can be triggered
|
||||
# either via celery beat or manually (API call)
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
index_attempt_id = None
|
||||
try:
|
||||
# Basic status checks
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# Generate custom task ID for tracking
|
||||
custom_task_id = f"docfetching_{cc_pair.id}_{search_settings.id}_{uuid4()}"
|
||||
|
||||
# Try to create a new index attempt using database coordination
|
||||
# This replaces the Redis fencing mechanism
|
||||
index_attempt_id = IndexingCoordination.try_create_index_attempt(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings.id,
|
||||
celery_task_id=custom_task_id,
|
||||
from_beginning=reindex,
|
||||
)
|
||||
|
||||
if index_attempt_id is None:
|
||||
# Another indexing attempt is already running
|
||||
return None
|
||||
|
||||
# Determine which queue to use based on whether this is a user file
|
||||
# TODO: at the moment the indexing pipeline is
|
||||
# shared between user files and connectors
|
||||
queue = (
|
||||
OnyxCeleryQueues.USER_FILES_INDEXING
|
||||
if cc_pair.is_user_file
|
||||
else OnyxCeleryQueues.CONNECTOR_DOC_FETCHING
|
||||
)
|
||||
|
||||
# Send the task to Celery
|
||||
result = celery_app.send_task(
|
||||
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
|
||||
kwargs=dict(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings.id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=queue,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("send_task for connector_doc_fetching_task failed.")
|
||||
|
||||
task_logger.info(
|
||||
f"Created docfetching task: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id} "
|
||||
f"attempt_id={index_attempt_id} "
|
||||
f"celery_task_id={custom_task_id}"
|
||||
)
|
||||
|
||||
return index_attempt_id
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"try_creating_indexing_task - Unexpected exception: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
# Clean up on failure
|
||||
if index_attempt_id is not None:
|
||||
mark_attempt_failed(index_attempt_id, db_session)
|
||||
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return index_attempt_id
|
||||
|
||||
@@ -127,12 +127,6 @@ def check_for_llm_model_update(self: Task, *, tenant_id: str) -> bool | None:
|
||||
f"available, setting to first model in list: {available_models[0]}"
|
||||
)
|
||||
default_provider.default_model_name = available_models[0]
|
||||
if default_provider.fast_default_model_name not in available_models:
|
||||
task_logger.info(
|
||||
f"Fast default model {default_provider.fast_default_model_name} "
|
||||
f"not available, setting to first model in list: {available_models[0]}"
|
||||
)
|
||||
default_provider.fast_default_model_name = available_models[0]
|
||||
db_session.commit()
|
||||
|
||||
if added_models or removed_models:
|
||||
|
||||
@@ -55,8 +55,8 @@ class RetryDocumentIndex:
|
||||
chunk_count: int | None,
|
||||
fields: VespaDocumentFields | None,
|
||||
user_fields: VespaDocumentUserFields | None,
|
||||
) -> int:
|
||||
return self.index.update_single(
|
||||
) -> None:
|
||||
self.index.update_single(
|
||||
doc_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
|
||||
@@ -95,7 +95,6 @@ def document_by_cc_pair_cleanup_task(
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
action = "skip"
|
||||
chunks_affected = 0
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
@@ -114,7 +113,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
|
||||
chunk_count = fetch_chunk_count_for_document(document_id, db_session)
|
||||
|
||||
chunks_affected = retry_index.delete_single(
|
||||
_ = retry_index.delete_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
@@ -157,7 +156,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
chunks_affected = retry_index.update_single(
|
||||
retry_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
@@ -187,7 +186,6 @@ def document_by_cc_pair_cleanup_task(
|
||||
f"doc={document_id} "
|
||||
f"action={action} "
|
||||
f"refcount={count} "
|
||||
f"chunks={chunks_affected} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
|
||||
@@ -597,7 +597,7 @@ def process_single_user_file_project_sync(
|
||||
return None
|
||||
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
chunks_affected = retry_index.update_single(
|
||||
retry_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
@@ -606,7 +606,7 @@ def process_single_user_file_project_sync(
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - Chunks affected id={user_file_id} chunks={chunks_affected}"
|
||||
f"process_single_user_file_project_sync - User file id={user_file_id}"
|
||||
)
|
||||
|
||||
user_file.needs_project_sync = False
|
||||
@@ -874,7 +874,10 @@ def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
|
||||
)
|
||||
|
||||
# Now update Vespa chunks with the found chunk count using retry_index
|
||||
updated_chunks = retry_index.update_single(
|
||||
# WARNING: In the future this will error; we no longer want
|
||||
# to support changing document ID.
|
||||
# TODO(andrei): Delete soon.
|
||||
retry_index.update_single(
|
||||
doc_id=str(normalized_doc_id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
@@ -883,7 +886,7 @@ def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
|
||||
user_projects=user_project_ids
|
||||
),
|
||||
)
|
||||
user_file.chunk_count = updated_chunks
|
||||
user_file.chunk_count = chunk_count
|
||||
|
||||
# Update the SearchDocs
|
||||
actual_doc_id = str(user_file.document_id)
|
||||
|
||||
@@ -501,7 +501,7 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
chunks_affected = retry_index.update_single(
|
||||
retry_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
@@ -515,10 +515,7 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action=sync "
|
||||
f"chunks={chunks_affected} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
f"doc={document_id} " f"action=sync " f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
|
||||
except SoftTimeLimitExceeded:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -21,7 +20,6 @@ from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
|
||||
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
|
||||
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -32,11 +30,8 @@ from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorStopSignal
|
||||
from onyx.connectors.models import DocExtractionContext
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_last_successful_attempt_poll_range_end
|
||||
@@ -49,34 +44,16 @@ from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
|
||||
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.index_attempt import mark_attempt_partially_succeeded
|
||||
from onyx.db.index_attempt import mark_attempt_succeeded
|
||||
from onyx.db.index_attempt import transition_attempt_to_in_progress
|
||||
from onyx.db.index_attempt import update_docs_indexed
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.adapters.document_indexing_adapter import (
|
||||
DocumentIndexingBatchAdapter,
|
||||
)
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
|
||||
@@ -272,583 +249,6 @@ def _check_failure_threshold(
|
||||
)
|
||||
|
||||
|
||||
# NOTE: this is the old run_indexing function that the new decoupled approach
|
||||
# is based on. Leaving this for comparison purposes, but if you see this comment
|
||||
# has been here for >2 month, please delete this function.
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
tenant_id: str,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
1. Get documents which are either new or updated from specified application
|
||||
2. Embed and index these documents into the chosen datastore (vespa)
|
||||
3. Updates Postgres to record the indexed documents + the outcome of this run
|
||||
"""
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
|
||||
start_time = time.monotonic() # jsut used for logging
|
||||
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
index_attempt_start = get_index_attempt(
|
||||
db_session_temp,
|
||||
index_attempt_id,
|
||||
eager_load_cc_pair=True,
|
||||
eager_load_search_settings=True,
|
||||
)
|
||||
if not index_attempt_start:
|
||||
raise ValueError(
|
||||
f"Index attempt {index_attempt_id} does not exist in DB. This should not be possible."
|
||||
)
|
||||
|
||||
if index_attempt_start.search_settings is None:
|
||||
raise ValueError(
|
||||
"Search settings must be set for indexing. This should not be possible."
|
||||
)
|
||||
|
||||
db_connector = index_attempt_start.connector_credential_pair.connector
|
||||
db_credential = index_attempt_start.connector_credential_pair.credential
|
||||
is_primary = (
|
||||
index_attempt_start.search_settings.status == IndexModelStatus.PRESENT
|
||||
)
|
||||
from_beginning = index_attempt_start.from_beginning
|
||||
has_successful_attempt = (
|
||||
index_attempt_start.connector_credential_pair.last_successful_index_time
|
||||
is not None
|
||||
)
|
||||
ctx = DocExtractionContext(
|
||||
index_name=index_attempt_start.search_settings.index_name,
|
||||
cc_pair_id=index_attempt_start.connector_credential_pair.id,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
source=db_connector.source,
|
||||
earliest_index_time=(
|
||||
db_connector.indexing_start.timestamp()
|
||||
if db_connector.indexing_start
|
||||
else 0
|
||||
),
|
||||
from_beginning=from_beginning,
|
||||
# Only update cc-pair status for primary index jobs
|
||||
# Secondary index syncs at the end when swapping
|
||||
is_primary=is_primary,
|
||||
should_fetch_permissions_during_indexing=(
|
||||
index_attempt_start.connector_credential_pair.access_type
|
||||
== AccessType.SYNC
|
||||
and source_should_fetch_permissions_during_indexing(db_connector.source)
|
||||
and is_primary
|
||||
# if we've already successfully indexed, let the doc_sync job
|
||||
# take care of doc-level permissions
|
||||
and (from_beginning or not has_successful_attempt)
|
||||
),
|
||||
search_settings_status=index_attempt_start.search_settings.status,
|
||||
doc_extraction_complete_batch_num=None,
|
||||
)
|
||||
|
||||
last_successful_index_poll_range_end = (
|
||||
ctx.earliest_index_time
|
||||
if ctx.from_beginning
|
||||
else get_last_successful_attempt_poll_range_end(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
earliest_index=ctx.earliest_index_time,
|
||||
search_settings=index_attempt_start.search_settings,
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
)
|
||||
if last_successful_index_poll_range_end > POLL_CONNECTOR_OFFSET:
|
||||
window_start = datetime.fromtimestamp(
|
||||
last_successful_index_poll_range_end, tz=timezone.utc
|
||||
) - timedelta(minutes=POLL_CONNECTOR_OFFSET)
|
||||
else:
|
||||
# don't go into "negative" time if we've never indexed before
|
||||
window_start = datetime.fromtimestamp(0, tz=timezone.utc)
|
||||
|
||||
most_recent_attempt = next(
|
||||
iter(
|
||||
get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
search_settings_id=index_attempt_start.search_settings_id,
|
||||
db_session=db_session_temp,
|
||||
limit=1,
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# if the last attempt failed, try and use the same window. This is necessary
|
||||
# to ensure correctness with checkpointing. If we don't do this, things like
|
||||
# new slack channels could be missed (since existing slack channels are
|
||||
# cached as part of the checkpoint).
|
||||
if (
|
||||
most_recent_attempt
|
||||
and most_recent_attempt.poll_range_end
|
||||
and (
|
||||
most_recent_attempt.status == IndexingStatus.FAILED
|
||||
or most_recent_attempt.status == IndexingStatus.CANCELED
|
||||
)
|
||||
):
|
||||
window_end = most_recent_attempt.poll_range_end
|
||||
else:
|
||||
window_end = datetime.now(tz=timezone.utc)
|
||||
|
||||
# add start/end now that they have been set
|
||||
index_attempt_start.poll_range_start = window_start
|
||||
index_attempt_start.poll_range_end = window_end
|
||||
db_session_temp.add(index_attempt_start)
|
||||
db_session_temp.commit()
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=index_attempt_start.search_settings,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
information_content_classification_model = InformationContentClassificationModel()
|
||||
|
||||
document_index = get_default_document_index(
|
||||
index_attempt_start.search_settings,
|
||||
None,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
# Initialize memory tracer. NOTE: won't actually do anything if
|
||||
# `INDEXING_TRACER_INTERVAL` is 0.
|
||||
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
|
||||
memory_tracer.start()
|
||||
|
||||
index_attempt_md = IndexAttemptMetadata(
|
||||
attempt_id=index_attempt_id,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
)
|
||||
|
||||
total_failures = 0
|
||||
batch_num = 0
|
||||
net_doc_change = 0
|
||||
document_count = 0
|
||||
chunk_count = 0
|
||||
index_attempt: IndexAttempt | None = None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session_temp, index_attempt_id, eager_load_cc_pair=True
|
||||
)
|
||||
if not index_attempt:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
|
||||
connector_runner = _get_connector_runner(
|
||||
db_session=db_session_temp,
|
||||
attempt=index_attempt,
|
||||
batch_size=INDEX_BATCH_SIZE,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
include_permissions=ctx.should_fetch_permissions_during_indexing,
|
||||
)
|
||||
|
||||
# don't use a checkpoint if we're explicitly indexing from
|
||||
# the beginning in order to avoid weird interactions between
|
||||
# checkpointing / failure handling
|
||||
# OR
|
||||
# if the last attempt was successful
|
||||
if index_attempt.from_beginning or (
|
||||
most_recent_attempt and most_recent_attempt.status.is_successful()
|
||||
):
|
||||
checkpoint = connector_runner.connector.build_dummy_checkpoint()
|
||||
else:
|
||||
checkpoint, _ = get_latest_valid_checkpoint(
|
||||
db_session=db_session_temp,
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
window_start=window_start,
|
||||
window_end=window_end,
|
||||
connector=connector_runner.connector,
|
||||
)
|
||||
|
||||
# save the initial checkpoint to have a proper record of the
|
||||
# "last used checkpoint"
|
||||
save_checkpoint(
|
||||
db_session=db_session_temp,
|
||||
index_attempt_id=index_attempt_id,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
unresolved_errors = get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
unresolved_only=True,
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
doc_id_to_unresolved_errors: dict[str, list[IndexAttemptError]] = (
|
||||
defaultdict(list)
|
||||
)
|
||||
for error in unresolved_errors:
|
||||
if error.document_id:
|
||||
doc_id_to_unresolved_errors[error.document_id].append(error)
|
||||
|
||||
entity_based_unresolved_errors = [
|
||||
error for error in unresolved_errors if error.entity_id
|
||||
]
|
||||
|
||||
while checkpoint.has_more:
|
||||
logger.info(
|
||||
f"Running '{ctx.source.value}' connector with checkpoint: {checkpoint}"
|
||||
)
|
||||
for document_batch, failure, next_checkpoint in connector_runner.run(
|
||||
checkpoint
|
||||
):
|
||||
# Check if connector is disabled mid run and stop if so unless it's the secondary
|
||||
# index being built. We want to populate it even for paused connectors
|
||||
# Often paused connectors are sources that aren't updated frequently but the
|
||||
# contents still need to be initially pulled.
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise ConnectorStopSignal("Connector stop signal detected")
|
||||
|
||||
# NOTE: this progress callback runs on every loop. We've seen cases
|
||||
# where we loop many times with no new documents and eventually time
|
||||
# out, so only doing the callback after indexing isn't sufficient.
|
||||
callback.progress("_run_indexing", 0)
|
||||
|
||||
# TODO: should we move this into the above callback instead?
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
# will exception if the connector/index attempt is marked as paused/failed
|
||||
_check_connector_and_attempt_status(
|
||||
db_session_temp,
|
||||
ctx.cc_pair_id,
|
||||
ctx.search_settings_status,
|
||||
index_attempt_id,
|
||||
)
|
||||
|
||||
# save record of any failures at the connector level
|
||||
if failure is not None:
|
||||
total_failures += 1
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
create_index_attempt_error(
|
||||
index_attempt_id,
|
||||
ctx.cc_pair_id,
|
||||
failure,
|
||||
db_session_temp,
|
||||
)
|
||||
|
||||
_check_failure_threshold(
|
||||
total_failures, document_count, batch_num, failure
|
||||
)
|
||||
|
||||
# save the new checkpoint (if one is provided)
|
||||
if next_checkpoint:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
# below is all document processing logic, so if no batch we can just continue
|
||||
if document_batch is None:
|
||||
continue
|
||||
|
||||
batch_description = []
|
||||
|
||||
# Generate an ID that can be used to correlate activity between here
|
||||
# and the embedding model server
|
||||
doc_batch_cleaned = strip_null_characters(document_batch)
|
||||
for doc in doc_batch_cleaned:
|
||||
batch_description.append(doc.to_short_descriptor())
|
||||
|
||||
doc_size = 0
|
||||
for section in doc.sections:
|
||||
if (
|
||||
isinstance(section, TextSection)
|
||||
and section.text is not None
|
||||
):
|
||||
doc_size += len(section.text)
|
||||
|
||||
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Document size: doc='{doc.to_short_descriptor()}' "
|
||||
f"size={doc_size} "
|
||||
f"threshold={INDEXING_SIZE_WARNING_THRESHOLD}"
|
||||
)
|
||||
|
||||
logger.debug(f"Indexing batch of documents: {batch_description}")
|
||||
|
||||
index_attempt_md.request_id = make_randomized_onyx_request_id("CIX")
|
||||
index_attempt_md.structured_id = (
|
||||
f"{tenant_id}:{ctx.cc_pair_id}:{index_attempt_id}:{batch_num}"
|
||||
)
|
||||
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
|
||||
|
||||
# real work happens here!
|
||||
adapter = DocumentIndexingBatchAdapter(
|
||||
db_session=db_session,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
tenant_id=tenant_id,
|
||||
index_attempt_metadata=index_attempt_md,
|
||||
)
|
||||
index_pipeline_result = run_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=(
|
||||
ctx.from_beginning
|
||||
or (ctx.search_settings_status == IndexModelStatus.FUTURE)
|
||||
),
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
document_batch=doc_batch_cleaned,
|
||||
request_id=index_attempt_md.request_id,
|
||||
adapter=adapter,
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
net_doc_change += index_pipeline_result.new_docs
|
||||
chunk_count += index_pipeline_result.total_chunks
|
||||
document_count += index_pipeline_result.total_docs
|
||||
|
||||
# resolve errors for documents that were successfully indexed
|
||||
failed_document_ids = [
|
||||
failure.failed_document.document_id
|
||||
for failure in index_pipeline_result.failures
|
||||
if failure.failed_document
|
||||
]
|
||||
successful_document_ids = [
|
||||
document.id
|
||||
for document in document_batch
|
||||
if document.id not in failed_document_ids
|
||||
]
|
||||
for document_id in successful_document_ids:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
if document_id in doc_id_to_unresolved_errors:
|
||||
logger.info(
|
||||
f"Resolving IndexAttemptError for document '{document_id}'"
|
||||
)
|
||||
for error in doc_id_to_unresolved_errors[document_id]:
|
||||
error.is_resolved = True
|
||||
db_session_temp.add(error)
|
||||
db_session_temp.commit()
|
||||
|
||||
# add brand new failures
|
||||
if index_pipeline_result.failures:
|
||||
total_failures += len(index_pipeline_result.failures)
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
for failure in index_pipeline_result.failures:
|
||||
create_index_attempt_error(
|
||||
index_attempt_id,
|
||||
ctx.cc_pair_id,
|
||||
failure,
|
||||
db_session_temp,
|
||||
)
|
||||
|
||||
_check_failure_threshold(
|
||||
total_failures,
|
||||
document_count,
|
||||
batch_num,
|
||||
index_pipeline_result.failures[-1],
|
||||
)
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
# NOTE: Postgres uses the start of the transactions when computing `NOW()`
|
||||
# so we need either to commit() or to use a new session
|
||||
update_docs_indexed(
|
||||
db_session=db_session_temp,
|
||||
index_attempt_id=index_attempt_id,
|
||||
total_docs_indexed=document_count,
|
||||
new_docs_indexed=net_doc_change,
|
||||
docs_removed_from_index=0,
|
||||
)
|
||||
|
||||
if callback:
|
||||
callback.progress("_run_indexing", len(doc_batch_cleaned))
|
||||
|
||||
# Add telemetry for indexing progress
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEXING_PROGRESS,
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": ctx.cc_pair_id,
|
||||
"current_docs_indexed": document_count,
|
||||
"current_chunks_indexed": chunk_count,
|
||||
"source": ctx.source.value,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
memory_tracer.increment_and_maybe_trace()
|
||||
|
||||
# `make sure the checkpoints aren't getting too large`at some regular interval
|
||||
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
|
||||
if batch_num % CHECKPOINT_SIZE_CHECK_INTERVAL == 0:
|
||||
check_checkpoint_size(checkpoint)
|
||||
|
||||
# save latest checkpoint
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
save_checkpoint(
|
||||
db_session=db_session_temp,
|
||||
index_attempt_id=index_attempt_id,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEXING_COMPLETE,
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": ctx.cc_pair_id,
|
||||
"total_docs_indexed": document_count,
|
||||
"total_chunks": chunk_count,
|
||||
"time_elapsed_seconds": time.monotonic() - start_time,
|
||||
"source": ctx.source.value,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Connector run exceptioned after elapsed time: "
|
||||
f"{time.monotonic() - start_time} seconds"
|
||||
)
|
||||
if isinstance(e, ConnectorValidationError):
|
||||
# On validation errors during indexing, we want to cancel the indexing attempt
|
||||
# and mark the CCPair as invalid. This prevents the connector from being
|
||||
# used in the future until the credentials are updated.
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
logger.exception(
|
||||
f"Marking attempt {index_attempt_id} as canceled due to validation error."
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=f"{CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX}{str(e)}",
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
if not index_attempt:
|
||||
# should always be set by now
|
||||
raise RuntimeError("Should never happen.")
|
||||
|
||||
VALIDATION_ERROR_THRESHOLD = 5
|
||||
|
||||
recent_index_attempts = get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
limit=VALIDATION_ERROR_THRESHOLD,
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
num_validation_errors = len(
|
||||
[
|
||||
index_attempt
|
||||
for index_attempt in recent_index_attempts
|
||||
if index_attempt.error_msg
|
||||
and index_attempt.error_msg.startswith(
|
||||
CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if num_validation_errors >= VALIDATION_ERROR_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Connector {ctx.connector_id} has {num_validation_errors} consecutive validation"
|
||||
f" errors. Marking the CC Pair as invalid."
|
||||
)
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
|
||||
elif isinstance(e, ConnectorStopSignal):
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
logger.exception(
|
||||
f"Marking attempt {index_attempt_id} as canceled due to stop signal."
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=str(e),
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
else:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
|
||||
memory_tracer.stop()
|
||||
|
||||
# we know index attempt is successful (at least partially) at this point,
|
||||
# all other cases have been short-circuited
|
||||
elapsed_time = time.monotonic() - start_time
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
# resolve entity-based errors
|
||||
for error in entity_based_unresolved_errors:
|
||||
logger.info(f"Resolving IndexAttemptError for entity '{error.entity_id}'")
|
||||
error.is_resolved = True
|
||||
db_session_temp.add(error)
|
||||
db_session_temp.commit()
|
||||
|
||||
if total_failures == 0:
|
||||
mark_attempt_succeeded(index_attempt_id, db_session_temp)
|
||||
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id or "N/A",
|
||||
event_type=MilestoneRecordType.CONNECTOR_SUCCEEDED,
|
||||
properties=None,
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Connector succeeded: "
|
||||
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
else:
|
||||
mark_attempt_partially_succeeded(index_attempt_id, db_session_temp)
|
||||
logger.info(
|
||||
f"Connector completed with some errors: "
|
||||
f"failures={total_failures} "
|
||||
f"batches={batch_num} "
|
||||
f"docs={document_count} "
|
||||
f"chunks={chunk_count} "
|
||||
f"elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
run_dt=window_end,
|
||||
)
|
||||
if ctx.should_fetch_permissions_during_indexing:
|
||||
mark_cc_pair_as_permissions_synced(
|
||||
db_session=db_session_temp,
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
start_time=window_end,
|
||||
)
|
||||
|
||||
|
||||
def run_docfetching_entrypoint(
|
||||
app: Celery,
|
||||
index_attempt_id: int,
|
||||
@@ -968,11 +368,19 @@ def connector_document_extraction(
|
||||
db_connector = index_attempt.connector_credential_pair.connector
|
||||
db_credential = index_attempt.connector_credential_pair.credential
|
||||
is_primary = index_attempt.search_settings.status == IndexModelStatus.PRESENT
|
||||
|
||||
from_beginning = index_attempt.from_beginning
|
||||
has_successful_attempt = (
|
||||
index_attempt.connector_credential_pair.last_successful_index_time
|
||||
is not None
|
||||
)
|
||||
# Use higher priority for first-time indexing to ensure new connectors
|
||||
# get processed before re-indexing of existing connectors
|
||||
docprocessing_priority = (
|
||||
OnyxCeleryPriority.MEDIUM
|
||||
if has_successful_attempt
|
||||
else OnyxCeleryPriority.HIGH
|
||||
)
|
||||
|
||||
earliest_index_time = (
|
||||
db_connector.indexing_start.timestamp()
|
||||
@@ -1095,6 +503,7 @@ def connector_document_extraction(
|
||||
tenant_id,
|
||||
app,
|
||||
most_recent_attempt,
|
||||
docprocessing_priority,
|
||||
)
|
||||
last_batch_num = reissued_batch_count + completed_batches
|
||||
index_attempt.completed_batches = completed_batches
|
||||
@@ -1207,7 +616,7 @@ def connector_document_extraction(
|
||||
OnyxCeleryTask.DOCPROCESSING_TASK,
|
||||
kwargs=processing_batch_data,
|
||||
queue=OnyxCeleryQueues.DOCPROCESSING,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
priority=docprocessing_priority,
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
@@ -1358,6 +767,7 @@ def reissue_old_batches(
|
||||
tenant_id: str,
|
||||
app: Celery,
|
||||
most_recent_attempt: IndexAttempt | None,
|
||||
priority: OnyxCeleryPriority,
|
||||
) -> tuple[int, int]:
|
||||
# When loading from a checkpoint, we need to start new docprocessing tasks
|
||||
# tied to the new index attempt for any batches left over in the file store
|
||||
@@ -1385,7 +795,7 @@ def reissue_old_batches(
|
||||
"batch_num": path_info.batch_num, # use same batch num as previously
|
||||
},
|
||||
queue=OnyxCeleryQueues.DOCPROCESSING,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
priority=priority,
|
||||
)
|
||||
recent_batches = most_recent_attempt.completed_batches if most_recent_attempt else 0
|
||||
# resume from the batch num of the last attempt. This should be one more
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
"""
|
||||
Module for handling chat-related milestone tracking and telemetry.
|
||||
"""
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NO_AUTH_USER_ID
|
||||
from onyx.db.milestone import check_multi_assistant_milestone
|
||||
from onyx.db.milestone import create_milestone_if_not_exists
|
||||
from onyx.db.milestone import update_user_assistant_milestone
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
|
||||
|
||||
def process_multi_assistant_milestone(
|
||||
user: User | None,
|
||||
assistant_id: int,
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""
|
||||
Process the multi-assistant milestone for a user.
|
||||
|
||||
This function:
|
||||
1. Creates or retrieves the multi-assistant milestone
|
||||
2. Updates the milestone with the current assistant usage
|
||||
3. Checks if the milestone was just achieved
|
||||
4. Sends telemetry if the milestone was just hit
|
||||
|
||||
Args:
|
||||
user: The user for whom to process the milestone (can be None for anonymous users)
|
||||
assistant_id: The ID of the assistant being used
|
||||
tenant_id: The current tenant ID
|
||||
db_session: Database session for queries
|
||||
"""
|
||||
# Create or retrieve the multi-assistant milestone
|
||||
multi_assistant_milestone, _is_new = create_milestone_if_not_exists(
|
||||
user=user,
|
||||
event_type=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Update the milestone with the current assistant usage
|
||||
update_user_assistant_milestone(
|
||||
milestone=multi_assistant_milestone,
|
||||
user_id=str(user.id) if user else NO_AUTH_USER_ID,
|
||||
assistant_id=assistant_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Check if the milestone was just achieved
|
||||
_, just_hit_multi_assistant_milestone = check_multi_assistant_milestone(
|
||||
milestone=multi_assistant_milestone,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Send telemetry if the milestone was just hit
|
||||
if just_hit_multi_assistant_milestone:
|
||||
mt_cloud_telemetry(
|
||||
distinct_id=tenant_id,
|
||||
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
properties=None,
|
||||
)
|
||||
@@ -1,10 +1,12 @@
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from queue import Empty
|
||||
from typing import Any
|
||||
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketException
|
||||
@@ -18,39 +20,77 @@ class ChatStateContainer:
|
||||
|
||||
This container holds the partial state that can be saved to the database
|
||||
if the generation is stopped by the user or completes normally.
|
||||
|
||||
Thread-safe: All write operations are protected by a lock to ensure safe
|
||||
concurrent access from multiple threads. For thread-safe reads, use the
|
||||
getter methods. Direct attribute access is not thread-safe.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
# These are collected at the end after the entire tool call is completed
|
||||
self.tool_calls: list[ToolCallInfo] = []
|
||||
# This is accumulated during the streaming
|
||||
self.reasoning_tokens: str | None = None
|
||||
# This is accumulated during the streaming of the answer
|
||||
self.answer_tokens: str | None = None
|
||||
# Store citation mapping for building citation_docs_info during partial saves
|
||||
self.citation_to_doc: dict[int, SearchDoc] = {}
|
||||
self.citation_to_doc: CitationMapping = {}
|
||||
# True if this turn is a clarification question (deep research flow)
|
||||
self.is_clarification: bool = False
|
||||
|
||||
def add_tool_call(self, tool_call: ToolCallInfo) -> None:
|
||||
"""Add a tool call to the accumulated state."""
|
||||
self.tool_calls.append(tool_call)
|
||||
with self._lock:
|
||||
self.tool_calls.append(tool_call)
|
||||
|
||||
def set_reasoning_tokens(self, reasoning: str | None) -> None:
|
||||
"""Set the reasoning tokens from the final answer generation."""
|
||||
self.reasoning_tokens = reasoning
|
||||
with self._lock:
|
||||
self.reasoning_tokens = reasoning
|
||||
|
||||
def set_answer_tokens(self, answer: str | None) -> None:
|
||||
"""Set the answer tokens from the final answer generation."""
|
||||
self.answer_tokens = answer
|
||||
with self._lock:
|
||||
self.answer_tokens = answer
|
||||
|
||||
def set_citation_mapping(self, citation_to_doc: dict[int, Any]) -> None:
|
||||
def set_citation_mapping(self, citation_to_doc: CitationMapping) -> None:
|
||||
"""Set the citation mapping from citation processor."""
|
||||
self.citation_to_doc = citation_to_doc
|
||||
with self._lock:
|
||||
self.citation_to_doc = citation_to_doc
|
||||
|
||||
def set_is_clarification(self, is_clarification: bool) -> None:
|
||||
"""Set whether this turn is a clarification question."""
|
||||
self.is_clarification = is_clarification
|
||||
with self._lock:
|
||||
self.is_clarification = is_clarification
|
||||
|
||||
def get_answer_tokens(self) -> str | None:
|
||||
"""Thread-safe getter for answer_tokens."""
|
||||
with self._lock:
|
||||
return self.answer_tokens
|
||||
|
||||
def get_reasoning_tokens(self) -> str | None:
|
||||
"""Thread-safe getter for reasoning_tokens."""
|
||||
with self._lock:
|
||||
return self.reasoning_tokens
|
||||
|
||||
def get_tool_calls(self) -> list[ToolCallInfo]:
|
||||
"""Thread-safe getter for tool_calls (returns a copy)."""
|
||||
with self._lock:
|
||||
return self.tool_calls.copy()
|
||||
|
||||
def get_citation_to_doc(self) -> CitationMapping:
|
||||
"""Thread-safe getter for citation_to_doc (returns a copy)."""
|
||||
with self._lock:
|
||||
return self.citation_to_doc.copy()
|
||||
|
||||
def get_is_clarification(self) -> bool:
|
||||
"""Thread-safe getter for is_clarification."""
|
||||
with self._lock:
|
||||
return self.is_clarification
|
||||
|
||||
|
||||
def run_chat_llm_with_state_containers(
|
||||
def run_chat_loop_with_state_containers(
|
||||
func: Callable[..., None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
@@ -74,7 +114,7 @@ def run_chat_llm_with_state_containers(
|
||||
**kwargs: Additional keyword arguments for func
|
||||
|
||||
Usage:
|
||||
packets = run_chat_llm_with_state_containers(
|
||||
packets = run_chat_loop_with_state_containers(
|
||||
my_func,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
@@ -95,7 +135,7 @@ def run_chat_llm_with_state_containers(
|
||||
# If execution fails, emit an exception packet
|
||||
emitter.emit(
|
||||
Packet(
|
||||
turn_index=0,
|
||||
placement=Placement(turn_index=0),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -38,6 +38,7 @@ from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
@@ -49,8 +50,10 @@ from onyx.llm.override_models import LLMOverride
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.prompts.chat_prompts import ADDITIONAL_CONTEXT_PROMPT
|
||||
from onyx.prompts.chat_prompts import TOOL_CALL_RESPONSE_CROSS_MESSAGE
|
||||
from onyx.prompts.tool_prompts import TOOL_CALL_FAILURE_PROMPT
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
@@ -71,7 +74,6 @@ def prepare_chat_message_request(
|
||||
retrieval_details: RetrievalDetails | None,
|
||||
rerank_settings: RerankingDetails | None,
|
||||
db_session: Session,
|
||||
use_agentic_search: bool = False,
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
llm_override: LLMOverride | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
@@ -98,7 +100,6 @@ def prepare_chat_message_request(
|
||||
search_doc_ids=None,
|
||||
retrieval_options=retrieval_details,
|
||||
rerank_settings=rerank_settings,
|
||||
use_agentic_search=use_agentic_search,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
llm_override=llm_override,
|
||||
allowed_tool_ids=allowed_tool_ids,
|
||||
@@ -483,10 +484,14 @@ def load_chat_file(
|
||||
|
||||
if file_type.is_text_file():
|
||||
try:
|
||||
content_text = content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
content_text = extract_file_text(
|
||||
file=file_io,
|
||||
file_name=file_descriptor.get("name") or "",
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to decode text content for file {file_descriptor['id']}"
|
||||
f"Failed to retrieve content for file {file_descriptor['id']}: {str(e)}"
|
||||
)
|
||||
|
||||
# Get token count from UserFile if available
|
||||
@@ -581,9 +586,16 @@ def convert_chat_history(
|
||||
|
||||
# Add text files as separate messages before the user message
|
||||
for text_file in text_files:
|
||||
file_text = text_file.content_text or ""
|
||||
filename = text_file.filename
|
||||
message = (
|
||||
f"File: {filename}\n{file_text}\nEnd of File"
|
||||
if filename
|
||||
else file_text
|
||||
)
|
||||
simple_messages.append(
|
||||
ChatMessageSimple(
|
||||
message=text_file.content_text or "",
|
||||
message=message,
|
||||
token_count=text_file.token_count,
|
||||
message_type=MessageType.USER,
|
||||
image_files=None,
|
||||
@@ -729,3 +741,38 @@ def is_last_assistant_message_clarification(chat_history: list[ChatMessage]) ->
|
||||
if message.message_type == MessageType.ASSISTANT:
|
||||
return message.is_clarification
|
||||
return False
|
||||
|
||||
|
||||
def create_tool_call_failure_messages(
|
||||
tool_call: ToolCallKickoff, token_counter: Callable[[str], int]
|
||||
) -> list[ChatMessageSimple]:
|
||||
"""Create ChatMessageSimple objects for a failed tool call.
|
||||
|
||||
Creates two messages:
|
||||
1. The tool call message itself
|
||||
2. A failure response message indicating the tool call failed
|
||||
|
||||
Args:
|
||||
tool_call: The ToolCallKickoff object representing the failed tool call
|
||||
token_counter: Function to count tokens in a message string
|
||||
|
||||
Returns:
|
||||
List containing two ChatMessageSimple objects: tool call message and failure response
|
||||
"""
|
||||
tool_call_msg = ChatMessageSimple(
|
||||
message=tool_call.to_msg_str(),
|
||||
token_count=token_counter(tool_call.to_msg_str()),
|
||||
message_type=MessageType.TOOL_CALL,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
|
||||
failure_response_msg = ChatMessageSimple(
|
||||
message=TOOL_CALL_FAILURE_PROMPT,
|
||||
token_count=token_counter(TOOL_CALL_FAILURE_PROMPT),
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
|
||||
return [tool_call_msg, failure_response_msg]
|
||||
|
||||
@@ -4,13 +4,15 @@ Dynamic Citation Processor for LLM Responses
|
||||
This module provides a citation processor that can:
|
||||
- Accept citation number to SearchDoc mappings dynamically
|
||||
- Process token streams from LLMs to extract citations
|
||||
- Remove citation markers from output text
|
||||
- Emit CitationInfo objects for detected citations
|
||||
- Optionally replace citation markers with formatted markdown links
|
||||
- Emit CitationInfo objects for detected citations (when replacing)
|
||||
- Track all seen citations regardless of replacement mode
|
||||
- Maintain a list of cited documents in order of first citation
|
||||
"""
|
||||
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import TypeAlias
|
||||
|
||||
from onyx.configs.chat_configs import STOP_STREAM_PAT
|
||||
from onyx.context.search.models import SearchDoc
|
||||
@@ -21,8 +23,11 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
CitationMapping: TypeAlias = dict[int, SearchDoc]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Utility functions (copied for self-containment)
|
||||
# Utility functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@@ -43,19 +48,29 @@ class DynamicCitationProcessor:
|
||||
|
||||
This processor is designed for multi-turn conversations where the citation
|
||||
number to document mapping is provided externally. It processes streaming
|
||||
tokens from an LLM, detects citations (e.g., [1], [2,3], [[4]]), and:
|
||||
tokens from an LLM, detects citations (e.g., [1], [2,3], [[4]]), and based
|
||||
on the `replace_citation_tokens` setting:
|
||||
|
||||
1. Removes citation markers from the output text
|
||||
2. Emits CitationInfo objects for tracking
|
||||
3. Maintains the order in which documents were first cited
|
||||
When replace_citation_tokens=True (default):
|
||||
1. Replaces citation markers with formatted markdown links (e.g., [[1]](url))
|
||||
2. Emits CitationInfo objects for tracking
|
||||
3. Maintains the order in which documents were first cited
|
||||
|
||||
When replace_citation_tokens=False:
|
||||
1. Preserves original citation markers in the output text
|
||||
2. Does NOT emit CitationInfo objects
|
||||
3. Still tracks all seen citations via get_seen_citations()
|
||||
|
||||
Features:
|
||||
- Accepts citation number → SearchDoc mapping via update_citation_mapping()
|
||||
- Processes tokens from LLM and removes citation markers
|
||||
- Holds back tokens that might be partial citations
|
||||
- Maintains list of cited SearchDocs in order of first citation
|
||||
- Accepts citation number → SearchDoc mapping via update_citation_mapping()
|
||||
- Configurable citation replacement behavior at initialization
|
||||
- Always tracks seen citations regardless of replacement mode
|
||||
- Holds back tokens that might be partial citations
|
||||
- Maintains list of cited SearchDocs in order of first citation
|
||||
- Handles unicode bracket variants (【】, [])
|
||||
- Skips citation processing inside code blocks
|
||||
|
||||
Example:
|
||||
Example (with citation replacement - default):
|
||||
processor = DynamicCitationProcessor()
|
||||
|
||||
# Set up citation mapping
|
||||
@@ -65,37 +80,55 @@ class DynamicCitationProcessor:
|
||||
for token in llm_stream:
|
||||
for result in processor.process_token(token):
|
||||
if isinstance(result, str):
|
||||
print(result) # Display text (citations removed)
|
||||
print(result) # Display text with [[1]](url) format
|
||||
elif isinstance(result, CitationInfo):
|
||||
handle_citation(result) # Track citation
|
||||
|
||||
# Update mapping with more documents
|
||||
processor.update_citation_mapping({3: search_doc3, 4: search_doc4})
|
||||
|
||||
# Continue processing...
|
||||
|
||||
# Get cited documents at the end
|
||||
cited_docs = processor.get_cited_documents()
|
||||
|
||||
Example (without citation replacement):
|
||||
processor = DynamicCitationProcessor(replace_citation_tokens=False)
|
||||
processor.update_citation_mapping({1: search_doc1, 2: search_doc2})
|
||||
|
||||
# Process tokens from LLM
|
||||
for token in llm_stream:
|
||||
for result in processor.process_token(token):
|
||||
# Only strings are yielded, no CitationInfo objects
|
||||
print(result) # Display text with original [1] format preserved
|
||||
|
||||
# Get all seen citations after processing
|
||||
seen_citations = processor.get_seen_citations() # {1: search_doc1, ...}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
replace_citation_tokens: bool = True,
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
):
|
||||
"""
|
||||
Initialize the citation processor.
|
||||
|
||||
Args:
|
||||
stop_stream: Optional stop token to halt processing early
|
||||
replace_citation_tokens: If True (default), citations like [1] are replaced
|
||||
with formatted markdown links like [[1]](url) and CitationInfo objects
|
||||
are emitted. If False, original citation text is preserved in output
|
||||
and no CitationInfo objects are emitted. Regardless of this setting,
|
||||
all seen citations are tracked and available via get_seen_citations().
|
||||
stop_stream: Optional stop token pattern to halt processing early.
|
||||
When this pattern is detected in the token stream, processing stops.
|
||||
Defaults to STOP_STREAM_PAT from chat configs.
|
||||
"""
|
||||
# Citation mapping from citation number to SearchDoc
|
||||
self.citation_to_doc: dict[int, SearchDoc] = {}
|
||||
self.citation_to_doc: CitationMapping = {}
|
||||
self.seen_citations: CitationMapping = {} # citation num -> SearchDoc
|
||||
|
||||
# Token processing state
|
||||
self.llm_out = "" # entire output so far
|
||||
self.curr_segment = "" # tokens held for citation processing
|
||||
self.hold = "" # tokens held for stop token processing
|
||||
self.stop_stream = stop_stream
|
||||
self.replace_citation_tokens = replace_citation_tokens
|
||||
|
||||
# Citation tracking
|
||||
self.cited_documents_in_order: list[SearchDoc] = (
|
||||
@@ -119,7 +152,11 @@ class DynamicCitationProcessor:
|
||||
r"([\[【[]{2}\d+[\]】]]{2})|([\[【[]\d+(?:, ?\d+)*[\]】]])"
|
||||
)
|
||||
|
||||
def update_citation_mapping(self, citation_mapping: dict[int, SearchDoc]) -> None:
|
||||
def update_citation_mapping(
|
||||
self,
|
||||
citation_mapping: CitationMapping,
|
||||
update_duplicate_keys: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Update the citation number to SearchDoc mapping.
|
||||
|
||||
@@ -128,15 +165,25 @@ class DynamicCitationProcessor:
|
||||
|
||||
Args:
|
||||
citation_mapping: Dictionary mapping citation numbers (1, 2, 3, ...) to SearchDoc objects
|
||||
update_duplicate_keys: If True, update existing mappings with new values when keys overlap.
|
||||
If False (default), filter out duplicate keys and only add non-duplicates.
|
||||
The default behavior is useful when OpenURL may have the same citation number as a
|
||||
Web Search result - in those cases, we keep the web search citation and snippet etc.
|
||||
"""
|
||||
# Filter out duplicate keys and only add non-duplicates
|
||||
# Reason for this is that OpenURL may have the same citation number as a Web Search result
|
||||
# For those, we should just keep the web search citation and snippet etc.
|
||||
duplicate_keys = set(citation_mapping.keys()) & set(self.citation_to_doc.keys())
|
||||
non_duplicate_mapping = {
|
||||
k: v for k, v in citation_mapping.items() if k not in duplicate_keys
|
||||
}
|
||||
self.citation_to_doc.update(non_duplicate_mapping)
|
||||
if update_duplicate_keys:
|
||||
# Update all mappings, including duplicates
|
||||
self.citation_to_doc.update(citation_mapping)
|
||||
else:
|
||||
# Filter out duplicate keys and only add non-duplicates
|
||||
# Reason for this is that OpenURL may have the same citation number as a Web Search result
|
||||
# For those, we should just keep the web search citation and snippet etc.
|
||||
duplicate_keys = set(citation_mapping.keys()) & set(
|
||||
self.citation_to_doc.keys()
|
||||
)
|
||||
non_duplicate_mapping = {
|
||||
k: v for k, v in citation_mapping.items() if k not in duplicate_keys
|
||||
}
|
||||
self.citation_to_doc.update(non_duplicate_mapping)
|
||||
|
||||
def process_token(
|
||||
self, token: str | None
|
||||
@@ -147,17 +194,24 @@ class DynamicCitationProcessor:
|
||||
This method:
|
||||
1. Accumulates tokens until a complete citation or non-citation is found
|
||||
2. Holds back potential partial citations (e.g., "[", "[1")
|
||||
3. Yields text chunks when they're safe to display (with citations REMOVED)
|
||||
4. Yields CitationInfo when citations are detected
|
||||
5. Handles code blocks (avoids processing citations inside code)
|
||||
6. Handles stop tokens
|
||||
3. Yields text chunks when they're safe to display
|
||||
4. Handles code blocks (avoids processing citations inside code)
|
||||
5. Handles stop tokens
|
||||
6. Always tracks seen citations in self.seen_citations
|
||||
|
||||
Behavior depends on the `replace_citation_tokens` setting from __init__:
|
||||
- If True: Citations are replaced with [[n]](url) format and CitationInfo
|
||||
objects are yielded before each formatted citation
|
||||
- If False: Original citation text (e.g., [1]) is preserved in output
|
||||
and no CitationInfo objects are yielded
|
||||
|
||||
Args:
|
||||
token: The next token from the LLM stream, or None to signal end of stream
|
||||
token: The next token from the LLM stream, or None to signal end of stream.
|
||||
Pass None to flush any remaining buffered text at end of stream.
|
||||
|
||||
Yields:
|
||||
- str: Text chunks to display (citations removed)
|
||||
- CitationInfo: Citation metadata when a citation is detected
|
||||
str: Text chunks to display. Citation format depends on replace_citation_tokens.
|
||||
CitationInfo: Citation metadata (only when replace_citation_tokens=True)
|
||||
"""
|
||||
# None -> end of stream, flush remaining segment
|
||||
if token is None:
|
||||
@@ -250,17 +304,24 @@ class DynamicCitationProcessor:
|
||||
yield intermatch_str
|
||||
|
||||
# Process the citation (returns formatted citation text and CitationInfo objects)
|
||||
# Always tracks seen citations regardless of strip_citations flag
|
||||
citation_text, citation_info_list = self._process_citation(
|
||||
match, has_leading_space
|
||||
match, has_leading_space, self.replace_citation_tokens
|
||||
)
|
||||
# Yield CitationInfo objects BEFORE the citation text
|
||||
# This allows the frontend to receive citation metadata before the token
|
||||
# that contains [[n]](link), enabling immediate rendering
|
||||
for citation in citation_info_list:
|
||||
yield citation
|
||||
# Then yield the formatted citation text
|
||||
if citation_text:
|
||||
yield citation_text
|
||||
|
||||
if self.replace_citation_tokens:
|
||||
# Yield CitationInfo objects BEFORE the citation text
|
||||
# This allows the frontend to receive citation metadata before the token
|
||||
# that contains [[n]](link), enabling immediate rendering
|
||||
for citation in citation_info_list:
|
||||
yield citation
|
||||
# Then yield the formatted citation text
|
||||
if citation_text:
|
||||
yield citation_text
|
||||
else:
|
||||
# When not stripping, yield the original citation text unchanged
|
||||
yield match.group()
|
||||
|
||||
self.non_citation_count = 0
|
||||
|
||||
# Leftover text could be part of next citation
|
||||
@@ -277,27 +338,42 @@ class DynamicCitationProcessor:
|
||||
yield result
|
||||
|
||||
def _process_citation(
|
||||
self, match: re.Match, has_leading_space: bool
|
||||
self, match: re.Match, has_leading_space: bool, replace_tokens: bool = True
|
||||
) -> tuple[str, list[CitationInfo]]:
|
||||
"""
|
||||
Process a single citation match and return formatted citation text and citation info objects.
|
||||
|
||||
The match string can look like '[1]', '[1, 13, 6]', '[[4]]', '【1】', etc.
|
||||
This is an internal method called by process_token(). The match string can be
|
||||
in various formats: '[1]', '[1, 13, 6]', '[[4]]', '【1】', '[1]', etc.
|
||||
|
||||
This method:
|
||||
This method always:
|
||||
1. Extracts citation numbers from the match
|
||||
2. Looks up the corresponding SearchDoc from the mapping
|
||||
3. Skips duplicate citations if they were recently cited
|
||||
4. Creates formatted citation text like [n](link) for each citation
|
||||
3. Tracks seen citations in self.seen_citations (regardless of replace_tokens)
|
||||
|
||||
When replace_tokens=True (controlled by self.replace_citation_tokens):
|
||||
4. Creates formatted citation text as [[n]](url)
|
||||
5. Creates CitationInfo objects for new citations
|
||||
6. Handles deduplication of recently cited documents
|
||||
|
||||
When replace_tokens=False:
|
||||
4. Returns empty string and empty list (caller yields original match text)
|
||||
|
||||
Args:
|
||||
match: Regex match object containing the citation
|
||||
has_leading_space: Whether the text before the citation has a leading space
|
||||
match: Regex match object containing the citation pattern
|
||||
has_leading_space: Whether the text immediately before this citation
|
||||
ends with whitespace. Used to determine if a leading space should
|
||||
be added to the formatted output.
|
||||
replace_tokens: If True, return formatted text and CitationInfo objects.
|
||||
If False, only track seen citations and return empty results.
|
||||
This is passed from self.replace_citation_tokens by the caller.
|
||||
|
||||
Returns:
|
||||
Tuple of (formatted_citation_text, list[CitationInfo])
|
||||
- formatted_citation_text: Markdown-formatted citation text like [1](link) [2](link)
|
||||
- citation_info_list: List of CitationInfo objects
|
||||
Tuple of (formatted_citation_text, citation_info_list):
|
||||
- formatted_citation_text: Markdown-formatted citation text like
|
||||
"[[1]](https://example.com)" or empty string if replace_tokens=False
|
||||
- citation_info_list: List of CitationInfo objects for newly cited
|
||||
documents, or empty list if replace_tokens=False
|
||||
"""
|
||||
citation_str: str = match.group() # e.g., '[1]', '[1, 2, 3]', '[[1]]', '【1】'
|
||||
formatted = (
|
||||
@@ -335,7 +411,14 @@ class DynamicCitationProcessor:
|
||||
doc_id = search_doc.document_id
|
||||
link = search_doc.link or ""
|
||||
|
||||
# Always format the citation text as [[n]](link)
|
||||
# Always track seen citations regardless of replace_tokens setting
|
||||
self.seen_citations[num] = search_doc
|
||||
|
||||
# When not replacing citation tokens, skip the rest of the processing
|
||||
if not replace_tokens:
|
||||
continue
|
||||
|
||||
# Format the citation text as [[n]](link)
|
||||
formatted_citation_parts.append(f"[[{num}]]({link})")
|
||||
|
||||
# Skip creating CitationInfo for citations of the same work if cited recently (deduplication)
|
||||
@@ -367,8 +450,14 @@ class DynamicCitationProcessor:
|
||||
"""
|
||||
Get the list of cited SearchDoc objects in the order they were first cited.
|
||||
|
||||
Note: This list is only populated when `replace_citation_tokens=True`.
|
||||
When `replace_citation_tokens=False`, this will return an empty list.
|
||||
Use get_seen_citations() instead if you need to track citations without
|
||||
replacing them.
|
||||
|
||||
Returns:
|
||||
List of SearchDoc objects
|
||||
List of SearchDoc objects in the order they were first cited.
|
||||
Empty list if replace_citation_tokens=False.
|
||||
"""
|
||||
return self.cited_documents_in_order
|
||||
|
||||
@@ -376,34 +465,89 @@ class DynamicCitationProcessor:
|
||||
"""
|
||||
Get the list of cited document IDs in the order they were first cited.
|
||||
|
||||
Note: This list is only populated when `replace_citation_tokens=True`.
|
||||
When `replace_citation_tokens=False`, this will return an empty list.
|
||||
Use get_seen_citations() instead if you need to track citations without
|
||||
replacing them.
|
||||
|
||||
Returns:
|
||||
List of document IDs (strings)
|
||||
List of document IDs (strings) in the order they were first cited.
|
||||
Empty list if replace_citation_tokens=False.
|
||||
"""
|
||||
return [doc.document_id for doc in self.cited_documents_in_order]
|
||||
|
||||
def get_seen_citations(self) -> CitationMapping:
|
||||
"""
|
||||
Get all seen citations as a mapping from citation number to SearchDoc.
|
||||
|
||||
This returns all citations that have been encountered during processing,
|
||||
regardless of the `replace_citation_tokens` setting. Citations are tracked
|
||||
whenever they are parsed, making this useful for cases where you need to
|
||||
know which citations appeared in the text without replacing them.
|
||||
|
||||
This is particularly useful when `replace_citation_tokens=False`, as
|
||||
get_cited_documents() will be empty in that case, but get_seen_citations()
|
||||
will still contain all the citations that were found.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping citation numbers (int) to SearchDoc objects.
|
||||
The dictionary is keyed by the citation number as it appeared in
|
||||
the text (e.g., {1: SearchDoc(...), 3: SearchDoc(...)}).
|
||||
"""
|
||||
return self.seen_citations
|
||||
|
||||
@property
|
||||
def num_cited_documents(self) -> int:
|
||||
"""Get the number of documents that have been cited."""
|
||||
"""
|
||||
Get the number of unique documents that have been cited.
|
||||
|
||||
Note: This count is only updated when `replace_citation_tokens=True`.
|
||||
When `replace_citation_tokens=False`, this will always return 0.
|
||||
Use len(get_seen_citations()) instead if you need to count citations
|
||||
without replacing them.
|
||||
|
||||
Returns:
|
||||
Number of unique documents cited. 0 if replace_citation_tokens=False.
|
||||
"""
|
||||
return len(self.cited_document_ids)
|
||||
|
||||
def reset_recent_citations(self) -> None:
|
||||
"""
|
||||
Reset the recent citations tracker.
|
||||
|
||||
This can be called to allow previously cited documents to be cited again
|
||||
without being filtered out by the deduplication logic.
|
||||
The processor tracks "recently cited" documents to avoid emitting duplicate
|
||||
CitationInfo objects for the same document when it's cited multiple times
|
||||
in close succession. This method clears that tracker.
|
||||
|
||||
This is primarily useful when `replace_citation_tokens=True` to allow
|
||||
previously cited documents to emit CitationInfo objects again. Has no
|
||||
effect when `replace_citation_tokens=False`.
|
||||
|
||||
The recent citation tracker is also automatically cleared when more than
|
||||
5 non-citation characters are processed between citations.
|
||||
"""
|
||||
self.recent_cited_documents.clear()
|
||||
|
||||
def get_next_citation_number(self) -> int:
|
||||
"""
|
||||
Get the next available citation number.
|
||||
Get the next available citation number for adding new documents to the mapping.
|
||||
|
||||
This method returns the next citation number that should be used for new documents.
|
||||
If no citations exist yet, it returns 1. Otherwise, it returns max + 1.
|
||||
This method returns the next citation number that should be used when adding
|
||||
new documents via update_citation_mapping(). Useful when dynamically adding
|
||||
citations during processing (e.g., from tool results like web search).
|
||||
|
||||
If no citations exist yet in the mapping, returns 1.
|
||||
Otherwise, returns max(existing_citation_numbers) + 1.
|
||||
|
||||
Returns:
|
||||
The next available citation number (1-indexed)
|
||||
The next available citation number (1-indexed integer).
|
||||
|
||||
Example:
|
||||
# After adding citations 1, 2, 3
|
||||
processor.get_next_citation_number() # Returns 4
|
||||
|
||||
# With non-sequential citations 1, 5, 10
|
||||
processor.get_next_citation_number() # Returns 11
|
||||
"""
|
||||
if not self.citation_to_doc:
|
||||
return 1
|
||||
|
||||
177
backend/onyx/chat/citation_utils.py
Normal file
177
backend/onyx/chat/citation_utils.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import re
|
||||
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
|
||||
from onyx.tools.models import ToolResponse
|
||||
|
||||
|
||||
def update_citation_processor_from_tool_response(
|
||||
tool_response: ToolResponse,
|
||||
citation_processor: DynamicCitationProcessor,
|
||||
) -> None:
|
||||
"""Update citation processor if this was a citeable tool with a SearchDocsResponse.
|
||||
|
||||
Checks if the tool call is citeable and if the response contains a SearchDocsResponse,
|
||||
then creates a mapping from citation numbers to SearchDoc objects and updates the
|
||||
citation processor.
|
||||
|
||||
Args:
|
||||
tool_response: The response from the tool execution (must have tool_call set)
|
||||
citation_processor: The DynamicCitationProcessor to update
|
||||
"""
|
||||
# Early return if tool_call is not set
|
||||
if tool_response.tool_call is None:
|
||||
return
|
||||
|
||||
# Update citation processor if this was a search tool
|
||||
if tool_response.tool_call.tool_name in CITEABLE_TOOLS_NAMES:
|
||||
# Check if the rich_response is a SearchDocsResponse
|
||||
if isinstance(tool_response.rich_response, SearchDocsResponse):
|
||||
search_response = tool_response.rich_response
|
||||
|
||||
# Create mapping from citation number to SearchDoc
|
||||
citation_to_doc: CitationMapping = {}
|
||||
for (
|
||||
citation_num,
|
||||
doc_id,
|
||||
) in search_response.citation_mapping.items():
|
||||
# Find the SearchDoc with this doc_id
|
||||
matching_doc = next(
|
||||
(
|
||||
doc
|
||||
for doc in search_response.search_docs
|
||||
if doc.document_id == doc_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if matching_doc:
|
||||
citation_to_doc[citation_num] = matching_doc
|
||||
|
||||
# Update the citation processor
|
||||
citation_processor.update_citation_mapping(citation_to_doc)
|
||||
|
||||
|
||||
def collapse_citations(
|
||||
answer_text: str,
|
||||
existing_citation_mapping: CitationMapping,
|
||||
new_citation_mapping: CitationMapping,
|
||||
) -> tuple[str, CitationMapping]:
|
||||
"""Collapse the citations in the text to use the smallest possible numbers.
|
||||
|
||||
This function takes citations in the text (like [25], [30], etc.) and replaces them
|
||||
with the smallest possible numbers. It starts numbering from the next available
|
||||
integer after the existing citation mapping. If a citation refers to a document
|
||||
that already exists in the existing citation mapping (matched by document_id),
|
||||
it uses the existing citation number instead of assigning a new one.
|
||||
|
||||
Args:
|
||||
answer_text: The text containing citations to collapse (e.g., "See [25] and [30]")
|
||||
existing_citation_mapping: Citations already processed/displayed. These mappings
|
||||
are preserved unchanged in the output.
|
||||
new_citation_mapping: Citations from the current text that need to be collapsed.
|
||||
The keys are the citation numbers as they appear in answer_text.
|
||||
|
||||
Returns:
|
||||
A tuple of (updated_text, combined_mapping) where:
|
||||
- updated_text: The text with citations replaced with collapsed numbers
|
||||
- combined_mapping: All values from existing_citation_mapping plus the new
|
||||
mappings with their (possibly renumbered) keys
|
||||
"""
|
||||
# Build a reverse lookup: document_id -> existing citation number
|
||||
doc_id_to_existing_citation: dict[str, int] = {
|
||||
doc.document_id: citation_num
|
||||
for citation_num, doc in existing_citation_mapping.items()
|
||||
}
|
||||
|
||||
# Determine the next available citation number
|
||||
if existing_citation_mapping:
|
||||
next_citation_num = max(existing_citation_mapping.keys()) + 1
|
||||
else:
|
||||
next_citation_num = 1
|
||||
|
||||
# Build the mapping from old citation numbers (in new_citation_mapping) to new numbers
|
||||
old_to_new: dict[int, int] = {}
|
||||
additional_mappings: CitationMapping = {}
|
||||
|
||||
for old_num, search_doc in new_citation_mapping.items():
|
||||
doc_id = search_doc.document_id
|
||||
|
||||
# Check if this document already exists in existing citations
|
||||
if doc_id in doc_id_to_existing_citation:
|
||||
# Use the existing citation number
|
||||
old_to_new[old_num] = doc_id_to_existing_citation[doc_id]
|
||||
else:
|
||||
# Check if we've already assigned a new number to this document
|
||||
# (handles case where same doc appears with different old numbers)
|
||||
existing_new_num = None
|
||||
for mapped_old, mapped_new in old_to_new.items():
|
||||
if (
|
||||
mapped_old in new_citation_mapping
|
||||
and new_citation_mapping[mapped_old].document_id == doc_id
|
||||
):
|
||||
existing_new_num = mapped_new
|
||||
break
|
||||
|
||||
if existing_new_num is not None:
|
||||
old_to_new[old_num] = existing_new_num
|
||||
else:
|
||||
# Assign the next available number
|
||||
old_to_new[old_num] = next_citation_num
|
||||
additional_mappings[next_citation_num] = search_doc
|
||||
next_citation_num += 1
|
||||
|
||||
# Pattern to match citations like [25], [1, 2, 3], [[25]], etc.
|
||||
# Also matches unicode bracket variants: 【】, []
|
||||
citation_pattern = re.compile(
|
||||
r"([\[【[]{2}\d+[\]】]]{2})|([\[【[]\d+(?:, ?\d+)*[\]】]])"
|
||||
)
|
||||
|
||||
def replace_citation(match: re.Match) -> str:
|
||||
"""Replace citation numbers in a match with their new collapsed values."""
|
||||
citation_str = match.group()
|
||||
|
||||
# Determine bracket style
|
||||
if (
|
||||
citation_str.startswith("[[")
|
||||
or citation_str.startswith("【【")
|
||||
or citation_str.startswith("[[")
|
||||
):
|
||||
open_bracket = citation_str[:2]
|
||||
close_bracket = citation_str[-2:]
|
||||
content = citation_str[2:-2]
|
||||
else:
|
||||
open_bracket = citation_str[0]
|
||||
close_bracket = citation_str[-1]
|
||||
content = citation_str[1:-1]
|
||||
|
||||
# Parse and replace citation numbers
|
||||
new_nums = []
|
||||
for num_str in content.split(","):
|
||||
num_str = num_str.strip()
|
||||
if not num_str:
|
||||
continue
|
||||
try:
|
||||
num = int(num_str)
|
||||
# Only replace if we have a mapping for this number
|
||||
if num in old_to_new:
|
||||
new_nums.append(str(old_to_new[num]))
|
||||
else:
|
||||
# Keep original if not in our mapping
|
||||
new_nums.append(num_str)
|
||||
except ValueError:
|
||||
new_nums.append(num_str)
|
||||
|
||||
# Reconstruct the citation with original bracket style
|
||||
new_content = ", ".join(new_nums)
|
||||
return f"{open_bracket}{new_content}{close_bracket}"
|
||||
|
||||
# Replace all citations in the text
|
||||
updated_text = citation_pattern.sub(replace_citation, answer_text)
|
||||
|
||||
# Build the combined mapping
|
||||
combined_mapping: CitationMapping = dict(existing_citation_mapping)
|
||||
combined_mapping.update(additional_mappings)
|
||||
|
||||
return updated_text, combined_mapping
|
||||
@@ -1,15 +1,14 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_utils import create_tool_call_failure_messages
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.citation_utils import update_citation_processor_from_tool_response
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.llm_step import run_llm_step
|
||||
from onyx.chat.llm_step import TOOL_CALL_MSG_ARGUMENTS
|
||||
from onyx.chat.llm_step import TOOL_CALL_MSG_FUNC_NAME
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import LlmStepResult
|
||||
@@ -30,18 +29,18 @@ from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.utils import model_needs_formatting_reenabled
|
||||
from onyx.prompts.chat_prompts import IMAGE_GEN_REMINDER
|
||||
from onyx.prompts.chat_prompts import OPEN_URL_REMINDER
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import TopLevelBranching
|
||||
from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
|
||||
from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.models import (
|
||||
FinalImageGenerationResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
from onyx.tools.tool_runner import run_tool_calls
|
||||
@@ -64,7 +63,7 @@ MAX_LLM_CYCLES = 6
|
||||
def _build_project_file_citation_mapping(
|
||||
project_file_metadata: list[ProjectFileMetadata],
|
||||
starting_citation_num: int = 1,
|
||||
) -> dict[int, SearchDoc]:
|
||||
) -> CitationMapping:
|
||||
"""Build citation mapping for project files.
|
||||
|
||||
Converts project file metadata into SearchDoc objects that can be cited.
|
||||
@@ -77,7 +76,7 @@ def _build_project_file_citation_mapping(
|
||||
Returns:
|
||||
Dictionary mapping citation numbers to SearchDoc objects
|
||||
"""
|
||||
citation_mapping: dict[int, SearchDoc] = {}
|
||||
citation_mapping: CitationMapping = {}
|
||||
|
||||
for idx, file_meta in enumerate(project_file_metadata, start=starting_citation_num):
|
||||
# Create a SearchDoc for each project file
|
||||
@@ -293,8 +292,16 @@ def run_llm_loop(
|
||||
db_session: Session,
|
||||
forced_tool_id: int | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
) -> None:
|
||||
with trace("run_llm_loop", metadata={"tenant_id": get_current_tenant_id()}):
|
||||
with trace(
|
||||
"run_llm_loop",
|
||||
group_id=chat_session_id,
|
||||
metadata={
|
||||
"tenant_id": get_current_tenant_id(),
|
||||
"chat_session_id": chat_session_id,
|
||||
},
|
||||
):
|
||||
# Fix some LiteLLM issues,
|
||||
from onyx.llm.litellm_singleton.config import (
|
||||
initialize_litellm,
|
||||
@@ -302,18 +309,11 @@ def run_llm_loop(
|
||||
|
||||
initialize_litellm()
|
||||
|
||||
stopping_tools_names: list[str] = [ImageGenerationTool.NAME]
|
||||
citeable_tools_names: list[str] = [
|
||||
SearchTool.NAME,
|
||||
WebSearchTool.NAME,
|
||||
OpenURLTool.NAME,
|
||||
]
|
||||
|
||||
# Initialize citation processor for handling citations dynamically
|
||||
citation_processor = DynamicCitationProcessor()
|
||||
|
||||
# Add project file citation mappings if project files are present
|
||||
project_citation_mapping: dict[int, SearchDoc] = {}
|
||||
project_citation_mapping: CitationMapping = {}
|
||||
if project_files.project_file_metadata:
|
||||
project_citation_mapping = _build_project_file_citation_mapping(
|
||||
project_files.project_file_metadata
|
||||
@@ -325,7 +325,6 @@ def run_llm_loop(
|
||||
# Pass the total budget to construct_message_history, which will handle token allocation
|
||||
available_tokens = llm.config.max_input_tokens
|
||||
tool_choice: ToolChoiceOptions = ToolChoiceOptions.AUTO
|
||||
collected_tool_calls: list[ToolCallInfo] = []
|
||||
# Initialize gathered_documents with project files if present
|
||||
gathered_documents: list[SearchDoc] | None = (
|
||||
list(project_citation_mapping.values())
|
||||
@@ -343,12 +342,8 @@ def run_llm_loop(
|
||||
has_called_search_tool: bool = False
|
||||
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
|
||||
|
||||
current_tool_call_index = (
|
||||
0 # TODO: just use the cycle count after parallel tool calls are supported
|
||||
)
|
||||
|
||||
reasoning_cycles = 0
|
||||
for llm_cycle_count in range(MAX_LLM_CYCLES):
|
||||
|
||||
if forced_tool_id:
|
||||
# Needs to be just the single one because the "required" currently doesn't have a specified tool, just a binary
|
||||
final_tools = [tool for tool in tools if tool.id == forced_tool_id]
|
||||
@@ -445,12 +440,13 @@ def run_llm_loop(
|
||||
|
||||
# This calls the LLM, yields packets (reasoning, answers, etc.) and returns the result
|
||||
# It also pre-processes the tool calls in preparation for running them
|
||||
step_generator = run_llm_step(
|
||||
llm_step_result, has_reasoned = run_llm_step(
|
||||
emitter=emitter,
|
||||
history=truncated_message_history,
|
||||
tool_definitions=[tool.tool_definition() for tool in final_tools],
|
||||
tool_choice=tool_choice,
|
||||
llm=llm,
|
||||
turn_index=current_tool_call_index,
|
||||
placement=Placement(turn_index=llm_cycle_count + reasoning_cycles),
|
||||
citation_processor=citation_processor,
|
||||
state_container=state_container,
|
||||
# The rich docs representation is passed in so that when yielding the answer, it can also
|
||||
@@ -459,18 +455,8 @@ def run_llm_loop(
|
||||
final_documents=gathered_documents,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
|
||||
# Consume the generator, emitting packets and capturing the final result
|
||||
while True:
|
||||
try:
|
||||
packet = next(step_generator)
|
||||
emitter.emit(packet)
|
||||
except StopIteration as e:
|
||||
llm_step_result, current_tool_call_index = e.value
|
||||
break
|
||||
|
||||
# Type narrowing: generator always returns a result, so this can't be None
|
||||
llm_step_result = cast(LlmStepResult, llm_step_result)
|
||||
if has_reasoned:
|
||||
reasoning_cycles += 1
|
||||
|
||||
# Save citation mapping after each LLM step for incremental state updates
|
||||
state_container.set_citation_mapping(citation_processor.citation_to_doc)
|
||||
@@ -480,21 +466,50 @@ def run_llm_loop(
|
||||
tool_responses: list[ToolResponse] = []
|
||||
tool_calls = llm_step_result.tool_calls or []
|
||||
|
||||
just_ran_web_search = False
|
||||
for tool_call in tool_calls:
|
||||
# TODO replace the [tool_call] with the list of tool calls once parallel tool calls are supported
|
||||
tool_responses, citation_mapping = run_tool_calls(
|
||||
tool_calls=[tool_call],
|
||||
tools=final_tools,
|
||||
turn_index=current_tool_call_index,
|
||||
message_history=truncated_message_history,
|
||||
memories=memories,
|
||||
user_info=None, # TODO, this is part of memories right now, might want to separate it out
|
||||
citation_mapping=citation_mapping,
|
||||
citation_processor=citation_processor,
|
||||
skip_search_query_expansion=has_called_search_tool,
|
||||
if len(tool_calls) > 1:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(
|
||||
turn_index=tool_calls[0].placement.turn_index
|
||||
),
|
||||
obj=TopLevelBranching(num_parallel_branches=len(tool_calls)),
|
||||
)
|
||||
)
|
||||
|
||||
# Quick note for why citation_mapping and citation_processors are both needed:
|
||||
# 1. Tools return lightweight string mappings, not SearchDoc objects
|
||||
# 2. The SearchDoc resolution is deliberately deferred to llm_loop.py
|
||||
# 3. The citation_processor operates on SearchDoc objects and can't provide a complete reverse URL lookup for
|
||||
# in-flight citations
|
||||
# It can be cleaned up but not super trivial or worthwhile right now
|
||||
just_ran_web_search = False
|
||||
tool_responses, citation_mapping = run_tool_calls(
|
||||
tool_calls=tool_calls,
|
||||
tools=final_tools,
|
||||
message_history=truncated_message_history,
|
||||
memories=memories,
|
||||
user_info=None, # TODO, this is part of memories right now, might want to separate it out
|
||||
citation_mapping=citation_mapping,
|
||||
next_citation_num=citation_processor.get_next_citation_number(),
|
||||
skip_search_query_expansion=has_called_search_tool,
|
||||
)
|
||||
|
||||
# Failure case, give something reasonable to the LLM to try again
|
||||
if tool_calls and not tool_responses:
|
||||
failure_messages = create_tool_call_failure_messages(
|
||||
tool_calls[0], token_counter
|
||||
)
|
||||
simple_chat_history.extend(failure_messages)
|
||||
continue
|
||||
|
||||
for tool_response in tool_responses:
|
||||
# Extract tool_call from the response (set by run_tool_calls)
|
||||
if tool_response.tool_call is None:
|
||||
raise ValueError("Tool response missing tool_call reference")
|
||||
|
||||
tool_call = tool_response.tool_call
|
||||
tab_index = tool_call.placement.tab_index
|
||||
|
||||
# Track if search tool was called (for skipping query expansion on subsequent calls)
|
||||
if tool_call.tool_name == SearchTool.NAME:
|
||||
has_called_search_tool = True
|
||||
@@ -502,110 +517,81 @@ def run_llm_loop(
|
||||
# Build a mapping of tool names to tool objects for getting tool_id
|
||||
tools_by_name = {tool.name: tool for tool in final_tools}
|
||||
|
||||
# Add the results to the chat history, note that even if the tools were run in parallel, this isn't supported
|
||||
# as all the LLM APIs require linear history, so these will just be included sequentially
|
||||
for tool_call, tool_response in zip([tool_call], tool_responses):
|
||||
# Get the tool object to retrieve tool_id
|
||||
tool = tools_by_name.get(tool_call.tool_name)
|
||||
if not tool:
|
||||
raise ValueError(
|
||||
f"Tool '{tool_call.tool_name}' not found in tools list"
|
||||
)
|
||||
|
||||
# Extract search_docs if this is a search tool response
|
||||
search_docs = None
|
||||
if isinstance(tool_response.rich_response, SearchDocsResponse):
|
||||
search_docs = tool_response.rich_response.search_docs
|
||||
if gathered_documents:
|
||||
gathered_documents.extend(search_docs)
|
||||
else:
|
||||
gathered_documents = search_docs
|
||||
|
||||
# This is used for the Open URL reminder in the next cycle
|
||||
# only do this if the web search tool yielded results
|
||||
if search_docs and tool_call.tool_name == WebSearchTool.NAME:
|
||||
just_ran_web_search = True
|
||||
|
||||
# Extract generated_images if this is an image generation tool response
|
||||
generated_images = None
|
||||
if isinstance(
|
||||
tool_response.rich_response, FinalImageGenerationResponse
|
||||
):
|
||||
generated_images = tool_response.rich_response.generated_images
|
||||
|
||||
tool_call_info = ToolCallInfo(
|
||||
parent_tool_call_id=None, # Top-level tool calls are attached to the chat message
|
||||
turn_index=current_tool_call_index,
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
tool_id=tool.id,
|
||||
reasoning_tokens=llm_step_result.reasoning, # All tool calls from this loop share the same reasoning
|
||||
tool_call_arguments=tool_call.tool_args,
|
||||
tool_call_response=tool_response.llm_facing_response,
|
||||
search_docs=search_docs,
|
||||
generated_images=generated_images,
|
||||
# Add the results to the chat history. Even though tools may run in parallel,
|
||||
# LLM APIs require linear history, so results are added sequentially.
|
||||
# Get the tool object to retrieve tool_id
|
||||
tool = tools_by_name.get(tool_call.tool_name)
|
||||
if not tool:
|
||||
raise ValueError(
|
||||
f"Tool '{tool_call.tool_name}' not found in tools list"
|
||||
)
|
||||
collected_tool_calls.append(tool_call_info)
|
||||
# Add to state container for partial save support
|
||||
state_container.add_tool_call(tool_call_info)
|
||||
|
||||
# Store tool call with function name and arguments in separate layers
|
||||
tool_call_data = {
|
||||
TOOL_CALL_MSG_FUNC_NAME: tool_call.tool_name,
|
||||
TOOL_CALL_MSG_ARGUMENTS: tool_call.tool_args,
|
||||
}
|
||||
tool_call_message = json.dumps(tool_call_data)
|
||||
tool_call_token_count = token_counter(tool_call_message)
|
||||
# Extract search_docs if this is a search tool response
|
||||
search_docs = None
|
||||
if isinstance(tool_response.rich_response, SearchDocsResponse):
|
||||
search_docs = tool_response.rich_response.search_docs
|
||||
if gathered_documents:
|
||||
gathered_documents.extend(search_docs)
|
||||
else:
|
||||
gathered_documents = search_docs
|
||||
|
||||
tool_call_msg = ChatMessageSimple(
|
||||
message=tool_call_message,
|
||||
token_count=tool_call_token_count,
|
||||
message_type=MessageType.TOOL_CALL,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(tool_call_msg)
|
||||
# This is used for the Open URL reminder in the next cycle
|
||||
# only do this if the web search tool yielded results
|
||||
if search_docs and tool_call.tool_name == WebSearchTool.NAME:
|
||||
just_ran_web_search = True
|
||||
|
||||
tool_response_message = tool_response.llm_facing_response
|
||||
tool_response_token_count = token_counter(tool_response_message)
|
||||
# Extract generated_images if this is an image generation tool response
|
||||
generated_images = None
|
||||
if isinstance(
|
||||
tool_response.rich_response, FinalImageGenerationResponse
|
||||
):
|
||||
generated_images = tool_response.rich_response.generated_images
|
||||
|
||||
tool_response_msg = ChatMessageSimple(
|
||||
message=tool_response_message,
|
||||
token_count=tool_response_token_count,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(tool_response_msg)
|
||||
tool_call_info = ToolCallInfo(
|
||||
parent_tool_call_id=None, # Top-level tool calls are attached to the chat message
|
||||
turn_index=llm_cycle_count + reasoning_cycles,
|
||||
tab_index=tab_index,
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
tool_id=tool.id,
|
||||
reasoning_tokens=llm_step_result.reasoning, # All tool calls from this loop share the same reasoning
|
||||
tool_call_arguments=tool_call.tool_args,
|
||||
tool_call_response=tool_response.llm_facing_response,
|
||||
search_docs=search_docs,
|
||||
generated_images=generated_images,
|
||||
)
|
||||
# Add to state container for partial save support
|
||||
state_container.add_tool_call(tool_call_info)
|
||||
|
||||
# Update citation processor if this was a search tool
|
||||
if tool_call.tool_name in citeable_tools_names:
|
||||
# Check if the rich_response is a SearchDocsResponse
|
||||
if isinstance(tool_response.rich_response, SearchDocsResponse):
|
||||
search_response = tool_response.rich_response
|
||||
# Store tool call with function name and arguments in separate layers
|
||||
tool_call_message = tool_call.to_msg_str()
|
||||
tool_call_token_count = token_counter(tool_call_message)
|
||||
|
||||
# Create mapping from citation number to SearchDoc
|
||||
citation_to_doc: dict[int, SearchDoc] = {}
|
||||
for (
|
||||
citation_num,
|
||||
doc_id,
|
||||
) in search_response.citation_mapping.items():
|
||||
# Find the SearchDoc with this doc_id
|
||||
matching_doc = next(
|
||||
(
|
||||
doc
|
||||
for doc in search_response.search_docs
|
||||
if doc.document_id == doc_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if matching_doc:
|
||||
citation_to_doc[citation_num] = matching_doc
|
||||
tool_call_msg = ChatMessageSimple(
|
||||
message=tool_call_message,
|
||||
token_count=tool_call_token_count,
|
||||
message_type=MessageType.TOOL_CALL,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(tool_call_msg)
|
||||
|
||||
# Update the citation processor
|
||||
citation_processor.update_citation_mapping(citation_to_doc)
|
||||
tool_response_message = tool_response.llm_facing_response
|
||||
tool_response_token_count = token_counter(tool_response_message)
|
||||
|
||||
current_tool_call_index += 1
|
||||
tool_response_msg = ChatMessageSimple(
|
||||
message=tool_response_message,
|
||||
token_count=tool_response_token_count,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(tool_response_msg)
|
||||
|
||||
# Update citation processor if this was a search tool
|
||||
update_citation_processor_from_tool_response(
|
||||
tool_response, citation_processor
|
||||
)
|
||||
|
||||
# If no tool calls, then it must have answered, wrap up
|
||||
if not llm_step_result.tool_calls or len(llm_step_result.tool_calls) == 0:
|
||||
@@ -613,13 +599,13 @@ def run_llm_loop(
|
||||
|
||||
# Certain tools do not allow further actions, force the LLM wrap up on the next cycle
|
||||
if any(
|
||||
tool.tool_name in stopping_tools_names
|
||||
tool.tool_name in STOPPING_TOOLS_NAMES
|
||||
for tool in llm_step_result.tool_calls
|
||||
):
|
||||
ran_image_gen = True
|
||||
|
||||
if llm_step_result.tool_calls and any(
|
||||
tool.tool_name in citeable_tools_names
|
||||
tool.tool_name in CITEABLE_TOOLS_NAMES
|
||||
for tool in llm_step_result.tool_calls
|
||||
):
|
||||
# As long as 1 tool with citeable documents is called at any point, we ask the LLM to try to cite
|
||||
@@ -629,5 +615,8 @@ def run_llm_loop(
|
||||
raise RuntimeError("LLM did not return an answer.")
|
||||
|
||||
emitter.emit(
|
||||
Packet(turn_index=current_tool_call_index, obj=OverallStop(type="stop"))
|
||||
Packet(
|
||||
placement=Placement(turn_index=llm_cycle_count + reasoning_cycles),
|
||||
obj=OverallStop(type="stop"),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Sequence
|
||||
@@ -7,6 +8,7 @@ from typing import cast
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
|
||||
@@ -17,16 +19,19 @@ from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.model_response import Delta
|
||||
from onyx.llm.models import AssistantMessage
|
||||
from onyx.llm.models import ChatCompletionMessage
|
||||
from onyx.llm.models import FunctionCall
|
||||
from onyx.llm.models import ImageContentPart
|
||||
from onyx.llm.models import ImageUrlDetail
|
||||
from onyx.llm.models import ReasoningEffort
|
||||
from onyx.llm.models import SystemMessage
|
||||
from onyx.llm.models import TextContentPart
|
||||
from onyx.llm.models import ToolCall
|
||||
from onyx.llm.models import ToolMessage
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
@@ -34,6 +39,8 @@ from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDone
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.tools.models import TOOL_CALL_MSG_ARGUMENTS
|
||||
from onyx.tools.models import TOOL_CALL_MSG_FUNC_NAME
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tracing.framework.create import generation_span
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
@@ -43,8 +50,77 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
TOOL_CALL_MSG_FUNC_NAME = "function_name"
|
||||
TOOL_CALL_MSG_ARGUMENTS = "arguments"
|
||||
def _try_parse_json_string(value: Any) -> Any:
|
||||
"""Attempt to parse a JSON string value into its Python equivalent.
|
||||
|
||||
If value is a string that looks like a JSON array or object, parse it.
|
||||
Otherwise return the value unchanged.
|
||||
|
||||
This handles the case where the LLM returns arguments like:
|
||||
- queries: '["query1", "query2"]' instead of ["query1", "query2"]
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
|
||||
stripped = value.strip()
|
||||
# Only attempt to parse if it looks like a JSON array or object
|
||||
if not (
|
||||
(stripped.startswith("[") and stripped.endswith("]"))
|
||||
or (stripped.startswith("{") and stripped.endswith("}"))
|
||||
):
|
||||
return value
|
||||
|
||||
try:
|
||||
return json.loads(stripped)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
|
||||
|
||||
def _parse_tool_args_to_dict(raw_args: Any) -> dict[str, Any]:
|
||||
"""Parse tool arguments into a dict.
|
||||
|
||||
Normal case:
|
||||
- raw_args == '{"queries":[...]}' -> dict via json.loads
|
||||
|
||||
Defensive case (JSON string literal of an object):
|
||||
- raw_args == '"{\\"queries\\":[...]}"' -> json.loads -> str -> json.loads -> dict
|
||||
|
||||
Also handles the case where argument values are JSON strings that need parsing:
|
||||
- {"queries": '["q1", "q2"]'} -> {"queries": ["q1", "q2"]}
|
||||
|
||||
Anything else returns {}.
|
||||
"""
|
||||
|
||||
if raw_args is None:
|
||||
return {}
|
||||
|
||||
if isinstance(raw_args, dict):
|
||||
# Parse any string values that look like JSON arrays/objects
|
||||
return {k: _try_parse_json_string(v) for k, v in raw_args.items()}
|
||||
|
||||
if not isinstance(raw_args, str):
|
||||
return {}
|
||||
|
||||
try:
|
||||
parsed1: Any = json.loads(raw_args)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
|
||||
if isinstance(parsed1, dict):
|
||||
# Parse any string values that look like JSON arrays/objects
|
||||
return {k: _try_parse_json_string(v) for k, v in parsed1.items()}
|
||||
|
||||
if isinstance(parsed1, str):
|
||||
try:
|
||||
parsed2: Any = json.loads(parsed1)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
if isinstance(parsed2, dict):
|
||||
# Parse any string values that look like JSON arrays/objects
|
||||
return {k: _try_parse_json_string(v) for k, v in parsed2.items()}
|
||||
return {}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def _format_message_history_for_logging(
|
||||
@@ -153,21 +229,27 @@ def _update_tool_call_with_delta(
|
||||
|
||||
def _extract_tool_call_kickoffs(
|
||||
id_to_tool_call_map: dict[int, dict[str, Any]],
|
||||
turn_index: int,
|
||||
tab_index: int | None = None,
|
||||
sub_turn_index: int | None = None,
|
||||
) -> list[ToolCallKickoff]:
|
||||
"""Extract ToolCallKickoff objects from the tool call map.
|
||||
|
||||
Returns a list of ToolCallKickoff objects for valid tool calls (those with both id and name).
|
||||
Each tool call is assigned the given turn_index and a tab_index based on its order.
|
||||
|
||||
Args:
|
||||
id_to_tool_call_map: Map of tool call index to tool call data
|
||||
turn_index: The turn index for this set of tool calls
|
||||
tab_index: If provided, use this tab_index for all tool calls (otherwise auto-increment)
|
||||
sub_turn_index: The sub-turn index for nested tool calls
|
||||
"""
|
||||
tool_calls: list[ToolCallKickoff] = []
|
||||
tab_index_calculated = 0
|
||||
for tool_call_data in id_to_tool_call_map.values():
|
||||
if tool_call_data.get("id") and tool_call_data.get("name"):
|
||||
try:
|
||||
# Parse arguments JSON string to dict
|
||||
tool_args = (
|
||||
json.loads(tool_call_data["arguments"])
|
||||
if tool_call_data["arguments"]
|
||||
else {}
|
||||
)
|
||||
tool_args = _parse_tool_args_to_dict(tool_call_data.get("arguments"))
|
||||
except json.JSONDecodeError:
|
||||
# If parsing fails, try empty dict, most tools would fail though
|
||||
logger.error(
|
||||
@@ -180,8 +262,16 @@ def _extract_tool_call_kickoffs(
|
||||
tool_call_id=tool_call_data["id"],
|
||||
tool_name=tool_call_data["name"],
|
||||
tool_args=tool_args,
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=(
|
||||
tab_index_calculated if tab_index is None else tab_index
|
||||
),
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
)
|
||||
)
|
||||
tab_index_calculated += 1
|
||||
return tool_calls
|
||||
|
||||
|
||||
@@ -272,13 +362,19 @@ def translate_history_to_llm_format(
|
||||
function_name = tool_call_data.get(
|
||||
TOOL_CALL_MSG_FUNC_NAME, "unknown"
|
||||
)
|
||||
tool_args = tool_call_data.get(TOOL_CALL_MSG_ARGUMENTS, {})
|
||||
raw_args = tool_call_data.get(TOOL_CALL_MSG_ARGUMENTS, {})
|
||||
else:
|
||||
function_name = "unknown"
|
||||
tool_args = (
|
||||
raw_args = (
|
||||
tool_call_data if isinstance(tool_call_data, dict) else {}
|
||||
)
|
||||
|
||||
# IMPORTANT: `FunctionCall.arguments` must be a JSON object string.
|
||||
# If `raw_args` is accidentally a JSON string literal of an object
|
||||
# (e.g. '"{\\"queries\\":[...]}"'), calling `json.dumps(raw_args)`
|
||||
# would produce a quoted JSON literal and break Anthropic tool parsing.
|
||||
tool_args = _parse_tool_args_to_dict(raw_args)
|
||||
|
||||
# NOTE: if the model is trained on a different tool call format, this may slightly interfere
|
||||
# with the future tool calls, if it doesn't look like this. Almost certainly not a big deal.
|
||||
tool_call = ToolCall(
|
||||
@@ -324,20 +420,87 @@ def translate_history_to_llm_format(
|
||||
return messages
|
||||
|
||||
|
||||
def run_llm_step(
|
||||
def _increment_turns(
|
||||
turn_index: int, sub_turn_index: int | None
|
||||
) -> tuple[int, int | None]:
|
||||
if sub_turn_index is None:
|
||||
return turn_index + 1, None
|
||||
else:
|
||||
return turn_index, sub_turn_index + 1
|
||||
|
||||
|
||||
def run_llm_step_pkt_generator(
|
||||
history: list[ChatMessageSimple],
|
||||
tool_definitions: list[dict],
|
||||
tool_choice: ToolChoiceOptions,
|
||||
llm: LLM,
|
||||
turn_index: int,
|
||||
citation_processor: DynamicCitationProcessor,
|
||||
state_container: ChatStateContainer,
|
||||
placement: Placement,
|
||||
state_container: ChatStateContainer | None,
|
||||
citation_processor: DynamicCitationProcessor | None,
|
||||
reasoning_effort: ReasoningEffort | None = None,
|
||||
final_documents: list[SearchDoc] | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
) -> Generator[Packet, None, tuple[LlmStepResult, int]]:
|
||||
# The second return value is for the turn index because reasoning counts on the frontend as a turn
|
||||
# TODO this is maybe ok but does not align well with the backend logic too well
|
||||
custom_token_processor: (
|
||||
Callable[[Delta | None, Any], tuple[Delta | None, Any]] | None
|
||||
) = None,
|
||||
max_tokens: int | None = None,
|
||||
# TODO: Temporary handling of nested tool calls with agents, figure out a better way to handle this
|
||||
use_existing_tab_index: bool = False,
|
||||
is_deep_research: bool = False,
|
||||
) -> Generator[Packet, None, tuple[LlmStepResult, bool]]:
|
||||
"""Run an LLM step and stream the response as packets.
|
||||
NOTE: DO NOT TOUCH THIS FUNCTION BEFORE ASKING YUHONG, this is very finicky and
|
||||
delicate logic that is core to the app's main functionality.
|
||||
|
||||
This generator function streams LLM responses, processing reasoning content,
|
||||
answer content, tool calls, and citations. It yields Packet objects for
|
||||
real-time streaming to clients and accumulates the final result.
|
||||
|
||||
Args:
|
||||
history: List of chat messages in the conversation history.
|
||||
tool_definitions: List of tool definitions available to the LLM.
|
||||
tool_choice: Tool choice configuration (e.g., "auto", "required", "none").
|
||||
llm: Language model interface to use for generation.
|
||||
turn_index: Current turn index in the conversation.
|
||||
state_container: Container for storing chat state (reasoning, answers).
|
||||
citation_processor: Optional processor for extracting and formatting citations
|
||||
from the response. If provided, processes tokens to identify citations.
|
||||
reasoning_effort: Optional reasoning effort configuration for models that
|
||||
support reasoning (e.g., o1 models).
|
||||
final_documents: Optional list of search documents to include in the response
|
||||
start packet.
|
||||
user_identity: Optional user identity information for the LLM.
|
||||
custom_token_processor: Optional callable that processes each token delta
|
||||
before yielding. Receives (delta, processor_state) and returns
|
||||
(modified_delta, new_processor_state). Can return None for delta to skip.
|
||||
sub_turn_index: Optional sub-turn index for nested tool/agent calls.
|
||||
|
||||
Yields:
|
||||
Packet: Streaming packets containing:
|
||||
- ReasoningStart/ReasoningDelta/ReasoningDone for reasoning content
|
||||
- AgentResponseStart/AgentResponseDelta for answer content
|
||||
- CitationInfo for extracted citations
|
||||
- ToolCallKickoff for tool calls (extracted at the end)
|
||||
|
||||
Returns:
|
||||
tuple[LlmStepResult, bool]: A tuple containing:
|
||||
- LlmStepResult: The final result with accumulated reasoning, answer,
|
||||
and tool calls (if any).
|
||||
- bool: Whether reasoning occurred during this step. This should be used to
|
||||
increment the turn index or sub_turn index for the rest of the LLM loop.
|
||||
|
||||
Note:
|
||||
The function handles incremental state updates, saving reasoning and answer
|
||||
tokens to the state container as they are generated. Tool calls are extracted
|
||||
and yielded only after the stream completes.
|
||||
"""
|
||||
|
||||
turn_index = placement.turn_index
|
||||
tab_index = placement.tab_index
|
||||
sub_turn_index = placement.sub_turn_index
|
||||
|
||||
llm_msg_history = translate_history_to_llm_format(history)
|
||||
has_reasoned = 0
|
||||
|
||||
# Uncomment the line below to log the entire message history to the console
|
||||
if LOG_ONYX_MODEL_INTERACTIONS:
|
||||
@@ -351,6 +514,8 @@ def run_llm_step(
|
||||
accumulated_reasoning = ""
|
||||
accumulated_answer = ""
|
||||
|
||||
processor_state: Any = None
|
||||
|
||||
with generation_span(
|
||||
model=llm.config.model_name,
|
||||
model_config={
|
||||
@@ -366,7 +531,8 @@ def run_llm_step(
|
||||
tools=tool_definitions,
|
||||
tool_choice=tool_choice,
|
||||
structured_response_format=None, # TODO
|
||||
# reasoning_effort=ReasoningEffort.OFF, # Can set this for dev/testing.
|
||||
max_tokens=max_tokens,
|
||||
reasoning_effort=reasoning_effort,
|
||||
user_identity=user_identity,
|
||||
):
|
||||
if packet.usage:
|
||||
@@ -379,69 +545,173 @@ def run_llm_step(
|
||||
}
|
||||
delta = packet.choice.delta
|
||||
|
||||
if custom_token_processor:
|
||||
# The custom token processor can modify the deltas for specific custom logic
|
||||
# It can also return a state so that it can handle aggregated delta logic etc.
|
||||
# Loosely typed so the function can be flexible
|
||||
modified_delta, processor_state = custom_token_processor(
|
||||
delta, processor_state
|
||||
)
|
||||
if modified_delta is None:
|
||||
continue
|
||||
delta = modified_delta
|
||||
|
||||
# Should only happen once, frontend does not expect multiple
|
||||
# ReasoningStart or ReasoningDone packets.
|
||||
if delta.reasoning_content:
|
||||
accumulated_reasoning += delta.reasoning_content
|
||||
# Save reasoning incrementally to state container
|
||||
state_container.set_reasoning_tokens(accumulated_reasoning)
|
||||
if state_container:
|
||||
state_container.set_reasoning_tokens(accumulated_reasoning)
|
||||
if not reasoning_start:
|
||||
yield Packet(
|
||||
turn_index=turn_index,
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningStart(),
|
||||
)
|
||||
yield Packet(
|
||||
turn_index=turn_index,
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDelta(reasoning=delta.reasoning_content),
|
||||
)
|
||||
reasoning_start = True
|
||||
|
||||
if delta.content:
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
turn_index=turn_index,
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
turn_index += 1
|
||||
reasoning_start = False
|
||||
|
||||
if not answer_start:
|
||||
yield Packet(
|
||||
turn_index=turn_index,
|
||||
obj=AgentResponseStart(
|
||||
final_documents=final_documents,
|
||||
),
|
||||
)
|
||||
answer_start = True
|
||||
|
||||
for result in citation_processor.process_token(delta.content):
|
||||
if isinstance(result, str):
|
||||
accumulated_answer += result
|
||||
# Save answer incrementally to state container
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
# When tool_choice is REQUIRED, content before tool calls is reasoning/thinking
|
||||
# about which tool to call, not an actual answer to the user.
|
||||
# Treat this content as reasoning instead of answer.
|
||||
if is_deep_research and tool_choice == ToolChoiceOptions.REQUIRED:
|
||||
# Treat content as reasoning when we know tool calls are coming
|
||||
accumulated_reasoning += delta.content
|
||||
if state_container:
|
||||
state_container.set_reasoning_tokens(accumulated_reasoning)
|
||||
if not reasoning_start:
|
||||
yield Packet(
|
||||
turn_index=turn_index,
|
||||
obj=AgentResponseDelta(content=result),
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningStart(),
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
yield Packet(
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
obj=result,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDelta(reasoning=delta.content),
|
||||
)
|
||||
reasoning_start = True
|
||||
else:
|
||||
# Normal flow for AUTO or NONE tool choice
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
has_reasoned = 1
|
||||
turn_index, sub_turn_index = _increment_turns(
|
||||
turn_index, sub_turn_index
|
||||
)
|
||||
reasoning_start = False
|
||||
|
||||
if not answer_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=AgentResponseStart(
|
||||
final_documents=final_documents,
|
||||
),
|
||||
)
|
||||
answer_start = True
|
||||
|
||||
if citation_processor:
|
||||
for result in citation_processor.process_token(delta.content):
|
||||
if isinstance(result, str):
|
||||
accumulated_answer += result
|
||||
# Save answer incrementally to state container
|
||||
if state_container:
|
||||
state_container.set_answer_tokens(
|
||||
accumulated_answer
|
||||
)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=AgentResponseDelta(content=result),
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=result,
|
||||
)
|
||||
else:
|
||||
# When citation_processor is None, use delta.content directly without modification
|
||||
accumulated_answer += delta.content
|
||||
# Save answer incrementally to state container
|
||||
if state_container:
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=AgentResponseDelta(content=delta.content),
|
||||
)
|
||||
|
||||
if delta.tool_calls:
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
turn_index=turn_index,
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
turn_index += 1
|
||||
has_reasoned = 1
|
||||
turn_index, sub_turn_index = _increment_turns(
|
||||
turn_index, sub_turn_index
|
||||
)
|
||||
reasoning_start = False
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
|
||||
|
||||
tool_calls = _extract_tool_call_kickoffs(id_to_tool_call_map)
|
||||
# Flush custom token processor to get any final tool calls
|
||||
if custom_token_processor:
|
||||
flush_delta, processor_state = custom_token_processor(None, processor_state)
|
||||
if flush_delta and flush_delta.tool_calls:
|
||||
for tool_call_delta in flush_delta.tool_calls:
|
||||
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
|
||||
|
||||
tool_calls = _extract_tool_call_kickoffs(
|
||||
id_to_tool_call_map=id_to_tool_call_map,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index if use_existing_tab_index else None,
|
||||
sub_turn_index=sub_turn_index,
|
||||
)
|
||||
if tool_calls:
|
||||
tool_calls_list: list[ToolCall] = [
|
||||
ToolCall(
|
||||
@@ -468,28 +738,48 @@ def run_llm_step(
|
||||
tool_calls=None,
|
||||
)
|
||||
span_generation.span_data.output = [assistant_msg_no_tools.model_dump()]
|
||||
# Close reasoning block if still open (stream ended with reasoning content)
|
||||
|
||||
# This may happen if the custom token processor is used to modify other packets into reasoning
|
||||
# Then there won't necessarily be anything else to come after the reasoning tokens
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
turn_index=turn_index,
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
turn_index += 1
|
||||
has_reasoned = 1
|
||||
turn_index, sub_turn_index = _increment_turns(turn_index, sub_turn_index)
|
||||
reasoning_start = False
|
||||
|
||||
# Flush any remaining content from citation processor
|
||||
# Reasoning is always first so this should use the post-incremented value of turn_index
|
||||
# Note that this doesn't need to handle any sub-turns as those docs will not have citations
|
||||
# as clickable items and will be stripped out instead.
|
||||
if citation_processor:
|
||||
for result in citation_processor.process_token(None):
|
||||
if isinstance(result, str):
|
||||
accumulated_answer += result
|
||||
# Save answer incrementally to state container
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
if state_container:
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
turn_index=turn_index,
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=AgentResponseDelta(content=result),
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
yield Packet(
|
||||
turn_index=turn_index,
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=result,
|
||||
)
|
||||
|
||||
@@ -514,5 +804,55 @@ def run_llm_step(
|
||||
answer=accumulated_answer if accumulated_answer else None,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
),
|
||||
turn_index,
|
||||
bool(has_reasoned),
|
||||
)
|
||||
|
||||
|
||||
def run_llm_step(
|
||||
emitter: Emitter,
|
||||
history: list[ChatMessageSimple],
|
||||
tool_definitions: list[dict],
|
||||
tool_choice: ToolChoiceOptions,
|
||||
llm: LLM,
|
||||
placement: Placement,
|
||||
state_container: ChatStateContainer | None,
|
||||
citation_processor: DynamicCitationProcessor | None,
|
||||
reasoning_effort: ReasoningEffort | None = None,
|
||||
final_documents: list[SearchDoc] | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
custom_token_processor: (
|
||||
Callable[[Delta | None, Any], tuple[Delta | None, Any]] | None
|
||||
) = None,
|
||||
max_tokens: int | None = None,
|
||||
use_existing_tab_index: bool = False,
|
||||
is_deep_research: bool = False,
|
||||
) -> tuple[LlmStepResult, bool]:
|
||||
"""Wrapper around run_llm_step_pkt_generator that consumes packets and emits them.
|
||||
|
||||
Returns:
|
||||
tuple[LlmStepResult, bool]: The LLM step result and whether reasoning occurred.
|
||||
"""
|
||||
step_generator = run_llm_step_pkt_generator(
|
||||
history=history,
|
||||
tool_definitions=tool_definitions,
|
||||
tool_choice=tool_choice,
|
||||
llm=llm,
|
||||
placement=placement,
|
||||
state_container=state_container,
|
||||
citation_processor=citation_processor,
|
||||
reasoning_effort=reasoning_effort,
|
||||
final_documents=final_documents,
|
||||
user_identity=user_identity,
|
||||
custom_token_processor=custom_token_processor,
|
||||
max_tokens=max_tokens,
|
||||
use_existing_tab_index=use_existing_tab_index,
|
||||
is_deep_research=is_deep_research,
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
packet = next(step_generator)
|
||||
emitter.emit(packet)
|
||||
except StopIteration as e:
|
||||
llm_step_result, has_reasoned = e.value
|
||||
return llm_step_result, bool(has_reasoned)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
@@ -7,9 +6,8 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_milestones import process_multi_assistant_milestone
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_state import run_chat_llm_with_state_containers
|
||||
from onyx.chat.chat_state import run_chat_loop_with_state_containers
|
||||
from onyx.chat.chat_utils import convert_chat_history
|
||||
from onyx.chat.chat_utils import create_chat_history_chain
|
||||
from onyx.chat.chat_utils import get_custom_agent_prompt
|
||||
@@ -32,6 +30,7 @@ from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.context.search.models import CitationDocInfo
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
@@ -51,8 +50,8 @@ from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.utils import load_in_memory_chat_files
|
||||
from onyx.file_store.utils import verify_user_files
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
@@ -65,13 +64,14 @@ from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.tools.constants import SEARCH_TOOL_ID
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.tool_constructor import construct_tools
|
||||
from onyx.tools.tool_constructor import CustomToolConfig
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.tools.tool_constructor import SearchToolUsage
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.timing import log_generator_function_time
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -367,11 +367,10 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
# Milestone tracking, most devs using the API don't need to understand this
|
||||
process_multi_assistant_milestone(
|
||||
user=user,
|
||||
assistant_id=persona.id,
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
db_session=db_session,
|
||||
distinct_id=user.email if user else tenant_id,
|
||||
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
)
|
||||
|
||||
if reference_doc_ids is None and retrieval_options is None:
|
||||
@@ -379,7 +378,7 @@ def stream_chat_message_objects(
|
||||
"Must specify a set of documents for chat or specify search options"
|
||||
)
|
||||
|
||||
llm, fast_llm = get_llms_for_persona(
|
||||
llm = get_llm_for_persona(
|
||||
persona=persona,
|
||||
user=user,
|
||||
llm_override=new_msg_req.llm_override or chat_session.llm_override,
|
||||
@@ -475,7 +474,6 @@ def stream_chat_message_objects(
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=user_selected_filters,
|
||||
project_id=(
|
||||
@@ -546,7 +544,7 @@ def stream_chat_message_objects(
|
||||
# for stop signals. run_llm_loop itself doesn't know about stopping.
|
||||
# Note: DB session is not thread safe but nothing else uses it and the
|
||||
# reference is passed directly so it's ok.
|
||||
if os.environ.get("ENABLE_DEEP_RESEARCH_LOOP"): # Dev only feature flag for now
|
||||
if new_msg_req.deep_research:
|
||||
if chat_session.project_id:
|
||||
raise RuntimeError("Deep research is not supported for projects")
|
||||
|
||||
@@ -554,7 +552,7 @@ def stream_chat_message_objects(
|
||||
# (user has already responded to a clarification question)
|
||||
skip_clarification = is_last_assistant_message_clarification(chat_history)
|
||||
|
||||
yield from run_chat_llm_with_state_containers(
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
run_deep_research_llm_loop,
|
||||
is_connected=check_is_connected,
|
||||
emitter=emitter,
|
||||
@@ -567,9 +565,10 @@ def stream_chat_message_objects(
|
||||
db_session=db_session,
|
||||
skip_clarification=skip_clarification,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session_id),
|
||||
)
|
||||
else:
|
||||
yield from run_chat_llm_with_state_containers(
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
run_llm_loop,
|
||||
is_connected=check_is_connected, # Not passed through to run_llm_loop
|
||||
emitter=emitter,
|
||||
@@ -589,6 +588,7 @@ def stream_chat_message_objects(
|
||||
else None
|
||||
),
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session_id),
|
||||
)
|
||||
|
||||
# Determine if stopped by user
|
||||
|
||||
@@ -22,7 +22,7 @@ from onyx.prompts.tool_prompts import PYTHON_TOOL_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import TOOL_SECTION_HEADER
|
||||
from onyx.prompts.tool_prompts import WEB_SEARCH_GUIDANCE
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
@@ -156,7 +156,7 @@ def build_system_prompt(
|
||||
system_prompt += company_context
|
||||
if memories:
|
||||
system_prompt += "\n".join(
|
||||
memory.strip() for memory in memories if memory.strip()
|
||||
"- " + memory.strip() for memory in memories if memory.strip()
|
||||
)
|
||||
|
||||
# Append citation guidance after company context if placeholder was not present
|
||||
|
||||
@@ -102,6 +102,7 @@ def _create_and_link_tool_calls(
|
||||
if tool_call_info.generated_images
|
||||
else None
|
||||
),
|
||||
tab_index=tool_call_info.tab_index,
|
||||
add_only=True,
|
||||
)
|
||||
|
||||
@@ -219,8 +220,8 @@ def save_chat_turn(
|
||||
search_doc_key_to_id[search_doc_key] = db_search_doc.id
|
||||
search_doc_ids_for_tool.append(db_search_doc.id)
|
||||
|
||||
tool_call_to_search_doc_ids[tool_call_info.tool_call_id] = (
|
||||
search_doc_ids_for_tool
|
||||
tool_call_to_search_doc_ids[tool_call_info.tool_call_id] = list(
|
||||
set(search_doc_ids_for_tool)
|
||||
)
|
||||
|
||||
# 3. Collect all unique SearchDoc IDs from all tool calls to link to ChatMessage
|
||||
|
||||
@@ -541,6 +541,11 @@ GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
|
||||
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||
)
|
||||
|
||||
# Default size threshold for Drupal Wiki attachments (10MB)
|
||||
DRUPAL_WIKI_ATTACHMENT_SIZE_THRESHOLD = int(
|
||||
os.environ.get("DRUPAL_WIKI_ATTACHMENT_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||
)
|
||||
|
||||
# Default size threshold for SharePoint files (20MB)
|
||||
SHAREPOINT_CONNECTOR_SIZE_THRESHOLD = int(
|
||||
os.environ.get("SHAREPOINT_CONNECTOR_SIZE_THRESHOLD", 20 * 1024 * 1024)
|
||||
@@ -583,6 +588,16 @@ LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
|
||||
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 8)
|
||||
MAX_SLACK_QUERY_EXPANSIONS = int(os.environ.get("MAX_SLACK_QUERY_EXPANSIONS", "5"))
|
||||
|
||||
# Slack federated search thread context settings
|
||||
# Batch size for fetching thread context (controls concurrent API calls per batch)
|
||||
SLACK_THREAD_CONTEXT_BATCH_SIZE = int(
|
||||
os.environ.get("SLACK_THREAD_CONTEXT_BATCH_SIZE", "5")
|
||||
)
|
||||
# Maximum messages to fetch thread context for (top N by relevance get full context)
|
||||
MAX_SLACK_THREAD_CONTEXT_MESSAGES = int(
|
||||
os.environ.get("MAX_SLACK_THREAD_CONTEXT_MESSAGES", "5")
|
||||
)
|
||||
|
||||
DASK_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
@@ -698,6 +713,15 @@ AVERAGE_SUMMARY_EMBEDDINGS = (
|
||||
|
||||
MAX_TOKENS_FOR_FULL_INCLUSION = 4096
|
||||
|
||||
# The intent was to have this be configurable per query, but I don't think any
|
||||
# codepath was actually configuring this, so for the migrated Vespa interface
|
||||
# we'll just use the default value, but also have it be configurable by env var.
|
||||
RECENCY_BIAS_MULTIPLIER = float(os.environ.get("RECENCY_BIAS_MULTIPLIER") or 1.0)
|
||||
|
||||
# Should match the rerank-count value set in
|
||||
# backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd.jinja.
|
||||
RERANK_COUNT = int(os.environ.get("RERANK_COUNT") or 1000)
|
||||
|
||||
|
||||
#####
|
||||
# Tool Configs
|
||||
|
||||
@@ -209,6 +209,7 @@ class DocumentSource(str, Enum):
|
||||
EGNYTE = "egnyte"
|
||||
AIRTABLE = "airtable"
|
||||
HIGHSPOT = "highspot"
|
||||
DRUPAL_WIKI = "drupal_wiki"
|
||||
|
||||
IMAP = "imap"
|
||||
BITBUCKET = "bitbucket"
|
||||
@@ -332,7 +333,6 @@ class FileType(str, Enum):
|
||||
class MilestoneRecordType(str, Enum):
|
||||
TENANT_CREATED = "tenant_created"
|
||||
USER_SIGNED_UP = "user_signed_up"
|
||||
MULTIPLE_USERS = "multiple_users"
|
||||
VISITED_ADMIN_PAGE = "visited_admin_page"
|
||||
CREATED_CONNECTOR = "created_connector"
|
||||
CONNECTOR_SUCCEEDED = "connector_succeeded"
|
||||
@@ -564,7 +564,7 @@ REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
|
||||
if platform.system() == "Darwin":
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore
|
||||
else:
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60
|
||||
|
||||
|
||||
class OnyxCallTypes(str, Enum):
|
||||
@@ -629,6 +629,7 @@ project management, and collaboration tools into a single, customizable platform
|
||||
DocumentSource.EGNYTE: "egnyte - files",
|
||||
DocumentSource.AIRTABLE: "airtable - database",
|
||||
DocumentSource.HIGHSPOT: "highspot - CRM data",
|
||||
DocumentSource.DRUPAL_WIKI: "drupal wiki - knowledge base content (pages, spaces, attachments)",
|
||||
DocumentSource.IMAP: "imap - email data",
|
||||
DocumentSource.TESTRAIL: "testrail - test case management tool for QA processes",
|
||||
}
|
||||
|
||||
@@ -64,12 +64,12 @@ _BASE_EMBEDDING_MODELS = [
|
||||
_BaseEmbeddingModel(
|
||||
name="google/gemini-embedding-001",
|
||||
dim=3072,
|
||||
index_name="danswer_chunk_google_gemini_embedding_001",
|
||||
index_name="danswer_chunk_gemini_embedding_001",
|
||||
),
|
||||
_BaseEmbeddingModel(
|
||||
name="google/text-embedding-005",
|
||||
dim=768,
|
||||
index_name="danswer_chunk_google_text_embedding_005",
|
||||
index_name="danswer_chunk_text_embedding_005",
|
||||
),
|
||||
_BaseEmbeddingModel(
|
||||
name="voyage/voyage-large-2-instruct",
|
||||
|
||||
@@ -51,10 +51,9 @@ CROSS_ENCODER_RANGE_MIN = 0
|
||||
# Generative AI Model Configs
|
||||
#####
|
||||
|
||||
# NOTE: the 3 below should only be used for dev.
|
||||
# NOTE: the 2 below should only be used for dev.
|
||||
GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY")
|
||||
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION")
|
||||
FAST_GEN_AI_MODEL_VERSION = os.environ.get("FAST_GEN_AI_MODEL_VERSION")
|
||||
|
||||
# Override the auto-detection of LLM max context length
|
||||
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None
|
||||
|
||||
@@ -38,7 +38,7 @@ class AsanaAPI:
|
||||
def __init__(
|
||||
self, api_token: str, workspace_gid: str, team_gid: str | None
|
||||
) -> None:
|
||||
self._user = None # type: ignore
|
||||
self._user = None
|
||||
self.workspace_gid = workspace_gid
|
||||
self.team_gid = team_gid
|
||||
|
||||
|
||||
@@ -9,14 +9,14 @@ from typing import Any
|
||||
from typing import Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
import boto3 # type: ignore
|
||||
from botocore.client import Config # type: ignore
|
||||
import boto3
|
||||
from botocore.client import Config
|
||||
from botocore.credentials import RefreshableCredentials
|
||||
from botocore.exceptions import ClientError
|
||||
from botocore.exceptions import NoCredentialsError
|
||||
from botocore.exceptions import PartialCredentialsError
|
||||
from botocore.session import get_session
|
||||
from mypy_boto3_s3 import S3Client # type: ignore
|
||||
from mypy_boto3_s3 import S3Client
|
||||
|
||||
from onyx.configs.app_configs import BLOB_STORAGE_SIZE_THRESHOLD
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
@@ -40,8 +40,7 @@ from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
from onyx.file_processing.extract_file_text import OnyxExtensionType
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -410,7 +409,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
continue
|
||||
|
||||
# Handle image files
|
||||
if is_accepted_file_ext(file_ext, OnyxExtensionType.Multimedia):
|
||||
if file_ext in OnyxFileExtensions.IMAGE_EXTENSIONS:
|
||||
if not self._allow_images:
|
||||
logger.debug(
|
||||
f"Skipping image file: {key} (image processing not enabled)"
|
||||
|
||||
@@ -84,6 +84,12 @@ ONE_DAY = ONE_HOUR * 24
|
||||
MAX_CACHED_IDS = 100
|
||||
|
||||
|
||||
def _get_page_id(page: dict[str, Any], allow_missing: bool = False) -> str:
|
||||
if allow_missing and "id" not in page:
|
||||
return "unknown"
|
||||
return str(page["id"])
|
||||
|
||||
|
||||
class ConfluenceCheckpoint(ConnectorCheckpoint):
|
||||
|
||||
next_page_url: str | None
|
||||
@@ -299,7 +305,7 @@ class ConfluenceConnector(
|
||||
page_id = page_url = ""
|
||||
try:
|
||||
# Extract basic page information
|
||||
page_id = page["id"]
|
||||
page_id = _get_page_id(page)
|
||||
page_title = page["title"]
|
||||
logger.info(f"Converting page {page_title} to document")
|
||||
page_url = build_confluence_document_id(
|
||||
@@ -382,7 +388,9 @@ class ConfluenceConnector(
|
||||
this function. The returned documents/connectorfailures are for non-inline attachments
|
||||
and those at the end of the page.
|
||||
"""
|
||||
attachment_query = self._construct_attachment_query(page["id"], start, end)
|
||||
attachment_query = self._construct_attachment_query(
|
||||
_get_page_id(page), start, end
|
||||
)
|
||||
attachment_failures: list[ConnectorFailure] = []
|
||||
attachment_docs: list[Document] = []
|
||||
page_url = ""
|
||||
@@ -430,7 +438,7 @@ class ConfluenceConnector(
|
||||
response = convert_attachment_to_content(
|
||||
confluence_client=self.confluence_client,
|
||||
attachment=attachment,
|
||||
page_id=page["id"],
|
||||
page_id=_get_page_id(page),
|
||||
allow_images=self.allow_images,
|
||||
)
|
||||
if response is None:
|
||||
@@ -515,14 +523,21 @@ class ConfluenceConnector(
|
||||
except HTTPError as e:
|
||||
# If we get a 403 after all retries, the user likely doesn't have permission
|
||||
# to access attachments on this page. Log and skip rather than failing the whole job.
|
||||
if e.response and e.response.status_code == 403:
|
||||
page_title = page.get("title", "unknown")
|
||||
page_id = page.get("id", "unknown")
|
||||
logger.warning(
|
||||
f"Permission denied (403) when fetching attachments for page '{page_title}' "
|
||||
page_id = _get_page_id(page, allow_missing=True)
|
||||
page_title = page.get("title", "unknown")
|
||||
if e.response and e.response.status_code in [401, 403]:
|
||||
failure_message_prefix = (
|
||||
"Invalid credentials (401)"
|
||||
if e.response.status_code == 401
|
||||
else "Permission denied (403)"
|
||||
)
|
||||
failure_message = (
|
||||
f"{failure_message_prefix} when fetching attachments for page '{page_title}' "
|
||||
f"(ID: {page_id}). The user may not have permission to query attachments on this page. "
|
||||
"Skipping attachments for this page."
|
||||
)
|
||||
logger.warning(failure_message)
|
||||
|
||||
# Build the page URL for the failure record
|
||||
try:
|
||||
page_url = build_confluence_document_id(
|
||||
@@ -537,7 +552,7 @@ class ConfluenceConnector(
|
||||
document_id=page_id,
|
||||
document_link=page_url,
|
||||
),
|
||||
failure_message=f"Permission denied (403) when fetching attachments for page '{page_title}'",
|
||||
failure_message=failure_message,
|
||||
exception=e,
|
||||
)
|
||||
]
|
||||
@@ -708,7 +723,7 @@ class ConfluenceConnector(
|
||||
expand=restrictions_expand,
|
||||
limit=_SLIM_DOC_BATCH_SIZE,
|
||||
):
|
||||
page_id = page["id"]
|
||||
page_id = _get_page_id(page)
|
||||
page_restrictions = page.get("restrictions") or {}
|
||||
page_space_key = page.get("space", {}).get("key")
|
||||
page_ancestors = page.get("ancestors", [])
|
||||
@@ -728,7 +743,7 @@ class ConfluenceConnector(
|
||||
)
|
||||
|
||||
# Query attachments for each page
|
||||
attachment_query = self._construct_attachment_query(page["id"])
|
||||
attachment_query = self._construct_attachment_query(_get_page_id(page))
|
||||
for attachment in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_query,
|
||||
expand=restrictions_expand,
|
||||
|
||||
@@ -24,9 +24,9 @@ from onyx.configs.app_configs import (
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
from onyx.file_processing.extract_file_text import OnyxExtensionType
|
||||
from onyx.file_processing.file_validation import is_valid_image_type
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.file_types import OnyxMimeTypes
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -56,15 +56,13 @@ def validate_attachment_filetype(
|
||||
"""
|
||||
media_type = attachment.get("metadata", {}).get("mediaType", "")
|
||||
if media_type.startswith("image/"):
|
||||
return is_valid_image_type(media_type)
|
||||
return media_type in OnyxMimeTypes.IMAGE_MIME_TYPES
|
||||
|
||||
# For non-image files, check if we support the extension
|
||||
title = attachment.get("title", "")
|
||||
extension = Path(title).suffix.lstrip(".").lower() if "." in title else ""
|
||||
extension = get_file_ext(title)
|
||||
|
||||
return is_accepted_file_ext(
|
||||
"." + extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
|
||||
)
|
||||
return extension in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS
|
||||
|
||||
|
||||
class AttachmentProcessingResult(BaseModel):
|
||||
|
||||
@@ -71,6 +71,13 @@ def time_str_to_utc(datetime_str: str) -> datetime:
|
||||
raise ValueError(f"Unable to parse datetime string: {datetime_str}")
|
||||
|
||||
|
||||
# TODO: use this function in other connectors
|
||||
def datetime_from_utc_timestamp(timestamp: int) -> datetime:
|
||||
"""Convert a Unix timestamp to a datetime object in UTC"""
|
||||
|
||||
return datetime.fromtimestamp(timestamp, tz=timezone.utc)
|
||||
|
||||
|
||||
def basic_expert_info_representation(info: BasicExpertInfo) -> str | None:
|
||||
if info.first_name and info.last_name:
|
||||
return f"{info.first_name} {info.middle_initial} {info.last_name}"
|
||||
|
||||
@@ -2,11 +2,11 @@ from datetime import timezone
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
from dropbox import Dropbox # type: ignore
|
||||
from dropbox.exceptions import ApiError # type:ignore
|
||||
from dropbox.exceptions import AuthError # type:ignore
|
||||
from dropbox.files import FileMetadata # type:ignore
|
||||
from dropbox.files import FolderMetadata # type:ignore
|
||||
from dropbox import Dropbox # type: ignore[import-untyped]
|
||||
from dropbox.exceptions import ApiError # type: ignore[import-untyped]
|
||||
from dropbox.exceptions import AuthError
|
||||
from dropbox.files import FileMetadata # type: ignore[import-untyped]
|
||||
from dropbox.files import FolderMetadata
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
0
backend/onyx/connectors/drupal_wiki/__init__.py
Normal file
0
backend/onyx/connectors/drupal_wiki/__init__.py
Normal file
907
backend/onyx/connectors/drupal_wiki/connector.py
Normal file
907
backend/onyx/connectors/drupal_wiki/connector.py
Normal file
@@ -0,0 +1,907 @@
|
||||
import mimetypes
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from onyx.configs.app_configs import DRUPAL_WIKI_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
datetime_from_utc_timestamp,
|
||||
)
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import rate_limit_builder
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import rl_requests
|
||||
from onyx.connectors.drupal_wiki.models import DrupalWikiCheckpoint
|
||||
from onyx.connectors.drupal_wiki.models import DrupalWikiPage
|
||||
from onyx.connectors.drupal_wiki.models import DrupalWikiPageResponse
|
||||
from onyx.connectors.drupal_wiki.models import DrupalWikiSpaceResponse
|
||||
from onyx.connectors.drupal_wiki.utils import build_drupal_wiki_document_id
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import ConnectorFailure
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
MAX_API_PAGE_SIZE = 2000 # max allowed by API
|
||||
DRUPAL_WIKI_SPACE_KEY = "space"
|
||||
|
||||
|
||||
rate_limited_get = retry_builder()(
|
||||
rate_limit_builder(max_calls=10, period=1)(rl_requests.get)
|
||||
)
|
||||
|
||||
|
||||
class DrupalWikiConnector(
|
||||
CheckpointedConnector[DrupalWikiCheckpoint],
|
||||
SlimConnector,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
spaces: list[str] | None = None,
|
||||
pages: list[str] | None = None,
|
||||
include_all_spaces: bool = False,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
drupal_wiki_scope: str | None = None,
|
||||
include_attachments: bool = False,
|
||||
allow_images: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the Drupal Wiki connector.
|
||||
|
||||
Args:
|
||||
base_url: The base URL of the Drupal Wiki instance (e.g., https://help.drupal-wiki.com)
|
||||
spaces: List of space IDs to index. If None and include_all_spaces is False, no spaces will be indexed.
|
||||
pages: List of page IDs to index. If provided, only these specific pages will be indexed.
|
||||
include_all_spaces: If True, all spaces will be indexed regardless of the spaces parameter.
|
||||
batch_size: Number of documents to process in a batch.
|
||||
continue_on_failure: If True, continue indexing even if some documents fail.
|
||||
drupal_wiki_scope: The selected tab value from the frontend. If "all_spaces", all spaces will be indexed.
|
||||
include_attachments: If True, enable processing of page attachments including images and documents.
|
||||
allow_images: If True, enable processing of image attachments.
|
||||
"""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.spaces = spaces or []
|
||||
self.pages = pages or []
|
||||
|
||||
# Determine whether to include all spaces based on the selected tab
|
||||
# If drupal_wiki_scope is "all_spaces", we should index all spaces
|
||||
# If it's "specific_spaces", we should only index the specified spaces
|
||||
# If it's None, we use the include_all_spaces parameter
|
||||
|
||||
if drupal_wiki_scope is not None:
|
||||
logger.debug(f"drupal_wiki_scope is set to {drupal_wiki_scope}")
|
||||
|
||||
self.include_all_spaces = drupal_wiki_scope == "all_spaces"
|
||||
# If scope is specific_spaces, include_all_spaces correctly defaults to False
|
||||
else:
|
||||
logger.debug(
|
||||
f"drupal_wiki_scope is not set, using include_all_spaces={include_all_spaces}"
|
||||
)
|
||||
self.include_all_spaces = include_all_spaces
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
|
||||
# Attachment processing configuration
|
||||
self.include_attachments = include_attachments
|
||||
self.allow_images = allow_images
|
||||
|
||||
self.headers: dict[str, str] = {"Accept": "application/json"}
|
||||
self._api_token: str | None = None # set by load_credentials
|
||||
|
||||
def set_allow_images(self, value: bool) -> None:
|
||||
logger.info(f"Setting allow_images to {value}.")
|
||||
self.allow_images = value
|
||||
|
||||
def _get_page_attachments(self, page_id: int) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get all attachments for a specific page.
|
||||
|
||||
Args:
|
||||
page_id: ID of the page.
|
||||
|
||||
Returns:
|
||||
List of attachment dictionaries.
|
||||
"""
|
||||
url = f"{self.base_url}/api/rest/scope/api/attachment"
|
||||
params = {"pageId": str(page_id)}
|
||||
logger.debug(f"Fetching attachments for page {page_id} from {url}")
|
||||
|
||||
try:
|
||||
response = rate_limited_get(url, headers=self.headers, params=params)
|
||||
response.raise_for_status()
|
||||
attachments = response.json()
|
||||
logger.info(f"Found {len(attachments)} attachments for page {page_id}")
|
||||
return attachments
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch attachments for page {page_id}: {e}")
|
||||
return []
|
||||
|
||||
def _download_attachment(self, attachment_id: int) -> bytes:
|
||||
"""
|
||||
Download attachment content.
|
||||
|
||||
Args:
|
||||
attachment_id: ID of the attachment to download.
|
||||
|
||||
Returns:
|
||||
Raw bytes of the attachment.
|
||||
"""
|
||||
url = f"{self.base_url}/api/rest/scope/api/attachment/{attachment_id}/download"
|
||||
logger.info(f"Downloading attachment {attachment_id} from {url}")
|
||||
|
||||
# Use headers without Accept for binary downloads
|
||||
download_headers = {"Authorization": f"Bearer {self._api_token}"}
|
||||
|
||||
response = rate_limited_get(url, headers=download_headers)
|
||||
response.raise_for_status()
|
||||
|
||||
return response.content
|
||||
|
||||
def _validate_attachment_filetype(self, attachment: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Validate if the attachment file type is supported.
|
||||
|
||||
Args:
|
||||
attachment: Attachment dictionary from Drupal Wiki API.
|
||||
|
||||
Returns:
|
||||
True if the file type is supported, False otherwise.
|
||||
"""
|
||||
file_name = attachment.get("fileName", "")
|
||||
if not file_name:
|
||||
return False
|
||||
|
||||
# Get file extension
|
||||
file_extension = get_file_ext(file_name)
|
||||
|
||||
if file_extension in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS:
|
||||
return True
|
||||
|
||||
logger.warning(f"Unsupported file type: {file_extension} for {file_name}")
|
||||
return False
|
||||
|
||||
def _get_media_type_from_filename(self, filename: str) -> str:
|
||||
"""
|
||||
Get media type from filename using the standard mimetypes library.
|
||||
|
||||
Args:
|
||||
filename: The filename.
|
||||
|
||||
Returns:
|
||||
Media type string.
|
||||
"""
|
||||
mime_type, _encoding = mimetypes.guess_type(filename)
|
||||
return mime_type or "application/octet-stream"
|
||||
|
||||
def _process_attachment(
|
||||
self,
|
||||
attachment: dict[str, Any],
|
||||
page_id: int,
|
||||
download_url: str,
|
||||
) -> tuple[list[TextSection | ImageSection], str | None]:
|
||||
"""
|
||||
Process a single attachment and return generated sections.
|
||||
|
||||
Args:
|
||||
attachment: Attachment dictionary from Drupal Wiki API.
|
||||
page_id: ID of the parent page.
|
||||
download_url: Direct download URL for the attachment.
|
||||
|
||||
Returns:
|
||||
Tuple of (sections, error_message). If error_message is not None, the
|
||||
sections list should be treated as invalid.
|
||||
"""
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
|
||||
try:
|
||||
if not self._validate_attachment_filetype(attachment):
|
||||
return (
|
||||
[],
|
||||
f"Unsupported file type: {attachment.get('fileName', 'unknown')}",
|
||||
)
|
||||
|
||||
attachment_id = attachment["id"]
|
||||
file_name = attachment.get("fileName", f"attachment_{attachment_id}")
|
||||
file_size = attachment.get("fileSize", 0)
|
||||
media_type = self._get_media_type_from_filename(file_name)
|
||||
|
||||
if file_size > DRUPAL_WIKI_ATTACHMENT_SIZE_THRESHOLD:
|
||||
return [], f"Attachment too large: {file_size} bytes"
|
||||
|
||||
try:
|
||||
raw_bytes = self._download_attachment(attachment_id)
|
||||
except Exception as e:
|
||||
return [], f"Failed to download attachment: {e}"
|
||||
|
||||
if media_type.startswith("image/"):
|
||||
if not self.allow_images:
|
||||
logger.info(
|
||||
f"Skipping image attachment {file_name} because allow_images is False",
|
||||
)
|
||||
return [], None
|
||||
|
||||
try:
|
||||
image_section, _ = store_image_and_create_section(
|
||||
image_data=raw_bytes,
|
||||
file_id=str(attachment_id),
|
||||
display_name=attachment.get(
|
||||
"name", attachment.get("fileName", "Unknown")
|
||||
),
|
||||
link=download_url,
|
||||
media_type=media_type,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
sections.append(image_section)
|
||||
logger.debug(f"Stored image attachment with file name: {file_name}")
|
||||
except Exception as e:
|
||||
return [], f"Image storage failed: {e}"
|
||||
|
||||
return sections, None
|
||||
|
||||
image_counter = 0
|
||||
|
||||
def _store_embedded_image(image_data: bytes, image_name: str) -> None:
|
||||
nonlocal image_counter
|
||||
|
||||
if not self.allow_images:
|
||||
return
|
||||
|
||||
media_for_image = self._get_media_type_from_filename(image_name)
|
||||
if media_for_image == "application/octet-stream":
|
||||
try:
|
||||
media_for_image = get_image_type_from_bytes(image_data)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Unable to determine media type for embedded image {image_name} on attachment {file_name}"
|
||||
)
|
||||
|
||||
image_counter += 1
|
||||
display_name = (
|
||||
image_name
|
||||
or f"{attachment.get('name', file_name)} - embedded image {image_counter}"
|
||||
)
|
||||
|
||||
try:
|
||||
image_section, _ = store_image_and_create_section(
|
||||
image_data=image_data,
|
||||
file_id=f"{attachment_id}_embedded_{image_counter}",
|
||||
display_name=display_name,
|
||||
link=download_url,
|
||||
media_type=media_for_image,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
sections.append(image_section)
|
||||
except Exception as err:
|
||||
logger.warning(
|
||||
f"Failed to store embedded image {image_name or image_counter} for attachment {file_name}: {err}"
|
||||
)
|
||||
|
||||
extraction_result = extract_text_and_images(
|
||||
file=BytesIO(raw_bytes),
|
||||
file_name=file_name,
|
||||
content_type=media_type,
|
||||
image_callback=_store_embedded_image if self.allow_images else None,
|
||||
)
|
||||
|
||||
text_content = extraction_result.text_content.strip()
|
||||
if text_content:
|
||||
sections.insert(0, TextSection(text=text_content, link=download_url))
|
||||
logger.info(
|
||||
f"Extracted {len(text_content)} characters from {file_name}"
|
||||
)
|
||||
elif not sections:
|
||||
return [], f"No text extracted for {file_name}"
|
||||
|
||||
return sections, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to process attachment {attachment.get('name', 'unknown')} on page {page_id}: {e}"
|
||||
)
|
||||
return [], f"Failed to process attachment: {e}"
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""
|
||||
Load credentials for the Drupal Wiki connector.
|
||||
|
||||
Args:
|
||||
credentials: Dictionary containing the API token.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
api_token = credentials.get("drupal_wiki_api_token", "").strip()
|
||||
|
||||
if not api_token:
|
||||
raise ConnectorValidationError(
|
||||
"API token is required for Drupal Wiki connector"
|
||||
)
|
||||
|
||||
self._api_token = api_token
|
||||
self.headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {api_token}",
|
||||
}
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _get_space_ids(self) -> list[int]:
|
||||
"""
|
||||
Get all space IDs from the Drupal Wiki instance.
|
||||
|
||||
Returns:
|
||||
List of space IDs (deduplicated). The list is sorted to be deterministic.
|
||||
"""
|
||||
url = f"{self.base_url}/api/rest/scope/api/space"
|
||||
size = MAX_API_PAGE_SIZE
|
||||
page = 0
|
||||
all_space_ids: set[int] = set()
|
||||
has_more = True
|
||||
last_num_ids = -1
|
||||
|
||||
while has_more and len(all_space_ids) > last_num_ids:
|
||||
last_num_ids = len(all_space_ids)
|
||||
params = {"size": size, "page": page}
|
||||
logger.debug(f"Fetching spaces from {url} (page={page}, size={size})")
|
||||
response = rate_limited_get(url, headers=self.headers, params=params)
|
||||
response.raise_for_status()
|
||||
resp_json = response.json()
|
||||
space_response = DrupalWikiSpaceResponse.model_validate(resp_json)
|
||||
|
||||
logger.info(f"Fetched {len(space_response.content)} spaces from {page}")
|
||||
# Collect ids into the set to deduplicate
|
||||
for space in space_response.content:
|
||||
all_space_ids.add(space.id)
|
||||
|
||||
# Continue if we got a full page, indicating there might be more
|
||||
has_more = len(space_response.content) >= size
|
||||
|
||||
page += 1
|
||||
|
||||
# Return a deterministic, sorted list of ids
|
||||
space_id_list = list(sorted(all_space_ids))
|
||||
logger.debug(f"Total spaces fetched: {len(space_id_list)}")
|
||||
return space_id_list
|
||||
|
||||
def _get_pages_for_space(
|
||||
self, space_id: int, modified_after: SecondsSinceUnixEpoch | None = None
|
||||
) -> list[DrupalWikiPage]:
|
||||
"""
|
||||
Get all pages for a specific space, optionally filtered by modification time.
|
||||
|
||||
Args:
|
||||
space_id: ID of the space.
|
||||
modified_after: Only return pages modified after this timestamp (seconds since Unix epoch).
|
||||
|
||||
Returns:
|
||||
List of DrupalWikiPage objects.
|
||||
"""
|
||||
url = f"{self.base_url}/api/rest/scope/api/page"
|
||||
size = MAX_API_PAGE_SIZE
|
||||
page = 0
|
||||
all_pages = []
|
||||
has_more = True
|
||||
|
||||
while has_more:
|
||||
params: dict[str, str | int] = {
|
||||
DRUPAL_WIKI_SPACE_KEY: str(space_id),
|
||||
"size": size,
|
||||
"page": page,
|
||||
}
|
||||
|
||||
# Add modifiedAfter parameter if provided
|
||||
if modified_after is not None:
|
||||
params["modifiedAfter"] = int(modified_after)
|
||||
|
||||
logger.debug(
|
||||
f"Fetching pages for space {space_id} from {url} ({page=}, {size=}, {modified_after=})"
|
||||
)
|
||||
response = rate_limited_get(url, headers=self.headers, params=params)
|
||||
response.raise_for_status()
|
||||
resp_json = response.json()
|
||||
|
||||
try:
|
||||
page_response = DrupalWikiPageResponse.model_validate(resp_json)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to validate Drupal Wiki page response: {e}")
|
||||
raise ConnectorValidationError(f"Invalid API response format: {e}")
|
||||
|
||||
logger.info(
|
||||
f"Fetched {len(page_response.content)} pages in space {space_id} (page={page})"
|
||||
)
|
||||
|
||||
# Pydantic should automatically parse content items as DrupalWikiPage objects
|
||||
# If validation fails, it will raise an exception which we should catch
|
||||
all_pages.extend(page_response.content)
|
||||
|
||||
# Continue if we got a full page, indicating there might be more
|
||||
has_more = len(page_response.content) >= size
|
||||
|
||||
page += 1
|
||||
|
||||
logger.debug(f"Total pages fetched for space {space_id}: {len(all_pages)}")
|
||||
return all_pages
|
||||
|
||||
def _get_page_content(self, page_id: int) -> DrupalWikiPage:
|
||||
"""
|
||||
Get the content of a specific page.
|
||||
|
||||
Args:
|
||||
page_id: ID of the page.
|
||||
|
||||
Returns:
|
||||
DrupalWikiPage object.
|
||||
"""
|
||||
url = f"{self.base_url}/api/rest/scope/api/page/{page_id}"
|
||||
response = rate_limited_get(url, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
|
||||
return DrupalWikiPage.model_validate(response.json())
|
||||
|
||||
def _process_page(self, page: DrupalWikiPage) -> Document | ConnectorFailure:
|
||||
"""
|
||||
Process a page and convert it to a Document.
|
||||
|
||||
Args:
|
||||
page: DrupalWikiPage object.
|
||||
|
||||
Returns:
|
||||
Document object or ConnectorFailure.
|
||||
"""
|
||||
try:
|
||||
# Extract text from HTML, handle None body
|
||||
text_content = parse_html_page_basic(page.body or "")
|
||||
|
||||
# Ensure text_content is a string, not None
|
||||
if text_content is None:
|
||||
text_content = ""
|
||||
|
||||
# Create document URL
|
||||
page_url = build_drupal_wiki_document_id(self.base_url, page.id)
|
||||
|
||||
# Create sections with just the page content
|
||||
sections: list[TextSection | ImageSection] = [
|
||||
TextSection(text=text_content, link=page_url)
|
||||
]
|
||||
|
||||
# Only process attachments if self.include_attachments is True
|
||||
if self.include_attachments:
|
||||
attachments = self._get_page_attachments(page.id)
|
||||
for attachment in attachments:
|
||||
logger.info(
|
||||
f"Processing attachment: {attachment.get('name', 'Unknown')} (ID: {attachment['id']})"
|
||||
)
|
||||
# Use downloadUrl from API; fallback to page URL
|
||||
raw_download = attachment.get("downloadUrl")
|
||||
if raw_download:
|
||||
download_url = (
|
||||
raw_download
|
||||
if raw_download.startswith("http")
|
||||
else f"{self.base_url.rstrip('/')}" + raw_download
|
||||
)
|
||||
else:
|
||||
download_url = page_url
|
||||
# Process the attachment
|
||||
attachment_sections, error = self._process_attachment(
|
||||
attachment, page.id, download_url
|
||||
)
|
||||
if error:
|
||||
logger.warning(
|
||||
f"Error processing attachment {attachment.get('name', 'Unknown')}: {error}"
|
||||
)
|
||||
continue
|
||||
|
||||
if attachment_sections:
|
||||
sections.extend(attachment_sections)
|
||||
logger.debug(
|
||||
f"Added {len(attachment_sections)} section(s) for attachment {attachment.get('name', 'Unknown')}"
|
||||
)
|
||||
|
||||
# Create metadata
|
||||
metadata: dict[str, str | list[str]] = {
|
||||
"space_id": str(page.homeSpace),
|
||||
"page_id": str(page.id),
|
||||
"type": page.type,
|
||||
}
|
||||
|
||||
# Create document
|
||||
return Document(
|
||||
id=page_url,
|
||||
sections=sections,
|
||||
source=DocumentSource.DRUPAL_WIKI,
|
||||
semantic_identifier=page.title,
|
||||
metadata=metadata,
|
||||
doc_updated_at=datetime_from_utc_timestamp(page.lastModified),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing page {page.id}: {e}")
|
||||
return ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=str(page.id),
|
||||
document_link=build_drupal_wiki_document_id(self.base_url, page.id),
|
||||
),
|
||||
failure_message=f"Error processing page {page.id}: {e}",
|
||||
exception=e,
|
||||
)
|
||||
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: DrupalWikiCheckpoint,
|
||||
) -> CheckpointOutput[DrupalWikiCheckpoint]:
|
||||
"""
|
||||
Load documents from a checkpoint.
|
||||
|
||||
Args:
|
||||
start: Start time as seconds since Unix epoch.
|
||||
end: End time as seconds since Unix epoch.
|
||||
checkpoint: Checkpoint to resume from.
|
||||
|
||||
Returns:
|
||||
Generator yielding documents and the updated checkpoint.
|
||||
"""
|
||||
# Ensure page_ids is not None
|
||||
if checkpoint.page_ids is None:
|
||||
checkpoint.page_ids = []
|
||||
|
||||
# Initialize page_ids from self.pages if not already set
|
||||
if not checkpoint.page_ids and self.pages:
|
||||
logger.info(f"Initializing page_ids from self.pages: {self.pages}")
|
||||
checkpoint.page_ids = [int(page_id.strip()) for page_id in self.pages]
|
||||
|
||||
# Ensure spaces is not None
|
||||
if checkpoint.spaces is None:
|
||||
checkpoint.spaces = []
|
||||
|
||||
while checkpoint.current_page_id_index < len(checkpoint.page_ids):
|
||||
page_id = checkpoint.page_ids[checkpoint.current_page_id_index]
|
||||
logger.debug(f"Processing page ID: {page_id}")
|
||||
|
||||
try:
|
||||
# Get the page content directly
|
||||
page = self._get_page_content(page_id)
|
||||
|
||||
# Skip pages outside the time range
|
||||
if not self._is_page_in_time_range(page.lastModified, start, end):
|
||||
logger.info(f"Skipping page {page_id} - outside time range")
|
||||
checkpoint.current_page_id_index += 1
|
||||
continue
|
||||
|
||||
# Process the page
|
||||
doc_or_failure = self._process_page(page)
|
||||
yield doc_or_failure
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing page ID {page_id}: {e}")
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=str(page_id),
|
||||
document_link=build_drupal_wiki_document_id(
|
||||
self.base_url, page_id
|
||||
),
|
||||
),
|
||||
failure_message=f"Error processing page ID {page_id}: {e}",
|
||||
exception=e,
|
||||
)
|
||||
|
||||
# Move to the next page ID
|
||||
checkpoint.current_page_id_index += 1
|
||||
|
||||
# TODO: The main benefit of CheckpointedConnectors is that they can "save their work"
|
||||
# by storing a checkpoint so transient errors are easy to recover from: simply resume
|
||||
# from the last checkpoint. The way to get checkpoints saved is to return them somewhere
|
||||
# in the middle of this function. The guarantee our checkpointing system gives to you,
|
||||
# the connector implementer, is that when you return a checkpoint, this connector will
|
||||
# at a later time (generally within a few seconds) call the load_from_checkpoint function
|
||||
# again with the checkpoint you last returned as long as has_more=True.
|
||||
|
||||
# Process spaces if include_all_spaces is True or spaces are provided
|
||||
if self.include_all_spaces or self.spaces:
|
||||
# If include_all_spaces is True, always fetch all spaces
|
||||
if self.include_all_spaces:
|
||||
logger.info("Fetching all spaces")
|
||||
# Fetch all spaces
|
||||
all_space_ids = self._get_space_ids()
|
||||
# checkpoint.spaces expects a list of ints; assign returned list
|
||||
checkpoint.spaces = all_space_ids
|
||||
logger.info(f"Found {len(checkpoint.spaces)} spaces to process")
|
||||
# Otherwise, use provided spaces if checkpoint is empty
|
||||
elif not checkpoint.spaces:
|
||||
logger.info(f"Using provided spaces: {self.spaces}")
|
||||
# Use provided spaces
|
||||
checkpoint.spaces = [int(space_id.strip()) for space_id in self.spaces]
|
||||
|
||||
# Process spaces from the checkpoint
|
||||
while checkpoint.current_space_index < len(checkpoint.spaces):
|
||||
space_id = checkpoint.spaces[checkpoint.current_space_index]
|
||||
logger.debug(f"Processing space ID: {space_id}")
|
||||
|
||||
# Get pages for the current space, filtered by start time if provided
|
||||
pages = self._get_pages_for_space(space_id, modified_after=start)
|
||||
|
||||
# Process pages from the checkpoint
|
||||
while checkpoint.current_page_index < len(pages):
|
||||
page = pages[checkpoint.current_page_index]
|
||||
logger.debug(f"Processing page: {page.title} (ID: {page.id})")
|
||||
|
||||
# For space-based pages, we already filtered by modifiedAfter in the API call
|
||||
# Only need to check the end time boundary
|
||||
if end and page.lastModified >= end:
|
||||
logger.info(
|
||||
f"Skipping page {page.id} - outside time range (after end)"
|
||||
)
|
||||
checkpoint.current_page_index += 1
|
||||
continue
|
||||
|
||||
# Process the page
|
||||
doc_or_failure = self._process_page(page)
|
||||
yield doc_or_failure
|
||||
|
||||
# Move to the next page
|
||||
checkpoint.current_page_index += 1
|
||||
|
||||
# Move to the next space
|
||||
checkpoint.current_space_index += 1
|
||||
checkpoint.current_page_index = 0
|
||||
|
||||
# All spaces and pages processed
|
||||
logger.info("Finished processing all spaces and pages")
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> DrupalWikiCheckpoint:
|
||||
"""
|
||||
Build a dummy checkpoint.
|
||||
|
||||
Returns:
|
||||
DrupalWikiCheckpoint with default values.
|
||||
"""
|
||||
return DrupalWikiCheckpoint(
|
||||
has_more=True,
|
||||
current_space_index=0,
|
||||
current_page_index=0,
|
||||
current_page_id_index=0,
|
||||
spaces=[],
|
||||
page_ids=[],
|
||||
is_processing_specific_pages=False,
|
||||
)
|
||||
|
||||
@override
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> DrupalWikiCheckpoint:
|
||||
"""
|
||||
Validate a checkpoint JSON string.
|
||||
|
||||
Args:
|
||||
checkpoint_json: JSON string representing a checkpoint.
|
||||
|
||||
Returns:
|
||||
Validated DrupalWikiCheckpoint.
|
||||
"""
|
||||
return DrupalWikiCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
# TODO: unify approach with load_from_checkpoint.
|
||||
# Ideally slim retrieval shares a lot of the same code with non-slim
|
||||
# and we pass in a param is_slim to the main helper function
|
||||
# that does the retrieval.
|
||||
@override
|
||||
def retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
"""
|
||||
Retrieve all slim documents.
|
||||
|
||||
Args:
|
||||
start: Start time as seconds since Unix epoch.
|
||||
end: End time as seconds since Unix epoch.
|
||||
callback: Callback for indexing heartbeat.
|
||||
|
||||
Returns:
|
||||
Generator yielding batches of SlimDocument objects.
|
||||
"""
|
||||
slim_docs: list[SlimDocument] = []
|
||||
logger.info(
|
||||
f"Starting retrieve_all_slim_docs with include_all_spaces={self.include_all_spaces}, spaces={self.spaces}"
|
||||
)
|
||||
|
||||
# Process specific page IDs if provided
|
||||
if self.pages:
|
||||
logger.info(f"Processing specific pages: {self.pages}")
|
||||
for page_id in self.pages:
|
||||
try:
|
||||
# Get the page content directly
|
||||
page_content = self._get_page_content(int(page_id.strip()))
|
||||
|
||||
# Skip pages outside the time range
|
||||
if not self._is_page_in_time_range(
|
||||
page_content.lastModified, start, end
|
||||
):
|
||||
logger.info(f"Skipping page {page_id} - outside time range")
|
||||
continue
|
||||
|
||||
# Create slim document for the page
|
||||
page_url = build_drupal_wiki_document_id(
|
||||
self.base_url, page_content.id
|
||||
)
|
||||
slim_docs.append(
|
||||
SlimDocument(
|
||||
id=page_url,
|
||||
)
|
||||
)
|
||||
logger.debug(f"Added slim document for page {page_content.id}")
|
||||
|
||||
# Process attachments for this page
|
||||
attachments = self._get_page_attachments(page_content.id)
|
||||
for attachment in attachments:
|
||||
if self._validate_attachment_filetype(attachment):
|
||||
attachment_url = f"{page_url}#attachment-{attachment['id']}"
|
||||
slim_docs.append(
|
||||
SlimDocument(
|
||||
id=attachment_url,
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f"Added slim document for attachment {attachment['id']}"
|
||||
)
|
||||
|
||||
# Yield batch if it reaches the batch size
|
||||
if len(slim_docs) >= self.batch_size:
|
||||
logger.debug(
|
||||
f"Yielding batch of {len(slim_docs)} slim documents"
|
||||
)
|
||||
yield slim_docs
|
||||
slim_docs = []
|
||||
|
||||
if callback and callback.should_stop():
|
||||
return
|
||||
if callback:
|
||||
callback.progress("retrieve_all_slim_docs", 1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing page ID {page_id} for slim documents: {e}"
|
||||
)
|
||||
|
||||
# Process spaces if include_all_spaces is True or spaces are provided
|
||||
if self.include_all_spaces or self.spaces:
|
||||
logger.info("Processing spaces for slim documents")
|
||||
# Get spaces to process
|
||||
spaces_to_process = []
|
||||
if self.include_all_spaces:
|
||||
logger.info("Fetching all spaces for slim documents")
|
||||
# Fetch all spaces
|
||||
all_space_ids = self._get_space_ids()
|
||||
spaces_to_process = all_space_ids
|
||||
logger.info(f"Found {len(spaces_to_process)} spaces to process")
|
||||
else:
|
||||
logger.info(f"Using provided spaces: {self.spaces}")
|
||||
# Use provided spaces
|
||||
spaces_to_process = [int(space_id.strip()) for space_id in self.spaces]
|
||||
|
||||
# Process each space
|
||||
for space_id in spaces_to_process:
|
||||
logger.info(f"Processing space ID: {space_id}")
|
||||
# Get pages for the current space, filtered by start time if provided
|
||||
pages = self._get_pages_for_space(space_id, modified_after=start)
|
||||
|
||||
# Process each page
|
||||
for page in pages:
|
||||
logger.debug(f"Processing page: {page.title} (ID: {page.id})")
|
||||
# Skip pages outside the time range
|
||||
if end and page.lastModified >= end:
|
||||
logger.info(
|
||||
f"Skipping page {page.id} - outside time range (after end)"
|
||||
)
|
||||
continue
|
||||
|
||||
# Create slim document for the page
|
||||
page_url = build_drupal_wiki_document_id(self.base_url, page.id)
|
||||
slim_docs.append(
|
||||
SlimDocument(
|
||||
id=page_url,
|
||||
)
|
||||
)
|
||||
logger.info(f"Added slim document for page {page.id}")
|
||||
|
||||
# Process attachments for this page
|
||||
attachments = self._get_page_attachments(page.id)
|
||||
for attachment in attachments:
|
||||
if self._validate_attachment_filetype(attachment):
|
||||
attachment_url = f"{page_url}#attachment-{attachment['id']}"
|
||||
slim_docs.append(
|
||||
SlimDocument(
|
||||
id=attachment_url,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f"Added slim document for attachment {attachment['id']}"
|
||||
)
|
||||
|
||||
# Yield batch if it reaches the batch size
|
||||
if len(slim_docs) >= self.batch_size:
|
||||
logger.info(
|
||||
f"Yielding batch of {len(slim_docs)} slim documents"
|
||||
)
|
||||
yield slim_docs
|
||||
slim_docs = []
|
||||
|
||||
if callback and callback.should_stop():
|
||||
return
|
||||
if callback:
|
||||
callback.progress("retrieve_all_slim_docs", 1)
|
||||
|
||||
# Yield remaining documents
|
||||
if slim_docs:
|
||||
logger.debug(f"Yielding final batch of {len(slim_docs)} slim documents")
|
||||
yield slim_docs
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""
|
||||
Validate the connector settings.
|
||||
|
||||
Raises:
|
||||
ConnectorValidationError: If the settings are invalid.
|
||||
"""
|
||||
if not self.headers:
|
||||
raise ConnectorMissingCredentialError("Drupal Wiki")
|
||||
|
||||
try:
|
||||
# Try to fetch spaces to validate the connection
|
||||
# Call the new helper which returns the list of space ids
|
||||
self._get_space_ids()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ConnectorValidationError(f"Failed to connect to Drupal Wiki: {e}")
|
||||
|
||||
def _is_page_in_time_range(
|
||||
self,
|
||||
last_modified: int,
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a page's last modified timestamp falls within the specified time range.
|
||||
|
||||
Args:
|
||||
last_modified: The page's last modified timestamp.
|
||||
start: Start time as seconds since Unix epoch (inclusive).
|
||||
end: End time as seconds since Unix epoch (exclusive).
|
||||
|
||||
Returns:
|
||||
True if the page is within the time range, False otherwise.
|
||||
"""
|
||||
return (not start or last_modified >= start) and (
|
||||
not end or last_modified < end
|
||||
)
|
||||
75
backend/onyx/connectors/drupal_wiki/models.py
Normal file
75
backend/onyx/connectors/drupal_wiki/models.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from enum import Enum
|
||||
from typing import Generic
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.connectors.interfaces import ConnectorCheckpoint
|
||||
|
||||
|
||||
class SpaceAccessStatus(str, Enum):
|
||||
"""Enum for Drupal Wiki space access status"""
|
||||
|
||||
PRIVATE = "PRIVATE"
|
||||
ANONYMOUS = "ANONYMOUS"
|
||||
AUTHENTICATED = "AUTHENTICATED"
|
||||
|
||||
|
||||
class DrupalWikiSpace(BaseModel):
|
||||
"""Model for a Drupal Wiki space"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
type: str
|
||||
description: Optional[str] = None
|
||||
accessStatus: Optional[SpaceAccessStatus] = None
|
||||
color: Optional[str] = None
|
||||
|
||||
|
||||
class DrupalWikiPage(BaseModel):
|
||||
"""Model for a Drupal Wiki page"""
|
||||
|
||||
id: int
|
||||
title: str
|
||||
homeSpace: int
|
||||
lastModified: int
|
||||
type: str
|
||||
body: Optional[str] = None
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class DrupalWikiBaseResponse(BaseModel, Generic[T]):
|
||||
"""Base model for Drupal Wiki API responses"""
|
||||
|
||||
totalPages: int
|
||||
totalElements: int
|
||||
size: int
|
||||
content: List[T]
|
||||
number: int
|
||||
first: bool
|
||||
last: bool
|
||||
numberOfElements: int
|
||||
empty: bool
|
||||
|
||||
|
||||
class DrupalWikiSpaceResponse(DrupalWikiBaseResponse[DrupalWikiSpace]):
|
||||
"""Model for the response from the Drupal Wiki spaces API"""
|
||||
|
||||
|
||||
class DrupalWikiPageResponse(DrupalWikiBaseResponse[DrupalWikiPage]):
|
||||
"""Model for the response from the Drupal Wiki pages API"""
|
||||
|
||||
|
||||
class DrupalWikiCheckpoint(ConnectorCheckpoint):
|
||||
"""Checkpoint for the Drupal Wiki connector"""
|
||||
|
||||
current_space_index: int = 0
|
||||
current_page_index: int = 0
|
||||
current_page_id_index: int = 0
|
||||
spaces: List[int] = []
|
||||
page_ids: List[int] = []
|
||||
is_processing_specific_pages: bool = False
|
||||
10
backend/onyx/connectors/drupal_wiki/utils.py
Normal file
10
backend/onyx/connectors/drupal_wiki/utils.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_drupal_wiki_document_id(base_url: str, page_id: int) -> str:
|
||||
"""Build a document ID for a Drupal Wiki page using the real URL format"""
|
||||
# Ensure base_url ends with a slash
|
||||
base_url = base_url.rstrip("/") + "/"
|
||||
return f"{base_url}node/{page_id}"
|
||||
@@ -28,10 +28,8 @@ from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import detect_encoding
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_text_file_extension
|
||||
from onyx.file_processing.extract_file_text import OnyxExtensionType
|
||||
from onyx.file_processing.extract_file_text import read_text_file
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import request_with_retries
|
||||
|
||||
@@ -70,14 +68,15 @@ def _process_egnyte_file(
|
||||
|
||||
file_name = file_metadata["name"]
|
||||
extension = get_file_ext(file_name)
|
||||
if not is_accepted_file_ext(
|
||||
extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
|
||||
):
|
||||
|
||||
# Explicitly excluding image extensions here. TODO: consider allowing images
|
||||
if extension not in OnyxFileExtensions.TEXT_AND_DOCUMENT_EXTENSIONS:
|
||||
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
||||
return None
|
||||
|
||||
# Extract text content based on file type
|
||||
if is_text_file_extension(file_name):
|
||||
# TODO @wenxi-onyx: convert to extract_text_and_images
|
||||
if extension in OnyxFileExtensions.PLAIN_TEXT_EXTENSIONS:
|
||||
encoding = detect_encoding(file_content)
|
||||
file_content_raw, file_metadata = read_text_file(
|
||||
file_content, encoding=encoding, ignore_onyx_metadata=False
|
||||
|
||||
@@ -18,8 +18,7 @@ from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
from onyx.file_processing.extract_file_text import OnyxExtensionType
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -90,7 +89,7 @@ def _process_file(
|
||||
# Get file extension and determine file type
|
||||
extension = get_file_ext(file_name)
|
||||
|
||||
if not is_accepted_file_ext(extension, OnyxExtensionType.All):
|
||||
if extension not in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS:
|
||||
logger.warning(
|
||||
f"Skipping file '{file_name}' with unrecognized extension '{extension}'"
|
||||
)
|
||||
@@ -111,7 +110,7 @@ def _process_file(
|
||||
title = metadata.get("title") or file_display_name
|
||||
|
||||
# 1) If the file itself is an image, handle that scenario quickly
|
||||
if extension in LoadConnector.IMAGE_EXTENSIONS:
|
||||
if extension in OnyxFileExtensions.IMAGE_EXTENSIONS:
|
||||
# Read the image data
|
||||
image_data = file.read()
|
||||
if not image_data:
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
|
||||
@@ -14,9 +14,9 @@ from typing import cast
|
||||
from typing import Protocol
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from google.auth.exceptions import RefreshError # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from google.auth.exceptions import RefreshError
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -1006,7 +1006,7 @@ class GoogleDriveConnector(
|
||||
file.user_email,
|
||||
)
|
||||
if file.error is None:
|
||||
file.error = exc # type: ignore[assignment]
|
||||
file.error = exc
|
||||
yield file
|
||||
continue
|
||||
|
||||
|
||||
@@ -29,14 +29,14 @@ from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import docx_to_text_and_images
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.extract_file_text import pptx_to_text
|
||||
from onyx.file_processing.extract_file_text import read_docx_file
|
||||
from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.extract_file_text import xlsx_to_text
|
||||
from onyx.file_processing.file_validation import is_valid_image_type
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.file_types import OnyxMimeTypes
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import (
|
||||
@@ -114,14 +114,6 @@ def onyx_document_id_from_drive_file(file: GoogleDriveFileType) -> str:
|
||||
return urlunparse(parsed_url)
|
||||
|
||||
|
||||
def is_gdrive_image_mime_type(mime_type: str) -> bool:
|
||||
"""
|
||||
Return True if the mime_type is a common image type in GDrive.
|
||||
(e.g. 'image/png', 'image/jpeg')
|
||||
"""
|
||||
return is_valid_image_type(mime_type)
|
||||
|
||||
|
||||
def download_request(
|
||||
service: GoogleDriveService, file_id: str, size_threshold: int
|
||||
) -> bytes:
|
||||
@@ -173,7 +165,7 @@ def _download_and_extract_sections_basic(
|
||||
def response_call() -> bytes:
|
||||
return download_request(service, file_id, size_threshold)
|
||||
|
||||
if is_gdrive_image_mime_type(mime_type):
|
||||
if mime_type in OnyxMimeTypes.IMAGE_MIME_TYPES:
|
||||
# Skip images if not explicitly enabled
|
||||
if not allow_images:
|
||||
return []
|
||||
@@ -222,7 +214,7 @@ def _download_and_extract_sections_basic(
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
text, _ = docx_to_text_and_images(io.BytesIO(response_call()))
|
||||
text, _ = read_docx_file(io.BytesIO(response_call()))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif (
|
||||
@@ -260,7 +252,7 @@ def _download_and_extract_sections_basic(
|
||||
|
||||
# Final attempt at extracting text
|
||||
file_ext = get_file_ext(file.get("name", ""))
|
||||
if file_ext not in ALL_ACCEPTED_FILE_EXTENSIONS:
|
||||
if file_ext not in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS:
|
||||
logger.warning(f"Skipping file {file.get('name')} due to extension.")
|
||||
return []
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from google.auth.transport.requests import Request # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
|
||||
from onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
|
||||
@@ -4,7 +4,7 @@ from urllib.parse import parse_qs
|
||||
from urllib.parse import ParseResult
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -179,7 +179,7 @@ def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
|
||||
get_kv_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
) # type: ignore
|
||||
)
|
||||
return str(auth_url)
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from google.auth.exceptions import RefreshError # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.discovery import build # type: ignore
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
from google.auth.exceptions import RefreshError
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
from googleapiclient.discovery import build # type: ignore[import-untyped]
|
||||
from googleapiclient.discovery import Resource
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -23,9 +23,8 @@ from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import ACCEPTED_DOCUMENT_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -309,10 +308,7 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync)
|
||||
|
||||
elif (
|
||||
is_valid_format
|
||||
and (
|
||||
file_extension in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
|
||||
or file_extension in ACCEPTED_DOCUMENT_FILE_EXTENSIONS
|
||||
)
|
||||
and file_extension in OnyxFileExtensions.TEXT_AND_DOCUMENT_EXTENSIONS
|
||||
and can_download
|
||||
):
|
||||
content_response = self.client.get_item_content(item_id)
|
||||
|
||||
@@ -27,8 +27,6 @@ CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
|
||||
class BaseConnector(abc.ABC, Generic[CT]):
|
||||
REDIS_KEY_PREFIX = "da_connector_data:"
|
||||
# Common image file extensions supported across connectors
|
||||
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterable
|
||||
@@ -9,6 +10,7 @@ from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from jira import JIRA
|
||||
from jira.exceptions import JIRAError
|
||||
from jira.resources import Issue
|
||||
from more_itertools import chunked
|
||||
from typing_extensions import override
|
||||
@@ -134,6 +136,80 @@ def _perform_jql_search(
|
||||
return _perform_jql_search_v2(jira_client, jql, start, max_results, fields)
|
||||
|
||||
|
||||
def _handle_jira_search_error(e: Exception, jql: str) -> None:
|
||||
"""Handle common Jira search errors and raise appropriate exceptions.
|
||||
|
||||
Args:
|
||||
e: The exception raised by the Jira API
|
||||
jql: The JQL query that caused the error
|
||||
|
||||
Raises:
|
||||
ConnectorValidationError: For HTTP 400 errors (invalid JQL or project)
|
||||
CredentialExpiredError: For HTTP 401 errors
|
||||
InsufficientPermissionsError: For HTTP 403 errors
|
||||
Exception: Re-raises the original exception for other error types
|
||||
"""
|
||||
# Extract error information from the exception
|
||||
error_text = ""
|
||||
status_code = None
|
||||
|
||||
def _format_error_text(error_payload: Any) -> str:
|
||||
error_messages = (
|
||||
error_payload.get("errorMessages", [])
|
||||
if isinstance(error_payload, dict)
|
||||
else []
|
||||
)
|
||||
if error_messages:
|
||||
return (
|
||||
"; ".join(error_messages)
|
||||
if isinstance(error_messages, list)
|
||||
else str(error_messages)
|
||||
)
|
||||
return str(error_payload)
|
||||
|
||||
# Try to get status code and error text from JIRAError or requests response
|
||||
if hasattr(e, "status_code"):
|
||||
status_code = e.status_code
|
||||
raw_text = getattr(e, "text", "")
|
||||
if isinstance(raw_text, str):
|
||||
try:
|
||||
error_text = _format_error_text(json.loads(raw_text))
|
||||
except Exception:
|
||||
error_text = raw_text
|
||||
else:
|
||||
error_text = str(raw_text)
|
||||
elif hasattr(e, "response") and e.response is not None:
|
||||
status_code = e.response.status_code
|
||||
# Try JSON first, fall back to text
|
||||
try:
|
||||
error_json = e.response.json()
|
||||
error_text = _format_error_text(error_json)
|
||||
except Exception:
|
||||
error_text = e.response.text
|
||||
|
||||
# Handle specific status codes
|
||||
if status_code == 400:
|
||||
if "does not exist for the field 'project'" in error_text:
|
||||
raise ConnectorValidationError(
|
||||
f"The specified Jira project does not exist or you don't have access to it. "
|
||||
f"JQL query: {jql}. Error: {error_text}"
|
||||
)
|
||||
raise ConnectorValidationError(
|
||||
f"Invalid JQL query. JQL: {jql}. Error: {error_text}"
|
||||
)
|
||||
elif status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Jira credentials are expired or invalid (HTTP 401)."
|
||||
)
|
||||
elif status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
f"Insufficient permissions to execute JQL query. JQL: {jql}"
|
||||
)
|
||||
|
||||
# Re-raise for other error types
|
||||
raise e
|
||||
|
||||
|
||||
def enhanced_search_ids(
|
||||
jira_client: JIRA, jql: str, nextPageToken: str | None = None
|
||||
) -> tuple[list[str], str | None]:
|
||||
@@ -149,8 +225,15 @@ def enhanced_search_ids(
|
||||
"nextPageToken": nextPageToken,
|
||||
"fields": "id",
|
||||
}
|
||||
response = jira_client._session.get(enhanced_search_path, params=params).json()
|
||||
return [str(issue["id"]) for issue in response["issues"]], response.get(
|
||||
try:
|
||||
response = jira_client._session.get(enhanced_search_path, params=params)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
except Exception as e:
|
||||
_handle_jira_search_error(e, jql)
|
||||
raise # Explicitly re-raise for type checker, should never reach here
|
||||
|
||||
return [str(issue["id"]) for issue in response_json["issues"]], response_json.get(
|
||||
"nextPageToken"
|
||||
)
|
||||
|
||||
@@ -232,12 +315,16 @@ def _perform_jql_search_v2(
|
||||
f"Fetching Jira issues with JQL: {jql}, "
|
||||
f"starting at {start}, max results: {max_results}"
|
||||
)
|
||||
issues = jira_client.search_issues(
|
||||
jql_str=jql,
|
||||
startAt=start,
|
||||
maxResults=max_results,
|
||||
fields=fields,
|
||||
)
|
||||
try:
|
||||
issues = jira_client.search_issues(
|
||||
jql_str=jql,
|
||||
startAt=start,
|
||||
maxResults=max_results,
|
||||
fields=fields,
|
||||
)
|
||||
except JIRAError as e:
|
||||
_handle_jira_search_error(e, jql)
|
||||
raise # Explicitly re-raise for type checker, should never reach here
|
||||
|
||||
for issue in issues:
|
||||
if isinstance(issue, Issue):
|
||||
|
||||
@@ -10,7 +10,7 @@ from urllib.parse import urlparse
|
||||
from urllib.parse import urlunparse
|
||||
|
||||
from pywikibot import family # type: ignore[import-untyped]
|
||||
from pywikibot import pagegenerators # type: ignore[import-untyped]
|
||||
from pywikibot import pagegenerators
|
||||
from pywikibot.scripts import generate_family_file # type: ignore[import-untyped]
|
||||
from pywikibot.scripts.generate_user_files import pywikibot # type: ignore[import-untyped]
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ from typing import cast
|
||||
from typing import ClassVar
|
||||
|
||||
import pywikibot.time # type: ignore[import-untyped]
|
||||
from pywikibot import pagegenerators # type: ignore[import-untyped]
|
||||
from pywikibot import textlib # type: ignore[import-untyped]
|
||||
from pywikibot import pagegenerators
|
||||
from pywikibot import textlib
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
@@ -196,6 +196,10 @@ CONNECTOR_CLASS_MAP = {
|
||||
module_path="onyx.connectors.highspot.connector",
|
||||
class_name="HighspotConnector",
|
||||
),
|
||||
DocumentSource.DRUPAL_WIKI: ConnectorMapping(
|
||||
module_path="onyx.connectors.drupal_wiki.connector",
|
||||
class_name="DrupalWikiConnector",
|
||||
),
|
||||
DocumentSource.IMAP: ConnectorMapping(
|
||||
module_path="onyx.connectors.imap.connector",
|
||||
class_name="ImapConnector",
|
||||
|
||||
@@ -55,12 +55,10 @@ from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.sharepoint.connector_utils import get_sharepoint_external_access
|
||||
from onyx.file_processing.extract_file_text import ACCEPTED_IMAGE_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
from onyx.file_processing.extract_file_text import OnyxExtensionType
|
||||
from onyx.file_processing.file_validation import EXCLUDED_IMAGE_TYPES
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.file_types import OnyxMimeTypes
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -328,7 +326,7 @@ def _convert_driveitem_to_document_with_permissions(
|
||||
try:
|
||||
item_json = driveitem.to_json()
|
||||
mime_type = item_json.get("file", {}).get("mimeType")
|
||||
if not mime_type or mime_type in EXCLUDED_IMAGE_TYPES:
|
||||
if not mime_type or mime_type in OnyxMimeTypes.EXCLUDED_IMAGE_TYPES:
|
||||
# NOTE: this function should be refactored to look like Drive doc_conversion.py pattern
|
||||
# for now, this skip must happen before we download the file
|
||||
# Similar to Google Drive, we'll just semi-silently skip excluded image types
|
||||
@@ -388,14 +386,14 @@ def _convert_driveitem_to_document_with_permissions(
|
||||
return None
|
||||
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
file_ext = driveitem.name.split(".")[-1]
|
||||
file_ext = get_file_ext(driveitem.name)
|
||||
|
||||
if not content_bytes:
|
||||
logger.warning(
|
||||
f"Zero-length content for '{driveitem.name}'. Skipping text/image extraction."
|
||||
)
|
||||
elif "." + file_ext in ACCEPTED_IMAGE_FILE_EXTENSIONS:
|
||||
# NOTE: this if should use is_valid_image_type instead with mime_type
|
||||
elif file_ext in OnyxFileExtensions.IMAGE_EXTENSIONS:
|
||||
# NOTE: this if should probably check mime_type instead
|
||||
image_section, _ = store_image_and_create_section(
|
||||
image_data=content_bytes,
|
||||
file_id=driveitem.id,
|
||||
@@ -418,7 +416,7 @@ def _convert_driveitem_to_document_with_permissions(
|
||||
|
||||
# The only mime type that would be returned by get_image_type_from_bytes that is in
|
||||
# EXCLUDED_IMAGE_TYPES is image/gif.
|
||||
if mime_type in EXCLUDED_IMAGE_TYPES:
|
||||
if mime_type in OnyxMimeTypes.EXCLUDED_IMAGE_TYPES:
|
||||
logger.debug(
|
||||
"Skipping embedded image of excluded type %s for %s",
|
||||
mime_type,
|
||||
@@ -1506,7 +1504,7 @@ class SharepointConnector(
|
||||
)
|
||||
for driveitem in driveitems:
|
||||
driveitem_extension = get_file_ext(driveitem.name)
|
||||
if not is_accepted_file_ext(driveitem_extension, OnyxExtensionType.All):
|
||||
if driveitem_extension not in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS:
|
||||
logger.warning(
|
||||
f"Skipping {driveitem.web_url} as it is not a supported file type"
|
||||
)
|
||||
@@ -1514,7 +1512,7 @@ class SharepointConnector(
|
||||
|
||||
# Only yield empty documents if they are PDFs or images
|
||||
should_yield_if_empty = (
|
||||
driveitem_extension in ACCEPTED_IMAGE_FILE_EXTENSIONS
|
||||
driveitem_extension in OnyxFileExtensions.IMAGE_EXTENSIONS
|
||||
or driveitem_extension == ".pdf"
|
||||
)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
@@ -50,7 +50,7 @@ class TeamsCheckpoint(ConnectorCheckpoint):
|
||||
|
||||
|
||||
class TeamsConnector(
|
||||
CheckpointedConnector[TeamsCheckpoint],
|
||||
CheckpointedConnectorWithPermSync[TeamsCheckpoint],
|
||||
SlimConnectorWithPermSync,
|
||||
):
|
||||
MAX_WORKERS = 10
|
||||
@@ -247,13 +247,23 @@ class TeamsConnector(
|
||||
has_more=bool(todos),
|
||||
)
|
||||
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: TeamsCheckpoint,
|
||||
) -> CheckpointOutput[TeamsCheckpoint]:
|
||||
# Teams already fetches external_access (permissions) for each document
|
||||
# in _convert_thread_to_document, so we can just delegate to load_from_checkpoint
|
||||
return self.load_from_checkpoint(start, end, checkpoint)
|
||||
|
||||
# impls for SlimConnectorWithPermSync
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
_end: SecondsSinceUnixEpoch | None = None,
|
||||
_callback: IndexingHeartbeatInterface | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
start = start or 0
|
||||
|
||||
@@ -302,6 +312,12 @@ class TeamsConnector(
|
||||
)
|
||||
|
||||
if len(slim_doc_buffer) >= _SLIM_DOC_BATCH_SIZE:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"retrieve_all_slim_docs_perm_sync: Stop signal detected"
|
||||
)
|
||||
callback.progress("retrieve_all_slim_docs_perm_sync", 1)
|
||||
yield slim_doc_buffer
|
||||
slim_doc_buffer = []
|
||||
|
||||
|
||||
@@ -4,9 +4,9 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
|
||||
from office365.graph_client import GraphClient # type: ignore
|
||||
from office365.teams.channels.channel import Channel # type: ignore
|
||||
from office365.teams.channels.channel import ConversationMember # type: ignore
|
||||
from office365.graph_client import GraphClient # type: ignore[import-untyped]
|
||||
from office365.teams.channels.channel import Channel # type: ignore[import-untyped]
|
||||
from office365.teams.channels.channel import ConversationMember
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
|
||||
@@ -18,6 +18,7 @@ from oauthlib.oauth2 import BackendApplicationClient
|
||||
from playwright.sync_api import BrowserContext
|
||||
from playwright.sync_api import Playwright
|
||||
from playwright.sync_api import sync_playwright
|
||||
from playwright.sync_api import TimeoutError
|
||||
from requests_oauthlib import OAuth2Session # type:ignore
|
||||
from urllib3.exceptions import MaxRetryError
|
||||
|
||||
@@ -86,6 +87,8 @@ WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20
|
||||
IFRAME_TEXT_LENGTH_THRESHOLD = 700
|
||||
# Message indicating JavaScript is disabled, which often appears when scraping fails
|
||||
JAVASCRIPT_DISABLED_MESSAGE = "You have JavaScript disabled in your browser"
|
||||
# Grace period after page navigation to allow bot-detection challenges to complete
|
||||
BOT_DETECTION_GRACE_PERIOD_MS = 5000
|
||||
|
||||
# Define common headers that mimic a real browser
|
||||
DEFAULT_USER_AGENT = (
|
||||
@@ -554,12 +557,17 @@ class WebConnector(LoadConnector):
|
||||
|
||||
page = session_ctx.playwright_context.new_page()
|
||||
try:
|
||||
# Can't use wait_until="networkidle" because it interferes with the scrolling behavior
|
||||
# Use "commit" instead of "domcontentloaded" to avoid hanging on bot-detection pages
|
||||
# that may never fire domcontentloaded. "commit" waits only for navigation to be
|
||||
# committed (response received), then we add a short wait for initial rendering.
|
||||
page_response = page.goto(
|
||||
initial_url,
|
||||
timeout=30000, # 30 seconds
|
||||
wait_until="domcontentloaded", # Wait for DOM to be ready
|
||||
wait_until="commit", # Wait for navigation to commit
|
||||
)
|
||||
# Give the page a moment to start rendering after navigation commits.
|
||||
# Allows CloudFlare and other bot-detection challenges to complete.
|
||||
page.wait_for_timeout(BOT_DETECTION_GRACE_PERIOD_MS)
|
||||
|
||||
last_modified = (
|
||||
page_response.header_value("Last-Modified") if page_response else None
|
||||
@@ -584,8 +592,15 @@ class WebConnector(LoadConnector):
|
||||
previous_height = page.evaluate("document.body.scrollHeight")
|
||||
while scroll_attempts < WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS:
|
||||
page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
|
||||
# wait for the content to load if we scrolled
|
||||
page.wait_for_load_state("networkidle", timeout=30000)
|
||||
# Wait for content to load, but catch timeout if page never reaches networkidle
|
||||
# (e.g., CloudFlare protection keeps making requests)
|
||||
try:
|
||||
page.wait_for_load_state(
|
||||
"networkidle", timeout=BOT_DETECTION_GRACE_PERIOD_MS
|
||||
)
|
||||
except TimeoutError:
|
||||
# If networkidle times out, just give it a moment for content to render
|
||||
time.sleep(1)
|
||||
time.sleep(0.5) # let javascript run
|
||||
|
||||
new_height = page.evaluate("document.body.scrollHeight")
|
||||
|
||||
@@ -21,6 +21,13 @@ class OptionalSearchSetting(str, Enum):
|
||||
|
||||
|
||||
class QueryType(str, Enum):
|
||||
"""
|
||||
The type of first-pass query to use for hybrid search.
|
||||
|
||||
The values of this enum are injected into the ranking profile name which
|
||||
should match the name in the schema.
|
||||
"""
|
||||
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@ from slack_sdk.errors import SlackApiError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG
|
||||
from onyx.configs.app_configs import MAX_SLACK_THREAD_CONTEXT_MESSAGES
|
||||
from onyx.configs.app_configs import SLACK_THREAD_CONTEXT_BATCH_SIZE
|
||||
from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
@@ -39,7 +41,7 @@ from onyx.federated_connectors.slack.models import SlackEntities
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.onyxbot.slack.models import ChannelType
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -623,33 +625,55 @@ def merge_slack_messages(
|
||||
return merged_messages, docid_to_message, all_filtered_channels
|
||||
|
||||
|
||||
def get_contextualized_thread_text(
|
||||
class SlackRateLimitError(Exception):
|
||||
"""Raised when Slack API returns a rate limit error (429)."""
|
||||
|
||||
|
||||
class ThreadContextResult:
|
||||
"""Result wrapper for thread context fetch that captures error type."""
|
||||
|
||||
__slots__ = ("text", "is_rate_limited", "is_error")
|
||||
|
||||
def __init__(
|
||||
self, text: str, is_rate_limited: bool = False, is_error: bool = False
|
||||
):
|
||||
self.text = text
|
||||
self.is_rate_limited = is_rate_limited
|
||||
self.is_error = is_error
|
||||
|
||||
@classmethod
|
||||
def success(cls, text: str) -> "ThreadContextResult":
|
||||
return cls(text)
|
||||
|
||||
@classmethod
|
||||
def rate_limited(cls, original_text: str) -> "ThreadContextResult":
|
||||
return cls(original_text, is_rate_limited=True)
|
||||
|
||||
@classmethod
|
||||
def error(cls, original_text: str) -> "ThreadContextResult":
|
||||
return cls(original_text, is_error=True)
|
||||
|
||||
|
||||
def _fetch_thread_context(
|
||||
message: SlackMessage, access_token: str, team_id: str | None = None
|
||||
) -> str:
|
||||
) -> ThreadContextResult:
|
||||
"""
|
||||
Retrieves the initial thread message as well as the text following the message
|
||||
and combines them into a single string. If the slack query fails, returns the
|
||||
original message text.
|
||||
Fetch thread context for a message, returning a result object.
|
||||
|
||||
The idea is that the message (the one that actually matched the search), the
|
||||
initial thread message, and the replies to the message are important in answering
|
||||
the user's query.
|
||||
|
||||
Args:
|
||||
message: The SlackMessage to get context for
|
||||
access_token: Slack OAuth access token
|
||||
team_id: Slack team ID for caching user profiles (optional but recommended)
|
||||
Returns ThreadContextResult with:
|
||||
- success: enriched thread text
|
||||
- rate_limited: original text + flag indicating we should stop
|
||||
- error: original text for other failures (graceful degradation)
|
||||
"""
|
||||
channel_id = message.channel_id
|
||||
thread_id = message.thread_id
|
||||
message_id = message.message_id
|
||||
|
||||
# if it's not a thread, return the message text
|
||||
# If not a thread, return original text as success
|
||||
if thread_id is None:
|
||||
return message.text
|
||||
return ThreadContextResult.success(message.text)
|
||||
|
||||
# get the thread messages
|
||||
slack_client = WebClient(token=access_token)
|
||||
slack_client = WebClient(token=access_token, timeout=30)
|
||||
try:
|
||||
response = slack_client.conversations_replies(
|
||||
channel=channel_id,
|
||||
@@ -658,19 +682,44 @@ def get_contextualized_thread_text(
|
||||
response.validate()
|
||||
messages: list[dict[str, Any]] = response.get("messages", [])
|
||||
except SlackApiError as e:
|
||||
logger.error(f"Slack API error in get_contextualized_thread_text: {e}")
|
||||
return message.text
|
||||
# Check for rate limit error specifically
|
||||
if e.response and e.response.status_code == 429:
|
||||
logger.warning(
|
||||
f"Slack rate limit hit while fetching thread context for {channel_id}/{thread_id}"
|
||||
)
|
||||
return ThreadContextResult.rate_limited(message.text)
|
||||
# For other Slack errors, log and return original text
|
||||
logger.error(f"Slack API error in thread context fetch: {e}")
|
||||
return ThreadContextResult.error(message.text)
|
||||
except Exception as e:
|
||||
# Network errors, timeouts, etc - treat as recoverable error
|
||||
logger.error(f"Unexpected error in thread context fetch: {e}")
|
||||
return ThreadContextResult.error(message.text)
|
||||
|
||||
# make sure we didn't get an empty response or a single message (not a thread)
|
||||
# If empty response or single message (not a thread), return original text
|
||||
if len(messages) <= 1:
|
||||
return message.text
|
||||
return ThreadContextResult.success(message.text)
|
||||
|
||||
# add the initial thread message
|
||||
# Build thread text from thread starter + context window around matched message
|
||||
thread_text = _build_thread_text(
|
||||
messages, message_id, thread_id, access_token, team_id, slack_client
|
||||
)
|
||||
return ThreadContextResult.success(thread_text)
|
||||
|
||||
|
||||
def _build_thread_text(
|
||||
messages: list[dict[str, Any]],
|
||||
message_id: str,
|
||||
thread_id: str,
|
||||
access_token: str,
|
||||
team_id: str | None,
|
||||
slack_client: WebClient,
|
||||
) -> str:
|
||||
"""Build the thread text from messages."""
|
||||
msg_text = messages[0].get("text", "")
|
||||
msg_sender = messages[0].get("user", "")
|
||||
thread_text = f"<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# add the message (unless it's the initial message)
|
||||
thread_text += "\n\nReplies:"
|
||||
if thread_id == message_id:
|
||||
message_id_idx = 0
|
||||
@@ -681,28 +730,21 @@ def get_contextualized_thread_text(
|
||||
if not message_id_idx:
|
||||
return thread_text
|
||||
|
||||
# Include a few messages BEFORE the matched message for context
|
||||
# This helps understand what the matched message is responding to
|
||||
start_idx = max(
|
||||
1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW
|
||||
) # Start after thread starter
|
||||
start_idx = max(1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW)
|
||||
|
||||
# Add ellipsis if we're skipping messages between thread starter and context window
|
||||
if start_idx > 1:
|
||||
thread_text += "\n..."
|
||||
|
||||
# Add context messages before the matched message
|
||||
for i in range(start_idx, message_id_idx):
|
||||
msg_text = messages[i].get("text", "")
|
||||
msg_sender = messages[i].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# Add the matched message itself
|
||||
msg_text = messages[message_id_idx].get("text", "")
|
||||
msg_sender = messages[message_id_idx].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# add the following replies to the thread text
|
||||
# Add following replies
|
||||
len_replies = 0
|
||||
for msg in messages[message_id_idx + 1 :]:
|
||||
msg_text = msg.get("text", "")
|
||||
@@ -710,22 +752,19 @@ def get_contextualized_thread_text(
|
||||
reply = f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
thread_text += reply
|
||||
|
||||
# stop if len_replies exceeds chunk_size * 4 chars as the rest likely won't fit
|
||||
len_replies += len(reply)
|
||||
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
|
||||
thread_text += "\n..."
|
||||
break
|
||||
|
||||
# replace user ids with names in the thread text using cached lookups
|
||||
# Replace user IDs with names using cached lookups
|
||||
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
|
||||
|
||||
if team_id:
|
||||
# Use cached batch lookup when team_id is available
|
||||
user_profiles = batch_get_user_profiles(access_token, team_id, userids)
|
||||
for userid, name in user_profiles.items():
|
||||
thread_text = thread_text.replace(f"<@{userid}>", name)
|
||||
else:
|
||||
# Fallback to individual lookups (no caching) when team_id not available
|
||||
for userid in userids:
|
||||
try:
|
||||
response = slack_client.users_profile_get(user=userid)
|
||||
@@ -735,7 +774,7 @@ def get_contextualized_thread_text(
|
||||
except SlackApiError as e:
|
||||
if "user_not_found" in str(e):
|
||||
logger.debug(
|
||||
f"User {userid} not found in Slack workspace (likely deleted/deactivated)"
|
||||
f"User {userid} not found (likely deleted/deactivated)"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Could not fetch profile for user {userid}: {e}")
|
||||
@@ -747,6 +786,115 @@ def get_contextualized_thread_text(
|
||||
return thread_text
|
||||
|
||||
|
||||
def fetch_thread_contexts_with_rate_limit_handling(
|
||||
slack_messages: list[SlackMessage],
|
||||
access_token: str,
|
||||
team_id: str | None,
|
||||
batch_size: int = SLACK_THREAD_CONTEXT_BATCH_SIZE,
|
||||
max_messages: int | None = MAX_SLACK_THREAD_CONTEXT_MESSAGES,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Fetch thread contexts in controlled batches, stopping on rate limit.
|
||||
|
||||
Distinguishes between error types:
|
||||
- Rate limit (429): Stop processing further batches
|
||||
- Other errors: Continue processing (graceful degradation)
|
||||
|
||||
Args:
|
||||
slack_messages: Messages to fetch thread context for (should be sorted by relevance)
|
||||
access_token: Slack OAuth token
|
||||
team_id: Slack team ID for user profile caching
|
||||
batch_size: Number of concurrent API calls per batch
|
||||
max_messages: Maximum messages to fetch thread context for (None = no limit)
|
||||
|
||||
Returns:
|
||||
List of thread texts, one per input message.
|
||||
Messages beyond max_messages or after rate limit get their original text.
|
||||
"""
|
||||
if not slack_messages:
|
||||
return []
|
||||
|
||||
# Limit how many messages we fetch thread context for (if max_messages is set)
|
||||
if max_messages and max_messages < len(slack_messages):
|
||||
messages_for_context = slack_messages[:max_messages]
|
||||
messages_without_context = slack_messages[max_messages:]
|
||||
else:
|
||||
messages_for_context = slack_messages
|
||||
messages_without_context = []
|
||||
|
||||
logger.info(
|
||||
f"Fetching thread context for {len(messages_for_context)} of {len(slack_messages)} messages "
|
||||
f"(batch_size={batch_size}, max={max_messages or 'unlimited'})"
|
||||
)
|
||||
|
||||
results: list[str] = []
|
||||
rate_limited = False
|
||||
total_batches = (len(messages_for_context) + batch_size - 1) // batch_size
|
||||
rate_limit_batch = 0
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(messages_for_context), batch_size):
|
||||
current_batch = i // batch_size + 1
|
||||
|
||||
if rate_limited:
|
||||
# Skip remaining batches, use original message text
|
||||
remaining = messages_for_context[i:]
|
||||
skipped_batches = total_batches - rate_limit_batch
|
||||
logger.warning(
|
||||
f"Slack rate limit: skipping {len(remaining)} remaining messages "
|
||||
f"({skipped_batches} of {total_batches} batches). "
|
||||
f"Successfully enriched {len(results)} messages before rate limit."
|
||||
)
|
||||
results.extend([msg.text for msg in remaining])
|
||||
break
|
||||
|
||||
batch = messages_for_context[i : i + batch_size]
|
||||
|
||||
# _fetch_thread_context returns ThreadContextResult (never raises)
|
||||
# allow_failures=True is a safety net for any unexpected exceptions
|
||||
batch_results: list[ThreadContextResult | None] = (
|
||||
run_functions_tuples_in_parallel(
|
||||
[
|
||||
(
|
||||
_fetch_thread_context,
|
||||
(msg, access_token, team_id),
|
||||
)
|
||||
for msg in batch
|
||||
],
|
||||
allow_failures=True,
|
||||
max_workers=batch_size,
|
||||
)
|
||||
)
|
||||
|
||||
# Process results - ThreadContextResult tells us exactly what happened
|
||||
for j, result in enumerate(batch_results):
|
||||
if result is None:
|
||||
# Unexpected exception (shouldn't happen) - use original text, stop
|
||||
logger.error(f"Unexpected None result for message {j} in batch")
|
||||
results.append(batch[j].text)
|
||||
rate_limited = True
|
||||
rate_limit_batch = current_batch
|
||||
elif result.is_rate_limited:
|
||||
# Rate limit hit - use original text, stop further batches
|
||||
results.append(result.text)
|
||||
rate_limited = True
|
||||
rate_limit_batch = current_batch
|
||||
else:
|
||||
# Success or recoverable error - use the text (enriched or original)
|
||||
results.append(result.text)
|
||||
|
||||
if rate_limited:
|
||||
logger.warning(
|
||||
f"Slack rate limit (429) hit at batch {current_batch}/{total_batches} "
|
||||
f"while fetching thread context. Stopping further API calls."
|
||||
)
|
||||
|
||||
# Add original text for messages we didn't fetch context for
|
||||
results.extend([msg.text for msg in messages_without_context])
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def convert_slack_score(slack_score: float) -> float:
|
||||
"""
|
||||
Convert slack score to a score between 0 and 1.
|
||||
@@ -830,8 +978,8 @@ def slack_retrieval(
|
||||
)
|
||||
|
||||
# Query slack with entity filtering
|
||||
_, fast_llm = get_default_llms()
|
||||
query_strings = build_slack_queries(query, fast_llm, entities, available_channels)
|
||||
llm = get_default_llm()
|
||||
query_strings = build_slack_queries(query, llm, entities, available_channels)
|
||||
|
||||
# Determine filtering based on entities OR context (bot)
|
||||
include_dm = False
|
||||
@@ -964,11 +1112,12 @@ def slack_retrieval(
|
||||
if not slack_messages:
|
||||
return []
|
||||
|
||||
thread_texts: list[str] = run_functions_tuples_in_parallel(
|
||||
[
|
||||
(get_contextualized_thread_text, (slack_message, access_token, team_id))
|
||||
for slack_message in slack_messages
|
||||
]
|
||||
# Fetch thread context with rate limit handling and message limiting
|
||||
# Messages are already sorted by relevance (slack_score), so top N get full context
|
||||
thread_texts = fetch_thread_contexts_with_rate_limit_handling(
|
||||
slack_messages=slack_messages,
|
||||
access_token=access_token,
|
||||
team_id=team_id,
|
||||
)
|
||||
for slack_message, thread_text in zip(slack_messages, thread_texts):
|
||||
slack_message.text = thread_text
|
||||
|
||||
@@ -90,6 +90,16 @@ def _build_index_filters(
|
||||
if not source_filter and detected_source_filter:
|
||||
source_filter = detected_source_filter
|
||||
|
||||
# CRITICAL FIX: If user_file_ids are present, we must ensure "user_file"
|
||||
# source type is included in the filter, otherwise user files will be excluded!
|
||||
if user_file_ids and source_filter:
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
# Add user_file to the source filter if not already present
|
||||
if DocumentSource.USER_FILE not in source_filter:
|
||||
source_filter = list(source_filter) + [DocumentSource.USER_FILE]
|
||||
logger.debug("Added USER_FILE to source_filter for user knowledge search")
|
||||
|
||||
user_acl_filters = (
|
||||
None if bypass_acl else build_access_filters_for_user(user, db_session)
|
||||
)
|
||||
@@ -104,6 +114,7 @@ def _build_index_filters(
|
||||
access_control_list=user_acl_filters,
|
||||
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
|
||||
)
|
||||
|
||||
return final_filters
|
||||
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ def query_analysis(query: str) -> tuple[bool, list[str]]:
|
||||
return analysis_model.predict(query)
|
||||
|
||||
|
||||
# TODO: This is unused code.
|
||||
@log_function_time(print_only=True)
|
||||
def retrieval_preprocessing(
|
||||
search_request: SearchRequest,
|
||||
|
||||
@@ -118,6 +118,7 @@ def combine_retrieval_results(
|
||||
return sorted_chunks
|
||||
|
||||
|
||||
# TODO: This is unused code.
|
||||
@log_function_time(print_only=True)
|
||||
def doc_index_retrieval(
|
||||
query: SearchQuery,
|
||||
@@ -348,6 +349,7 @@ def retrieve_chunks(
|
||||
list(query.filters.source_type) if query.filters.source_type else None,
|
||||
query.filters.document_set,
|
||||
slack_context,
|
||||
query.filters.user_file_ids,
|
||||
)
|
||||
federated_sources = set(
|
||||
federated_retrieval_info.source.to_non_federated_source()
|
||||
@@ -475,6 +477,7 @@ def search_chunks(
|
||||
source_types=list(source_filters) if source_filters else None,
|
||||
document_set_names=query_request.filters.document_set,
|
||||
slack_context=slack_context,
|
||||
user_file_ids=query_request.filters.user_file_ids,
|
||||
)
|
||||
|
||||
federated_sources = set(
|
||||
@@ -510,6 +513,7 @@ def search_chunks(
|
||||
return top_chunks
|
||||
|
||||
|
||||
# TODO: This is unused code.
|
||||
def inference_sections_from_ids(
|
||||
doc_identifiers: list[tuple[str, int]],
|
||||
document_index: DocumentIndex,
|
||||
|
||||
@@ -63,7 +63,7 @@ def get_live_users_count(db_session: Session) -> int:
|
||||
This does NOT include invited users, "users" pulled in
|
||||
from external connectors, or API keys.
|
||||
"""
|
||||
count_stmt = func.count(User.id) # type: ignore
|
||||
count_stmt = func.count(User.id)
|
||||
select_stmt = select(count_stmt)
|
||||
select_stmt_w_filters = _add_live_user_count_where_clause(select_stmt, False)
|
||||
user_count = db_session.scalar(select_stmt_w_filters)
|
||||
@@ -74,7 +74,7 @@ def get_live_users_count(db_session: Session) -> int:
|
||||
|
||||
async def get_user_count(only_admin_users: bool = False) -> int:
|
||||
async with get_async_session_context_manager() as session:
|
||||
count_stmt = func.count(User.id) # type: ignore
|
||||
count_stmt = func.count(User.id)
|
||||
stmt = select(count_stmt)
|
||||
stmt_w_filters = _add_live_user_count_where_clause(stmt, only_admin_users)
|
||||
user_count = await session.scalar(stmt_w_filters)
|
||||
@@ -100,10 +100,10 @@ class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
|
||||
async def get_user_db(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> AsyncGenerator[SQLAlchemyUserAdminDB, None]:
|
||||
yield SQLAlchemyUserAdminDB(session, User, OAuthAccount) # type: ignore
|
||||
yield SQLAlchemyUserAdminDB(session, User, OAuthAccount)
|
||||
|
||||
|
||||
async def get_access_token_db(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> AsyncGenerator[SQLAlchemyAccessTokenDatabase, None]:
|
||||
yield SQLAlchemyAccessTokenDatabase(session, AccessToken) # type: ignore
|
||||
yield SQLAlchemyAccessTokenDatabase(session, AccessToken)
|
||||
|
||||
@@ -626,7 +626,7 @@ def reserve_message_id(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message,
|
||||
latest_child_message_id=None,
|
||||
message="Response was termination prior to completion, try regenerating.",
|
||||
message="Response was terminated prior to completion, try regenerating.",
|
||||
token_count=15,
|
||||
message_type=message_type,
|
||||
)
|
||||
@@ -744,29 +744,61 @@ def update_search_docs_table_with_relevance(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _sanitize_for_postgres(value: str) -> str:
|
||||
"""Remove NUL (0x00) characters from strings as PostgreSQL doesn't allow them."""
|
||||
sanitized = value.replace("\x00", "")
|
||||
if value and not sanitized:
|
||||
logger.warning("Sanitization removed all characters from string")
|
||||
return sanitized
|
||||
|
||||
|
||||
def _sanitize_list_for_postgres(values: list[str]) -> list[str]:
|
||||
"""Remove NUL (0x00) characters from all strings in a list."""
|
||||
return [_sanitize_for_postgres(v) for v in values]
|
||||
|
||||
|
||||
def create_db_search_doc(
|
||||
server_search_doc: ServerSearchDoc,
|
||||
db_session: Session,
|
||||
commit: bool = True,
|
||||
) -> DBSearchDoc:
|
||||
# Sanitize string fields to remove NUL characters (PostgreSQL doesn't allow them)
|
||||
db_search_doc = DBSearchDoc(
|
||||
document_id=server_search_doc.document_id,
|
||||
document_id=_sanitize_for_postgres(server_search_doc.document_id),
|
||||
chunk_ind=server_search_doc.chunk_ind,
|
||||
semantic_id=server_search_doc.semantic_identifier or "Unknown",
|
||||
link=server_search_doc.link,
|
||||
blurb=server_search_doc.blurb,
|
||||
semantic_id=_sanitize_for_postgres(server_search_doc.semantic_identifier),
|
||||
link=(
|
||||
_sanitize_for_postgres(server_search_doc.link)
|
||||
if server_search_doc.link is not None
|
||||
else None
|
||||
),
|
||||
blurb=_sanitize_for_postgres(server_search_doc.blurb),
|
||||
source_type=server_search_doc.source_type,
|
||||
boost=server_search_doc.boost,
|
||||
hidden=server_search_doc.hidden,
|
||||
doc_metadata=server_search_doc.metadata,
|
||||
is_relevant=server_search_doc.is_relevant,
|
||||
relevance_explanation=server_search_doc.relevance_explanation,
|
||||
relevance_explanation=(
|
||||
_sanitize_for_postgres(server_search_doc.relevance_explanation)
|
||||
if server_search_doc.relevance_explanation is not None
|
||||
else None
|
||||
),
|
||||
# For docs further down that aren't reranked, we can't use the retrieval score
|
||||
score=server_search_doc.score or 0.0,
|
||||
match_highlights=server_search_doc.match_highlights,
|
||||
match_highlights=_sanitize_list_for_postgres(
|
||||
server_search_doc.match_highlights
|
||||
),
|
||||
updated_at=server_search_doc.updated_at,
|
||||
primary_owners=server_search_doc.primary_owners,
|
||||
secondary_owners=server_search_doc.secondary_owners,
|
||||
primary_owners=(
|
||||
_sanitize_list_for_postgres(server_search_doc.primary_owners)
|
||||
if server_search_doc.primary_owners is not None
|
||||
else None
|
||||
),
|
||||
secondary_owners=(
|
||||
_sanitize_list_for_postgres(server_search_doc.secondary_owners)
|
||||
if server_search_doc.secondary_owners is not None
|
||||
else None
|
||||
),
|
||||
is_internet=server_search_doc.is_internet,
|
||||
)
|
||||
|
||||
|
||||
@@ -40,6 +40,21 @@ def check_connectors_exist(db_session: Session) -> bool:
|
||||
return result.scalar() or False
|
||||
|
||||
|
||||
def check_user_files_exist(db_session: Session) -> bool:
|
||||
"""Check if any user files exist in the system.
|
||||
|
||||
This is used to determine if the search tool should be available
|
||||
when there are no regular connectors but there are user files
|
||||
(User Knowledge mode).
|
||||
"""
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.enums import UserFileStatus
|
||||
|
||||
stmt = select(exists(UserFile).where(UserFile.status == UserFileStatus.COMPLETED))
|
||||
result = db_session.execute(stmt)
|
||||
return result.scalar() or False
|
||||
|
||||
|
||||
def fetch_connectors(
|
||||
db_session: Session,
|
||||
sources: list[DocumentSource] | None = None,
|
||||
|
||||
@@ -290,7 +290,7 @@ def get_document_counts_for_cc_pairs(
|
||||
)
|
||||
)
|
||||
|
||||
for connector_id, credential_id, cnt in db_session.execute(stmt).all(): # type: ignore
|
||||
for connector_id, credential_id, cnt in db_session.execute(stmt).all():
|
||||
aggregated_counts[(connector_id, credential_id)] = cnt
|
||||
|
||||
# Convert aggregated results back to the expected sequence of tuples
|
||||
@@ -1098,7 +1098,7 @@ def reset_all_document_kg_stages(db_session: Session) -> int:
|
||||
|
||||
# The hasattr check is needed for type checking, even though rowcount
|
||||
# is guaranteed to exist at runtime for UPDATE operations
|
||||
return result.rowcount if hasattr(result, "rowcount") else 0 # type: ignore
|
||||
return result.rowcount if hasattr(result, "rowcount") else 0
|
||||
|
||||
|
||||
def update_document_kg_stages(
|
||||
@@ -1121,7 +1121,7 @@ def update_document_kg_stages(
|
||||
result = db_session.execute(stmt)
|
||||
# The hasattr check is needed for type checking, even though rowcount
|
||||
# is guaranteed to exist at runtime for UPDATE operations
|
||||
return result.rowcount if hasattr(result, "rowcount") else 0 # type: ignore
|
||||
return result.rowcount if hasattr(result, "rowcount") else 0
|
||||
|
||||
|
||||
def get_skipped_kg_documents(db_session: Session) -> list[str]:
|
||||
|
||||
@@ -234,9 +234,6 @@ def upsert_llm_provider(
|
||||
existing_llm_provider.default_model_name = (
|
||||
llm_provider_upsert_request.default_model_name
|
||||
)
|
||||
existing_llm_provider.fast_default_model_name = (
|
||||
llm_provider_upsert_request.fast_default_model_name
|
||||
)
|
||||
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
|
||||
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import datetime
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
@@ -24,7 +26,9 @@ logger = setup_logger()
|
||||
# MCPServer operations
|
||||
def get_all_mcp_servers(db_session: Session) -> list[MCPServer]:
|
||||
"""Get all MCP servers"""
|
||||
return list(db_session.scalars(select(MCPServer)).all())
|
||||
return list(
|
||||
db_session.scalars(select(MCPServer).order_by(MCPServer.created_at)).all()
|
||||
)
|
||||
|
||||
|
||||
def get_mcp_server_by_id(server_id: int, db_session: Session) -> MCPServer:
|
||||
@@ -124,6 +128,7 @@ def update_mcp_server__no_commit(
|
||||
auth_performer: MCPAuthenticationPerformer | None = None,
|
||||
transport: MCPTransport | None = None,
|
||||
status: MCPServerStatus | None = None,
|
||||
last_refreshed_at: datetime.datetime | None = None,
|
||||
) -> MCPServer:
|
||||
"""Update an existing MCP server"""
|
||||
server = get_mcp_server_by_id(server_id, db_session)
|
||||
@@ -144,6 +149,8 @@ def update_mcp_server__no_commit(
|
||||
server.transport = transport
|
||||
if status is not None:
|
||||
server.status = status
|
||||
if last_refreshed_at is not None:
|
||||
server.last_refreshed_at = last_refreshed_at
|
||||
|
||||
db_session.flush() # Don't commit yet, let caller decide when to commit
|
||||
return server
|
||||
@@ -330,3 +337,15 @@ def delete_user_connection_configs_for_server(
|
||||
db_session.delete(config)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_all_user_connection_configs_for_server_no_commit(
|
||||
server_id: int, db_session: Session
|
||||
) -> None:
|
||||
"""Delete all user connection configs for a specific MCP server"""
|
||||
db_session.execute(
|
||||
delete(MCPConnectionConfig).where(
|
||||
MCPConnectionConfig.mcp_server_id == server_id
|
||||
)
|
||||
)
|
||||
db_session.flush() # Don't commit yet, let caller decide when to commit
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.models import Milestone
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
USER_ASSISTANT_PREFIX = "user_assistants_used_"
|
||||
MULTI_ASSISTANT_USED = "multi_assistant_used"
|
||||
|
||||
|
||||
def create_milestone(
|
||||
user: User | None,
|
||||
event_type: MilestoneRecordType,
|
||||
db_session: Session,
|
||||
) -> Milestone:
|
||||
milestone = Milestone(
|
||||
event_type=event_type,
|
||||
user_id=user.id if user else None,
|
||||
)
|
||||
db_session.add(milestone)
|
||||
db_session.commit()
|
||||
|
||||
return milestone
|
||||
|
||||
|
||||
def create_milestone_if_not_exists(
|
||||
user: User | None, event_type: MilestoneRecordType, db_session: Session
|
||||
) -> tuple[Milestone, bool]:
|
||||
# Check if it exists
|
||||
milestone = db_session.execute(
|
||||
select(Milestone).where(Milestone.event_type == event_type)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if milestone is not None:
|
||||
return milestone, False
|
||||
|
||||
# If it doesn't exist, try to create it.
|
||||
try:
|
||||
milestone = create_milestone(user, event_type, db_session)
|
||||
return milestone, True
|
||||
except IntegrityError:
|
||||
# Another thread or process inserted it in the meantime
|
||||
db_session.rollback()
|
||||
# Fetch again to return the existing record
|
||||
milestone = db_session.execute(
|
||||
select(Milestone).where(Milestone.event_type == event_type)
|
||||
).scalar_one() # Now should exist
|
||||
return milestone, False
|
||||
|
||||
|
||||
def update_user_assistant_milestone(
|
||||
milestone: Milestone,
|
||||
user_id: str | None,
|
||||
assistant_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
event_tracker = milestone.event_tracker
|
||||
if event_tracker is None:
|
||||
milestone.event_tracker = event_tracker = {}
|
||||
|
||||
if event_tracker.get(MULTI_ASSISTANT_USED):
|
||||
# No need to keep tracking and populating if the milestone has already been hit
|
||||
return
|
||||
|
||||
user_key = f"{USER_ASSISTANT_PREFIX}{user_id}"
|
||||
|
||||
if event_tracker.get(user_key) is None:
|
||||
event_tracker[user_key] = [assistant_id]
|
||||
elif assistant_id not in event_tracker[user_key]:
|
||||
event_tracker[user_key].append(assistant_id)
|
||||
|
||||
flag_modified(milestone, "event_tracker")
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def check_multi_assistant_milestone(
|
||||
milestone: Milestone,
|
||||
db_session: Session,
|
||||
) -> tuple[bool, bool]:
|
||||
"""Returns if the milestone was hit and if it was just hit for the first time"""
|
||||
event_tracker = milestone.event_tracker
|
||||
if event_tracker is None:
|
||||
return False, False
|
||||
|
||||
if event_tracker.get(MULTI_ASSISTANT_USED):
|
||||
return True, False
|
||||
|
||||
for key, value in event_tracker.items():
|
||||
if key.startswith(USER_ASSISTANT_PREFIX) and len(value) > 1:
|
||||
event_tracker[MULTI_ASSISTANT_USED] = True
|
||||
flag_modified(milestone, "event_tracker")
|
||||
db_session.commit()
|
||||
return True, True
|
||||
|
||||
return False, False
|
||||
@@ -2215,6 +2215,8 @@ class ToolCall(Base):
|
||||
# The tools with the same turn number (and parent) were called in parallel
|
||||
# Ones with different turn numbers (and same parent) were called sequentially
|
||||
turn_number: Mapped[int] = mapped_column(Integer)
|
||||
# Index order of tool calls from the LLM for parallel tool calls
|
||||
tab_index: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
# Not a FK because we want to be able to delete the tool without deleting
|
||||
# this entry
|
||||
@@ -2382,7 +2384,6 @@ class LLMProvider(Base):
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
default_model_name: Mapped[str] = mapped_column(String)
|
||||
fast_default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
@@ -2958,7 +2959,7 @@ class SlackChannelConfig(Base):
|
||||
"slack_bot_id",
|
||||
"is_default",
|
||||
unique=True,
|
||||
postgresql_where=(is_default is True), # type: ignore
|
||||
postgresql_where=(is_default is True),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3673,6 +3674,9 @@ class MCPServer(Base):
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
last_refreshed_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
# Relationships
|
||||
admin_connection_config: Mapped["MCPConnectionConfig | None"] = relationship(
|
||||
@@ -3685,6 +3689,7 @@ class MCPServer(Base):
|
||||
"MCPConnectionConfig",
|
||||
foreign_keys="MCPConnectionConfig.mcp_server_id",
|
||||
back_populates="mcp_server",
|
||||
passive_deletes=True,
|
||||
)
|
||||
current_actions: Mapped[list["Tool"]] = relationship(
|
||||
"Tool", back_populates="mcp_server", cascade="all, delete-orphan"
|
||||
@@ -3913,3 +3918,22 @@ class ExternalGroupPermissionSyncAttempt(Base):
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return self.status.is_terminal()
|
||||
|
||||
|
||||
class License(Base):
|
||||
"""Stores the signed license blob (singleton pattern - only one row)."""
|
||||
|
||||
__tablename__ = "license"
|
||||
__table_args__ = (
|
||||
# Singleton pattern - unique index on constant ensures only one row
|
||||
Index("idx_license_singleton", text("(true)"), unique=True),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
license_data: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
@@ -221,6 +221,7 @@ def create_tool_call_no_commit(
|
||||
parent_tool_call_id: int | None = None,
|
||||
reasoning_tokens: str | None = None,
|
||||
generated_images: list[dict] | None = None,
|
||||
tab_index: int = 0,
|
||||
add_only: bool = True,
|
||||
) -> ToolCall:
|
||||
"""
|
||||
@@ -239,6 +240,7 @@ def create_tool_call_no_commit(
|
||||
parent_tool_call_id: Optional parent tool call ID (for nested tool calls)
|
||||
reasoning_tokens: Optional reasoning tokens
|
||||
generated_images: Optional list of generated image metadata for replay
|
||||
tab_index: Index order of tool calls from the LLM for parallel tool calls
|
||||
commit: If True, commit the transaction; if False, flush only
|
||||
|
||||
Returns:
|
||||
@@ -249,6 +251,7 @@ def create_tool_call_no_commit(
|
||||
parent_chat_message_id=parent_chat_message_id,
|
||||
parent_tool_call_id=parent_tool_call_id,
|
||||
turn_number=turn_number,
|
||||
tab_index=tab_index,
|
||||
tool_id=tool_id,
|
||||
tool_call_id=tool_call_id,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
|
||||
@@ -257,7 +257,7 @@ def _get_users_by_emails(
|
||||
"""given a list of lowercase emails,
|
||||
returns a list[User] of Users whose emails match and a list[str]
|
||||
the missing emails that had no User"""
|
||||
stmt = select(User).filter(func.lower(User.email).in_(lower_emails)) # type: ignore
|
||||
stmt = select(User).filter(func.lower(User.email).in_(lower_emails))
|
||||
found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list
|
||||
|
||||
# Extract found emails and convert to lowercase to avoid case sensitivity issues
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user