Compare commits

..

2 Commits

Author SHA1 Message Date
Weves
ca3db17b08 add restart 2025-12-17 12:48:46 -08:00
Weves
ffd13b1104 dump scripts 2025-12-17 12:48:46 -08:00
1194 changed files with 36500 additions and 91759 deletions

View File

@@ -1,8 +0,0 @@
# Exclude these commits from git blame (e.g. mass reformatting).
# These are ignored by GitHub automatically.
# To enable this locally, run:
#
# git config blame.ignoreRevsFile .git-blame-ignore-revs
3134e5f840c12c8f32613ce520101a047c89dcc2 # refactor(whitespace): rm temporary react fragments (#7161)
ed3f72bc75f3e3a9ae9e4d8cd38278f9c97e78b4 # refactor(whitespace): rm react fragment #7190

7
.github/CODEOWNERS vendored
View File

@@ -1,10 +1,3 @@
* @onyx-dot-app/onyx-core-team
# Helm charts Owners
/helm/ @justin-tahara
# Web standards updates
/web/STANDARDS.md @raunakab @Weves
# Agent context files
/CLAUDE.md.template @Weves
/AGENTS.md.template @Weves

View File

@@ -7,6 +7,12 @@ inputs:
runs:
using: "composite"
steps:
- name: Setup uv
uses: astral-sh/setup-uv@caf0cab7a618c569241d31dcd442f54681755d39 # ratchet:astral-sh/setup-uv@v3
# TODO: Enable caching once there is a uv.lock file checked in.
# with:
# enable-cache: true
- name: Compute requirements hash
id: req-hash
shell: bash
@@ -22,8 +28,6 @@ runs:
done <<< "$REQUIREMENTS"
echo "hash=$(echo "$hash" | sha256sum | cut -d' ' -f1)" >> "$GITHUB_OUTPUT"
# NOTE: This comes before Setup uv since clean-ups run in reverse chronological order
# such that Setup uv's prune-cache is able to prune the cache before we upload.
- name: Cache uv cache directory
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
@@ -32,14 +36,6 @@ runs:
restore-keys: |
${{ runner.os }}-uv-
- name: Setup uv
uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # ratchet:astral-sh/setup-uv@v7
with:
version: "0.9.9"
# TODO: Enable caching once there is a uv.lock file checked in.
# with:
# enable-cache: true
- name: Setup Python
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # ratchet:actions/setup-python@v5
with:

View File

@@ -1,10 +1,10 @@
## Description
<!--- Provide a brief description of the changes in this PR --->
[Provide a brief description of the changes in this PR]
## How Has This Been Tested?
<!--- Describe the tests you ran to verify your changes --->
[Describe the tests you ran to verify your changes]
## Additional Options

File diff suppressed because it is too large Load Diff

View File

@@ -1,31 +0,0 @@
name: Merge Group-Specific
on:
merge_group:
permissions:
contents: read
jobs:
# This job immediately succeeds to satisfy branch protection rules on merge_group events.
# There is a similarly named "required" job in pr-integration-tests.yml which runs the actual
# integration tests. That job runs on both pull_request and merge_group events, and this job
# exists solely to provide a fast-passing check with the same name for branch protection.
# The actual tests remain enforced on presubmit (pull_request events).
required:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Success
run: echo "Success"
# This job immediately succeeds to satisfy branch protection rules on merge_group events.
# There is a similarly named "playwright-required" job in pr-playwright-tests.yml which runs
# the actual playwright tests. That job runs on both pull_request and merge_group events, and
# this job exists solely to provide a fast-passing check with the same name for branch protection.
# The actual tests remain enforced on presubmit (pull_request events).
playwright-required:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Success
run: echo "Success"

View File

@@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # ratchet:actions/stale@v10
- uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # ratchet:actions/stale@v10
with:
stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'

View File

@@ -1,62 +0,0 @@
name: Database Tests
concurrency:
group: Database-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
push:
tags:
- "v*.*.*"
permissions:
contents: read
jobs:
database-tests:
runs-on:
- runs-on
- runner=2cpu-linux-arm64
- "run-id=${{ github.run_id }}-database-tests"
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Setup Python and Install Dependencies
uses: ./.github/actions/setup-python-and-install-dependencies
with:
requirements: |
backend/requirements/default.txt
backend/requirements/dev.txt
- name: Generate OpenAPI schema and Python client
shell: bash
run: |
ods openapi all
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Start Docker containers
working-directory: ./deployment/docker_compose
run: |
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d \
relational_db
- name: Run Database Tests
working-directory: ./backend
run: pytest -m alembic tests/integration/tests/migrations/

View File

@@ -38,8 +38,6 @@ env:
# LLMs
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
VERTEX_CREDENTIALS: ${{ secrets.VERTEX_CREDENTIALS }}
VERTEX_LOCATION: ${{ vars.VERTEX_LOCATION }}
# Code Interpreter
# TODO: debug why this is failing and enable
@@ -172,7 +170,7 @@ jobs:
- name: Upload Docker logs
if: failure()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v5
with:
name: docker-logs-${{ matrix.test-dir }}
path: docker-logs/

View File

@@ -6,11 +6,11 @@ concurrency:
on:
merge_group:
pull_request:
branches: [main]
branches: [ main ]
push:
tags:
- "v*.*.*"
workflow_dispatch: # Allows manual triggering
workflow_dispatch: # Allows manual triggering
permissions:
contents: read
@@ -18,233 +18,225 @@ permissions:
jobs:
helm-chart-check:
# See https://runs-on.com/runners/linux/
runs-on:
[
runs-on,
runner=8cpu-linux-x64,
hdd=256,
"run-id=${{ github.run_id }}-helm-chart-check",
]
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}-helm-chart-check"]
timeout-minutes: 45
# fetch-depth 0 is required for helm/chart-testing-action
steps:
- name: Checkout code
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: false
- name: Checkout code
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: false
- name: Set up Helm
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4.3.1
with:
version: v3.19.0
- name: Set up Helm
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4.3.1
with:
version: v3.19.0
- name: Set up chart-testing
# NOTE: This is Jamison's patch from https://github.com/helm/chart-testing-action/pull/194
uses: helm/chart-testing-action@8958a6ac472cbd8ee9a8fbb6f1acbc1b0e966e44 # zizmor: ignore[impostor-commit]
with:
uv_version: "0.9.9"
- name: Set up chart-testing
uses: helm/chart-testing-action@6ec842c01de15ebb84c8627d2744a0c2f2755c9f # ratchet:helm/chart-testing-action@v2.8.0
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
- name: Run chart-testing (list-changed)
id: list-changed
env:
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
run: |
echo "default_branch: ${DEFAULT_BRANCH}"
changed=$(ct list-changed --remote origin --target-branch ${DEFAULT_BRANCH} --chart-dirs deployment/helm/charts)
echo "list-changed output: $changed"
if [[ -n "$changed" ]]; then
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
- name: Run chart-testing (list-changed)
id: list-changed
env:
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
run: |
echo "default_branch: ${DEFAULT_BRANCH}"
changed=$(ct list-changed --remote origin --target-branch ${DEFAULT_BRANCH} --chart-dirs deployment/helm/charts)
echo "list-changed output: $changed"
if [[ -n "$changed" ]]; then
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
# uncomment to force run chart-testing
# - name: Force run chart-testing (list-changed)
# id: list-changed
# run: echo "changed=true" >> $GITHUB_OUTPUT
# lint all charts if any changes were detected
- name: Run chart-testing (lint)
if: steps.list-changed.outputs.changed == 'true'
run: ct lint --config ct.yaml --all
# the following would lint only changed charts, but linting isn't expensive
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
# uncomment to force run chart-testing
# - name: Force run chart-testing (list-changed)
# id: list-changed
# run: echo "changed=true" >> $GITHUB_OUTPUT
- name: Create kind cluster
if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@92086f6be054225fa813e0a4b13787fc9088faab # ratchet:helm/kind-action@v1.13.0
# lint all charts if any changes were detected
- name: Run chart-testing (lint)
if: steps.list-changed.outputs.changed == 'true'
run: ct lint --config ct.yaml --all
# the following would lint only changed charts, but linting isn't expensive
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
- name: Pre-install cluster status check
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Pre-install Cluster Status ==="
kubectl get nodes -o wide
kubectl get pods --all-namespaces
kubectl get storageclass
- name: Create kind cluster
if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@92086f6be054225fa813e0a4b13787fc9088faab # ratchet:helm/kind-action@v1.13.0
- name: Add Helm repositories and update
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Adding Helm repositories ==="
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo update
- name: Pre-install cluster status check
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Pre-install Cluster Status ==="
kubectl get nodes -o wide
kubectl get pods --all-namespaces
kubectl get storageclass
- name: Install Redis operator
if: steps.list-changed.outputs.changed == 'true'
shell: bash
run: |
echo "=== Installing redis-operator CRDs ==="
helm upgrade --install redis-operator ot-container-kit/redis-operator \
--namespace redis-operator --create-namespace --wait --timeout 300s
- name: Add Helm repositories and update
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Adding Helm repositories ==="
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo update
- name: Pre-pull required images
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Pre-pulling required images to avoid timeout ==="
KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//')
echo "Kind cluster: $KIND_CLUSTER"
- name: Install Redis operator
if: steps.list-changed.outputs.changed == 'true'
shell: bash
run: |
echo "=== Installing redis-operator CRDs ==="
helm upgrade --install redis-operator ot-container-kit/redis-operator \
--namespace redis-operator --create-namespace --wait --timeout 300s
IMAGES=(
"ghcr.io/cloudnative-pg/cloudnative-pg:1.27.0"
"quay.io/opstree/redis:v7.0.15"
"docker.io/onyxdotapp/onyx-web-server:latest"
)
- name: Pre-pull required images
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Pre-pulling required images to avoid timeout ==="
KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//')
echo "Kind cluster: $KIND_CLUSTER"
for image in "${IMAGES[@]}"; do
echo "Pre-pulling $image"
if docker pull "$image"; then
kind load docker-image "$image" --name "$KIND_CLUSTER" || echo "Failed to load $image into kind"
else
echo "Failed to pull $image"
fi
done
IMAGES=(
"ghcr.io/cloudnative-pg/cloudnative-pg:1.27.0"
"quay.io/opstree/redis:v7.0.15"
"docker.io/onyxdotapp/onyx-web-server:latest"
)
echo "=== Images loaded into Kind cluster ==="
docker exec "$KIND_CLUSTER"-control-plane crictl images | grep -E "(cloudnative-pg|redis|onyx)" || echo "Some images may still be loading..."
- name: Validate chart dependencies
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Validating chart dependencies ==="
cd deployment/helm/charts/onyx
helm dependency update
helm lint .
- name: Run chart-testing (install) with enhanced monitoring
timeout-minutes: 25
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Starting chart installation with monitoring ==="
# Function to monitor cluster state
monitor_cluster() {
while true; do
echo "=== Cluster Status Check at $(date) ==="
# Only show non-running pods to reduce noise
NON_RUNNING_PODS=$(kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded --no-headers 2>/dev/null | wc -l)
if [ "$NON_RUNNING_PODS" -gt 0 ]; then
echo "Non-running pods:"
kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded
else
echo "All pods running successfully"
fi
# Only show recent events if there are issues
RECENT_EVENTS=$(kubectl get events --sort-by=.lastTimestamp --all-namespaces --field-selector=type!=Normal 2>/dev/null | tail -5)
if [ -n "$RECENT_EVENTS" ]; then
echo "Recent warnings/errors:"
echo "$RECENT_EVENTS"
fi
sleep 60
done
}
# Start monitoring in background
monitor_cluster &
MONITOR_PID=$!
# Set up cleanup
cleanup() {
echo "=== Cleaning up monitoring process ==="
kill $MONITOR_PID 2>/dev/null || true
echo "=== Final cluster state ==="
kubectl get pods --all-namespaces
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -20
}
# Trap cleanup on exit
trap cleanup EXIT
# Run the actual installation with detailed logging
echo "=== Starting ct install ==="
set +e
ct install --all \
--helm-extra-set-args="\
--set=nginx.enabled=false \
--set=minio.enabled=false \
--set=vespa.enabled=false \
--set=slackbot.enabled=false \
--set=postgresql.enabled=true \
--set=postgresql.nameOverride=cloudnative-pg \
--set=postgresql.cluster.storage.storageClass=standard \
--set=redis.enabled=true \
--set=redis.storageSpec.volumeClaimTemplate.spec.storageClassName=standard \
--set=webserver.replicaCount=1 \
--set=api.replicaCount=0 \
--set=inferenceCapability.replicaCount=0 \
--set=indexCapability.replicaCount=0 \
--set=celery_beat.replicaCount=0 \
--set=celery_worker_heavy.replicaCount=0 \
--set=celery_worker_docfetching.replicaCount=0 \
--set=celery_worker_docprocessing.replicaCount=0 \
--set=celery_worker_light.replicaCount=0 \
--set=celery_worker_monitoring.replicaCount=0 \
--set=celery_worker_primary.replicaCount=0 \
--set=celery_worker_user_file_processing.replicaCount=0 \
--set=celery_worker_user_files_indexing.replicaCount=0" \
--helm-extra-args="--timeout 900s --debug" \
--debug --config ct.yaml
CT_EXIT=$?
set -e
if [[ $CT_EXIT -ne 0 ]]; then
echo "ct install failed with exit code $CT_EXIT"
exit $CT_EXIT
for image in "${IMAGES[@]}"; do
echo "Pre-pulling $image"
if docker pull "$image"; then
kind load docker-image "$image" --name "$KIND_CLUSTER" || echo "Failed to load $image into kind"
else
echo "=== Installation completed successfully ==="
echo "Failed to pull $image"
fi
done
kubectl get pods --all-namespaces
echo "=== Images loaded into Kind cluster ==="
docker exec "$KIND_CLUSTER"-control-plane crictl images | grep -E "(cloudnative-pg|redis|onyx)" || echo "Some images may still be loading..."
- name: Post-install verification
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Post-install verification ==="
kubectl get pods --all-namespaces
kubectl get services --all-namespaces
# Only show issues if they exist
kubectl describe pods --all-namespaces | grep -A 5 -B 2 "Failed\|Error\|Warning" || echo "No pod issues found"
- name: Validate chart dependencies
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Validating chart dependencies ==="
cd deployment/helm/charts/onyx
helm dependency update
helm lint .
- name: Cleanup on failure
if: failure() && steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Cleanup on failure ==="
- name: Run chart-testing (install) with enhanced monitoring
timeout-minutes: 25
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Starting chart installation with monitoring ==="
# Function to monitor cluster state
monitor_cluster() {
while true; do
echo "=== Cluster Status Check at $(date) ==="
# Only show non-running pods to reduce noise
NON_RUNNING_PODS=$(kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded --no-headers 2>/dev/null | wc -l)
if [ "$NON_RUNNING_PODS" -gt 0 ]; then
echo "Non-running pods:"
kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded
else
echo "All pods running successfully"
fi
# Only show recent events if there are issues
RECENT_EVENTS=$(kubectl get events --sort-by=.lastTimestamp --all-namespaces --field-selector=type!=Normal 2>/dev/null | tail -5)
if [ -n "$RECENT_EVENTS" ]; then
echo "Recent warnings/errors:"
echo "$RECENT_EVENTS"
fi
sleep 60
done
}
# Start monitoring in background
monitor_cluster &
MONITOR_PID=$!
# Set up cleanup
cleanup() {
echo "=== Cleaning up monitoring process ==="
kill $MONITOR_PID 2>/dev/null || true
echo "=== Final cluster state ==="
kubectl get pods --all-namespaces
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -10
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -20
}
echo "=== Pod descriptions for debugging ==="
kubectl describe pods --all-namespaces | grep -A 10 -B 3 "Failed\|Error\|Warning\|Pending" || echo "No problematic pods found"
# Trap cleanup on exit
trap cleanup EXIT
echo "=== Recent logs for debugging ==="
kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found"
# Run the actual installation with detailed logging
echo "=== Starting ct install ==="
set +e
ct install --all \
--helm-extra-set-args="\
--set=nginx.enabled=false \
--set=minio.enabled=false \
--set=vespa.enabled=false \
--set=slackbot.enabled=false \
--set=postgresql.enabled=true \
--set=postgresql.nameOverride=cloudnative-pg \
--set=postgresql.cluster.storage.storageClass=standard \
--set=redis.enabled=true \
--set=redis.storageSpec.volumeClaimTemplate.spec.storageClassName=standard \
--set=webserver.replicaCount=1 \
--set=api.replicaCount=0 \
--set=inferenceCapability.replicaCount=0 \
--set=indexCapability.replicaCount=0 \
--set=celery_beat.replicaCount=0 \
--set=celery_worker_heavy.replicaCount=0 \
--set=celery_worker_docfetching.replicaCount=0 \
--set=celery_worker_docprocessing.replicaCount=0 \
--set=celery_worker_light.replicaCount=0 \
--set=celery_worker_monitoring.replicaCount=0 \
--set=celery_worker_primary.replicaCount=0 \
--set=celery_worker_user_file_processing.replicaCount=0 \
--set=celery_worker_user_files_indexing.replicaCount=0" \
--helm-extra-args="--timeout 900s --debug" \
--debug --config ct.yaml
CT_EXIT=$?
set -e
echo "=== Helm releases ==="
helm list --all-namespaces
# the following would install only changed charts, but we only have one chart so
# don't worry about that for now
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
if [[ $CT_EXIT -ne 0 ]]; then
echo "ct install failed with exit code $CT_EXIT"
exit $CT_EXIT
else
echo "=== Installation completed successfully ==="
fi
kubectl get pods --all-namespaces
- name: Post-install verification
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Post-install verification ==="
kubectl get pods --all-namespaces
kubectl get services --all-namespaces
# Only show issues if they exist
kubectl describe pods --all-namespaces | grep -A 5 -B 2 "Failed\|Error\|Warning" || echo "No pod issues found"
- name: Cleanup on failure
if: failure() && steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Cleanup on failure ==="
echo "=== Final cluster state ==="
kubectl get pods --all-namespaces
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -10
echo "=== Pod descriptions for debugging ==="
kubectl describe pods --all-namespaces | grep -A 10 -B 3 "Failed\|Error\|Warning\|Pending" || echo "No problematic pods found"
echo "=== Recent logs for debugging ==="
kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found"
echo "=== Helm releases ==="
helm list --all-namespaces
# the following would install only changed charts, but we only have one chart so
# don't worry about that for now
# run: ct install --target-branch ${{ github.event.repository.default_branch }}

View File

@@ -33,11 +33,6 @@ env:
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN: ${{ secrets.ONYX_GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN }}
GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC: ${{ secrets.ONYX_GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC }}
GITHUB_ADMIN_EMAIL: ${{ secrets.ONYX_GITHUB_ADMIN_EMAIL }}
GITHUB_TEST_USER_1_EMAIL: ${{ secrets.ONYX_GITHUB_TEST_USER_1_EMAIL }}
GITHUB_TEST_USER_2_EMAIL: ${{ secrets.ONYX_GITHUB_TEST_USER_2_EMAIL }}
jobs:
discover-test-dirs:
@@ -56,7 +51,7 @@ jobs:
id: set-matrix
run: |
# Find all leaf-level directories in both test directories
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" -exec basename {} \; | sort)
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
# Create JSON array with directory info
@@ -72,14 +67,9 @@ jobs:
all_dirs="[${all_dirs%,}]"
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
build-backend-image:
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-backend-image",
"extras=ecr-cache",
]
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-backend-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -132,14 +122,9 @@ jobs:
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-model-server-image:
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-model-server-image",
"extras=ecr-cache",
]
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-model-server-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -191,14 +176,9 @@ jobs:
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
build-integration-image:
runs-on:
[
runs-on,
runner=2cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-integration-image",
"extras=ecr-cache",
]
runs-on: [runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-build-integration-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -240,7 +220,7 @@ jobs:
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
HEAD_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
run: |
docker buildx bake --push \
cd backend && docker buildx bake --push \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
@@ -310,9 +290,7 @@ jobs:
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
INTEGRATION_TESTS_MODE=true
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
MCP_SERVER_ENABLED=true
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
EOF
- name: Start Docker containers
@@ -326,6 +304,7 @@ jobs:
api_server \
inference_model_server \
indexing_model_server \
mcp_server \
background \
-d
id: start_docker
@@ -368,6 +347,12 @@ jobs:
}
wait_for_service "http://localhost:8080/health" "API server"
test_dir="${{ matrix.test-dir.path }}"
if [ "$test_dir" = "tests/mcp" ]; then
wait_for_service "http://localhost:8090/health" "MCP server"
else
echo "Skipping MCP server wait for non-MCP suite: $test_dir"
fi
echo "Finished waiting for services."
- name: Start Mock Services
@@ -397,6 +382,8 @@ jobs:
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e MCP_SERVER_HOST=mcp_server \
-e MCP_SERVER_PORT=8090 \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
@@ -412,11 +399,6 @@ jobs:
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
-e GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN=${GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN} \
-e GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC=${GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC} \
-e GITHUB_ADMIN_EMAIL=${GITHUB_ADMIN_EMAIL} \
-e GITHUB_TEST_USER_1_EMAIL=${GITHUB_TEST_USER_1_EMAIL} \
-e GITHUB_TEST_USER_2_EMAIL=${GITHUB_TEST_USER_2_EMAIL} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
@@ -439,22 +421,21 @@ jobs:
- name: Upload logs
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
with:
name: docker-all-logs-${{ matrix.test-dir.name }}
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
multitenant-tests:
needs:
[build-backend-image, build-model-server-image, build-integration-image]
runs-on:
[
runs-on,
runner=8cpu-linux-arm64,
"run-id=${{ github.run_id }}-multitenant-tests",
"extras=ecr-cache",
build-backend-image,
build-model-server-image,
build-integration-image,
]
runs-on: [runs-on, runner=8cpu-linux-arm64, "run-id=${{ github.run_id }}-multitenant-tests", "extras=ecr-cache"]
timeout-minutes: 45
steps:
@@ -481,10 +462,10 @@ jobs:
AUTH_TYPE=cloud \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
OPENAI_DEFAULT_API_KEY=${OPENAI_API_KEY} \
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
DEV_MODE=true \
MCP_SERVER_ENABLED=true \
docker compose -f docker-compose.multitenant-dev.yml up \
relational_db \
index \
@@ -493,6 +474,7 @@ jobs:
api_server \
inference_model_server \
indexing_model_server \
mcp_server \
background \
-d
id: start_docker_multi_tenant
@@ -541,6 +523,8 @@ jobs:
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e MCP_SERVER_HOST=mcp_server \
-e MCP_SERVER_PORT=8090 \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
@@ -568,7 +552,7 @@ jobs:
- name: Upload logs (multi-tenant)
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
with:
name: docker-all-logs-multitenant
path: ${{ github.workspace }}/docker-compose-multitenant.log

View File

@@ -4,14 +4,7 @@ concurrency:
cancel-in-progress: true
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
push:
tags:
- "v*.*.*"
permissions:
contents: read
@@ -44,7 +37,7 @@ jobs:
- name: Upload coverage reports
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
with:
name: jest-coverage-${{ github.run_id }}
path: ./web/coverage

View File

@@ -48,7 +48,7 @@ jobs:
id: set-matrix
run: |
# Find all leaf-level directories in both test directories
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" -exec basename {} \; | sort)
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
# Create JSON array with directory info
@@ -65,13 +65,7 @@ jobs:
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
build-backend-image:
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-backend-image",
"extras=ecr-cache",
]
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-backend-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -125,13 +119,7 @@ jobs:
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-model-server-image:
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-model-server-image",
"extras=ecr-cache",
]
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-model-server-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -184,13 +172,7 @@ jobs:
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
build-integration-image:
runs-on:
[
runs-on,
runner=2cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-integration-image",
"extras=ecr-cache",
]
runs-on: [runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-build-integration-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -232,7 +214,7 @@ jobs:
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
HEAD_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
run: |
docker buildx bake --push \
cd backend && docker buildx bake --push \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
@@ -301,7 +283,6 @@ jobs:
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
INTEGRATION_TESTS_MODE=true
MCP_SERVER_ENABLED=true
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
EOF
- name: Start Docker containers
@@ -315,6 +296,7 @@ jobs:
api_server \
inference_model_server \
indexing_model_server \
mcp_server \
background \
-d
id: start_docker
@@ -357,6 +339,12 @@ jobs:
}
wait_for_service "http://localhost:8080/health" "API server"
test_dir="${{ matrix.test-dir.path }}"
if [ "$test_dir" = "tests/mcp" ]; then
wait_for_service "http://localhost:8090/health" "MCP server"
else
echo "Skipping MCP server wait for non-MCP suite: $test_dir"
fi
echo "Finished waiting for services."
- name: Start Mock Services
@@ -387,6 +375,8 @@ jobs:
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e MCP_SERVER_HOST=mcp_server \
-e MCP_SERVER_PORT=8090 \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
@@ -424,12 +414,13 @@ jobs:
- name: Upload logs
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
with:
name: docker-all-logs-${{ matrix.test-dir.name }}
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
required:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim

View File

@@ -4,14 +4,7 @@ concurrency:
cancel-in-progress: true
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
push:
tags:
- "v*.*.*"
permissions:
contents: read
@@ -54,13 +47,7 @@ env:
jobs:
build-web-image:
runs-on:
[
runs-on,
runner=4cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-web-image",
"extras=ecr-cache",
]
runs-on: [runs-on, runner=4cpu-linux-arm64, "run-id=${{ github.run_id }}-build-web-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -115,13 +102,7 @@ jobs:
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-backend-image:
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-backend-image",
"extras=ecr-cache",
]
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-backend-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -176,13 +157,7 @@ jobs:
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-model-server-image:
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-model-server-image",
"extras=ecr-cache",
]
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-model-server-image", "extras=ecr-cache"]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -256,13 +231,14 @@ jobs:
- name: Checkout code
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: false
- name: Setup node
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v4
with:
node-version: 22
cache: "npm"
cache: 'npm'
cache-dependency-path: ./web/package-lock.json
- name: Install node dependencies
@@ -435,7 +411,7 @@ jobs:
fi
npx playwright test --project ${PROJECT}
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
- uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
if: always()
with:
# Includes test results and trace.zip files
@@ -455,7 +431,7 @@ jobs:
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
with:
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
path: ${{ github.workspace }}/docker-compose.log
@@ -471,6 +447,7 @@ jobs:
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
run: exit 1
# NOTE: Chromatic UI diff testing is currently disabled.
# We are using Playwright for local and CI testing without visual regression checks.
# Chromatic may be reintroduced in the future for UI diff testing if needed.

View File

@@ -144,7 +144,7 @@ jobs:
- name: Upload logs
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
with:
name: docker-all-logs
path: ${{ github.workspace }}/docker-compose.log

View File

@@ -16,22 +16,21 @@ jobs:
strategy:
matrix:
os-arch:
- { goos: "linux", goarch: "amd64" }
- { goos: "linux", goarch: "arm64" }
- { goos: "windows", goarch: "amd64" }
- { goos: "windows", goarch: "arm64" }
- { goos: "darwin", goarch: "amd64" }
- { goos: "darwin", goarch: "arm64" }
- { goos: "", goarch: "" }
- {goos: "linux", goarch: "amd64"}
- {goos: "linux", goarch: "arm64"}
- {goos: "windows", goarch: "amd64"}
- {goos: "windows", goarch: "arm64"}
- {goos: "darwin", goarch: "amd64"}
- {goos: "darwin", goarch: "arm64"}
- {goos: "", goarch: ""}
steps:
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
with:
persist-credentials: false
fetch-depth: 0
- uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # ratchet:astral-sh/setup-uv@v7
- uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
- run: |
GOOS="${{ matrix.os-arch.goos }}" \
GOARCH="${{ matrix.os-arch.goarch }}" \

View File

@@ -21,29 +21,17 @@ jobs:
with:
persist-credentials: false
- name: Detect changes
id: filter
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # ratchet:dorny/paths-filter@v3
with:
filters: |
zizmor:
- '.github/**'
- name: Install the latest version of uv
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # ratchet:astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # ratchet:astral-sh/setup-uv@v7.1.4
with:
enable-cache: false
version: "0.9.9"
- name: Run zizmor
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
run: uv run --no-sync --with zizmor zizmor --format=sarif . > results.sarif
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Upload SARIF file
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab # ratchet:github/codeql-action/upload-sarif@codeql-bundle-v2.23.5
with:
sarif_file: results.sarif

1
.gitignore vendored
View File

@@ -21,7 +21,6 @@ backend/tests/regression/search_quality/*.json
backend/onyx/evals/data/
backend/onyx/evals/one_off/*.json
*.log
*.csv
# secret files
.env

View File

@@ -8,65 +8,30 @@ repos:
# From: https://github.com/astral-sh/uv-pre-commit/pull/53/commits/d30b4298e4fb63ce8609e29acdbcf4c9018a483c
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
hooks:
- id: uv-run
name: Check lazy imports
args: ["--with=onyx-devtools", "ods", "check-lazy-imports"]
files: ^backend/(?!\.venv/).*\.py$
- id: uv-sync
args: ["--locked", "--all-extras"]
- id: uv-lock
files: ^pyproject\.toml$
- id: uv-export
name: uv-export default.txt
args:
[
"--no-emit-project",
"--no-default-groups",
"--no-hashes",
"--extra",
"backend",
"-o",
"backend/requirements/default.txt",
]
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "backend", "-o", "backend/requirements/default.txt"]
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
- id: uv-export
name: uv-export dev.txt
args:
[
"--no-emit-project",
"--no-default-groups",
"--no-hashes",
"--extra",
"dev",
"-o",
"backend/requirements/dev.txt",
]
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "dev", "-o", "backend/requirements/dev.txt"]
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
- id: uv-export
name: uv-export ee.txt
args:
[
"--no-emit-project",
"--no-default-groups",
"--no-hashes",
"--extra",
"ee",
"-o",
"backend/requirements/ee.txt",
]
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "ee", "-o", "backend/requirements/ee.txt"]
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
- id: uv-export
name: uv-export model_server.txt
args:
[
"--no-emit-project",
"--no-default-groups",
"--no-hashes",
"--extra",
"model_server",
"-o",
"backend/requirements/model_server.txt",
]
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "model_server", "-o", "backend/requirements/model_server.txt"]
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
- id: uv-run
name: Check lazy imports
args: ["--active", "--with=onyx-devtools", "ods", "check-lazy-imports"]
files: ^backend/(?!\.venv/).*\.py$
# NOTE: This takes ~6s on a single, large module which is prohibitively slow.
# - id: uv-run
# name: mypy
@@ -74,68 +39,69 @@ repos:
# pass_filenames: true
# files: ^backend/.*\.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0
hooks:
- id: check-yaml
files: ^.github/
- repo: https://github.com/rhysd/actionlint
rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9
rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9
hooks:
- id: actionlint
- repo: https://github.com/psf/black
rev: 8a737e727ac5ab2f1d4cf5876720ed276dc8dc4b # frozen: 25.1.0
hooks:
- id: black
language_version: python3.11
- id: black
language_version: python3.11
# this is a fork which keeps compatibility with black
- repo: https://github.com/wimglenn/reorder-python-imports-black
rev: f55cd27f90f0cf0ee775002c2383ce1c7820013d # frozen: v3.14.0
rev: f55cd27f90f0cf0ee775002c2383ce1c7820013d # frozen: v3.14.0
hooks:
- id: reorder-python-imports
args: ["--py311-plus", "--application-directories=backend/"]
# need to ignore alembic files, since reorder-python-imports gets confused
# and thinks that alembic is a local package since there is a folder
# in the backend directory called `alembic`
exclude: ^backend/alembic/
- id: reorder-python-imports
args: ['--py311-plus', '--application-directories=backend/']
# need to ignore alembic files, since reorder-python-imports gets confused
# and thinks that alembic is a local package since there is a folder
# in the backend directory called `alembic`
exclude: ^backend/alembic/
# These settings will remove unused imports with side effects
# Note: The repo currently does not and should not have imports with side effects
- repo: https://github.com/PyCQA/autoflake
rev: 0544741e2b4a22b472d9d93e37d4ea9153820bb1 # frozen: v2.3.1
rev: 0544741e2b4a22b472d9d93e37d4ea9153820bb1 # frozen: v2.3.1
hooks:
- id: autoflake
args:
[
"--remove-all-unused-imports",
"--remove-unused-variables",
"--in-place",
"--recursive",
]
args: [ '--remove-all-unused-imports', '--remove-unused-variables', '--in-place' , '--recursive']
- repo: https://github.com/golangci/golangci-lint
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
hooks:
- id: golangci-lint
entry: bash -c "find tools/ -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: 971923581912ef60a6b70dbf0c3e9a39563c9d47 # frozen: v0.11.4
rev: 971923581912ef60a6b70dbf0c3e9a39563c9d47 # frozen: v0.11.4
hooks:
- id: ruff
- repo: https://github.com/pre-commit/mirrors-prettier
rev: ffb6a759a979008c0e6dff86e39f4745a2d9eac4 # frozen: v3.1.0
rev: ffb6a759a979008c0e6dff86e39f4745a2d9eac4 # frozen: v3.1.0
hooks:
- id: prettier
types_or: [html, css, javascript, ts, tsx]
language_version: system
- id: prettier
types_or: [html, css, javascript, ts, tsx]
language_version: system
- repo: https://github.com/sirwart/ripsecrets
rev: 7d94620933e79b8acaa0cd9e60e9864b07673d86 # frozen: v0.1.11
rev: 7d94620933e79b8acaa0cd9e60e9864b07673d86 # frozen: v0.1.11
hooks:
- id: ripsecrets
args:
- --additional-pattern
- ^sk-[A-Za-z0-9_\-]{20,}$
- --additional-pattern
- ^sk-[A-Za-z0-9_\-]{20,}$
- repo: local
hooks:
@@ -146,13 +112,9 @@ repos:
pass_filenames: false
files: \.tf$
# Uses tsgo (TypeScript's native Go compiler) for ~10x faster type checking.
# This is a preview package - if it breaks:
# 1. Try updating: cd web && npm update @typescript/native-preview
# 2. Or fallback to tsc: replace 'tsgo' with 'tsc' below
- id: typescript-check
name: TypeScript type check
entry: bash -c 'cd web && npx tsgo --noEmit --project tsconfig.types.json'
entry: bash -c 'cd web && npm run types:check'
language: system
pass_filenames: false
files: ^web/.*\.(ts|tsx)$

View File

@@ -1,45 +1,36 @@
# Copy this file to .env in the .vscode folder.
# Fill in the <REPLACE THIS> values as needed; it is recommended to set the
# GEN_AI_API_KEY value to avoid having to set up an LLM in the UI.
# Also check out onyx/backend/scripts/restart_containers.sh for a script to
# restart the containers which Onyx relies on outside of VSCode/Cursor
# processes.
# Copy this file to .env in the .vscode folder
# Fill in the <REPLACE THIS> values as needed, it is recommended to set the GEN_AI_API_KEY value to avoid having to set up an LLM in the UI
# Also check out onyx/backend/scripts/restart_containers.sh for a script to restart the containers which Onyx relies on outside of VSCode/Cursor processes
# For local dev, often user Authentication is not needed.
# For local dev, often user Authentication is not needed
AUTH_TYPE=disabled
# Always keep these on for Dev.
# Logs model prompts, reasoning, and answer to stdout.
# Always keep these on for Dev
# Logs model prompts, reasoning, and answer to stdout
LOG_ONYX_MODEL_INTERACTIONS=True
# More verbose logging
LOG_LEVEL=debug
# This passes top N results to LLM an additional time for reranking prior to
# answer generation.
# This step is quite heavy on token usage so we disable it for dev generally.
# This passes top N results to LLM an additional time for reranking prior to answer generation
# This step is quite heavy on token usage so we disable it for dev generally
DISABLE_LLM_DOC_RELEVANCE=False
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically).
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
OAUTH_CLIENT_ID=<REPLACE THIS>
OAUTH_CLIENT_SECRET=<REPLACE THIS>
OPENID_CONFIG_URL=<REPLACE THIS>
SAML_CONF_DIR=/<ABSOLUTE PATH TO ONYX>/onyx/backend/ee/onyx/configs/saml_config
# Generally not useful for dev, we don't generally want to set up an SMTP server
# for dev.
# Generally not useful for dev, we don't generally want to set up an SMTP server for dev
REQUIRE_EMAIL_VERIFICATION=False
# Set these so if you wipe the DB, you don't end up having to go through the UI
# every time.
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
GEN_AI_API_KEY=<REPLACE THIS>
OPENAI_API_KEY=<REPLACE THIS>
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper.
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper
GEN_AI_MODEL_VERSION=gpt-4o
FAST_GEN_AI_MODEL_VERSION=gpt-4o
@@ -49,36 +40,26 @@ PYTHONPATH=../backend
PYTHONUNBUFFERED=1
# Enable the full set of Danswer Enterprise Edition features.
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you
# are using this for local testing/development).
# Enable the full set of Danswer Enterprise Edition features
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
# S3 File Store Configuration (MinIO for local development)
S3_ENDPOINT_URL=http://localhost:9004
S3_FILE_STORE_BUCKET_NAME=onyx-file-store-bucket
S3_AWS_ACCESS_KEY_ID=minioadmin
S3_AWS_SECRET_ACCESS_KEY=minioadmin
# Show extra/uncommon connectors.
# Show extra/uncommon connectors
SHOW_EXTRA_CONNECTORS=True
# Local langsmith tracing
LANGSMITH_TRACING="true"
LANGSMITH_ENDPOINT="https://api.smith.langchain.com"
LANGSMITH_API_KEY=<REPLACE_THIS>
LANGSMITH_PROJECT=<REPLACE_THIS>
# Local Confluence OAuth testing
# OAUTH_CONFLUENCE_CLOUD_CLIENT_ID=<REPLACE_THIS>
# OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET=<REPLACE_THIS>
# NEXT_PUBLIC_TEST_ENV=True
# OpenSearch
# Arbitrary password is fine for local development.
OPENSEARCH_INITIAL_ADMIN_PASSWORD=<REPLACE THIS>
# NEXT_PUBLIC_TEST_ENV=True

View File

@@ -512,21 +512,6 @@
"group": "3"
}
},
{
"name": "Clear and Restart OpenSearch Container",
// Generic debugger type, required arg but has no bearing on bash.
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"${workspaceFolder}/backend/scripts/restart_opensearch_container.sh"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "3"
}
},
{
"name": "Eval CLI",
"type": "debugpy",

View File

@@ -1,13 +1,13 @@
# AGENTS.md
This file provides guidance to AI agents when working with code in this repository.
This file provides guidance to Codex when working with code in this repository.
## KEY NOTES
- If you run into any missing python dependency errors, try running your command with `source .venv/bin/activate` \
- If you run into any missing python dependency errors, try running your command with `source backend/.venv/bin/activate` \
to assume the python venv.
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
- If using `playwright` to explore the frontend, you can usually log in with username `a@example.com` and password
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
`a`. The app can be accessed at `http://localhost:3000`.
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
make sure we see logs coming out from the relevant service.
@@ -181,286 +181,6 @@ web/
└── src/lib/ # Utilities & business logic
```
## Frontend Standards
### 1. Import Standards
**Always use absolute imports with the `@` prefix.**
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
```typescript
// ✅ Good
import { Button } from "@/components/ui/button";
import { useAuth } from "@/hooks/useAuth";
import { Text } from "@/refresh-components/texts/Text";
// ❌ Bad
import { Button } from "../../../components/ui/button";
import { useAuth } from "./hooks/useAuth";
```
### 2. React Component Functions
**Prefer regular functions over arrow functions for React components.**
**Reason:** Functions just become easier to read.
```typescript
// ✅ Good
function UserProfile({ userId }: UserProfileProps) {
return <div>User Profile</div>
}
// ❌ Bad
const UserProfile = ({ userId }: UserProfileProps) => {
return <div>User Profile</div>
}
```
### 3. Props Interface Extraction
**Extract prop types into their own interface definitions.**
**Reason:** Functions just become easier to read.
```typescript
// ✅ Good
interface UserCardProps {
user: User
showActions?: boolean
onEdit?: (userId: string) => void
}
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
return <div>User Card</div>
}
// ❌ Bad
function UserCard({
user,
showActions = false,
onEdit
}: {
user: User
showActions?: boolean
onEdit?: (userId: string) => void
}) {
return <div>User Card</div>
}
```
### 4. Spacing Guidelines
**Prefer padding over margins for spacing.**
**Reason:** We want to consolidate usage to paddings instead of margins.
```typescript
// ✅ Good
<div className="p-4 space-y-2">
<div className="p-2">Content</div>
</div>
// ❌ Bad
<div className="m-4 space-y-2">
<div className="m-2">Content</div>
</div>
```
### 5. Tailwind Dark Mode
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
```typescript
// ✅ Good - Standard components use `web/tailwind-themes/tailwind.config.js` / `web/src/app/css/colors.css`
<div className="bg-background-neutral-03 text-text-02">
Content
</div>
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
export const GithubIcon = createLogoIcon(githubLightIcon, {
monochromatic: true, // Will apply dark:invert internally
});
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
});
// ❌ Bad - Manual dark mode overrides
<div className="bg-white dark:bg-black text-black dark:text-white">
Content
</div>
```
### 6. Class Name Utilities
**Use the `cn` utility instead of raw string formatting for classNames.**
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
```typescript
import { cn } from '@/lib/utils'
// ✅ Good
<div className={cn(
'base-class',
isActive && 'active-class',
className
)}>
Content
</div>
// ❌ Bad
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
Content
</div>
```
### 7. Custom Hooks Organization
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
**Reason:** This is just a layout preference. Keeps code clean.
```typescript
// web/src/hooks/useUserData.ts
export function useUserData(userId: string) {
// hook implementation
}
// web/src/hooks/useLocalStorage.ts
export function useLocalStorage<T>(key: string, initialValue: T) {
// hook implementation
}
```
### 8. Icon Usage
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
```typescript
// ✅ Good
import SvgX from "@/icons/x";
import SvgMoreHorizontal from "@/icons/more-horizontal";
// ❌ Bad
import { User } from "lucide-react";
import { FiSearch } from "react-icons/fi";
```
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
If you need help with this step, reach out to `raunak@onyx.app`.
### 9. Text Rendering
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
```typescript
// ✅ Good
import { Text } from '@/refresh-components/texts/Text'
function UserCard({ name }: { name: string }) {
return (
<Text
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
text03
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
mainAction
>
{name}
</Text>
)
}
// ❌ Bad
function UserCard({ name }: { name: string }) {
return (
<div>
<h2>{name}</h2>
<p>User details</p>
</div>
)
}
```
### 10. Component Usage
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
```typescript
// ✅ Good
import Button from '@/refresh-components/buttons/Button'
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
import SvgPlusCircle from '@/icons/plus-circle'
function ContactForm() {
return (
<form>
<InputTypeIn placeholder="Search..." />
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
</form>
)
}
// ❌ Bad
function ContactForm() {
return (
<form>
<input placeholder="Name" />
<textarea placeholder="Message" />
<button type="submit">Submit</button>
</form>
)
}
```
### 11. Colors
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
**Available color categories:**
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
- **Actions:** `action-link-XX`, `action-danger-XX`
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
```typescript
// ✅ Good - Use custom Onyx color classes
<div className="bg-background-neutral-01 border border-border-02" />
<div className="bg-background-tint-02 border border-border-01" />
<div className="bg-status-success-01" />
<div className="bg-action-link-01" />
<div className="bg-theme-primary-05" />
// ❌ Bad - Do NOT use standard Tailwind colors
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
<div className="bg-white border border-slate-200" />
<div className="bg-green-100 text-green-700" />
<div className="bg-blue-100 text-blue-600" />
<div className="bg-indigo-500" />
```
### 12. Data Fetching
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).
## Database & Migrations
### Running Migrations
@@ -575,6 +295,14 @@ will be tailing their logs to this file.
- Token management and rate limiting
- Custom prompts and agent actions
## UI/UX Patterns
- Tailwind CSS with design system in `web/src/components/ui/`
- Radix UI and Headless UI for accessible components
- SWR for data fetching and caching
- Form validation with react-hook-form
- Error handling with popup notifications
## Creating a Plan
When creating a plan in the `plans` directory, make sure to include at least these elements:

View File

@@ -7,7 +7,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
- If you run into any missing python dependency errors, try running your command with `source .venv/bin/activate` \
to assume the python venv.
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
- If using `playwright` to explore the frontend, you can usually log in with username `a@example.com` and password
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
`a`. The app can be accessed at `http://localhost:3000`.
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
make sure we see logs coming out from the relevant service.
@@ -184,286 +184,6 @@ web/
└── src/lib/ # Utilities & business logic
```
## Frontend Standards
### 1. Import Standards
**Always use absolute imports with the `@` prefix.**
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
```typescript
// ✅ Good
import { Button } from "@/components/ui/button";
import { useAuth } from "@/hooks/useAuth";
import { Text } from "@/refresh-components/texts/Text";
// ❌ Bad
import { Button } from "../../../components/ui/button";
import { useAuth } from "./hooks/useAuth";
```
### 2. React Component Functions
**Prefer regular functions over arrow functions for React components.**
**Reason:** Functions just become easier to read.
```typescript
// ✅ Good
function UserProfile({ userId }: UserProfileProps) {
return <div>User Profile</div>
}
// ❌ Bad
const UserProfile = ({ userId }: UserProfileProps) => {
return <div>User Profile</div>
}
```
### 3. Props Interface Extraction
**Extract prop types into their own interface definitions.**
**Reason:** Functions just become easier to read.
```typescript
// ✅ Good
interface UserCardProps {
user: User
showActions?: boolean
onEdit?: (userId: string) => void
}
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
return <div>User Card</div>
}
// ❌ Bad
function UserCard({
user,
showActions = false,
onEdit
}: {
user: User
showActions?: boolean
onEdit?: (userId: string) => void
}) {
return <div>User Card</div>
}
```
### 4. Spacing Guidelines
**Prefer padding over margins for spacing.**
**Reason:** We want to consolidate usage to paddings instead of margins.
```typescript
// ✅ Good
<div className="p-4 space-y-2">
<div className="p-2">Content</div>
</div>
// ❌ Bad
<div className="m-4 space-y-2">
<div className="m-2">Content</div>
</div>
```
### 5. Tailwind Dark Mode
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
```typescript
// ✅ Good - Standard components use `tailwind-themes/tailwind.config.js` / `src/app/css/colors.css`
<div className="bg-background-neutral-03 text-text-02">
Content
</div>
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
export const GithubIcon = createLogoIcon(githubLightIcon, {
monochromatic: true, // Will apply dark:invert internally
});
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
});
// ❌ Bad - Manual dark mode overrides
<div className="bg-white dark:bg-black text-black dark:text-white">
Content
</div>
```
### 6. Class Name Utilities
**Use the `cn` utility instead of raw string formatting for classNames.**
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
```typescript
import { cn } from '@/lib/utils'
// ✅ Good
<div className={cn(
'base-class',
isActive && 'active-class',
className
)}>
Content
</div>
// ❌ Bad
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
Content
</div>
```
### 7. Custom Hooks Organization
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
**Reason:** This is just a layout preference. Keeps code clean.
```typescript
// web/src/hooks/useUserData.ts
export function useUserData(userId: string) {
// hook implementation
}
// web/src/hooks/useLocalStorage.ts
export function useLocalStorage<T>(key: string, initialValue: T) {
// hook implementation
}
```
### 8. Icon Usage
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
```typescript
// ✅ Good
import SvgX from "@/icons/x";
import SvgMoreHorizontal from "@/icons/more-horizontal";
// ❌ Bad
import { User } from "lucide-react";
import { FiSearch } from "react-icons/fi";
```
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
If you need help with this step, reach out to `raunak@onyx.app`.
### 9. Text Rendering
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
```typescript
// ✅ Good
import { Text } from '@/refresh-components/texts/Text'
function UserCard({ name }: { name: string }) {
return (
<Text
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
text03
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
mainAction
>
{name}
</Text>
)
}
// ❌ Bad
function UserCard({ name }: { name: string }) {
return (
<div>
<h2>{name}</h2>
<p>User details</p>
</div>
)
}
```
### 10. Component Usage
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
```typescript
// ✅ Good
import Button from '@/refresh-components/buttons/Button'
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
import SvgPlusCircle from '@/icons/plus-circle'
function ContactForm() {
return (
<form>
<InputTypeIn placeholder="Search..." />
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
</form>
)
}
// ❌ Bad
function ContactForm() {
return (
<form>
<input placeholder="Name" />
<textarea placeholder="Message" />
<button type="submit">Submit</button>
</form>
)
}
```
### 11. Colors
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
**Available color categories:**
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
- **Actions:** `action-link-XX`, `action-danger-XX`
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
```typescript
// ✅ Good - Use custom Onyx color classes
<div className="bg-background-neutral-01 border border-border-02" />
<div className="bg-background-tint-02 border border-border-01" />
<div className="bg-status-success-01" />
<div className="bg-action-link-01" />
<div className="bg-theme-primary-05" />
// ❌ Bad - Do NOT use standard Tailwind colors
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
<div className="bg-white border border-slate-200" />
<div className="bg-green-100 text-green-700" />
<div className="bg-blue-100 text-blue-600" />
<div className="bg-indigo-500" />
```
### 12. Data Fetching
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).
## Database & Migrations
### Running Migrations
@@ -580,6 +300,14 @@ will be tailing their logs to this file.
- Token management and rate limiting
- Custom prompts and agent actions
## UI/UX Patterns
- Tailwind CSS with design system in `web/src/components/ui/`
- Radix UI and Headless UI for accessible components
- SWR for data fetching and caching
- Form validation with react-hook-form
- Error handling with popup notifications
## Creating a Plan
When creating a plan in the `plans` directory, make sure to include at least these elements:

View File

@@ -161,7 +161,7 @@ You will need Docker installed to run these containers.
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis/MinIO with:
```bash
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d index relational_db cache minio
docker compose up -d index relational_db cache minio
```
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)

View File

@@ -15,4 +15,3 @@ build/
dist/
.coverage
htmlcov/
model_server/legacy/

View File

@@ -13,10 +13,23 @@ RUN uv pip install --system --no-cache-dir --upgrade \
-r /tmp/requirements.txt && \
rm -rf ~/.cache/uv /tmp/*.txt
# Stage for downloading embedding models
# Stage for downloading tokenizers
FROM base AS tokenizers
RUN python -c "from transformers import AutoTokenizer; \
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1');"
# Stage for downloading Onyx models
FROM base AS onyx-models
RUN python -c "from huggingface_hub import snapshot_download; \
snapshot_download(repo_id='onyx-dot-app/hybrid-intent-token-classifier'); \
snapshot_download(repo_id='onyx-dot-app/information-content-model');"
# Stage for downloading embedding and reranking models
FROM base AS embedding-models
RUN python -c "from huggingface_hub import snapshot_download; \
snapshot_download('nomic-ai/nomic-embed-text-v1');"
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1');"
# Initialize SentenceTransformer to cache the custom architecture
RUN python -c "from sentence_transformers import SentenceTransformer; \
@@ -41,6 +54,8 @@ RUN groupadd -g 1001 onyx && \
# In case the user has volumes mounted to /app/.cache/huggingface that they've downloaded while
# running Onyx, move the current contents of the cache folder to a temporary location to ensure
# it's preserved in order to combine with the user's cache contents
COPY --chown=onyx:onyx --from=tokenizers /app/.cache/huggingface /app/.cache/temp_huggingface
COPY --chown=onyx:onyx --from=onyx-models /app/.cache/huggingface /app/.cache/temp_huggingface
COPY --chown=onyx:onyx --from=embedding-models /app/.cache/huggingface /app/.cache/temp_huggingface
WORKDIR /app

View File

@@ -39,9 +39,7 @@ config = context.config
if config.config_file_name is not None and config.attributes.get(
"configure_logger", True
):
# disable_existing_loggers=False prevents breaking pytest's caplog fixture
# See: https://pytest-alembic.readthedocs.io/en/latest/setup.html#caplog-issues
fileConfig(config.config_file_name, disable_existing_loggers=False)
fileConfig(config.config_file_name)
target_metadata = [Base.metadata, ResultModelBase.metadata]
@@ -225,6 +223,7 @@ def do_run_migrations(
) -> None:
if create_schema:
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
connection.execute(text("COMMIT"))
connection.execute(text(f'SET search_path TO "{schema_name}"'))
@@ -308,7 +307,6 @@ async def run_async_migrations() -> None:
schema_name=schema,
create_schema=create_schema,
)
await connection.commit()
except Exception as e:
logger.error(f"Error migrating schema {schema}: {e}")
if not continue_on_error:
@@ -346,7 +344,6 @@ async def run_async_migrations() -> None:
schema_name=schema,
create_schema=create_schema,
)
await connection.commit()
except Exception as e:
logger.error(f"Error migrating schema {schema}: {e}")
if not continue_on_error:
@@ -463,49 +460,8 @@ def run_migrations_offline() -> None:
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
Supports pytest-alembic by checking for a pre-configured connection
in context.config.attributes["connection"]. If present, uses that
connection/engine directly instead of creating a new async engine.
"""
# Check if pytest-alembic is providing a connection/engine
connectable = context.config.attributes.get("connection", None)
if connectable is not None:
# pytest-alembic is providing an engine - use it directly
logger.info("run_migrations_online starting (pytest-alembic mode).")
# For pytest-alembic, we use the default schema (public)
schema_name = context.config.attributes.get(
"schema_name", POSTGRES_DEFAULT_SCHEMA
)
# pytest-alembic passes an Engine, we need to get a connection from it
with connectable.connect() as connection:
# Set search path for the schema
connection.execute(text(f'SET search_path TO "{schema_name}"'))
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
include_object=include_object,
version_table_schema=schema_name,
include_schemas=True,
compare_type=True,
compare_server_default=True,
script_location=config.get_main_option("script_location"),
)
with context.begin_transaction():
context.run_migrations()
# Commit the transaction to ensure changes are visible to next migration
connection.commit()
else:
# Normal operation - use async migrations
logger.info("run_migrations_online starting.")
asyncio.run(run_async_migrations())
logger.info("run_migrations_online starting.")
asyncio.run(run_async_migrations())
if context.is_offline_mode():

View File

@@ -12,8 +12,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "23957775e5f5"
down_revision = "bc9771dccadf"
branch_labels = None
depends_on = None
branch_labels = None # type: ignore
depends_on = None # type: ignore
def upgrade() -> None:

View File

@@ -1,27 +0,0 @@
"""add last refreshed at mcp server
Revision ID: 2a391f840e85
Revises: 4cebcbc9b2ae
Create Date: 2025-12-06 15:19:59.766066
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembi.
revision = "2a391f840e85"
down_revision = "4cebcbc9b2ae"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"mcp_server",
sa.Column("last_refreshed_at", sa.DateTime(timezone=True), nullable=True),
)
def downgrade() -> None:
op.drop_column("mcp_server", "last_refreshed_at")

View File

@@ -1,46 +0,0 @@
"""usage_limits
Revision ID: 2b90f3af54b8
Revises: 9a0296d7421e
Create Date: 2026-01-03 16:55:30.449692
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2b90f3af54b8"
down_revision = "9a0296d7421e"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"tenant_usage",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"window_start", sa.DateTime(timezone=True), nullable=False, index=True
),
sa.Column("llm_cost_cents", sa.Float(), nullable=False, server_default="0.0"),
sa.Column("chunks_indexed", sa.Integer(), nullable=False, server_default="0"),
sa.Column("api_calls", sa.Integer(), nullable=False, server_default="0"),
sa.Column(
"non_streaming_api_calls", sa.Integer(), nullable=False, server_default="0"
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=True,
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("window_start", name="uq_tenant_usage_window"),
)
def downgrade() -> None:
op.drop_index("ix_tenant_usage_window_start", table_name="tenant_usage")
op.drop_table("tenant_usage")

View File

@@ -11,7 +11,7 @@ from pydantic import BaseModel, ConfigDict
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from onyx.llm.well_known_providers.llm_provider_options import (
from onyx.llm.llm_provider_options import (
fetch_model_names_for_provider_as_set,
fetch_visible_model_names_for_provider_as_set,
)

View File

@@ -1,27 +0,0 @@
"""add tab_index to tool_call
Revision ID: 4cebcbc9b2ae
Revises: a1b2c3d4e5f6
Create Date: 2025-12-16
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4cebcbc9b2ae"
down_revision = "a1b2c3d4e5f6"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"tool_call",
sa.Column("tab_index", sa.Integer(), nullable=False, server_default="0"),
)
def downgrade() -> None:
op.drop_column("tool_call", "tab_index")

View File

@@ -62,11 +62,6 @@ def upgrade() -> None:
)
"""
)
# Drop the temporary table to avoid conflicts if migration runs again
# (e.g., during upgrade -> downgrade -> upgrade cycles in tests)
op.execute("DROP TABLE IF EXISTS temp_connector_credential")
# If no exception was raised, alter the column
op.alter_column("credential", "source", nullable=True) # TODO modify
# # ### end Alembic commands ###

View File

@@ -85,122 +85,103 @@ class UserRow(NamedTuple):
def upgrade() -> None:
conn = op.get_bind()
# Step 1: Create or update the unified assistant (ID 0)
search_assistant = conn.execute(
sa.text("SELECT * FROM persona WHERE id = 0")
).fetchone()
# Start transaction
conn.execute(sa.text("BEGIN"))
if search_assistant:
# Update existing Search assistant to be the unified assistant
try:
# Step 1: Create or update the unified assistant (ID 0)
search_assistant = conn.execute(
sa.text("SELECT * FROM persona WHERE id = 0")
).fetchone()
if search_assistant:
# Update existing Search assistant to be the unified assistant
conn.execute(
sa.text(
"""
UPDATE persona
SET name = :name,
description = :description,
system_prompt = :system_prompt,
num_chunks = :num_chunks,
is_default_persona = true,
is_visible = true,
deleted = false,
display_priority = :display_priority,
llm_filter_extraction = :llm_filter_extraction,
llm_relevance_filter = :llm_relevance_filter,
recency_bias = :recency_bias,
chunks_above = :chunks_above,
chunks_below = :chunks_below,
datetime_aware = :datetime_aware,
starter_messages = null
WHERE id = 0
"""
),
INSERT_DICT,
)
else:
# Create new unified assistant with ID 0
conn.execute(
sa.text(
"""
INSERT INTO persona (
id, name, description, system_prompt, num_chunks,
is_default_persona, is_visible, deleted, display_priority,
llm_filter_extraction, llm_relevance_filter, recency_bias,
chunks_above, chunks_below, datetime_aware, starter_messages,
builtin_persona
) VALUES (
0, :name, :description, :system_prompt, :num_chunks,
true, true, false, :display_priority, :llm_filter_extraction,
:llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below,
:datetime_aware, null, true
)
"""
),
INSERT_DICT,
)
# Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0)
conn.execute(
sa.text(
"""
UPDATE persona
SET name = :name,
description = :description,
system_prompt = :system_prompt,
num_chunks = :num_chunks,
is_default_persona = true,
is_visible = true,
deleted = false,
display_priority = :display_priority,
llm_filter_extraction = :llm_filter_extraction,
llm_relevance_filter = :llm_relevance_filter,
recency_bias = :recency_bias,
chunks_above = :chunks_above,
chunks_below = :chunks_below,
datetime_aware = :datetime_aware,
starter_messages = null
WHERE id = 0
SET deleted = true, is_visible = false, is_default_persona = false
WHERE builtin_persona = true AND id != 0
"""
),
INSERT_DICT,
)
else:
# Create new unified assistant with ID 0
conn.execute(
sa.text(
"""
INSERT INTO persona (
id, name, description, system_prompt, num_chunks,
is_default_persona, is_visible, deleted, display_priority,
llm_filter_extraction, llm_relevance_filter, recency_bias,
chunks_above, chunks_below, datetime_aware, starter_messages,
builtin_persona
) VALUES (
0, :name, :description, :system_prompt, :num_chunks,
true, true, false, :display_priority, :llm_filter_extraction,
:llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below,
:datetime_aware, null, true
)
"""
),
INSERT_DICT,
)
)
# Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0)
conn.execute(
sa.text(
"""
UPDATE persona
SET deleted = true, is_visible = false, is_default_persona = false
WHERE builtin_persona = true AND id != 0
"""
)
)
# Step 3: Add all built-in tools to the unified assistant
# First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool
search_tool = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'")
).fetchone()
# Step 3: Add all built-in tools to the unified assistant
# First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool
search_tool = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'")
).fetchone()
if not search_tool:
raise ValueError(
"SearchTool not found in database. Ensure tools migration has run first."
)
if not search_tool:
raise ValueError(
"SearchTool not found in database. Ensure tools migration has run first."
)
image_gen_tool = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'")
).fetchone()
image_gen_tool = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'")
).fetchone()
if not image_gen_tool:
raise ValueError(
"ImageGenerationTool not found in database. Ensure tools migration has run first."
)
if not image_gen_tool:
raise ValueError(
"ImageGenerationTool not found in database. Ensure tools migration has run first."
)
# WebSearchTool is optional - may not be configured
web_search_tool = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'")
).fetchone()
# WebSearchTool is optional - may not be configured
web_search_tool = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'")
).fetchone()
# Clear existing tool associations for persona 0
conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0"))
# Clear existing tool associations for persona 0
conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0"))
# Add tools to the unified assistant
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (0, :tool_id)
ON CONFLICT DO NOTHING
"""
),
{"tool_id": search_tool[0]},
)
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (0, :tool_id)
ON CONFLICT DO NOTHING
"""
),
{"tool_id": image_gen_tool[0]},
)
if web_search_tool:
# Add tools to the unified assistant
conn.execute(
sa.text(
"""
@@ -209,148 +190,191 @@ def upgrade() -> None:
ON CONFLICT DO NOTHING
"""
),
{"tool_id": web_search_tool[0]},
{"tool_id": search_tool[0]},
)
# Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant
conn.execute(
sa.text(
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (0, :tool_id)
ON CONFLICT DO NOTHING
"""
UPDATE chat_session
SET persona_id = 0
WHERE persona_id IN (
SELECT id FROM persona WHERE builtin_persona = true AND id != 0
)
"""
),
{"tool_id": image_gen_tool[0]},
)
)
# Step 5: Migrate user preferences - remove references to all builtin assistants
# First, get all builtin assistant IDs (except 0)
builtin_assistants_result = conn.execute(
sa.text(
"""
SELECT id FROM persona
WHERE builtin_persona = true AND id != 0
"""
)
).fetchall()
builtin_assistant_ids = [row[0] for row in builtin_assistants_result]
# Get all users with preferences
users_result = conn.execute(
sa.text(
"""
SELECT id, chosen_assistants, visible_assistants,
hidden_assistants, pinned_assistants
FROM "user"
"""
)
).fetchall()
for user_row in users_result:
user = UserRow(*user_row)
user_id: UUID = user.id
updates: dict[str, Any] = {}
# Remove all builtin assistants from chosen_assistants
if user.chosen_assistants:
new_chosen: list[int] = [
assistant_id
for assistant_id in user.chosen_assistants
if assistant_id not in builtin_assistant_ids
]
if new_chosen != user.chosen_assistants:
updates["chosen_assistants"] = json.dumps(new_chosen)
# Remove all builtin assistants from visible_assistants
if user.visible_assistants:
new_visible: list[int] = [
assistant_id
for assistant_id in user.visible_assistants
if assistant_id not in builtin_assistant_ids
]
if new_visible != user.visible_assistants:
updates["visible_assistants"] = json.dumps(new_visible)
# Add all builtin assistants to hidden_assistants
if user.hidden_assistants:
new_hidden: list[int] = list(user.hidden_assistants)
for old_id in builtin_assistant_ids:
if old_id not in new_hidden:
new_hidden.append(old_id)
if new_hidden != user.hidden_assistants:
updates["hidden_assistants"] = json.dumps(new_hidden)
else:
updates["hidden_assistants"] = json.dumps(builtin_assistant_ids)
# Remove all builtin assistants from pinned_assistants
if user.pinned_assistants:
new_pinned: list[int] = [
assistant_id
for assistant_id in user.pinned_assistants
if assistant_id not in builtin_assistant_ids
]
if new_pinned != user.pinned_assistants:
updates["pinned_assistants"] = json.dumps(new_pinned)
# Apply updates if any
if updates:
set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()])
updates["user_id"] = str(user_id) # Convert UUID to string for SQL
if web_search_tool:
conn.execute(
sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'),
updates,
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (0, :tool_id)
ON CONFLICT DO NOTHING
"""
),
{"tool_id": web_search_tool[0]},
)
# Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant
conn.execute(
sa.text(
"""
UPDATE chat_session
SET persona_id = 0
WHERE persona_id IN (
SELECT id FROM persona WHERE builtin_persona = true AND id != 0
)
"""
)
)
# Step 5: Migrate user preferences - remove references to all builtin assistants
# First, get all builtin assistant IDs (except 0)
builtin_assistants_result = conn.execute(
sa.text(
"""
SELECT id FROM persona
WHERE builtin_persona = true AND id != 0
"""
)
).fetchall()
builtin_assistant_ids = [row[0] for row in builtin_assistants_result]
# Get all users with preferences
users_result = conn.execute(
sa.text(
"""
SELECT id, chosen_assistants, visible_assistants,
hidden_assistants, pinned_assistants
FROM "user"
"""
)
).fetchall()
for user_row in users_result:
user = UserRow(*user_row)
user_id: UUID = user.id
updates: dict[str, Any] = {}
# Remove all builtin assistants from chosen_assistants
if user.chosen_assistants:
new_chosen: list[int] = [
assistant_id
for assistant_id in user.chosen_assistants
if assistant_id not in builtin_assistant_ids
]
if new_chosen != user.chosen_assistants:
updates["chosen_assistants"] = json.dumps(new_chosen)
# Remove all builtin assistants from visible_assistants
if user.visible_assistants:
new_visible: list[int] = [
assistant_id
for assistant_id in user.visible_assistants
if assistant_id not in builtin_assistant_ids
]
if new_visible != user.visible_assistants:
updates["visible_assistants"] = json.dumps(new_visible)
# Add all builtin assistants to hidden_assistants
if user.hidden_assistants:
new_hidden: list[int] = list(user.hidden_assistants)
for old_id in builtin_assistant_ids:
if old_id not in new_hidden:
new_hidden.append(old_id)
if new_hidden != user.hidden_assistants:
updates["hidden_assistants"] = json.dumps(new_hidden)
else:
updates["hidden_assistants"] = json.dumps(builtin_assistant_ids)
# Remove all builtin assistants from pinned_assistants
if user.pinned_assistants:
new_pinned: list[int] = [
assistant_id
for assistant_id in user.pinned_assistants
if assistant_id not in builtin_assistant_ids
]
if new_pinned != user.pinned_assistants:
updates["pinned_assistants"] = json.dumps(new_pinned)
# Apply updates if any
if updates:
set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()])
updates["user_id"] = str(user_id) # Convert UUID to string for SQL
conn.execute(
sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'),
updates,
)
# Commit transaction
conn.execute(sa.text("COMMIT"))
except Exception as e:
# Rollback on error
conn.execute(sa.text("ROLLBACK"))
raise e
def downgrade() -> None:
conn = op.get_bind()
# Only restore General (ID -1) and Art (ID -3) assistants
# Step 1: Keep Search assistant (ID 0) as default but restore original state
conn.execute(
sa.text(
# Start transaction
conn.execute(sa.text("BEGIN"))
try:
# Only restore General (ID -1) and Art (ID -3) assistants
# Step 1: Keep Search assistant (ID 0) as default but restore original state
conn.execute(
sa.text(
"""
UPDATE persona
SET is_default_persona = true,
is_visible = true,
deleted = false
WHERE id = 0
"""
UPDATE persona
SET is_default_persona = true,
is_visible = true,
deleted = false
WHERE id = 0
"""
)
)
)
# Step 2: Restore General assistant (ID -1)
conn.execute(
sa.text(
# Step 2: Restore General assistant (ID -1)
conn.execute(
sa.text(
"""
UPDATE persona
SET deleted = false,
is_visible = true,
is_default_persona = true
WHERE id = :general_assistant_id
"""
UPDATE persona
SET deleted = false,
is_visible = true,
is_default_persona = true
WHERE id = :general_assistant_id
"""
),
{"general_assistant_id": GENERAL_ASSISTANT_ID},
)
),
{"general_assistant_id": GENERAL_ASSISTANT_ID},
)
# Step 3: Restore Art assistant (ID -3)
conn.execute(
sa.text(
# Step 3: Restore Art assistant (ID -3)
conn.execute(
sa.text(
"""
UPDATE persona
SET deleted = false,
is_visible = true,
is_default_persona = true
WHERE id = :art_assistant_id
"""
UPDATE persona
SET deleted = false,
is_visible = true,
is_default_persona = true
WHERE id = :art_assistant_id
"""
),
{"art_assistant_id": ART_ASSISTANT_ID},
)
),
{"art_assistant_id": ART_ASSISTANT_ID},
)
# Note: We don't restore the original tool associations, names, or descriptions
# as those would require more complex logic to determine original state.
# We also cannot restore original chat session persona_ids as we don't
# have the original mappings.
# Other builtin assistants remain deleted as per the requirement.
# Note: We don't restore the original tool associations, names, or descriptions
# as those would require more complex logic to determine original state.
# We also cannot restore original chat session persona_ids as we don't
# have the original mappings.
# Other builtin assistants remain deleted as per the requirement.
# Commit transaction
conn.execute(sa.text("COMMIT"))
except Exception as e:
# Rollback on error
conn.execute(sa.text("ROLLBACK"))
raise e

View File

@@ -1,35 +0,0 @@
"""backend driven notification details
Revision ID: 5c3dca366b35
Revises: 9087b548dd69
Create Date: 2026-01-06 16:03:11.413724
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5c3dca366b35"
down_revision = "9087b548dd69"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"notification",
sa.Column(
"title", sa.String(), nullable=False, server_default="New Notification"
),
)
op.add_column(
"notification",
sa.Column("description", sa.String(), nullable=True, server_default=""),
)
def downgrade() -> None:
op.drop_column("notification", "title")
op.drop_column("notification", "description")

View File

@@ -1,75 +0,0 @@
"""nullify_default_task_prompt
Revision ID: 699221885109
Revises: 7e490836d179
Create Date: 2025-12-30 10:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "699221885109"
down_revision = "7e490836d179"
branch_labels = None
depends_on = None
DEFAULT_PERSONA_ID = 0
def upgrade() -> None:
# Make task_prompt column nullable
# Note: The model had nullable=True but the DB column was NOT NULL until this point
op.alter_column(
"persona",
"task_prompt",
nullable=True,
)
# Set task_prompt to NULL for the default persona
conn = op.get_bind()
conn.execute(
sa.text(
"""
UPDATE persona
SET task_prompt = NULL
WHERE id = :persona_id
"""
),
{"persona_id": DEFAULT_PERSONA_ID},
)
def downgrade() -> None:
# Restore task_prompt to empty string for the default persona
conn = op.get_bind()
conn.execute(
sa.text(
"""
UPDATE persona
SET task_prompt = ''
WHERE id = :persona_id AND task_prompt IS NULL
"""
),
{"persona_id": DEFAULT_PERSONA_ID},
)
# Set any remaining NULL task_prompts to empty string before making non-nullable
conn.execute(
sa.text(
"""
UPDATE persona
SET task_prompt = ''
WHERE task_prompt IS NULL
"""
)
)
# Revert task_prompt column to not nullable
op.alter_column(
"persona",
"task_prompt",
nullable=False,
)

View File

@@ -1,54 +0,0 @@
"""add image generation config table
Revision ID: 7206234e012a
Revises: 699221885109
Create Date: 2025-12-21 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7206234e012a"
down_revision = "699221885109"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"image_generation_config",
sa.Column("image_provider_id", sa.String(), primary_key=True),
sa.Column("model_configuration_id", sa.Integer(), nullable=False),
sa.Column("is_default", sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(
["model_configuration_id"],
["model_configuration.id"],
ondelete="CASCADE",
),
)
op.create_index(
"ix_image_generation_config_is_default",
"image_generation_config",
["is_default"],
unique=False,
)
op.create_index(
"ix_image_generation_config_model_configuration_id",
"image_generation_config",
["model_configuration_id"],
unique=False,
)
def downgrade() -> None:
op.drop_index(
"ix_image_generation_config_model_configuration_id",
table_name="image_generation_config",
)
op.drop_index(
"ix_image_generation_config_is_default", table_name="image_generation_config"
)
op.drop_table("image_generation_config")

View File

@@ -10,7 +10,7 @@ from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from onyx.llm.well_known_providers.llm_provider_options import (
from onyx.llm.llm_provider_options import (
fetch_model_names_for_provider_as_set,
fetch_visible_model_names_for_provider_as_set,
)

View File

@@ -1,80 +0,0 @@
"""nullify_default_system_prompt
Revision ID: 7e490836d179
Revises: c1d2e3f4a5b6
Create Date: 2025-12-29 16:54:36.635574
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7e490836d179"
down_revision = "c1d2e3f4a5b6"
branch_labels = None
depends_on = None
# This is the default system prompt from the previous migration (87c52ec39f84)
# ruff: noqa: E501, W605 start
PREVIOUS_DEFAULT_SYSTEM_PROMPT = """
You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the user's intent, ask clarifying questions when needed, think step-by-step through complex problems, provide clear and accurate answers, and proactively anticipate helpful follow-up information. Always prioritize being truthful, nuanced, insightful, and efficient.
The current date is [[CURRENT_DATETIME]].[[CITATION_GUIDANCE]]
# Response Style
You use different text styles, bolding, emojis (sparingly), block quotes, and other formatting to make your responses more readable and engaging.
You use proper Markdown and LaTeX to format your responses for math, scientific, and chemical formulas, symbols, etc.: '$$\\n[expression]\\n$$' for standalone cases and '\\( [expression] \\)' when inline.
For code you prefer to use Markdown and specify the language.
You can use horizontal rules (---) to separate sections of your responses.
You can use Markdown tables to format your responses for data, lists, and other structured information.
""".lstrip()
# ruff: noqa: E501, W605 end
def upgrade() -> None:
# Make system_prompt column nullable (model already has nullable=True but DB doesn't)
op.alter_column(
"persona",
"system_prompt",
nullable=True,
)
# Set system_prompt to NULL where it matches the previous default
conn = op.get_bind()
conn.execute(
sa.text(
"""
UPDATE persona
SET system_prompt = NULL
WHERE system_prompt = :previous_default
"""
),
{"previous_default": PREVIOUS_DEFAULT_SYSTEM_PROMPT},
)
def downgrade() -> None:
# Restore the default system prompt for personas that have NULL
# Note: This may restore the prompt to personas that originally had NULL
# before this migration, but there's no way to distinguish them
conn = op.get_bind()
conn.execute(
sa.text(
"""
UPDATE persona
SET system_prompt = :previous_default
WHERE system_prompt IS NULL
"""
),
{"previous_default": PREVIOUS_DEFAULT_SYSTEM_PROMPT},
)
# Revert system_prompt column to not nullable
op.alter_column(
"persona",
"system_prompt",
nullable=False,
)

View File

@@ -42,13 +42,13 @@ def upgrade() -> None:
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
server_default=sa.text("now()"), # type: ignore
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
server_default=sa.text("now()"), # type: ignore
nullable=False,
),
)
@@ -63,13 +63,13 @@ def upgrade() -> None:
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
server_default=sa.text("now()"), # type: ignore
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
server_default=sa.text("now()"), # type: ignore
nullable=False,
),
sa.ForeignKeyConstraint(

View File

@@ -1,49 +0,0 @@
"""notifications constraint, sort index, and cleanup old notifications
Revision ID: 8405ca81cc83
Revises: a3c1a7904cd0
Create Date: 2026-01-07 16:43:44.855156
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "8405ca81cc83"
down_revision = "a3c1a7904cd0"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create unique index for notification deduplication.
# This enables atomic ON CONFLICT DO NOTHING inserts in batch_create_notifications.
#
# Uses COALESCE to handle NULL additional_data (NULLs are normally distinct
# in unique constraints, but we want NULL == NULL for deduplication).
# The '{}' represents an empty JSONB object as the NULL replacement.
# Clean up legacy notifications first
op.execute("DELETE FROM notification WHERE title = 'New Notification'")
op.execute(
"""
CREATE UNIQUE INDEX IF NOT EXISTS ix_notification_user_type_data
ON notification (user_id, notif_type, COALESCE(additional_data, '{}'::jsonb))
"""
)
# Create index for efficient notification sorting by user
# Covers: WHERE user_id = ? ORDER BY dismissed, first_shown DESC
op.execute(
"""
CREATE INDEX IF NOT EXISTS ix_notification_user_sort
ON notification (user_id, dismissed, first_shown DESC)
"""
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS ix_notification_user_type_data")
op.execute("DROP INDEX IF EXISTS ix_notification_user_sort")

View File

@@ -1,136 +0,0 @@
"""seed_default_image_gen_config
Revision ID: 9087b548dd69
Revises: 2b90f3af54b8
Create Date: 2026-01-05 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9087b548dd69"
down_revision = "2b90f3af54b8"
branch_labels = None
depends_on = None
# Constants for default image generation config
# Source: web/src/app/admin/configuration/image-generation/constants.ts
IMAGE_PROVIDER_ID = "openai_gpt_image_1"
MODEL_NAME = "gpt-image-1"
PROVIDER_NAME = "openai"
def upgrade() -> None:
conn = op.get_bind()
# Check if image_generation_config table already has records
existing_configs = (
conn.execute(sa.text("SELECT COUNT(*) FROM image_generation_config")).scalar()
or 0
)
if existing_configs > 0:
# Skip if configs already exist - user may have configured manually
return
# Find the first OpenAI LLM provider
openai_provider = conn.execute(
sa.text(
"""
SELECT id, api_key
FROM llm_provider
WHERE provider = :provider
ORDER BY id
LIMIT 1
"""
),
{"provider": PROVIDER_NAME},
).fetchone()
if not openai_provider:
# No OpenAI provider found - nothing to do
return
source_provider_id, api_key = openai_provider
# Create new LLM provider for image generation (clone only api_key)
result = conn.execute(
sa.text(
"""
INSERT INTO llm_provider (
name, provider, api_key, api_base, api_version,
deployment_name, default_model_name, is_public,
is_default_provider, is_default_vision_provider, is_auto_mode
)
VALUES (
:name, :provider, :api_key, NULL, NULL,
NULL, :default_model_name, :is_public,
NULL, NULL, :is_auto_mode
)
RETURNING id
"""
),
{
"name": f"Image Gen - {IMAGE_PROVIDER_ID}",
"provider": PROVIDER_NAME,
"api_key": api_key,
"default_model_name": MODEL_NAME,
"is_public": True,
"is_auto_mode": False,
},
)
new_provider_id = result.scalar()
# Create model configuration
result = conn.execute(
sa.text(
"""
INSERT INTO model_configuration (
llm_provider_id, name, is_visible, max_input_tokens,
supports_image_input, display_name
)
VALUES (
:llm_provider_id, :name, :is_visible, :max_input_tokens,
:supports_image_input, :display_name
)
RETURNING id
"""
),
{
"llm_provider_id": new_provider_id,
"name": MODEL_NAME,
"is_visible": True,
"max_input_tokens": None,
"supports_image_input": False,
"display_name": None,
},
)
model_config_id = result.scalar()
# Create image generation config
conn.execute(
sa.text(
"""
INSERT INTO image_generation_config (
image_provider_id, model_configuration_id, is_default
)
VALUES (
:image_provider_id, :model_configuration_id, :is_default
)
"""
),
{
"image_provider_id": IMAGE_PROVIDER_ID,
"model_configuration_id": model_config_id,
"is_default": True,
},
)
def downgrade() -> None:
# We don't remove the config on downgrade since it's safe to keep around
# If we upgrade again, it will be a no-op due to the existing records check
pass

View File

@@ -1,33 +0,0 @@
"""add_is_auto_mode_to_llm_provider
Revision ID: 9a0296d7421e
Revises: 7206234e012a
Create Date: 2025-12-17 18:14:29.620981
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9a0296d7421e"
down_revision = "7206234e012a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"llm_provider",
sa.Column(
"is_auto_mode",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
def downgrade() -> None:
op.drop_column("llm_provider", "is_auto_mode")

View File

@@ -234,8 +234,6 @@ def downgrade() -> None:
if "instructions" in columns:
op.drop_column("user_project", "instructions")
op.execute("ALTER TABLE user_project RENAME TO user_folder")
# Update NULL descriptions to empty string before setting NOT NULL constraint
op.execute("UPDATE user_folder SET description = '' WHERE description IS NULL")
op.alter_column("user_folder", "description", nullable=False)
logger.info("Renamed user_project back to user_folder")

View File

@@ -42,13 +42,20 @@ TOOL_DESCRIPTIONS = {
def upgrade() -> None:
conn = op.get_bind()
for tool_id, description in TOOL_DESCRIPTIONS.items():
conn.execute(
sa.text(
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
),
{"description": description, "tool_id": tool_id},
)
conn.execute(sa.text("BEGIN"))
try:
for tool_id, description in TOOL_DESCRIPTIONS.items():
conn.execute(
sa.text(
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
),
{"description": description, "tool_id": tool_id},
)
conn.execute(sa.text("COMMIT"))
except Exception as e:
conn.execute(sa.text("ROLLBACK"))
raise e
def downgrade() -> None:

View File

@@ -1,49 +0,0 @@
"""add license table
Revision ID: a1b2c3d4e5f6
Revises: a01bf2971c5d
Create Date: 2025-12-04 10:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a1b2c3d4e5f6"
down_revision = "a01bf2971c5d"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"license",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("license_data", sa.Text(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
)
# Singleton pattern - only ever one row in this table
op.create_index(
"idx_license_singleton",
"license",
[sa.text("(true)")],
unique=True,
)
def downgrade() -> None:
op.drop_index("idx_license_singleton", table_name="license")
op.drop_table("license")

View File

@@ -1,27 +0,0 @@
"""Remove fast_default_model_name from llm_provider
Revision ID: a2b3c4d5e6f7
Revises: 2a391f840e85
Create Date: 2024-12-17
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a2b3c4d5e6f7"
down_revision = "2a391f840e85"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.drop_column("llm_provider", "fast_default_model_name")
def downgrade() -> None:
op.add_column(
"llm_provider",
sa.Column("fast_default_model_name", sa.String(), nullable=True),
)

View File

@@ -1,39 +0,0 @@
"""remove userfile related deprecated fields
Revision ID: a3c1a7904cd0
Revises: 5c3dca366b35
Create Date: 2026-01-06 13:00:30.634396
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a3c1a7904cd0"
down_revision = "5c3dca366b35"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_column("user_file", "document_id")
op.drop_column("user_file", "document_id_migrated")
op.drop_column("connector_credential_pair", "is_user_file")
def downgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column("is_user_file", sa.Boolean(), nullable=False, server_default="false"),
)
op.add_column(
"user_file",
sa.Column("document_id", sa.String(), nullable=True),
)
op.add_column(
"user_file",
sa.Column(
"document_id_migrated", sa.Boolean(), nullable=False, server_default="true"
),
)

View File

@@ -280,14 +280,6 @@ def downgrade() -> None:
op.add_column(
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
)
# Recreate the FK constraint that was implicitly dropped when the column was dropped
op.create_foreign_key(
"fk_chat_message_persona",
"chat_message",
"persona",
["alternate_assistant_id"],
["id"],
)
op.add_column(
"chat_message", sa.Column("rephrased_query", sa.Text(), nullable=True)
)

View File

@@ -1,46 +0,0 @@
"""Drop milestone table
Revision ID: b8c9d0e1f2a3
Revises: a2b3c4d5e6f7
Create Date: 2025-12-18
"""
from alembic import op
import sqlalchemy as sa
import fastapi_users_db_sqlalchemy
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "b8c9d0e1f2a3"
down_revision = "a2b3c4d5e6f7"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_table("milestone")
def downgrade() -> None:
op.create_table(
"milestone",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("tenant_id", sa.String(), nullable=True),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column("event_type", sa.String(), nullable=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("event_tracker", postgresql.JSONB(), nullable=True),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("event_type", name="uq_milestone_event_type"),
)

View File

@@ -1,51 +0,0 @@
"""add_deep_research_tool
Revision ID: c1d2e3f4a5b6
Revises: b8c9d0e1f2a3
Create Date: 2025-12-18 16:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c1d2e3f4a5b6"
down_revision = "b8c9d0e1f2a3"
branch_labels = None
depends_on = None
DEEP_RESEARCH_TOOL = {
"name": "ResearchAgent",
"display_name": "Research Agent",
"description": "The Research Agent is a sub-agent that conducts research on a specific topic.",
"in_code_tool_id": "ResearchAgent",
}
def upgrade() -> None:
conn = op.get_bind()
conn.execute(
sa.text(
"""
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
VALUES (:name, :display_name, :description, :in_code_tool_id, false)
"""
),
DEEP_RESEARCH_TOOL,
)
def downgrade() -> None:
conn = op.get_bind()
conn.execute(
sa.text(
"""
DELETE FROM tool
WHERE in_code_tool_id = :in_code_tool_id
"""
),
{"in_code_tool_id": DEEP_RESEARCH_TOOL["in_code_tool_id"]},
)

View File

@@ -257,8 +257,8 @@ def _migrate_files_to_external_storage() -> None:
print(f"File {file_id} not found in PostgreSQL storage.")
continue
lobj_id = cast(int, file_record.lobj_oid)
file_metadata = cast(Any, file_record.file_metadata)
lobj_id = cast(int, file_record.lobj_oid) # type: ignore
file_metadata = cast(Any, file_record.file_metadata) # type: ignore
# Read file content from PostgreSQL
try:
@@ -280,7 +280,7 @@ def _migrate_files_to_external_storage() -> None:
else:
# Convert other types to dict if possible, otherwise None
try:
file_metadata = dict(file_record.file_metadata)
file_metadata = dict(file_record.file_metadata) # type: ignore
except (TypeError, ValueError):
file_metadata = None

View File

@@ -70,66 +70,80 @@ BUILT_IN_TOOLS = [
def upgrade() -> None:
conn = op.get_bind()
# Get existing tools to check what already exists
existing_tools = conn.execute(
sa.text("SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL")
).fetchall()
existing_tool_ids = {row[0] for row in existing_tools}
# Start transaction
conn.execute(sa.text("BEGIN"))
# Insert or update built-in tools
for tool in BUILT_IN_TOOLS:
in_code_id = tool["in_code_tool_id"]
try:
# Get existing tools to check what already exists
existing_tools = conn.execute(
sa.text(
"SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL"
)
).fetchall()
existing_tool_ids = {row[0] for row in existing_tools}
# Handle historical rename: InternetSearchTool -> WebSearchTool
if (
in_code_id == "WebSearchTool"
and "WebSearchTool" not in existing_tool_ids
and "InternetSearchTool" in existing_tool_ids
):
# Rename the existing InternetSearchTool row in place and update fields
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description,
in_code_tool_id = :in_code_tool_id
WHERE in_code_tool_id = 'InternetSearchTool'
"""
),
tool,
)
# Keep the local view of existing ids in sync to avoid duplicate insert
existing_tool_ids.discard("InternetSearchTool")
existing_tool_ids.add("WebSearchTool")
continue
# Insert or update built-in tools
for tool in BUILT_IN_TOOLS:
in_code_id = tool["in_code_tool_id"]
if in_code_id in existing_tool_ids:
# Update existing tool
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description
WHERE in_code_tool_id = :in_code_tool_id
"""
),
tool,
)
else:
# Insert new tool
conn.execute(
sa.text(
"""
INSERT INTO tool (name, display_name, description, in_code_tool_id)
VALUES (:name, :display_name, :description, :in_code_tool_id)
"""
),
tool,
)
# Handle historical rename: InternetSearchTool -> WebSearchTool
if (
in_code_id == "WebSearchTool"
and "WebSearchTool" not in existing_tool_ids
and "InternetSearchTool" in existing_tool_ids
):
# Rename the existing InternetSearchTool row in place and update fields
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description,
in_code_tool_id = :in_code_tool_id
WHERE in_code_tool_id = 'InternetSearchTool'
"""
),
tool,
)
# Keep the local view of existing ids in sync to avoid duplicate insert
existing_tool_ids.discard("InternetSearchTool")
existing_tool_ids.add("WebSearchTool")
continue
if in_code_id in existing_tool_ids:
# Update existing tool
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description
WHERE in_code_tool_id = :in_code_tool_id
"""
),
tool,
)
else:
# Insert new tool
conn.execute(
sa.text(
"""
INSERT INTO tool (name, display_name, description, in_code_tool_id)
VALUES (:name, :display_name, :description, :in_code_tool_id)
"""
),
tool,
)
# Commit transaction
conn.execute(sa.text("COMMIT"))
except Exception as e:
# Rollback on error
conn.execute(sa.text("ROLLBACK"))
raise e
def downgrade() -> None:

View File

@@ -1,64 +0,0 @@
"""sync_exa_api_key_to_content_provider
Revision ID: d1b637d7050a
Revises: d25168c2beee
Create Date: 2026-01-09 15:54:15.646249
"""
from alembic import op
from sqlalchemy import text
# revision identifiers, used by Alembic.
revision = "d1b637d7050a"
down_revision = "d25168c2beee"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Exa uses a shared API key between search and content providers.
# For existing Exa search providers with API keys, create the corresponding
# content provider if it doesn't exist yet.
connection = op.get_bind()
# Check if Exa search provider exists with an API key
result = connection.execute(
text(
"""
SELECT api_key FROM internet_search_provider
WHERE provider_type = 'exa' AND api_key IS NOT NULL
LIMIT 1
"""
)
)
row = result.fetchone()
if row:
api_key = row[0]
# Create Exa content provider with the shared key
connection.execute(
text(
"""
INSERT INTO internet_content_provider
(name, provider_type, api_key, is_active)
VALUES ('Exa', 'exa', :api_key, false)
ON CONFLICT (name) DO NOTHING
"""
),
{"api_key": api_key},
)
def downgrade() -> None:
# Remove the Exa content provider that was created by this migration
connection = op.get_bind()
connection.execute(
text(
"""
DELETE FROM internet_content_provider
WHERE provider_type = 'exa'
"""
)
)

View File

@@ -1,86 +0,0 @@
"""tool_name_consistency
Revision ID: d25168c2beee
Revises: 8405ca81cc83
Create Date: 2026-01-11 17:54:40.135777
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "d25168c2beee"
down_revision = "8405ca81cc83"
branch_labels = None
depends_on = None
# Currently the seeded tools have the in_code_tool_id == name
CURRENT_TOOL_NAME_MAPPING = [
"SearchTool",
"WebSearchTool",
"ImageGenerationTool",
"PythonTool",
"OpenURLTool",
"KnowledgeGraphTool",
"ResearchAgent",
]
# Mapping of in_code_tool_id -> name
# These are the expected names that we want in the database
EXPECTED_TOOL_NAME_MAPPING = {
"SearchTool": "internal_search",
"WebSearchTool": "web_search",
"ImageGenerationTool": "generate_image",
"PythonTool": "python",
"OpenURLTool": "open_url",
"KnowledgeGraphTool": "run_kg_search",
"ResearchAgent": "research_agent",
}
def upgrade() -> None:
conn = op.get_bind()
# Mapping of in_code_tool_id to the NAME constant from each tool class
# These match the .name property of each tool implementation
tool_name_mapping = EXPECTED_TOOL_NAME_MAPPING
# Update the name column for each tool based on its in_code_tool_id
for in_code_tool_id, expected_name in tool_name_mapping.items():
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :expected_name
WHERE in_code_tool_id = :in_code_tool_id
"""
),
{
"expected_name": expected_name,
"in_code_tool_id": in_code_tool_id,
},
)
def downgrade() -> None:
conn = op.get_bind()
# Reverse the migration by setting name back to in_code_tool_id
# This matches the original pattern where name was the class name
for in_code_tool_id in CURRENT_TOOL_NAME_MAPPING:
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :current_name
WHERE in_code_tool_id = :in_code_tool_id
"""
),
{
"current_name": in_code_tool_id,
"in_code_tool_id": in_code_tool_id,
},
)

View File

@@ -11,8 +11,8 @@ import sqlalchemy as sa
revision = "e209dc5a8156"
down_revision = "48d14957fe80"
branch_labels = None
depends_on = None
branch_labels = None # type: ignore
depends_on = None # type: ignore
def upgrade() -> None:

View File

@@ -8,7 +8,7 @@ Create Date: 2025-11-28 11:15:37.667340
from alembic import op
import sqlalchemy as sa
from onyx.db.enums import (
from onyx.db.enums import ( # type: ignore[import-untyped]
MCPTransport,
MCPAuthenticationType,
MCPAuthenticationPerformer,

View File

@@ -20,9 +20,7 @@ config = context.config
if config.config_file_name is not None and config.attributes.get(
"configure_logger", True
):
# disable_existing_loggers=False prevents breaking pytest's caplog fixture
# See: https://pytest-alembic.readthedocs.io/en/latest/setup.html#caplog-issues
fileConfig(config.config_file_name, disable_existing_loggers=False)
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
@@ -84,9 +82,9 @@ def run_migrations_offline() -> None:
def do_run_migrations(connection: Connection) -> None:
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore[arg-type]
target_metadata=target_metadata, # type: ignore
include_object=include_object,
)
) # type: ignore
with context.begin_transaction():
context.run_migrations()
@@ -110,24 +108,9 @@ async def run_async_migrations() -> None:
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
"""Run migrations in 'online' mode."""
Supports pytest-alembic by checking for a pre-configured connection
in context.config.attributes["connection"]. If present, uses that
connection/engine directly instead of creating a new async engine.
"""
# Check if pytest-alembic is providing a connection/engine
connectable = context.config.attributes.get("connection", None)
if connectable is not None:
# pytest-alembic is providing an engine - use it directly
with connectable.connect() as connection:
do_run_migrations(connection)
# Commit to ensure changes are visible to next migration
connection.commit()
else:
# Normal operation - use async migrations
asyncio.run(run_async_migrations())
asyncio.run(run_async_migrations())
if context.is_offline_mode():

View File

@@ -1,15 +1,11 @@
group "default" {
targets = ["backend", "model-server", "web"]
targets = ["backend", "model-server"]
}
variable "BACKEND_REPOSITORY" {
default = "onyxdotapp/onyx-backend"
}
variable "WEB_SERVER_REPOSITORY" {
default = "onyxdotapp/onyx-web-server"
}
variable "MODEL_SERVER_REPOSITORY" {
default = "onyxdotapp/onyx-model-server"
}
@@ -23,7 +19,7 @@ variable "TAG" {
}
target "backend" {
context = "backend"
context = "."
dockerfile = "Dockerfile"
cache-from = ["type=registry,ref=${BACKEND_REPOSITORY}:latest"]
@@ -32,18 +28,8 @@ target "backend" {
tags = ["${BACKEND_REPOSITORY}:${TAG}"]
}
target "web" {
context = "web"
dockerfile = "Dockerfile"
cache-from = ["type=registry,ref=${WEB_SERVER_REPOSITORY}:latest"]
cache-to = ["type=inline"]
tags = ["${WEB_SERVER_REPOSITORY}:${TAG}"]
}
target "model-server" {
context = "backend"
context = "."
dockerfile = "Dockerfile.model_server"
@@ -54,7 +40,7 @@ target "model-server" {
}
target "integration" {
context = "backend"
context = "."
dockerfile = "tests/integration/Dockerfile"
// Provide the base image via build context from the backend target

View File

@@ -109,6 +109,11 @@ CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS = float(
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
OPENAI_DEFAULT_API_KEY = os.environ.get("OPENAI_DEFAULT_API_KEY")
ANTHROPIC_DEFAULT_API_KEY = os.environ.get("ANTHROPIC_DEFAULT_API_KEY")
COHERE_DEFAULT_API_KEY = os.environ.get("COHERE_DEFAULT_API_KEY")
# JWT Public Key URL
JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)

View File

@@ -118,6 +118,6 @@ def fetch_document_sets(
.all()
)
document_set_with_cc_pairs.append((document_set, cc_pairs))
document_set_with_cc_pairs.append((document_set, cc_pairs)) # type: ignore
return document_set_with_cc_pairs

View File

@@ -1,278 +0,0 @@
"""Database and cache operations for the license table."""
from datetime import datetime
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from ee.onyx.server.license.models import LicenseMetadata
from ee.onyx.server.license.models import LicensePayload
from ee.onyx.server.license.models import LicenseSource
from onyx.db.models import License
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
LICENSE_METADATA_KEY = "license:metadata"
LICENSE_CACHE_TTL_SECONDS = 86400 # 24 hours
# -----------------------------------------------------------------------------
# Database CRUD Operations
# -----------------------------------------------------------------------------
def get_license(db_session: Session) -> License | None:
"""
Get the current license (singleton pattern - only one row).
Args:
db_session: Database session
Returns:
License object if exists, None otherwise
"""
return db_session.execute(select(License)).scalars().first()
def upsert_license(db_session: Session, license_data: str) -> License:
"""
Insert or update the license (singleton pattern).
Args:
db_session: Database session
license_data: Base64-encoded signed license blob
Returns:
The created or updated License object
"""
existing = get_license(db_session)
if existing:
existing.license_data = license_data
db_session.commit()
db_session.refresh(existing)
logger.info("License updated")
return existing
new_license = License(license_data=license_data)
db_session.add(new_license)
db_session.commit()
db_session.refresh(new_license)
logger.info("License created")
return new_license
def delete_license(db_session: Session) -> bool:
"""
Delete the current license.
Args:
db_session: Database session
Returns:
True if deleted, False if no license existed
"""
existing = get_license(db_session)
if existing:
db_session.delete(existing)
db_session.commit()
logger.info("License deleted")
return True
return False
# -----------------------------------------------------------------------------
# Seat Counting
# -----------------------------------------------------------------------------
def get_used_seats(tenant_id: str | None = None) -> int:
"""
Get current seat usage.
For multi-tenant: counts users in UserTenantMapping for this tenant.
For self-hosted: counts all active users (includes both Onyx UI users
and Slack users who have been converted to Onyx users).
"""
if MULTI_TENANT:
from ee.onyx.server.tenants.user_mapping import get_tenant_count
return get_tenant_count(tenant_id or get_current_tenant_id())
else:
# Self-hosted: count all active users (Onyx + converted Slack users)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
with get_session_with_current_tenant() as db_session:
result = db_session.execute(
select(func.count()).select_from(User).where(User.is_active) # type: ignore
)
return result.scalar() or 0
# -----------------------------------------------------------------------------
# Redis Cache Operations
# -----------------------------------------------------------------------------
def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata | None:
"""
Get license metadata from Redis cache.
Args:
tenant_id: Tenant ID (for multi-tenant deployments)
Returns:
LicenseMetadata if cached, None otherwise
"""
tenant = tenant_id or get_current_tenant_id()
redis_client = get_redis_replica_client(tenant_id=tenant)
cached = redis_client.get(LICENSE_METADATA_KEY)
if cached:
try:
cached_str: str
if isinstance(cached, bytes):
cached_str = cached.decode("utf-8")
else:
cached_str = str(cached)
return LicenseMetadata.model_validate_json(cached_str)
except Exception as e:
logger.warning(f"Failed to parse cached license metadata: {e}")
return None
return None
def invalidate_license_cache(tenant_id: str | None = None) -> None:
"""
Invalidate the license metadata cache (not the license itself).
This deletes the cached LicenseMetadata from Redis. The actual license
in the database is not affected. Redis delete is idempotent - if the
key doesn't exist, this is a no-op.
Args:
tenant_id: Tenant ID (for multi-tenant deployments)
"""
tenant = tenant_id or get_current_tenant_id()
redis_client = get_redis_client(tenant_id=tenant)
redis_client.delete(LICENSE_METADATA_KEY)
logger.info("License cache invalidated")
def update_license_cache(
payload: LicensePayload,
source: LicenseSource | None = None,
grace_period_end: datetime | None = None,
tenant_id: str | None = None,
) -> LicenseMetadata:
"""
Update the Redis cache with license metadata.
We cache all license statuses (ACTIVE, GRACE_PERIOD, GATED_ACCESS) because:
1. Frontend needs status to show appropriate UI/banners
2. Caching avoids repeated DB + crypto verification on every request
3. Status enforcement happens at the feature level, not here
Args:
payload: Verified license payload
source: How the license was obtained
grace_period_end: Optional grace period end time
tenant_id: Tenant ID (for multi-tenant deployments)
Returns:
The cached LicenseMetadata
"""
from ee.onyx.utils.license import get_license_status
tenant = tenant_id or get_current_tenant_id()
redis_client = get_redis_client(tenant_id=tenant)
used_seats = get_used_seats(tenant)
status = get_license_status(payload, grace_period_end)
metadata = LicenseMetadata(
tenant_id=payload.tenant_id,
organization_name=payload.organization_name,
seats=payload.seats,
used_seats=used_seats,
plan_type=payload.plan_type,
issued_at=payload.issued_at,
expires_at=payload.expires_at,
grace_period_end=grace_period_end,
status=status,
source=source,
stripe_subscription_id=payload.stripe_subscription_id,
)
redis_client.setex(
LICENSE_METADATA_KEY,
LICENSE_CACHE_TTL_SECONDS,
metadata.model_dump_json(),
)
logger.info(f"License cache updated: {metadata.seats} seats, status={status.value}")
return metadata
def refresh_license_cache(
db_session: Session,
tenant_id: str | None = None,
) -> LicenseMetadata | None:
"""
Refresh the license cache from the database.
Args:
db_session: Database session
tenant_id: Tenant ID (for multi-tenant deployments)
Returns:
LicenseMetadata if license exists, None otherwise
"""
from ee.onyx.utils.license import verify_license_signature
license_record = get_license(db_session)
if not license_record:
invalidate_license_cache(tenant_id)
return None
try:
payload = verify_license_signature(license_record.license_data)
return update_license_cache(
payload,
source=LicenseSource.AUTO_FETCH,
tenant_id=tenant_id,
)
except ValueError as e:
logger.error(f"Failed to verify license during cache refresh: {e}")
invalidate_license_cache(tenant_id)
return None
def get_license_metadata(
db_session: Session,
tenant_id: str | None = None,
) -> LicenseMetadata | None:
"""
Get license metadata, using cache if available.
Args:
db_session: Database session
tenant_id: Tenant ID (for multi-tenant deployments)
Returns:
LicenseMetadata if license exists, None otherwise
"""
# Try cache first
cached = get_cached_license_metadata(tenant_id)
if cached:
return cached
# Refresh from database
return refresh_license_cache(db_session, tenant_id)

View File

@@ -3,42 +3,30 @@ from uuid import UUID
from sqlalchemy.orm import Session
from onyx.configs.constants import NotificationType
from onyx.db.models import Persona
from onyx.db.models import Persona__User
from onyx.db.models import Persona__UserGroup
from onyx.db.notification import create_notification
from onyx.server.features.persona.models import PersonaSharedNotificationData
def update_persona_access(
def make_persona_private(
persona_id: int,
creator_user_id: UUID | None,
user_ids: list[UUID] | None,
group_ids: list[int] | None,
db_session: Session,
is_public: bool | None = None,
user_ids: list[UUID] | None = None,
group_ids: list[int] | None = None,
) -> None:
"""Updates the access settings for a persona including public status, user shares,
and group shares.
"""NOTE(rkuo): This function batches all updates into a single commit. If we don't
dedupe the inputs, the commit will exception."""
NOTE: This function batches all updates. If we don't dedupe the inputs,
the commit will exception.
NOTE: Callers are responsible for committing."""
if is_public is not None:
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
if persona:
persona.is_public = is_public
# NOTE: For user-ids and group-ids, `None` means "leave unchanged", `[]` means "clear all shares",
# and a non-empty list means "replace with these shares".
if user_ids is not None:
db_session.query(Persona__User).filter(
Persona__User.persona_id == persona_id
).delete(synchronize_session="fetch")
db_session.query(Persona__User).filter(
Persona__User.persona_id == persona_id
).delete(synchronize_session="fetch")
db_session.query(Persona__UserGroup).filter(
Persona__UserGroup.persona_id == persona_id
).delete(synchronize_session="fetch")
if user_ids:
user_ids_set = set(user_ids)
for user_id in user_ids_set:
db_session.add(Persona__User(persona_id=persona_id, user_id=user_id))
@@ -46,20 +34,17 @@ def update_persona_access(
create_notification(
user_id=user_id,
notif_type=NotificationType.PERSONA_SHARED,
title="A new agent was shared with you!",
db_session=db_session,
additional_data=PersonaSharedNotificationData(
persona_id=persona_id,
).model_dump(),
)
if group_ids is not None:
db_session.query(Persona__UserGroup).filter(
Persona__UserGroup.persona_id == persona_id
).delete(synchronize_session="fetch")
if group_ids:
group_ids_set = set(group_ids)
for group_id in group_ids_set:
db_session.add(
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
)
db_session.commit()

View File

@@ -14,7 +14,6 @@ from ee.onyx.server.enterprise_settings.api import (
basic_router as enterprise_settings_router,
)
from ee.onyx.server.evals.api import router as evals_router
from ee.onyx.server.license.api import router as license_router
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.tenant_tracking import (
add_api_server_tenant_id_middleware,
@@ -140,8 +139,6 @@ def get_application() -> FastAPI:
)
include_router_with_global_prefix_prepended(application, enterprise_settings_router)
include_router_with_global_prefix_prepended(application, usage_export_router)
# License management
include_router_with_global_prefix_prepended(application, license_router)
if MULTI_TENANT:
# Tenant management

View File

@@ -21,9 +21,8 @@ from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.server.utils import PUBLIC_API_TAGS
router = APIRouter(prefix="/analytics", tags=PUBLIC_API_TAGS)
router = APIRouter(prefix="/analytics")
_DEFAULT_LOOKBACK_DAYS = 30

View File

@@ -1,4 +1,3 @@
from enum import Enum
from typing import Any
from typing import List
@@ -24,12 +23,6 @@ class NavigationItem(BaseModel):
return instance
class LogoDisplayStyle(str, Enum):
LOGO_AND_NAME = "logo_and_name"
LOGO_ONLY = "logo_only"
NAME_ONLY = "name_only"
class EnterpriseSettings(BaseModel):
"""General settings that only apply to the Enterprise Edition of Onyx
@@ -38,7 +31,6 @@ class EnterpriseSettings(BaseModel):
application_name: str | None = None
use_custom_logo: bool = False
use_custom_logotype: bool = False
logo_display_style: LogoDisplayStyle | None = None
# custom navigation
custom_nav_items: List[NavigationItem] = Field(default_factory=list)
@@ -50,9 +42,6 @@ class EnterpriseSettings(BaseModel):
custom_popup_header: str | None = None
custom_popup_content: str | None = None
enable_consent_screen: bool | None = None
consent_screen_prompt: str | None = None
show_first_visit_notice: bool | None = None
custom_greeting_message: str | None = None
def check_validity(self) -> None:
return

View File

@@ -1,246 +0,0 @@
"""License API endpoints."""
import requests
from fastapi import APIRouter
from fastapi import Depends
from fastapi import File
from fastapi import HTTPException
from fastapi import UploadFile
from sqlalchemy.orm import Session
from ee.onyx.auth.users import current_admin_user
from ee.onyx.db.license import delete_license as db_delete_license
from ee.onyx.db.license import get_license_metadata
from ee.onyx.db.license import invalidate_license_cache
from ee.onyx.db.license import refresh_license_cache
from ee.onyx.db.license import update_license_cache
from ee.onyx.db.license import upsert_license
from ee.onyx.server.license.models import LicenseResponse
from ee.onyx.server.license.models import LicenseSource
from ee.onyx.server.license.models import LicenseStatusResponse
from ee.onyx.server.license.models import LicenseUploadResponse
from ee.onyx.server.license.models import SeatUsageResponse
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.utils.license import verify_license_signature
from onyx.auth.users import User
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.db.engine.sql_engine import get_session
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/license")
@router.get("")
async def get_license_status(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LicenseStatusResponse:
"""Get current license status and seat usage."""
metadata = get_license_metadata(db_session)
if not metadata:
return LicenseStatusResponse(has_license=False)
return LicenseStatusResponse(
has_license=True,
seats=metadata.seats,
used_seats=metadata.used_seats,
plan_type=metadata.plan_type,
issued_at=metadata.issued_at,
expires_at=metadata.expires_at,
grace_period_end=metadata.grace_period_end,
status=metadata.status,
source=metadata.source,
)
@router.get("/seats")
async def get_seat_usage(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> SeatUsageResponse:
"""Get detailed seat usage information."""
metadata = get_license_metadata(db_session)
if not metadata:
return SeatUsageResponse(
total_seats=0,
used_seats=0,
available_seats=0,
)
return SeatUsageResponse(
total_seats=metadata.seats,
used_seats=metadata.used_seats,
available_seats=max(0, metadata.seats - metadata.used_seats),
)
@router.post("/fetch")
async def fetch_license(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LicenseResponse:
"""
Fetch license from control plane.
Used after Stripe checkout completion to retrieve the new license.
"""
tenant_id = get_current_tenant_id()
try:
token = generate_data_plane_token()
except ValueError as e:
logger.error(f"Failed to generate data plane token: {e}")
raise HTTPException(
status_code=500, detail="Authentication configuration error"
)
try:
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}/license/{tenant_id}"
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
data = response.json()
if not isinstance(data, dict) or "license" not in data:
raise HTTPException(
status_code=502, detail="Invalid response from control plane"
)
license_data = data["license"]
if not license_data:
raise HTTPException(status_code=404, detail="No license found")
# Verify signature before persisting
payload = verify_license_signature(license_data)
# Verify the fetched license is for this tenant
if payload.tenant_id != tenant_id:
logger.error(
f"License tenant mismatch: expected {tenant_id}, got {payload.tenant_id}"
)
raise HTTPException(
status_code=400,
detail="License tenant ID mismatch - control plane returned wrong license",
)
# Persist to DB and update cache atomically
upsert_license(db_session, license_data)
try:
update_license_cache(payload, source=LicenseSource.AUTO_FETCH)
except Exception as cache_error:
# Log but don't fail - DB is source of truth, cache will refresh on next read
logger.warning(f"Failed to update license cache: {cache_error}")
return LicenseResponse(success=True, license=payload)
except requests.HTTPError as e:
status_code = e.response.status_code if e.response is not None else 502
logger.error(f"Control plane returned error: {status_code}")
raise HTTPException(
status_code=status_code,
detail="Failed to fetch license from control plane",
)
except ValueError as e:
logger.error(f"License verification failed: {type(e).__name__}")
raise HTTPException(status_code=400, detail=str(e))
except requests.RequestException:
logger.exception("Failed to fetch license from control plane")
raise HTTPException(
status_code=502, detail="Failed to connect to control plane"
)
@router.post("/upload")
async def upload_license(
license_file: UploadFile = File(...),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LicenseUploadResponse:
"""
Upload a license file manually.
Used for air-gapped deployments where control plane is not accessible.
"""
try:
content = await license_file.read()
license_data = content.decode("utf-8").strip()
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="Invalid license file format")
try:
payload = verify_license_signature(license_data)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
tenant_id = get_current_tenant_id()
if payload.tenant_id != tenant_id:
raise HTTPException(
status_code=400,
detail=f"License tenant ID mismatch. Expected {tenant_id}, got {payload.tenant_id}",
)
# Persist to DB and update cache
upsert_license(db_session, license_data)
try:
update_license_cache(payload, source=LicenseSource.MANUAL_UPLOAD)
except Exception as cache_error:
# Log but don't fail - DB is source of truth, cache will refresh on next read
logger.warning(f"Failed to update license cache: {cache_error}")
return LicenseUploadResponse(
success=True,
message=f"License uploaded successfully. {payload.seats} seats, expires {payload.expires_at.date()}",
)
@router.post("/refresh")
async def refresh_license_cache_endpoint(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LicenseStatusResponse:
"""
Force refresh the license cache from the database.
Useful after manual database changes or to verify license validity.
"""
metadata = refresh_license_cache(db_session)
if not metadata:
return LicenseStatusResponse(has_license=False)
return LicenseStatusResponse(
has_license=True,
seats=metadata.seats,
used_seats=metadata.used_seats,
plan_type=metadata.plan_type,
issued_at=metadata.issued_at,
expires_at=metadata.expires_at,
grace_period_end=metadata.grace_period_end,
status=metadata.status,
source=metadata.source,
)
@router.delete("")
async def delete_license(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, bool]:
"""
Delete the current license.
Admin only - removes license and invalidates cache.
"""
# Invalidate cache first - if DB delete fails, stale cache is worse than no cache
try:
invalidate_license_cache()
except Exception as cache_error:
logger.warning(f"Failed to invalidate license cache: {cache_error}")
deleted = db_delete_license(db_session)
return {"deleted": deleted}

View File

@@ -1,92 +0,0 @@
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from onyx.server.settings.models import ApplicationStatus
class PlanType(str, Enum):
MONTHLY = "monthly"
ANNUAL = "annual"
class LicenseSource(str, Enum):
AUTO_FETCH = "auto_fetch"
MANUAL_UPLOAD = "manual_upload"
class LicensePayload(BaseModel):
"""The payload portion of a signed license."""
version: str
tenant_id: str
organization_name: str | None = None
issued_at: datetime
expires_at: datetime
seats: int
plan_type: PlanType
billing_cycle: str | None = None
grace_period_days: int = 30
stripe_subscription_id: str | None = None
stripe_customer_id: str | None = None
class LicenseData(BaseModel):
"""Full signed license structure."""
payload: LicensePayload
signature: str
class LicenseMetadata(BaseModel):
"""Cached license metadata stored in Redis."""
tenant_id: str
organization_name: str | None = None
seats: int
used_seats: int
plan_type: PlanType
issued_at: datetime
expires_at: datetime
grace_period_end: datetime | None = None
status: ApplicationStatus
source: LicenseSource | None = None
stripe_subscription_id: str | None = None
class LicenseStatusResponse(BaseModel):
"""Response for license status API."""
has_license: bool
seats: int = 0
used_seats: int = 0
plan_type: PlanType | None = None
issued_at: datetime | None = None
expires_at: datetime | None = None
grace_period_end: datetime | None = None
status: ApplicationStatus | None = None
source: LicenseSource | None = None
class LicenseResponse(BaseModel):
"""Response after license fetch/upload."""
success: bool
message: str | None = None
license: LicensePayload | None = None
class LicenseUploadResponse(BaseModel):
"""Response after license upload."""
success: bool
message: str | None = None
class SeatUsageResponse(BaseModel):
"""Response for seat usage API."""
total_seats: int
used_seats: int
available_seats: int

View File

@@ -20,10 +20,9 @@ from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_or_create_root_message
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llms_for_persona
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -101,13 +100,14 @@ def handle_simplified_chat_message(
chunks_below=0,
full_doc=chat_message_req.full_doc,
structured_response_format=chat_message_req.structured_response_format,
origin=MessageOrigin.API,
use_agentic_search=chat_message_req.use_agentic_search,
)
packets = stream_chat_message_objects(
new_msg_req=full_chat_msg_info,
user=user,
db_session=db_session,
enforce_chat_session_id_for_search_docs=False,
)
return gather_stream(packets)
@@ -158,7 +158,7 @@ def handle_send_message_simple_with_history(
persona_id=req.persona_id,
)
llm = get_llm_for_persona(persona=chat_session.persona, user=user)
llm, _ = get_llms_for_persona(persona=chat_session.persona, user=user)
llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,
@@ -205,13 +205,14 @@ def handle_send_message_simple_with_history(
chunks_below=0,
full_doc=req.full_doc,
structured_response_format=req.structured_response_format,
origin=MessageOrigin.API,
use_agentic_search=req.use_agentic_search,
)
packets = stream_chat_message_objects(
new_msg_req=full_chat_msg_info,
user=user,
db_session=db_session,
enforce_chat_session_id_for_search_docs=False,
)
return gather_stream(packets)

View File

@@ -54,6 +54,9 @@ class BasicCreateChatMessageRequest(ChunkContext):
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
# If True, uses agentic search instead of basic search
use_agentic_search: bool = False
@model_validator(mode="after")
def validate_chat_session_or_persona(self) -> "BasicCreateChatMessageRequest":
if self.chat_session_id is None and self.persona_id is None:
@@ -73,6 +76,8 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
# only works if using an OpenAI model. See the following for more details:
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
# If True, uses agentic search instead of basic search
use_agentic_search: bool = False
class SimpleDoc(BaseModel):

View File

@@ -48,7 +48,6 @@ from onyx.file_store.file_store import get_default_file_store
from onyx.server.documents.models import PaginatedReturn
from onyx.server.query_and_chat.models import ChatSessionDetails
from onyx.server.query_and_chat.models import ChatSessionsResponse
from onyx.server.utils import PUBLIC_API_TAGS
from onyx.utils.threadpool_concurrency import parallel_yield
from shared_configs.contextvars import get_current_tenant_id
@@ -295,7 +294,7 @@ def list_all_query_history_exports(
)
@router.post("/admin/query-history/start-export", tags=PUBLIC_API_TAGS)
@router.post("/admin/query-history/start-export")
def start_query_history_export(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
@@ -341,7 +340,7 @@ def start_query_history_export(
return {"request_id": task_id}
@router.get("/admin/query-history/export-status", tags=PUBLIC_API_TAGS)
@router.get("/admin/query-history/export-status")
def get_query_history_export_status(
request_id: str,
_: User | None = Depends(current_admin_user),
@@ -375,7 +374,7 @@ def get_query_history_export_status(
return {"status": TaskStatus.SUCCESS}
@router.get("/admin/query-history/download", tags=PUBLIC_API_TAGS)
@router.get("/admin/query-history/download")
def download_query_history_csv(
request_id: str,
_: User | None = Depends(current_admin_user),

View File

@@ -1,92 +0,0 @@
"""Tenant-specific usage limit overrides from the control plane (EE version)."""
import requests
from ee.onyx.server.tenants.access import generate_data_plane_token
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.server.tenant_usage_limits import TenantUsageLimitOverrides
from onyx.utils.logger import setup_logger
logger = setup_logger()
# In-memory storage for tenant overrides (populated at startup)
_tenant_usage_limit_overrides: dict[str, TenantUsageLimitOverrides] | None = None
def fetch_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides]:
"""
Fetch tenant-specific usage limit overrides from the control plane.
Returns:
Dictionary mapping tenant_id to their specific limit overrides.
Returns empty dict on any error (falls back to defaults).
"""
try:
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}/usage-limit-overrides"
response = requests.get(url, headers=headers, timeout=30)
response.raise_for_status()
tenant_overrides = response.json()
# Parse each tenant's overrides
result: dict[str, TenantUsageLimitOverrides] = {}
for override_data in tenant_overrides:
tenant_id = override_data["tenant_id"]
try:
result[tenant_id] = TenantUsageLimitOverrides(**override_data)
except Exception as e:
logger.warning(
f"Failed to parse usage limit overrides for tenant {tenant_id}: {e}"
)
return result
except requests.exceptions.RequestException as e:
logger.warning(f"Failed to fetch usage limit overrides from control plane: {e}")
return {}
except Exception as e:
logger.error(f"Error parsing usage limit overrides: {e}")
return {}
def load_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides]:
"""
Load tenant usage limit overrides from the control plane.
Called at server startup to populate the in-memory cache.
"""
global _tenant_usage_limit_overrides
logger.info("Loading tenant usage limit overrides from control plane...")
overrides = fetch_usage_limit_overrides()
_tenant_usage_limit_overrides = overrides
if overrides:
logger.info(f"Loaded usage limit overrides for {len(overrides)} tenants")
else:
logger.info("No tenant-specific usage limit overrides found")
return overrides
def get_tenant_usage_limit_overrides(
tenant_id: str,
) -> TenantUsageLimitOverrides | None:
"""
Get the usage limit overrides for a specific tenant.
Args:
tenant_id: The tenant ID to look up
Returns:
TenantUsageLimitOverrides if the tenant has overrides, None otherwise.
"""
global _tenant_usage_limit_overrides
if _tenant_usage_limit_overrides is None:
_tenant_usage_limit_overrides = load_usage_limit_overrides()
return _tenant_usage_limit_overrides.get(tenant_id)

View File

@@ -1,9 +1,9 @@
from typing import cast
from typing import Literal
import requests
import stripe
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import BillingInformation
@@ -16,21 +16,15 @@ stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
def fetch_stripe_checkout_session(
tenant_id: str,
billing_period: Literal["monthly", "annual"] = "monthly",
) -> str:
def fetch_stripe_checkout_session(tenant_id: str) -> str:
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
payload = {
"tenant_id": tenant_id,
"billing_period": billing_period,
}
response = requests.post(url, headers=headers, json=payload)
params = {"tenant_id": tenant_id}
response = requests.post(url, headers=headers, params=params)
response.raise_for_status()
return response.json()["sessionId"]
@@ -78,24 +72,22 @@ def fetch_billing_information(
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
"""
Update the number of seats for a tenant's subscription.
Preserves the existing price (monthly, annual, or grandfathered).
Send a request to the control service to register the number of users for a tenant.
"""
if not STRIPE_PRICE_ID:
raise Exception("STRIPE_PRICE_ID is not set")
response = fetch_tenant_stripe_information(tenant_id)
stripe_subscription_id = cast(str, response.get("stripe_subscription_id"))
subscription = stripe.Subscription.retrieve(stripe_subscription_id)
subscription_item = subscription["items"]["data"][0]
# Use existing price to preserve the customer's current plan
current_price_id = subscription_item.price.id
updated_subscription = stripe.Subscription.modify(
stripe_subscription_id,
items=[
{
"id": subscription_item.id,
"price": current_price_id,
"id": subscription["items"]["data"][0].id,
"price": STRIPE_PRICE_ID,
"quantity": number_of_users,
}
],

View File

@@ -10,7 +10,6 @@ from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import CreateSubscriptionSessionRequest
from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
@@ -105,18 +104,15 @@ async def create_customer_portal_session(
@router.post("/create-subscription-session")
async def create_subscription_session(
request: CreateSubscriptionSessionRequest | None = None,
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
billing_period = request.billing_period if request else "monthly"
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create subscription session")
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -1,5 +1,4 @@
from datetime import datetime
from typing import Literal
from pydantic import BaseModel
@@ -74,12 +73,6 @@ class SubscriptionSessionResponse(BaseModel):
sessionId: str
class CreateSubscriptionSessionRequest(BaseModel):
"""Request to create a subscription checkout session."""
billing_period: Literal["monthly", "annual"] = "monthly"
class TenantByDomainResponse(BaseModel):
tenant_id: str
number_of_users: int

View File

@@ -1,4 +1,5 @@
import asyncio
import logging
import uuid
import aiohttp # Async HTTP client
@@ -9,7 +10,10 @@ from fastapi import Request
from sqlalchemy import select
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import TenantByDomainResponse
from ee.onyx.server.tenants.models import TenantCreationPayload
@@ -21,18 +25,11 @@ from ee.onyx.server.tenants.user_mapping import add_users_to_tenant
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import user_owns_a_tenant
from onyx.auth.users import exceptions
from onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
from onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
from onyx.configs.app_configs import OPENROUTER_DEFAULT_API_KEY
from onyx.configs.app_configs import VERTEXAI_DEFAULT_CREDENTIALS
from onyx.configs.app_configs import VERTEXAI_DEFAULT_LOCATION
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.image_generation import create_default_image_gen_config_from_api_key
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.llm import upsert_llm_provider
@@ -40,25 +37,15 @@ from onyx.db.models import AvailableTenant
from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.db.models import UserTenantMapping
from onyx.llm.well_known_providers.auto_update_models import LLMRecommendations
from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import VERTEX_CREDENTIALS_FILE_KWARG
from onyx.llm.well_known_providers.constants import VERTEX_LOCATION_KWARG
from onyx.llm.well_known_providers.constants import VERTEXAI_PROVIDER_NAME
from onyx.llm.well_known_providers.llm_provider_options import (
get_recommendations,
)
from onyx.llm.well_known_providers.llm_provider_options import (
model_configurations_for_provider,
)
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
from onyx.llm.llm_provider_options import get_anthropic_model_names
from onyx.llm.llm_provider_options import get_openai_model_names
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.setup import setup_onyx
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.telemetry import create_milestone_and_report
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
@@ -66,7 +53,7 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.enums import EmbeddingProvider
logger = setup_logger()
logger = logging.getLogger(__name__)
async def get_or_provision_tenant(
@@ -275,173 +262,61 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
logger.info(f"Tenant rollback completed successfully for tenant {tenant_id}")
def _build_model_configuration_upsert_requests(
provider_name: str,
recommendations: LLMRecommendations,
) -> list[ModelConfigurationUpsertRequest]:
model_configurations = model_configurations_for_provider(
provider_name, recommendations
)
return [
ModelConfigurationUpsertRequest(
name=model_configuration.name,
is_visible=model_configuration.is_visible,
max_input_tokens=model_configuration.max_input_tokens,
supports_image_input=model_configuration.supports_image_input,
)
for model_configuration in model_configurations
]
def configure_default_api_keys(db_session: Session) -> None:
"""Configure default LLM providers using recommended-models.json for model selection."""
# Load recommendations from JSON config
recommendations = get_recommendations()
has_set_default_provider = False
def _upsert(request: LLMProviderUpsertRequest) -> None:
nonlocal has_set_default_provider
try:
provider = upsert_llm_provider(request, db_session)
if not has_set_default_provider:
update_default_provider(provider.id, db_session)
has_set_default_provider = True
except Exception as e:
logger.error(f"Failed to configure {request.provider} provider: {e}")
# Configure OpenAI provider
if OPENAI_DEFAULT_API_KEY:
default_model = recommendations.get_default_model(OPENAI_PROVIDER_NAME)
if default_model is None:
logger.error(
f"No default model found for {OPENAI_PROVIDER_NAME} in recommendations"
)
default_model_name = default_model.name if default_model else "gpt-5.2"
openai_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
OPENAI_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(openai_provider)
# Create default image generation config using the OpenAI API key
try:
create_default_image_gen_config_from_api_key(
db_session, OPENAI_DEFAULT_API_KEY
)
except Exception as e:
logger.error(f"Failed to create default image gen config: {e}")
else:
logger.info(
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
)
# Configure Anthropic provider
if ANTHROPIC_DEFAULT_API_KEY:
default_model = recommendations.get_default_model(ANTHROPIC_PROVIDER_NAME)
if default_model is None:
logger.error(
f"No default model found for {ANTHROPIC_PROVIDER_NAME} in recommendations"
)
default_model_name = (
default_model.name if default_model else "claude-sonnet-4-5"
)
anthropic_provider = LLMProviderUpsertRequest(
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
ANTHROPIC_PROVIDER_NAME, recommendations
),
default_model_name="claude-3-7-sonnet-20250219",
fast_default_model_name="claude-3-5-sonnet-20241022",
model_configurations=[
ModelConfigurationUpsertRequest(
name=name,
is_visible=False,
max_input_tokens=None,
)
for name in get_anthropic_model_names()
],
api_key_changed=True,
is_auto_mode=True,
)
_upsert(anthropic_provider)
try:
full_provider = upsert_llm_provider(anthropic_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure Anthropic provider: {e}")
else:
logger.info(
logger.error(
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
)
# Configure Vertex AI provider
if VERTEXAI_DEFAULT_CREDENTIALS:
default_model = recommendations.get_default_model(VERTEXAI_PROVIDER_NAME)
if default_model is None:
logger.error(
f"No default model found for {VERTEXAI_PROVIDER_NAME} in recommendations"
)
default_model_name = default_model.name if default_model else "gemini-2.5-pro"
# Vertex AI uses custom_config for credentials and location
custom_config = {
VERTEX_CREDENTIALS_FILE_KWARG: VERTEXAI_DEFAULT_CREDENTIALS,
VERTEX_LOCATION_KWARG: VERTEXAI_DEFAULT_LOCATION,
}
vertexai_provider = LLMProviderUpsertRequest(
name="Google Vertex AI",
provider=VERTEXAI_PROVIDER_NAME,
custom_config=custom_config,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
VERTEXAI_PROVIDER_NAME, recommendations
),
if OPENAI_DEFAULT_API_KEY:
openai_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4o",
fast_default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name=model_name,
is_visible=False,
max_input_tokens=None,
)
for model_name in get_openai_model_names()
],
api_key_changed=True,
is_auto_mode=True,
)
_upsert(vertexai_provider)
try:
full_provider = upsert_llm_provider(openai_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure OpenAI provider: {e}")
else:
logger.info(
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
logger.error(
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
)
# Configure OpenRouter provider
if OPENROUTER_DEFAULT_API_KEY:
default_model = recommendations.get_default_model(OPENROUTER_PROVIDER_NAME)
if default_model is None:
logger.error(
f"No default model found for {OPENROUTER_PROVIDER_NAME} in recommendations"
)
default_model_name = default_model.name if default_model else "z-ai/glm-4.7"
# For OpenRouter, we use the visible models from recommendations as model_configurations
# since OpenRouter models are dynamic (fetched from their API)
visible_models = recommendations.get_visible_models(OPENROUTER_PROVIDER_NAME)
model_configurations = [
ModelConfigurationUpsertRequest(
name=model.name,
is_visible=True,
max_input_tokens=None,
display_name=model.display_name,
)
for model in visible_models
]
openrouter_provider = LLMProviderUpsertRequest(
name="OpenRouter",
provider=OPENROUTER_PROVIDER_NAME,
api_key=OPENROUTER_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=model_configurations,
api_key_changed=True,
is_auto_mode=True,
)
_upsert(openrouter_provider)
else:
logger.info(
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"
)
# Configure Cohere embedding provider
if COHERE_DEFAULT_API_KEY:
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
provider_type=EmbeddingProvider.COHERE,
@@ -687,11 +562,17 @@ async def assign_tenant_to_user(
try:
add_users_to_tenant([email], tenant_id)
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=email,
event=MilestoneRecordType.TENANT_CREATED,
)
# Create milestone record in the same transaction context as the tenant assignment
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
except Exception:
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
raise Exception("Failed to assign tenant to user")

View File

@@ -249,17 +249,6 @@ def accept_user_invite(email: str, tenant_id: str) -> None:
)
raise
# Remove from invited users list since they've accepted
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
invited_users = get_invited_users()
if email in invited_users:
invited_users.remove(email)
write_invited_users(invited_users)
logger.info(f"Removed {email} from invited users list after acceptance")
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def deny_user_invite(email: str, tenant_id: str) -> None:
"""

View File

@@ -16,9 +16,8 @@ from onyx.db.token_limit import insert_user_token_rate_limit
from onyx.server.query_and_chat.token_limit import any_rate_limit_exists
from onyx.server.token_rate_limits.models import TokenRateLimitArgs
from onyx.server.token_rate_limits.models import TokenRateLimitDisplay
from onyx.server.utils import PUBLIC_API_TAGS
router = APIRouter(prefix="/admin/token-rate-limits", tags=PUBLIC_API_TAGS)
router = APIRouter(prefix="/admin/token-rate-limits")
"""

View File

@@ -1,38 +0,0 @@
"""EE Usage limits - trial detection via billing information."""
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
def is_tenant_on_trial(tenant_id: str) -> bool:
"""
Determine if a tenant is currently on a trial subscription.
In multi-tenant mode, we fetch billing information from the control plane
to determine if the tenant has an active trial.
"""
if not MULTI_TENANT:
return False
try:
billing_info = fetch_billing_information(tenant_id)
# If not subscribed at all, check if we have trial information
if isinstance(billing_info, SubscriptionStatusResponse):
# No subscription means they're likely on trial (new tenant)
return True
if isinstance(billing_info, BillingInformation):
return billing_info.status == "trialing"
return False
except Exception as e:
logger.warning(f"Failed to fetch billing info for trial check: {e}")
# Default to trial limits on error (more restrictive = safer)
return True

View File

@@ -21,12 +21,11 @@ from onyx.auth.users import current_curator_or_admin_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.db.models import UserRole
from onyx.server.utils import PUBLIC_API_TAGS
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS)
router = APIRouter(prefix="/manage")
@router.get("/admin/user-group")

View File

@@ -1,126 +0,0 @@
"""RSA-4096 license signature verification utilities."""
import base64
import json
import os
from datetime import datetime
from datetime import timezone
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
from ee.onyx.server.license.models import LicenseData
from ee.onyx.server.license.models import LicensePayload
from onyx.server.settings.models import ApplicationStatus
from onyx.utils.logger import setup_logger
logger = setup_logger()
# RSA-4096 Public Key for license verification
# Load from environment variable - key is generated on the control plane
# In production, inject via Kubernetes secrets or secrets manager
LICENSE_PUBLIC_KEY_PEM = os.environ.get("LICENSE_PUBLIC_KEY_PEM", "")
def _get_public_key() -> RSAPublicKey:
"""Load the public key from environment variable."""
if not LICENSE_PUBLIC_KEY_PEM:
raise ValueError(
"LICENSE_PUBLIC_KEY_PEM environment variable not set. "
"License verification requires the control plane public key."
)
key = serialization.load_pem_public_key(LICENSE_PUBLIC_KEY_PEM.encode())
if not isinstance(key, RSAPublicKey):
raise ValueError("Expected RSA public key")
return key
def verify_license_signature(license_data: str) -> LicensePayload:
"""
Verify RSA-4096 signature and return payload if valid.
Args:
license_data: Base64-encoded JSON containing payload and signature
Returns:
LicensePayload if signature is valid
Raises:
ValueError: If license data is invalid or signature verification fails
"""
try:
# Decode the license data
decoded = json.loads(base64.b64decode(license_data))
license_obj = LicenseData(**decoded)
payload_json = json.dumps(
license_obj.payload.model_dump(mode="json"), sort_keys=True
)
signature_bytes = base64.b64decode(license_obj.signature)
# Verify signature using PSS padding (modern standard)
public_key = _get_public_key()
public_key.verify(
signature_bytes,
payload_json.encode(),
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
hashes.SHA256(),
)
return license_obj.payload
except InvalidSignature:
logger.error("License signature verification failed")
raise ValueError("Invalid license signature")
except json.JSONDecodeError:
logger.error("Failed to decode license JSON")
raise ValueError("Invalid license format: not valid JSON")
except (ValueError, KeyError, TypeError) as e:
logger.error(f"License data validation error: {type(e).__name__}")
raise ValueError(f"Invalid license format: {type(e).__name__}")
except Exception:
logger.exception("Unexpected error during license verification")
raise ValueError("License verification failed: unexpected error")
def get_license_status(
payload: LicensePayload,
grace_period_end: datetime | None = None,
) -> ApplicationStatus:
"""
Determine current license status based on expiry.
Args:
payload: The verified license payload
grace_period_end: Optional grace period end datetime
Returns:
ApplicationStatus indicating current license state
"""
now = datetime.now(timezone.utc)
# Check if grace period has expired
if grace_period_end and now > grace_period_end:
return ApplicationStatus.GATED_ACCESS
# Check if license has expired
if now > payload.expires_at:
if grace_period_end and now <= grace_period_end:
return ApplicationStatus.GRACE_PERIOD
return ApplicationStatus.GATED_ACCESS
# License is valid
return ApplicationStatus.ACTIVE
def is_license_valid(payload: LicensePayload) -> bool:
"""Check if a license is currently valid (not expired)."""
now = datetime.now(timezone.utc)
return now <= payload.expires_at

View File

@@ -1,4 +1,5 @@
MODEL_WARM_UP_STRING = "hi " * 512
INFORMATION_CONTENT_MODEL_WARM_UP_STRING = "hi " * 16
class GPUStatus:

View File

@@ -0,0 +1,562 @@
from typing import cast
from typing import Optional
from typing import TYPE_CHECKING
import numpy as np
import torch
import torch.nn.functional as F
from fastapi import APIRouter
from huggingface_hub import snapshot_download # type: ignore
from model_server.constants import INFORMATION_CONTENT_MODEL_WARM_UP_STRING
from model_server.constants import MODEL_WARM_UP_STRING
from model_server.onyx_torch_model import ConnectorClassifier
from model_server.onyx_torch_model import HybridClassifier
from model_server.utils import simple_log_function_time
from onyx.utils.logger import setup_logger
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
)
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE,
)
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import INFORMATION_CONTENT_MODEL_TAG
from shared_configs.configs import INFORMATION_CONTENT_MODEL_VERSION
from shared_configs.configs import INTENT_MODEL_TAG
from shared_configs.configs import INTENT_MODEL_VERSION
from shared_configs.model_server_models import ConnectorClassificationRequest
from shared_configs.model_server_models import ConnectorClassificationResponse
from shared_configs.model_server_models import ContentClassificationPrediction
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
if TYPE_CHECKING:
from setfit import SetFitModel # type: ignore
from transformers import PreTrainedTokenizer, BatchEncoding # type: ignore
logger = setup_logger()
router = APIRouter(prefix="/custom")
_CONNECTOR_CLASSIFIER_TOKENIZER: Optional["PreTrainedTokenizer"] = None
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
_INTENT_TOKENIZER: Optional["PreTrainedTokenizer"] = None
_INTENT_MODEL: HybridClassifier | None = None
_INFORMATION_CONTENT_MODEL: Optional["SetFitModel"] = None
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
def get_connector_classifier_tokenizer() -> "PreTrainedTokenizer":
global _CONNECTOR_CLASSIFIER_TOKENIZER
from transformers import AutoTokenizer, PreTrainedTokenizer
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
# The tokenizer details are not uploaded to the HF hub since it's just the
# unmodified distilbert tokenizer.
_CONNECTOR_CLASSIFIER_TOKENIZER = cast(
PreTrainedTokenizer,
AutoTokenizer.from_pretrained("distilbert-base-uncased"),
)
return _CONNECTOR_CLASSIFIER_TOKENIZER
def get_local_connector_classifier(
model_name_or_path: str = CONNECTOR_CLASSIFIER_MODEL_REPO,
tag: str = CONNECTOR_CLASSIFIER_MODEL_TAG,
) -> ConnectorClassifier:
global _CONNECTOR_CLASSIFIER_MODEL
if _CONNECTOR_CLASSIFIER_MODEL is None:
try:
# Calculate where the cache should be, then load from local if available
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=True
)
_CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
local_path
)
except Exception as e:
logger.warning(f"Failed to load model directly: {e}")
try:
# Attempt to download the model snapshot
logger.info(f"Downloading model snapshot for {model_name_or_path}")
local_path = snapshot_download(repo_id=model_name_or_path, revision=tag)
_CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
local_path
)
except Exception as e:
logger.error(
f"Failed to load model even after attempted snapshot download: {e}"
)
raise
return _CONNECTOR_CLASSIFIER_MODEL
def get_intent_model_tokenizer() -> "PreTrainedTokenizer":
from transformers import AutoTokenizer, PreTrainedTokenizer
global _INTENT_TOKENIZER
if _INTENT_TOKENIZER is None:
# The tokenizer details are not uploaded to the HF hub since it's just the
# unmodified distilbert tokenizer.
_INTENT_TOKENIZER = cast(
PreTrainedTokenizer,
AutoTokenizer.from_pretrained("distilbert-base-uncased"),
)
return _INTENT_TOKENIZER
def get_local_intent_model(
model_name_or_path: str = INTENT_MODEL_VERSION,
tag: str | None = INTENT_MODEL_TAG,
) -> HybridClassifier:
global _INTENT_MODEL
if _INTENT_MODEL is None:
try:
# Calculate where the cache should be, then load from local if available
logger.notice(f"Loading model from local cache: {model_name_or_path}")
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=True
)
_INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
logger.notice(f"Loaded model from local cache: {local_path}")
except Exception as e:
logger.warning(f"Failed to load model directly: {e}")
try:
# Attempt to download the model snapshot
logger.notice(f"Downloading model snapshot for {model_name_or_path}")
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=False
)
_INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
except Exception as e:
logger.error(
f"Failed to load model even after attempted snapshot download: {e}"
)
raise
return _INTENT_MODEL
def get_local_information_content_model(
model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION,
tag: str | None = INFORMATION_CONTENT_MODEL_TAG,
) -> "SetFitModel":
from setfit import SetFitModel
global _INFORMATION_CONTENT_MODEL
if _INFORMATION_CONTENT_MODEL is None:
try:
# Calculate where the cache should be, then load from local if available
logger.notice(
f"Loading content information model from local cache: {model_name_or_path}"
)
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=True
)
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
logger.notice(
f"Loaded content information model from local cache: {local_path}"
)
except Exception as e:
logger.warning(f"Failed to load content information model directly: {e}")
try:
# Attempt to download the model snapshot
logger.notice(
f"Downloading content information model snapshot for {model_name_or_path}"
)
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=False
)
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
except Exception as e:
logger.error(
f"Failed to load content information model even after attempted snapshot download: {e}"
)
raise
return _INFORMATION_CONTENT_MODEL
def tokenize_connector_classification_query(
connectors: list[str],
query: str,
tokenizer: "PreTrainedTokenizer",
connector_token_end_id: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Tokenize the connectors & user query into one prompt for the forward pass of ConnectorClassifier models
The attention mask is just all 1s. The prompt is CLS + each connector name suffixed with the connector end
token and then the user query.
"""
input_ids = torch.tensor([tokenizer.cls_token_id], dtype=torch.long)
for connector in connectors:
connector_token_ids = tokenizer(
connector,
add_special_tokens=False,
return_tensors="pt",
)
input_ids = torch.cat(
(
input_ids,
connector_token_ids["input_ids"].squeeze(dim=0),
torch.tensor([connector_token_end_id], dtype=torch.long),
),
dim=-1,
)
query_token_ids = tokenizer(
query,
add_special_tokens=False,
return_tensors="pt",
)
input_ids = torch.cat(
(
input_ids,
query_token_ids["input_ids"].squeeze(dim=0),
torch.tensor([tokenizer.sep_token_id], dtype=torch.long),
),
dim=-1,
)
attention_mask = torch.ones(input_ids.numel(), dtype=torch.long)
return input_ids.unsqueeze(0), attention_mask.unsqueeze(0)
def warm_up_connector_classifier_model() -> None:
logger.info(
f"Warming up connector_classifier model {CONNECTOR_CLASSIFIER_MODEL_TAG}"
)
connector_classifier_tokenizer = get_connector_classifier_tokenizer()
connector_classifier = get_local_connector_classifier()
input_ids, attention_mask = tokenize_connector_classification_query(
["GitHub"],
"onyx classifier query google doc",
connector_classifier_tokenizer,
connector_classifier.connector_end_token_id,
)
input_ids = input_ids.to(connector_classifier.device)
attention_mask = attention_mask.to(connector_classifier.device)
connector_classifier(input_ids, attention_mask)
def warm_up_intent_model() -> None:
logger.notice(f"Warming up Intent Model: {INTENT_MODEL_VERSION}")
intent_tokenizer = get_intent_model_tokenizer()
tokens = intent_tokenizer(
MODEL_WARM_UP_STRING, return_tensors="pt", truncation=True, padding=True
)
intent_model = get_local_intent_model()
device = intent_model.device
intent_model(
query_ids=tokens["input_ids"].to(device),
query_mask=tokens["attention_mask"].to(device),
)
def warm_up_information_content_model() -> None:
logger.notice("Warming up Content Model") # TODO: add version if needed
information_content_model = get_local_information_content_model()
information_content_model(INFORMATION_CONTENT_MODEL_WARM_UP_STRING)
@simple_log_function_time()
def run_inference(tokens: "BatchEncoding") -> tuple[list[float], list[float]]:
intent_model = get_local_intent_model()
device = intent_model.device
outputs = intent_model(
query_ids=tokens["input_ids"].to(device),
query_mask=tokens["attention_mask"].to(device),
)
token_logits = outputs["token_logits"]
intent_logits = outputs["intent_logits"]
# Move tensors to CPU before applying softmax and converting to numpy
intent_probabilities = F.softmax(intent_logits.cpu(), dim=-1).numpy()[0]
token_probabilities = F.softmax(token_logits.cpu(), dim=-1).numpy()[0]
# Extract the probabilities for the positive class (index 1) for each token
token_positive_probs = token_probabilities[:, 1].tolist()
return intent_probabilities.tolist(), token_positive_probs
@simple_log_function_time()
def run_content_classification_inference(
text_inputs: list[str],
) -> list[ContentClassificationPrediction]:
"""
Assign a score to the segments in question. The model stored in get_local_information_content_model()
creates the 'model score' based on its training, and the scores are then converted to a 0.0-1.0 scale.
In the code outside of the model/inference model servers that score will be converted into the actual
boost factor.
"""
def _prob_to_score(prob: float) -> float:
"""
Conversion of base score to 0.0 - 1.0 score. Note that the min/max values depend on the model!
"""
_MIN_BASE_SCORE = 0.25
_MAX_BASE_SCORE = 0.75
if prob < _MIN_BASE_SCORE:
raw_score = 0.0
elif prob < _MAX_BASE_SCORE:
raw_score = (prob - _MIN_BASE_SCORE) / (_MAX_BASE_SCORE - _MIN_BASE_SCORE)
else:
raw_score = 1.0
return (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
+ (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
- INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
)
* raw_score
)
_BATCH_SIZE = 32
content_model = get_local_information_content_model()
# Process inputs in batches
all_output_classes: list[int] = []
all_base_output_probabilities: list[float] = []
for i in range(0, len(text_inputs), _BATCH_SIZE):
batch = text_inputs[i : i + _BATCH_SIZE]
batch_with_prefix = []
batch_indices = []
# Pre-allocate results for this batch
batch_output_classes: list[np.ndarray] = [np.array(1)] * len(batch)
batch_probabilities: list[np.ndarray] = [np.array(1.0)] * len(batch)
# Pre-process batch to handle long input exceptions
for j, text in enumerate(batch):
if len(text) == 0:
# if no input, treat as non-informative from the model's perspective
batch_output_classes[j] = np.array(0)
batch_probabilities[j] = np.array(0.0)
logger.warning("Input for Content Information Model is empty")
elif (
len(text.split())
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
):
# if input is short, use the model
batch_with_prefix.append(
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX + text
)
batch_indices.append(j)
else:
# if longer than cutoff, treat as informative (stay with default), but issue warning
logger.warning("Input for Content Information Model too long")
if batch_with_prefix: # Only run model if we have valid inputs
# Get predictions for the batch
model_output_classes = content_model(batch_with_prefix)
model_output_probabilities = content_model.predict_proba(batch_with_prefix)
# Place results in the correct positions
for idx, batch_idx in enumerate(batch_indices):
batch_output_classes[batch_idx] = model_output_classes[idx].numpy()
batch_probabilities[batch_idx] = model_output_probabilities[idx][
1
].numpy() # x[1] is prob of the positive class
all_output_classes.extend([int(x) for x in batch_output_classes])
all_base_output_probabilities.extend([float(x) for x in batch_probabilities])
logits = [
np.log(p / (1 - p)) if p != 0.0 and p != 1.0 else (100 if p == 1.0 else -100)
for p in all_base_output_probabilities
]
scaled_logits = [
logit / INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE
for logit in logits
]
output_probabilities_with_temp = [
np.exp(scaled_logit) / (1 + np.exp(scaled_logit))
for scaled_logit in scaled_logits
]
prediction_scores = [
_prob_to_score(p_temp) for p_temp in output_probabilities_with_temp
]
content_classification_predictions = [
ContentClassificationPrediction(
predicted_label=predicted_label, content_boost_factor=output_score
)
for predicted_label, output_score in zip(all_output_classes, prediction_scores)
]
return content_classification_predictions
def map_keywords(
input_ids: torch.Tensor, tokenizer: "PreTrainedTokenizer", is_keyword: list[bool]
) -> list[str]:
tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore
if not len(tokens) == len(is_keyword):
raise ValueError("Length of tokens and keyword predictions must match")
if input_ids[0] == tokenizer.cls_token_id:
tokens = tokens[1:]
is_keyword = is_keyword[1:]
if input_ids[-1] == tokenizer.sep_token_id:
tokens = tokens[:-1]
is_keyword = is_keyword[:-1]
unk_token = tokenizer.unk_token
if unk_token in tokens:
raise ValueError("Unknown token detected in the input")
keywords = []
current_keyword = ""
for ind, token in enumerate(tokens):
if is_keyword[ind]:
if token.startswith("##"):
current_keyword += token[2:]
else:
if current_keyword:
keywords.append(current_keyword)
current_keyword = token
else:
# If mispredicted a later token of a keyword, add it to the current keyword
# to complete it
if current_keyword:
if len(current_keyword) > 2 and current_keyword.startswith("##"):
current_keyword = current_keyword[2:]
else:
keywords.append(current_keyword)
current_keyword = ""
if current_keyword:
keywords.append(current_keyword)
return keywords
def clean_keywords(keywords: list[str]) -> list[str]:
cleaned_words = []
for word in keywords:
word = word[:-2] if word.endswith("'s") else word
word = word.replace("/", " ")
word = word.replace("'", "").replace('"', "")
cleaned_words.extend([w for w in word.strip().split() if w and not w.isspace()])
return cleaned_words
def run_connector_classification(req: ConnectorClassificationRequest) -> list[str]:
tokenizer = get_connector_classifier_tokenizer()
model = get_local_connector_classifier()
connector_names = req.available_connectors
input_ids, attention_mask = tokenize_connector_classification_query(
connector_names,
req.query,
tokenizer,
model.connector_end_token_id,
)
input_ids = input_ids.to(model.device)
attention_mask = attention_mask.to(model.device)
global_confidence, classifier_confidence = model(input_ids, attention_mask)
if global_confidence.item() < 0.5:
return []
passed_connectors = []
for i, connector_name in enumerate(connector_names):
if classifier_confidence.view(-1)[i].item() > 0.5:
passed_connectors.append(connector_name)
return passed_connectors
def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]:
tokenizer = get_intent_model_tokenizer()
model_input = tokenizer(
intent_req.query, return_tensors="pt", truncation=False, padding=False
)
if len(model_input.input_ids[0]) > 512:
# If the user text is too long, assume it is semantic and keep all words
return True, intent_req.query.split()
intent_probs, token_probs = run_inference(model_input)
is_keyword_sequence = intent_probs[0] >= intent_req.keyword_percent_threshold
keyword_preds = [
token_prob >= intent_req.keyword_percent_threshold for token_prob in token_probs
]
try:
keywords = map_keywords(model_input.input_ids[0], tokenizer, keyword_preds)
except Exception as e:
logger.warning(
f"Failed to extract keywords for query: {intent_req.query} due to {e}"
)
# Fallback to keeping all words
keywords = intent_req.query.split()
cleaned_keywords = clean_keywords(keywords)
return is_keyword_sequence, cleaned_keywords
@router.post("/connector-classification")
async def process_connector_classification_request(
classification_request: ConnectorClassificationRequest,
) -> ConnectorClassificationResponse:
if INDEXING_ONLY:
raise RuntimeError(
"Indexing model server should not call connector classification endpoint"
)
if len(classification_request.available_connectors) == 0:
return ConnectorClassificationResponse(connectors=[])
connectors = run_connector_classification(classification_request)
return ConnectorClassificationResponse(connectors=connectors)
@router.post("/query-analysis")
async def process_analysis_request(
intent_request: IntentRequest,
) -> IntentResponse:
if INDEXING_ONLY:
raise RuntimeError("Indexing model server should not call intent endpoint")
is_keyword, keywords = run_analysis(intent_request)
return IntentResponse(is_keyword=is_keyword, keywords=keywords)
@router.post("/content-classification")
async def process_content_classification_request(
content_classification_requests: list[str],
) -> list[ContentClassificationPrediction]:
return run_content_classification_inference(content_classification_requests)

View File

@@ -1,6 +1,7 @@
import asyncio
import time
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
from fastapi import APIRouter
@@ -9,13 +10,16 @@ from fastapi import Request
from model_server.utils import simple_log_function_time
from onyx.utils.logger import setup_logger
from shared_configs.configs import INDEXING_ONLY
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer
from sentence_transformers import CrossEncoder, SentenceTransformer
logger = setup_logger()
@@ -23,6 +27,11 @@ router = APIRouter(prefix="/encoder")
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
_RERANK_MODEL: Optional["CrossEncoder"] = None
# If we are not only indexing, dont want retry very long
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
def get_embedding_model(
@@ -33,7 +42,7 @@ def get_embedding_model(
Loads or returns a cached SentenceTransformer, sets max_seq_length, pins device,
pre-warms rotary caches once, and wraps encode() with a lock to avoid cache races.
"""
from sentence_transformers import SentenceTransformer
from sentence_transformers import SentenceTransformer # type: ignore
def _prewarm_rope(st_model: "SentenceTransformer", target_len: int) -> None:
"""
@@ -78,6 +87,19 @@ def get_embedding_model(
return _GLOBAL_MODELS_DICT[model_name]
def get_local_reranking_model(
model_name: str,
) -> "CrossEncoder":
global _RERANK_MODEL
from sentence_transformers import CrossEncoder # type: ignore
if _RERANK_MODEL is None:
logger.notice(f"Loading {model_name}")
model = CrossEncoder(model_name)
_RERANK_MODEL = model
return _RERANK_MODEL
ENCODING_RETRIES = 3
ENCODING_RETRY_DELAY = 0.1
@@ -167,6 +189,16 @@ async def embed_text(
return embeddings
@simple_log_function_time()
async def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
cross_encoder = get_local_reranking_model(model_name)
# Run CPU-bound reranking in a thread pool
return await asyncio.get_event_loop().run_in_executor(
None,
lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(), # type: ignore
)
@router.post("/bi-encoder-embed")
async def route_bi_encoder_embed(
request: Request,
@@ -222,3 +254,39 @@ async def process_embed_request(
raise HTTPException(
status_code=500, detail=f"Error during embedding process: {e}"
)
@router.post("/cross-encoder-scores")
async def process_rerank_request(rerank_request: RerankRequest) -> RerankResponse:
"""Cross encoders can be purely black box from the app perspective"""
# Only local models should use this endpoint - API providers should make direct API calls
if rerank_request.provider_type is not None:
raise ValueError(
f"Model server reranking endpoint should only be used for local models. "
f"API provider '{rerank_request.provider_type}' should make direct API calls instead."
)
if INDEXING_ONLY:
raise RuntimeError("Indexing model server should not call intent endpoint")
if not rerank_request.documents or not rerank_request.query:
raise HTTPException(
status_code=400, detail="Missing documents or query for reranking"
)
if not all(rerank_request.documents):
raise ValueError("Empty documents cannot be reranked.")
try:
# At this point, provider_type is None, so handle local reranking
sim_scores = await local_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
)
return RerankResponse(scores=sim_scores)
except Exception as e:
logger.exception(f"Error during reranking process:\n{str(e)}")
raise HTTPException(
status_code=500, detail="Failed to run Cross-Encoder reranking"
)

View File

@@ -1,5 +0,0 @@
This directory contains code that was useful and may become useful again in the future.
We stopped using rerankers because the state of the art rerankers are not significantly better than the biencoders and much worse than LLMs which are also capable of acting on a small set of documents for filtering, reranking, etc.
We stopped using the internal query classifier as that's now offloaded to the LLM which does query expansion so we know ahead of time if it's a keyword or semantic query.

View File

@@ -1,573 +0,0 @@
# from typing import cast
# from typing import Optional
# from typing import TYPE_CHECKING
# import numpy as np
# import torch
# import torch.nn.functional as F
# from fastapi import APIRouter
# from huggingface_hub import snapshot_download
# from pydantic import BaseModel
# from model_server.constants import MODEL_WARM_UP_STRING
# from model_server.legacy.onyx_torch_model import ConnectorClassifier
# from model_server.legacy.onyx_torch_model import HybridClassifier
# from model_server.utils import simple_log_function_time
# from onyx.utils.logger import setup_logger
# from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
# from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
# from shared_configs.configs import INDEXING_ONLY
# from shared_configs.configs import INTENT_MODEL_TAG
# from shared_configs.configs import INTENT_MODEL_VERSION
# from shared_configs.model_server_models import IntentRequest
# from shared_configs.model_server_models import IntentResponse
# if TYPE_CHECKING:
# from setfit import SetFitModel # type: ignore[import-untyped]
# from transformers import PreTrainedTokenizer, BatchEncoding
# INFORMATION_CONTENT_MODEL_WARM_UP_STRING = "hi" * 50
# INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX = 1.0
# INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN = 0.7
# INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE = 4.0
# INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH = 10
# INFORMATION_CONTENT_MODEL_VERSION = "onyx-dot-app/information-content-model"
# INFORMATION_CONTENT_MODEL_TAG: str | None = None
# class ConnectorClassificationRequest(BaseModel):
# available_connectors: list[str]
# query: str
# class ConnectorClassificationResponse(BaseModel):
# connectors: list[str]
# class ContentClassificationPrediction(BaseModel):
# predicted_label: int
# content_boost_factor: float
# logger = setup_logger()
# router = APIRouter(prefix="/custom")
# _CONNECTOR_CLASSIFIER_TOKENIZER: Optional["PreTrainedTokenizer"] = None
# _CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
# _INTENT_TOKENIZER: Optional["PreTrainedTokenizer"] = None
# _INTENT_MODEL: HybridClassifier | None = None
# _INFORMATION_CONTENT_MODEL: Optional["SetFitModel"] = None
# _INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
# def get_connector_classifier_tokenizer() -> "PreTrainedTokenizer":
# global _CONNECTOR_CLASSIFIER_TOKENIZER
# from transformers import AutoTokenizer, PreTrainedTokenizer
# if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
# # The tokenizer details are not uploaded to the HF hub since it's just the
# # unmodified distilbert tokenizer.
# _CONNECTOR_CLASSIFIER_TOKENIZER = cast(
# PreTrainedTokenizer,
# AutoTokenizer.from_pretrained("distilbert-base-uncased"),
# )
# return _CONNECTOR_CLASSIFIER_TOKENIZER
# def get_local_connector_classifier(
# model_name_or_path: str = CONNECTOR_CLASSIFIER_MODEL_REPO,
# tag: str = CONNECTOR_CLASSIFIER_MODEL_TAG,
# ) -> ConnectorClassifier:
# global _CONNECTOR_CLASSIFIER_MODEL
# if _CONNECTOR_CLASSIFIER_MODEL is None:
# try:
# # Calculate where the cache should be, then load from local if available
# local_path = snapshot_download(
# repo_id=model_name_or_path, revision=tag, local_files_only=True
# )
# _CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
# local_path
# )
# except Exception as e:
# logger.warning(f"Failed to load model directly: {e}")
# try:
# # Attempt to download the model snapshot
# logger.info(f"Downloading model snapshot for {model_name_or_path}")
# local_path = snapshot_download(repo_id=model_name_or_path, revision=tag)
# _CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
# local_path
# )
# except Exception as e:
# logger.error(
# f"Failed to load model even after attempted snapshot download: {e}"
# )
# raise
# return _CONNECTOR_CLASSIFIER_MODEL
# def get_intent_model_tokenizer() -> "PreTrainedTokenizer":
# from transformers import AutoTokenizer, PreTrainedTokenizer
# global _INTENT_TOKENIZER
# if _INTENT_TOKENIZER is None:
# # The tokenizer details are not uploaded to the HF hub since it's just the
# # unmodified distilbert tokenizer.
# _INTENT_TOKENIZER = cast(
# PreTrainedTokenizer,
# AutoTokenizer.from_pretrained("distilbert-base-uncased"),
# )
# return _INTENT_TOKENIZER
# def get_local_intent_model(
# model_name_or_path: str = INTENT_MODEL_VERSION,
# tag: str | None = INTENT_MODEL_TAG,
# ) -> HybridClassifier:
# global _INTENT_MODEL
# if _INTENT_MODEL is None:
# try:
# # Calculate where the cache should be, then load from local if available
# logger.notice(f"Loading model from local cache: {model_name_or_path}")
# local_path = snapshot_download(
# repo_id=model_name_or_path, revision=tag, local_files_only=True
# )
# _INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
# logger.notice(f"Loaded model from local cache: {local_path}")
# except Exception as e:
# logger.warning(f"Failed to load model directly: {e}")
# try:
# # Attempt to download the model snapshot
# logger.notice(f"Downloading model snapshot for {model_name_or_path}")
# local_path = snapshot_download(
# repo_id=model_name_or_path, revision=tag, local_files_only=False
# )
# _INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
# except Exception as e:
# logger.error(
# f"Failed to load model even after attempted snapshot download: {e}"
# )
# raise
# return _INTENT_MODEL
# def get_local_information_content_model(
# model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION,
# tag: str | None = INFORMATION_CONTENT_MODEL_TAG,
# ) -> "SetFitModel":
# from setfit import SetFitModel
# global _INFORMATION_CONTENT_MODEL
# if _INFORMATION_CONTENT_MODEL is None:
# try:
# # Calculate where the cache should be, then load from local if available
# logger.notice(
# f"Loading content information model from local cache: {model_name_or_path}"
# )
# local_path = snapshot_download(
# repo_id=model_name_or_path, revision=tag, local_files_only=True
# )
# _INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
# logger.notice(
# f"Loaded content information model from local cache: {local_path}"
# )
# except Exception as e:
# logger.warning(f"Failed to load content information model directly: {e}")
# try:
# # Attempt to download the model snapshot
# logger.notice(
# f"Downloading content information model snapshot for {model_name_or_path}"
# )
# local_path = snapshot_download(
# repo_id=model_name_or_path, revision=tag, local_files_only=False
# )
# _INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
# except Exception as e:
# logger.error(
# f"Failed to load content information model even after attempted snapshot download: {e}"
# )
# raise
# return _INFORMATION_CONTENT_MODEL
# def tokenize_connector_classification_query(
# connectors: list[str],
# query: str,
# tokenizer: "PreTrainedTokenizer",
# connector_token_end_id: int,
# ) -> tuple[torch.Tensor, torch.Tensor]:
# """
# Tokenize the connectors & user query into one prompt for the forward pass of ConnectorClassifier models
# The attention mask is just all 1s. The prompt is CLS + each connector name suffixed with the connector end
# token and then the user query.
# """
# input_ids = torch.tensor([tokenizer.cls_token_id], dtype=torch.long)
# for connector in connectors:
# connector_token_ids = tokenizer(
# connector,
# add_special_tokens=False,
# return_tensors="pt",
# )
# input_ids = torch.cat(
# (
# input_ids,
# connector_token_ids["input_ids"].squeeze(dim=0),
# torch.tensor([connector_token_end_id], dtype=torch.long),
# ),
# dim=-1,
# )
# query_token_ids = tokenizer(
# query,
# add_special_tokens=False,
# return_tensors="pt",
# )
# input_ids = torch.cat(
# (
# input_ids,
# query_token_ids["input_ids"].squeeze(dim=0),
# torch.tensor([tokenizer.sep_token_id], dtype=torch.long),
# ),
# dim=-1,
# )
# attention_mask = torch.ones(input_ids.numel(), dtype=torch.long)
# return input_ids.unsqueeze(0), attention_mask.unsqueeze(0)
# def warm_up_connector_classifier_model() -> None:
# logger.info(
# f"Warming up connector_classifier model {CONNECTOR_CLASSIFIER_MODEL_TAG}"
# )
# connector_classifier_tokenizer = get_connector_classifier_tokenizer()
# connector_classifier = get_local_connector_classifier()
# input_ids, attention_mask = tokenize_connector_classification_query(
# ["GitHub"],
# "onyx classifier query google doc",
# connector_classifier_tokenizer,
# connector_classifier.connector_end_token_id,
# )
# input_ids = input_ids.to(connector_classifier.device)
# attention_mask = attention_mask.to(connector_classifier.device)
# connector_classifier(input_ids, attention_mask)
# def warm_up_intent_model() -> None:
# logger.notice(f"Warming up Intent Model: {INTENT_MODEL_VERSION}")
# intent_tokenizer = get_intent_model_tokenizer()
# tokens = intent_tokenizer(
# MODEL_WARM_UP_STRING, return_tensors="pt", truncation=True, padding=True
# )
# intent_model = get_local_intent_model()
# device = intent_model.device
# intent_model(
# query_ids=tokens["input_ids"].to(device),
# query_mask=tokens["attention_mask"].to(device),
# )
# def warm_up_information_content_model() -> None:
# logger.notice("Warming up Content Model") # TODO: add version if needed
# information_content_model = get_local_information_content_model()
# information_content_model(INFORMATION_CONTENT_MODEL_WARM_UP_STRING)
# @simple_log_function_time()
# def run_inference(tokens: "BatchEncoding") -> tuple[list[float], list[float]]:
# intent_model = get_local_intent_model()
# device = intent_model.device
# outputs = intent_model(
# query_ids=tokens["input_ids"].to(device),
# query_mask=tokens["attention_mask"].to(device),
# )
# token_logits = outputs["token_logits"]
# intent_logits = outputs["intent_logits"]
# # Move tensors to CPU before applying softmax and converting to numpy
# intent_probabilities = F.softmax(intent_logits.cpu(), dim=-1).numpy()[0]
# token_probabilities = F.softmax(token_logits.cpu(), dim=-1).numpy()[0]
# # Extract the probabilities for the positive class (index 1) for each token
# token_positive_probs = token_probabilities[:, 1].tolist()
# return intent_probabilities.tolist(), token_positive_probs
# @simple_log_function_time()
# def run_content_classification_inference(
# text_inputs: list[str],
# ) -> list[ContentClassificationPrediction]:
# """
# Assign a score to the segments in question. The model stored in get_local_information_content_model()
# creates the 'model score' based on its training, and the scores are then converted to a 0.0-1.0 scale.
# In the code outside of the model/inference model servers that score will be converted into the actual
# boost factor.
# """
# def _prob_to_score(prob: float) -> float:
# """
# Conversion of base score to 0.0 - 1.0 score. Note that the min/max values depend on the model!
# """
# _MIN_BASE_SCORE = 0.25
# _MAX_BASE_SCORE = 0.75
# if prob < _MIN_BASE_SCORE:
# raw_score = 0.0
# elif prob < _MAX_BASE_SCORE:
# raw_score = (prob - _MIN_BASE_SCORE) / (_MAX_BASE_SCORE - _MIN_BASE_SCORE)
# else:
# raw_score = 1.0
# return (
# INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
# + (
# INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
# - INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
# )
# * raw_score
# )
# _BATCH_SIZE = 32
# content_model = get_local_information_content_model()
# # Process inputs in batches
# all_output_classes: list[int] = []
# all_base_output_probabilities: list[float] = []
# for i in range(0, len(text_inputs), _BATCH_SIZE):
# batch = text_inputs[i : i + _BATCH_SIZE]
# batch_with_prefix = []
# batch_indices = []
# # Pre-allocate results for this batch
# batch_output_classes: list[np.ndarray] = [np.array(1)] * len(batch)
# batch_probabilities: list[np.ndarray] = [np.array(1.0)] * len(batch)
# # Pre-process batch to handle long input exceptions
# for j, text in enumerate(batch):
# if len(text) == 0:
# # if no input, treat as non-informative from the model's perspective
# batch_output_classes[j] = np.array(0)
# batch_probabilities[j] = np.array(0.0)
# logger.warning("Input for Content Information Model is empty")
# elif (
# len(text.split())
# <= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
# ):
# # if input is short, use the model
# batch_with_prefix.append(
# _INFORMATION_CONTENT_MODEL_PROMPT_PREFIX + text
# )
# batch_indices.append(j)
# else:
# # if longer than cutoff, treat as informative (stay with default), but issue warning
# logger.warning("Input for Content Information Model too long")
# if batch_with_prefix: # Only run model if we have valid inputs
# # Get predictions for the batch
# model_output_classes = content_model(batch_with_prefix)
# model_output_probabilities = content_model.predict_proba(batch_with_prefix)
# # Place results in the correct positions
# for idx, batch_idx in enumerate(batch_indices):
# batch_output_classes[batch_idx] = model_output_classes[idx].numpy()
# batch_probabilities[batch_idx] = model_output_probabilities[idx][
# 1
# ].numpy() # x[1] is prob of the positive class
# all_output_classes.extend([int(x) for x in batch_output_classes])
# all_base_output_probabilities.extend([float(x) for x in batch_probabilities])
# logits = [
# np.log(p / (1 - p)) if p != 0.0 and p != 1.0 else (100 if p == 1.0 else -100)
# for p in all_base_output_probabilities
# ]
# scaled_logits = [
# logit / INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE
# for logit in logits
# ]
# output_probabilities_with_temp = [
# np.exp(scaled_logit) / (1 + np.exp(scaled_logit))
# for scaled_logit in scaled_logits
# ]
# prediction_scores = [
# _prob_to_score(p_temp) for p_temp in output_probabilities_with_temp
# ]
# content_classification_predictions = [
# ContentClassificationPrediction(
# predicted_label=predicted_label, content_boost_factor=output_score
# )
# for predicted_label, output_score in zip(all_output_classes, prediction_scores)
# ]
# return content_classification_predictions
# def map_keywords(
# input_ids: torch.Tensor, tokenizer: "PreTrainedTokenizer", is_keyword: list[bool]
# ) -> list[str]:
# tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore
# if not len(tokens) == len(is_keyword):
# raise ValueError("Length of tokens and keyword predictions must match")
# if input_ids[0] == tokenizer.cls_token_id:
# tokens = tokens[1:]
# is_keyword = is_keyword[1:]
# if input_ids[-1] == tokenizer.sep_token_id:
# tokens = tokens[:-1]
# is_keyword = is_keyword[:-1]
# unk_token = tokenizer.unk_token
# if unk_token in tokens:
# raise ValueError("Unknown token detected in the input")
# keywords = []
# current_keyword = ""
# for ind, token in enumerate(tokens):
# if is_keyword[ind]:
# if token.startswith("##"):
# current_keyword += token[2:]
# else:
# if current_keyword:
# keywords.append(current_keyword)
# current_keyword = token
# else:
# # If mispredicted a later token of a keyword, add it to the current keyword
# # to complete it
# if current_keyword:
# if len(current_keyword) > 2 and current_keyword.startswith("##"):
# current_keyword = current_keyword[2:]
# else:
# keywords.append(current_keyword)
# current_keyword = ""
# if current_keyword:
# keywords.append(current_keyword)
# return keywords
# def clean_keywords(keywords: list[str]) -> list[str]:
# cleaned_words = []
# for word in keywords:
# word = word[:-2] if word.endswith("'s") else word
# word = word.replace("/", " ")
# word = word.replace("'", "").replace('"', "")
# cleaned_words.extend([w for w in word.strip().split() if w and not w.isspace()])
# return cleaned_words
# def run_connector_classification(req: ConnectorClassificationRequest) -> list[str]:
# tokenizer = get_connector_classifier_tokenizer()
# model = get_local_connector_classifier()
# connector_names = req.available_connectors
# input_ids, attention_mask = tokenize_connector_classification_query(
# connector_names,
# req.query,
# tokenizer,
# model.connector_end_token_id,
# )
# input_ids = input_ids.to(model.device)
# attention_mask = attention_mask.to(model.device)
# global_confidence, classifier_confidence = model(input_ids, attention_mask)
# if global_confidence.item() < 0.5:
# return []
# passed_connectors = []
# for i, connector_name in enumerate(connector_names):
# if classifier_confidence.view(-1)[i].item() > 0.5:
# passed_connectors.append(connector_name)
# return passed_connectors
# def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]:
# tokenizer = get_intent_model_tokenizer()
# model_input = tokenizer(
# intent_req.query, return_tensors="pt", truncation=False, padding=False
# )
# if len(model_input.input_ids[0]) > 512:
# # If the user text is too long, assume it is semantic and keep all words
# return True, intent_req.query.split()
# intent_probs, token_probs = run_inference(model_input)
# is_keyword_sequence = intent_probs[0] >= intent_req.keyword_percent_threshold
# keyword_preds = [
# token_prob >= intent_req.keyword_percent_threshold for token_prob in token_probs
# ]
# try:
# keywords = map_keywords(model_input.input_ids[0], tokenizer, keyword_preds)
# except Exception as e:
# logger.warning(
# f"Failed to extract keywords for query: {intent_req.query} due to {e}"
# )
# # Fallback to keeping all words
# keywords = intent_req.query.split()
# cleaned_keywords = clean_keywords(keywords)
# return is_keyword_sequence, cleaned_keywords
# @router.post("/connector-classification")
# async def process_connector_classification_request(
# classification_request: ConnectorClassificationRequest,
# ) -> ConnectorClassificationResponse:
# if INDEXING_ONLY:
# raise RuntimeError(
# "Indexing model server should not call connector classification endpoint"
# )
# if len(classification_request.available_connectors) == 0:
# return ConnectorClassificationResponse(connectors=[])
# connectors = run_connector_classification(classification_request)
# return ConnectorClassificationResponse(connectors=connectors)
# @router.post("/query-analysis")
# async def process_analysis_request(
# intent_request: IntentRequest,
# ) -> IntentResponse:
# if INDEXING_ONLY:
# raise RuntimeError("Indexing model server should not call intent endpoint")
# is_keyword, keywords = run_analysis(intent_request)
# return IntentResponse(is_keyword=is_keyword, keywords=keywords)
# @router.post("/content-classification")
# async def process_content_classification_request(
# content_classification_requests: list[str],
# ) -> list[ContentClassificationPrediction]:
# return run_content_classification_inference(content_classification_requests)

View File

@@ -1,154 +0,0 @@
# import json
# import os
# from typing import cast
# from typing import TYPE_CHECKING
# import torch
# import torch.nn as nn
# if TYPE_CHECKING:
# from transformers import DistilBertConfig
# class HybridClassifier(nn.Module):
# def __init__(self) -> None:
# from transformers import DistilBertConfig, DistilBertModel
# super().__init__()
# config = DistilBertConfig()
# self.distilbert = DistilBertModel(config)
# config = self.distilbert.config # type: ignore
# # Keyword tokenwise binary classification layer
# self.keyword_classifier = nn.Linear(config.dim, 2)
# # Intent Classifier layers
# self.pre_classifier = nn.Linear(config.dim, config.dim)
# self.intent_classifier = nn.Linear(config.dim, 2)
# self.device = torch.device("cpu")
# def forward(
# self,
# query_ids: torch.Tensor,
# query_mask: torch.Tensor,
# ) -> dict[str, torch.Tensor]:
# outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask)
# sequence_output = outputs.last_hidden_state
# # Intent classification on the CLS token
# cls_token_state = sequence_output[:, 0, :]
# pre_classifier_out = self.pre_classifier(cls_token_state)
# intent_logits = self.intent_classifier(pre_classifier_out)
# # Keyword classification on all tokens
# token_logits = self.keyword_classifier(sequence_output)
# return {"intent_logits": intent_logits, "token_logits": token_logits}
# @classmethod
# def from_pretrained(cls, load_directory: str) -> "HybridClassifier":
# model_path = os.path.join(load_directory, "pytorch_model.bin")
# config_path = os.path.join(load_directory, "config.json")
# with open(config_path, "r") as f:
# config = json.load(f)
# model = cls(**config)
# if torch.backends.mps.is_available():
# # Apple silicon GPU
# device = torch.device("mps")
# elif torch.cuda.is_available():
# device = torch.device("cuda")
# else:
# device = torch.device("cpu")
# model.load_state_dict(torch.load(model_path, map_location=device))
# model = model.to(device)
# model.device = device
# model.eval()
# # Eval doesn't set requires_grad to False, do it manually to save memory and have faster inference
# for param in model.parameters():
# param.requires_grad = False
# return model
# class ConnectorClassifier(nn.Module):
# def __init__(self, config: "DistilBertConfig") -> None:
# from transformers import DistilBertTokenizer, DistilBertModel
# super().__init__()
# self.config = config
# self.distilbert = DistilBertModel(config)
# config = self.distilbert.config # type: ignore
# self.connector_global_classifier = nn.Linear(config.dim, 1)
# self.connector_match_classifier = nn.Linear(config.dim, 1)
# self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
# # Token indicating end of connector name, and on which classifier is used
# self.connector_end_token_id = self.tokenizer.get_vocab()[
# self.config.connector_end_token
# ]
# self.device = torch.device("cpu")
# def forward(
# self,
# input_ids: torch.Tensor,
# attention_mask: torch.Tensor,
# ) -> tuple[torch.Tensor, torch.Tensor]:
# hidden_states = self.distilbert(
# input_ids=input_ids, attention_mask=attention_mask
# ).last_hidden_state
# cls_hidden_states = hidden_states[
# :, 0, :
# ] # Take leap of faith that first token is always [CLS]
# global_logits = self.connector_global_classifier(cls_hidden_states).view(-1)
# global_confidence = torch.sigmoid(global_logits).view(-1)
# connector_end_position_ids = input_ids == self.connector_end_token_id
# connector_end_hidden_states = hidden_states[connector_end_position_ids]
# classifier_output = self.connector_match_classifier(connector_end_hidden_states)
# classifier_confidence = torch.nn.functional.sigmoid(classifier_output).view(-1)
# return global_confidence, classifier_confidence
# @classmethod
# def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
# from transformers import DistilBertConfig
# config = cast(
# DistilBertConfig,
# DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")),
# )
# device = (
# torch.device("cuda")
# if torch.cuda.is_available()
# else (
# torch.device("mps")
# if torch.backends.mps.is_available()
# else torch.device("cpu")
# )
# )
# state_dict = torch.load(
# os.path.join(repo_dir, "pytorch_model.pt"),
# map_location=device,
# weights_only=True,
# )
# model = cls(config)
# model.load_state_dict(state_dict)
# model.to(device)
# model.device = device
# model.eval()
# for param in model.parameters():
# param.requires_grad = False
# return model

View File

@@ -1,80 +0,0 @@
# import asyncio
# from typing import Optional
# from typing import TYPE_CHECKING
# from fastapi import APIRouter
# from fastapi import HTTPException
# from model_server.utils import simple_log_function_time
# from onyx.utils.logger import setup_logger
# from shared_configs.configs import INDEXING_ONLY
# from shared_configs.model_server_models import RerankRequest
# from shared_configs.model_server_models import RerankResponse
# if TYPE_CHECKING:
# from sentence_transformers import CrossEncoder
# logger = setup_logger()
# router = APIRouter(prefix="/encoder")
# _RERANK_MODEL: Optional["CrossEncoder"] = None
# def get_local_reranking_model(
# model_name: str,
# ) -> "CrossEncoder":
# global _RERANK_MODEL
# from sentence_transformers import CrossEncoder
# if _RERANK_MODEL is None:
# logger.notice(f"Loading {model_name}")
# model = CrossEncoder(model_name)
# _RERANK_MODEL = model
# return _RERANK_MODEL
# @simple_log_function_time()
# async def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
# cross_encoder = get_local_reranking_model(model_name)
# # Run CPU-bound reranking in a thread pool
# return await asyncio.get_event_loop().run_in_executor(
# None,
# lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(),
# )
# @router.post("/cross-encoder-scores")
# async def process_rerank_request(rerank_request: RerankRequest) -> RerankResponse:
# """Cross encoders can be purely black box from the app perspective"""
# # Only local models should use this endpoint - API providers should make direct API calls
# if rerank_request.provider_type is not None:
# raise ValueError(
# f"Model server reranking endpoint should only be used for local models. "
# f"API provider '{rerank_request.provider_type}' should make direct API calls instead."
# )
# if INDEXING_ONLY:
# raise RuntimeError("Indexing model server should not call reranking endpoint")
# if not rerank_request.documents or not rerank_request.query:
# raise HTTPException(
# status_code=400, detail="Missing documents or query for reranking"
# )
# if not all(rerank_request.documents):
# raise ValueError("Empty documents cannot be reranked.")
# try:
# # At this point, provider_type is None, so handle local reranking
# sim_scores = await local_rerank(
# query=rerank_request.query,
# docs=rerank_request.documents,
# model_name=rerank_request.model_name,
# )
# return RerankResponse(scores=sim_scores)
# except Exception as e:
# logger.exception(f"Error during reranking process:\n{str(e)}")
# raise HTTPException(
# status_code=500, detail="Failed to run Cross-Encoder reranking"
# )

View File

@@ -12,8 +12,11 @@ from fastapi import FastAPI
from prometheus_fastapi_instrumentator import Instrumentator
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration
from transformers import logging as transformer_logging
from transformers import logging as transformer_logging # type:ignore
from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_information_content_model
from model_server.custom_models import warm_up_intent_model
from model_server.encoders import router as encoders_router
from model_server.management_endpoints import router as management_router
from model_server.utils import get_gpu_type
@@ -27,6 +30,7 @@ from shared_configs.configs import MIN_THREADS_ML_MODELS
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import SENTRY_DSN
from shared_configs.configs import SKIP_WARM_UP
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
@@ -88,6 +92,18 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
if not SKIP_WARM_UP:
if not INDEXING_ONLY:
logger.notice("Warming up intent model for inference model server")
warm_up_intent_model()
else:
logger.notice(
"Warming up content information model for indexing model server"
)
warm_up_information_content_model()
else:
logger.notice("Skipping model warmup due to SKIP_WARM_UP=true")
yield
@@ -107,6 +123,7 @@ def get_model_app() -> FastAPI:
application.include_router(management_router)
application.include_router(encoders_router)
application.include_router(custom_models_router)
request_id_prefix = "INF"
if INDEXING_ONLY:

View File

@@ -0,0 +1,154 @@
import json
import os
from typing import cast
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
if TYPE_CHECKING:
from transformers import DistilBertConfig # type: ignore
class HybridClassifier(nn.Module):
def __init__(self) -> None:
from transformers import DistilBertConfig, DistilBertModel
super().__init__()
config = DistilBertConfig()
self.distilbert = DistilBertModel(config)
config = self.distilbert.config # type: ignore
# Keyword tokenwise binary classification layer
self.keyword_classifier = nn.Linear(config.dim, 2)
# Intent Classifier layers
self.pre_classifier = nn.Linear(config.dim, config.dim)
self.intent_classifier = nn.Linear(config.dim, 2)
self.device = torch.device("cpu")
def forward(
self,
query_ids: torch.Tensor,
query_mask: torch.Tensor,
) -> dict[str, torch.Tensor]:
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask) # type: ignore
sequence_output = outputs.last_hidden_state
# Intent classification on the CLS token
cls_token_state = sequence_output[:, 0, :]
pre_classifier_out = self.pre_classifier(cls_token_state)
intent_logits = self.intent_classifier(pre_classifier_out)
# Keyword classification on all tokens
token_logits = self.keyword_classifier(sequence_output)
return {"intent_logits": intent_logits, "token_logits": token_logits}
@classmethod
def from_pretrained(cls, load_directory: str) -> "HybridClassifier":
model_path = os.path.join(load_directory, "pytorch_model.bin")
config_path = os.path.join(load_directory, "config.json")
with open(config_path, "r") as f:
config = json.load(f)
model = cls(**config)
if torch.backends.mps.is_available():
# Apple silicon GPU
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.device = device
model.eval()
# Eval doesn't set requires_grad to False, do it manually to save memory and have faster inference
for param in model.parameters():
param.requires_grad = False
return model
class ConnectorClassifier(nn.Module):
def __init__(self, config: "DistilBertConfig") -> None:
from transformers import DistilBertTokenizer, DistilBertModel
super().__init__()
self.config = config
self.distilbert = DistilBertModel(config)
config = self.distilbert.config # type: ignore
self.connector_global_classifier = nn.Linear(config.dim, 1)
self.connector_match_classifier = nn.Linear(config.dim, 1)
self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
# Token indicating end of connector name, and on which classifier is used
self.connector_end_token_id = self.tokenizer.get_vocab()[
self.config.connector_end_token
]
self.device = torch.device("cpu")
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states = self.distilbert( # type: ignore
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state
cls_hidden_states = hidden_states[
:, 0, :
] # Take leap of faith that first token is always [CLS]
global_logits = self.connector_global_classifier(cls_hidden_states).view(-1)
global_confidence = torch.sigmoid(global_logits).view(-1)
connector_end_position_ids = input_ids == self.connector_end_token_id
connector_end_hidden_states = hidden_states[connector_end_position_ids]
classifier_output = self.connector_match_classifier(connector_end_hidden_states)
classifier_confidence = torch.nn.functional.sigmoid(classifier_output).view(-1)
return global_confidence, classifier_confidence
@classmethod
def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
from transformers import DistilBertConfig
config = cast(
DistilBertConfig,
DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")),
)
device = (
torch.device("cuda")
if torch.cuda.is_available()
else (
torch.device("mps")
if torch.backends.mps.is_available()
else torch.device("cpu")
)
)
state_dict = torch.load(
os.path.join(repo_dir, "pytorch_model.pt"),
map_location=device,
weights_only=True,
)
model = cls(config)
model.load_state_dict(state_dict)
model.to(device)
model.device = device
model.eval()
for param in model.parameters():
param.requires_grad = False
return model

View File

@@ -43,7 +43,7 @@ def get_access_for_document(
versioned_get_access_for_document_fn = fetch_versioned_implementation(
"onyx.access.access", "_get_access_for_document"
)
return versioned_get_access_for_document_fn(document_id, db_session)
return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore
def get_null_document_access() -> DocumentAccess:
@@ -93,7 +93,9 @@ def get_access_for_documents(
versioned_get_access_for_documents_fn = fetch_versioned_implementation(
"onyx.access.access", "_get_access_for_documents"
)
return versioned_get_access_for_documents_fn(document_ids, db_session)
return versioned_get_access_for_documents_fn(
document_ids, db_session
) # type: ignore
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
@@ -111,7 +113,7 @@ def get_acl_for_user(user: User | None, db_session: Session | None = None) -> se
versioned_acl_for_user_fn = fetch_versioned_implementation(
"onyx.access.access", "_get_acl_for_user"
)
return versioned_acl_for_user_fn(user, db_session)
return versioned_acl_for_user_fn(user, db_session) # type: ignore
def source_should_fetch_permissions_during_indexing(source: DocumentSource) -> bool:

View File

@@ -105,8 +105,6 @@ class DocExternalAccess:
)
# TODO(andrei): First refactor this into a pydantic model, then get rid of
# duplicate fields.
@dataclass(frozen=True, init=False)
class DocumentAccess(ExternalAccess):
# User emails for Onyx users, None indicates admin

View File

@@ -1,107 +0,0 @@
"""Captcha verification for user registration."""
import httpx
from pydantic import BaseModel
from pydantic import Field
from onyx.configs.app_configs import CAPTCHA_ENABLED
from onyx.configs.app_configs import RECAPTCHA_SCORE_THRESHOLD
from onyx.configs.app_configs import RECAPTCHA_SECRET_KEY
from onyx.utils.logger import setup_logger
logger = setup_logger()
RECAPTCHA_VERIFY_URL = "https://www.google.com/recaptcha/api/siteverify"
class CaptchaVerificationError(Exception):
"""Raised when captcha verification fails."""
class RecaptchaResponse(BaseModel):
"""Response from Google reCAPTCHA verification API."""
success: bool
score: float | None = None # Only present for reCAPTCHA v3
action: str | None = None
challenge_ts: str | None = None
hostname: str | None = None
error_codes: list[str] | None = Field(default=None, alias="error-codes")
def is_captcha_enabled() -> bool:
"""Check if captcha verification is enabled."""
return CAPTCHA_ENABLED and bool(RECAPTCHA_SECRET_KEY)
async def verify_captcha_token(
token: str,
expected_action: str = "signup",
) -> None:
"""
Verify a reCAPTCHA token with Google's API.
Args:
token: The reCAPTCHA response token from the client
expected_action: Expected action name for v3 verification
Raises:
CaptchaVerificationError: If verification fails
"""
if not is_captcha_enabled():
return
if not token:
raise CaptchaVerificationError("Captcha token is required")
try:
async with httpx.AsyncClient() as client:
response = await client.post(
RECAPTCHA_VERIFY_URL,
data={
"secret": RECAPTCHA_SECRET_KEY,
"response": token,
},
timeout=10.0,
)
response.raise_for_status()
data = response.json()
result = RecaptchaResponse(**data)
if not result.success:
error_codes = result.error_codes or ["unknown-error"]
logger.warning(f"Captcha verification failed: {error_codes}")
raise CaptchaVerificationError(
f"Captcha verification failed: {', '.join(error_codes)}"
)
# For reCAPTCHA v3, also check the score
if result.score is not None:
if result.score < RECAPTCHA_SCORE_THRESHOLD:
logger.warning(
f"Captcha score too low: {result.score} < {RECAPTCHA_SCORE_THRESHOLD}"
)
raise CaptchaVerificationError(
"Captcha verification failed: suspicious activity detected"
)
# Optionally verify the action matches
if result.action and result.action != expected_action:
logger.warning(
f"Captcha action mismatch: {result.action} != {expected_action}"
)
raise CaptchaVerificationError(
"Captcha verification failed: action mismatch"
)
logger.debug(
f"Captcha verification passed: score={result.score}, "
f"action={result.action}"
)
except httpx.HTTPError as e:
logger.error(f"Captcha API request failed: {e}")
# In case of API errors, we might want to allow registration
# to prevent blocking legitimate users. This is a policy decision.
raise CaptchaVerificationError("Captcha verification service unavailable")

View File

@@ -1,192 +0,0 @@
"""
Utility to validate and block disposable/temporary email addresses.
This module fetches a list of known disposable email domains from a remote source
and caches them for performance. It's used during user registration to prevent
abuse from temporary email services.
"""
import threading
import time
from typing import Set
import httpx
from onyx.configs.app_configs import DISPOSABLE_EMAIL_DOMAINS_URL
from onyx.utils.logger import setup_logger
logger = setup_logger()
class DisposableEmailValidator:
"""
Thread-safe singleton validator for disposable email domains.
Fetches and caches the list of disposable domains, with periodic refresh.
"""
_instance: "DisposableEmailValidator | None" = None
_lock = threading.Lock()
def __new__(cls) -> "DisposableEmailValidator":
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
# Check if already initialized using a try/except to avoid type issues
try:
if self._initialized:
return
except AttributeError:
pass
self._domains: Set[str] = set()
self._last_fetch_time: float = 0
self._fetch_lock = threading.Lock()
# Cache for 1 hour
self._cache_duration = 3600
# Hardcoded fallback list of common disposable domains
# This ensures we block at least these even if the remote fetch fails
self._fallback_domains = {
"trashlify.com",
"10minutemail.com",
"guerrillamail.com",
"mailinator.com",
"tempmail.com",
"throwaway.email",
"yopmail.com",
"temp-mail.org",
"getnada.com",
"maildrop.cc",
}
# Set initialized flag last to prevent race conditions
self._initialized: bool = True
def _should_refresh(self) -> bool:
"""Check if the cached domains should be refreshed."""
return (time.time() - self._last_fetch_time) > self._cache_duration
def _fetch_domains(self) -> Set[str]:
"""
Fetch disposable email domains from the configured URL.
Returns:
Set of domain strings (lowercased)
"""
if not DISPOSABLE_EMAIL_DOMAINS_URL:
logger.debug("DISPOSABLE_EMAIL_DOMAINS_URL not configured")
return self._fallback_domains.copy()
try:
logger.info(
f"Fetching disposable email domains from {DISPOSABLE_EMAIL_DOMAINS_URL}"
)
with httpx.Client(timeout=10.0) as client:
response = client.get(DISPOSABLE_EMAIL_DOMAINS_URL)
response.raise_for_status()
domains_list = response.json()
if not isinstance(domains_list, list):
logger.error(
f"Expected list from disposable domains URL, got {type(domains_list)}"
)
return self._fallback_domains.copy()
# Convert all to lowercase and create set
domains = {domain.lower().strip() for domain in domains_list if domain}
# Always include fallback domains
domains.update(self._fallback_domains)
logger.info(
f"Successfully fetched {len(domains)} disposable email domains"
)
return domains
except httpx.HTTPError as e:
logger.warning(f"Failed to fetch disposable domains (HTTP error): {e}")
except Exception as e:
logger.warning(f"Failed to fetch disposable domains: {e}")
# On error, return fallback domains
return self._fallback_domains.copy()
def get_domains(self) -> Set[str]:
"""
Get the cached set of disposable email domains.
Refreshes the cache if needed.
Returns:
Set of disposable domain strings (lowercased)
"""
# Fast path: return cached domains if still fresh
if self._domains and not self._should_refresh():
return self._domains.copy()
# Slow path: need to refresh
with self._fetch_lock:
# Double-check after acquiring lock
if self._domains and not self._should_refresh():
return self._domains.copy()
self._domains = self._fetch_domains()
self._last_fetch_time = time.time()
return self._domains.copy()
def is_disposable(self, email: str) -> bool:
"""
Check if an email address uses a disposable domain.
Args:
email: The email address to check
Returns:
True if the email domain is disposable, False otherwise
"""
if not email or "@" not in email:
return False
parts = email.split("@")
if len(parts) != 2 or not parts[0]: # Must have user@domain with non-empty user
return False
domain = parts[1].lower().strip()
if not domain: # Domain part must not be empty
return False
disposable_domains = self.get_domains()
return domain in disposable_domains
# Global singleton instance
_validator = DisposableEmailValidator()
def is_disposable_email(email: str) -> bool:
"""
Check if an email address uses a disposable/temporary domain.
This is a convenience function that uses the global validator instance.
Args:
email: The email address to check
Returns:
True if the email uses a disposable domain, False otherwise
"""
return _validator.is_disposable(email)
def refresh_disposable_domains() -> None:
"""
Force a refresh of the disposable domains list.
This can be called manually if you want to update the list
without waiting for the cache to expire.
"""
_validator._last_fetch_time = 0
_validator.get_domains()

View File

@@ -40,8 +40,6 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
tenant_id: str | None = None
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
captcha_token: str | None = None
class UserUpdateWithRole(schemas.BaseUserUpdate):

View File

@@ -60,7 +60,6 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from onyx.auth.api_key import get_hashed_api_key_from_request
from onyx.auth.disposable_email_validator import is_disposable_email
from onyx.auth.email_utils import send_forgot_password_email
from onyx.auth.email_utils import send_user_verification_email
from onyx.auth.invited_users import get_invited_users
@@ -118,7 +117,7 @@ from onyx.redis.redis_pool import get_async_redis_connection
from onyx.redis.redis_pool import get_redis_client
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.timing import log_function_time
@@ -249,23 +248,13 @@ def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
def verify_email_domain(email: str) -> None:
if email.count("@") != 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email is not valid",
)
domain = email.split("@")[-1].lower()
# Check if email uses a disposable/temporary domain
if is_disposable_email(email):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Disposable email addresses are not allowed. Please use a permanent email address.",
)
# Check domain whitelist if configured
if VALID_EMAIL_DOMAINS:
if email.count("@") != 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email is not valid",
)
domain = email.split("@")[-1].lower()
if domain not in VALID_EMAIL_DOMAINS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -303,57 +292,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> User:
# Verify captcha if enabled (for cloud signup protection)
from onyx.auth.captcha import CaptchaVerificationError
from onyx.auth.captcha import is_captcha_enabled
from onyx.auth.captcha import verify_captcha_token
if is_captcha_enabled() and request is not None:
# Get captcha token from request body or headers
captcha_token = None
if hasattr(user_create, "captcha_token"):
captcha_token = getattr(user_create, "captcha_token", None)
# Also check headers as a fallback
if not captcha_token:
captcha_token = request.headers.get("X-Captcha-Token")
try:
await verify_captcha_token(
captcha_token or "", expected_action="signup"
)
except CaptchaVerificationError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": str(e)},
)
# We verify the password here to make sure it's valid before we proceed
await self.validate_password(
user_create.password, cast(schemas.UC, user_create)
)
# Check for disposable emails BEFORE provisioning tenant
# This prevents creating tenants for throwaway email addresses
try:
verify_email_domain(user_create.email)
except HTTPException as e:
# Log blocked disposable email attempts
if (
e.status_code == status.HTTP_400_BAD_REQUEST
and "Disposable email" in str(e.detail)
):
domain = (
user_create.email.split("@")[-1]
if "@" in user_create.email
else "unknown"
)
logger.warning(
f"Blocked disposable email registration attempt: {domain}",
extra={"email_domain": domain},
)
raise
user_count: int | None = None
referral_source = (
request.cookies.get("referral_source", None)
@@ -375,17 +318,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
async with get_async_session_context_manager(tenant_id) as db_session:
# Check invite list based on deployment mode
if MULTI_TENANT:
# Multi-tenant: Only require invite for existing tenants
# New tenant creation (first user) doesn't require an invite
user_count = await get_user_count()
if user_count > 0:
# Tenant already has users - require invite for new users
verify_email_is_invited(user_create.email)
else:
# Single-tenant: Check invite list (skips if SAML/OIDC or no list configured)
verify_email_is_invited(user_create.email)
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
db_session, User, OAuthAccount
@@ -404,7 +338,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_created = False
try:
user = await super().create(user_create, safe=safe, request=request)
user = await super().create(
user_create, safe=safe, request=request
) # type: ignore
user_created = True
except IntegrityError as error:
# Race condition: another request created the same user after the
@@ -668,7 +604,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
# otherwise, the oidc expiry will always be old, and the user will never be able to login
if user.oidc_expiry is not None and not TRACK_EXTERNAL_IDP_EXPIRY:
if (
user.oidc_expiry is not None # type: ignore
and not TRACK_EXTERNAL_IDP_EXPIRY
):
await self.user_db.update(user, {"oidc_expiry": None})
user.oidc_expiry = None # type: ignore
remove_user_from_invited_users(user.email)
@@ -714,11 +653,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_count = await get_user_count()
logger.debug(f"Current tenant user count: {user_count}")
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=user.email,
event=MilestoneRecordType.USER_SIGNED_UP,
)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
event_type = (
MilestoneRecordType.USER_SIGNED_UP
if user_count == 1
else MilestoneRecordType.MULTIPLE_USERS
)
create_milestone_and_report(
user=user,
distinct_id=user.email,
event_type=event_type,
properties=None,
db_session=db_session,
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@@ -1239,7 +1186,7 @@ async def _sync_jwt_oidc_expiry(
return
await user_manager.user_db.update(user, {"oidc_expiry": oidc_expiry})
user.oidc_expiry = oidc_expiry
user.oidc_expiry = oidc_expiry # type: ignore
return
if user.oidc_expiry is not None:

View File

@@ -26,7 +26,6 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.background.celery.celery_utils import make_probe_path
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFIX
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
from onyx.configs.app_configs import ENABLE_OPENSEARCH_FOR_ONYX
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
@@ -516,9 +515,6 @@ def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
"""Waits for Vespa to become ready subject to a timeout.
Raises WorkerShutdown if the timeout is reached."""
if ENABLE_OPENSEARCH_FOR_ONYX:
return
if not wait_for_vespa_with_timeout():
msg = "Vespa: Readiness probe did not succeed within the timeout. Exiting..."
logger.error(msg)

View File

@@ -124,7 +124,6 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.kg_processing",
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.user_file_processing",
"onyx.background.celery.tasks.llm_model_update",
# Light worker tasks
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",

View File

@@ -98,5 +98,8 @@ for bootstep in base_bootsteps:
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.docfetching",
# Ensure the user files indexing worker registers the doc_id migration task
# TODO(subash): remove this once the doc_id migration is complete
"onyx.background.celery.tasks.user_file_processing",
]
)

Some files were not shown because too many files have changed in this diff Show More