mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
3 Commits
thread_sen
...
eric/tool_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc198ce1fb | ||
|
|
52222669b7 | ||
|
|
8fa2f2953b |
@@ -1,8 +0,0 @@
|
||||
# Exclude these commits from git blame (e.g. mass reformatting).
|
||||
# These are ignored by GitHub automatically.
|
||||
# To enable this locally, run:
|
||||
#
|
||||
# git config blame.ignoreRevsFile .git-blame-ignore-revs
|
||||
|
||||
3134e5f840c12c8f32613ce520101a047c89dcc2 # refactor(whitespace): rm temporary react fragments (#7161)
|
||||
ed3f72bc75f3e3a9ae9e4d8cd38278f9c97e78b4 # refactor(whitespace): rm react fragment #7190
|
||||
7
.github/CODEOWNERS
vendored
7
.github/CODEOWNERS
vendored
@@ -1,10 +1,3 @@
|
||||
* @onyx-dot-app/onyx-core-team
|
||||
# Helm charts Owners
|
||||
/helm/ @justin-tahara
|
||||
|
||||
# Web standards updates
|
||||
/web/STANDARDS.md @raunakab @Weves
|
||||
|
||||
# Agent context files
|
||||
/CLAUDE.md.template @Weves
|
||||
/AGENTS.md.template @Weves
|
||||
|
||||
@@ -7,6 +7,12 @@ inputs:
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@caf0cab7a618c569241d31dcd442f54681755d39 # ratchet:astral-sh/setup-uv@v3
|
||||
# TODO: Enable caching once there is a uv.lock file checked in.
|
||||
# with:
|
||||
# enable-cache: true
|
||||
|
||||
- name: Compute requirements hash
|
||||
id: req-hash
|
||||
shell: bash
|
||||
@@ -22,8 +28,6 @@ runs:
|
||||
done <<< "$REQUIREMENTS"
|
||||
echo "hash=$(echo "$hash" | sha256sum | cut -d' ' -f1)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
# NOTE: This comes before Setup uv since clean-ups run in reverse chronological order
|
||||
# such that Setup uv's prune-cache is able to prune the cache before we upload.
|
||||
- name: Cache uv cache directory
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
@@ -32,14 +36,6 @@ runs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-uv-
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
# TODO: Enable caching once there is a uv.lock file checked in.
|
||||
# with:
|
||||
# enable-cache: true
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # ratchet:actions/setup-python@v5
|
||||
with:
|
||||
|
||||
4
.github/pull_request_template.md
vendored
4
.github/pull_request_template.md
vendored
@@ -1,10 +1,10 @@
|
||||
## Description
|
||||
|
||||
<!--- Provide a brief description of the changes in this PR --->
|
||||
[Provide a brief description of the changes in this PR]
|
||||
|
||||
## How Has This Been Tested?
|
||||
|
||||
<!--- Describe the tests you ran to verify your changes --->
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
## Additional Options
|
||||
|
||||
|
||||
33
.github/workflows/check-lazy-imports.yml
vendored
Normal file
33
.github/workflows/check-lazy-imports.yml
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
name: Check Lazy Imports
|
||||
concurrency:
|
||||
group: Check-Lazy-Imports-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
check-lazy-imports:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # ratchet:actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Check lazy imports
|
||||
run: python3 backend/scripts/check_lazy_imports.py
|
||||
343
.github/workflows/deployment.yml
vendored
343
.github/workflows/deployment.yml
vendored
@@ -6,11 +6,11 @@ on:
|
||||
- "*"
|
||||
workflow_dispatch:
|
||||
|
||||
# Set restrictive default permissions for all jobs. Jobs that need more permissions
|
||||
# should explicitly declare them.
|
||||
permissions: {}
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
IS_DRY_RUN: ${{ github.event_name == 'workflow_dispatch' }}
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
|
||||
jobs:
|
||||
@@ -20,7 +20,6 @@ jobs:
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 90
|
||||
outputs:
|
||||
build-desktop: ${{ steps.check.outputs.build-desktop }}
|
||||
build-web: ${{ steps.check.outputs.build-web }}
|
||||
build-web-cloud: ${{ steps.check.outputs.build-web-cloud }}
|
||||
build-backend: ${{ steps.check.outputs.build-backend }}
|
||||
@@ -30,46 +29,32 @@ jobs:
|
||||
is-beta: ${{ steps.check.outputs.is-beta }}
|
||||
is-stable-standalone: ${{ steps.check.outputs.is-stable-standalone }}
|
||||
is-beta-standalone: ${{ steps.check.outputs.is-beta-standalone }}
|
||||
is-test-run: ${{ steps.check.outputs.is-test-run }}
|
||||
sanitized-tag: ${{ steps.check.outputs.sanitized-tag }}
|
||||
short-sha: ${{ steps.check.outputs.short-sha }}
|
||||
steps:
|
||||
- name: Check which components to build and version info
|
||||
id: check
|
||||
env:
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
run: |
|
||||
TAG="${GITHUB_REF_NAME}"
|
||||
# Sanitize tag name by replacing slashes with hyphens (for Docker tag compatibility)
|
||||
SANITIZED_TAG=$(echo "$TAG" | tr '/' '-')
|
||||
SHORT_SHA="${GITHUB_SHA::7}"
|
||||
|
||||
# Initialize all flags to false
|
||||
IS_CLOUD=false
|
||||
IS_NIGHTLY=false
|
||||
IS_VERSION_TAG=false
|
||||
IS_STABLE=false
|
||||
IS_BETA=false
|
||||
IS_STABLE_STANDALONE=false
|
||||
IS_BETA_STANDALONE=false
|
||||
IS_PROD_TAG=false
|
||||
IS_TEST_RUN=false
|
||||
BUILD_DESKTOP=false
|
||||
BUILD_WEB=false
|
||||
BUILD_WEB_CLOUD=false
|
||||
BUILD_BACKEND=true
|
||||
BUILD_MODEL_SERVER=true
|
||||
IS_STABLE=false
|
||||
IS_BETA=false
|
||||
IS_STABLE_STANDALONE=false
|
||||
IS_BETA_STANDALONE=false
|
||||
|
||||
# Determine tag type based on pattern matching (do regex checks once)
|
||||
if [[ "$TAG" == *cloud* ]]; then
|
||||
IS_CLOUD=true
|
||||
BUILD_WEB_CLOUD=true
|
||||
else
|
||||
BUILD_WEB=true
|
||||
fi
|
||||
if [[ "$TAG" == nightly* ]]; then
|
||||
IS_NIGHTLY=true
|
||||
fi
|
||||
if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+ ]]; then
|
||||
IS_VERSION_TAG=true
|
||||
fi
|
||||
|
||||
# Version checks (for web - any stable version)
|
||||
if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
IS_STABLE=true
|
||||
fi
|
||||
@@ -77,37 +62,15 @@ jobs:
|
||||
IS_BETA=true
|
||||
fi
|
||||
|
||||
# Determine what to build based on tag type
|
||||
if [[ "$IS_CLOUD" == "true" ]]; then
|
||||
BUILD_WEB_CLOUD=true
|
||||
else
|
||||
BUILD_WEB=true
|
||||
# Skip desktop builds on beta tags and nightly runs
|
||||
if [[ "$IS_BETA" != "true" ]] && [[ "$IS_NIGHTLY" != "true" ]]; then
|
||||
BUILD_DESKTOP=true
|
||||
fi
|
||||
fi
|
||||
|
||||
# Standalone version checks (for backend/model-server - version excluding cloud tags)
|
||||
if [[ "$IS_STABLE" == "true" ]] && [[ "$IS_CLOUD" != "true" ]]; then
|
||||
# Version checks (for backend/model-server - stable version excluding cloud tags)
|
||||
if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "$TAG" != *cloud* ]]; then
|
||||
IS_STABLE_STANDALONE=true
|
||||
fi
|
||||
if [[ "$IS_BETA" == "true" ]] && [[ "$IS_CLOUD" != "true" ]]; then
|
||||
if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+-beta\.[0-9]+$ ]] && [[ "$TAG" != *cloud* ]]; then
|
||||
IS_BETA_STANDALONE=true
|
||||
fi
|
||||
|
||||
# Determine if this is a production tag
|
||||
# Production tags are: version tags (v1.2.3*) or nightly tags
|
||||
if [[ "$IS_VERSION_TAG" == "true" ]] || [[ "$IS_NIGHTLY" == "true" ]]; then
|
||||
IS_PROD_TAG=true
|
||||
fi
|
||||
|
||||
# Determine if this is a test run (workflow_dispatch on non-production ref)
|
||||
if [[ "$EVENT_NAME" == "workflow_dispatch" ]] && [[ "$IS_PROD_TAG" != "true" ]]; then
|
||||
IS_TEST_RUN=true
|
||||
fi
|
||||
{
|
||||
echo "build-desktop=$BUILD_DESKTOP"
|
||||
echo "build-web=$BUILD_WEB"
|
||||
echo "build-web-cloud=$BUILD_WEB_CLOUD"
|
||||
echo "build-backend=$BUILD_BACKEND"
|
||||
@@ -117,9 +80,7 @@ jobs:
|
||||
echo "is-beta=$IS_BETA"
|
||||
echo "is-stable-standalone=$IS_STABLE_STANDALONE"
|
||||
echo "is-beta-standalone=$IS_BETA_STANDALONE"
|
||||
echo "is-test-run=$IS_TEST_RUN"
|
||||
echo "sanitized-tag=$SANITIZED_TAG"
|
||||
echo "short-sha=$SHORT_SHA"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
check-version-tag:
|
||||
@@ -128,15 +89,13 @@ jobs:
|
||||
if: ${{ !startsWith(github.ref_name, 'nightly-latest') && github.event_name != 'workflow_dispatch' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # ratchet:astral-sh/setup-uv@v7.1.4
|
||||
with:
|
||||
version: "0.9.9"
|
||||
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.
|
||||
enable-cache: false
|
||||
|
||||
@@ -152,7 +111,7 @@ jobs:
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -164,138 +123,6 @@ jobs:
|
||||
title: "🚨 Version Tag Check Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
build-desktop:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-desktop == 'true'
|
||||
permissions:
|
||||
contents: write
|
||||
actions: read
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- platform: "macos-latest" # Build a universal image for macOS.
|
||||
args: "--target universal-apple-darwin"
|
||||
- platform: "ubuntu-24.04"
|
||||
args: "--bundles deb,rpm"
|
||||
- platform: "ubuntu-24.04-arm" # Only available in public repos.
|
||||
args: "--bundles deb,rpm"
|
||||
- platform: "windows-latest"
|
||||
args: ""
|
||||
|
||||
runs-on: ${{ matrix.platform }}
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6.0.1
|
||||
with:
|
||||
# NOTE: persist-credentials is needed for tauri-action to create GitHub releases.
|
||||
persist-credentials: true # zizmor: ignore[artipacked]
|
||||
|
||||
- name: install dependencies (ubuntu only)
|
||||
if: startsWith(matrix.platform, 'ubuntu-')
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y \
|
||||
build-essential \
|
||||
libglib2.0-dev \
|
||||
libgirepository1.0-dev \
|
||||
libgtk-3-dev \
|
||||
libjavascriptcoregtk-4.1-dev \
|
||||
libwebkit2gtk-4.1-dev \
|
||||
libayatana-appindicator3-dev \
|
||||
gobject-introspection \
|
||||
pkg-config \
|
||||
curl \
|
||||
xdg-utils
|
||||
|
||||
- name: setup node
|
||||
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v6.1.0
|
||||
with:
|
||||
node-version: 24
|
||||
package-manager-cache: false
|
||||
|
||||
- name: install Rust stable
|
||||
uses: dtolnay/rust-toolchain@6d9817901c499d6b02debbb57edb38d33daa680b # zizmor: ignore[impostor-commit]
|
||||
with:
|
||||
# Those targets are only used on macos runners so it's in an `if` to slightly speed up windows and linux builds.
|
||||
targets: ${{ matrix.platform == 'macos-latest' && 'aarch64-apple-darwin,x86_64-apple-darwin' || '' }}
|
||||
|
||||
- name: install frontend dependencies
|
||||
working-directory: ./desktop
|
||||
run: npm install
|
||||
|
||||
- name: Inject version (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
working-directory: ./desktop
|
||||
env:
|
||||
SHORT_SHA: ${{ needs.determine-builds.outputs.short-sha }}
|
||||
IS_TEST_RUN: ${{ needs.determine-builds.outputs.is-test-run }}
|
||||
run: |
|
||||
if [ "${IS_TEST_RUN}" == "true" ]; then
|
||||
VERSION="0.0.0-dev+${SHORT_SHA}"
|
||||
else
|
||||
VERSION="${GITHUB_REF_NAME#v}"
|
||||
fi
|
||||
echo "Injecting version: $VERSION"
|
||||
|
||||
# Update Cargo.toml
|
||||
sed "s/^version = .*/version = \"$VERSION\"/" src-tauri/Cargo.toml > src-tauri/Cargo.toml.tmp
|
||||
mv src-tauri/Cargo.toml.tmp src-tauri/Cargo.toml
|
||||
|
||||
# Update tauri.conf.json
|
||||
jq --arg v "$VERSION" '.version = $v' src-tauri/tauri.conf.json > src-tauri/tauri.conf.json.tmp
|
||||
mv src-tauri/tauri.conf.json.tmp src-tauri/tauri.conf.json
|
||||
|
||||
# Update package.json
|
||||
jq --arg v "$VERSION" '.version = $v' package.json > package.json.tmp
|
||||
mv package.json.tmp package.json
|
||||
|
||||
echo "Versions set to: $VERSION"
|
||||
|
||||
- name: Inject version (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
working-directory: ./desktop
|
||||
shell: pwsh
|
||||
env:
|
||||
IS_TEST_RUN: ${{ needs.determine-builds.outputs.is-test-run }}
|
||||
run: |
|
||||
# Windows MSI requires numeric-only build metadata, so we skip the SHA suffix
|
||||
if ($env:IS_TEST_RUN -eq "true") {
|
||||
$VERSION = "0.0.0"
|
||||
} else {
|
||||
# Strip 'v' prefix and any pre-release suffix (e.g., -beta.13) for MSI compatibility
|
||||
$VERSION = "$env:GITHUB_REF_NAME" -replace '^v', '' -replace '-.*$', ''
|
||||
}
|
||||
Write-Host "Injecting version: $VERSION"
|
||||
|
||||
# Update Cargo.toml
|
||||
$cargo = Get-Content src-tauri/Cargo.toml -Raw
|
||||
$cargo = $cargo -replace '(?m)^version = .*', "version = `"$VERSION`""
|
||||
Set-Content src-tauri/Cargo.toml $cargo -NoNewline
|
||||
|
||||
# Update tauri.conf.json
|
||||
$json = Get-Content src-tauri/tauri.conf.json | ConvertFrom-Json
|
||||
$json.version = $VERSION
|
||||
$json | ConvertTo-Json -Depth 100 | Set-Content src-tauri/tauri.conf.json
|
||||
|
||||
# Update package.json
|
||||
$pkg = Get-Content package.json | ConvertFrom-Json
|
||||
$pkg.version = $VERSION
|
||||
$pkg | ConvertTo-Json -Depth 100 | Set-Content package.json
|
||||
|
||||
Write-Host "Versions set to: $VERSION"
|
||||
|
||||
- uses: tauri-apps/tauri-action@19b93bb55601e3e373a93cfb6eb4242e45f5af20 # ratchet:tauri-apps/tauri-action@action-v0.6.0
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
tagName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseBody: "See the assets to download this version and install."
|
||||
releaseDraft: true
|
||||
prerelease: false
|
||||
args: ${{ matrix.args }}
|
||||
|
||||
build-web-amd64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-web == 'true'
|
||||
@@ -313,15 +140,15 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -351,7 +178,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64,mode=max
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-web-arm64:
|
||||
@@ -371,15 +198,15 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -409,7 +236,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64,mode=max
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
merge-web:
|
||||
@@ -439,20 +266,20 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('web-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta == 'true' && 'beta' || '' }}
|
||||
type=raw,value=${{ github.event_name == 'workflow_dispatch' && format('web-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-beta == 'true' && 'beta' || '' }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-web-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-web-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
@@ -479,15 +306,15 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -525,7 +352,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64,mode=max
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-web-cloud-arm64:
|
||||
@@ -545,15 +372,15 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -591,7 +418,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64,mode=max
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
merge-web-cloud:
|
||||
@@ -621,17 +448,17 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('web-cloud-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ github.event_name == 'workflow_dispatch' && format('web-cloud-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-web-cloud-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-web-cloud-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
@@ -658,15 +485,15 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -695,7 +522,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64,mode=max
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-backend-arm64:
|
||||
@@ -715,15 +542,15 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -752,7 +579,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64,mode=max
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
merge-backend:
|
||||
@@ -782,20 +609,20 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('backend-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-stable-standalone == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
type=raw,value=${{ github.event_name == 'workflow_dispatch' && format('backend-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-stable-standalone == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-backend-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-backend-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
@@ -823,15 +650,15 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -864,7 +691,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64,mode=max
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ env.EDGE_TAG != 'true' && vars.MODEL_SERVER_NO_CACHE == 'true' }}
|
||||
provenance: false
|
||||
sbom: false
|
||||
@@ -887,15 +714,15 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -928,7 +755,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64,mode=max
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ env.EDGE_TAG != 'true' && vars.MODEL_SERVER_NO_CACHE == 'true' }}
|
||||
provenance: false
|
||||
sbom: false
|
||||
@@ -960,20 +787,20 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('model-server-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-stable-standalone == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
type=raw,value=${{ github.event_name == 'workflow_dispatch' && format('model-server-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-stable-standalone == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-model-server-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-model-server-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
@@ -1006,7 +833,7 @@ jobs:
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ "${{ needs.determine-builds.outputs.is-test-run }}" == "true" ]; then
|
||||
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
SCAN_IMAGE="${{ env.RUNS_ON_ECR_CACHE }}:web-${{ needs.determine-builds.outputs.sanitized-tag }}"
|
||||
else
|
||||
SCAN_IMAGE="docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}"
|
||||
@@ -1046,7 +873,7 @@ jobs:
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ "${{ needs.determine-builds.outputs.is-test-run }}" == "true" ]; then
|
||||
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
SCAN_IMAGE="${{ env.RUNS_ON_ECR_CACHE }}:web-cloud-${{ needs.determine-builds.outputs.sanitized-tag }}"
|
||||
else
|
||||
SCAN_IMAGE="docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}"
|
||||
@@ -1080,7 +907,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -1091,7 +918,7 @@ jobs:
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ "${{ needs.determine-builds.outputs.is-test-run }}" == "true" ]; then
|
||||
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
SCAN_IMAGE="${{ env.RUNS_ON_ECR_CACHE }}:backend-${{ needs.determine-builds.outputs.sanitized-tag }}"
|
||||
else
|
||||
SCAN_IMAGE="docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}"
|
||||
@@ -1133,7 +960,7 @@ jobs:
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ "${{ needs.determine-builds.outputs.is-test-run }}" == "true" ]; then
|
||||
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
SCAN_IMAGE="${{ env.RUNS_ON_ECR_CACHE }}:model-server-${{ needs.determine-builds.outputs.sanitized-tag }}"
|
||||
else
|
||||
SCAN_IMAGE="docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}"
|
||||
@@ -1152,8 +979,6 @@ jobs:
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs:
|
||||
- determine-builds
|
||||
- build-desktop
|
||||
- build-web-amd64
|
||||
- build-web-arm64
|
||||
- merge-web
|
||||
@@ -1166,13 +991,13 @@ jobs:
|
||||
- build-model-server-amd64
|
||||
- build-model-server-arm64
|
||||
- merge-model-server
|
||||
if: always() && (needs.build-desktop.result == 'failure' || needs.build-web-amd64.result == 'failure' || needs.build-web-arm64.result == 'failure' || needs.merge-web.result == 'failure' || needs.build-web-cloud-amd64.result == 'failure' || needs.build-web-cloud-arm64.result == 'failure' || needs.merge-web-cloud.result == 'failure' || needs.build-backend-amd64.result == 'failure' || needs.build-backend-arm64.result == 'failure' || needs.merge-backend.result == 'failure' || needs.build-model-server-amd64.result == 'failure' || needs.build-model-server-arm64.result == 'failure' || needs.merge-model-server.result == 'failure') && needs.determine-builds.outputs.is-test-run != 'true'
|
||||
if: always() && (needs.build-web-amd64.result == 'failure' || needs.build-web-arm64.result == 'failure' || needs.merge-web.result == 'failure' || needs.build-web-cloud-amd64.result == 'failure' || needs.build-web-cloud-arm64.result == 'failure' || needs.merge-web-cloud.result == 'failure' || needs.build-backend-amd64.result == 'failure' || needs.build-backend-arm64.result == 'failure' || needs.merge-backend.result == 'failure' || needs.build-model-server-amd64.result == 'failure' || needs.build-model-server-arm64.result == 'failure' || needs.merge-model-server.result == 'failure') && github.event_name != 'workflow_dispatch'
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -1181,9 +1006,6 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
FAILED_JOBS=""
|
||||
if [ "${NEEDS_BUILD_DESKTOP_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-desktop\\n"
|
||||
fi
|
||||
if [ "${NEEDS_BUILD_WEB_AMD64_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-web-amd64\\n"
|
||||
fi
|
||||
@@ -1224,7 +1046,6 @@ jobs:
|
||||
FAILED_JOBS=$(printf '%s' "$FAILED_JOBS" | sed 's/\\n$//')
|
||||
echo "jobs=$FAILED_JOBS" >> "$GITHUB_OUTPUT"
|
||||
env:
|
||||
NEEDS_BUILD_DESKTOP_RESULT: ${{ needs.build-desktop.result }}
|
||||
NEEDS_BUILD_WEB_AMD64_RESULT: ${{ needs.build-web-amd64.result }}
|
||||
NEEDS_BUILD_WEB_ARM64_RESULT: ${{ needs.build-web-arm64.result }}
|
||||
NEEDS_MERGE_WEB_RESULT: ${{ needs.merge-web.result }}
|
||||
|
||||
2
.github/workflows/helm-chart-releases.yml
vendored
2
.github/workflows/helm-chart-releases.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
31
.github/workflows/merge-group.yml
vendored
31
.github/workflows/merge-group.yml
vendored
@@ -1,31 +0,0 @@
|
||||
name: Merge Group-Specific
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
# This job immediately succeeds to satisfy branch protection rules on merge_group events.
|
||||
# There is a similarly named "required" job in pr-integration-tests.yml which runs the actual
|
||||
# integration tests. That job runs on both pull_request and merge_group events, and this job
|
||||
# exists solely to provide a fast-passing check with the same name for branch protection.
|
||||
# The actual tests remain enforced on presubmit (pull_request events).
|
||||
required:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Success
|
||||
run: echo "Success"
|
||||
# This job immediately succeeds to satisfy branch protection rules on merge_group events.
|
||||
# There is a similarly named "playwright-required" job in pr-playwright-tests.yml which runs
|
||||
# the actual playwright tests. That job runs on both pull_request and merge_group events, and
|
||||
# this job exists solely to provide a fast-passing check with the same name for branch protection.
|
||||
# The actual tests remain enforced on presubmit (pull_request events).
|
||||
playwright-required:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Success
|
||||
run: echo "Success"
|
||||
2
.github/workflows/nightly-scan-licenses.yml
vendored
2
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
62
.github/workflows/pr-database-tests.yml
vendored
62
.github/workflows/pr-database-tests.yml
vendored
@@ -1,62 +0,0 @@
|
||||
name: Database Tests
|
||||
concurrency:
|
||||
group: Database-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
database-tests:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-database-tests"
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Python and Install Dependencies
|
||||
uses: ./.github/actions/setup-python-and-install-dependencies
|
||||
with:
|
||||
requirements: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Generate OpenAPI schema and Python client
|
||||
shell: bash
|
||||
run: |
|
||||
ods openapi all
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Start Docker containers
|
||||
working-directory: ./deployment/docker_compose
|
||||
run: |
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d \
|
||||
relational_db
|
||||
|
||||
- name: Run Database Tests
|
||||
working-directory: ./backend
|
||||
run: pytest -m alembic tests/integration/tests/migrations/
|
||||
@@ -38,8 +38,6 @@ env:
|
||||
# LLMs
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
VERTEX_CREDENTIALS: ${{ secrets.VERTEX_CREDENTIALS }}
|
||||
VERTEX_LOCATION: ${{ vars.VERTEX_LOCATION }}
|
||||
|
||||
# Code Interpreter
|
||||
# TODO: debug why this is failing and enable
|
||||
@@ -54,7 +52,7 @@ jobs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -82,13 +80,12 @@ jobs:
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
MODEL_SERVER_HOST: "disabled"
|
||||
DISABLE_TELEMETRY: "true"
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -116,7 +113,6 @@ jobs:
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
CODE_INTERPRETER_BETA_ENABLED=true
|
||||
DISABLE_TELEMETRY=true
|
||||
EOF
|
||||
|
||||
- name: Set up Standard Dependencies
|
||||
|
||||
404
.github/workflows/pr-helm-chart-testing.yml
vendored
404
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -6,11 +6,11 @@ concurrency:
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
branches: [ main ]
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -18,233 +18,225 @@ permissions:
|
||||
jobs:
|
||||
helm-chart-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=8cpu-linux-x64,
|
||||
hdd=256,
|
||||
"run-id=${{ github.run_id }}-helm-chart-check",
|
||||
]
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}-helm-chart-check"]
|
||||
timeout-minutes: 45
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4.3.1
|
||||
with:
|
||||
version: v3.19.0
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4.3.1
|
||||
with:
|
||||
version: v3.19.0
|
||||
|
||||
- name: Set up chart-testing
|
||||
# NOTE: This is Jamison's patch from https://github.com/helm/chart-testing-action/pull/194
|
||||
uses: helm/chart-testing-action@8958a6ac472cbd8ee9a8fbb6f1acbc1b0e966e44 # zizmor: ignore[impostor-commit]
|
||||
with:
|
||||
uv_version: "0.9.9"
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@6ec842c01de15ebb84c8627d2744a0c2f2755c9f # ratchet:helm/chart-testing-action@v2.8.0
|
||||
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
env:
|
||||
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
|
||||
run: |
|
||||
echo "default_branch: ${DEFAULT_BRANCH}"
|
||||
changed=$(ct list-changed --remote origin --target-branch ${DEFAULT_BRANCH} --chart-dirs deployment/helm/charts)
|
||||
echo "list-changed output: $changed"
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
env:
|
||||
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
|
||||
run: |
|
||||
echo "default_branch: ${DEFAULT_BRANCH}"
|
||||
changed=$(ct list-changed --remote origin --target-branch ${DEFAULT_BRANCH} --chart-dirs deployment/helm/charts)
|
||||
echo "list-changed output: $changed"
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# uncomment to force run chart-testing
|
||||
# - name: Force run chart-testing (list-changed)
|
||||
# id: list-changed
|
||||
# run: echo "changed=true" >> $GITHUB_OUTPUT
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --config ct.yaml --all
|
||||
# the following would lint only changed charts, but linting isn't expensive
|
||||
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
# uncomment to force run chart-testing
|
||||
# - name: Force run chart-testing (list-changed)
|
||||
# id: list-changed
|
||||
# run: echo "changed=true" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@92086f6be054225fa813e0a4b13787fc9088faab # ratchet:helm/kind-action@v1.13.0
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --config ct.yaml --all
|
||||
# the following would lint only changed charts, but linting isn't expensive
|
||||
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Pre-install cluster status check
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-install Cluster Status ==="
|
||||
kubectl get nodes -o wide
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get storageclass
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@92086f6be054225fa813e0a4b13787fc9088faab # ratchet:helm/kind-action@v1.13.0
|
||||
|
||||
- name: Add Helm repositories and update
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Adding Helm repositories ==="
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
|
||||
helm repo update
|
||||
- name: Pre-install cluster status check
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-install Cluster Status ==="
|
||||
kubectl get nodes -o wide
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get storageclass
|
||||
|
||||
- name: Install Redis operator
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
echo "=== Installing redis-operator CRDs ==="
|
||||
helm upgrade --install redis-operator ot-container-kit/redis-operator \
|
||||
--namespace redis-operator --create-namespace --wait --timeout 300s
|
||||
- name: Add Helm repositories and update
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Adding Helm repositories ==="
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
|
||||
helm repo update
|
||||
|
||||
- name: Pre-pull required images
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-pulling required images to avoid timeout ==="
|
||||
KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//')
|
||||
echo "Kind cluster: $KIND_CLUSTER"
|
||||
- name: Install Redis operator
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
echo "=== Installing redis-operator CRDs ==="
|
||||
helm upgrade --install redis-operator ot-container-kit/redis-operator \
|
||||
--namespace redis-operator --create-namespace --wait --timeout 300s
|
||||
|
||||
IMAGES=(
|
||||
"ghcr.io/cloudnative-pg/cloudnative-pg:1.27.0"
|
||||
"quay.io/opstree/redis:v7.0.15"
|
||||
"docker.io/onyxdotapp/onyx-web-server:latest"
|
||||
)
|
||||
- name: Pre-pull required images
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-pulling required images to avoid timeout ==="
|
||||
KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//')
|
||||
echo "Kind cluster: $KIND_CLUSTER"
|
||||
|
||||
for image in "${IMAGES[@]}"; do
|
||||
echo "Pre-pulling $image"
|
||||
if docker pull "$image"; then
|
||||
kind load docker-image "$image" --name "$KIND_CLUSTER" || echo "Failed to load $image into kind"
|
||||
else
|
||||
echo "Failed to pull $image"
|
||||
fi
|
||||
done
|
||||
IMAGES=(
|
||||
"ghcr.io/cloudnative-pg/cloudnative-pg:1.27.0"
|
||||
"quay.io/opstree/redis:v7.0.15"
|
||||
"docker.io/onyxdotapp/onyx-web-server:latest"
|
||||
)
|
||||
|
||||
echo "=== Images loaded into Kind cluster ==="
|
||||
docker exec "$KIND_CLUSTER"-control-plane crictl images | grep -E "(cloudnative-pg|redis|onyx)" || echo "Some images may still be loading..."
|
||||
|
||||
- name: Validate chart dependencies
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Validating chart dependencies ==="
|
||||
cd deployment/helm/charts/onyx
|
||||
helm dependency update
|
||||
helm lint .
|
||||
|
||||
- name: Run chart-testing (install) with enhanced monitoring
|
||||
timeout-minutes: 25
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Starting chart installation with monitoring ==="
|
||||
|
||||
# Function to monitor cluster state
|
||||
monitor_cluster() {
|
||||
while true; do
|
||||
echo "=== Cluster Status Check at $(date) ==="
|
||||
# Only show non-running pods to reduce noise
|
||||
NON_RUNNING_PODS=$(kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded --no-headers 2>/dev/null | wc -l)
|
||||
if [ "$NON_RUNNING_PODS" -gt 0 ]; then
|
||||
echo "Non-running pods:"
|
||||
kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded
|
||||
else
|
||||
echo "All pods running successfully"
|
||||
fi
|
||||
# Only show recent events if there are issues
|
||||
RECENT_EVENTS=$(kubectl get events --sort-by=.lastTimestamp --all-namespaces --field-selector=type!=Normal 2>/dev/null | tail -5)
|
||||
if [ -n "$RECENT_EVENTS" ]; then
|
||||
echo "Recent warnings/errors:"
|
||||
echo "$RECENT_EVENTS"
|
||||
fi
|
||||
sleep 60
|
||||
done
|
||||
}
|
||||
|
||||
# Start monitoring in background
|
||||
monitor_cluster &
|
||||
MONITOR_PID=$!
|
||||
|
||||
# Set up cleanup
|
||||
cleanup() {
|
||||
echo "=== Cleaning up monitoring process ==="
|
||||
kill $MONITOR_PID 2>/dev/null || true
|
||||
echo "=== Final cluster state ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -20
|
||||
}
|
||||
|
||||
# Trap cleanup on exit
|
||||
trap cleanup EXIT
|
||||
|
||||
# Run the actual installation with detailed logging
|
||||
echo "=== Starting ct install ==="
|
||||
set +e
|
||||
ct install --all \
|
||||
--helm-extra-set-args="\
|
||||
--set=nginx.enabled=false \
|
||||
--set=minio.enabled=false \
|
||||
--set=vespa.enabled=false \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=postgresql.enabled=true \
|
||||
--set=postgresql.nameOverride=cloudnative-pg \
|
||||
--set=postgresql.cluster.storage.storageClass=standard \
|
||||
--set=redis.enabled=true \
|
||||
--set=redis.storageSpec.volumeClaimTemplate.spec.storageClassName=standard \
|
||||
--set=webserver.replicaCount=1 \
|
||||
--set=api.replicaCount=0 \
|
||||
--set=inferenceCapability.replicaCount=0 \
|
||||
--set=indexCapability.replicaCount=0 \
|
||||
--set=celery_beat.replicaCount=0 \
|
||||
--set=celery_worker_heavy.replicaCount=0 \
|
||||
--set=celery_worker_docfetching.replicaCount=0 \
|
||||
--set=celery_worker_docprocessing.replicaCount=0 \
|
||||
--set=celery_worker_light.replicaCount=0 \
|
||||
--set=celery_worker_monitoring.replicaCount=0 \
|
||||
--set=celery_worker_primary.replicaCount=0 \
|
||||
--set=celery_worker_user_file_processing.replicaCount=0 \
|
||||
--set=celery_worker_user_files_indexing.replicaCount=0" \
|
||||
--helm-extra-args="--timeout 900s --debug" \
|
||||
--debug --config ct.yaml
|
||||
CT_EXIT=$?
|
||||
set -e
|
||||
|
||||
if [[ $CT_EXIT -ne 0 ]]; then
|
||||
echo "ct install failed with exit code $CT_EXIT"
|
||||
exit $CT_EXIT
|
||||
for image in "${IMAGES[@]}"; do
|
||||
echo "Pre-pulling $image"
|
||||
if docker pull "$image"; then
|
||||
kind load docker-image "$image" --name "$KIND_CLUSTER" || echo "Failed to load $image into kind"
|
||||
else
|
||||
echo "=== Installation completed successfully ==="
|
||||
echo "Failed to pull $image"
|
||||
fi
|
||||
done
|
||||
|
||||
kubectl get pods --all-namespaces
|
||||
echo "=== Images loaded into Kind cluster ==="
|
||||
docker exec "$KIND_CLUSTER"-control-plane crictl images | grep -E "(cloudnative-pg|redis|onyx)" || echo "Some images may still be loading..."
|
||||
|
||||
- name: Post-install verification
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Post-install verification ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get services --all-namespaces
|
||||
# Only show issues if they exist
|
||||
kubectl describe pods --all-namespaces | grep -A 5 -B 2 "Failed\|Error\|Warning" || echo "No pod issues found"
|
||||
- name: Validate chart dependencies
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Validating chart dependencies ==="
|
||||
cd deployment/helm/charts/onyx
|
||||
helm dependency update
|
||||
helm lint .
|
||||
|
||||
- name: Cleanup on failure
|
||||
if: failure() && steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Cleanup on failure ==="
|
||||
- name: Run chart-testing (install) with enhanced monitoring
|
||||
timeout-minutes: 25
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Starting chart installation with monitoring ==="
|
||||
|
||||
# Function to monitor cluster state
|
||||
monitor_cluster() {
|
||||
while true; do
|
||||
echo "=== Cluster Status Check at $(date) ==="
|
||||
# Only show non-running pods to reduce noise
|
||||
NON_RUNNING_PODS=$(kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded --no-headers 2>/dev/null | wc -l)
|
||||
if [ "$NON_RUNNING_PODS" -gt 0 ]; then
|
||||
echo "Non-running pods:"
|
||||
kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded
|
||||
else
|
||||
echo "All pods running successfully"
|
||||
fi
|
||||
# Only show recent events if there are issues
|
||||
RECENT_EVENTS=$(kubectl get events --sort-by=.lastTimestamp --all-namespaces --field-selector=type!=Normal 2>/dev/null | tail -5)
|
||||
if [ -n "$RECENT_EVENTS" ]; then
|
||||
echo "Recent warnings/errors:"
|
||||
echo "$RECENT_EVENTS"
|
||||
fi
|
||||
sleep 60
|
||||
done
|
||||
}
|
||||
|
||||
# Start monitoring in background
|
||||
monitor_cluster &
|
||||
MONITOR_PID=$!
|
||||
|
||||
# Set up cleanup
|
||||
cleanup() {
|
||||
echo "=== Cleaning up monitoring process ==="
|
||||
kill $MONITOR_PID 2>/dev/null || true
|
||||
echo "=== Final cluster state ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -10
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -20
|
||||
}
|
||||
|
||||
echo "=== Pod descriptions for debugging ==="
|
||||
kubectl describe pods --all-namespaces | grep -A 10 -B 3 "Failed\|Error\|Warning\|Pending" || echo "No problematic pods found"
|
||||
# Trap cleanup on exit
|
||||
trap cleanup EXIT
|
||||
|
||||
echo "=== Recent logs for debugging ==="
|
||||
kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found"
|
||||
# Run the actual installation with detailed logging
|
||||
echo "=== Starting ct install ==="
|
||||
set +e
|
||||
ct install --all \
|
||||
--helm-extra-set-args="\
|
||||
--set=nginx.enabled=false \
|
||||
--set=minio.enabled=false \
|
||||
--set=vespa.enabled=false \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=postgresql.enabled=true \
|
||||
--set=postgresql.nameOverride=cloudnative-pg \
|
||||
--set=postgresql.cluster.storage.storageClass=standard \
|
||||
--set=redis.enabled=true \
|
||||
--set=redis.storageSpec.volumeClaimTemplate.spec.storageClassName=standard \
|
||||
--set=webserver.replicaCount=1 \
|
||||
--set=api.replicaCount=0 \
|
||||
--set=inferenceCapability.replicaCount=0 \
|
||||
--set=indexCapability.replicaCount=0 \
|
||||
--set=celery_beat.replicaCount=0 \
|
||||
--set=celery_worker_heavy.replicaCount=0 \
|
||||
--set=celery_worker_docfetching.replicaCount=0 \
|
||||
--set=celery_worker_docprocessing.replicaCount=0 \
|
||||
--set=celery_worker_light.replicaCount=0 \
|
||||
--set=celery_worker_monitoring.replicaCount=0 \
|
||||
--set=celery_worker_primary.replicaCount=0 \
|
||||
--set=celery_worker_user_file_processing.replicaCount=0 \
|
||||
--set=celery_worker_user_files_indexing.replicaCount=0" \
|
||||
--helm-extra-args="--timeout 900s --debug" \
|
||||
--debug --config ct.yaml
|
||||
CT_EXIT=$?
|
||||
set -e
|
||||
|
||||
echo "=== Helm releases ==="
|
||||
helm list --all-namespaces
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
# don't worry about that for now
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
if [[ $CT_EXIT -ne 0 ]]; then
|
||||
echo "ct install failed with exit code $CT_EXIT"
|
||||
exit $CT_EXIT
|
||||
else
|
||||
echo "=== Installation completed successfully ==="
|
||||
fi
|
||||
|
||||
kubectl get pods --all-namespaces
|
||||
|
||||
- name: Post-install verification
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Post-install verification ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get services --all-namespaces
|
||||
# Only show issues if they exist
|
||||
kubectl describe pods --all-namespaces | grep -A 5 -B 2 "Failed\|Error\|Warning" || echo "No pod issues found"
|
||||
|
||||
- name: Cleanup on failure
|
||||
if: failure() && steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Cleanup on failure ==="
|
||||
echo "=== Final cluster state ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -10
|
||||
|
||||
echo "=== Pod descriptions for debugging ==="
|
||||
kubectl describe pods --all-namespaces | grep -A 10 -B 3 "Failed\|Error\|Warning\|Pending" || echo "No problematic pods found"
|
||||
|
||||
echo "=== Recent logs for debugging ==="
|
||||
kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found"
|
||||
|
||||
echo "=== Helm releases ==="
|
||||
helm list --all-namespaces
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
# don't worry about that for now
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
108
.github/workflows/pr-integration-tests.yml
vendored
108
.github/workflows/pr-integration-tests.yml
vendored
@@ -33,11 +33,6 @@ env:
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
|
||||
GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN: ${{ secrets.ONYX_GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN }}
|
||||
GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC: ${{ secrets.ONYX_GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC }}
|
||||
GITHUB_ADMIN_EMAIL: ${{ secrets.ONYX_GITHUB_ADMIN_EMAIL }}
|
||||
GITHUB_TEST_USER_1_EMAIL: ${{ secrets.ONYX_GITHUB_TEST_USER_1_EMAIL }}
|
||||
GITHUB_TEST_USER_2_EMAIL: ${{ secrets.ONYX_GITHUB_TEST_USER_2_EMAIL }}
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
@@ -48,7 +43,7 @@ jobs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -56,7 +51,7 @@ jobs:
|
||||
id: set-matrix
|
||||
run: |
|
||||
# Find all leaf-level directories in both test directories
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" -exec basename {} \; | sort)
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
|
||||
# Create JSON array with directory info
|
||||
@@ -72,19 +67,14 @@ jobs:
|
||||
all_dirs="[${all_dirs%,}]"
|
||||
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
|
||||
|
||||
|
||||
build-backend-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-backend-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-backend-image", "extras=ecr-cache"]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -132,19 +122,14 @@ jobs:
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
|
||||
build-model-server-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-model-server-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-model-server-image", "extras=ecr-cache"]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -191,19 +176,14 @@ jobs:
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
|
||||
|
||||
build-integration-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=2cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-integration-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
runs-on: [runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-build-integration-image", "extras=ecr-cache"]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -240,7 +220,7 @@ jobs:
|
||||
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
|
||||
HEAD_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
run: |
|
||||
docker buildx bake --push \
|
||||
cd backend && docker buildx bake --push \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
|
||||
@@ -279,7 +259,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -294,29 +274,23 @@ jobs:
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
# NOTE: don't need web server for integration tests
|
||||
- name: Create .env file for Docker Compose
|
||||
- name: Start Docker containers
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=1
|
||||
MCP_SERVER_ENABLED=true
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
POSTGRES_POOL_PRE_PING=true \
|
||||
POSTGRES_USE_NULL_POOL=true \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 \
|
||||
MCP_SERVER_ENABLED=true \
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
@@ -325,6 +299,7 @@ jobs:
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
mcp_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker
|
||||
@@ -367,6 +342,12 @@ jobs:
|
||||
}
|
||||
|
||||
wait_for_service "http://localhost:8080/health" "API server"
|
||||
test_dir="${{ matrix.test-dir.path }}"
|
||||
if [ "$test_dir" = "tests/mcp" ]; then
|
||||
wait_for_service "http://localhost:8090/health" "MCP server"
|
||||
else
|
||||
echo "Skipping MCP server wait for non-MCP suite: $test_dir"
|
||||
fi
|
||||
echo "Finished waiting for services."
|
||||
|
||||
- name: Start Mock Services
|
||||
@@ -396,6 +377,8 @@ jobs:
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e MCP_SERVER_HOST=mcp_server \
|
||||
-e MCP_SERVER_PORT=8090 \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
@@ -411,11 +394,6 @@ jobs:
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
|
||||
-e GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN=${GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN} \
|
||||
-e GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC=${GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC} \
|
||||
-e GITHUB_ADMIN_EMAIL=${GITHUB_ADMIN_EMAIL} \
|
||||
-e GITHUB_TEST_USER_1_EMAIL=${GITHUB_TEST_USER_1_EMAIL} \
|
||||
-e GITHUB_TEST_USER_2_EMAIL=${GITHUB_TEST_USER_2_EMAIL} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
@@ -444,22 +422,21 @@ jobs:
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
|
||||
multitenant-tests:
|
||||
needs:
|
||||
[build-backend-image, build-model-server-image, build-integration-image]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=8cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-multitenant-tests",
|
||||
"extras=ecr-cache",
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
]
|
||||
runs-on: [runs-on, runner=8cpu-linux-arm64, "run-id=${{ github.run_id }}-multitenant-tests", "extras=ecr-cache"]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -480,10 +457,10 @@ jobs:
|
||||
AUTH_TYPE=cloud \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
OPENAI_DEFAULT_API_KEY=${OPENAI_API_KEY} \
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
DEV_MODE=true \
|
||||
MCP_SERVER_ENABLED=true \
|
||||
docker compose -f docker-compose.multitenant-dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
@@ -492,6 +469,7 @@ jobs:
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
mcp_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker_multi_tenant
|
||||
@@ -540,6 +518,8 @@ jobs:
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e MCP_SERVER_HOST=mcp_server \
|
||||
-e MCP_SERVER_PORT=8090 \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
|
||||
11
.github/workflows/pr-jest-tests.yml
vendored
11
.github/workflows/pr-jest-tests.yml
vendored
@@ -4,14 +4,7 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -23,12 +16,12 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v4
|
||||
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
|
||||
74
.github/workflows/pr-mit-integration-tests.yml
vendored
74
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -48,7 +48,7 @@ jobs:
|
||||
id: set-matrix
|
||||
run: |
|
||||
# Find all leaf-level directories in both test directories
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" -exec basename {} \; | sort)
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
|
||||
# Create JSON array with directory info
|
||||
@@ -65,18 +65,12 @@ jobs:
|
||||
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
|
||||
|
||||
build-backend-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-backend-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-backend-image", "extras=ecr-cache"]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -125,18 +119,12 @@ jobs:
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-model-server-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-model-server-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-model-server-image", "extras=ecr-cache"]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -184,18 +172,12 @@ jobs:
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
|
||||
build-integration-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=2cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-integration-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
runs-on: [runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-build-integration-image", "extras=ecr-cache"]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -232,7 +214,7 @@ jobs:
|
||||
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
|
||||
HEAD_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
run: |
|
||||
docker buildx bake --push \
|
||||
cd backend && docker buildx bake --push \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
|
||||
@@ -271,7 +253,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -286,27 +268,21 @@ jobs:
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
# NOTE: don't need web server for integration tests
|
||||
- name: Create .env file for Docker Compose
|
||||
- name: Start Docker containers
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
MCP_SERVER_ENABLED=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=1
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
AUTH_TYPE=basic \
|
||||
POSTGRES_POOL_PRE_PING=true \
|
||||
POSTGRES_USE_NULL_POOL=true \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
MCP_SERVER_ENABLED=true \
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
@@ -315,6 +291,7 @@ jobs:
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
mcp_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker
|
||||
@@ -357,6 +334,12 @@ jobs:
|
||||
}
|
||||
|
||||
wait_for_service "http://localhost:8080/health" "API server"
|
||||
test_dir="${{ matrix.test-dir.path }}"
|
||||
if [ "$test_dir" = "tests/mcp" ]; then
|
||||
wait_for_service "http://localhost:8090/health" "MCP server"
|
||||
else
|
||||
echo "Skipping MCP server wait for non-MCP suite: $test_dir"
|
||||
fi
|
||||
echo "Finished waiting for services."
|
||||
|
||||
- name: Start Mock Services
|
||||
@@ -387,6 +370,8 @@ jobs:
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e MCP_SERVER_HOST=mcp_server \
|
||||
-e MCP_SERVER_PORT=8090 \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
@@ -430,6 +415,7 @@ jobs:
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
|
||||
required:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
|
||||
49
.github/workflows/pr-playwright-tests.yml
vendored
49
.github/workflows/pr-playwright-tests.yml
vendored
@@ -4,14 +4,7 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -54,19 +47,13 @@ env:
|
||||
|
||||
jobs:
|
||||
build-web-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=4cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-web-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
runs-on: [runs-on, runner=4cpu-linux-arm64, "run-id=${{ github.run_id }}-build-web-image", "extras=ecr-cache"]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -115,19 +102,13 @@ jobs:
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-backend-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-backend-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-backend-image", "extras=ecr-cache"]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -176,19 +157,13 @@ jobs:
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-model-server-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-model-server-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-model-server-image", "extras=ecr-cache"]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -254,15 +229,16 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v4
|
||||
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
cache: 'npm'
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install node dependencies
|
||||
@@ -471,6 +447,7 @@ jobs:
|
||||
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
|
||||
run: exit 1
|
||||
|
||||
|
||||
# NOTE: Chromatic UI diff testing is currently disabled.
|
||||
# We are using Playwright for local and CI testing without visual regression checks.
|
||||
# Chromatic may be reintroduced in the future for UI diff testing if needed.
|
||||
@@ -488,12 +465,12 @@ jobs:
|
||||
# ]
|
||||
# steps:
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
# uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
|
||||
# - name: Setup node
|
||||
# uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v4
|
||||
# uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4
|
||||
# with:
|
||||
# node-version: 22
|
||||
|
||||
|
||||
55
.github/workflows/pr-python-checks.yml
vendored
55
.github/workflows/pr-python-checks.yml
vendored
@@ -17,6 +17,24 @@ permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
validate-requirements:
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # ratchet:astral-sh/setup-uv@v7.1.4
|
||||
# TODO: Enable caching once there is a uv.lock file checked in.
|
||||
# with:
|
||||
# enable-cache: true
|
||||
|
||||
- name: Validate requirements lock files
|
||||
run: ./backend/scripts/compile_requirements.py --check
|
||||
|
||||
mypy-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# Note: Mypy seems quite optimized for x64 compared to arm64.
|
||||
@@ -27,7 +45,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -40,10 +58,35 @@ jobs:
|
||||
backend/requirements/model_server.txt
|
||||
backend/requirements/ee.txt
|
||||
|
||||
- name: Generate OpenAPI schema and Python client
|
||||
- name: Generate OpenAPI schema
|
||||
shell: bash
|
||||
working-directory: backend
|
||||
env:
|
||||
PYTHONPATH: "."
|
||||
run: |
|
||||
python scripts/onyx_openapi_schema.py --filename generated/openapi.json
|
||||
|
||||
# needed for pulling openapitools/openapi-generator-cli
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Generate OpenAPI Python client
|
||||
shell: bash
|
||||
run: |
|
||||
ods openapi all
|
||||
docker run --rm \
|
||||
-v "${{ github.workspace }}/backend/generated:/local" \
|
||||
openapitools/openapi-generator-cli generate \
|
||||
-i /local/openapi.json \
|
||||
-g python \
|
||||
-o /local/onyx_openapi_client \
|
||||
--package-name onyx_openapi_client \
|
||||
--skip-validate-spec \
|
||||
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
|
||||
|
||||
- name: Cache mypy cache
|
||||
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
|
||||
@@ -60,9 +103,3 @@ jobs:
|
||||
MYPY_FORCE_COLOR: 1
|
||||
TERM: xterm-256color
|
||||
run: mypy .
|
||||
|
||||
- name: Run MyPy (tools/)
|
||||
env:
|
||||
MYPY_FORCE_COLOR: 1
|
||||
TERM: xterm-256color
|
||||
run: mypy tools/
|
||||
|
||||
12
.github/workflows/pr-python-connector-tests.yml
vendored
12
.github/workflows/pr-python-connector-tests.yml
vendored
@@ -133,13 +133,12 @@ jobs:
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
DISABLE_TELEMETRY: "true"
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -161,20 +160,16 @@ jobs:
|
||||
hubspot:
|
||||
- 'backend/onyx/connectors/hubspot/**'
|
||||
- 'backend/tests/daily/connectors/hubspot/**'
|
||||
- 'uv.lock'
|
||||
salesforce:
|
||||
- 'backend/onyx/connectors/salesforce/**'
|
||||
- 'backend/tests/daily/connectors/salesforce/**'
|
||||
- 'uv.lock'
|
||||
github:
|
||||
- 'backend/onyx/connectors/github/**'
|
||||
- 'backend/tests/daily/connectors/github/**'
|
||||
- 'uv.lock'
|
||||
file_processing:
|
||||
- 'backend/onyx/file_processing/**'
|
||||
- 'uv.lock'
|
||||
|
||||
- name: Run Tests (excluding HubSpot, Salesforce, GitHub, and Coda)
|
||||
- name: Run Tests (excluding HubSpot, Salesforce, and GitHub)
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test \
|
||||
@@ -187,8 +182,7 @@ jobs:
|
||||
backend/tests/daily/connectors \
|
||||
--ignore backend/tests/daily/connectors/hubspot \
|
||||
--ignore backend/tests/daily/connectors/salesforce \
|
||||
--ignore backend/tests/daily/connectors/github \
|
||||
--ignore backend/tests/daily/connectors/coda
|
||||
--ignore backend/tests/daily/connectors/github
|
||||
|
||||
- name: Run HubSpot Connector Tests
|
||||
if: ${{ github.event_name == 'schedule' || steps.changes.outputs.hubspot == 'true' || steps.changes.outputs.file_processing == 'true' }}
|
||||
|
||||
2
.github/workflows/pr-python-model-tests.yml
vendored
2
.github/workflows/pr-python-model-tests.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
6
.github/workflows/pr-python-tests.yml
vendored
6
.github/workflows/pr-python-tests.yml
vendored
@@ -26,13 +26,15 @@ jobs:
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
REDIS_CLOUD_PYTEST_PASSWORD: ${{ secrets.REDIS_CLOUD_PYTEST_PASSWORD }}
|
||||
DISABLE_TELEMETRY: "true"
|
||||
SF_USERNAME: ${{ secrets.SF_USERNAME }}
|
||||
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
|
||||
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
23
.github/workflows/pr-quality-checks.yml
vendored
23
.github/workflows/pr-quality-checks.yml
vendored
@@ -7,8 +7,6 @@ on:
|
||||
merge_group:
|
||||
pull_request: null
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
@@ -17,10 +15,17 @@ permissions:
|
||||
|
||||
jobs:
|
||||
quality-checks:
|
||||
runs-on: ubuntu-latest
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-quality-checks",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
@@ -30,7 +35,7 @@ jobs:
|
||||
- name: Setup Terraform
|
||||
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # ratchet:hashicorp/setup-terraform@v3
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v6
|
||||
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4
|
||||
with: # zizmor: ignore[cache-poisoning]
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
@@ -38,10 +43,12 @@ jobs:
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
- uses: j178/prek-action@91fd7d7cf70ae1dee9f4f44e7dfa5d1073fe6623 # ratchet:j178/prek-action@v1
|
||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet:pre-commit/action@v3.0.1
|
||||
env:
|
||||
# uv-run is mypy's id and mypy is covered by the Python Checks which caches dependencies better.
|
||||
SKIP: uv-run
|
||||
with:
|
||||
prek-version: '0.2.21'
|
||||
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}
|
||||
extra_args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || '' }}
|
||||
- name: Check Actions
|
||||
uses: giner/check-actions@28d366c7cbbe235f9624a88aa31a628167eee28c # ratchet:giner/check-actions@v1.0.1
|
||||
with:
|
||||
|
||||
19
.github/workflows/release-devtools.yml
vendored
19
.github/workflows/release-devtools.yml
vendored
@@ -16,22 +16,21 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os-arch:
|
||||
- { goos: "linux", goarch: "amd64" }
|
||||
- { goos: "linux", goarch: "arm64" }
|
||||
- { goos: "windows", goarch: "amd64" }
|
||||
- { goos: "windows", goarch: "arm64" }
|
||||
- { goos: "darwin", goarch: "amd64" }
|
||||
- { goos: "darwin", goarch: "arm64" }
|
||||
- { goos: "", goarch: "" }
|
||||
- {goos: "linux", goarch: "amd64"}
|
||||
- {goos: "linux", goarch: "arm64"}
|
||||
- {goos: "windows", goarch: "amd64"}
|
||||
- {goos: "windows", goarch: "arm64"}
|
||||
- {goos: "darwin", goarch: "amd64"}
|
||||
- {goos: "darwin", goarch: "arm64"}
|
||||
- {goos: "", goarch: ""}
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
- uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # ratchet:astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
- run: |
|
||||
GOOS="${{ matrix.os-arch.goos }}" \
|
||||
GOARCH="${{ matrix.os-arch.goarch }}" \
|
||||
|
||||
2
.github/workflows/sync_foss.yml
vendored
2
.github/workflows/sync_foss.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout main Onyx repo
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
2
.github/workflows/tag-nightly.yml
vendored
2
.github/workflows/tag-nightly.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
|
||||
# implement here which needs an actual user's deploy key
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
ssh-key: "${{ secrets.DEPLOY_KEY }}"
|
||||
persist-credentials: true
|
||||
|
||||
16
.github/workflows/zizmor.yml
vendored
16
.github/workflows/zizmor.yml
vendored
@@ -17,33 +17,21 @@ jobs:
|
||||
security-events: write # needed for SARIF uploads
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6.0.1
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Detect changes
|
||||
id: filter
|
||||
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # ratchet:dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
zizmor:
|
||||
- '.github/**'
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # ratchet:astral-sh/setup-uv@v7.1.4
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
- name: Run zizmor
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
run: uv run --no-sync --with zizmor zizmor --format=sarif . > results.sarif
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Upload SARIF file
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab # ratchet:github/codeql-action/upload-sarif@codeql-bundle-v2.23.5
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -53,6 +53,3 @@ node_modules
|
||||
|
||||
# MCP configs
|
||||
.playwright-mcp
|
||||
|
||||
# plans
|
||||
plans/
|
||||
|
||||
@@ -5,138 +5,79 @@ default_install_hook_types:
|
||||
- post-rewrite
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||
# From: https://github.com/astral-sh/uv-pre-commit/pull/53/commits/d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
rev: 569ddf04117761eb74cef7afb5143bbb96fcdfbb # frozen: 0.9.15
|
||||
hooks:
|
||||
- id: uv-sync
|
||||
args: ["--locked", "--all-extras"]
|
||||
- id: uv-lock
|
||||
files: ^pyproject\.toml$
|
||||
- id: uv-export
|
||||
name: uv-export default.txt
|
||||
args:
|
||||
[
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"backend",
|
||||
"-o",
|
||||
"backend/requirements/default.txt",
|
||||
]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-export
|
||||
name: uv-export dev.txt
|
||||
args:
|
||||
[
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"dev",
|
||||
"-o",
|
||||
"backend/requirements/dev.txt",
|
||||
]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-export
|
||||
name: uv-export ee.txt
|
||||
args:
|
||||
[
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"ee",
|
||||
"-o",
|
||||
"backend/requirements/ee.txt",
|
||||
]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-export
|
||||
name: uv-export model_server.txt
|
||||
args:
|
||||
[
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"model_server",
|
||||
"-o",
|
||||
"backend/requirements/model_server.txt",
|
||||
]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-run
|
||||
name: Check lazy imports
|
||||
args: ["--active", "--with=onyx-devtools", "ods", "check-lazy-imports"]
|
||||
files: ^backend/(?!\.venv/).*\.py$
|
||||
# NOTE: This takes ~6s on a single, large module which is prohibitively slow.
|
||||
# - id: uv-run
|
||||
# name: mypy
|
||||
# args: ["--all-extras", "mypy"]
|
||||
# args: ["mypy"]
|
||||
# pass_filenames: true
|
||||
# files: ^backend/.*\.py$
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
files: ^.github/
|
||||
|
||||
- repo: https://github.com/rhysd/actionlint
|
||||
rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9
|
||||
rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9
|
||||
hooks:
|
||||
- id: actionlint
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 8a737e727ac5ab2f1d4cf5876720ed276dc8dc4b # frozen: 25.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3.11
|
||||
- id: black
|
||||
language_version: python3.11
|
||||
|
||||
# this is a fork which keeps compatibility with black
|
||||
- repo: https://github.com/wimglenn/reorder-python-imports-black
|
||||
rev: f55cd27f90f0cf0ee775002c2383ce1c7820013d # frozen: v3.14.0
|
||||
rev: f55cd27f90f0cf0ee775002c2383ce1c7820013d # frozen: v3.14.0
|
||||
hooks:
|
||||
- id: reorder-python-imports
|
||||
args: ["--py311-plus", "--application-directories=backend/"]
|
||||
# need to ignore alembic files, since reorder-python-imports gets confused
|
||||
# and thinks that alembic is a local package since there is a folder
|
||||
# in the backend directory called `alembic`
|
||||
exclude: ^backend/alembic/
|
||||
- id: reorder-python-imports
|
||||
args: ['--py311-plus', '--application-directories=backend/']
|
||||
# need to ignore alembic files, since reorder-python-imports gets confused
|
||||
# and thinks that alembic is a local package since there is a folder
|
||||
# in the backend directory called `alembic`
|
||||
exclude: ^backend/alembic/
|
||||
|
||||
# These settings will remove unused imports with side effects
|
||||
# Note: The repo currently does not and should not have imports with side effects
|
||||
- repo: https://github.com/PyCQA/autoflake
|
||||
rev: 0544741e2b4a22b472d9d93e37d4ea9153820bb1 # frozen: v2.3.1
|
||||
rev: 0544741e2b4a22b472d9d93e37d4ea9153820bb1 # frozen: v2.3.1
|
||||
hooks:
|
||||
- id: autoflake
|
||||
args:
|
||||
[
|
||||
"--remove-all-unused-imports",
|
||||
"--remove-unused-variables",
|
||||
"--in-place",
|
||||
"--recursive",
|
||||
]
|
||||
args: [ '--remove-all-unused-imports', '--remove-unused-variables', '--in-place' , '--recursive']
|
||||
|
||||
- repo: https://github.com/golangci/golangci-lint
|
||||
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
|
||||
rev: e6ebea0145f385056bce15041d3244c0e5e15848 # frozen: v2.7.0
|
||||
hooks:
|
||||
- id: golangci-lint
|
||||
entry: bash -c "find tools/ -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: 971923581912ef60a6b70dbf0c3e9a39563c9d47 # frozen: v0.11.4
|
||||
rev: 971923581912ef60a6b70dbf0c3e9a39563c9d47 # frozen: v0.11.4
|
||||
hooks:
|
||||
- id: ruff
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||
rev: ffb6a759a979008c0e6dff86e39f4745a2d9eac4 # frozen: v3.1.0
|
||||
rev: ffb6a759a979008c0e6dff86e39f4745a2d9eac4 # frozen: v3.1.0
|
||||
hooks:
|
||||
- id: prettier
|
||||
types_or: [html, css, javascript, ts, tsx]
|
||||
language_version: system
|
||||
- id: prettier
|
||||
types_or: [html, css, javascript, ts, tsx]
|
||||
language_version: system
|
||||
|
||||
- repo: https://github.com/sirwart/ripsecrets
|
||||
rev: 7d94620933e79b8acaa0cd9e60e9864b07673d86 # frozen: v0.1.11
|
||||
rev: 7d94620933e79b8acaa0cd9e60e9864b07673d86 # frozen: v0.1.11
|
||||
hooks:
|
||||
- id: ripsecrets
|
||||
args:
|
||||
- --additional-pattern
|
||||
- ^sk-[A-Za-z0-9_\-]{20,}$
|
||||
- --additional-pattern
|
||||
- ^sk-[A-Za-z0-9_\-]{20,}$
|
||||
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
@@ -147,13 +88,15 @@ repos:
|
||||
pass_filenames: false
|
||||
files: \.tf$
|
||||
|
||||
# Uses tsgo (TypeScript's native Go compiler) for ~10x faster type checking.
|
||||
# This is a preview package - if it breaks:
|
||||
# 1. Try updating: cd web && npm update @typescript/native-preview
|
||||
# 2. Or fallback to tsc: replace 'tsgo' with 'tsc' below
|
||||
- id: check-lazy-imports
|
||||
name: Check lazy imports
|
||||
entry: python3 backend/scripts/check_lazy_imports.py
|
||||
language: system
|
||||
files: ^backend/(?!\.venv/).*\.py$
|
||||
|
||||
- id: typescript-check
|
||||
name: TypeScript type check
|
||||
entry: bash -c 'cd web && npx tsgo --noEmit --project tsconfig.types.json'
|
||||
entry: bash -c 'cd web && npm run types:check'
|
||||
language: system
|
||||
pass_filenames: false
|
||||
files: ^web/.*\.(ts|tsx)$
|
||||
|
||||
51
.vscode/env_template.txt
vendored
51
.vscode/env_template.txt
vendored
@@ -1,45 +1,36 @@
|
||||
# Copy this file to .env in the .vscode folder.
|
||||
# Fill in the <REPLACE THIS> values as needed; it is recommended to set the
|
||||
# GEN_AI_API_KEY value to avoid having to set up an LLM in the UI.
|
||||
# Also check out onyx/backend/scripts/restart_containers.sh for a script to
|
||||
# restart the containers which Onyx relies on outside of VSCode/Cursor
|
||||
# processes.
|
||||
# Copy this file to .env in the .vscode folder
|
||||
# Fill in the <REPLACE THIS> values as needed, it is recommended to set the GEN_AI_API_KEY value to avoid having to set up an LLM in the UI
|
||||
# Also check out onyx/backend/scripts/restart_containers.sh for a script to restart the containers which Onyx relies on outside of VSCode/Cursor processes
|
||||
|
||||
|
||||
# For local dev, often user Authentication is not needed.
|
||||
# For local dev, often user Authentication is not needed
|
||||
AUTH_TYPE=disabled
|
||||
|
||||
|
||||
# Always keep these on for Dev.
|
||||
# Logs model prompts, reasoning, and answer to stdout.
|
||||
# Always keep these on for Dev
|
||||
# Logs model prompts, reasoning, and answer to stdout
|
||||
LOG_ONYX_MODEL_INTERACTIONS=True
|
||||
# More verbose logging
|
||||
LOG_LEVEL=debug
|
||||
|
||||
|
||||
# This passes top N results to LLM an additional time for reranking prior to
|
||||
# answer generation.
|
||||
# This step is quite heavy on token usage so we disable it for dev generally.
|
||||
# This passes top N results to LLM an additional time for reranking prior to answer generation
|
||||
# This step is quite heavy on token usage so we disable it for dev generally
|
||||
DISABLE_LLM_DOC_RELEVANCE=False
|
||||
|
||||
|
||||
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically).
|
||||
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
|
||||
OAUTH_CLIENT_ID=<REPLACE THIS>
|
||||
OAUTH_CLIENT_SECRET=<REPLACE THIS>
|
||||
OPENID_CONFIG_URL=<REPLACE THIS>
|
||||
SAML_CONF_DIR=/<ABSOLUTE PATH TO ONYX>/onyx/backend/ee/onyx/configs/saml_config
|
||||
|
||||
|
||||
# Generally not useful for dev, we don't generally want to set up an SMTP server
|
||||
# for dev.
|
||||
# Generally not useful for dev, we don't generally want to set up an SMTP server for dev
|
||||
REQUIRE_EMAIL_VERIFICATION=False
|
||||
|
||||
|
||||
# Set these so if you wipe the DB, you don't end up having to go through the UI
|
||||
# every time.
|
||||
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
|
||||
GEN_AI_API_KEY=<REPLACE THIS>
|
||||
OPENAI_API_KEY=<REPLACE THIS>
|
||||
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper.
|
||||
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper
|
||||
GEN_AI_MODEL_VERSION=gpt-4o
|
||||
FAST_GEN_AI_MODEL_VERSION=gpt-4o
|
||||
|
||||
@@ -49,36 +40,26 @@ PYTHONPATH=../backend
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
|
||||
# Enable the full set of Danswer Enterprise Edition features.
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you
|
||||
# are using this for local testing/development).
|
||||
# Enable the full set of Danswer Enterprise Edition features
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
|
||||
|
||||
# S3 File Store Configuration (MinIO for local development)
|
||||
S3_ENDPOINT_URL=http://localhost:9004
|
||||
S3_FILE_STORE_BUCKET_NAME=onyx-file-store-bucket
|
||||
S3_AWS_ACCESS_KEY_ID=minioadmin
|
||||
S3_AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
|
||||
|
||||
# Show extra/uncommon connectors.
|
||||
# Show extra/uncommon connectors
|
||||
SHOW_EXTRA_CONNECTORS=True
|
||||
|
||||
|
||||
# Local langsmith tracing
|
||||
LANGSMITH_TRACING="true"
|
||||
LANGSMITH_ENDPOINT="https://api.smith.langchain.com"
|
||||
LANGSMITH_API_KEY=<REPLACE_THIS>
|
||||
LANGSMITH_PROJECT=<REPLACE_THIS>
|
||||
|
||||
|
||||
# Local Confluence OAuth testing
|
||||
# OAUTH_CONFLUENCE_CLOUD_CLIENT_ID=<REPLACE_THIS>
|
||||
# OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET=<REPLACE_THIS>
|
||||
# NEXT_PUBLIC_TEST_ENV=True
|
||||
|
||||
|
||||
# OpenSearch
|
||||
# Arbitrary password is fine for local development.
|
||||
OPENSEARCH_INITIAL_ADMIN_PASSWORD=<REPLACE THIS>
|
||||
# NEXT_PUBLIC_TEST_ENV=True
|
||||
22
.vscode/launch.template.jsonc
vendored
22
.vscode/launch.template.jsonc
vendored
@@ -508,21 +508,7 @@
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart OpenSearch Container",
|
||||
// Generic debugger type, required arg but has no bearing on bash.
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"${workspaceFolder}/backend/scripts/restart_opensearch_container.sh"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"stopOnEntry": true,
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
@@ -568,10 +554,10 @@
|
||||
"name": "Install Python Requirements",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"sync",
|
||||
"--all-extras"
|
||||
"-c",
|
||||
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
# AGENTS.md
|
||||
|
||||
This file provides guidance to AI agents when working with code in this repository.
|
||||
This file provides guidance to Codex when working with code in this repository.
|
||||
|
||||
## KEY NOTES
|
||||
|
||||
- If you run into any missing python dependency errors, try running your command with `source .venv/bin/activate` \
|
||||
- If you run into any missing python dependency errors, try running your command with `source backend/.venv/bin/activate` \
|
||||
to assume the python venv.
|
||||
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@example.com` and password
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
|
||||
`a`. The app can be accessed at `http://localhost:3000`.
|
||||
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
|
||||
make sure we see logs coming out from the relevant service.
|
||||
@@ -181,286 +181,6 @@ web/
|
||||
└── src/lib/ # Utilities & business logic
|
||||
```
|
||||
|
||||
## Frontend Standards
|
||||
|
||||
### 1. Import Standards
|
||||
|
||||
**Always use absolute imports with the `@` prefix.**
|
||||
|
||||
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useAuth } from "@/hooks/useAuth";
|
||||
import { Text } from "@/refresh-components/texts/Text";
|
||||
|
||||
// ❌ Bad
|
||||
import { Button } from "../../../components/ui/button";
|
||||
import { useAuth } from "./hooks/useAuth";
|
||||
```
|
||||
|
||||
### 2. React Component Functions
|
||||
|
||||
**Prefer regular functions over arrow functions for React components.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
function UserProfile({ userId }: UserProfileProps) {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
const UserProfile = ({ userId }: UserProfileProps) => {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Props Interface Extraction
|
||||
|
||||
**Extract prop types into their own interface definitions.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
interface UserCardProps {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}
|
||||
|
||||
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({
|
||||
user,
|
||||
showActions = false,
|
||||
onEdit
|
||||
}: {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Spacing Guidelines
|
||||
|
||||
**Prefer padding over margins for spacing.**
|
||||
|
||||
**Reason:** We want to consolidate usage to paddings instead of margins.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
<div className="p-4 space-y-2">
|
||||
<div className="p-2">Content</div>
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className="m-4 space-y-2">
|
||||
<div className="m-2">Content</div>
|
||||
</div>
|
||||
```
|
||||
|
||||
### 5. Tailwind Dark Mode
|
||||
|
||||
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
|
||||
|
||||
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
|
||||
|
||||
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Standard components use `web/tailwind-themes/tailwind.config.js` / `web/src/app/css/colors.css`
|
||||
<div className="bg-background-neutral-03 text-text-02">
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
|
||||
export const GithubIcon = createLogoIcon(githubLightIcon, {
|
||||
monochromatic: true, // Will apply dark:invert internally
|
||||
});
|
||||
|
||||
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
|
||||
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
|
||||
});
|
||||
|
||||
// ❌ Bad - Manual dark mode overrides
|
||||
<div className="bg-white dark:bg-black text-black dark:text-white">
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
### 6. Class Name Utilities
|
||||
|
||||
**Use the `cn` utility instead of raw string formatting for classNames.**
|
||||
|
||||
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
|
||||
|
||||
```typescript
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
// ✅ Good
|
||||
<div className={cn(
|
||||
'base-class',
|
||||
isActive && 'active-class',
|
||||
className
|
||||
)}>
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
### 7. Custom Hooks Organization
|
||||
|
||||
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
|
||||
|
||||
**Reason:** This is just a layout preference. Keeps code clean.
|
||||
|
||||
```typescript
|
||||
// web/src/hooks/useUserData.ts
|
||||
export function useUserData(userId: string) {
|
||||
// hook implementation
|
||||
}
|
||||
|
||||
// web/src/hooks/useLocalStorage.ts
|
||||
export function useLocalStorage<T>(key: string, initialValue: T) {
|
||||
// hook implementation
|
||||
}
|
||||
```
|
||||
|
||||
### 8. Icon Usage
|
||||
|
||||
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
|
||||
|
||||
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import SvgX from "@/icons/x";
|
||||
import SvgMoreHorizontal from "@/icons/more-horizontal";
|
||||
|
||||
// ❌ Bad
|
||||
import { User } from "lucide-react";
|
||||
import { FiSearch } from "react-icons/fi";
|
||||
```
|
||||
|
||||
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
|
||||
If you need help with this step, reach out to `raunak@onyx.app`.
|
||||
|
||||
### 9. Text Rendering
|
||||
|
||||
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
|
||||
|
||||
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Text } from '@/refresh-components/texts/Text'
|
||||
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<Text
|
||||
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
|
||||
text03
|
||||
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
|
||||
mainAction
|
||||
>
|
||||
{name}
|
||||
</Text>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<div>
|
||||
<h2>{name}</h2>
|
||||
<p>User details</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 10. Component Usage
|
||||
|
||||
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
|
||||
|
||||
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import Button from '@/refresh-components/buttons/Button'
|
||||
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
|
||||
import SvgPlusCircle from '@/icons/plus-circle'
|
||||
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<InputTypeIn placeholder="Search..." />
|
||||
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<input placeholder="Name" />
|
||||
<textarea placeholder="Message" />
|
||||
<button type="submit">Submit</button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 11. Colors
|
||||
|
||||
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
|
||||
|
||||
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
|
||||
|
||||
**Available color categories:**
|
||||
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
|
||||
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
|
||||
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
|
||||
- **Actions:** `action-link-XX`, `action-danger-XX`
|
||||
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
|
||||
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Use custom Onyx color classes
|
||||
<div className="bg-background-neutral-01 border border-border-02" />
|
||||
<div className="bg-background-tint-02 border border-border-01" />
|
||||
<div className="bg-status-success-01" />
|
||||
<div className="bg-action-link-01" />
|
||||
<div className="bg-theme-primary-05" />
|
||||
|
||||
// ❌ Bad - Do NOT use standard Tailwind colors
|
||||
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
|
||||
<div className="bg-white border border-slate-200" />
|
||||
<div className="bg-green-100 text-green-700" />
|
||||
<div className="bg-blue-100 text-blue-600" />
|
||||
<div className="bg-indigo-500" />
|
||||
```
|
||||
|
||||
### 12. Data Fetching
|
||||
|
||||
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
|
||||
|
||||
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
### Running Migrations
|
||||
@@ -575,6 +295,14 @@ will be tailing their logs to this file.
|
||||
- Token management and rate limiting
|
||||
- Custom prompts and agent actions
|
||||
|
||||
## UI/UX Patterns
|
||||
|
||||
- Tailwind CSS with design system in `web/src/components/ui/`
|
||||
- Radix UI and Headless UI for accessible components
|
||||
- SWR for data fetching and caching
|
||||
- Form validation with react-hook-form
|
||||
- Error handling with popup notifications
|
||||
|
||||
## Creating a Plan
|
||||
When creating a plan in the `plans` directory, make sure to include at least these elements:
|
||||
|
||||
|
||||
@@ -4,10 +4,10 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
||||
|
||||
## KEY NOTES
|
||||
|
||||
- If you run into any missing python dependency errors, try running your command with `source .venv/bin/activate` \
|
||||
- If you run into any missing python dependency errors, try running your command with `source backend/.venv/bin/activate` \
|
||||
to assume the python venv.
|
||||
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@example.com` and password
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
|
||||
`a`. The app can be accessed at `http://localhost:3000`.
|
||||
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
|
||||
make sure we see logs coming out from the relevant service.
|
||||
@@ -184,286 +184,6 @@ web/
|
||||
└── src/lib/ # Utilities & business logic
|
||||
```
|
||||
|
||||
## Frontend Standards
|
||||
|
||||
### 1. Import Standards
|
||||
|
||||
**Always use absolute imports with the `@` prefix.**
|
||||
|
||||
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useAuth } from "@/hooks/useAuth";
|
||||
import { Text } from "@/refresh-components/texts/Text";
|
||||
|
||||
// ❌ Bad
|
||||
import { Button } from "../../../components/ui/button";
|
||||
import { useAuth } from "./hooks/useAuth";
|
||||
```
|
||||
|
||||
### 2. React Component Functions
|
||||
|
||||
**Prefer regular functions over arrow functions for React components.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
function UserProfile({ userId }: UserProfileProps) {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
const UserProfile = ({ userId }: UserProfileProps) => {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Props Interface Extraction
|
||||
|
||||
**Extract prop types into their own interface definitions.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
interface UserCardProps {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}
|
||||
|
||||
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({
|
||||
user,
|
||||
showActions = false,
|
||||
onEdit
|
||||
}: {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Spacing Guidelines
|
||||
|
||||
**Prefer padding over margins for spacing.**
|
||||
|
||||
**Reason:** We want to consolidate usage to paddings instead of margins.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
<div className="p-4 space-y-2">
|
||||
<div className="p-2">Content</div>
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className="m-4 space-y-2">
|
||||
<div className="m-2">Content</div>
|
||||
</div>
|
||||
```
|
||||
|
||||
### 5. Tailwind Dark Mode
|
||||
|
||||
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
|
||||
|
||||
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
|
||||
|
||||
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Standard components use `tailwind-themes/tailwind.config.js` / `src/app/css/colors.css`
|
||||
<div className="bg-background-neutral-03 text-text-02">
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
|
||||
export const GithubIcon = createLogoIcon(githubLightIcon, {
|
||||
monochromatic: true, // Will apply dark:invert internally
|
||||
});
|
||||
|
||||
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
|
||||
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
|
||||
});
|
||||
|
||||
// ❌ Bad - Manual dark mode overrides
|
||||
<div className="bg-white dark:bg-black text-black dark:text-white">
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
### 6. Class Name Utilities
|
||||
|
||||
**Use the `cn` utility instead of raw string formatting for classNames.**
|
||||
|
||||
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
|
||||
|
||||
```typescript
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
// ✅ Good
|
||||
<div className={cn(
|
||||
'base-class',
|
||||
isActive && 'active-class',
|
||||
className
|
||||
)}>
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
### 7. Custom Hooks Organization
|
||||
|
||||
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
|
||||
|
||||
**Reason:** This is just a layout preference. Keeps code clean.
|
||||
|
||||
```typescript
|
||||
// web/src/hooks/useUserData.ts
|
||||
export function useUserData(userId: string) {
|
||||
// hook implementation
|
||||
}
|
||||
|
||||
// web/src/hooks/useLocalStorage.ts
|
||||
export function useLocalStorage<T>(key: string, initialValue: T) {
|
||||
// hook implementation
|
||||
}
|
||||
```
|
||||
|
||||
### 8. Icon Usage
|
||||
|
||||
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
|
||||
|
||||
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import SvgX from "@/icons/x";
|
||||
import SvgMoreHorizontal from "@/icons/more-horizontal";
|
||||
|
||||
// ❌ Bad
|
||||
import { User } from "lucide-react";
|
||||
import { FiSearch } from "react-icons/fi";
|
||||
```
|
||||
|
||||
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
|
||||
If you need help with this step, reach out to `raunak@onyx.app`.
|
||||
|
||||
### 9. Text Rendering
|
||||
|
||||
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
|
||||
|
||||
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Text } from '@/refresh-components/texts/Text'
|
||||
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<Text
|
||||
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
|
||||
text03
|
||||
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
|
||||
mainAction
|
||||
>
|
||||
{name}
|
||||
</Text>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<div>
|
||||
<h2>{name}</h2>
|
||||
<p>User details</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 10. Component Usage
|
||||
|
||||
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
|
||||
|
||||
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import Button from '@/refresh-components/buttons/Button'
|
||||
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
|
||||
import SvgPlusCircle from '@/icons/plus-circle'
|
||||
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<InputTypeIn placeholder="Search..." />
|
||||
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<input placeholder="Name" />
|
||||
<textarea placeholder="Message" />
|
||||
<button type="submit">Submit</button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 11. Colors
|
||||
|
||||
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
|
||||
|
||||
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
|
||||
|
||||
**Available color categories:**
|
||||
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
|
||||
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
|
||||
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
|
||||
- **Actions:** `action-link-XX`, `action-danger-XX`
|
||||
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
|
||||
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Use custom Onyx color classes
|
||||
<div className="bg-background-neutral-01 border border-border-02" />
|
||||
<div className="bg-background-tint-02 border border-border-01" />
|
||||
<div className="bg-status-success-01" />
|
||||
<div className="bg-action-link-01" />
|
||||
<div className="bg-theme-primary-05" />
|
||||
|
||||
// ❌ Bad - Do NOT use standard Tailwind colors
|
||||
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
|
||||
<div className="bg-white border border-slate-200" />
|
||||
<div className="bg-green-100 text-green-700" />
|
||||
<div className="bg-blue-100 text-blue-600" />
|
||||
<div className="bg-indigo-500" />
|
||||
```
|
||||
|
||||
### 12. Data Fetching
|
||||
|
||||
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
|
||||
|
||||
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
### Running Migrations
|
||||
@@ -580,6 +300,14 @@ will be tailing their logs to this file.
|
||||
- Token management and rate limiting
|
||||
- Custom prompts and agent actions
|
||||
|
||||
## UI/UX Patterns
|
||||
|
||||
- Tailwind CSS with design system in `web/src/components/ui/`
|
||||
- Radix UI and Headless UI for accessible components
|
||||
- SWR for data fetching and caching
|
||||
- Form validation with react-hook-form
|
||||
- Error handling with popup notifications
|
||||
|
||||
## Creating a Plan
|
||||
When creating a plan in the `plans` directory, make sure to include at least these elements:
|
||||
|
||||
|
||||
@@ -71,12 +71,12 @@ If using a higher version, sometimes some libraries will not be available (i.e.
|
||||
|
||||
#### Backend: Python requirements
|
||||
|
||||
Currently, we use [uv](https://docs.astral.sh/uv/) and recommend creating a [virtual environment](https://docs.astral.sh/uv/pip/environments/#using-a-virtual-environment).
|
||||
Currently, we use pip and recommend creating a virtual environment.
|
||||
|
||||
For convenience here's a command for it:
|
||||
|
||||
```bash
|
||||
uv venv .venv --python 3.11
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
@@ -95,15 +95,33 @@ If using PowerShell, the command slightly differs:
|
||||
Install the required python dependencies:
|
||||
|
||||
```bash
|
||||
uv sync --all-extras
|
||||
pip install -r backend/requirements/combined.txt
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector):
|
||||
or
|
||||
|
||||
```bash
|
||||
uv run playwright install
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
pip install -r backend/requirements/ee.txt
|
||||
pip install -r backend/requirements/model_server.txt
|
||||
```
|
||||
|
||||
Fix vscode/cursor auto-imports:
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector)
|
||||
|
||||
In the activated Python virtualenv, install Playwright for Python by running:
|
||||
|
||||
```bash
|
||||
playwright install
|
||||
```
|
||||
|
||||
You may have to deactivate and reactivate your virtualenv for `playwright` to appear on your path.
|
||||
|
||||
#### Frontend: Node dependencies
|
||||
|
||||
Onyx uses Node v22.20.0. We highly recommend you use [Node Version Manager (nvm)](https://github.com/nvm-sh/nvm)
|
||||
@@ -112,7 +130,7 @@ to manage your Node installations. Once installed, you can run
|
||||
```bash
|
||||
nvm install 22 && nvm use 22
|
||||
node -v # verify your active version
|
||||
```
|
||||
```
|
||||
|
||||
Navigate to `onyx/web` and run:
|
||||
|
||||
@@ -126,15 +144,21 @@ npm i
|
||||
|
||||
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
|
||||
|
||||
Then run:
|
||||
With the virtual environment active, install the pre-commit library with:
|
||||
|
||||
```bash
|
||||
uv run pre-commit install
|
||||
pip install pre-commit
|
||||
```
|
||||
|
||||
Then, from the `onyx/backend` directory, run:
|
||||
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
Additionally, we use `mypy` for static type checking.
|
||||
Onyx is fully type-annotated, and we want to keep it that way!
|
||||
To run the mypy checks manually, run `uv run mypy .` from the `onyx/backend` directory.
|
||||
To run the mypy checks manually, run `python -m mypy .` from the `onyx/backend` directory.
|
||||
|
||||
### Web
|
||||
|
||||
@@ -161,7 +185,7 @@ You will need Docker installed to run these containers.
|
||||
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis/MinIO with:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d index relational_db cache minio
|
||||
docker compose up -d index relational_db cache minio
|
||||
```
|
||||
|
||||
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
|
||||
|
||||
@@ -15,4 +15,3 @@ build/
|
||||
dist/
|
||||
.coverage
|
||||
htmlcov/
|
||||
model_server/legacy/
|
||||
|
||||
@@ -13,10 +13,23 @@ RUN uv pip install --system --no-cache-dir --upgrade \
|
||||
-r /tmp/requirements.txt && \
|
||||
rm -rf ~/.cache/uv /tmp/*.txt
|
||||
|
||||
# Stage for downloading embedding models
|
||||
# Stage for downloading tokenizers
|
||||
FROM base AS tokenizers
|
||||
RUN python -c "from transformers import AutoTokenizer; \
|
||||
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
|
||||
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1');"
|
||||
|
||||
# Stage for downloading Onyx models
|
||||
FROM base AS onyx-models
|
||||
RUN python -c "from huggingface_hub import snapshot_download; \
|
||||
snapshot_download(repo_id='onyx-dot-app/hybrid-intent-token-classifier'); \
|
||||
snapshot_download(repo_id='onyx-dot-app/information-content-model');"
|
||||
|
||||
# Stage for downloading embedding and reranking models
|
||||
FROM base AS embedding-models
|
||||
RUN python -c "from huggingface_hub import snapshot_download; \
|
||||
snapshot_download('nomic-ai/nomic-embed-text-v1');"
|
||||
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1');"
|
||||
|
||||
# Initialize SentenceTransformer to cache the custom architecture
|
||||
RUN python -c "from sentence_transformers import SentenceTransformer; \
|
||||
@@ -41,6 +54,8 @@ RUN groupadd -g 1001 onyx && \
|
||||
# In case the user has volumes mounted to /app/.cache/huggingface that they've downloaded while
|
||||
# running Onyx, move the current contents of the cache folder to a temporary location to ensure
|
||||
# it's preserved in order to combine with the user's cache contents
|
||||
COPY --chown=onyx:onyx --from=tokenizers /app/.cache/huggingface /app/.cache/temp_huggingface
|
||||
COPY --chown=onyx:onyx --from=onyx-models /app/.cache/huggingface /app/.cache/temp_huggingface
|
||||
COPY --chown=onyx:onyx --from=embedding-models /app/.cache/huggingface /app/.cache/temp_huggingface
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -7,12 +7,8 @@ Onyx migrations use a generic single-database configuration with an async dbapi.
|
||||
|
||||
## To generate new migrations:
|
||||
|
||||
From onyx/backend, run:
|
||||
`alembic revision -m <DESCRIPTION_OF_MIGRATION>`
|
||||
|
||||
Note: you cannot use the `--autogenerate` flag as the automatic schema parsing does not work.
|
||||
|
||||
Manually populate the upgrade and downgrade in your new migration.
|
||||
run from onyx/backend:
|
||||
`alembic revision --autogenerate -m <DESCRIPTION_OF_MIGRATION>`
|
||||
|
||||
More info can be found here: https://alembic.sqlalchemy.org/en/latest/autogenerate.html
|
||||
|
||||
|
||||
@@ -39,9 +39,7 @@ config = context.config
|
||||
if config.config_file_name is not None and config.attributes.get(
|
||||
"configure_logger", True
|
||||
):
|
||||
# disable_existing_loggers=False prevents breaking pytest's caplog fixture
|
||||
# See: https://pytest-alembic.readthedocs.io/en/latest/setup.html#caplog-issues
|
||||
fileConfig(config.config_file_name, disable_existing_loggers=False)
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
|
||||
@@ -462,49 +460,8 @@ def run_migrations_offline() -> None:
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
Supports pytest-alembic by checking for a pre-configured connection
|
||||
in context.config.attributes["connection"]. If present, uses that
|
||||
connection/engine directly instead of creating a new async engine.
|
||||
"""
|
||||
# Check if pytest-alembic is providing a connection/engine
|
||||
connectable = context.config.attributes.get("connection", None)
|
||||
|
||||
if connectable is not None:
|
||||
# pytest-alembic is providing an engine - use it directly
|
||||
logger.info("run_migrations_online starting (pytest-alembic mode).")
|
||||
|
||||
# For pytest-alembic, we use the default schema (public)
|
||||
schema_name = context.config.attributes.get(
|
||||
"schema_name", POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
|
||||
# pytest-alembic passes an Engine, we need to get a connection from it
|
||||
with connectable.connect() as connection:
|
||||
# Set search path for the schema
|
||||
connection.execute(text(f'SET search_path TO "{schema_name}"'))
|
||||
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
include_object=include_object,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
compare_server_default=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
# Commit the transaction to ensure changes are visible to next migration
|
||||
connection.commit()
|
||||
else:
|
||||
# Normal operation - use async migrations
|
||||
logger.info("run_migrations_online starting.")
|
||||
asyncio.run(run_async_migrations())
|
||||
logger.info("run_migrations_online starting.")
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
"""add is_clarification to chat_message
|
||||
|
||||
Revision ID: 18b5b2524446
|
||||
Revises: 87c52ec39f84
|
||||
Create Date: 2025-01-16
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "18b5b2524446"
|
||||
down_revision = "87c52ec39f84"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"is_clarification", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "is_clarification")
|
||||
@@ -12,8 +12,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "23957775e5f5"
|
||||
down_revision = "bc9771dccadf"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels = None # type: ignore
|
||||
depends_on = None # type: ignore
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
"""add last refreshed at mcp server
|
||||
|
||||
Revision ID: 2a391f840e85
|
||||
Revises: 4cebcbc9b2ae
|
||||
Create Date: 2025-12-06 15:19:59.766066
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembi.
|
||||
revision = "2a391f840e85"
|
||||
down_revision = "4cebcbc9b2ae"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"mcp_server",
|
||||
sa.Column("last_refreshed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("mcp_server", "last_refreshed_at")
|
||||
@@ -1,46 +0,0 @@
|
||||
"""usage_limits
|
||||
|
||||
Revision ID: 2b90f3af54b8
|
||||
Revises: 9a0296d7421e
|
||||
Create Date: 2026-01-03 16:55:30.449692
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2b90f3af54b8"
|
||||
down_revision = "9a0296d7421e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"tenant_usage",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"window_start", sa.DateTime(timezone=True), nullable=False, index=True
|
||||
),
|
||||
sa.Column("llm_cost_cents", sa.Float(), nullable=False, server_default="0.0"),
|
||||
sa.Column("chunks_indexed", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("api_calls", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column(
|
||||
"non_streaming_api_calls", sa.Integer(), nullable=False, server_default="0"
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("window_start", name="uq_tenant_usage_window"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_tenant_usage_window_start", table_name="tenant_usage")
|
||||
op.drop_table("tenant_usage")
|
||||
@@ -11,7 +11,7 @@ from pydantic import BaseModel, ConfigDict
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
from onyx.llm.llm_provider_options import (
|
||||
fetch_model_names_for_provider_as_set,
|
||||
fetch_visible_model_names_for_provider_as_set,
|
||||
)
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
"""add tab_index to tool_call
|
||||
|
||||
Revision ID: 4cebcbc9b2ae
|
||||
Revises: a1b2c3d4e5f6
|
||||
Create Date: 2025-12-16
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4cebcbc9b2ae"
|
||||
down_revision = "a1b2c3d4e5f6"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("tab_index", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("tool_call", "tab_index")
|
||||
@@ -62,11 +62,6 @@ def upgrade() -> None:
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop the temporary table to avoid conflicts if migration runs again
|
||||
# (e.g., during upgrade -> downgrade -> upgrade cycles in tests)
|
||||
op.execute("DROP TABLE IF EXISTS temp_connector_credential")
|
||||
|
||||
# If no exception was raised, alter the column
|
||||
op.alter_column("credential", "source", nullable=True) # TODO modify
|
||||
# # ### end Alembic commands ###
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
"""backend driven notification details
|
||||
|
||||
Revision ID: 5c3dca366b35
|
||||
Revises: 9087b548dd69
|
||||
Create Date: 2026-01-06 16:03:11.413724
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5c3dca366b35"
|
||||
down_revision = "9087b548dd69"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"notification",
|
||||
sa.Column(
|
||||
"title", sa.String(), nullable=False, server_default="New Notification"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"notification",
|
||||
sa.Column("description", sa.String(), nullable=True, server_default=""),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("notification", "title")
|
||||
op.drop_column("notification", "description")
|
||||
@@ -1,75 +0,0 @@
|
||||
"""nullify_default_task_prompt
|
||||
|
||||
Revision ID: 699221885109
|
||||
Revises: 7e490836d179
|
||||
Create Date: 2025-12-30 10:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "699221885109"
|
||||
down_revision = "7e490836d179"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
DEFAULT_PERSONA_ID = 0
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Make task_prompt column nullable
|
||||
# Note: The model had nullable=True but the DB column was NOT NULL until this point
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"task_prompt",
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Set task_prompt to NULL for the default persona
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET task_prompt = NULL
|
||||
WHERE id = :persona_id
|
||||
"""
|
||||
),
|
||||
{"persona_id": DEFAULT_PERSONA_ID},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore task_prompt to empty string for the default persona
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET task_prompt = ''
|
||||
WHERE id = :persona_id AND task_prompt IS NULL
|
||||
"""
|
||||
),
|
||||
{"persona_id": DEFAULT_PERSONA_ID},
|
||||
)
|
||||
|
||||
# Set any remaining NULL task_prompts to empty string before making non-nullable
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET task_prompt = ''
|
||||
WHERE task_prompt IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Revert task_prompt column to not nullable
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"task_prompt",
|
||||
nullable=False,
|
||||
)
|
||||
@@ -1,54 +0,0 @@
|
||||
"""add image generation config table
|
||||
|
||||
Revision ID: 7206234e012a
|
||||
Revises: 699221885109
|
||||
Create Date: 2025-12-21 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7206234e012a"
|
||||
down_revision = "699221885109"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"image_generation_config",
|
||||
sa.Column("image_provider_id", sa.String(), primary_key=True),
|
||||
sa.Column("model_configuration_id", sa.Integer(), nullable=False),
|
||||
sa.Column("is_default", sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["model_configuration_id"],
|
||||
["model_configuration.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_image_generation_config_is_default",
|
||||
"image_generation_config",
|
||||
["is_default"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_image_generation_config_model_configuration_id",
|
||||
"image_generation_config",
|
||||
["model_configuration_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
"ix_image_generation_config_model_configuration_id",
|
||||
table_name="image_generation_config",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_image_generation_config_is_default", table_name="image_generation_config"
|
||||
)
|
||||
op.drop_table("image_generation_config")
|
||||
@@ -10,7 +10,7 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
from onyx.llm.llm_provider_options import (
|
||||
fetch_model_names_for_provider_as_set,
|
||||
fetch_visible_model_names_for_provider_as_set,
|
||||
)
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
"""Add display_name to model_configuration
|
||||
|
||||
Revision ID: 7bd55f264e1b
|
||||
Revises: e8f0d2a38171
|
||||
Create Date: 2025-12-04
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7bd55f264e1b"
|
||||
down_revision = "e8f0d2a38171"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"model_configuration",
|
||||
sa.Column("display_name", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("model_configuration", "display_name")
|
||||
@@ -1,80 +0,0 @@
|
||||
"""nullify_default_system_prompt
|
||||
|
||||
Revision ID: 7e490836d179
|
||||
Revises: c1d2e3f4a5b6
|
||||
Create Date: 2025-12-29 16:54:36.635574
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7e490836d179"
|
||||
down_revision = "c1d2e3f4a5b6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# This is the default system prompt from the previous migration (87c52ec39f84)
|
||||
# ruff: noqa: E501, W605 start
|
||||
PREVIOUS_DEFAULT_SYSTEM_PROMPT = """
|
||||
You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the user's intent, ask clarifying questions when needed, think step-by-step through complex problems, provide clear and accurate answers, and proactively anticipate helpful follow-up information. Always prioritize being truthful, nuanced, insightful, and efficient.
|
||||
|
||||
The current date is [[CURRENT_DATETIME]].[[CITATION_GUIDANCE]]
|
||||
|
||||
# Response Style
|
||||
You use different text styles, bolding, emojis (sparingly), block quotes, and other formatting to make your responses more readable and engaging.
|
||||
You use proper Markdown and LaTeX to format your responses for math, scientific, and chemical formulas, symbols, etc.: '$$\\n[expression]\\n$$' for standalone cases and '\\( [expression] \\)' when inline.
|
||||
For code you prefer to use Markdown and specify the language.
|
||||
You can use horizontal rules (---) to separate sections of your responses.
|
||||
You can use Markdown tables to format your responses for data, lists, and other structured information.
|
||||
""".lstrip()
|
||||
# ruff: noqa: E501, W605 end
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Make system_prompt column nullable (model already has nullable=True but DB doesn't)
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"system_prompt",
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Set system_prompt to NULL where it matches the previous default
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET system_prompt = NULL
|
||||
WHERE system_prompt = :previous_default
|
||||
"""
|
||||
),
|
||||
{"previous_default": PREVIOUS_DEFAULT_SYSTEM_PROMPT},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore the default system prompt for personas that have NULL
|
||||
# Note: This may restore the prompt to personas that originally had NULL
|
||||
# before this migration, but there's no way to distinguish them
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET system_prompt = :previous_default
|
||||
WHERE system_prompt IS NULL
|
||||
"""
|
||||
),
|
||||
{"previous_default": PREVIOUS_DEFAULT_SYSTEM_PROMPT},
|
||||
)
|
||||
|
||||
# Revert system_prompt column to not nullable
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"system_prompt",
|
||||
nullable=False,
|
||||
)
|
||||
@@ -42,13 +42,13 @@ def upgrade() -> None:
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
@@ -63,13 +63,13 @@ def upgrade() -> None:
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
"""update_default_system_prompt
|
||||
|
||||
Revision ID: 87c52ec39f84
|
||||
Revises: 7bd55f264e1b
|
||||
Create Date: 2025-12-05 15:54:06.002452
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "87c52ec39f84"
|
||||
down_revision = "7bd55f264e1b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
DEFAULT_PERSONA_ID = 0
|
||||
|
||||
# ruff: noqa: E501, W605 start
|
||||
DEFAULT_SYSTEM_PROMPT = """
|
||||
You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the user's intent, ask clarifying questions when needed, think step-by-step through complex problems, provide clear and accurate answers, and proactively anticipate helpful follow-up information. Always prioritize being truthful, nuanced, insightful, and efficient.
|
||||
|
||||
The current date is [[CURRENT_DATETIME]].[[CITATION_GUIDANCE]]
|
||||
|
||||
# Response Style
|
||||
You use different text styles, bolding, emojis (sparingly), block quotes, and other formatting to make your responses more readable and engaging.
|
||||
You use proper Markdown and LaTeX to format your responses for math, scientific, and chemical formulas, symbols, etc.: '$$\\n[expression]\\n$$' for standalone cases and '\\( [expression] \\)' when inline.
|
||||
For code you prefer to use Markdown and specify the language.
|
||||
You can use horizontal rules (---) to separate sections of your responses.
|
||||
You can use Markdown tables to format your responses for data, lists, and other structured information.
|
||||
""".lstrip()
|
||||
# ruff: noqa: E501, W605 end
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET system_prompt = :system_prompt
|
||||
WHERE id = :persona_id
|
||||
"""
|
||||
),
|
||||
{"system_prompt": DEFAULT_SYSTEM_PROMPT, "persona_id": DEFAULT_PERSONA_ID},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We don't revert the system prompt on downgrade since we don't know
|
||||
# what the previous value was. The new prompt is a reasonable default.
|
||||
pass
|
||||
@@ -1,136 +0,0 @@
|
||||
"""seed_default_image_gen_config
|
||||
|
||||
Revision ID: 9087b548dd69
|
||||
Revises: 2b90f3af54b8
|
||||
Create Date: 2026-01-05 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9087b548dd69"
|
||||
down_revision = "2b90f3af54b8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# Constants for default image generation config
|
||||
# Source: web/src/app/admin/configuration/image-generation/constants.ts
|
||||
IMAGE_PROVIDER_ID = "openai_gpt_image_1"
|
||||
MODEL_NAME = "gpt-image-1"
|
||||
PROVIDER_NAME = "openai"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Check if image_generation_config table already has records
|
||||
existing_configs = (
|
||||
conn.execute(sa.text("SELECT COUNT(*) FROM image_generation_config")).scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
if existing_configs > 0:
|
||||
# Skip if configs already exist - user may have configured manually
|
||||
return
|
||||
|
||||
# Find the first OpenAI LLM provider
|
||||
openai_provider = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, api_key
|
||||
FROM llm_provider
|
||||
WHERE provider = :provider
|
||||
ORDER BY id
|
||||
LIMIT 1
|
||||
"""
|
||||
),
|
||||
{"provider": PROVIDER_NAME},
|
||||
).fetchone()
|
||||
|
||||
if not openai_provider:
|
||||
# No OpenAI provider found - nothing to do
|
||||
return
|
||||
|
||||
source_provider_id, api_key = openai_provider
|
||||
|
||||
# Create new LLM provider for image generation (clone only api_key)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO llm_provider (
|
||||
name, provider, api_key, api_base, api_version,
|
||||
deployment_name, default_model_name, is_public,
|
||||
is_default_provider, is_default_vision_provider, is_auto_mode
|
||||
)
|
||||
VALUES (
|
||||
:name, :provider, :api_key, NULL, NULL,
|
||||
NULL, :default_model_name, :is_public,
|
||||
NULL, NULL, :is_auto_mode
|
||||
)
|
||||
RETURNING id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"name": f"Image Gen - {IMAGE_PROVIDER_ID}",
|
||||
"provider": PROVIDER_NAME,
|
||||
"api_key": api_key,
|
||||
"default_model_name": MODEL_NAME,
|
||||
"is_public": True,
|
||||
"is_auto_mode": False,
|
||||
},
|
||||
)
|
||||
new_provider_id = result.scalar()
|
||||
|
||||
# Create model configuration
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO model_configuration (
|
||||
llm_provider_id, name, is_visible, max_input_tokens,
|
||||
supports_image_input, display_name
|
||||
)
|
||||
VALUES (
|
||||
:llm_provider_id, :name, :is_visible, :max_input_tokens,
|
||||
:supports_image_input, :display_name
|
||||
)
|
||||
RETURNING id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"llm_provider_id": new_provider_id,
|
||||
"name": MODEL_NAME,
|
||||
"is_visible": True,
|
||||
"max_input_tokens": None,
|
||||
"supports_image_input": False,
|
||||
"display_name": None,
|
||||
},
|
||||
)
|
||||
model_config_id = result.scalar()
|
||||
|
||||
# Create image generation config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO image_generation_config (
|
||||
image_provider_id, model_configuration_id, is_default
|
||||
)
|
||||
VALUES (
|
||||
:image_provider_id, :model_configuration_id, :is_default
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"image_provider_id": IMAGE_PROVIDER_ID,
|
||||
"model_configuration_id": model_config_id,
|
||||
"is_default": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We don't remove the config on downgrade since it's safe to keep around
|
||||
# If we upgrade again, it will be a no-op due to the existing records check
|
||||
pass
|
||||
@@ -1,33 +0,0 @@
|
||||
"""add_is_auto_mode_to_llm_provider
|
||||
|
||||
Revision ID: 9a0296d7421e
|
||||
Revises: 7206234e012a
|
||||
Create Date: 2025-12-17 18:14:29.620981
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9a0296d7421e"
|
||||
down_revision = "7206234e012a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column(
|
||||
"is_auto_mode",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("llm_provider", "is_auto_mode")
|
||||
@@ -234,8 +234,6 @@ def downgrade() -> None:
|
||||
if "instructions" in columns:
|
||||
op.drop_column("user_project", "instructions")
|
||||
op.execute("ALTER TABLE user_project RENAME TO user_folder")
|
||||
# Update NULL descriptions to empty string before setting NOT NULL constraint
|
||||
op.execute("UPDATE user_folder SET description = '' WHERE description IS NULL")
|
||||
op.alter_column("user_folder", "description", nullable=False)
|
||||
logger.info("Renamed user_project back to user_folder")
|
||||
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
"""update_default_tool_descriptions
|
||||
|
||||
Revision ID: a01bf2971c5d
|
||||
Revises: 87c52ec39f84
|
||||
Create Date: 2025-12-16 15:21:25.656375
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a01bf2971c5d"
|
||||
down_revision = "18b5b2524446"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# new tool descriptions (12/2025)
|
||||
TOOL_DESCRIPTIONS = {
|
||||
"SearchTool": "The Search Action allows the agent to search through connected knowledge to help build an answer.",
|
||||
"ImageGenerationTool": (
|
||||
"The Image Generation Action allows the agent to use DALL-E 3 or GPT-IMAGE-1 to generate images. "
|
||||
"The action will be used when the user asks the agent to generate an image."
|
||||
),
|
||||
"WebSearchTool": (
|
||||
"The Web Search Action allows the agent "
|
||||
"to perform internet searches for up-to-date information."
|
||||
),
|
||||
"KnowledgeGraphTool": (
|
||||
"The Knowledge Graph Search Action allows the agent to search the "
|
||||
"Knowledge Graph for information. This tool can (for now) only be active in the KG Beta Agent, "
|
||||
"and it requires the Knowledge Graph to be enabled."
|
||||
),
|
||||
"OktaProfileTool": (
|
||||
"The Okta Profile Action allows the agent to fetch the current user's information from Okta. "
|
||||
"This may include the user's name, email, phone number, address, and other details such as their "
|
||||
"manager and direct reports."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
for tool_id, description in TOOL_DESCRIPTIONS.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
|
||||
),
|
||||
{"description": description, "tool_id": tool_id},
|
||||
)
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
except Exception as e:
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -1,49 +0,0 @@
|
||||
"""add license table
|
||||
|
||||
Revision ID: a1b2c3d4e5f6
|
||||
Revises: a01bf2971c5d
|
||||
Create Date: 2025-12-04 10:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a1b2c3d4e5f6"
|
||||
down_revision = "a01bf2971c5d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"license",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("license_data", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Singleton pattern - only ever one row in this table
|
||||
op.create_index(
|
||||
"idx_license_singleton",
|
||||
"license",
|
||||
[sa.text("(true)")],
|
||||
unique=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("idx_license_singleton", table_name="license")
|
||||
op.drop_table("license")
|
||||
@@ -1,27 +0,0 @@
|
||||
"""Remove fast_default_model_name from llm_provider
|
||||
|
||||
Revision ID: a2b3c4d5e6f7
|
||||
Revises: 2a391f840e85
|
||||
Create Date: 2024-12-17
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a2b3c4d5e6f7"
|
||||
down_revision = "2a391f840e85"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("llm_provider", "fast_default_model_name")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column("fast_default_model_name", sa.String(), nullable=True),
|
||||
)
|
||||
@@ -1,39 +0,0 @@
|
||||
"""remove userfile related deprecated fields
|
||||
|
||||
Revision ID: a3c1a7904cd0
|
||||
Revises: 5c3dca366b35
|
||||
Create Date: 2026-01-06 13:00:30.634396
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a3c1a7904cd0"
|
||||
down_revision = "5c3dca366b35"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("user_file", "document_id")
|
||||
op.drop_column("user_file", "document_id_migrated")
|
||||
op.drop_column("connector_credential_pair", "is_user_file")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("is_user_file", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column("document_id", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"document_id_migrated", sa.Boolean(), nullable=False, server_default="true"
|
||||
),
|
||||
)
|
||||
@@ -280,14 +280,6 @@ def downgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
# Recreate the FK constraint that was implicitly dropped when the column was dropped
|
||||
op.create_foreign_key(
|
||||
"fk_chat_message_persona",
|
||||
"chat_message",
|
||||
"persona",
|
||||
["alternate_assistant_id"],
|
||||
["id"],
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("rephrased_query", sa.Text(), nullable=True)
|
||||
)
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
"""Drop milestone table
|
||||
|
||||
Revision ID: b8c9d0e1f2a3
|
||||
Revises: a2b3c4d5e6f7
|
||||
Create Date: 2025-12-18
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import fastapi_users_db_sqlalchemy
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b8c9d0e1f2a3"
|
||||
down_revision = "a2b3c4d5e6f7"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_table("milestone")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_table(
|
||||
"milestone",
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("tenant_id", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("event_type", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("event_tracker", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("event_type", name="uq_milestone_event_type"),
|
||||
)
|
||||
@@ -1,52 +0,0 @@
|
||||
"""add_deep_research_tool
|
||||
|
||||
Revision ID: c1d2e3f4a5b6
|
||||
Revises: b8c9d0e1f2a3
|
||||
Create Date: 2025-12-18 16:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c1d2e3f4a5b6"
|
||||
down_revision = "b8c9d0e1f2a3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
DEEP_RESEARCH_TOOL = {
|
||||
"name": RESEARCH_AGENT_DB_NAME,
|
||||
"display_name": "Research Agent",
|
||||
"description": "The Research Agent is a sub-agent that conducts research on a specific topic.",
|
||||
"in_code_tool_id": "ResearchAgent",
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id, false)
|
||||
"""
|
||||
),
|
||||
DEEP_RESEARCH_TOOL,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM tool
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
{"in_code_tool_id": DEEP_RESEARCH_TOOL["in_code_tool_id"]},
|
||||
)
|
||||
@@ -257,8 +257,8 @@ def _migrate_files_to_external_storage() -> None:
|
||||
print(f"File {file_id} not found in PostgreSQL storage.")
|
||||
continue
|
||||
|
||||
lobj_id = cast(int, file_record.lobj_oid)
|
||||
file_metadata = cast(Any, file_record.file_metadata)
|
||||
lobj_id = cast(int, file_record.lobj_oid) # type: ignore
|
||||
file_metadata = cast(Any, file_record.file_metadata) # type: ignore
|
||||
|
||||
# Read file content from PostgreSQL
|
||||
try:
|
||||
@@ -280,7 +280,7 @@ def _migrate_files_to_external_storage() -> None:
|
||||
else:
|
||||
# Convert other types to dict if possible, otherwise None
|
||||
try:
|
||||
file_metadata = dict(file_record.file_metadata)
|
||||
file_metadata = dict(file_record.file_metadata) # type: ignore
|
||||
except (TypeError, ValueError):
|
||||
file_metadata = None
|
||||
|
||||
|
||||
@@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
|
||||
revision = "e209dc5a8156"
|
||||
down_revision = "48d14957fe80"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels = None # type: ignore
|
||||
depends_on = None # type: ignore
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
@@ -8,7 +8,7 @@ Create Date: 2025-11-28 11:15:37.667340
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from onyx.db.enums import (
|
||||
from onyx.db.enums import ( # type: ignore[import-untyped]
|
||||
MCPTransport,
|
||||
MCPAuthenticationType,
|
||||
MCPAuthenticationPerformer,
|
||||
|
||||
@@ -20,9 +20,7 @@ config = context.config
|
||||
if config.config_file_name is not None and config.attributes.get(
|
||||
"configure_logger", True
|
||||
):
|
||||
# disable_existing_loggers=False prevents breaking pytest's caplog fixture
|
||||
# See: https://pytest-alembic.readthedocs.io/en/latest/setup.html#caplog-issues
|
||||
fileConfig(config.config_file_name, disable_existing_loggers=False)
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
@@ -84,9 +82,9 @@ def run_migrations_offline() -> None:
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore[arg-type]
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
include_object=include_object,
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
@@ -110,24 +108,9 @@ async def run_async_migrations() -> None:
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
"""Run migrations in 'online' mode."""
|
||||
|
||||
Supports pytest-alembic by checking for a pre-configured connection
|
||||
in context.config.attributes["connection"]. If present, uses that
|
||||
connection/engine directly instead of creating a new async engine.
|
||||
"""
|
||||
# Check if pytest-alembic is providing a connection/engine
|
||||
connectable = context.config.attributes.get("connection", None)
|
||||
|
||||
if connectable is not None:
|
||||
# pytest-alembic is providing an engine - use it directly
|
||||
with connectable.connect() as connection:
|
||||
do_run_migrations(connection)
|
||||
# Commit to ensure changes are visible to next migration
|
||||
connection.commit()
|
||||
else:
|
||||
# Normal operation - use async migrations
|
||||
asyncio.run(run_async_migrations())
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
group "default" {
|
||||
targets = ["backend", "model-server", "web"]
|
||||
targets = ["backend", "model-server"]
|
||||
}
|
||||
|
||||
variable "BACKEND_REPOSITORY" {
|
||||
default = "onyxdotapp/onyx-backend"
|
||||
}
|
||||
|
||||
variable "WEB_SERVER_REPOSITORY" {
|
||||
default = "onyxdotapp/onyx-web-server"
|
||||
}
|
||||
|
||||
variable "MODEL_SERVER_REPOSITORY" {
|
||||
default = "onyxdotapp/onyx-model-server"
|
||||
}
|
||||
@@ -23,7 +19,7 @@ variable "TAG" {
|
||||
}
|
||||
|
||||
target "backend" {
|
||||
context = "backend"
|
||||
context = "."
|
||||
dockerfile = "Dockerfile"
|
||||
|
||||
cache-from = ["type=registry,ref=${BACKEND_REPOSITORY}:latest"]
|
||||
@@ -32,18 +28,8 @@ target "backend" {
|
||||
tags = ["${BACKEND_REPOSITORY}:${TAG}"]
|
||||
}
|
||||
|
||||
target "web" {
|
||||
context = "web"
|
||||
dockerfile = "Dockerfile"
|
||||
|
||||
cache-from = ["type=registry,ref=${WEB_SERVER_REPOSITORY}:latest"]
|
||||
cache-to = ["type=inline"]
|
||||
|
||||
tags = ["${WEB_SERVER_REPOSITORY}:${TAG}"]
|
||||
}
|
||||
|
||||
target "model-server" {
|
||||
context = "backend"
|
||||
context = "."
|
||||
|
||||
dockerfile = "Dockerfile.model_server"
|
||||
|
||||
@@ -54,7 +40,7 @@ target "model-server" {
|
||||
}
|
||||
|
||||
target "integration" {
|
||||
context = "backend"
|
||||
context = "."
|
||||
dockerfile = "tests/integration/Dockerfile"
|
||||
|
||||
// Provide the base image via build context from the backend target
|
||||
@@ -111,6 +111,10 @@ CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS = float(
|
||||
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
|
||||
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
|
||||
|
||||
OPENAI_DEFAULT_API_KEY = os.environ.get("OPENAI_DEFAULT_API_KEY")
|
||||
ANTHROPIC_DEFAULT_API_KEY = os.environ.get("ANTHROPIC_DEFAULT_API_KEY")
|
||||
COHERE_DEFAULT_API_KEY = os.environ.get("COHERE_DEFAULT_API_KEY")
|
||||
|
||||
# JWT Public Key URL
|
||||
JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
|
||||
|
||||
|
||||
@@ -118,6 +118,6 @@ def fetch_document_sets(
|
||||
.all()
|
||||
)
|
||||
|
||||
document_set_with_cc_pairs.append((document_set, cc_pairs))
|
||||
document_set_with_cc_pairs.append((document_set, cc_pairs)) # type: ignore
|
||||
|
||||
return document_set_with_cc_pairs
|
||||
|
||||
@@ -1,278 +0,0 @@
|
||||
"""Database and cache operations for the license table."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from onyx.db.models import License
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
LICENSE_METADATA_KEY = "license:metadata"
|
||||
LICENSE_CACHE_TTL_SECONDS = 86400 # 24 hours
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Database CRUD Operations
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_license(db_session: Session) -> License | None:
|
||||
"""
|
||||
Get the current license (singleton pattern - only one row).
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
License object if exists, None otherwise
|
||||
"""
|
||||
return db_session.execute(select(License)).scalars().first()
|
||||
|
||||
|
||||
def upsert_license(db_session: Session, license_data: str) -> License:
|
||||
"""
|
||||
Insert or update the license (singleton pattern).
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
license_data: Base64-encoded signed license blob
|
||||
|
||||
Returns:
|
||||
The created or updated License object
|
||||
"""
|
||||
existing = get_license(db_session)
|
||||
|
||||
if existing:
|
||||
existing.license_data = license_data
|
||||
db_session.commit()
|
||||
db_session.refresh(existing)
|
||||
logger.info("License updated")
|
||||
return existing
|
||||
|
||||
new_license = License(license_data=license_data)
|
||||
db_session.add(new_license)
|
||||
db_session.commit()
|
||||
db_session.refresh(new_license)
|
||||
logger.info("License created")
|
||||
return new_license
|
||||
|
||||
|
||||
def delete_license(db_session: Session) -> bool:
|
||||
"""
|
||||
Delete the current license.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
True if deleted, False if no license existed
|
||||
"""
|
||||
existing = get_license(db_session)
|
||||
if existing:
|
||||
db_session.delete(existing)
|
||||
db_session.commit()
|
||||
logger.info("License deleted")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Seat Counting
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
"""
|
||||
Get current seat usage.
|
||||
|
||||
For multi-tenant: counts users in UserTenantMapping for this tenant.
|
||||
For self-hosted: counts all active users (includes both Onyx UI users
|
||||
and Slack users who have been converted to Onyx users).
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_count
|
||||
|
||||
return get_tenant_count(tenant_id or get_current_tenant_id())
|
||||
else:
|
||||
# Self-hosted: count all active users (Onyx + converted Slack users)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
result = db_session.execute(
|
||||
select(func.count()).select_from(User).where(User.is_active) # type: ignore
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Redis Cache Operations
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata | None:
|
||||
"""
|
||||
Get license metadata from Redis cache.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
|
||||
Returns:
|
||||
LicenseMetadata if cached, None otherwise
|
||||
"""
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_replica_client(tenant_id=tenant)
|
||||
|
||||
cached = redis_client.get(LICENSE_METADATA_KEY)
|
||||
if cached:
|
||||
try:
|
||||
cached_str: str
|
||||
if isinstance(cached, bytes):
|
||||
cached_str = cached.decode("utf-8")
|
||||
else:
|
||||
cached_str = str(cached)
|
||||
return LicenseMetadata.model_validate_json(cached_str)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse cached license metadata: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def invalidate_license_cache(tenant_id: str | None = None) -> None:
|
||||
"""
|
||||
Invalidate the license metadata cache (not the license itself).
|
||||
|
||||
This deletes the cached LicenseMetadata from Redis. The actual license
|
||||
in the database is not affected. Redis delete is idempotent - if the
|
||||
key doesn't exist, this is a no-op.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
"""
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_client(tenant_id=tenant)
|
||||
redis_client.delete(LICENSE_METADATA_KEY)
|
||||
logger.info("License cache invalidated")
|
||||
|
||||
|
||||
def update_license_cache(
|
||||
payload: LicensePayload,
|
||||
source: LicenseSource | None = None,
|
||||
grace_period_end: datetime | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LicenseMetadata:
|
||||
"""
|
||||
Update the Redis cache with license metadata.
|
||||
|
||||
We cache all license statuses (ACTIVE, GRACE_PERIOD, GATED_ACCESS) because:
|
||||
1. Frontend needs status to show appropriate UI/banners
|
||||
2. Caching avoids repeated DB + crypto verification on every request
|
||||
3. Status enforcement happens at the feature level, not here
|
||||
|
||||
Args:
|
||||
payload: Verified license payload
|
||||
source: How the license was obtained
|
||||
grace_period_end: Optional grace period end time
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
|
||||
Returns:
|
||||
The cached LicenseMetadata
|
||||
"""
|
||||
from ee.onyx.utils.license import get_license_status
|
||||
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_client(tenant_id=tenant)
|
||||
|
||||
used_seats = get_used_seats(tenant)
|
||||
status = get_license_status(payload, grace_period_end)
|
||||
|
||||
metadata = LicenseMetadata(
|
||||
tenant_id=payload.tenant_id,
|
||||
organization_name=payload.organization_name,
|
||||
seats=payload.seats,
|
||||
used_seats=used_seats,
|
||||
plan_type=payload.plan_type,
|
||||
issued_at=payload.issued_at,
|
||||
expires_at=payload.expires_at,
|
||||
grace_period_end=grace_period_end,
|
||||
status=status,
|
||||
source=source,
|
||||
stripe_subscription_id=payload.stripe_subscription_id,
|
||||
)
|
||||
|
||||
redis_client.setex(
|
||||
LICENSE_METADATA_KEY,
|
||||
LICENSE_CACHE_TTL_SECONDS,
|
||||
metadata.model_dump_json(),
|
||||
)
|
||||
|
||||
logger.info(f"License cache updated: {metadata.seats} seats, status={status.value}")
|
||||
return metadata
|
||||
|
||||
|
||||
def refresh_license_cache(
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
) -> LicenseMetadata | None:
|
||||
"""
|
||||
Refresh the license cache from the database.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
|
||||
Returns:
|
||||
LicenseMetadata if license exists, None otherwise
|
||||
"""
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
|
||||
license_record = get_license(db_session)
|
||||
if not license_record:
|
||||
invalidate_license_cache(tenant_id)
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = verify_license_signature(license_record.license_data)
|
||||
return update_license_cache(
|
||||
payload,
|
||||
source=LicenseSource.AUTO_FETCH,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to verify license during cache refresh: {e}")
|
||||
invalidate_license_cache(tenant_id)
|
||||
return None
|
||||
|
||||
|
||||
def get_license_metadata(
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
) -> LicenseMetadata | None:
|
||||
"""
|
||||
Get license metadata, using cache if available.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
|
||||
Returns:
|
||||
LicenseMetadata if license exists, None otherwise
|
||||
"""
|
||||
# Try cache first
|
||||
cached = get_cached_license_metadata(tenant_id)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
# Refresh from database
|
||||
return refresh_license_cache(db_session, tenant_id)
|
||||
@@ -34,7 +34,6 @@ def make_persona_private(
|
||||
create_notification(
|
||||
user_id=user_id,
|
||||
notif_type=NotificationType.PERSONA_SHARED,
|
||||
title="A new agent was shared with you!",
|
||||
db_session=db_session,
|
||||
additional_data=PersonaSharedNotificationData(
|
||||
persona_id=persona_id,
|
||||
|
||||
@@ -8,7 +8,6 @@ from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
@@ -363,29 +362,14 @@ def _check_user_group_is_modifiable(user_group: UserGroup) -> None:
|
||||
|
||||
def _add_user__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int, user_ids: list[UUID]
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction.
|
||||
|
||||
This function is idempotent - it will skip users who are already in the group
|
||||
to avoid duplicate key violations during concurrent operations or re-syncs.
|
||||
Uses ON CONFLICT DO NOTHING to keep inserts atomic under concurrency.
|
||||
"""
|
||||
if not user_ids:
|
||||
return
|
||||
|
||||
insert_stmt = (
|
||||
insert(User__UserGroup)
|
||||
.values(
|
||||
[
|
||||
{"user_id": user_id, "user_group_id": user_group_id}
|
||||
for user_id in user_ids
|
||||
]
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[User__UserGroup.user_group_id, User__UserGroup.user_id]
|
||||
)
|
||||
)
|
||||
db_session.execute(insert_stmt)
|
||||
) -> list[User__UserGroup]:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
relationships = [
|
||||
User__UserGroup(user_id=user_id, user_group_id=user_group_id)
|
||||
for user_id in user_ids
|
||||
]
|
||||
db_session.add_all(relationships)
|
||||
return relationships
|
||||
|
||||
|
||||
def _add_user_group__cc_pair_relationships__no_commit(
|
||||
|
||||
@@ -14,7 +14,6 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
basic_router as enterprise_settings_router,
|
||||
)
|
||||
from ee.onyx.server.evals.api import router as evals_router
|
||||
from ee.onyx.server.license.api import router as license_router
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.tenant_tracking import (
|
||||
add_api_server_tenant_id_middleware,
|
||||
@@ -140,8 +139,6 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, enterprise_settings_router)
|
||||
include_router_with_global_prefix_prepended(application, usage_export_router)
|
||||
# License management
|
||||
include_router_with_global_prefix_prepended(application, license_router)
|
||||
|
||||
if MULTI_TENANT:
|
||||
# Tenant management
|
||||
|
||||
@@ -21,9 +21,8 @@ from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
|
||||
router = APIRouter(prefix="/analytics", tags=PUBLIC_API_TAGS)
|
||||
router = APIRouter(prefix="/analytics")
|
||||
|
||||
|
||||
_DEFAULT_LOOKBACK_DAYS = 30
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import List
|
||||
|
||||
@@ -24,12 +23,6 @@ class NavigationItem(BaseModel):
|
||||
return instance
|
||||
|
||||
|
||||
class LogoDisplayStyle(str, Enum):
|
||||
LOGO_AND_NAME = "logo_and_name"
|
||||
LOGO_ONLY = "logo_only"
|
||||
NAME_ONLY = "name_only"
|
||||
|
||||
|
||||
class EnterpriseSettings(BaseModel):
|
||||
"""General settings that only apply to the Enterprise Edition of Onyx
|
||||
|
||||
@@ -38,7 +31,6 @@ class EnterpriseSettings(BaseModel):
|
||||
application_name: str | None = None
|
||||
use_custom_logo: bool = False
|
||||
use_custom_logotype: bool = False
|
||||
logo_display_style: LogoDisplayStyle | None = None
|
||||
|
||||
# custom navigation
|
||||
custom_nav_items: List[NavigationItem] = Field(default_factory=list)
|
||||
@@ -50,9 +42,6 @@ class EnterpriseSettings(BaseModel):
|
||||
custom_popup_header: str | None = None
|
||||
custom_popup_content: str | None = None
|
||||
enable_consent_screen: bool | None = None
|
||||
consent_screen_prompt: str | None = None
|
||||
show_first_visit_notice: bool | None = None
|
||||
custom_greeting_message: str | None = None
|
||||
|
||||
def check_validity(self) -> None:
|
||||
return
|
||||
|
||||
@@ -1,246 +0,0 @@
|
||||
"""License API endpoints."""
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.db.license import delete_license as db_delete_license
|
||||
from ee.onyx.db.license import get_license_metadata
|
||||
from ee.onyx.db.license import invalidate_license_cache
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from ee.onyx.db.license import update_license_cache
|
||||
from ee.onyx.db.license import upsert_license
|
||||
from ee.onyx.server.license.models import LicenseResponse
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from ee.onyx.server.license.models import LicenseStatusResponse
|
||||
from ee.onyx.server.license.models import LicenseUploadResponse
|
||||
from ee.onyx.server.license.models import SeatUsageResponse
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/license")
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_license_status(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseStatusResponse:
|
||||
"""Get current license status and seat usage."""
|
||||
metadata = get_license_metadata(db_session)
|
||||
|
||||
if not metadata:
|
||||
return LicenseStatusResponse(has_license=False)
|
||||
|
||||
return LicenseStatusResponse(
|
||||
has_license=True,
|
||||
seats=metadata.seats,
|
||||
used_seats=metadata.used_seats,
|
||||
plan_type=metadata.plan_type,
|
||||
issued_at=metadata.issued_at,
|
||||
expires_at=metadata.expires_at,
|
||||
grace_period_end=metadata.grace_period_end,
|
||||
status=metadata.status,
|
||||
source=metadata.source,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/seats")
|
||||
async def get_seat_usage(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SeatUsageResponse:
|
||||
"""Get detailed seat usage information."""
|
||||
metadata = get_license_metadata(db_session)
|
||||
|
||||
if not metadata:
|
||||
return SeatUsageResponse(
|
||||
total_seats=0,
|
||||
used_seats=0,
|
||||
available_seats=0,
|
||||
)
|
||||
|
||||
return SeatUsageResponse(
|
||||
total_seats=metadata.seats,
|
||||
used_seats=metadata.used_seats,
|
||||
available_seats=max(0, metadata.seats - metadata.used_seats),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/fetch")
|
||||
async def fetch_license(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseResponse:
|
||||
"""
|
||||
Fetch license from control plane.
|
||||
Used after Stripe checkout completion to retrieve the new license.
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
try:
|
||||
token = generate_data_plane_token()
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to generate data plane token: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Authentication configuration error"
|
||||
)
|
||||
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/license/{tenant_id}"
|
||||
response = requests.get(url, headers=headers, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
if not isinstance(data, dict) or "license" not in data:
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Invalid response from control plane"
|
||||
)
|
||||
|
||||
license_data = data["license"]
|
||||
if not license_data:
|
||||
raise HTTPException(status_code=404, detail="No license found")
|
||||
|
||||
# Verify signature before persisting
|
||||
payload = verify_license_signature(license_data)
|
||||
|
||||
# Verify the fetched license is for this tenant
|
||||
if payload.tenant_id != tenant_id:
|
||||
logger.error(
|
||||
f"License tenant mismatch: expected {tenant_id}, got {payload.tenant_id}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License tenant ID mismatch - control plane returned wrong license",
|
||||
)
|
||||
|
||||
# Persist to DB and update cache atomically
|
||||
upsert_license(db_session, license_data)
|
||||
try:
|
||||
update_license_cache(payload, source=LicenseSource.AUTO_FETCH)
|
||||
except Exception as cache_error:
|
||||
# Log but don't fail - DB is source of truth, cache will refresh on next read
|
||||
logger.warning(f"Failed to update license cache: {cache_error}")
|
||||
|
||||
return LicenseResponse(success=True, license=payload)
|
||||
|
||||
except requests.HTTPError as e:
|
||||
status_code = e.response.status_code if e.response is not None else 502
|
||||
logger.error(f"Control plane returned error: {status_code}")
|
||||
raise HTTPException(
|
||||
status_code=status_code,
|
||||
detail="Failed to fetch license from control plane",
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"License verification failed: {type(e).__name__}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except requests.RequestException:
|
||||
logger.exception("Failed to fetch license from control plane")
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Failed to connect to control plane"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload_license(
|
||||
license_file: UploadFile = File(...),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseUploadResponse:
|
||||
"""
|
||||
Upload a license file manually.
|
||||
Used for air-gapped deployments where control plane is not accessible.
|
||||
"""
|
||||
try:
|
||||
content = await license_file.read()
|
||||
license_data = content.decode("utf-8").strip()
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Invalid license file format")
|
||||
|
||||
try:
|
||||
payload = verify_license_signature(license_data)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
if payload.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"License tenant ID mismatch. Expected {tenant_id}, got {payload.tenant_id}",
|
||||
)
|
||||
|
||||
# Persist to DB and update cache
|
||||
upsert_license(db_session, license_data)
|
||||
try:
|
||||
update_license_cache(payload, source=LicenseSource.MANUAL_UPLOAD)
|
||||
except Exception as cache_error:
|
||||
# Log but don't fail - DB is source of truth, cache will refresh on next read
|
||||
logger.warning(f"Failed to update license cache: {cache_error}")
|
||||
|
||||
return LicenseUploadResponse(
|
||||
success=True,
|
||||
message=f"License uploaded successfully. {payload.seats} seats, expires {payload.expires_at.date()}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh_license_cache_endpoint(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseStatusResponse:
|
||||
"""
|
||||
Force refresh the license cache from the database.
|
||||
Useful after manual database changes or to verify license validity.
|
||||
"""
|
||||
metadata = refresh_license_cache(db_session)
|
||||
|
||||
if not metadata:
|
||||
return LicenseStatusResponse(has_license=False)
|
||||
|
||||
return LicenseStatusResponse(
|
||||
has_license=True,
|
||||
seats=metadata.seats,
|
||||
used_seats=metadata.used_seats,
|
||||
plan_type=metadata.plan_type,
|
||||
issued_at=metadata.issued_at,
|
||||
expires_at=metadata.expires_at,
|
||||
grace_period_end=metadata.grace_period_end,
|
||||
status=metadata.status,
|
||||
source=metadata.source,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("")
|
||||
async def delete_license(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""
|
||||
Delete the current license.
|
||||
Admin only - removes license and invalidates cache.
|
||||
"""
|
||||
# Invalidate cache first - if DB delete fails, stale cache is worse than no cache
|
||||
try:
|
||||
invalidate_license_cache()
|
||||
except Exception as cache_error:
|
||||
logger.warning(f"Failed to invalidate license cache: {cache_error}")
|
||||
|
||||
deleted = db_delete_license(db_session)
|
||||
|
||||
return {"deleted": deleted}
|
||||
@@ -1,92 +0,0 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
|
||||
|
||||
class PlanType(str, Enum):
|
||||
MONTHLY = "monthly"
|
||||
ANNUAL = "annual"
|
||||
|
||||
|
||||
class LicenseSource(str, Enum):
|
||||
AUTO_FETCH = "auto_fetch"
|
||||
MANUAL_UPLOAD = "manual_upload"
|
||||
|
||||
|
||||
class LicensePayload(BaseModel):
|
||||
"""The payload portion of a signed license."""
|
||||
|
||||
version: str
|
||||
tenant_id: str
|
||||
organization_name: str | None = None
|
||||
issued_at: datetime
|
||||
expires_at: datetime
|
||||
seats: int
|
||||
plan_type: PlanType
|
||||
billing_cycle: str | None = None
|
||||
grace_period_days: int = 30
|
||||
stripe_subscription_id: str | None = None
|
||||
stripe_customer_id: str | None = None
|
||||
|
||||
|
||||
class LicenseData(BaseModel):
|
||||
"""Full signed license structure."""
|
||||
|
||||
payload: LicensePayload
|
||||
signature: str
|
||||
|
||||
|
||||
class LicenseMetadata(BaseModel):
|
||||
"""Cached license metadata stored in Redis."""
|
||||
|
||||
tenant_id: str
|
||||
organization_name: str | None = None
|
||||
seats: int
|
||||
used_seats: int
|
||||
plan_type: PlanType
|
||||
issued_at: datetime
|
||||
expires_at: datetime
|
||||
grace_period_end: datetime | None = None
|
||||
status: ApplicationStatus
|
||||
source: LicenseSource | None = None
|
||||
stripe_subscription_id: str | None = None
|
||||
|
||||
|
||||
class LicenseStatusResponse(BaseModel):
|
||||
"""Response for license status API."""
|
||||
|
||||
has_license: bool
|
||||
seats: int = 0
|
||||
used_seats: int = 0
|
||||
plan_type: PlanType | None = None
|
||||
issued_at: datetime | None = None
|
||||
expires_at: datetime | None = None
|
||||
grace_period_end: datetime | None = None
|
||||
status: ApplicationStatus | None = None
|
||||
source: LicenseSource | None = None
|
||||
|
||||
|
||||
class LicenseResponse(BaseModel):
|
||||
"""Response after license fetch/upload."""
|
||||
|
||||
success: bool
|
||||
message: str | None = None
|
||||
license: LicensePayload | None = None
|
||||
|
||||
|
||||
class LicenseUploadResponse(BaseModel):
|
||||
"""Response after license upload."""
|
||||
|
||||
success: bool
|
||||
message: str | None = None
|
||||
|
||||
|
||||
class SeatUsageResponse(BaseModel):
|
||||
"""Response for seat usage API."""
|
||||
|
||||
total_seats: int
|
||||
used_seats: int
|
||||
available_seats: int
|
||||
@@ -8,10 +8,12 @@ from ee.onyx.server.query_and_chat.models import (
|
||||
BasicCreateChatMessageWithHistoryRequest,
|
||||
)
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import combine_message_thread
|
||||
from onyx.chat.chat_utils import create_chat_history_chain
|
||||
from onyx.chat.models import ChatBasicResponse
|
||||
from onyx.chat.process_message import gather_stream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import OptionalSearchSetting
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
@@ -20,8 +22,9 @@ from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -100,12 +103,14 @@ def handle_simplified_chat_message(
|
||||
chunks_below=0,
|
||||
full_doc=chat_message_req.full_doc,
|
||||
structured_response_format=chat_message_req.structured_response_format,
|
||||
use_agentic_search=chat_message_req.use_agentic_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=full_chat_msg_info,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
enforce_chat_session_id_for_search_docs=False,
|
||||
)
|
||||
|
||||
return gather_stream(packets)
|
||||
@@ -156,13 +161,15 @@ def handle_send_message_simple_with_history(
|
||||
persona_id=req.persona_id,
|
||||
)
|
||||
|
||||
llm = get_llm_for_persona(persona=chat_session.persona, user=user)
|
||||
llm, _ = get_llms_for_persona(persona=chat_session.persona, user=user)
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
)
|
||||
|
||||
max_history_tokens = int(llm.config.max_input_tokens * CHAT_TARGET_CHUNK_PERCENTAGE)
|
||||
|
||||
# Every chat Session begins with an empty root message
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
@@ -181,6 +188,17 @@ def handle_send_message_simple_with_history(
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
history_str = combine_message_thread(
|
||||
messages=msg_history,
|
||||
max_tokens=max_history_tokens,
|
||||
llm_tokenizer=llm_tokenizer,
|
||||
)
|
||||
|
||||
rephrased_query = req.query_override or thread_based_query_rephrase(
|
||||
user_query=query,
|
||||
history_str=history_str,
|
||||
)
|
||||
|
||||
if req.retrieval_options is None and req.search_doc_ids is None:
|
||||
retrieval_options: RetrievalDetails | None = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
@@ -198,17 +216,19 @@ def handle_send_message_simple_with_history(
|
||||
retrieval_options=retrieval_options,
|
||||
# Simple API does not support reranking, hide complexity from user
|
||||
rerank_settings=None,
|
||||
query_override=None,
|
||||
query_override=rephrased_query,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
full_doc=req.full_doc,
|
||||
structured_response_format=req.structured_response_format,
|
||||
use_agentic_search=req.use_agentic_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=full_chat_msg_info,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
enforce_chat_session_id_for_search_docs=False,
|
||||
)
|
||||
|
||||
return gather_stream(packets)
|
||||
|
||||
@@ -54,6 +54,9 @@ class BasicCreateChatMessageRequest(ChunkContext):
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_chat_session_or_persona(self) -> "BasicCreateChatMessageRequest":
|
||||
if self.chat_session_id is None and self.persona_id is None:
|
||||
@@ -73,6 +76,8 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
|
||||
class SimpleDoc(BaseModel):
|
||||
|
||||
@@ -48,7 +48,6 @@ from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.query_and_chat.models import ChatSessionDetails
|
||||
from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
from onyx.utils.threadpool_concurrency import parallel_yield
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -295,7 +294,7 @@ def list_all_query_history_exports(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/query-history/start-export", tags=PUBLIC_API_TAGS)
|
||||
@router.post("/admin/query-history/start-export")
|
||||
def start_query_history_export(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
@@ -341,7 +340,7 @@ def start_query_history_export(
|
||||
return {"request_id": task_id}
|
||||
|
||||
|
||||
@router.get("/admin/query-history/export-status", tags=PUBLIC_API_TAGS)
|
||||
@router.get("/admin/query-history/export-status")
|
||||
def get_query_history_export_status(
|
||||
request_id: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
@@ -375,7 +374,7 @@ def get_query_history_export_status(
|
||||
return {"status": TaskStatus.SUCCESS}
|
||||
|
||||
|
||||
@router.get("/admin/query-history/download", tags=PUBLIC_API_TAGS)
|
||||
@router.get("/admin/query-history/download")
|
||||
def download_query_history_csv(
|
||||
request_id: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
"""Tenant-specific usage limit overrides from the control plane (EE version)."""
|
||||
|
||||
import requests
|
||||
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.server.tenant_usage_limits import TenantUsageLimitOverrides
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# In-memory storage for tenant overrides (populated at startup)
|
||||
_tenant_usage_limit_overrides: dict[str, TenantUsageLimitOverrides] | None = None
|
||||
|
||||
|
||||
def fetch_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides]:
|
||||
"""
|
||||
Fetch tenant-specific usage limit overrides from the control plane.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tenant_id to their specific limit overrides.
|
||||
Returns empty dict on any error (falls back to defaults).
|
||||
"""
|
||||
try:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/usage-limit-overrides"
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
tenant_overrides = response.json()
|
||||
|
||||
# Parse each tenant's overrides
|
||||
result: dict[str, TenantUsageLimitOverrides] = {}
|
||||
for override_data in tenant_overrides:
|
||||
tenant_id = override_data["tenant_id"]
|
||||
try:
|
||||
result[tenant_id] = TenantUsageLimitOverrides(**override_data)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to parse usage limit overrides for tenant {tenant_id}: {e}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warning(f"Failed to fetch usage limit overrides from control plane: {e}")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing usage limit overrides: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def load_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides]:
|
||||
"""
|
||||
Load tenant usage limit overrides from the control plane.
|
||||
|
||||
Called at server startup to populate the in-memory cache.
|
||||
"""
|
||||
global _tenant_usage_limit_overrides
|
||||
|
||||
logger.info("Loading tenant usage limit overrides from control plane...")
|
||||
overrides = fetch_usage_limit_overrides()
|
||||
_tenant_usage_limit_overrides = overrides
|
||||
|
||||
if overrides:
|
||||
logger.info(f"Loaded usage limit overrides for {len(overrides)} tenants")
|
||||
else:
|
||||
logger.info("No tenant-specific usage limit overrides found")
|
||||
return overrides
|
||||
|
||||
|
||||
def get_tenant_usage_limit_overrides(
|
||||
tenant_id: str,
|
||||
) -> TenantUsageLimitOverrides | None:
|
||||
"""
|
||||
Get the usage limit overrides for a specific tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID to look up
|
||||
|
||||
Returns:
|
||||
TenantUsageLimitOverrides if the tenant has overrides, None otherwise.
|
||||
"""
|
||||
global _tenant_usage_limit_overrides
|
||||
if _tenant_usage_limit_overrides is None:
|
||||
_tenant_usage_limit_overrides = load_usage_limit_overrides()
|
||||
return _tenant_usage_limit_overrides.get(tenant_id)
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import aiohttp # Async HTTP client
|
||||
@@ -9,7 +10,10 @@ from fastapi import Request
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
|
||||
from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
|
||||
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import TenantByDomainResponse
|
||||
from ee.onyx.server.tenants.models import TenantCreationPayload
|
||||
@@ -21,18 +25,11 @@ from ee.onyx.server.tenants.user_mapping import add_users_to_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import user_owns_a_tenant
|
||||
from onyx.auth.users import exceptions
|
||||
from onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
|
||||
from onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from onyx.configs.app_configs import OPENROUTER_DEFAULT_API_KEY
|
||||
from onyx.configs.app_configs import VERTEXAI_DEFAULT_CREDENTIALS
|
||||
from onyx.configs.app_configs import VERTEXAI_DEFAULT_LOCATION
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.image_generation import create_default_image_gen_config_from_api_key
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
@@ -40,25 +37,15 @@ from onyx.db.models import AvailableTenant
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import UserTenantMapping
|
||||
from onyx.llm.well_known_providers.auto_update_models import LLMRecommendations
|
||||
from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import VERTEX_CREDENTIALS_FILE_KWARG
|
||||
from onyx.llm.well_known_providers.constants import VERTEX_LOCATION_KWARG
|
||||
from onyx.llm.well_known_providers.constants import VERTEXAI_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
get_recommendations,
|
||||
)
|
||||
from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
model_configurations_for_provider,
|
||||
)
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.llm_provider_options import get_anthropic_model_names
|
||||
from onyx.llm.llm_provider_options import get_openai_model_names
|
||||
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.setup import setup_onyx
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
@@ -66,7 +53,7 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_or_provision_tenant(
|
||||
@@ -275,173 +262,61 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
logger.info(f"Tenant rollback completed successfully for tenant {tenant_id}")
|
||||
|
||||
|
||||
def _build_model_configuration_upsert_requests(
|
||||
provider_name: str,
|
||||
recommendations: LLMRecommendations,
|
||||
) -> list[ModelConfigurationUpsertRequest]:
|
||||
model_configurations = model_configurations_for_provider(
|
||||
provider_name, recommendations
|
||||
)
|
||||
return [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_configuration.name,
|
||||
is_visible=model_configuration.is_visible,
|
||||
max_input_tokens=model_configuration.max_input_tokens,
|
||||
supports_image_input=model_configuration.supports_image_input,
|
||||
)
|
||||
for model_configuration in model_configurations
|
||||
]
|
||||
|
||||
|
||||
def configure_default_api_keys(db_session: Session) -> None:
|
||||
"""Configure default LLM providers using recommended-models.json for model selection."""
|
||||
# Load recommendations from JSON config
|
||||
recommendations = get_recommendations()
|
||||
|
||||
has_set_default_provider = False
|
||||
|
||||
def _upsert(request: LLMProviderUpsertRequest) -> None:
|
||||
nonlocal has_set_default_provider
|
||||
try:
|
||||
provider = upsert_llm_provider(request, db_session)
|
||||
if not has_set_default_provider:
|
||||
update_default_provider(provider.id, db_session)
|
||||
has_set_default_provider = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure {request.provider} provider: {e}")
|
||||
|
||||
# Configure OpenAI provider
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
default_model = recommendations.get_default_model(OPENAI_PROVIDER_NAME)
|
||||
if default_model is None:
|
||||
logger.error(
|
||||
f"No default model found for {OPENAI_PROVIDER_NAME} in recommendations"
|
||||
)
|
||||
default_model_name = default_model.name if default_model else "gpt-5.2"
|
||||
|
||||
openai_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
OPENAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openai_provider)
|
||||
|
||||
# Create default image generation config using the OpenAI API key
|
||||
try:
|
||||
create_default_image_gen_config_from_api_key(
|
||||
db_session, OPENAI_DEFAULT_API_KEY
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create default image gen config: {e}")
|
||||
else:
|
||||
logger.info(
|
||||
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
|
||||
)
|
||||
|
||||
# Configure Anthropic provider
|
||||
if ANTHROPIC_DEFAULT_API_KEY:
|
||||
default_model = recommendations.get_default_model(ANTHROPIC_PROVIDER_NAME)
|
||||
if default_model is None:
|
||||
logger.error(
|
||||
f"No default model found for {ANTHROPIC_PROVIDER_NAME} in recommendations"
|
||||
)
|
||||
default_model_name = (
|
||||
default_model.name if default_model else "claude-sonnet-4-5"
|
||||
)
|
||||
|
||||
anthropic_provider = LLMProviderUpsertRequest(
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
ANTHROPIC_PROVIDER_NAME, recommendations
|
||||
),
|
||||
default_model_name="claude-3-7-sonnet-20250219",
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=name,
|
||||
is_visible=False,
|
||||
max_input_tokens=None,
|
||||
)
|
||||
for name in get_anthropic_model_names()
|
||||
],
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(anthropic_provider)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure Anthropic provider: {e}")
|
||||
else:
|
||||
logger.info(
|
||||
logger.error(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
)
|
||||
|
||||
# Configure Vertex AI provider
|
||||
if VERTEXAI_DEFAULT_CREDENTIALS:
|
||||
default_model = recommendations.get_default_model(VERTEXAI_PROVIDER_NAME)
|
||||
if default_model is None:
|
||||
logger.error(
|
||||
f"No default model found for {VERTEXAI_PROVIDER_NAME} in recommendations"
|
||||
)
|
||||
default_model_name = default_model.name if default_model else "gemini-2.5-pro"
|
||||
|
||||
# Vertex AI uses custom_config for credentials and location
|
||||
custom_config = {
|
||||
VERTEX_CREDENTIALS_FILE_KWARG: VERTEXAI_DEFAULT_CREDENTIALS,
|
||||
VERTEX_LOCATION_KWARG: VERTEXAI_DEFAULT_LOCATION,
|
||||
}
|
||||
|
||||
vertexai_provider = LLMProviderUpsertRequest(
|
||||
name="Google Vertex AI",
|
||||
provider=VERTEXAI_PROVIDER_NAME,
|
||||
custom_config=custom_config,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
VERTEXAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
openai_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name="gpt-4o",
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_name,
|
||||
is_visible=False,
|
||||
max_input_tokens=None,
|
||||
)
|
||||
for model_name in get_openai_model_names()
|
||||
],
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(vertexai_provider)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(openai_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure OpenAI provider: {e}")
|
||||
else:
|
||||
logger.info(
|
||||
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
|
||||
logger.error(
|
||||
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
|
||||
)
|
||||
|
||||
# Configure OpenRouter provider
|
||||
if OPENROUTER_DEFAULT_API_KEY:
|
||||
default_model = recommendations.get_default_model(OPENROUTER_PROVIDER_NAME)
|
||||
if default_model is None:
|
||||
logger.error(
|
||||
f"No default model found for {OPENROUTER_PROVIDER_NAME} in recommendations"
|
||||
)
|
||||
default_model_name = default_model.name if default_model else "z-ai/glm-4.7"
|
||||
|
||||
# For OpenRouter, we use the visible models from recommendations as model_configurations
|
||||
# since OpenRouter models are dynamic (fetched from their API)
|
||||
visible_models = recommendations.get_visible_models(OPENROUTER_PROVIDER_NAME)
|
||||
model_configurations = [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model.name,
|
||||
is_visible=True,
|
||||
max_input_tokens=None,
|
||||
display_name=model.display_name,
|
||||
)
|
||||
for model in visible_models
|
||||
]
|
||||
|
||||
openrouter_provider = LLMProviderUpsertRequest(
|
||||
name="OpenRouter",
|
||||
provider=OPENROUTER_PROVIDER_NAME,
|
||||
api_key=OPENROUTER_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=model_configurations,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openrouter_provider)
|
||||
else:
|
||||
logger.info(
|
||||
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"
|
||||
)
|
||||
|
||||
# Configure Cohere embedding provider
|
||||
if COHERE_DEFAULT_API_KEY:
|
||||
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
|
||||
provider_type=EmbeddingProvider.COHERE,
|
||||
@@ -687,11 +562,17 @@ async def assign_tenant_to_user(
|
||||
try:
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=email,
|
||||
event=MilestoneRecordType.TENANT_CREATED,
|
||||
)
|
||||
# Create milestone record in the same transaction context as the tenant assignment
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||
properties={
|
||||
"email": email,
|
||||
},
|
||||
db_session=db_session,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
|
||||
raise Exception("Failed to assign tenant to user")
|
||||
|
||||
@@ -249,17 +249,6 @@ def accept_user_invite(email: str, tenant_id: str) -> None:
|
||||
)
|
||||
raise
|
||||
|
||||
# Remove from invited users list since they've accepted
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
invited_users = get_invited_users()
|
||||
if email in invited_users:
|
||||
invited_users.remove(email)
|
||||
write_invited_users(invited_users)
|
||||
logger.info(f"Removed {email} from invited users list after acceptance")
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def deny_user_invite(email: str, tenant_id: str) -> None:
|
||||
"""
|
||||
|
||||
@@ -16,9 +16,8 @@ from onyx.db.token_limit import insert_user_token_rate_limit
|
||||
from onyx.server.query_and_chat.token_limit import any_rate_limit_exists
|
||||
from onyx.server.token_rate_limits.models import TokenRateLimitArgs
|
||||
from onyx.server.token_rate_limits.models import TokenRateLimitDisplay
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
|
||||
router = APIRouter(prefix="/admin/token-rate-limits", tags=PUBLIC_API_TAGS)
|
||||
router = APIRouter(prefix="/admin/token-rate-limits")
|
||||
|
||||
|
||||
"""
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
"""EE Usage limits - trial detection via billing information."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_tenant_on_trial(tenant_id: str) -> bool:
|
||||
"""
|
||||
Determine if a tenant is currently on a trial subscription.
|
||||
|
||||
In multi-tenant mode, we fetch billing information from the control plane
|
||||
to determine if the tenant has an active trial.
|
||||
"""
|
||||
if not MULTI_TENANT:
|
||||
return False
|
||||
|
||||
try:
|
||||
billing_info = fetch_billing_information(tenant_id)
|
||||
|
||||
# If not subscribed at all, check if we have trial information
|
||||
if isinstance(billing_info, SubscriptionStatusResponse):
|
||||
# No subscription means they're likely on trial (new tenant)
|
||||
return True
|
||||
|
||||
if isinstance(billing_info, BillingInformation):
|
||||
# Check if trial is active
|
||||
if billing_info.trial_end is not None:
|
||||
now = datetime.now(timezone.utc)
|
||||
# Trial active if trial_end is in the future
|
||||
# and subscription status indicates trialing
|
||||
if billing_info.trial_end > now and billing_info.status == "trialing":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch billing info for trial check: {e}")
|
||||
# Default to trial limits on error (more restrictive = safer)
|
||||
return True
|
||||
@@ -21,12 +21,11 @@ from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS)
|
||||
router = APIRouter(prefix="/manage")
|
||||
|
||||
|
||||
@router.get("/admin/user-group")
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
"""RSA-4096 license signature verification utilities."""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
|
||||
|
||||
from ee.onyx.server.license.models import LicenseData
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# RSA-4096 Public Key for license verification
|
||||
# Load from environment variable - key is generated on the control plane
|
||||
# In production, inject via Kubernetes secrets or secrets manager
|
||||
LICENSE_PUBLIC_KEY_PEM = os.environ.get("LICENSE_PUBLIC_KEY_PEM", "")
|
||||
|
||||
|
||||
def _get_public_key() -> RSAPublicKey:
|
||||
"""Load the public key from environment variable."""
|
||||
if not LICENSE_PUBLIC_KEY_PEM:
|
||||
raise ValueError(
|
||||
"LICENSE_PUBLIC_KEY_PEM environment variable not set. "
|
||||
"License verification requires the control plane public key."
|
||||
)
|
||||
key = serialization.load_pem_public_key(LICENSE_PUBLIC_KEY_PEM.encode())
|
||||
if not isinstance(key, RSAPublicKey):
|
||||
raise ValueError("Expected RSA public key")
|
||||
return key
|
||||
|
||||
|
||||
def verify_license_signature(license_data: str) -> LicensePayload:
|
||||
"""
|
||||
Verify RSA-4096 signature and return payload if valid.
|
||||
|
||||
Args:
|
||||
license_data: Base64-encoded JSON containing payload and signature
|
||||
|
||||
Returns:
|
||||
LicensePayload if signature is valid
|
||||
|
||||
Raises:
|
||||
ValueError: If license data is invalid or signature verification fails
|
||||
"""
|
||||
try:
|
||||
# Decode the license data
|
||||
decoded = json.loads(base64.b64decode(license_data))
|
||||
license_obj = LicenseData(**decoded)
|
||||
|
||||
payload_json = json.dumps(
|
||||
license_obj.payload.model_dump(mode="json"), sort_keys=True
|
||||
)
|
||||
signature_bytes = base64.b64decode(license_obj.signature)
|
||||
|
||||
# Verify signature using PSS padding (modern standard)
|
||||
public_key = _get_public_key()
|
||||
public_key.verify(
|
||||
signature_bytes,
|
||||
payload_json.encode(),
|
||||
padding.PSS(
|
||||
mgf=padding.MGF1(hashes.SHA256()),
|
||||
salt_length=padding.PSS.MAX_LENGTH,
|
||||
),
|
||||
hashes.SHA256(),
|
||||
)
|
||||
|
||||
return license_obj.payload
|
||||
|
||||
except InvalidSignature:
|
||||
logger.error("License signature verification failed")
|
||||
raise ValueError("Invalid license signature")
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Failed to decode license JSON")
|
||||
raise ValueError("Invalid license format: not valid JSON")
|
||||
except (ValueError, KeyError, TypeError) as e:
|
||||
logger.error(f"License data validation error: {type(e).__name__}")
|
||||
raise ValueError(f"Invalid license format: {type(e).__name__}")
|
||||
except Exception:
|
||||
logger.exception("Unexpected error during license verification")
|
||||
raise ValueError("License verification failed: unexpected error")
|
||||
|
||||
|
||||
def get_license_status(
|
||||
payload: LicensePayload,
|
||||
grace_period_end: datetime | None = None,
|
||||
) -> ApplicationStatus:
|
||||
"""
|
||||
Determine current license status based on expiry.
|
||||
|
||||
Args:
|
||||
payload: The verified license payload
|
||||
grace_period_end: Optional grace period end datetime
|
||||
|
||||
Returns:
|
||||
ApplicationStatus indicating current license state
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Check if grace period has expired
|
||||
if grace_period_end and now > grace_period_end:
|
||||
return ApplicationStatus.GATED_ACCESS
|
||||
|
||||
# Check if license has expired
|
||||
if now > payload.expires_at:
|
||||
if grace_period_end and now <= grace_period_end:
|
||||
return ApplicationStatus.GRACE_PERIOD
|
||||
return ApplicationStatus.GATED_ACCESS
|
||||
|
||||
# License is valid
|
||||
return ApplicationStatus.ACTIVE
|
||||
|
||||
|
||||
def is_license_valid(payload: LicensePayload) -> bool:
|
||||
"""Check if a license is currently valid (not expired)."""
|
||||
now = datetime.now(timezone.utc)
|
||||
return now <= payload.expires_at
|
||||
@@ -1,4 +1,5 @@
|
||||
MODEL_WARM_UP_STRING = "hi " * 512
|
||||
INFORMATION_CONTENT_MODEL_WARM_UP_STRING = "hi " * 16
|
||||
|
||||
|
||||
class GPUStatus:
|
||||
|
||||
562
backend/model_server/custom_models.py
Normal file
562
backend/model_server/custom_models.py
Normal file
@@ -0,0 +1,562 @@
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fastapi import APIRouter
|
||||
from huggingface_hub import snapshot_download # type: ignore
|
||||
|
||||
from model_server.constants import INFORMATION_CONTENT_MODEL_WARM_UP_STRING
|
||||
from model_server.constants import MODEL_WARM_UP_STRING
|
||||
from model_server.onyx_torch_model import ConnectorClassifier
|
||||
from model_server.onyx_torch_model import HybridClassifier
|
||||
from model_server.utils import simple_log_function_time
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
|
||||
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
|
||||
from shared_configs.configs import (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
|
||||
)
|
||||
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
|
||||
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
from shared_configs.configs import (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE,
|
||||
)
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.configs import INFORMATION_CONTENT_MODEL_TAG
|
||||
from shared_configs.configs import INFORMATION_CONTENT_MODEL_VERSION
|
||||
from shared_configs.configs import INTENT_MODEL_TAG
|
||||
from shared_configs.configs import INTENT_MODEL_VERSION
|
||||
from shared_configs.model_server_models import ConnectorClassificationRequest
|
||||
from shared_configs.model_server_models import ConnectorClassificationResponse
|
||||
from shared_configs.model_server_models import ContentClassificationPrediction
|
||||
from shared_configs.model_server_models import IntentRequest
|
||||
from shared_configs.model_server_models import IntentResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from setfit import SetFitModel # type: ignore
|
||||
from transformers import PreTrainedTokenizer, BatchEncoding # type: ignore
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/custom")
|
||||
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER: Optional["PreTrainedTokenizer"] = None
|
||||
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
|
||||
|
||||
_INTENT_TOKENIZER: Optional["PreTrainedTokenizer"] = None
|
||||
_INTENT_MODEL: HybridClassifier | None = None
|
||||
|
||||
_INFORMATION_CONTENT_MODEL: Optional["SetFitModel"] = None
|
||||
|
||||
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
|
||||
|
||||
|
||||
def get_connector_classifier_tokenizer() -> "PreTrainedTokenizer":
|
||||
global _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
|
||||
# The tokenizer details are not uploaded to the HF hub since it's just the
|
||||
# unmodified distilbert tokenizer.
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER = cast(
|
||||
PreTrainedTokenizer,
|
||||
AutoTokenizer.from_pretrained("distilbert-base-uncased"),
|
||||
)
|
||||
return _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
|
||||
|
||||
def get_local_connector_classifier(
|
||||
model_name_or_path: str = CONNECTOR_CLASSIFIER_MODEL_REPO,
|
||||
tag: str = CONNECTOR_CLASSIFIER_MODEL_TAG,
|
||||
) -> ConnectorClassifier:
|
||||
global _CONNECTOR_CLASSIFIER_MODEL
|
||||
if _CONNECTOR_CLASSIFIER_MODEL is None:
|
||||
try:
|
||||
# Calculate where the cache should be, then load from local if available
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=True
|
||||
)
|
||||
_CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
|
||||
local_path
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load model directly: {e}")
|
||||
try:
|
||||
# Attempt to download the model snapshot
|
||||
logger.info(f"Downloading model snapshot for {model_name_or_path}")
|
||||
local_path = snapshot_download(repo_id=model_name_or_path, revision=tag)
|
||||
_CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
|
||||
local_path
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to load model even after attempted snapshot download: {e}"
|
||||
)
|
||||
raise
|
||||
return _CONNECTOR_CLASSIFIER_MODEL
|
||||
|
||||
|
||||
def get_intent_model_tokenizer() -> "PreTrainedTokenizer":
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
global _INTENT_TOKENIZER
|
||||
if _INTENT_TOKENIZER is None:
|
||||
# The tokenizer details are not uploaded to the HF hub since it's just the
|
||||
# unmodified distilbert tokenizer.
|
||||
_INTENT_TOKENIZER = cast(
|
||||
PreTrainedTokenizer,
|
||||
AutoTokenizer.from_pretrained("distilbert-base-uncased"),
|
||||
)
|
||||
return _INTENT_TOKENIZER
|
||||
|
||||
|
||||
def get_local_intent_model(
|
||||
model_name_or_path: str = INTENT_MODEL_VERSION,
|
||||
tag: str | None = INTENT_MODEL_TAG,
|
||||
) -> HybridClassifier:
|
||||
global _INTENT_MODEL
|
||||
if _INTENT_MODEL is None:
|
||||
try:
|
||||
# Calculate where the cache should be, then load from local if available
|
||||
logger.notice(f"Loading model from local cache: {model_name_or_path}")
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=True
|
||||
)
|
||||
_INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
|
||||
logger.notice(f"Loaded model from local cache: {local_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load model directly: {e}")
|
||||
try:
|
||||
# Attempt to download the model snapshot
|
||||
logger.notice(f"Downloading model snapshot for {model_name_or_path}")
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=False
|
||||
)
|
||||
_INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to load model even after attempted snapshot download: {e}"
|
||||
)
|
||||
raise
|
||||
return _INTENT_MODEL
|
||||
|
||||
|
||||
def get_local_information_content_model(
|
||||
model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION,
|
||||
tag: str | None = INFORMATION_CONTENT_MODEL_TAG,
|
||||
) -> "SetFitModel":
|
||||
from setfit import SetFitModel
|
||||
|
||||
global _INFORMATION_CONTENT_MODEL
|
||||
if _INFORMATION_CONTENT_MODEL is None:
|
||||
try:
|
||||
# Calculate where the cache should be, then load from local if available
|
||||
logger.notice(
|
||||
f"Loading content information model from local cache: {model_name_or_path}"
|
||||
)
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=True
|
||||
)
|
||||
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
|
||||
logger.notice(
|
||||
f"Loaded content information model from local cache: {local_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load content information model directly: {e}")
|
||||
try:
|
||||
# Attempt to download the model snapshot
|
||||
logger.notice(
|
||||
f"Downloading content information model snapshot for {model_name_or_path}"
|
||||
)
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=False
|
||||
)
|
||||
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to load content information model even after attempted snapshot download: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
return _INFORMATION_CONTENT_MODEL
|
||||
|
||||
|
||||
def tokenize_connector_classification_query(
|
||||
connectors: list[str],
|
||||
query: str,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
connector_token_end_id: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Tokenize the connectors & user query into one prompt for the forward pass of ConnectorClassifier models
|
||||
|
||||
The attention mask is just all 1s. The prompt is CLS + each connector name suffixed with the connector end
|
||||
token and then the user query.
|
||||
"""
|
||||
|
||||
input_ids = torch.tensor([tokenizer.cls_token_id], dtype=torch.long)
|
||||
|
||||
for connector in connectors:
|
||||
connector_token_ids = tokenizer(
|
||||
connector,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids = torch.cat(
|
||||
(
|
||||
input_ids,
|
||||
connector_token_ids["input_ids"].squeeze(dim=0),
|
||||
torch.tensor([connector_token_end_id], dtype=torch.long),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
query_token_ids = tokenizer(
|
||||
query,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids = torch.cat(
|
||||
(
|
||||
input_ids,
|
||||
query_token_ids["input_ids"].squeeze(dim=0),
|
||||
torch.tensor([tokenizer.sep_token_id], dtype=torch.long),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
attention_mask = torch.ones(input_ids.numel(), dtype=torch.long)
|
||||
|
||||
return input_ids.unsqueeze(0), attention_mask.unsqueeze(0)
|
||||
|
||||
|
||||
def warm_up_connector_classifier_model() -> None:
|
||||
logger.info(
|
||||
f"Warming up connector_classifier model {CONNECTOR_CLASSIFIER_MODEL_TAG}"
|
||||
)
|
||||
connector_classifier_tokenizer = get_connector_classifier_tokenizer()
|
||||
connector_classifier = get_local_connector_classifier()
|
||||
|
||||
input_ids, attention_mask = tokenize_connector_classification_query(
|
||||
["GitHub"],
|
||||
"onyx classifier query google doc",
|
||||
connector_classifier_tokenizer,
|
||||
connector_classifier.connector_end_token_id,
|
||||
)
|
||||
input_ids = input_ids.to(connector_classifier.device)
|
||||
attention_mask = attention_mask.to(connector_classifier.device)
|
||||
|
||||
connector_classifier(input_ids, attention_mask)
|
||||
|
||||
|
||||
def warm_up_intent_model() -> None:
|
||||
logger.notice(f"Warming up Intent Model: {INTENT_MODEL_VERSION}")
|
||||
intent_tokenizer = get_intent_model_tokenizer()
|
||||
tokens = intent_tokenizer(
|
||||
MODEL_WARM_UP_STRING, return_tensors="pt", truncation=True, padding=True
|
||||
)
|
||||
|
||||
intent_model = get_local_intent_model()
|
||||
device = intent_model.device
|
||||
intent_model(
|
||||
query_ids=tokens["input_ids"].to(device),
|
||||
query_mask=tokens["attention_mask"].to(device),
|
||||
)
|
||||
|
||||
|
||||
def warm_up_information_content_model() -> None:
|
||||
logger.notice("Warming up Content Model") # TODO: add version if needed
|
||||
|
||||
information_content_model = get_local_information_content_model()
|
||||
information_content_model(INFORMATION_CONTENT_MODEL_WARM_UP_STRING)
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def run_inference(tokens: "BatchEncoding") -> tuple[list[float], list[float]]:
|
||||
intent_model = get_local_intent_model()
|
||||
device = intent_model.device
|
||||
|
||||
outputs = intent_model(
|
||||
query_ids=tokens["input_ids"].to(device),
|
||||
query_mask=tokens["attention_mask"].to(device),
|
||||
)
|
||||
|
||||
token_logits = outputs["token_logits"]
|
||||
intent_logits = outputs["intent_logits"]
|
||||
|
||||
# Move tensors to CPU before applying softmax and converting to numpy
|
||||
intent_probabilities = F.softmax(intent_logits.cpu(), dim=-1).numpy()[0]
|
||||
token_probabilities = F.softmax(token_logits.cpu(), dim=-1).numpy()[0]
|
||||
|
||||
# Extract the probabilities for the positive class (index 1) for each token
|
||||
token_positive_probs = token_probabilities[:, 1].tolist()
|
||||
|
||||
return intent_probabilities.tolist(), token_positive_probs
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def run_content_classification_inference(
|
||||
text_inputs: list[str],
|
||||
) -> list[ContentClassificationPrediction]:
|
||||
"""
|
||||
Assign a score to the segments in question. The model stored in get_local_information_content_model()
|
||||
creates the 'model score' based on its training, and the scores are then converted to a 0.0-1.0 scale.
|
||||
In the code outside of the model/inference model servers that score will be converted into the actual
|
||||
boost factor.
|
||||
"""
|
||||
|
||||
def _prob_to_score(prob: float) -> float:
|
||||
"""
|
||||
Conversion of base score to 0.0 - 1.0 score. Note that the min/max values depend on the model!
|
||||
"""
|
||||
_MIN_BASE_SCORE = 0.25
|
||||
_MAX_BASE_SCORE = 0.75
|
||||
if prob < _MIN_BASE_SCORE:
|
||||
raw_score = 0.0
|
||||
elif prob < _MAX_BASE_SCORE:
|
||||
raw_score = (prob - _MIN_BASE_SCORE) / (_MAX_BASE_SCORE - _MIN_BASE_SCORE)
|
||||
else:
|
||||
raw_score = 1.0
|
||||
return (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
+ (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
|
||||
- INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
)
|
||||
* raw_score
|
||||
)
|
||||
|
||||
_BATCH_SIZE = 32
|
||||
content_model = get_local_information_content_model()
|
||||
|
||||
# Process inputs in batches
|
||||
all_output_classes: list[int] = []
|
||||
all_base_output_probabilities: list[float] = []
|
||||
|
||||
for i in range(0, len(text_inputs), _BATCH_SIZE):
|
||||
batch = text_inputs[i : i + _BATCH_SIZE]
|
||||
batch_with_prefix = []
|
||||
batch_indices = []
|
||||
|
||||
# Pre-allocate results for this batch
|
||||
batch_output_classes: list[np.ndarray] = [np.array(1)] * len(batch)
|
||||
batch_probabilities: list[np.ndarray] = [np.array(1.0)] * len(batch)
|
||||
|
||||
# Pre-process batch to handle long input exceptions
|
||||
for j, text in enumerate(batch):
|
||||
if len(text) == 0:
|
||||
# if no input, treat as non-informative from the model's perspective
|
||||
batch_output_classes[j] = np.array(0)
|
||||
batch_probabilities[j] = np.array(0.0)
|
||||
logger.warning("Input for Content Information Model is empty")
|
||||
|
||||
elif (
|
||||
len(text.split())
|
||||
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
|
||||
):
|
||||
# if input is short, use the model
|
||||
batch_with_prefix.append(
|
||||
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX + text
|
||||
)
|
||||
batch_indices.append(j)
|
||||
else:
|
||||
# if longer than cutoff, treat as informative (stay with default), but issue warning
|
||||
logger.warning("Input for Content Information Model too long")
|
||||
|
||||
if batch_with_prefix: # Only run model if we have valid inputs
|
||||
# Get predictions for the batch
|
||||
model_output_classes = content_model(batch_with_prefix)
|
||||
model_output_probabilities = content_model.predict_proba(batch_with_prefix)
|
||||
|
||||
# Place results in the correct positions
|
||||
for idx, batch_idx in enumerate(batch_indices):
|
||||
batch_output_classes[batch_idx] = model_output_classes[idx].numpy()
|
||||
batch_probabilities[batch_idx] = model_output_probabilities[idx][
|
||||
1
|
||||
].numpy() # x[1] is prob of the positive class
|
||||
|
||||
all_output_classes.extend([int(x) for x in batch_output_classes])
|
||||
all_base_output_probabilities.extend([float(x) for x in batch_probabilities])
|
||||
|
||||
logits = [
|
||||
np.log(p / (1 - p)) if p != 0.0 and p != 1.0 else (100 if p == 1.0 else -100)
|
||||
for p in all_base_output_probabilities
|
||||
]
|
||||
scaled_logits = [
|
||||
logit / INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE
|
||||
for logit in logits
|
||||
]
|
||||
output_probabilities_with_temp = [
|
||||
np.exp(scaled_logit) / (1 + np.exp(scaled_logit))
|
||||
for scaled_logit in scaled_logits
|
||||
]
|
||||
|
||||
prediction_scores = [
|
||||
_prob_to_score(p_temp) for p_temp in output_probabilities_with_temp
|
||||
]
|
||||
|
||||
content_classification_predictions = [
|
||||
ContentClassificationPrediction(
|
||||
predicted_label=predicted_label, content_boost_factor=output_score
|
||||
)
|
||||
for predicted_label, output_score in zip(all_output_classes, prediction_scores)
|
||||
]
|
||||
|
||||
return content_classification_predictions
|
||||
|
||||
|
||||
def map_keywords(
|
||||
input_ids: torch.Tensor, tokenizer: "PreTrainedTokenizer", is_keyword: list[bool]
|
||||
) -> list[str]:
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore
|
||||
|
||||
if not len(tokens) == len(is_keyword):
|
||||
raise ValueError("Length of tokens and keyword predictions must match")
|
||||
|
||||
if input_ids[0] == tokenizer.cls_token_id:
|
||||
tokens = tokens[1:]
|
||||
is_keyword = is_keyword[1:]
|
||||
|
||||
if input_ids[-1] == tokenizer.sep_token_id:
|
||||
tokens = tokens[:-1]
|
||||
is_keyword = is_keyword[:-1]
|
||||
|
||||
unk_token = tokenizer.unk_token
|
||||
if unk_token in tokens:
|
||||
raise ValueError("Unknown token detected in the input")
|
||||
|
||||
keywords = []
|
||||
current_keyword = ""
|
||||
|
||||
for ind, token in enumerate(tokens):
|
||||
if is_keyword[ind]:
|
||||
if token.startswith("##"):
|
||||
current_keyword += token[2:]
|
||||
else:
|
||||
if current_keyword:
|
||||
keywords.append(current_keyword)
|
||||
current_keyword = token
|
||||
else:
|
||||
# If mispredicted a later token of a keyword, add it to the current keyword
|
||||
# to complete it
|
||||
if current_keyword:
|
||||
if len(current_keyword) > 2 and current_keyword.startswith("##"):
|
||||
current_keyword = current_keyword[2:]
|
||||
|
||||
else:
|
||||
keywords.append(current_keyword)
|
||||
current_keyword = ""
|
||||
|
||||
if current_keyword:
|
||||
keywords.append(current_keyword)
|
||||
|
||||
return keywords
|
||||
|
||||
|
||||
def clean_keywords(keywords: list[str]) -> list[str]:
|
||||
cleaned_words = []
|
||||
for word in keywords:
|
||||
word = word[:-2] if word.endswith("'s") else word
|
||||
word = word.replace("/", " ")
|
||||
word = word.replace("'", "").replace('"', "")
|
||||
cleaned_words.extend([w for w in word.strip().split() if w and not w.isspace()])
|
||||
return cleaned_words
|
||||
|
||||
|
||||
def run_connector_classification(req: ConnectorClassificationRequest) -> list[str]:
|
||||
tokenizer = get_connector_classifier_tokenizer()
|
||||
model = get_local_connector_classifier()
|
||||
|
||||
connector_names = req.available_connectors
|
||||
|
||||
input_ids, attention_mask = tokenize_connector_classification_query(
|
||||
connector_names,
|
||||
req.query,
|
||||
tokenizer,
|
||||
model.connector_end_token_id,
|
||||
)
|
||||
input_ids = input_ids.to(model.device)
|
||||
attention_mask = attention_mask.to(model.device)
|
||||
|
||||
global_confidence, classifier_confidence = model(input_ids, attention_mask)
|
||||
|
||||
if global_confidence.item() < 0.5:
|
||||
return []
|
||||
|
||||
passed_connectors = []
|
||||
|
||||
for i, connector_name in enumerate(connector_names):
|
||||
if classifier_confidence.view(-1)[i].item() > 0.5:
|
||||
passed_connectors.append(connector_name)
|
||||
|
||||
return passed_connectors
|
||||
|
||||
|
||||
def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]:
|
||||
tokenizer = get_intent_model_tokenizer()
|
||||
model_input = tokenizer(
|
||||
intent_req.query, return_tensors="pt", truncation=False, padding=False
|
||||
)
|
||||
|
||||
if len(model_input.input_ids[0]) > 512:
|
||||
# If the user text is too long, assume it is semantic and keep all words
|
||||
return True, intent_req.query.split()
|
||||
|
||||
intent_probs, token_probs = run_inference(model_input)
|
||||
|
||||
is_keyword_sequence = intent_probs[0] >= intent_req.keyword_percent_threshold
|
||||
|
||||
keyword_preds = [
|
||||
token_prob >= intent_req.keyword_percent_threshold for token_prob in token_probs
|
||||
]
|
||||
|
||||
try:
|
||||
keywords = map_keywords(model_input.input_ids[0], tokenizer, keyword_preds)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to extract keywords for query: {intent_req.query} due to {e}"
|
||||
)
|
||||
# Fallback to keeping all words
|
||||
keywords = intent_req.query.split()
|
||||
|
||||
cleaned_keywords = clean_keywords(keywords)
|
||||
|
||||
return is_keyword_sequence, cleaned_keywords
|
||||
|
||||
|
||||
@router.post("/connector-classification")
|
||||
async def process_connector_classification_request(
|
||||
classification_request: ConnectorClassificationRequest,
|
||||
) -> ConnectorClassificationResponse:
|
||||
if INDEXING_ONLY:
|
||||
raise RuntimeError(
|
||||
"Indexing model server should not call connector classification endpoint"
|
||||
)
|
||||
|
||||
if len(classification_request.available_connectors) == 0:
|
||||
return ConnectorClassificationResponse(connectors=[])
|
||||
|
||||
connectors = run_connector_classification(classification_request)
|
||||
return ConnectorClassificationResponse(connectors=connectors)
|
||||
|
||||
|
||||
@router.post("/query-analysis")
|
||||
async def process_analysis_request(
|
||||
intent_request: IntentRequest,
|
||||
) -> IntentResponse:
|
||||
if INDEXING_ONLY:
|
||||
raise RuntimeError("Indexing model server should not call intent endpoint")
|
||||
|
||||
is_keyword, keywords = run_analysis(intent_request)
|
||||
return IntentResponse(is_keyword=is_keyword, keywords=keywords)
|
||||
|
||||
|
||||
@router.post("/content-classification")
|
||||
async def process_content_classification_request(
|
||||
content_classification_requests: list[str],
|
||||
) -> list[ContentClassificationPrediction]:
|
||||
return run_content_classification_inference(content_classification_requests)
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -9,13 +10,16 @@ from fastapi import Request
|
||||
|
||||
from model_server.utils import simple_log_function_time
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.model_server_models import Embedding
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
from shared_configs.model_server_models import EmbedResponse
|
||||
from shared_configs.model_server_models import RerankRequest
|
||||
from shared_configs.model_server_models import RerankResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sentence_transformers import CrossEncoder, SentenceTransformer
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -23,6 +27,11 @@ router = APIRouter(prefix="/encoder")
|
||||
|
||||
|
||||
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
||||
_RERANK_MODEL: Optional["CrossEncoder"] = None
|
||||
|
||||
# If we are not only indexing, dont want retry very long
|
||||
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
|
||||
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
|
||||
|
||||
|
||||
def get_embedding_model(
|
||||
@@ -33,7 +42,7 @@ def get_embedding_model(
|
||||
Loads or returns a cached SentenceTransformer, sets max_seq_length, pins device,
|
||||
pre-warms rotary caches once, and wraps encode() with a lock to avoid cache races.
|
||||
"""
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
def _prewarm_rope(st_model: "SentenceTransformer", target_len: int) -> None:
|
||||
"""
|
||||
@@ -78,6 +87,19 @@ def get_embedding_model(
|
||||
return _GLOBAL_MODELS_DICT[model_name]
|
||||
|
||||
|
||||
def get_local_reranking_model(
|
||||
model_name: str,
|
||||
) -> "CrossEncoder":
|
||||
global _RERANK_MODEL
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
|
||||
if _RERANK_MODEL is None:
|
||||
logger.notice(f"Loading {model_name}")
|
||||
model = CrossEncoder(model_name)
|
||||
_RERANK_MODEL = model
|
||||
return _RERANK_MODEL
|
||||
|
||||
|
||||
ENCODING_RETRIES = 3
|
||||
ENCODING_RETRY_DELAY = 0.1
|
||||
|
||||
@@ -167,6 +189,16 @@ async def embed_text(
|
||||
return embeddings
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
async def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
|
||||
cross_encoder = get_local_reranking_model(model_name)
|
||||
# Run CPU-bound reranking in a thread pool
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(), # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@router.post("/bi-encoder-embed")
|
||||
async def route_bi_encoder_embed(
|
||||
request: Request,
|
||||
@@ -222,3 +254,39 @@ async def process_embed_request(
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error during embedding process: {e}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/cross-encoder-scores")
|
||||
async def process_rerank_request(rerank_request: RerankRequest) -> RerankResponse:
|
||||
"""Cross encoders can be purely black box from the app perspective"""
|
||||
# Only local models should use this endpoint - API providers should make direct API calls
|
||||
if rerank_request.provider_type is not None:
|
||||
raise ValueError(
|
||||
f"Model server reranking endpoint should only be used for local models. "
|
||||
f"API provider '{rerank_request.provider_type}' should make direct API calls instead."
|
||||
)
|
||||
|
||||
if INDEXING_ONLY:
|
||||
raise RuntimeError("Indexing model server should not call intent endpoint")
|
||||
|
||||
if not rerank_request.documents or not rerank_request.query:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing documents or query for reranking"
|
||||
)
|
||||
if not all(rerank_request.documents):
|
||||
raise ValueError("Empty documents cannot be reranked.")
|
||||
|
||||
try:
|
||||
# At this point, provider_type is None, so handle local reranking
|
||||
sim_scores = await local_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during reranking process:\n{str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to run Cross-Encoder reranking"
|
||||
)
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
This directory contains code that was useful and may become useful again in the future.
|
||||
|
||||
We stopped using rerankers because the state of the art rerankers are not significantly better than the biencoders and much worse than LLMs which are also capable of acting on a small set of documents for filtering, reranking, etc.
|
||||
|
||||
We stopped using the internal query classifier as that's now offloaded to the LLM which does query expansion so we know ahead of time if it's a keyword or semantic query.
|
||||
@@ -1,573 +0,0 @@
|
||||
# from typing import cast
|
||||
# from typing import Optional
|
||||
# from typing import TYPE_CHECKING
|
||||
|
||||
# import numpy as np
|
||||
# import torch
|
||||
# import torch.nn.functional as F
|
||||
# from fastapi import APIRouter
|
||||
# from huggingface_hub import snapshot_download
|
||||
# from pydantic import BaseModel
|
||||
|
||||
# from model_server.constants import MODEL_WARM_UP_STRING
|
||||
# from model_server.legacy.onyx_torch_model import ConnectorClassifier
|
||||
# from model_server.legacy.onyx_torch_model import HybridClassifier
|
||||
# from model_server.utils import simple_log_function_time
|
||||
# from onyx.utils.logger import setup_logger
|
||||
# from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
|
||||
# from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
|
||||
# from shared_configs.configs import INDEXING_ONLY
|
||||
# from shared_configs.configs import INTENT_MODEL_TAG
|
||||
# from shared_configs.configs import INTENT_MODEL_VERSION
|
||||
# from shared_configs.model_server_models import IntentRequest
|
||||
# from shared_configs.model_server_models import IntentResponse
|
||||
|
||||
# if TYPE_CHECKING:
|
||||
# from setfit import SetFitModel # type: ignore[import-untyped]
|
||||
# from transformers import PreTrainedTokenizer, BatchEncoding
|
||||
|
||||
|
||||
# INFORMATION_CONTENT_MODEL_WARM_UP_STRING = "hi" * 50
|
||||
|
||||
# INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX = 1.0
|
||||
# INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN = 0.7
|
||||
# INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE = 4.0
|
||||
# INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH = 10
|
||||
# INFORMATION_CONTENT_MODEL_VERSION = "onyx-dot-app/information-content-model"
|
||||
# INFORMATION_CONTENT_MODEL_TAG: str | None = None
|
||||
|
||||
|
||||
# class ConnectorClassificationRequest(BaseModel):
|
||||
# available_connectors: list[str]
|
||||
# query: str
|
||||
|
||||
|
||||
# class ConnectorClassificationResponse(BaseModel):
|
||||
# connectors: list[str]
|
||||
|
||||
|
||||
# class ContentClassificationPrediction(BaseModel):
|
||||
# predicted_label: int
|
||||
# content_boost_factor: float
|
||||
|
||||
|
||||
# logger = setup_logger()
|
||||
|
||||
# router = APIRouter(prefix="/custom")
|
||||
|
||||
# _CONNECTOR_CLASSIFIER_TOKENIZER: Optional["PreTrainedTokenizer"] = None
|
||||
# _CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
|
||||
|
||||
# _INTENT_TOKENIZER: Optional["PreTrainedTokenizer"] = None
|
||||
# _INTENT_MODEL: HybridClassifier | None = None
|
||||
|
||||
# _INFORMATION_CONTENT_MODEL: Optional["SetFitModel"] = None
|
||||
|
||||
# _INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
|
||||
|
||||
|
||||
# def get_connector_classifier_tokenizer() -> "PreTrainedTokenizer":
|
||||
# global _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
# from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
# if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
|
||||
# # The tokenizer details are not uploaded to the HF hub since it's just the
|
||||
# # unmodified distilbert tokenizer.
|
||||
# _CONNECTOR_CLASSIFIER_TOKENIZER = cast(
|
||||
# PreTrainedTokenizer,
|
||||
# AutoTokenizer.from_pretrained("distilbert-base-uncased"),
|
||||
# )
|
||||
# return _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
|
||||
|
||||
# def get_local_connector_classifier(
|
||||
# model_name_or_path: str = CONNECTOR_CLASSIFIER_MODEL_REPO,
|
||||
# tag: str = CONNECTOR_CLASSIFIER_MODEL_TAG,
|
||||
# ) -> ConnectorClassifier:
|
||||
# global _CONNECTOR_CLASSIFIER_MODEL
|
||||
# if _CONNECTOR_CLASSIFIER_MODEL is None:
|
||||
# try:
|
||||
# # Calculate where the cache should be, then load from local if available
|
||||
# local_path = snapshot_download(
|
||||
# repo_id=model_name_or_path, revision=tag, local_files_only=True
|
||||
# )
|
||||
# _CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
|
||||
# local_path
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to load model directly: {e}")
|
||||
# try:
|
||||
# # Attempt to download the model snapshot
|
||||
# logger.info(f"Downloading model snapshot for {model_name_or_path}")
|
||||
# local_path = snapshot_download(repo_id=model_name_or_path, revision=tag)
|
||||
# _CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
|
||||
# local_path
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.error(
|
||||
# f"Failed to load model even after attempted snapshot download: {e}"
|
||||
# )
|
||||
# raise
|
||||
# return _CONNECTOR_CLASSIFIER_MODEL
|
||||
|
||||
|
||||
# def get_intent_model_tokenizer() -> "PreTrainedTokenizer":
|
||||
# from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
# global _INTENT_TOKENIZER
|
||||
# if _INTENT_TOKENIZER is None:
|
||||
# # The tokenizer details are not uploaded to the HF hub since it's just the
|
||||
# # unmodified distilbert tokenizer.
|
||||
# _INTENT_TOKENIZER = cast(
|
||||
# PreTrainedTokenizer,
|
||||
# AutoTokenizer.from_pretrained("distilbert-base-uncased"),
|
||||
# )
|
||||
# return _INTENT_TOKENIZER
|
||||
|
||||
|
||||
# def get_local_intent_model(
|
||||
# model_name_or_path: str = INTENT_MODEL_VERSION,
|
||||
# tag: str | None = INTENT_MODEL_TAG,
|
||||
# ) -> HybridClassifier:
|
||||
# global _INTENT_MODEL
|
||||
# if _INTENT_MODEL is None:
|
||||
# try:
|
||||
# # Calculate where the cache should be, then load from local if available
|
||||
# logger.notice(f"Loading model from local cache: {model_name_or_path}")
|
||||
# local_path = snapshot_download(
|
||||
# repo_id=model_name_or_path, revision=tag, local_files_only=True
|
||||
# )
|
||||
# _INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
|
||||
# logger.notice(f"Loaded model from local cache: {local_path}")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to load model directly: {e}")
|
||||
# try:
|
||||
# # Attempt to download the model snapshot
|
||||
# logger.notice(f"Downloading model snapshot for {model_name_or_path}")
|
||||
# local_path = snapshot_download(
|
||||
# repo_id=model_name_or_path, revision=tag, local_files_only=False
|
||||
# )
|
||||
# _INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
|
||||
# except Exception as e:
|
||||
# logger.error(
|
||||
# f"Failed to load model even after attempted snapshot download: {e}"
|
||||
# )
|
||||
# raise
|
||||
# return _INTENT_MODEL
|
||||
|
||||
|
||||
# def get_local_information_content_model(
|
||||
# model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION,
|
||||
# tag: str | None = INFORMATION_CONTENT_MODEL_TAG,
|
||||
# ) -> "SetFitModel":
|
||||
# from setfit import SetFitModel
|
||||
|
||||
# global _INFORMATION_CONTENT_MODEL
|
||||
# if _INFORMATION_CONTENT_MODEL is None:
|
||||
# try:
|
||||
# # Calculate where the cache should be, then load from local if available
|
||||
# logger.notice(
|
||||
# f"Loading content information model from local cache: {model_name_or_path}"
|
||||
# )
|
||||
# local_path = snapshot_download(
|
||||
# repo_id=model_name_or_path, revision=tag, local_files_only=True
|
||||
# )
|
||||
# _INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
|
||||
# logger.notice(
|
||||
# f"Loaded content information model from local cache: {local_path}"
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to load content information model directly: {e}")
|
||||
# try:
|
||||
# # Attempt to download the model snapshot
|
||||
# logger.notice(
|
||||
# f"Downloading content information model snapshot for {model_name_or_path}"
|
||||
# )
|
||||
# local_path = snapshot_download(
|
||||
# repo_id=model_name_or_path, revision=tag, local_files_only=False
|
||||
# )
|
||||
# _INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
|
||||
# except Exception as e:
|
||||
# logger.error(
|
||||
# f"Failed to load content information model even after attempted snapshot download: {e}"
|
||||
# )
|
||||
# raise
|
||||
|
||||
# return _INFORMATION_CONTENT_MODEL
|
||||
|
||||
|
||||
# def tokenize_connector_classification_query(
|
||||
# connectors: list[str],
|
||||
# query: str,
|
||||
# tokenizer: "PreTrainedTokenizer",
|
||||
# connector_token_end_id: int,
|
||||
# ) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# """
|
||||
# Tokenize the connectors & user query into one prompt for the forward pass of ConnectorClassifier models
|
||||
|
||||
# The attention mask is just all 1s. The prompt is CLS + each connector name suffixed with the connector end
|
||||
# token and then the user query.
|
||||
# """
|
||||
|
||||
# input_ids = torch.tensor([tokenizer.cls_token_id], dtype=torch.long)
|
||||
|
||||
# for connector in connectors:
|
||||
# connector_token_ids = tokenizer(
|
||||
# connector,
|
||||
# add_special_tokens=False,
|
||||
# return_tensors="pt",
|
||||
# )
|
||||
|
||||
# input_ids = torch.cat(
|
||||
# (
|
||||
# input_ids,
|
||||
# connector_token_ids["input_ids"].squeeze(dim=0),
|
||||
# torch.tensor([connector_token_end_id], dtype=torch.long),
|
||||
# ),
|
||||
# dim=-1,
|
||||
# )
|
||||
# query_token_ids = tokenizer(
|
||||
# query,
|
||||
# add_special_tokens=False,
|
||||
# return_tensors="pt",
|
||||
# )
|
||||
|
||||
# input_ids = torch.cat(
|
||||
# (
|
||||
# input_ids,
|
||||
# query_token_ids["input_ids"].squeeze(dim=0),
|
||||
# torch.tensor([tokenizer.sep_token_id], dtype=torch.long),
|
||||
# ),
|
||||
# dim=-1,
|
||||
# )
|
||||
# attention_mask = torch.ones(input_ids.numel(), dtype=torch.long)
|
||||
|
||||
# return input_ids.unsqueeze(0), attention_mask.unsqueeze(0)
|
||||
|
||||
|
||||
# def warm_up_connector_classifier_model() -> None:
|
||||
# logger.info(
|
||||
# f"Warming up connector_classifier model {CONNECTOR_CLASSIFIER_MODEL_TAG}"
|
||||
# )
|
||||
# connector_classifier_tokenizer = get_connector_classifier_tokenizer()
|
||||
# connector_classifier = get_local_connector_classifier()
|
||||
|
||||
# input_ids, attention_mask = tokenize_connector_classification_query(
|
||||
# ["GitHub"],
|
||||
# "onyx classifier query google doc",
|
||||
# connector_classifier_tokenizer,
|
||||
# connector_classifier.connector_end_token_id,
|
||||
# )
|
||||
# input_ids = input_ids.to(connector_classifier.device)
|
||||
# attention_mask = attention_mask.to(connector_classifier.device)
|
||||
|
||||
# connector_classifier(input_ids, attention_mask)
|
||||
|
||||
|
||||
# def warm_up_intent_model() -> None:
|
||||
# logger.notice(f"Warming up Intent Model: {INTENT_MODEL_VERSION}")
|
||||
# intent_tokenizer = get_intent_model_tokenizer()
|
||||
# tokens = intent_tokenizer(
|
||||
# MODEL_WARM_UP_STRING, return_tensors="pt", truncation=True, padding=True
|
||||
# )
|
||||
|
||||
# intent_model = get_local_intent_model()
|
||||
# device = intent_model.device
|
||||
# intent_model(
|
||||
# query_ids=tokens["input_ids"].to(device),
|
||||
# query_mask=tokens["attention_mask"].to(device),
|
||||
# )
|
||||
|
||||
|
||||
# def warm_up_information_content_model() -> None:
|
||||
# logger.notice("Warming up Content Model") # TODO: add version if needed
|
||||
|
||||
# information_content_model = get_local_information_content_model()
|
||||
# information_content_model(INFORMATION_CONTENT_MODEL_WARM_UP_STRING)
|
||||
|
||||
|
||||
# @simple_log_function_time()
|
||||
# def run_inference(tokens: "BatchEncoding") -> tuple[list[float], list[float]]:
|
||||
# intent_model = get_local_intent_model()
|
||||
# device = intent_model.device
|
||||
|
||||
# outputs = intent_model(
|
||||
# query_ids=tokens["input_ids"].to(device),
|
||||
# query_mask=tokens["attention_mask"].to(device),
|
||||
# )
|
||||
|
||||
# token_logits = outputs["token_logits"]
|
||||
# intent_logits = outputs["intent_logits"]
|
||||
|
||||
# # Move tensors to CPU before applying softmax and converting to numpy
|
||||
# intent_probabilities = F.softmax(intent_logits.cpu(), dim=-1).numpy()[0]
|
||||
# token_probabilities = F.softmax(token_logits.cpu(), dim=-1).numpy()[0]
|
||||
|
||||
# # Extract the probabilities for the positive class (index 1) for each token
|
||||
# token_positive_probs = token_probabilities[:, 1].tolist()
|
||||
|
||||
# return intent_probabilities.tolist(), token_positive_probs
|
||||
|
||||
|
||||
# @simple_log_function_time()
|
||||
# def run_content_classification_inference(
|
||||
# text_inputs: list[str],
|
||||
# ) -> list[ContentClassificationPrediction]:
|
||||
# """
|
||||
# Assign a score to the segments in question. The model stored in get_local_information_content_model()
|
||||
# creates the 'model score' based on its training, and the scores are then converted to a 0.0-1.0 scale.
|
||||
# In the code outside of the model/inference model servers that score will be converted into the actual
|
||||
# boost factor.
|
||||
# """
|
||||
|
||||
# def _prob_to_score(prob: float) -> float:
|
||||
# """
|
||||
# Conversion of base score to 0.0 - 1.0 score. Note that the min/max values depend on the model!
|
||||
# """
|
||||
# _MIN_BASE_SCORE = 0.25
|
||||
# _MAX_BASE_SCORE = 0.75
|
||||
# if prob < _MIN_BASE_SCORE:
|
||||
# raw_score = 0.0
|
||||
# elif prob < _MAX_BASE_SCORE:
|
||||
# raw_score = (prob - _MIN_BASE_SCORE) / (_MAX_BASE_SCORE - _MIN_BASE_SCORE)
|
||||
# else:
|
||||
# raw_score = 1.0
|
||||
# return (
|
||||
# INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
# + (
|
||||
# INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
|
||||
# - INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
# )
|
||||
# * raw_score
|
||||
# )
|
||||
|
||||
# _BATCH_SIZE = 32
|
||||
# content_model = get_local_information_content_model()
|
||||
|
||||
# # Process inputs in batches
|
||||
# all_output_classes: list[int] = []
|
||||
# all_base_output_probabilities: list[float] = []
|
||||
|
||||
# for i in range(0, len(text_inputs), _BATCH_SIZE):
|
||||
# batch = text_inputs[i : i + _BATCH_SIZE]
|
||||
# batch_with_prefix = []
|
||||
# batch_indices = []
|
||||
|
||||
# # Pre-allocate results for this batch
|
||||
# batch_output_classes: list[np.ndarray] = [np.array(1)] * len(batch)
|
||||
# batch_probabilities: list[np.ndarray] = [np.array(1.0)] * len(batch)
|
||||
|
||||
# # Pre-process batch to handle long input exceptions
|
||||
# for j, text in enumerate(batch):
|
||||
# if len(text) == 0:
|
||||
# # if no input, treat as non-informative from the model's perspective
|
||||
# batch_output_classes[j] = np.array(0)
|
||||
# batch_probabilities[j] = np.array(0.0)
|
||||
# logger.warning("Input for Content Information Model is empty")
|
||||
|
||||
# elif (
|
||||
# len(text.split())
|
||||
# <= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
|
||||
# ):
|
||||
# # if input is short, use the model
|
||||
# batch_with_prefix.append(
|
||||
# _INFORMATION_CONTENT_MODEL_PROMPT_PREFIX + text
|
||||
# )
|
||||
# batch_indices.append(j)
|
||||
# else:
|
||||
# # if longer than cutoff, treat as informative (stay with default), but issue warning
|
||||
# logger.warning("Input for Content Information Model too long")
|
||||
|
||||
# if batch_with_prefix: # Only run model if we have valid inputs
|
||||
# # Get predictions for the batch
|
||||
# model_output_classes = content_model(batch_with_prefix)
|
||||
# model_output_probabilities = content_model.predict_proba(batch_with_prefix)
|
||||
|
||||
# # Place results in the correct positions
|
||||
# for idx, batch_idx in enumerate(batch_indices):
|
||||
# batch_output_classes[batch_idx] = model_output_classes[idx].numpy()
|
||||
# batch_probabilities[batch_idx] = model_output_probabilities[idx][
|
||||
# 1
|
||||
# ].numpy() # x[1] is prob of the positive class
|
||||
|
||||
# all_output_classes.extend([int(x) for x in batch_output_classes])
|
||||
# all_base_output_probabilities.extend([float(x) for x in batch_probabilities])
|
||||
|
||||
# logits = [
|
||||
# np.log(p / (1 - p)) if p != 0.0 and p != 1.0 else (100 if p == 1.0 else -100)
|
||||
# for p in all_base_output_probabilities
|
||||
# ]
|
||||
# scaled_logits = [
|
||||
# logit / INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE
|
||||
# for logit in logits
|
||||
# ]
|
||||
# output_probabilities_with_temp = [
|
||||
# np.exp(scaled_logit) / (1 + np.exp(scaled_logit))
|
||||
# for scaled_logit in scaled_logits
|
||||
# ]
|
||||
|
||||
# prediction_scores = [
|
||||
# _prob_to_score(p_temp) for p_temp in output_probabilities_with_temp
|
||||
# ]
|
||||
|
||||
# content_classification_predictions = [
|
||||
# ContentClassificationPrediction(
|
||||
# predicted_label=predicted_label, content_boost_factor=output_score
|
||||
# )
|
||||
# for predicted_label, output_score in zip(all_output_classes, prediction_scores)
|
||||
# ]
|
||||
|
||||
# return content_classification_predictions
|
||||
|
||||
|
||||
# def map_keywords(
|
||||
# input_ids: torch.Tensor, tokenizer: "PreTrainedTokenizer", is_keyword: list[bool]
|
||||
# ) -> list[str]:
|
||||
# tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore
|
||||
|
||||
# if not len(tokens) == len(is_keyword):
|
||||
# raise ValueError("Length of tokens and keyword predictions must match")
|
||||
|
||||
# if input_ids[0] == tokenizer.cls_token_id:
|
||||
# tokens = tokens[1:]
|
||||
# is_keyword = is_keyword[1:]
|
||||
|
||||
# if input_ids[-1] == tokenizer.sep_token_id:
|
||||
# tokens = tokens[:-1]
|
||||
# is_keyword = is_keyword[:-1]
|
||||
|
||||
# unk_token = tokenizer.unk_token
|
||||
# if unk_token in tokens:
|
||||
# raise ValueError("Unknown token detected in the input")
|
||||
|
||||
# keywords = []
|
||||
# current_keyword = ""
|
||||
|
||||
# for ind, token in enumerate(tokens):
|
||||
# if is_keyword[ind]:
|
||||
# if token.startswith("##"):
|
||||
# current_keyword += token[2:]
|
||||
# else:
|
||||
# if current_keyword:
|
||||
# keywords.append(current_keyword)
|
||||
# current_keyword = token
|
||||
# else:
|
||||
# # If mispredicted a later token of a keyword, add it to the current keyword
|
||||
# # to complete it
|
||||
# if current_keyword:
|
||||
# if len(current_keyword) > 2 and current_keyword.startswith("##"):
|
||||
# current_keyword = current_keyword[2:]
|
||||
|
||||
# else:
|
||||
# keywords.append(current_keyword)
|
||||
# current_keyword = ""
|
||||
|
||||
# if current_keyword:
|
||||
# keywords.append(current_keyword)
|
||||
|
||||
# return keywords
|
||||
|
||||
|
||||
# def clean_keywords(keywords: list[str]) -> list[str]:
|
||||
# cleaned_words = []
|
||||
# for word in keywords:
|
||||
# word = word[:-2] if word.endswith("'s") else word
|
||||
# word = word.replace("/", " ")
|
||||
# word = word.replace("'", "").replace('"', "")
|
||||
# cleaned_words.extend([w for w in word.strip().split() if w and not w.isspace()])
|
||||
# return cleaned_words
|
||||
|
||||
|
||||
# def run_connector_classification(req: ConnectorClassificationRequest) -> list[str]:
|
||||
# tokenizer = get_connector_classifier_tokenizer()
|
||||
# model = get_local_connector_classifier()
|
||||
|
||||
# connector_names = req.available_connectors
|
||||
|
||||
# input_ids, attention_mask = tokenize_connector_classification_query(
|
||||
# connector_names,
|
||||
# req.query,
|
||||
# tokenizer,
|
||||
# model.connector_end_token_id,
|
||||
# )
|
||||
# input_ids = input_ids.to(model.device)
|
||||
# attention_mask = attention_mask.to(model.device)
|
||||
|
||||
# global_confidence, classifier_confidence = model(input_ids, attention_mask)
|
||||
|
||||
# if global_confidence.item() < 0.5:
|
||||
# return []
|
||||
|
||||
# passed_connectors = []
|
||||
|
||||
# for i, connector_name in enumerate(connector_names):
|
||||
# if classifier_confidence.view(-1)[i].item() > 0.5:
|
||||
# passed_connectors.append(connector_name)
|
||||
|
||||
# return passed_connectors
|
||||
|
||||
|
||||
# def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]:
|
||||
# tokenizer = get_intent_model_tokenizer()
|
||||
# model_input = tokenizer(
|
||||
# intent_req.query, return_tensors="pt", truncation=False, padding=False
|
||||
# )
|
||||
|
||||
# if len(model_input.input_ids[0]) > 512:
|
||||
# # If the user text is too long, assume it is semantic and keep all words
|
||||
# return True, intent_req.query.split()
|
||||
|
||||
# intent_probs, token_probs = run_inference(model_input)
|
||||
|
||||
# is_keyword_sequence = intent_probs[0] >= intent_req.keyword_percent_threshold
|
||||
|
||||
# keyword_preds = [
|
||||
# token_prob >= intent_req.keyword_percent_threshold for token_prob in token_probs
|
||||
# ]
|
||||
|
||||
# try:
|
||||
# keywords = map_keywords(model_input.input_ids[0], tokenizer, keyword_preds)
|
||||
# except Exception as e:
|
||||
# logger.warning(
|
||||
# f"Failed to extract keywords for query: {intent_req.query} due to {e}"
|
||||
# )
|
||||
# # Fallback to keeping all words
|
||||
# keywords = intent_req.query.split()
|
||||
|
||||
# cleaned_keywords = clean_keywords(keywords)
|
||||
|
||||
# return is_keyword_sequence, cleaned_keywords
|
||||
|
||||
|
||||
# @router.post("/connector-classification")
|
||||
# async def process_connector_classification_request(
|
||||
# classification_request: ConnectorClassificationRequest,
|
||||
# ) -> ConnectorClassificationResponse:
|
||||
# if INDEXING_ONLY:
|
||||
# raise RuntimeError(
|
||||
# "Indexing model server should not call connector classification endpoint"
|
||||
# )
|
||||
|
||||
# if len(classification_request.available_connectors) == 0:
|
||||
# return ConnectorClassificationResponse(connectors=[])
|
||||
|
||||
# connectors = run_connector_classification(classification_request)
|
||||
# return ConnectorClassificationResponse(connectors=connectors)
|
||||
|
||||
|
||||
# @router.post("/query-analysis")
|
||||
# async def process_analysis_request(
|
||||
# intent_request: IntentRequest,
|
||||
# ) -> IntentResponse:
|
||||
# if INDEXING_ONLY:
|
||||
# raise RuntimeError("Indexing model server should not call intent endpoint")
|
||||
|
||||
# is_keyword, keywords = run_analysis(intent_request)
|
||||
# return IntentResponse(is_keyword=is_keyword, keywords=keywords)
|
||||
|
||||
|
||||
# @router.post("/content-classification")
|
||||
# async def process_content_classification_request(
|
||||
# content_classification_requests: list[str],
|
||||
# ) -> list[ContentClassificationPrediction]:
|
||||
# return run_content_classification_inference(content_classification_requests)
|
||||
@@ -1,154 +0,0 @@
|
||||
# import json
|
||||
# import os
|
||||
# from typing import cast
|
||||
# from typing import TYPE_CHECKING
|
||||
|
||||
# import torch
|
||||
# import torch.nn as nn
|
||||
|
||||
|
||||
# if TYPE_CHECKING:
|
||||
# from transformers import DistilBertConfig
|
||||
|
||||
|
||||
# class HybridClassifier(nn.Module):
|
||||
# def __init__(self) -> None:
|
||||
# from transformers import DistilBertConfig, DistilBertModel
|
||||
|
||||
# super().__init__()
|
||||
# config = DistilBertConfig()
|
||||
# self.distilbert = DistilBertModel(config)
|
||||
# config = self.distilbert.config # type: ignore
|
||||
|
||||
# # Keyword tokenwise binary classification layer
|
||||
# self.keyword_classifier = nn.Linear(config.dim, 2)
|
||||
|
||||
# # Intent Classifier layers
|
||||
# self.pre_classifier = nn.Linear(config.dim, config.dim)
|
||||
# self.intent_classifier = nn.Linear(config.dim, 2)
|
||||
|
||||
# self.device = torch.device("cpu")
|
||||
|
||||
# def forward(
|
||||
# self,
|
||||
# query_ids: torch.Tensor,
|
||||
# query_mask: torch.Tensor,
|
||||
# ) -> dict[str, torch.Tensor]:
|
||||
# outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask)
|
||||
# sequence_output = outputs.last_hidden_state
|
||||
|
||||
# # Intent classification on the CLS token
|
||||
# cls_token_state = sequence_output[:, 0, :]
|
||||
# pre_classifier_out = self.pre_classifier(cls_token_state)
|
||||
# intent_logits = self.intent_classifier(pre_classifier_out)
|
||||
|
||||
# # Keyword classification on all tokens
|
||||
# token_logits = self.keyword_classifier(sequence_output)
|
||||
|
||||
# return {"intent_logits": intent_logits, "token_logits": token_logits}
|
||||
|
||||
# @classmethod
|
||||
# def from_pretrained(cls, load_directory: str) -> "HybridClassifier":
|
||||
# model_path = os.path.join(load_directory, "pytorch_model.bin")
|
||||
# config_path = os.path.join(load_directory, "config.json")
|
||||
|
||||
# with open(config_path, "r") as f:
|
||||
# config = json.load(f)
|
||||
# model = cls(**config)
|
||||
|
||||
# if torch.backends.mps.is_available():
|
||||
# # Apple silicon GPU
|
||||
# device = torch.device("mps")
|
||||
# elif torch.cuda.is_available():
|
||||
# device = torch.device("cuda")
|
||||
# else:
|
||||
# device = torch.device("cpu")
|
||||
|
||||
# model.load_state_dict(torch.load(model_path, map_location=device))
|
||||
# model = model.to(device)
|
||||
|
||||
# model.device = device
|
||||
|
||||
# model.eval()
|
||||
# # Eval doesn't set requires_grad to False, do it manually to save memory and have faster inference
|
||||
# for param in model.parameters():
|
||||
# param.requires_grad = False
|
||||
|
||||
# return model
|
||||
|
||||
|
||||
# class ConnectorClassifier(nn.Module):
|
||||
# def __init__(self, config: "DistilBertConfig") -> None:
|
||||
# from transformers import DistilBertTokenizer, DistilBertModel
|
||||
|
||||
# super().__init__()
|
||||
|
||||
# self.config = config
|
||||
# self.distilbert = DistilBertModel(config)
|
||||
# config = self.distilbert.config # type: ignore
|
||||
# self.connector_global_classifier = nn.Linear(config.dim, 1)
|
||||
# self.connector_match_classifier = nn.Linear(config.dim, 1)
|
||||
# self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
||||
|
||||
# # Token indicating end of connector name, and on which classifier is used
|
||||
# self.connector_end_token_id = self.tokenizer.get_vocab()[
|
||||
# self.config.connector_end_token
|
||||
# ]
|
||||
|
||||
# self.device = torch.device("cpu")
|
||||
|
||||
# def forward(
|
||||
# self,
|
||||
# input_ids: torch.Tensor,
|
||||
# attention_mask: torch.Tensor,
|
||||
# ) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# hidden_states = self.distilbert(
|
||||
# input_ids=input_ids, attention_mask=attention_mask
|
||||
# ).last_hidden_state
|
||||
|
||||
# cls_hidden_states = hidden_states[
|
||||
# :, 0, :
|
||||
# ] # Take leap of faith that first token is always [CLS]
|
||||
# global_logits = self.connector_global_classifier(cls_hidden_states).view(-1)
|
||||
# global_confidence = torch.sigmoid(global_logits).view(-1)
|
||||
|
||||
# connector_end_position_ids = input_ids == self.connector_end_token_id
|
||||
# connector_end_hidden_states = hidden_states[connector_end_position_ids]
|
||||
# classifier_output = self.connector_match_classifier(connector_end_hidden_states)
|
||||
# classifier_confidence = torch.nn.functional.sigmoid(classifier_output).view(-1)
|
||||
|
||||
# return global_confidence, classifier_confidence
|
||||
|
||||
# @classmethod
|
||||
# def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
|
||||
# from transformers import DistilBertConfig
|
||||
|
||||
# config = cast(
|
||||
# DistilBertConfig,
|
||||
# DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")),
|
||||
# )
|
||||
# device = (
|
||||
# torch.device("cuda")
|
||||
# if torch.cuda.is_available()
|
||||
# else (
|
||||
# torch.device("mps")
|
||||
# if torch.backends.mps.is_available()
|
||||
# else torch.device("cpu")
|
||||
# )
|
||||
# )
|
||||
# state_dict = torch.load(
|
||||
# os.path.join(repo_dir, "pytorch_model.pt"),
|
||||
# map_location=device,
|
||||
# weights_only=True,
|
||||
# )
|
||||
|
||||
# model = cls(config)
|
||||
# model.load_state_dict(state_dict)
|
||||
# model.to(device)
|
||||
# model.device = device
|
||||
# model.eval()
|
||||
|
||||
# for param in model.parameters():
|
||||
# param.requires_grad = False
|
||||
|
||||
# return model
|
||||
@@ -1,80 +0,0 @@
|
||||
# import asyncio
|
||||
# from typing import Optional
|
||||
# from typing import TYPE_CHECKING
|
||||
|
||||
# from fastapi import APIRouter
|
||||
# from fastapi import HTTPException
|
||||
|
||||
# from model_server.utils import simple_log_function_time
|
||||
# from onyx.utils.logger import setup_logger
|
||||
# from shared_configs.configs import INDEXING_ONLY
|
||||
# from shared_configs.model_server_models import RerankRequest
|
||||
# from shared_configs.model_server_models import RerankResponse
|
||||
|
||||
# if TYPE_CHECKING:
|
||||
# from sentence_transformers import CrossEncoder
|
||||
|
||||
# logger = setup_logger()
|
||||
|
||||
# router = APIRouter(prefix="/encoder")
|
||||
|
||||
# _RERANK_MODEL: Optional["CrossEncoder"] = None
|
||||
|
||||
|
||||
# def get_local_reranking_model(
|
||||
# model_name: str,
|
||||
# ) -> "CrossEncoder":
|
||||
# global _RERANK_MODEL
|
||||
# from sentence_transformers import CrossEncoder
|
||||
|
||||
# if _RERANK_MODEL is None:
|
||||
# logger.notice(f"Loading {model_name}")
|
||||
# model = CrossEncoder(model_name)
|
||||
# _RERANK_MODEL = model
|
||||
# return _RERANK_MODEL
|
||||
|
||||
|
||||
# @simple_log_function_time()
|
||||
# async def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
|
||||
# cross_encoder = get_local_reranking_model(model_name)
|
||||
# # Run CPU-bound reranking in a thread pool
|
||||
# return await asyncio.get_event_loop().run_in_executor(
|
||||
# None,
|
||||
# lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(),
|
||||
# )
|
||||
|
||||
|
||||
# @router.post("/cross-encoder-scores")
|
||||
# async def process_rerank_request(rerank_request: RerankRequest) -> RerankResponse:
|
||||
# """Cross encoders can be purely black box from the app perspective"""
|
||||
# # Only local models should use this endpoint - API providers should make direct API calls
|
||||
# if rerank_request.provider_type is not None:
|
||||
# raise ValueError(
|
||||
# f"Model server reranking endpoint should only be used for local models. "
|
||||
# f"API provider '{rerank_request.provider_type}' should make direct API calls instead."
|
||||
# )
|
||||
|
||||
# if INDEXING_ONLY:
|
||||
# raise RuntimeError("Indexing model server should not call reranking endpoint")
|
||||
|
||||
# if not rerank_request.documents or not rerank_request.query:
|
||||
# raise HTTPException(
|
||||
# status_code=400, detail="Missing documents or query for reranking"
|
||||
# )
|
||||
# if not all(rerank_request.documents):
|
||||
# raise ValueError("Empty documents cannot be reranked.")
|
||||
|
||||
# try:
|
||||
# # At this point, provider_type is None, so handle local reranking
|
||||
# sim_scores = await local_rerank(
|
||||
# query=rerank_request.query,
|
||||
# docs=rerank_request.documents,
|
||||
# model_name=rerank_request.model_name,
|
||||
# )
|
||||
# return RerankResponse(scores=sim_scores)
|
||||
|
||||
# except Exception as e:
|
||||
# logger.exception(f"Error during reranking process:\n{str(e)}")
|
||||
# raise HTTPException(
|
||||
# status_code=500, detail="Failed to run Cross-Encoder reranking"
|
||||
# )
|
||||
@@ -12,8 +12,11 @@ from fastapi import FastAPI
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
||||
from sentry_sdk.integrations.starlette import StarletteIntegration
|
||||
from transformers import logging as transformer_logging
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
|
||||
from model_server.custom_models import router as custom_models_router
|
||||
from model_server.custom_models import warm_up_information_content_model
|
||||
from model_server.custom_models import warm_up_intent_model
|
||||
from model_server.encoders import router as encoders_router
|
||||
from model_server.management_endpoints import router as management_router
|
||||
from model_server.utils import get_gpu_type
|
||||
@@ -27,6 +30,7 @@ from shared_configs.configs import MIN_THREADS_ML_MODELS
|
||||
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
from shared_configs.configs import SKIP_WARM_UP
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
@@ -88,6 +92,18 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
|
||||
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
if not SKIP_WARM_UP:
|
||||
if not INDEXING_ONLY:
|
||||
logger.notice("Warming up intent model for inference model server")
|
||||
warm_up_intent_model()
|
||||
else:
|
||||
logger.notice(
|
||||
"Warming up content information model for indexing model server"
|
||||
)
|
||||
warm_up_information_content_model()
|
||||
else:
|
||||
logger.notice("Skipping model warmup due to SKIP_WARM_UP=true")
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@@ -107,6 +123,7 @@ def get_model_app() -> FastAPI:
|
||||
|
||||
application.include_router(management_router)
|
||||
application.include_router(encoders_router)
|
||||
application.include_router(custom_models_router)
|
||||
|
||||
request_id_prefix = "INF"
|
||||
if INDEXING_ONLY:
|
||||
|
||||
154
backend/model_server/onyx_torch_model.py
Normal file
154
backend/model_server/onyx_torch_model.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import json
|
||||
import os
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import DistilBertConfig # type: ignore
|
||||
|
||||
|
||||
class HybridClassifier(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
from transformers import DistilBertConfig, DistilBertModel
|
||||
|
||||
super().__init__()
|
||||
config = DistilBertConfig()
|
||||
self.distilbert = DistilBertModel(config)
|
||||
config = self.distilbert.config # type: ignore
|
||||
|
||||
# Keyword tokenwise binary classification layer
|
||||
self.keyword_classifier = nn.Linear(config.dim, 2)
|
||||
|
||||
# Intent Classifier layers
|
||||
self.pre_classifier = nn.Linear(config.dim, config.dim)
|
||||
self.intent_classifier = nn.Linear(config.dim, 2)
|
||||
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query_ids: torch.Tensor,
|
||||
query_mask: torch.Tensor,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask) # type: ignore
|
||||
sequence_output = outputs.last_hidden_state
|
||||
|
||||
# Intent classification on the CLS token
|
||||
cls_token_state = sequence_output[:, 0, :]
|
||||
pre_classifier_out = self.pre_classifier(cls_token_state)
|
||||
intent_logits = self.intent_classifier(pre_classifier_out)
|
||||
|
||||
# Keyword classification on all tokens
|
||||
token_logits = self.keyword_classifier(sequence_output)
|
||||
|
||||
return {"intent_logits": intent_logits, "token_logits": token_logits}
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, load_directory: str) -> "HybridClassifier":
|
||||
model_path = os.path.join(load_directory, "pytorch_model.bin")
|
||||
config_path = os.path.join(load_directory, "config.json")
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
model = cls(**config)
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
# Apple silicon GPU
|
||||
device = torch.device("mps")
|
||||
elif torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
model.load_state_dict(torch.load(model_path, map_location=device))
|
||||
model = model.to(device)
|
||||
|
||||
model.device = device
|
||||
|
||||
model.eval()
|
||||
# Eval doesn't set requires_grad to False, do it manually to save memory and have faster inference
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class ConnectorClassifier(nn.Module):
|
||||
def __init__(self, config: "DistilBertConfig") -> None:
|
||||
from transformers import DistilBertTokenizer, DistilBertModel
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.distilbert = DistilBertModel(config)
|
||||
config = self.distilbert.config # type: ignore
|
||||
self.connector_global_classifier = nn.Linear(config.dim, 1)
|
||||
self.connector_match_classifier = nn.Linear(config.dim, 1)
|
||||
self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
||||
|
||||
# Token indicating end of connector name, and on which classifier is used
|
||||
self.connector_end_token_id = self.tokenizer.get_vocab()[
|
||||
self.config.connector_end_token
|
||||
]
|
||||
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_states = self.distilbert( # type: ignore
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
).last_hidden_state
|
||||
|
||||
cls_hidden_states = hidden_states[
|
||||
:, 0, :
|
||||
] # Take leap of faith that first token is always [CLS]
|
||||
global_logits = self.connector_global_classifier(cls_hidden_states).view(-1)
|
||||
global_confidence = torch.sigmoid(global_logits).view(-1)
|
||||
|
||||
connector_end_position_ids = input_ids == self.connector_end_token_id
|
||||
connector_end_hidden_states = hidden_states[connector_end_position_ids]
|
||||
classifier_output = self.connector_match_classifier(connector_end_hidden_states)
|
||||
classifier_confidence = torch.nn.functional.sigmoid(classifier_output).view(-1)
|
||||
|
||||
return global_confidence, classifier_confidence
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
|
||||
from transformers import DistilBertConfig
|
||||
|
||||
config = cast(
|
||||
DistilBertConfig,
|
||||
DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")),
|
||||
)
|
||||
device = (
|
||||
torch.device("cuda")
|
||||
if torch.cuda.is_available()
|
||||
else (
|
||||
torch.device("mps")
|
||||
if torch.backends.mps.is_available()
|
||||
else torch.device("cpu")
|
||||
)
|
||||
)
|
||||
state_dict = torch.load(
|
||||
os.path.join(repo_dir, "pytorch_model.pt"),
|
||||
map_location=device,
|
||||
weights_only=True,
|
||||
)
|
||||
|
||||
model = cls(config)
|
||||
model.load_state_dict(state_dict)
|
||||
model.to(device)
|
||||
model.device = device
|
||||
model.eval()
|
||||
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
return model
|
||||
@@ -43,7 +43,7 @@ def get_access_for_document(
|
||||
versioned_get_access_for_document_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "_get_access_for_document"
|
||||
)
|
||||
return versioned_get_access_for_document_fn(document_id, db_session)
|
||||
return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore
|
||||
|
||||
|
||||
def get_null_document_access() -> DocumentAccess:
|
||||
@@ -93,7 +93,9 @@ def get_access_for_documents(
|
||||
versioned_get_access_for_documents_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "_get_access_for_documents"
|
||||
)
|
||||
return versioned_get_access_for_documents_fn(document_ids, db_session)
|
||||
return versioned_get_access_for_documents_fn(
|
||||
document_ids, db_session
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
@@ -111,7 +113,7 @@ def get_acl_for_user(user: User | None, db_session: Session | None = None) -> se
|
||||
versioned_acl_for_user_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "_get_acl_for_user"
|
||||
)
|
||||
return versioned_acl_for_user_fn(user, db_session)
|
||||
return versioned_acl_for_user_fn(user, db_session) # type: ignore
|
||||
|
||||
|
||||
def source_should_fetch_permissions_during_indexing(source: DocumentSource) -> bool:
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
"""Captcha verification for user registration."""
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.app_configs import CAPTCHA_ENABLED
|
||||
from onyx.configs.app_configs import RECAPTCHA_SCORE_THRESHOLD
|
||||
from onyx.configs.app_configs import RECAPTCHA_SECRET_KEY
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
RECAPTCHA_VERIFY_URL = "https://www.google.com/recaptcha/api/siteverify"
|
||||
|
||||
|
||||
class CaptchaVerificationError(Exception):
|
||||
"""Raised when captcha verification fails."""
|
||||
|
||||
|
||||
class RecaptchaResponse(BaseModel):
|
||||
"""Response from Google reCAPTCHA verification API."""
|
||||
|
||||
success: bool
|
||||
score: float | None = None # Only present for reCAPTCHA v3
|
||||
action: str | None = None
|
||||
challenge_ts: str | None = None
|
||||
hostname: str | None = None
|
||||
error_codes: list[str] | None = Field(default=None, alias="error-codes")
|
||||
|
||||
|
||||
def is_captcha_enabled() -> bool:
|
||||
"""Check if captcha verification is enabled."""
|
||||
return CAPTCHA_ENABLED and bool(RECAPTCHA_SECRET_KEY)
|
||||
|
||||
|
||||
async def verify_captcha_token(
|
||||
token: str,
|
||||
expected_action: str = "signup",
|
||||
) -> None:
|
||||
"""
|
||||
Verify a reCAPTCHA token with Google's API.
|
||||
|
||||
Args:
|
||||
token: The reCAPTCHA response token from the client
|
||||
expected_action: Expected action name for v3 verification
|
||||
|
||||
Raises:
|
||||
CaptchaVerificationError: If verification fails
|
||||
"""
|
||||
if not is_captcha_enabled():
|
||||
return
|
||||
|
||||
if not token:
|
||||
raise CaptchaVerificationError("Captcha token is required")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
RECAPTCHA_VERIFY_URL,
|
||||
data={
|
||||
"secret": RECAPTCHA_SECRET_KEY,
|
||||
"response": token,
|
||||
},
|
||||
timeout=10.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
result = RecaptchaResponse(**data)
|
||||
|
||||
if not result.success:
|
||||
error_codes = result.error_codes or ["unknown-error"]
|
||||
logger.warning(f"Captcha verification failed: {error_codes}")
|
||||
raise CaptchaVerificationError(
|
||||
f"Captcha verification failed: {', '.join(error_codes)}"
|
||||
)
|
||||
|
||||
# For reCAPTCHA v3, also check the score
|
||||
if result.score is not None:
|
||||
if result.score < RECAPTCHA_SCORE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Captcha score too low: {result.score} < {RECAPTCHA_SCORE_THRESHOLD}"
|
||||
)
|
||||
raise CaptchaVerificationError(
|
||||
"Captcha verification failed: suspicious activity detected"
|
||||
)
|
||||
|
||||
# Optionally verify the action matches
|
||||
if result.action and result.action != expected_action:
|
||||
logger.warning(
|
||||
f"Captcha action mismatch: {result.action} != {expected_action}"
|
||||
)
|
||||
raise CaptchaVerificationError(
|
||||
"Captcha verification failed: action mismatch"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Captcha verification passed: score={result.score}, "
|
||||
f"action={result.action}"
|
||||
)
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Captcha API request failed: {e}")
|
||||
# In case of API errors, we might want to allow registration
|
||||
# to prevent blocking legitimate users. This is a policy decision.
|
||||
raise CaptchaVerificationError("Captcha verification service unavailable")
|
||||
@@ -1,192 +0,0 @@
|
||||
"""
|
||||
Utility to validate and block disposable/temporary email addresses.
|
||||
|
||||
This module fetches a list of known disposable email domains from a remote source
|
||||
and caches them for performance. It's used during user registration to prevent
|
||||
abuse from temporary email services.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from typing import Set
|
||||
|
||||
import httpx
|
||||
|
||||
from onyx.configs.app_configs import DISPOSABLE_EMAIL_DOMAINS_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DisposableEmailValidator:
|
||||
"""
|
||||
Thread-safe singleton validator for disposable email domains.
|
||||
|
||||
Fetches and caches the list of disposable domains, with periodic refresh.
|
||||
"""
|
||||
|
||||
_instance: "DisposableEmailValidator | None" = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls) -> "DisposableEmailValidator":
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Check if already initialized using a try/except to avoid type issues
|
||||
try:
|
||||
if self._initialized:
|
||||
return
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
self._domains: Set[str] = set()
|
||||
self._last_fetch_time: float = 0
|
||||
self._fetch_lock = threading.Lock()
|
||||
# Cache for 1 hour
|
||||
self._cache_duration = 3600
|
||||
# Hardcoded fallback list of common disposable domains
|
||||
# This ensures we block at least these even if the remote fetch fails
|
||||
self._fallback_domains = {
|
||||
"trashlify.com",
|
||||
"10minutemail.com",
|
||||
"guerrillamail.com",
|
||||
"mailinator.com",
|
||||
"tempmail.com",
|
||||
"throwaway.email",
|
||||
"yopmail.com",
|
||||
"temp-mail.org",
|
||||
"getnada.com",
|
||||
"maildrop.cc",
|
||||
}
|
||||
# Set initialized flag last to prevent race conditions
|
||||
self._initialized: bool = True
|
||||
|
||||
def _should_refresh(self) -> bool:
|
||||
"""Check if the cached domains should be refreshed."""
|
||||
return (time.time() - self._last_fetch_time) > self._cache_duration
|
||||
|
||||
def _fetch_domains(self) -> Set[str]:
|
||||
"""
|
||||
Fetch disposable email domains from the configured URL.
|
||||
|
||||
Returns:
|
||||
Set of domain strings (lowercased)
|
||||
"""
|
||||
if not DISPOSABLE_EMAIL_DOMAINS_URL:
|
||||
logger.debug("DISPOSABLE_EMAIL_DOMAINS_URL not configured")
|
||||
return self._fallback_domains.copy()
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"Fetching disposable email domains from {DISPOSABLE_EMAIL_DOMAINS_URL}"
|
||||
)
|
||||
with httpx.Client(timeout=10.0) as client:
|
||||
response = client.get(DISPOSABLE_EMAIL_DOMAINS_URL)
|
||||
response.raise_for_status()
|
||||
|
||||
domains_list = response.json()
|
||||
|
||||
if not isinstance(domains_list, list):
|
||||
logger.error(
|
||||
f"Expected list from disposable domains URL, got {type(domains_list)}"
|
||||
)
|
||||
return self._fallback_domains.copy()
|
||||
|
||||
# Convert all to lowercase and create set
|
||||
domains = {domain.lower().strip() for domain in domains_list if domain}
|
||||
|
||||
# Always include fallback domains
|
||||
domains.update(self._fallback_domains)
|
||||
|
||||
logger.info(
|
||||
f"Successfully fetched {len(domains)} disposable email domains"
|
||||
)
|
||||
return domains
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.warning(f"Failed to fetch disposable domains (HTTP error): {e}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch disposable domains: {e}")
|
||||
|
||||
# On error, return fallback domains
|
||||
return self._fallback_domains.copy()
|
||||
|
||||
def get_domains(self) -> Set[str]:
|
||||
"""
|
||||
Get the cached set of disposable email domains.
|
||||
Refreshes the cache if needed.
|
||||
|
||||
Returns:
|
||||
Set of disposable domain strings (lowercased)
|
||||
"""
|
||||
# Fast path: return cached domains if still fresh
|
||||
if self._domains and not self._should_refresh():
|
||||
return self._domains.copy()
|
||||
|
||||
# Slow path: need to refresh
|
||||
with self._fetch_lock:
|
||||
# Double-check after acquiring lock
|
||||
if self._domains and not self._should_refresh():
|
||||
return self._domains.copy()
|
||||
|
||||
self._domains = self._fetch_domains()
|
||||
self._last_fetch_time = time.time()
|
||||
return self._domains.copy()
|
||||
|
||||
def is_disposable(self, email: str) -> bool:
|
||||
"""
|
||||
Check if an email address uses a disposable domain.
|
||||
|
||||
Args:
|
||||
email: The email address to check
|
||||
|
||||
Returns:
|
||||
True if the email domain is disposable, False otherwise
|
||||
"""
|
||||
if not email or "@" not in email:
|
||||
return False
|
||||
|
||||
parts = email.split("@")
|
||||
if len(parts) != 2 or not parts[0]: # Must have user@domain with non-empty user
|
||||
return False
|
||||
|
||||
domain = parts[1].lower().strip()
|
||||
if not domain: # Domain part must not be empty
|
||||
return False
|
||||
|
||||
disposable_domains = self.get_domains()
|
||||
return domain in disposable_domains
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_validator = DisposableEmailValidator()
|
||||
|
||||
|
||||
def is_disposable_email(email: str) -> bool:
|
||||
"""
|
||||
Check if an email address uses a disposable/temporary domain.
|
||||
|
||||
This is a convenience function that uses the global validator instance.
|
||||
|
||||
Args:
|
||||
email: The email address to check
|
||||
|
||||
Returns:
|
||||
True if the email uses a disposable domain, False otherwise
|
||||
"""
|
||||
return _validator.is_disposable(email)
|
||||
|
||||
|
||||
def refresh_disposable_domains() -> None:
|
||||
"""
|
||||
Force a refresh of the disposable domains list.
|
||||
|
||||
This can be called manually if you want to update the list
|
||||
without waiting for the cache to expire.
|
||||
"""
|
||||
_validator._last_fetch_time = 0
|
||||
_validator.get_domains()
|
||||
@@ -40,8 +40,6 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
tenant_id: str | None = None
|
||||
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
|
||||
captcha_token: str | None = None
|
||||
|
||||
|
||||
class UserUpdateWithRole(schemas.BaseUserUpdate):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user