Compare commits

..

2 Commits

Author SHA1 Message Date
pablonyx
bbf5fa13dc update 2025-04-11 15:13:02 -07:00
pablonyx
a4a399bc31 update 2025-04-11 15:11:35 -07:00
541 changed files with 7273 additions and 18187 deletions

View File

@@ -25,10 +25,6 @@ inputs:
tags:
description: 'Image tags'
required: true
no-cache:
description: 'Read from cache'
required: false
default: 'false'
cache-from:
description: 'Cache sources'
required: false
@@ -59,7 +55,6 @@ runs:
push: ${{ inputs.push }}
load: ${{ inputs.load }}
tags: ${{ inputs.tags }}
no-cache: ${{ inputs.no-cache }}
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}
@@ -82,7 +77,6 @@ runs:
push: ${{ inputs.push }}
load: ${{ inputs.load }}
tags: ${{ inputs.tags }}
no-cache: ${{ inputs.no-cache }}
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}
@@ -105,7 +99,6 @@ runs:
push: ${{ inputs.push }}
load: ${{ inputs.load }}
tags: ${{ inputs.tags }}
no-cache: ${{ inputs.no-cache }}
cache-from: ${{ inputs.cache-from }}
cache-to: ${{ inputs.cache-to }}

View File

@@ -7,47 +7,18 @@ on:
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
# don't tag cloud images with "latest"
LATEST_TAG: ${{ contains(github.ref_name, 'latest') && !contains(github.ref_name, 'cloud') }}
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
# TODO: investigate a matrix build like the web container
# See https://runs-on.com/runners/linux/
runs-on:
- runs-on
- runner=${{ matrix.platform == 'linux/amd64' && '8cpu-linux-x64' || '8cpu-linux-arm64' }}
- run-id=${{ github.run_id }}
- tag=platform-${{ matrix.platform }}
strategy:
fail-fast: false
matrix:
platform:
- linux/amd64
- linux/arm64
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
steps:
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout code
uses: actions/checkout@v4
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
flavor: |
latest=false
tags: |
type=raw,value=${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -63,80 +34,18 @@ jobs:
sudo apt-get install -y build-essential
- name: Backend Image Docker Build and Push
id: build
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
platforms: ${{ matrix.platform }}
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
ONYX_VERSION=${{ github.ref_name }}
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Export digest
run: |
mkdir -p /tmp/digests
digest="${{ steps.build.outputs.digest }}"
touch "/tmp/digests/${digest#sha256:}"
- name: Upload digest
uses: actions/upload-artifact@v4
with:
name: backend-digests-${{ env.PLATFORM_PAIR }}-${{ github.run_id }}
path: /tmp/digests/*
if-no-files-found: error
retention-days: 1
merge:
runs-on: ubuntu-latest
needs:
- build-and-push
steps:
# Needed for trivyignore
- name: Checkout
uses: actions/checkout@v4
- name: Download digests
uses: actions/download-artifact@v4
with:
path: /tmp/digests
pattern: backend-digests-*-${{ github.run_id }}
merge-multiple: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
flavor: |
latest=false
tags: |
type=raw,value=${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Create manifest list and push
working-directory: /tmp/digests
run: |
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
- name: Inspect image
run: |
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
@@ -147,8 +56,6 @@ jobs:
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }}
TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
with:
# To run locally: trivy image --severity HIGH,CRITICAL onyxdotapp/onyx-backend
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}

View File

@@ -4,12 +4,12 @@ name: Build and Push Cloud Web Image on Tag
on:
push:
tags:
- "*cloud*"
- "*"
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
DEPLOYMENT: cloud
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build:
runs-on:
@@ -38,10 +38,9 @@ jobs:
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
flavor: |
latest=false
tags: |
type=raw,value=${{ github.ref_name }}
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -71,12 +70,10 @@ jobs:
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/cloudweb-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/cloudweb-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
# no-cache needed due to weird interactions with the builds for different platforms
# NOTE(rkuo): this may not be true any more with the proper cache prefixing by architecture - currently testing with it off
- name: Export digest
run: |
@@ -87,7 +84,7 @@ jobs:
- name: Upload digest
uses: actions/upload-artifact@v4
with:
name: cloudweb-digests-${{ env.PLATFORM_PAIR }}-${{ github.run_id }}
name: digests-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*
if-no-files-found: error
retention-days: 1
@@ -101,7 +98,7 @@ jobs:
uses: actions/download-artifact@v4
with:
path: /tmp/digests
pattern: cloudweb-digests-*-${{ github.run_id }}
pattern: digests-*
merge-multiple: true
- name: Set up Docker Buildx
@@ -112,10 +109,6 @@ jobs:
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
flavor: |
latest=false
tags: |
type=raw,value=${{ github.ref_name }}
- name: Login to Docker Hub
uses: docker/login-action@v3
@@ -143,8 +136,6 @@ jobs:
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }}
TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: "CRITICAL,HIGH"

View File

@@ -7,13 +7,10 @@ on:
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
DOCKER_BUILDKIT: 1
BUILDKIT_PROGRESS: plain
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
# don't tag cloud images with "latest"
LATEST_TAG: ${{ contains(github.ref_name, 'latest') && !contains(github.ref_name, 'cloud') }}
jobs:
# Bypassing this for now as the idea of not building is glitching
@@ -54,8 +51,6 @@ jobs:
if: needs.check_model_server_changes.outputs.changed == 'true'
runs-on:
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-amd64"]
env:
PLATFORM_PAIR: linux-amd64
steps:
- name: Checkout code
uses: actions/checkout@v4
@@ -91,17 +86,12 @@ jobs:
DANSWER_VERSION=${{ github.ref_name }}
outputs: type=registry
provenance: false
cache-from: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
# no-cache: true
build-arm64:
needs: [check_model_server_changes]
if: needs.check_model_server_changes.outputs.changed == 'true'
runs-on:
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-arm64"]
env:
PLATFORM_PAIR: linux-arm64
steps:
- name: Checkout code
uses: actions/checkout@v4
@@ -137,8 +127,6 @@ jobs:
DANSWER_VERSION=${{ github.ref_name }}
outputs: type=registry
provenance: false
cache-from: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
merge-and-scan:
needs: [build-amd64, build-arm64, check_model_server_changes]
@@ -168,8 +156,6 @@ jobs:
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }}
TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: "CRITICAL,HIGH"

View File

@@ -8,25 +8,9 @@ on:
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
DEPLOYMENT: standalone
jobs:
precheck:
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}"]
outputs:
should-run: ${{ steps.set-output.outputs.should-run }}
steps:
- name: Check if tag contains "cloud"
id: set-output
run: |
if [[ "${{ github.ref_name }}" == *cloud* ]]; then
echo "should-run=false" >> "$GITHUB_OUTPUT"
else
echo "should-run=true" >> "$GITHUB_OUTPUT"
fi
build:
needs: precheck
if: needs.precheck.outputs.should-run == 'true'
runs-on:
- runs-on
- runner=${{ matrix.platform == 'linux/amd64' && '8cpu-linux-x64' || '8cpu-linux-arm64' }}
@@ -53,11 +37,9 @@ jobs:
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
flavor: |
latest=false
tags: |
type=raw,value=${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -80,13 +62,11 @@ jobs:
ONYX_VERSION=${{ github.ref_name }}
NODE_OPTIONS=--max-old-space-size=8192
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/web-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/${{ env.DEPLOYMENT }}/web-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
# no-cache needed due to weird interactions with the builds for different platforms
# NOTE(rkuo): this may not be true any more with the proper cache prefixing by architecture - currently testing with it off
- name: Export digest
run: |
mkdir -p /tmp/digests
@@ -96,22 +76,21 @@ jobs:
- name: Upload digest
uses: actions/upload-artifact@v4
with:
name: web-digests-${{ env.PLATFORM_PAIR }}-${{ github.run_id }}
name: digests-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*
if-no-files-found: error
retention-days: 1
merge:
runs-on: ubuntu-latest
needs:
- build
if: needs.precheck.outputs.should-run == 'true'
runs-on: ubuntu-latest
steps:
- name: Download digests
uses: actions/download-artifact@v4
with:
path: /tmp/digests
pattern: web-digests-*-${{ github.run_id }}
pattern: digests-*
merge-multiple: true
- name: Set up Docker Buildx
@@ -122,11 +101,6 @@ jobs:
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
flavor: |
latest=false
tags: |
type=raw,value=${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
- name: Login to Docker Hub
uses: docker/login-action@v3
@@ -154,8 +128,6 @@ jobs:
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }}
TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: "CRITICAL,HIGH"

View File

@@ -37,11 +37,6 @@ jobs:
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
# uncomment to force run chart-testing
# - name: Force run chart-testing (list-changed)
# id: list-changed
# run: echo "changed=true" >> $GITHUB_OUTPUT
# lint all charts if any changes were detected
- name: Run chart-testing (lint)
if: steps.list-changed.outputs.changed == 'true'

View File

@@ -16,7 +16,6 @@ env:
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
PLATFORM_PAIR: linux-amd64
jobs:
integration-tests:
@@ -62,8 +61,8 @@ jobs:
tags: onyxdotapp/onyx-backend:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Model Server Docker image
uses: ./.github/actions/custom-build-and-push
@@ -74,8 +73,8 @@ jobs:
tags: onyxdotapp/onyx-model-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
@@ -86,8 +85,8 @@ jobs:
tags: onyxdotapp/onyx-integration:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
# Start containers for multi-tenant tests
- name: Start Docker containers for multi-tenant tests
@@ -159,7 +158,6 @@ jobs:
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
INTEGRATION_TESTS_MODE=true \
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 \
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
id: start_docker

View File

@@ -16,7 +16,7 @@ env:
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
PLATFORM_PAIR: linux-amd64
jobs:
integration-tests-mit:
# See https://runs-on.com/runners/linux/
@@ -61,8 +61,8 @@ jobs:
tags: onyxdotapp/onyx-backend:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Model Server Docker image
uses: ./.github/actions/custom-build-and-push
@@ -73,8 +73,8 @@ jobs:
tags: onyxdotapp/onyx-model-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
@@ -85,8 +85,8 @@ jobs:
tags: onyxdotapp/onyx-integration:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
- name: Start Docker containers

View File

@@ -10,7 +10,6 @@ env:
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
MOCK_LLM_RESPONSE: true
PYTEST_PLAYWRIGHT_SKIP_INITIAL_RESET: true
jobs:
playwright-tests:

View File

@@ -12,7 +12,7 @@ env:
# AWS
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
@@ -20,12 +20,10 @@ env:
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
# Jira
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
# Gong
GONG_ACCESS_KEY: ${{ secrets.GONG_ACCESS_KEY }}
GONG_ACCESS_KEY_SECRET: ${{ secrets.GONG_ACCESS_KEY_SECRET }}
@@ -35,52 +33,37 @@ env:
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
# Slab
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
# Zendesk
ZENDESK_SUBDOMAIN: ${{ secrets.ZENDESK_SUBDOMAIN }}
ZENDESK_EMAIL: ${{ secrets.ZENDESK_EMAIL }}
ZENDESK_TOKEN: ${{ secrets.ZENDESK_TOKEN }}
# Salesforce
SF_USERNAME: ${{ secrets.SF_USERNAME }}
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
# Airtable
AIRTABLE_TEST_BASE_ID: ${{ secrets.AIRTABLE_TEST_BASE_ID }}
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
# Sharepoint
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
# Github
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
# Gitlab
GITLAB_ACCESS_TOKEN: ${{ secrets.GITLAB_ACCESS_TOKEN }}
# Gitbook
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
# Notion
NOTION_INTEGRATION_TOKEN: ${{ secrets.NOTION_INTEGRATION_TOKEN }}
# Highspot
HIGHSPOT_KEY: ${{ secrets.HIGHSPOT_KEY }}
HIGHSPOT_SECRET: ${{ secrets.HIGHSPOT_SECRET }}
# Slack
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/
@@ -112,15 +95,7 @@ jobs:
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
py.test \
-n 8 \
--dist loadfile \
--durations=8 \
-o junit_family=xunit2 \
-xv \
--ff \
backend/tests/daily/connectors
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors
- name: Alert on Failure
if: failure() && github.event_name == 'schedule'

View File

@@ -1,4 +1,4 @@
<!-- ONYX_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/CONTRIBUTING.md"} -->
<!-- DANSWER_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/CONTRIBUTING.md"} -->
# Contributing to Onyx

View File

@@ -1,4 +1,4 @@
<!-- ONYX_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/README.md"} -->
<!-- DANSWER_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/README.md"} -->
<a name="readme-top"></a>
@@ -13,7 +13,7 @@
<a href="https://docs.onyx.app/" target="_blank">
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
</a>
<a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-34lu4m7xg-TsKGO6h8PDvR5W27zTdyhA" target="_blank">
<a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA" target="_blank">
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
</a>
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">

4
backend/.gitignore vendored
View File

@@ -9,6 +9,4 @@ api_keys.py
vespa-app.zip
dynamic_config_storage/
celerybeat-schedule*
onyx/connectors/salesforce/data/
.test.env
onyx/connectors/salesforce/data/

View File

@@ -1,4 +1,4 @@
<!-- ONYX_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/backend/alembic/README.md"} -->
<!-- DANSWER_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/backend/alembic/README.md"} -->
# Alembic DB Migrations

View File

@@ -24,7 +24,6 @@ from onyx.configs.constants import SSL_CERT_FILE
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
from onyx.db.models import Base
from celery.backends.database.session import ResultModelBase # type: ignore
from onyx.db.engine import SqlEngine
# Make sure in alembic.ini [logger_root] level=INFO is set or most logging will be
# hidden! (defaults to level=WARN)
@@ -148,9 +147,6 @@ async def run_async_migrations() -> None:
continue_on_error,
) = get_schema_options()
# without init_engine, subsequent engine calls fail hard intentionally
SqlEngine.init_engine(pool_size=20, max_overflow=5)
engine = create_async_engine(
build_connection_string(),
poolclass=pool.NullPool,
@@ -184,10 +180,10 @@ async def run_async_migrations() -> None:
except Exception as e:
logger.error(f"Error migrating schema {schema}: {e}")
if not continue_on_error:
logger.error("--continue=true is not set, raising exception!")
logger.error("--continue is not set, raising exception!")
raise
logger.warning("--continue=true is set, continuing to next schema.")
logger.warning("--continue is set, continuing to next schema.")
else:
try:
@@ -206,21 +202,10 @@ async def run_async_migrations() -> None:
def run_migrations_offline() -> None:
"""
NOTE(rkuo): This generates a sql script that can be used to migrate the database ...
instead of migrating the db live via an open connection
Not clear on when this would be used by us or if it even works.
If it is offline, then why are there calls to the db engine?
This doesn't really get used when we migrate in the cloud."""
"""This doesn't really get used when we migrate in the cloud."""
logger.info("run_migrations_offline starting.")
# without init_engine, subsequent engine calls fail hard intentionally
SqlEngine.init_engine(pool_size=20, max_overflow=5)
schema_name, _, upgrade_all_tenants, continue_on_error = get_schema_options()
url = build_connection_string()

View File

@@ -1,150 +0,0 @@
"""Fix invalid model-configurations state
Revision ID: 47a07e1a38f1
Revises: 7a70b7664e37
Create Date: 2025-04-23 15:39:43.159504
"""
from alembic import op
from pydantic import BaseModel, ConfigDict
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from onyx.llm.llm_provider_options import (
fetch_model_names_for_provider_as_set,
fetch_visible_model_names_for_provider_as_set,
)
# revision identifiers, used by Alembic.
revision = "47a07e1a38f1"
down_revision = "7a70b7664e37"
branch_labels = None
depends_on = None
class _SimpleModelConfiguration(BaseModel):
# Configure model to read from attributes
model_config = ConfigDict(from_attributes=True)
id: int
llm_provider_id: int
name: str
is_visible: bool
max_input_tokens: int | None
def upgrade() -> None:
llm_provider_table = sa.sql.table(
"llm_provider",
sa.column("id", sa.Integer),
sa.column("provider", sa.String),
sa.column("model_names", postgresql.ARRAY(sa.String)),
sa.column("display_model_names", postgresql.ARRAY(sa.String)),
sa.column("default_model_name", sa.String),
sa.column("fast_default_model_name", sa.String),
)
model_configuration_table = sa.sql.table(
"model_configuration",
sa.column("id", sa.Integer),
sa.column("llm_provider_id", sa.Integer),
sa.column("name", sa.String),
sa.column("is_visible", sa.Boolean),
sa.column("max_input_tokens", sa.Integer),
)
connection = op.get_bind()
llm_providers = connection.execute(
sa.select(
llm_provider_table.c.id,
llm_provider_table.c.provider,
)
).fetchall()
for llm_provider in llm_providers:
llm_provider_id, provider_name = llm_provider
default_models = fetch_model_names_for_provider_as_set(provider_name)
display_models = fetch_visible_model_names_for_provider_as_set(
provider_name=provider_name
)
# if `fetch_model_names_for_provider_as_set` returns `None`, then
# that means that `provider_name` is not a well-known llm provider.
if not default_models:
continue
if not display_models:
raise RuntimeError(
"If `default_models` is non-None, `display_models` must be non-None too."
)
model_configurations = [
_SimpleModelConfiguration.model_validate(model_configuration)
for model_configuration in connection.execute(
sa.select(
model_configuration_table.c.id,
model_configuration_table.c.llm_provider_id,
model_configuration_table.c.name,
model_configuration_table.c.is_visible,
model_configuration_table.c.max_input_tokens,
).where(model_configuration_table.c.llm_provider_id == llm_provider_id)
).fetchall()
]
if model_configurations:
at_least_one_is_visible = any(
[
model_configuration.is_visible
for model_configuration in model_configurations
]
)
# If there is at least one model which is public, this is a valid state.
# Therefore, don't touch it and move on to the next one.
if at_least_one_is_visible:
continue
existing_visible_model_names: set[str] = set(
[
model_configuration.name
for model_configuration in model_configurations
if model_configuration.is_visible
]
)
difference = display_models.difference(existing_visible_model_names)
for model_name in difference:
if not model_name:
continue
insert_statement = postgresql.insert(model_configuration_table).values(
llm_provider_id=llm_provider_id,
name=model_name,
is_visible=True,
max_input_tokens=None,
)
connection.execute(
insert_statement.on_conflict_do_update(
index_elements=["llm_provider_id", "name"],
set_={"is_visible": insert_statement.excluded.is_visible},
)
)
else:
for model_name in default_models:
connection.execute(
model_configuration_table.insert().values(
llm_provider_id=llm_provider_id,
name=model_name,
is_visible=model_name in display_models,
max_input_tokens=None,
)
)
def downgrade() -> None:
pass

View File

@@ -1,24 +0,0 @@
"""Add content type to UserFile
Revision ID: 5c448911b12f
Revises: 47a07e1a38f1
Create Date: 2025-04-25 16:59:48.182672
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5c448911b12f"
down_revision = "47a07e1a38f1"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column("user_file", sa.Column("content_type", sa.String(), nullable=True))
def downgrade() -> None:
op.drop_column("user_file", "content_type")

View File

@@ -6,6 +6,12 @@ Create Date: 2025-04-01 07:26:10.539362
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy import inspect
import datetime
# revision identifiers, used by Alembic.
revision = "6a804aeb4830"
down_revision = "8e1ac4f39a9f"
@@ -13,10 +19,99 @@ branch_labels = None
depends_on = None
# Leaving this around only because some people might be on this migration
# originally was a duplicate of the user files migration
def upgrade() -> None:
pass
# Check if user_file table already exists
conn = op.get_bind()
inspector = inspect(conn)
if not inspector.has_table("user_file"):
# Create user_folder table without parent_id
op.create_table(
"user_folder",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
sa.Column("name", sa.String(length=255), nullable=True),
sa.Column("description", sa.String(length=255), nullable=True),
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
sa.Column(
"created_at", sa.DateTime(timezone=True), server_default=sa.func.now()
),
)
# Create user_file table with folder_id instead of parent_folder_id
op.create_table(
"user_file",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
sa.Column(
"folder_id",
sa.Integer(),
sa.ForeignKey("user_folder.id"),
nullable=True,
),
sa.Column("link_url", sa.String(), nullable=True),
sa.Column("token_count", sa.Integer(), nullable=True),
sa.Column("file_type", sa.String(), nullable=True),
sa.Column("file_id", sa.String(length=255), nullable=False),
sa.Column("document_id", sa.String(length=255), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column(
"created_at",
sa.DateTime(),
default=datetime.datetime.utcnow,
),
sa.Column(
"cc_pair_id",
sa.Integer(),
sa.ForeignKey("connector_credential_pair.id"),
nullable=True,
unique=True,
),
)
# Create persona__user_file table
op.create_table(
"persona__user_file",
sa.Column(
"persona_id",
sa.Integer(),
sa.ForeignKey("persona.id"),
primary_key=True,
),
sa.Column(
"user_file_id",
sa.Integer(),
sa.ForeignKey("user_file.id"),
primary_key=True,
),
)
# Create persona__user_folder table
op.create_table(
"persona__user_folder",
sa.Column(
"persona_id",
sa.Integer(),
sa.ForeignKey("persona.id"),
primary_key=True,
),
sa.Column(
"user_folder_id",
sa.Integer(),
sa.ForeignKey("user_folder.id"),
primary_key=True,
),
)
op.add_column(
"connector_credential_pair",
sa.Column("is_user_file", sa.Boolean(), nullable=True, default=False),
)
# Update existing records to have is_user_file=False instead of NULL
op.execute(
"UPDATE connector_credential_pair SET is_user_file = FALSE WHERE is_user_file IS NULL"
)
def downgrade() -> None:

View File

@@ -1,237 +0,0 @@
"""Add model-configuration table
Revision ID: 7a70b7664e37
Revises: d961aca62eb3
Create Date: 2025-04-10 15:00:35.984669
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from onyx.llm.llm_provider_options import (
fetch_model_names_for_provider_as_set,
fetch_visible_model_names_for_provider_as_set,
)
# revision identifiers, used by Alembic.
revision = "7a70b7664e37"
down_revision = "d961aca62eb3"
branch_labels = None
depends_on = None
def _resolve(
provider_name: str,
model_names: list[str] | None,
display_model_names: list[str] | None,
default_model_name: str,
fast_default_model_name: str | None,
) -> set[tuple[str, bool]]:
models = set(model_names) if model_names else None
display_models = set(display_model_names) if display_model_names else None
# If both are defined, we need to make sure that `model_names` is a superset of `display_model_names`.
if models and display_models:
models = display_models.union(models)
# If only `model_names` is defined, then:
# - If default-model-names are available for the `provider_name`, then set `display_model_names` to it
# and set `model_names` to the union of those default-model-names with itself.
# - If no default-model-names are available, then set `display_models` to `models`.
#
# This preserves the invariant that `display_models` is a subset of `models`.
elif models and not display_models:
visible_default_models = fetch_visible_model_names_for_provider_as_set(
provider_name=provider_name
)
if visible_default_models:
display_models = set(visible_default_models)
models = display_models.union(models)
else:
display_models = set(models)
# If only the `display_model_names` are defined, then set `models` to the union of `display_model_names`
# and the default-model-names for that provider.
#
# This will also preserve the invariant that `display_models` is a subset of `models`.
elif not models and display_models:
default_models = fetch_model_names_for_provider_as_set(
provider_name=provider_name
)
if default_models:
models = display_models.union(default_models)
else:
models = set(display_models)
# If neither are defined, then set `models` and `display_models` to the default-model-names for the given provider.
#
# This will also preserve the invariant that `display_models` is a subset of `models`.
else:
default_models = fetch_model_names_for_provider_as_set(
provider_name=provider_name
)
visible_default_models = fetch_visible_model_names_for_provider_as_set(
provider_name=provider_name
)
if default_models:
if not visible_default_models:
raise RuntimeError
raise RuntimeError(
"If `default_models` is non-None, `visible_default_models` must be non-None too."
)
models = default_models
display_models = visible_default_models
# This is not a well-known llm-provider; we can't provide any model suggestions.
# Therefore, we set to the empty set and continue
else:
models = set()
display_models = set()
# It is possible that `default_model_name` is not in `models` and is not in `display_models`.
# It is also possible that `fast_default_model_name` is not in `models` and is not in `display_models`.
models.add(default_model_name)
if fast_default_model_name:
models.add(fast_default_model_name)
display_models.add(default_model_name)
if fast_default_model_name:
display_models.add(fast_default_model_name)
return set([(model, model in display_models) for model in models])
def upgrade() -> None:
op.create_table(
"model_configuration",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("llm_provider_id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("is_visible", sa.Boolean(), nullable=False),
sa.Column("max_input_tokens", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["llm_provider_id"], ["llm_provider.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("llm_provider_id", "name"),
)
# Create temporary sqlalchemy references to tables for data migration
llm_provider_table = sa.sql.table(
"llm_provider",
sa.column("id", sa.Integer),
sa.column("provider", sa.Integer),
sa.column("model_names", postgresql.ARRAY(sa.String)),
sa.column("display_model_names", postgresql.ARRAY(sa.String)),
sa.column("default_model_name", sa.String),
sa.column("fast_default_model_name", sa.String),
)
model_configuration_table = sa.sql.table(
"model_configuration",
sa.column("id", sa.Integer),
sa.column("llm_provider_id", sa.Integer),
sa.column("name", sa.String),
sa.column("is_visible", sa.Boolean),
sa.column("max_input_tokens", sa.Integer),
)
connection = op.get_bind()
llm_providers = connection.execute(
sa.select(
llm_provider_table.c.id,
llm_provider_table.c.provider,
llm_provider_table.c.model_names,
llm_provider_table.c.display_model_names,
llm_provider_table.c.default_model_name,
llm_provider_table.c.fast_default_model_name,
)
).fetchall()
for llm_provider in llm_providers:
provider_id = llm_provider[0]
provider_name = llm_provider[1]
model_names = llm_provider[2]
display_model_names = llm_provider[3]
default_model_name = llm_provider[4]
fast_default_model_name = llm_provider[5]
model_configurations = _resolve(
provider_name=provider_name,
model_names=model_names,
display_model_names=display_model_names,
default_model_name=default_model_name,
fast_default_model_name=fast_default_model_name,
)
for model_name, is_visible in model_configurations:
connection.execute(
model_configuration_table.insert().values(
llm_provider_id=provider_id,
name=model_name,
is_visible=is_visible,
max_input_tokens=None,
)
)
op.drop_column("llm_provider", "model_names")
op.drop_column("llm_provider", "display_model_names")
def downgrade() -> None:
llm_provider = sa.table(
"llm_provider",
sa.column("id", sa.Integer),
sa.column("model_names", postgresql.ARRAY(sa.String)),
sa.column("display_model_names", postgresql.ARRAY(sa.String)),
)
model_configuration = sa.table(
"model_configuration",
sa.column("id", sa.Integer),
sa.column("llm_provider_id", sa.Integer),
sa.column("name", sa.String),
sa.column("is_visible", sa.Boolean),
sa.column("max_input_tokens", sa.Integer),
)
op.add_column(
"llm_provider",
sa.Column(
"model_names",
postgresql.ARRAY(sa.VARCHAR()),
autoincrement=False,
nullable=True,
),
)
op.add_column(
"llm_provider",
sa.Column(
"display_model_names",
postgresql.ARRAY(sa.VARCHAR()),
autoincrement=False,
nullable=True,
),
)
connection = op.get_bind()
provider_ids = connection.execute(sa.select(llm_provider.c.id)).fetchall()
for (provider_id,) in provider_ids:
# Get all models for this provider
models = connection.execute(
sa.select(
model_configuration.c.name, model_configuration.c.is_visible
).where(model_configuration.c.llm_provider_id == provider_id)
).fetchall()
all_models = [model[0] for model in models]
visible_models = [model[0] for model in models if model[1]]
# Update provider with arrays
op.execute(
llm_provider.update()
.where(llm_provider.c.id == provider_id)
.values(model_names=all_models, display_model_names=visible_models)
)
op.drop_table("model_configuration")

View File

@@ -103,7 +103,6 @@ def upgrade() -> None:
def downgrade() -> None:
op.drop_column("connector_credential_pair", "is_user_file")
# Drop the persona__user_folder table
op.drop_table("persona__user_folder")
# Drop the persona__user_file table
@@ -112,3 +111,4 @@ def downgrade() -> None:
op.drop_table("user_file")
# Drop the user_folder table
op.drop_table("user_folder")
op.drop_column("connector_credential_pair", "is_user_file")

View File

@@ -1,32 +0,0 @@
"""Add public_external_user_group table
Revision ID: a7688ab35c45
Revises: 5c448911b12f
Create Date: 2025-05-06 20:55:12.747875
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a7688ab35c45"
down_revision = "5c448911b12f"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"public_external_user_group",
sa.Column("external_user_group_id", sa.String(), nullable=False),
sa.Column("cc_pair_id", sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint("external_user_group_id", "cc_pair_id"),
sa.ForeignKeyConstraint(
["cc_pair_id"], ["connector_credential_pair.id"], ondelete="CASCADE"
),
)
def downgrade() -> None:
op.drop_table("public_external_user_group")

View File

@@ -1,57 +0,0 @@
"""Update status length
Revision ID: d961aca62eb3
Revises: cf90764725d8
Create Date: 2025-03-23 16:10:05.683965
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "d961aca62eb3"
down_revision = "cf90764725d8"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop the existing enum type constraint
op.execute("ALTER TABLE connector_credential_pair ALTER COLUMN status TYPE varchar")
# Create new enum type with all values
op.execute(
"ALTER TABLE connector_credential_pair ALTER COLUMN status TYPE VARCHAR(20) USING status::varchar(20)"
)
# Update the enum type to include all possible values
op.alter_column(
"connector_credential_pair",
"status",
type_=sa.Enum(
"SCHEDULED",
"INITIAL_INDEXING",
"ACTIVE",
"PAUSED",
"DELETING",
"INVALID",
name="connectorcredentialpairstatus",
native_enum=False,
),
existing_type=sa.String(20),
nullable=False,
)
op.add_column(
"connector_credential_pair",
sa.Column(
"in_repeated_error_state", sa.Boolean, default=False, server_default="false"
),
)
def downgrade() -> None:
# no need to convert back to the old enum type, since we're not using it anymore
op.drop_column("connector_credential_pair", "in_repeated_error_state")

View File

@@ -21,9 +21,6 @@ branch_labels = None
depends_on = None
PRESERVED_CONFIG_KEYS = ["comment_email_blacklist", "batch_size", "labels_to_skip"]
def upgrade() -> None:
# Get all Jira connectors
conn = op.get_bind()
@@ -65,9 +62,6 @@ def upgrade() -> None:
f"WARNING: Jira connector {connector_id} has no project URL configured"
)
continue
for old_key in PRESERVED_CONFIG_KEYS:
if old_key in old_config:
new_config[old_key] = old_config[old_key]
# Update the connector config
conn.execute(
@@ -114,10 +108,6 @@ def downgrade() -> None:
else:
continue
for old_key in PRESERVED_CONFIG_KEYS:
if old_key in new_config:
old_config[old_key] = new_config[old_key]
# Update the connector config
conn.execute(
sa.text(
@@ -127,5 +117,5 @@ def downgrade() -> None:
WHERE id = :id
"""
),
{"id": connector_id, "old_config": json.dumps(old_config)},
{"id": connector_id, "old_config": old_config},
)

View File

@@ -1,7 +1,6 @@
from sqlalchemy.orm import Session
from ee.onyx.db.external_perm import fetch_external_groups_for_user
from ee.onyx.db.external_perm import fetch_public_external_group_ids
from ee.onyx.db.user_group import fetch_user_groups_for_documents
from ee.onyx.db.user_group import fetch_user_groups_for_user
from ee.onyx.external_permissions.post_query_censoring import (
@@ -64,8 +63,6 @@ def _get_access_for_documents(
document_ids=document_ids,
)
all_public_ext_u_group_ids = set(fetch_public_external_group_ids(db_session))
access_map = {}
for document_id, non_ee_access in non_ee_access_dict.items():
document = doc_id_map[document_id]
@@ -92,10 +89,7 @@ def _get_access_for_documents(
# If its censored, then it's public anywhere during the search and then permissions are
# applied after the search
is_public_anywhere = (
document.is_public
or non_ee_access.is_public
or is_only_censored
or any(u_group in all_public_ext_u_group_ids for u_group in ext_u_groups)
document.is_public or non_ee_access.is_public or is_only_censored
)
# To avoid collisions of group namings between connectors, they need to be prefixed

View File

@@ -1,140 +0,0 @@
import csv
import io
from datetime import datetime
from datetime import timezone
from celery import shared_task
from celery import Task
from ee.onyx.background.task_name_builders import query_history_task_name
from ee.onyx.server.query_history.api import fetch_and_process_chat_session_history
from ee.onyx.server.query_history.api import ONYX_ANONYMIZED_EMAIL
from ee.onyx.server.query_history.models import ChatSessionSnapshot
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
from onyx.background.celery.apps.primary import celery_app
from onyx.background.task_utils import construct_query_history_report_name
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import FileType
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import QueryHistoryType
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import TaskStatus
from onyx.db.tasks import delete_task_with_id
from onyx.db.tasks import mark_task_as_finished_with_id
from onyx.db.tasks import register_task
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
logger = setup_logger()
@shared_task(
name=OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
bind=True,
trail=False,
)
def export_query_history_task(self: Task, *, start: datetime, end: datetime) -> None:
if not self.request.id:
raise RuntimeError("No task id defined for this task; cannot identify it")
task_id = self.request.id
start_time = datetime.now(tz=timezone.utc)
with get_session_with_current_tenant() as db_session:
try:
register_task(
db_session=db_session,
task_name=query_history_task_name(start=start, end=end),
task_id=task_id,
status=TaskStatus.STARTED,
start_time=start_time,
)
complete_chat_session_history: list[ChatSessionSnapshot] = (
fetch_and_process_chat_session_history(
db_session=db_session,
start=start,
end=end,
feedback_type=None,
limit=None,
)
)
except Exception:
logger.exception(f"Failed to export query history with {task_id=}")
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=False,
)
raise
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
complete_chat_session_history = [
ChatSessionSnapshot(
**chat_session_snapshot.model_dump(), user_email=ONYX_ANONYMIZED_EMAIL
)
for chat_session_snapshot in complete_chat_session_history
]
qa_pairs: list[QuestionAnswerPairSnapshot] = [
qa_pair
for chat_session_snapshot in complete_chat_session_history
for qa_pair in QuestionAnswerPairSnapshot.from_chat_session_snapshot(
chat_session_snapshot
)
]
stream = io.StringIO()
writer = csv.DictWriter(
stream,
fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()),
)
writer.writeheader()
for row in qa_pairs:
writer.writerow(row.to_json())
report_name = construct_query_history_report_name(task_id)
with get_session_with_current_tenant() as db_session:
try:
stream.seek(0)
get_default_file_store(db_session).save_file(
file_name=report_name,
content=stream,
display_name=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,
file_type=FileType.CSV,
file_metadata={
"start": start.isoformat(),
"end": end.isoformat(),
"start_time": start_time.isoformat(),
},
)
delete_task_with_id(
db_session=db_session,
task_id=task_id,
)
except Exception:
logger.exception(
f"Failed to save query history export file; {report_name=}"
)
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=False,
)
raise
celery_app.autodiscover_tasks(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cleanup",
]
)

View File

@@ -1,8 +0,0 @@
from onyx.background.celery.apps.light import celery_app
celery_app.autodiscover_tasks(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
]
)

View File

@@ -1,7 +0,0 @@
from onyx.background.celery.apps.monitoring import celery_app
celery_app.autodiscover_tasks(
[
"ee.onyx.background.celery.tasks.tenant_provisioning",
]
)

View File

@@ -1,22 +1,12 @@
from datetime import datetime
from datetime import timezone
from uuid import UUID
from celery import shared_task
from celery import Task
from ee.onyx.background.celery_utils import should_perform_chat_ttl_check
from ee.onyx.background.task_name_builders import name_chat_ttl_task
from ee.onyx.server.reporting.usage_export_generation import create_new_usage_report
from onyx.background.celery.apps.primary import celery_app
from onyx.background.task_utils import build_celery_task_wrapper
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.chat import delete_chat_session
from onyx.db.chat import get_chat_sessions_older_than
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import TaskStatus
from onyx.db.tasks import mark_task_as_finished_with_id
from onyx.db.tasks import register_task
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
@@ -25,42 +15,18 @@ logger = setup_logger()
# mark as EE for all tasks in this file
@shared_task(
name=OnyxCeleryTask.PERFORM_TTL_MANAGEMENT_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
bind=True,
trail=False,
)
def perform_ttl_management_task(
self: Task, retention_limit_days: int, *, tenant_id: str
) -> None:
task_id = self.request.id
if not task_id:
raise RuntimeError("No task id defined for this task; cannot identify it")
@build_celery_task_wrapper(name_chat_ttl_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None:
with get_session_with_current_tenant() as db_session:
old_chat_sessions = get_chat_sessions_older_than(
retention_limit_days, db_session
)
start_time = datetime.now(tz=timezone.utc)
user_id: UUID | None = None
session_id: UUID | None = None
try:
for user_id, session_id in old_chat_sessions:
# one session per delete so that we don't blow up if a deletion fails.
with get_session_with_current_tenant() as db_session:
# we generally want to move off this, but keeping for now
register_task(
db_session=db_session,
task_name=name_chat_ttl_task(retention_limit_days, tenant_id),
task_id=task_id,
status=TaskStatus.STARTED,
start_time=start_time,
)
old_chat_sessions = get_chat_sessions_older_than(
retention_limit_days, db_session
)
for user_id, session_id in old_chat_sessions:
# one session per delete so that we don't blow up if a deletion fails.
with get_session_with_current_tenant() as db_session:
try:
delete_chat_session(
user_id,
session_id,
@@ -68,26 +34,11 @@ def perform_ttl_management_task(
include_deleted=True,
hard_delete=True,
)
with get_session_with_current_tenant() as db_session:
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=True,
)
except Exception:
logger.exception(
"delete_chat_session exceptioned. "
f"user_id={user_id} session_id={session_id}"
)
with get_session_with_current_tenant() as db_session:
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=False,
)
raise
except Exception:
logger.exception(
"delete_chat_session exceptioned. "
f"user_id={user_id} session_id={session_id}"
)
#####
@@ -96,7 +47,7 @@ def perform_ttl_management_task(
@celery_app.task(
name=OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
name="check_ttl_management_task",
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
@@ -116,7 +67,7 @@ def check_ttl_management_task(*, tenant_id: str) -> None:
@celery_app.task(
name=OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
name="autogenerate_usage_report_task",
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
@@ -128,12 +79,3 @@ def autogenerate_usage_report_task(*, tenant_id: str) -> None:
user_id=None,
period=None,
)
celery_app.autodiscover_tasks(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cloud",
]
)

View File

@@ -1,7 +1,6 @@
from datetime import timedelta
from typing import Any
from ee.onyx.configs.app_configs import CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS
from onyx.background.celery.tasks.beat_schedule import (
beat_cloud_tasks as base_beat_system_tasks,
)
@@ -14,7 +13,6 @@ from onyx.background.celery.tasks.beat_schedule import (
get_tasks_to_schedule as base_get_tasks_to_schedule,
)
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from shared_configs.configs import MULTI_TENANT
@@ -35,20 +33,10 @@ ee_beat_task_templates.extend(
{
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "export-query-history-cleanup-task",
"task": OnyxCeleryTask.EXPORT_QUERY_HISTORY_CLEANUP_TASK,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.CSV_GENERATION,
},
},
]
@@ -70,20 +58,10 @@ if not MULTI_TENANT:
{
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "export-query-history-cleanup-task",
"task": OnyxCeleryTask.EXPORT_QUERY_HISTORY_CLEANUP_TASK,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.CSV_GENERATION,
},
},
]

View File

@@ -1,40 +0,0 @@
from datetime import datetime
from datetime import timedelta
from celery import shared_task
from ee.onyx.db.query_history import get_all_query_history_export_tasks
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import TaskStatus
from onyx.db.tasks import delete_task_with_id
from onyx.utils.logger import setup_logger
logger = setup_logger()
@shared_task(
name=OnyxCeleryTask.EXPORT_QUERY_HISTORY_CLEANUP_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def export_query_history_cleanup_task(*, tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
tasks = get_all_query_history_export_tasks(db_session=db_session)
for task in tasks:
if task.status == TaskStatus.SUCCESS:
delete_task_with_id(db_session=db_session, task_id=task.task_id)
elif task.status == TaskStatus.FAILURE:
if task.start_time:
deadline = task.start_time + timedelta(hours=24)
now = datetime.now()
if now < deadline:
continue
logger.error(
f"Task with {task.task_id=} failed; it is being deleted now"
)
delete_task_with_id(db_session=db_session, task_id=task.task_id)

View File

@@ -1,104 +0,0 @@
import time
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis.lock import Lock as RedisLock
from ee.onyx.server.tenants.product_gating import get_gated_tenants
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_all_tenant_ids
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
@shared_task(
name=OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
ignore_result=True,
trail=False,
bind=True,
)
def cloud_beat_task_generator(
self: Task,
task_name: str,
queue: str = OnyxCeleryTask.DEFAULT,
priority: int = OnyxCeleryPriority.MEDIUM,
expires: int = BEAT_EXPIRES_DEFAULT,
) -> bool | None:
"""a lightweight task used to kick off individual beat tasks per tenant."""
time_start = time.monotonic()
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
lock_beat: RedisLock = redis_client.lock(
f"{OnyxRedisLocks.CLOUD_BEAT_TASK_GENERATOR_LOCK}:{task_name}",
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
last_lock_time = time.monotonic()
tenant_ids: list[str] = []
num_processed_tenants = 0
try:
tenant_ids = get_all_tenant_ids()
gated_tenants = get_gated_tenants()
for tenant_id in tenant_ids:
if tenant_id in gated_tenants:
continue
current_time = time.monotonic()
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
lock_beat.reacquire()
last_lock_time = current_time
# needed in the cloud
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
continue
self.app.send_task(
task_name,
kwargs=dict(
tenant_id=tenant_id,
),
queue=queue,
priority=priority,
expires=expires,
ignore_result=True,
)
num_processed_tenants += 1
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception during cloud_beat_task_generator")
finally:
if not lock_beat.owned():
task_logger.error(
"cloud_beat_task_generator - Lock not owned on completion"
)
redis_lock_dump(lock_beat, redis_client)
else:
lock_beat.release()
time_elapsed = time.monotonic() - time_start
task_logger.info(
f"cloud_beat_task_generator finished: "
f"task={task_name} "
f"num_processed_tenants={num_processed_tenants} "
f"num_tenants={len(tenant_ids)} "
f"elapsed={time_elapsed:.2f}"
)
return True

View File

@@ -1,30 +0,0 @@
from sqlalchemy.orm import Session
from ee.onyx.external_permissions.sync_params import (
source_group_sync_is_cc_pair_agnostic,
)
from onyx.db.connector import mark_cc_pair_as_external_group_synced
from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_source
from onyx.db.models import ConnectorCredentialPair
def _get_all_cc_pair_ids_to_mark_as_group_synced(
db_session: Session, cc_pair: ConnectorCredentialPair
) -> list[int]:
if not source_group_sync_is_cc_pair_agnostic(cc_pair.connector.source):
return [cc_pair.id]
cc_pairs = get_connector_credential_pairs_for_source(
db_session, cc_pair.connector.source
)
return [cc_pair.id for cc_pair in cc_pairs]
def mark_all_relevant_cc_pairs_as_external_group_synced(
db_session: Session, cc_pair: ConnectorCredentialPair
) -> None:
"""For some source types, one successful group sync run should count for all
cc pairs of that type. This function handles that case."""
cc_pair_ids = _get_all_cc_pair_ids_to_mark_as_group_synced(db_session, cc_pair)
for cc_pair_id in cc_pair_ids:
mark_cc_pair_as_external_group_synced(db_session, cc_pair_id)

View File

@@ -9,7 +9,7 @@ logger = setup_logger()
def should_perform_chat_ttl_check(
retention_limit_days: float | None, db_session: Session
retention_limit_days: int | None, db_session: Session
) -> bool:
# TODO: make this a check for None and add behavior for 0 day TTL
if not retention_limit_days:

View File

@@ -1,16 +1,2 @@
from datetime import datetime
from onyx.configs.constants import OnyxCeleryTask
QUERY_HISTORY_TASK_NAME_PREFIX = OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK
def name_chat_ttl_task(
retention_limit_days: float, tenant_id: str | None = None
) -> str:
def name_chat_ttl_task(retention_limit_days: int, tenant_id: str | None = None) -> str:
return f"chat_ttl_{retention_limit_days}_days"
def query_history_task_name(start: datetime, end: datetime) -> str:
return f"{QUERY_HISTORY_TASK_NAME_PREFIX}_{start}_{end}"

View File

@@ -25,25 +25,13 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_co
#####
# Auto Permission Sync
#####
# should generally only be used for sources that support polling of permissions
# e.g. can pull in only permission changes rather than having to go through all
# documents every time
DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
#####
# Confluence
#####
# In seconds, default is 30 minutes
# In seconds, default is 5 minutes
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 30 * 60
)
# In seconds, default is 30 minutes
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 30 * 60
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
)
# This is a boolean that determines if anonymous access is public
# Default behavior is to not make the page public and instead add a group
@@ -51,34 +39,14 @@ CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
os.environ.get("CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC", "").lower() == "true"
)
#####
# Google Drive
#####
GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
)
#####
# Slack
#####
SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("SLACK_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
# In seconds, default is 5 minutes
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
####
# Celery Job Frequency
####
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS = float(
os.environ.get("CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS") or 1
) # float for easier testing
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
@@ -94,6 +62,29 @@ JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", "[]"))
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
"OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", ""
)
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get(
"OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", ""
)
OAUTH_JIRA_CLOUD_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_ID", "")
OAUTH_JIRA_CLOUD_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_SECRET", "")
OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "")
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
)
GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
)
SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("SLACK_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
# The posthog client does not accept empty API keys or hosts however it fails silently
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
@@ -101,4 +92,6 @@ POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
ANONYMOUS_USER_COOKIE_NAME = "onyx_anonymous_user"
GATED_TENANTS_KEY = "gated_tenants"

View File

@@ -140,7 +140,7 @@ def fetch_onyxbot_analytics(
(
or_(
ChatMessageFeedback.is_positive.is_(False),
ChatMessageFeedback.required_followup.is_(True),
ChatMessageFeedback.required_followup,
),
1,
),
@@ -173,7 +173,7 @@ def fetch_onyxbot_analytics(
.all()
)
return [tuple(row) for row in results]
return results
def fetch_persona_message_analytics(

View File

@@ -8,7 +8,6 @@ from sqlalchemy.orm import Session
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.constants import DocumentSource
from onyx.db.models import PublicExternalUserGroup
from onyx.db.models import User
from onyx.db.models import User__ExternalUserGroupId
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
@@ -21,12 +20,6 @@ logger = setup_logger()
class ExternalUserGroup(BaseModel):
id: str
user_emails: list[str]
# `True` for cases like a Folder in Google Drive that give domain-wide
# or "Anyone with link" access to all files in the folder.
# if this is set, `user_emails` don't really matter.
# When this is `True`, this `ExternalUserGroup` object doesn't really represent
# an actual "group" in the source.
gives_anyone_access: bool = False
def delete_user__ext_group_for_user__no_commit(
@@ -51,17 +44,6 @@ def delete_user__ext_group_for_cc_pair__no_commit(
)
def delete_public_external_group_for_cc_pair__no_commit(
db_session: Session,
cc_pair_id: int,
) -> None:
db_session.execute(
delete(PublicExternalUserGroup).where(
PublicExternalUserGroup.cc_pair_id == cc_pair_id
)
)
def replace_user__ext_group_for_cc_pair(
db_session: Session,
cc_pair_id: int,
@@ -90,22 +72,13 @@ def replace_user__ext_group_for_cc_pair(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
delete_public_external_group_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
# map emails to ids
email_id_map = {user.email: user.id for user in all_group_members}
# use these ids to create new external user group relations relating group_id to user_ids
new_external_permissions: list[User__ExternalUserGroupId] = []
new_public_external_groups: list[PublicExternalUserGroup] = []
new_external_permissions = []
for external_group in group_defs:
external_group_id = build_ext_group_name_for_onyx(
ext_group_name=external_group.id,
source=source,
)
for user_email in external_group.user_emails:
user_id = email_id_map.get(user_email.lower())
if user_id is None:
@@ -114,6 +87,10 @@ def replace_user__ext_group_for_cc_pair(
f" with email {user_email} not found"
)
continue
external_group_id = build_ext_group_name_for_onyx(
ext_group_name=external_group.id,
source=source,
)
new_external_permissions.append(
User__ExternalUserGroupId(
user_id=user_id,
@@ -122,16 +99,7 @@ def replace_user__ext_group_for_cc_pair(
)
)
if external_group.gives_anyone_access:
new_public_external_groups.append(
PublicExternalUserGroup(
external_user_group_id=external_group_id,
cc_pair_id=cc_pair_id,
)
)
db_session.add_all(new_external_permissions)
db_session.add_all(new_public_external_groups)
db_session.commit()
@@ -162,11 +130,3 @@ def fetch_external_groups_for_user_email_and_group_ids(
)
).all()
return list(user_ext_groups)
def fetch_public_external_group_ids(
db_session: Session,
) -> list[str]:
return list(
db_session.scalars(select(PublicExternalUserGroup.external_user_group_id)).all()
)

View File

@@ -11,7 +11,6 @@ from onyx.server.features.persona.models import PersonaSharedNotificationData
def make_persona_private(
persona_id: int,
creator_user_id: UUID | None,
user_ids: list[UUID] | None,
group_ids: list[int] | None,
db_session: Session,
@@ -30,15 +29,15 @@ def make_persona_private(
user_ids_set = set(user_ids)
for user_id in user_ids_set:
db_session.add(Persona__User(persona_id=persona_id, user_id=user_id))
if user_id != creator_user_id:
create_notification(
user_id=user_id,
notif_type=NotificationType.PERSONA_SHARED,
db_session=db_session,
additional_data=PersonaSharedNotificationData(
persona_id=persona_id,
).model_dump(),
)
create_notification(
user_id=user_id,
notif_type=NotificationType.PERSONA_SHARED,
db_session=db_session,
additional_data=PersonaSharedNotificationData(
persona_id=persona_id,
).model_dump(),
)
if group_ids:
group_ids_set = set(group_ids)

View File

@@ -15,13 +15,10 @@ from sqlalchemy.sql import select
from sqlalchemy.sql.expression import literal
from sqlalchemy.sql.expression import UnaryExpression
from ee.onyx.background.task_name_builders import QUERY_HISTORY_TASK_NAME_PREFIX
from onyx.configs.constants import QAFeedbackType
from onyx.db.models import ChatMessage
from onyx.db.models import ChatMessageFeedback
from onyx.db.models import ChatSession
from onyx.db.models import TaskQueueState
from onyx.db.tasks import get_all_tasks_with_prefix
def _build_filter_conditions(
@@ -174,9 +171,3 @@ def fetch_chat_sessions_eagerly_by_time(
chat_sessions = query.all()
return chat_sessions
def get_all_query_history_export_tasks(
db_session: Session,
) -> list[TaskQueueState]:
return get_all_tasks_with_prefix(db_session, QUERY_HISTORY_TASK_NAME_PREFIX)

View File

@@ -8,7 +8,6 @@ from typing import Any
from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.confluence.connector import ConfluenceConnector
@@ -161,24 +160,16 @@ def _get_space_permissions(
# Stores the permissions for each space
space_permissions_by_space_key[space_key] = space_permissions
if (
not space_permissions.is_public
and not space_permissions.external_user_emails
and not space_permissions.external_user_group_ids
):
logger.warning(
f"No permissions found for space '{space_key}'. This is very unlikely"
"to be correct and is more likely caused by an access token with"
"insufficient permissions. Make sure that the access token has Admin"
f"permissions for space '{space_key}'"
)
logger.info(
f"Found space permissions for space '{space_key}': {space_permissions}"
)
return space_permissions_by_space_key
def _extract_read_access_restrictions(
confluence_client: OnyxConfluence, restrictions: dict[str, Any]
) -> tuple[set[str], set[str], bool]:
) -> tuple[set[str], set[str]]:
"""
Converts a page's restrictions dict into an ExternalAccess object.
If there are no restrictions, then return None
@@ -189,9 +180,6 @@ def _extract_read_access_restrictions(
# Extract the users with read access
read_access_user = read_access_restrictions.get("user", {})
read_access_user_jsons = read_access_user.get("results", [])
# any items found means that there is a restriction
found_any_restriction = bool(read_access_user_jsons)
read_access_user_emails = []
for user in read_access_user_jsons:
# If the user has an email, then add it to the list
@@ -220,17 +208,11 @@ def _extract_read_access_restrictions(
# Extract the groups with read access
read_access_group = read_access_restrictions.get("group", {})
read_access_group_jsons = read_access_group.get("results", [])
# any items found means that there is a restriction
found_any_restriction |= bool(read_access_group_jsons)
read_access_group_names = [
group["name"] for group in read_access_group_jsons if group.get("name")
]
return (
set(read_access_user_emails),
set(read_access_group_names),
found_any_restriction,
)
return set(read_access_user_emails), set(read_access_group_names)
def _get_all_page_restrictions(
@@ -238,64 +220,46 @@ def _get_all_page_restrictions(
perm_sync_data: dict[str, Any],
) -> ExternalAccess | None:
"""
This function gets the restrictions for a page. In Confluence, a child can have
at MOST the same level accessibility as its immediate parent.
This function gets the restrictions for a page by taking the intersection
of the page's restrictions and the restrictions of all the ancestors
of the page.
If the page/ancestor has no restrictions, then it is ignored (no intersection).
If no restrictions are found anywhere, then return None, indicating that the page
should inherit the space's restrictions.
"""
found_user_emails: set[str] = set()
found_group_names: set[str] = set()
# NOTE: need the found_any_restriction, since we can find restrictions
# but not be able to extract any user emails or group names
# in this case, we should just give no access
found_user_emails, found_group_names, found_any_page_level_restriction = (
_extract_read_access_restrictions(
confluence_client=confluence_client,
restrictions=perm_sync_data.get("restrictions", {}),
)
found_user_emails, found_group_names = _extract_read_access_restrictions(
confluence_client=confluence_client,
restrictions=perm_sync_data.get("restrictions", {}),
)
# if there are individual page-level restrictions, then this is the accurate
# restriction for the page. You cannot both have page-level restrictions AND
# inherit restrictions from the parent.
if found_any_page_level_restriction:
return ExternalAccess(
external_user_emails=found_user_emails,
external_user_group_ids=found_group_names,
is_public=False,
)
ancestors: list[dict[str, Any]] = perm_sync_data.get("ancestors", [])
# ancestors seem to be in order from root to immediate parent
# https://community.atlassian.com/forums/Confluence-questions/Order-of-ancestors-in-REST-API-response-Confluence-Server-amp/qaq-p/2385981
# we want the restrictions from the immediate parent to take precedence, so we should
# reverse the list
for ancestor in reversed(ancestors):
(
ancestor_user_emails,
ancestor_group_names,
found_any_restrictions_in_ancestor,
) = _extract_read_access_restrictions(
for ancestor in ancestors:
ancestor_user_emails, ancestor_group_names = _extract_read_access_restrictions(
confluence_client=confluence_client,
restrictions=ancestor.get("restrictions", {}),
)
if found_any_restrictions_in_ancestor:
# if inheriting restrictions from the parent, then the first one we run into
# should be applied (the reason why we'd traverse more than one ancestor is if
# the ancestor also is in "inherit" mode.)
logger.info(
f"Found user restrictions {ancestor_user_emails} and group restrictions {ancestor_group_names}"
f"for document {perm_sync_data.get('id')} based on ancestor {ancestor}"
)
return ExternalAccess(
external_user_emails=ancestor_user_emails,
external_user_group_ids=ancestor_group_names,
is_public=False,
)
if not ancestor_user_emails and not ancestor_group_names:
# This ancestor has no restrictions, so it has no effect on
# the page's restrictions, so we ignore it
continue
# we didn't find any restrictions, so the page inherits the space's restrictions
return None
found_user_emails.intersection_update(ancestor_user_emails)
found_group_names.intersection_update(ancestor_group_names)
# If there are no restrictions found, then the page
# inherits the space's restrictions so return None
if not found_user_emails and not found_group_names:
return None
return ExternalAccess(
external_user_emails=found_user_emails,
external_user_group_ids=found_group_names,
# there is no way for a page to be individually public if the space isn't public
is_public=False,
)
def _fetch_all_page_restrictions(
@@ -325,7 +289,6 @@ def _fetch_all_page_restrictions(
confluence_client=confluence_client,
perm_sync_data=slim_doc.perm_sync_data,
):
logger.info(f"Found restrictions {restrictions} for document {slim_doc.id}")
yield DocExternalAccess(
doc_id=slim_doc.id,
external_access=restrictions,
@@ -335,9 +298,8 @@ def _fetch_all_page_restrictions(
space_key = slim_doc.perm_sync_data.get("space_key")
if not (space_permissions := space_permissions_by_space_key.get(space_key)):
logger.warning(
f"Individually fetching space permissions for space {space_key}. This is "
"unexpected. It means the permissions were not able to fetched initially."
logger.debug(
f"Individually fetching space permissions for space {space_key}"
)
try:
# If the space permissions are not in the cache, then fetch them
@@ -360,15 +322,6 @@ def _fetch_all_page_restrictions(
logger.warning(
f"No permissions found for document {slim_doc.id} in space {space_key}"
)
# be safe, if we can't get the permissions then make the document inaccessible
yield DocExternalAccess(
doc_id=slim_doc.id,
external_access=ExternalAccess(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=False,
),
)
continue
# If there are no restrictions, then use the space's restrictions
@@ -383,24 +336,24 @@ def _fetch_all_page_restrictions(
):
logger.warning(
f"Permissions are empty for document: {slim_doc.id}\n"
"This means space permissions may be wrong for"
"This means space permissions are may be wrong for"
f" Space key: {space_key}"
)
logger.info("Finished fetching all page restrictions")
logger.debug("Finished fetching all page restrictions for space")
def confluence_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
"""
Fetches document permissions from Confluence and yields DocExternalAccess objects.
Compares fetched documents against existing documents in the DB for the connector.
If a document exists in the DB but not in the Confluence fetch, it's marked as restricted.
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
it in postgres so that when it gets created later, the permissions are
already populated
"""
logger.info(f"Starting confluence doc sync for CC Pair ID: {cc_pair.id}")
logger.debug("Starting confluence doc sync")
confluence_connector = ConfluenceConnector(
**cc_pair.connector.connector_specific_config
)
@@ -416,16 +369,13 @@ def confluence_doc_sync(
confluence_client=confluence_connector.confluence_client,
is_cloud=is_cloud,
)
logger.info("Space permissions by space key:")
for space_key, space_permissions in space_permissions_by_space_key.items():
logger.info(f"Space key: {space_key}, Permissions: {space_permissions}")
slim_docs: list[SlimDocument] = []
logger.info("Fetching all slim documents from confluence")
slim_docs = []
logger.debug("Fetching all slim documents from confluence")
for doc_batch in confluence_connector.retrieve_all_slim_documents(
callback=callback
):
logger.info(f"Got {len(doc_batch)} slim documents from confluence")
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
if callback:
if callback.should_stop():
raise RuntimeError("confluence_doc_sync: Stop signal detected")
@@ -434,32 +384,7 @@ def confluence_doc_sync(
slim_docs.extend(doc_batch)
# Find documents that are no longer accessible in Confluence
logger.info(f"Querying existing document IDs for CC Pair ID: {cc_pair.id}")
existing_doc_ids = fetch_all_existing_docs_fn()
# Find missing doc IDs
fetched_doc_ids = {doc.id for doc in slim_docs}
missing_doc_ids = set(existing_doc_ids) - fetched_doc_ids
# Yield access removal for missing docs. Better to be safe.
if missing_doc_ids:
logger.warning(
f"Found {len(missing_doc_ids)} documents that are in the DB but "
"not present in Confluence fetch. Making them inaccessible."
)
for missing_id in missing_doc_ids:
logger.warning(f"Removing access for document ID: {missing_id}")
yield DocExternalAccess(
doc_id=missing_id,
external_access=ExternalAccess(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=False,
),
)
logger.info("Fetching all page restrictions for fetched documents")
logger.debug("Fetching all page restrictions for space")
yield from _fetch_all_page_restrictions(
confluence_client=confluence_connector.confluence_client,
slim_docs=slim_docs,
@@ -467,5 +392,3 @@ def confluence_doc_sync(
is_cloud=is_cloud,
callback=callback,
)
logger.info("Finished confluence doc sync")

View File

@@ -2,7 +2,6 @@ from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.gmail.connector import GmailConnector
@@ -35,7 +34,6 @@ def _get_slim_doc_generator(
def gmail_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
"""

View File

@@ -3,15 +3,10 @@ from datetime import datetime
from datetime import timezone
from typing import Any
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
from ee.onyx.external_permissions.google_drive.models import PermissionType
from ee.onyx.external_permissions.google_drive.permission_retrieval import (
get_permissions_by_ids,
)
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
from onyx.connectors.google_utils.resources import get_drive_service
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.models import SlimDocument
@@ -21,6 +16,8 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_PERMISSION_ID_PERMISSION_MAP: dict[str, dict[str, Any]] = {}
def _get_slim_doc_generator(
cc_pair: ConnectorCredentialPair,
@@ -43,28 +40,46 @@ def _get_slim_doc_generator(
def _fetch_permissions_for_permission_ids(
google_drive_connector: GoogleDriveConnector,
permission_ids: list[str],
permission_info: dict[str, Any],
) -> list[GoogleDrivePermission]:
) -> list[dict[str, Any]]:
doc_id = permission_info.get("doc_id")
if not permission_info or not doc_id:
return []
permissions = [
_PERMISSION_ID_PERMISSION_MAP[pid]
for pid in permission_ids
if pid in _PERMISSION_ID_PERMISSION_MAP
]
if len(permissions) == len(permission_ids):
return permissions
owner_email = permission_info.get("owner_email")
permission_ids = permission_info.get("permission_ids", [])
if not permission_ids:
return []
drive_service = get_drive_service(
creds=google_drive_connector.creds,
user_email=(owner_email or google_drive_connector.primary_admin_email),
)
return get_permissions_by_ids(
drive_service=drive_service,
doc_id=doc_id,
permission_ids=permission_ids,
# We continue on 404 or 403 because the document may not exist or the user may not have access to it
fetched_permissions = execute_paginated_retrieval(
retrieval_function=drive_service.permissions().list,
list_key="permissions",
fileId=doc_id,
fields="permissions(id, emailAddress, type, domain)",
supportsAllDrives=True,
continue_on_404_or_403=True,
)
permissions_for_doc_id = []
for permission in fetched_permissions:
permissions_for_doc_id.append(permission)
_PERMISSION_ID_PERMISSION_MAP[permission["id"]] = permission
return permissions_for_doc_id
def _get_permissions_from_slim_doc(
google_drive_connector: GoogleDriveConnector,
@@ -72,13 +87,14 @@ def _get_permissions_from_slim_doc(
) -> ExternalAccess:
permission_info = slim_doc.perm_sync_data or {}
permissions_list: list[GoogleDrivePermission] = []
raw_permissions_list = permission_info.get("permissions", [])
if not raw_permissions_list:
permissions_list = _fetch_permissions_for_permission_ids(
google_drive_connector=google_drive_connector,
permission_info=permission_info,
)
permissions_list = permission_info.get("permissions", [])
if not permissions_list:
if permission_ids := permission_info.get("permission_ids"):
permissions_list = _fetch_permissions_for_permission_ids(
google_drive_connector=google_drive_connector,
permission_ids=permission_ids,
permission_info=permission_info,
)
if not permissions_list:
logger.warning(f"No permissions found for document {slim_doc.id}")
return ExternalAccess(
@@ -86,71 +102,41 @@ def _get_permissions_from_slim_doc(
external_user_group_ids=set(),
is_public=False,
)
else:
permissions_list = [
GoogleDrivePermission.from_drive_permission(p) for p in raw_permissions_list
]
company_domain = google_drive_connector.google_domain
folder_ids_to_inherit_permissions_from: set[str] = set()
user_emails: set[str] = set()
group_emails: set[str] = set()
public = False
skipped_permissions = 0
for permission in permissions_list:
# if the permission is inherited, do not add it directly to the file
# instead, add the folder ID as a group that has access to the file
# we will then handle mapping that folder to the list of Onyx users
# in the group sync job
# NOTE: this doesn't handle the case where a folder initially has no
# permissioning, but then later that folder is shared with a user or group.
# We could fetch all ancestors of the file to get the list of folders that
# might affect the permissions of the file, but this will get replaced with
# an audit-log based approach in the future so not doing it now.
if (
permission.permission_details
and permission.permission_details.inherited_from
):
folder_ids_to_inherit_permissions_from.add(
permission.permission_details.inherited_from
)
if not permission:
skipped_permissions += 1
continue
if permission.type == PermissionType.USER:
if permission.email_address:
user_emails.add(permission.email_address)
else:
logger.error(
"Permission is type `user` but no email address is "
f"provided for document {slim_doc.id}"
f"\n {permission}"
)
elif permission.type == PermissionType.GROUP:
# groups are represented as email addresses within Drive
if permission.email_address:
group_emails.add(permission.email_address)
else:
logger.error(
"Permission is type `group` but no email address is "
f"provided for document {slim_doc.id}"
f"\n {permission}"
)
elif permission.type == PermissionType.DOMAIN and company_domain:
if permission.domain == company_domain:
permission_type = permission["type"]
if permission_type == "user":
user_emails.add(permission["emailAddress"])
elif permission_type == "group":
group_emails.add(permission["emailAddress"])
elif permission_type == "domain" and company_domain:
if permission.get("domain") == company_domain:
public = True
else:
logger.warning(
"Permission is type domain but does not match company domain:"
f"\n {permission}"
)
elif permission.type == PermissionType.ANYONE:
elif permission_type == "anyone":
public = True
if skipped_permissions > 0:
logger.warning(
f"Skipped {skipped_permissions} permissions of {len(permissions_list)} for document {slim_doc.id}"
)
drive_id = permission_info.get("drive_id")
group_ids = (
group_emails
| folder_ids_to_inherit_permissions_from
| ({drive_id} if drive_id is not None else set())
)
group_ids = group_emails | ({drive_id} if drive_id is not None else set())
return ExternalAccess(
external_user_emails=user_emails,
@@ -161,7 +147,6 @@ def _get_permissions_from_slim_doc(
def gdrive_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
"""

View File

@@ -1,84 +0,0 @@
from collections.abc import Iterator
from googleapiclient.discovery import Resource # type: ignore
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
from ee.onyx.external_permissions.google_drive.permission_retrieval import (
get_permissions_by_ids,
)
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from onyx.connectors.google_drive.file_retrieval import generate_time_range_filter
from onyx.connectors.google_drive.models import GoogleDriveFileType
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Only include fields we need - folder ID and permissions
# IMPORTANT: must fetch permissionIds, since sometimes the drive API
# seems to miss permissions when requesting them directly
FOLDER_PERMISSION_FIELDS = (
"nextPageToken, files(id, name, permissionIds, "
"permissions(id, emailAddress, type, domain, permissionDetails))"
)
def get_folder_permissions_by_ids(
service: Resource,
folder_id: str,
permission_ids: list[str],
) -> list[GoogleDrivePermission]:
"""
Retrieves permissions for a specific folder filtered by permission IDs.
Args:
service: The Google Drive service instance
folder_id: The ID of the folder to fetch permissions for
permission_ids: A list of permission IDs to filter by
Returns:
A list of permissions matching the provided permission IDs
"""
return get_permissions_by_ids(
drive_service=service,
doc_id=folder_id,
permission_ids=permission_ids,
)
def get_modified_folders(
service: Resource,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
"""
Retrieves all folders that were modified within the specified time range.
Only includes folder ID and permission information, not any contained files.
Args:
service: The Google Drive service instance
start: The start time as seconds since Unix epoch (inclusive)
end: The end time as seconds since Unix epoch (inclusive)
Returns:
An iterator yielding folder information including ID and permissions
"""
# Build query for folders
query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
query += " and trashed = false"
query += generate_time_range_filter(start, end)
# Retrieve and yield folders
for folder in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="allDrives",
supportsAllDrives=True,
includeItemsFromAllDrives=True,
includePermissionsForView="published",
fields=FOLDER_PERMISSION_FIELDS,
q=query,
):
yield folder

View File

@@ -1,15 +1,6 @@
from googleapiclient.errors import HttpError # type: ignore
from pydantic import BaseModel
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.google_drive.folder_retrieval import (
get_folder_permissions_by_ids,
)
from ee.onyx.external_permissions.google_drive.folder_retrieval import (
get_modified_folders,
)
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
from ee.onyx.external_permissions.google_drive.models import PermissionType
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
from onyx.connectors.google_utils.resources import AdminService
@@ -21,77 +12,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
"""
Folder Permission Sync.
Each folder is treated as a group. Each file has all ancestor folders
as groups.
"""
class FolderInfo(BaseModel):
id: str
permissions: list[GoogleDrivePermission]
def _get_all_folders(
google_drive_connector: GoogleDriveConnector, skip_folders_without_permissions: bool
) -> list[FolderInfo]:
"""Have to get all folders since the group syncing system assumes all groups
are returned every time.
TODO: tweak things so we can fetch deltas.
"""
all_folders: list[FolderInfo] = []
seen_folder_ids: set[str] = set()
user_emails = google_drive_connector._get_all_user_emails()
for user_email in user_emails:
drive_service = get_drive_service(
google_drive_connector.creds,
user_email,
)
for folder in get_modified_folders(
service=drive_service,
):
folder_id = folder["id"]
if folder_id in seen_folder_ids:
logger.debug(f"Folder {folder_id} has already been seen. Skipping.")
continue
# Check if the folder has permission IDs but no permissions
permission_ids = folder.get("permissionIds", [])
raw_permissions = folder.get("permissions", [])
if not raw_permissions and permission_ids:
# Fetch permissions using the IDs
permissions = get_folder_permissions_by_ids(
drive_service, folder_id, permission_ids
)
else:
permissions = [
GoogleDrivePermission.from_drive_permission(permission)
for permission in raw_permissions
]
if not permissions and skip_folders_without_permissions:
continue
all_folders.append(
FolderInfo(
id=folder_id,
permissions=permissions,
)
)
seen_folder_ids.add(folder_id)
return all_folders
"""Individual Shared Drive / My Drive Permission Sync"""
def _get_drive_members(
google_drive_connector: GoogleDriveConnector,
admin_service: AdminService,
@@ -131,17 +51,15 @@ def _get_drive_members(
drive_service.permissions().list,
list_key="permissions",
fileId=drive_id,
fields="permissions(emailAddress, type),nextPageToken",
fields="permissions(emailAddress, type)",
supportsAllDrives=True,
# can only set `useDomainAdminAccess` to true if the user
# is an admin
useDomainAdminAccess=is_admin,
):
# NOTE: don't need to check for PermissionType.ANYONE since
# you can't share a drive with the internet
if permission["type"] == PermissionType.GROUP:
if permission["type"] == "group":
group_emails.add(permission["emailAddress"])
elif permission["type"] == PermissionType.USER:
elif permission["type"] == "user":
user_emails.add(permission["emailAddress"])
except HttpError as e:
if e.status_code == 404:
@@ -169,7 +87,7 @@ def _get_all_groups(
admin_service.groups().list,
list_key="groups",
domain=google_domain,
fields="groups(email),nextPageToken",
fields="groups(email)",
):
group_emails.add(group["email"])
return group_emails
@@ -189,7 +107,7 @@ def _map_group_email_to_member_emails(
admin_service.members().list,
list_key="members",
groupKey=group_email,
fields="members(email),nextPageToken",
fields="members(email)",
):
group_member_emails.add(member["email"])
@@ -200,7 +118,6 @@ def _map_group_email_to_member_emails(
def _build_onyx_groups(
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]],
group_email_to_member_emails_map: dict[str, set[str]],
folder_info: list[FolderInfo],
) -> list[ExternalUserGroup]:
onyx_groups: list[ExternalUserGroup] = []
@@ -208,52 +125,13 @@ def _build_onyx_groups(
# This is because having drive level access means you have
# irrevocable access to all the files in the drive.
for drive_id, (group_emails, user_emails) in drive_id_to_members_map.items():
drive_member_emails: set[str] = user_emails
all_member_emails: set[str] = user_emails
for group_email in group_emails:
if group_email not in group_email_to_member_emails_map:
logger.warning(
f"Group email {group_email} for drive {drive_id} not found in "
"group_email_to_member_emails_map"
)
continue
drive_member_emails.update(group_email_to_member_emails_map[group_email])
all_member_emails.update(group_email_to_member_emails_map[group_email])
onyx_groups.append(
ExternalUserGroup(
id=drive_id,
user_emails=list(drive_member_emails),
)
)
# Convert all folder permissions to onyx groups
for folder in folder_info:
anyone_can_access = False
folder_member_emails: set[str] = set()
for permission in folder.permissions:
if permission.type == PermissionType.USER:
if permission.email_address is None:
logger.warning(
f"User email is None for folder {folder.id} permission {permission}"
)
continue
folder_member_emails.add(permission.email_address)
elif permission.type == PermissionType.GROUP:
if permission.email_address not in group_email_to_member_emails_map:
logger.warning(
f"Group email {permission.email_address} for folder {folder.id} "
"not found in group_email_to_member_emails_map"
)
continue
folder_member_emails.update(
group_email_to_member_emails_map[permission.email_address]
)
elif permission.type == PermissionType.ANYONE:
anyone_can_access = True
onyx_groups.append(
ExternalUserGroup(
id=folder.id,
user_emails=list(folder_member_emails),
gives_anyone_access=anyone_can_access,
user_emails=list(all_member_emails),
)
)
@@ -290,12 +168,6 @@ def gdrive_group_sync(
admin_service, google_drive_connector.google_domain
)
# Get all folder permissions
folder_info = _get_all_folders(
google_drive_connector=google_drive_connector,
skip_folders_without_permissions=True,
)
# Map group emails to their members
group_email_to_member_emails_map = _map_group_email_to_member_emails(
admin_service, all_group_emails
@@ -305,7 +177,6 @@ def gdrive_group_sync(
onyx_groups = _build_onyx_groups(
drive_id_to_members_map=drive_id_to_members_map,
group_email_to_member_emails_map=group_email_to_member_emails_map,
folder_info=folder_info,
)
return onyx_groups

View File

@@ -1,59 +0,0 @@
from enum import Enum
from typing import Any
from pydantic import BaseModel
class PermissionType(str, Enum):
USER = "user"
GROUP = "group"
DOMAIN = "domain"
ANYONE = "anyone"
class GoogleDrivePermissionDetails(BaseModel):
# this is "file", "member", etc.
# different from the `type` field within `GoogleDrivePermission`
# Sometimes can be not, although not sure why...
permission_type: str | None
# this is "reader", "writer", "owner", etc.
role: str
# this is the id of the parent permission
inherited_from: str | None
class GoogleDrivePermission(BaseModel):
id: str
# groups are also represented as email addresses within Drive
# will be None for domain/global permissions
email_address: str | None
type: PermissionType
domain: str | None # only applies to domain permissions
permission_details: GoogleDrivePermissionDetails | None
@classmethod
def from_drive_permission(
cls, drive_permission: dict[str, Any]
) -> "GoogleDrivePermission":
# we seem to only get details for permissions that are inherited
# we can get multiple details if a permission is inherited from multiple
# parents
permission_details_list = drive_permission.get("permissionDetails", [])
permission_details: dict[str, Any] | None = (
permission_details_list[0] if permission_details_list else None
)
return cls(
id=drive_permission["id"],
email_address=drive_permission.get("emailAddress"),
type=PermissionType(drive_permission["type"]),
domain=drive_permission.get("domain"),
permission_details=(
GoogleDrivePermissionDetails(
permission_type=permission_details.get("type"),
role=permission_details.get("role", ""),
inherited_from=permission_details.get("inheritedFrom"),
)
if permission_details
else None
),
)

View File

@@ -1,60 +0,0 @@
from googleapiclient.discovery import Resource # type: ignore
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_permissions_by_ids(
drive_service: Resource,
doc_id: str,
permission_ids: list[str],
) -> list[GoogleDrivePermission]:
"""
Fetches permissions for a document based on a list of permission IDs.
Args:
drive_service: The Google Drive service instance
doc_id: The ID of the document to fetch permissions for
permission_ids: A list of permission IDs to filter by
Returns:
A list of GoogleDrivePermission objects matching the provided permission IDs
"""
if not permission_ids:
return []
# Create a set for faster lookup
permission_id_set = set(permission_ids)
# Fetch all permissions for the document
fetched_permissions = execute_paginated_retrieval(
retrieval_function=drive_service.permissions().list,
list_key="permissions",
fileId=doc_id,
fields="permissions(id, emailAddress, type, domain, permissionDetails),nextPageToken",
supportsAllDrives=True,
continue_on_404_or_403=True,
)
# Filter permissions by ID and convert to GoogleDrivePermission objects
filtered_permissions = []
for permission in fetched_permissions:
permission_id = permission.get("id")
if permission_id in permission_id_set:
google_drive_permission = GoogleDrivePermission.from_drive_permission(
permission
)
filtered_permissions.append(google_drive_permission)
# Log if we couldn't find all requested permission IDs
if len(filtered_permissions) < len(permission_ids):
missing_ids = permission_id_set - {p.id for p in filtered_permissions if p.id}
logger.warning(
f"Could not find all requested permission IDs for document {doc_id}. "
f"Missing IDs: {missing_ids}"
)
return filtered_permissions

View File

@@ -1,44 +0,0 @@
from collections.abc import Callable
from collections.abc import Generator
from typing import Optional
from typing import Protocol
from typing import TYPE_CHECKING
# Avoid circular imports
if TYPE_CHECKING:
from ee.onyx.db.external_perm import ExternalUserGroup # noqa
from onyx.access.models import DocExternalAccess # noqa
from onyx.db.models import ConnectorCredentialPair # noqa
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface # noqa
class FetchAllDocumentsFunction(Protocol):
"""Protocol for a function that fetches all document IDs for a connector credential pair."""
def __call__(self) -> list[str]:
"""
Returns a list of document IDs for a connector credential pair.
This is typically used to determine which documents should no longer be
accessible during the document sync process.
"""
...
# Defining the input/output types for the sync functions
DocSyncFuncType = Callable[
[
"ConnectorCredentialPair",
FetchAllDocumentsFunction,
Optional["IndexingHeartbeatInterface"],
],
Generator["DocExternalAccess", None, None],
]
GroupSyncFuncType = Callable[
[
str,
"ConnectorCredentialPair",
],
list["ExternalUserGroup"],
]

View File

@@ -2,7 +2,6 @@ from collections.abc import Generator
from slack_sdk import WebClient
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
@@ -132,7 +131,6 @@ def _get_slack_document_access(
def slack_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
"""

View File

@@ -1,21 +1,42 @@
from collections.abc import Callable
from collections.abc import Generator
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync
from ee.onyx.external_permissions.gmail.doc_sync import gmail_doc_sync
from ee.onyx.external_permissions.google_drive.doc_sync import gdrive_doc_sync
from ee.onyx.external_permissions.google_drive.group_sync import gdrive_group_sync
from ee.onyx.external_permissions.perm_sync_types import DocSyncFuncType
from ee.onyx.external_permissions.perm_sync_types import GroupSyncFuncType
from ee.onyx.external_permissions.post_query_censoring import (
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
)
from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
from ee.onyx.external_permissions.slack.group_sync import slack_group_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
# Defining the input/output types for the sync functions
DocSyncFuncType = Callable[
[
ConnectorCredentialPair,
IndexingHeartbeatInterface | None,
],
Generator[DocExternalAccess, None, None],
]
GroupSyncFuncType = Callable[
[
str,
ConnectorCredentialPair,
],
list[ExternalUserGroup],
]
# These functions update:
# - the user_email <-> document mapping
@@ -29,12 +50,6 @@ DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = {
DocumentSource.GMAIL: gmail_doc_sync,
}
def source_requires_doc_sync(source: DocumentSource) -> bool:
"""Checks if the given DocumentSource requires doc syncing."""
return source in DOC_PERMISSIONS_FUNC_MAP
# These functions update:
# - the user_email <-> external_user_group_id mapping
# in postgres without committing
@@ -46,21 +61,11 @@ GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, GroupSyncFuncType] = {
}
def source_requires_external_group_sync(source: DocumentSource) -> bool:
"""Checks if the given DocumentSource requires external group syncing."""
return source in GROUP_PERMISSIONS_FUNC_MAP
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC: set[DocumentSource] = {
DocumentSource.CONFLUENCE,
}
def source_group_sync_is_cc_pair_agnostic(source: DocumentSource) -> bool:
"""Checks if the given DocumentSource requires external group syncing."""
return source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC
# If nothing is specified here, we run the doc_sync every time the celery beat runs
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all doc permissions every 5 minutes

View File

@@ -10,7 +10,6 @@ from ee.onyx.configs.app_configs import OIDC_SCOPE_OVERRIDE
from ee.onyx.configs.app_configs import OPENID_CONFIG_URL
from ee.onyx.server.analytics.api import router as analytics_router
from ee.onyx.server.auth_check import check_ee_router_auth
from ee.onyx.server.documents.cc_pair import router as ee_document_cc_pair_router
from ee.onyx.server.enterprise_settings.api import (
admin_router as enterprise_settings_admin_router,
)
@@ -168,7 +167,6 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, standard_answer_router)
include_router_with_global_prefix_prepended(application, ee_oauth_router)
include_router_with_global_prefix_prepended(application, ee_document_cc_pair_router)
# Enterprise-only global settings
include_router_with_global_prefix_prepended(

View File

@@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
from ee.onyx.db.standard_answer import fetch_standard_answer_categories_by_names
from ee.onyx.db.standard_answer import find_matching_standard_answers
from ee.onyx.server.manage.models import StandardAnswer as PydanticStandardAnswer
from onyx.configs.constants import MessageType
from onyx.configs.onyxbot_configs import DANSWER_REACT_EMOJI
from onyx.db.chat import create_chat_session
@@ -23,7 +24,6 @@ from onyx.onyxbot.slack.handlers.utils import send_team_member_message
from onyx.onyxbot.slack.models import SlackMessageInfo
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
from onyx.onyxbot.slack.utils import update_emote_react
from onyx.server.manage.models import StandardAnswer as PydanticStandardAnswer
from onyx.utils.logger import OnyxLoggingAdapter
from onyx.utils.logger import setup_logger

View File

@@ -1,177 +0,0 @@
from datetime import datetime
from http import HTTPStatus
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.background.celery.tasks.doc_permission_syncing.tasks import (
try_creating_permissions_sync_task,
)
from ee.onyx.background.celery.tasks.external_group_syncing.tasks import (
try_creating_external_group_sync_task,
)
from onyx.auth.users import current_curator_or_admin_user
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.db.connector_credential_pair import (
get_connector_credential_pair_from_id_for_user,
)
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_pool import get_redis_client
from onyx.server.models import StatusResponse
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/manage")
@router.get("/admin/cc-pair/{cc_pair_id}/sync-permissions")
def get_cc_pair_latest_sync(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> datetime | None:
cc_pair = get_connector_credential_pair_from_id_for_user(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="cc_pair not found for current user's permissions",
)
return cc_pair.last_time_perm_sync
@router.post("/admin/cc-pair/{cc_pair_id}/sync-permissions")
def sync_cc_pair(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> StatusResponse[None]:
"""Triggers permissions sync on a particular cc_pair immediately"""
tenant_id = get_current_tenant_id()
cc_pair = get_connector_credential_pair_from_id_for_user(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="Connection not found for current user's permissions",
)
r = get_redis_client()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if redis_connector.permissions.fenced:
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="Permissions sync task already in progress.",
)
logger.info(
f"Permissions sync cc_pair={cc_pair_id} "
f"connector_id={cc_pair.connector_id} "
f"credential_id={cc_pair.credential_id} "
f"{cc_pair.connector.name} connector."
)
payload_id = try_creating_permissions_sync_task(
client_app, cc_pair_id, r, tenant_id
)
if not payload_id:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="Permissions sync task creation failed.",
)
logger.info(f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}")
return StatusResponse(
success=True,
message="Successfully created the permissions sync task.",
)
@router.get("/admin/cc-pair/{cc_pair_id}/sync-groups")
def get_cc_pair_latest_group_sync(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> datetime | None:
cc_pair = get_connector_credential_pair_from_id_for_user(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="cc_pair not found for current user's permissions",
)
return cc_pair.last_time_external_group_sync
@router.post("/admin/cc-pair/{cc_pair_id}/sync-groups")
def sync_cc_pair_groups(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> StatusResponse[None]:
"""Triggers group sync on a particular cc_pair immediately"""
tenant_id = get_current_tenant_id()
cc_pair = get_connector_credential_pair_from_id_for_user(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="Connection not found for current user's permissions",
)
r = get_redis_client()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if redis_connector.external_group_sync.fenced:
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="External group sync task already in progress.",
)
logger.info(
f"External group sync cc_pair={cc_pair_id} "
f"connector_id={cc_pair.connector_id} "
f"credential_id={cc_pair.credential_id} "
f"{cc_pair.connector.name} connector."
)
payload_id = try_creating_external_group_sync_task(
client_app, cc_pair_id, r, tenant_id
)
if not payload_id:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="External group sync task creation failed.",
)
logger.info(f"External group sync queued: cc_pair={cc_pair_id} id={payload_id}")
return StatusResponse(
success=True,
message="Successfully created the external group sync task.",
)

View File

@@ -29,11 +29,7 @@ from onyx.auth.users import UserManager
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.file_store.file_store import PostgresBackedFileStore
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import get_current_tenant_id
admin_router = APIRouter(prefix="/admin/enterprise-settings")
basic_router = APIRouter(prefix="/enterprise-settings")
@@ -122,11 +118,6 @@ def put_settings(
@basic_router.get("")
def fetch_settings() -> EnterpriseSettings:
if MULTI_TENANT:
tenant_id = get_current_tenant_id()
if not tenant_id or tenant_id == POSTGRES_DEFAULT_SCHEMA:
raise BasicAuthenticationError(detail="User must authenticate")
return load_settings()

View File

@@ -0,0 +1,98 @@
import re
from typing import Any
from pydantic import BaseModel
from pydantic import field_validator
from pydantic import model_validator
from onyx.db.models import StandardAnswer as StandardAnswerModel
from onyx.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
class StandardAnswerCategoryCreationRequest(BaseModel):
name: str
class StandardAnswerCategory(BaseModel):
id: int
name: str
@classmethod
def from_model(
cls, standard_answer_category: StandardAnswerCategoryModel
) -> "StandardAnswerCategory":
return cls(
id=standard_answer_category.id,
name=standard_answer_category.name,
)
class StandardAnswer(BaseModel):
id: int
keyword: str
answer: str
categories: list[StandardAnswerCategory]
match_regex: bool
match_any_keywords: bool
@classmethod
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
return cls(
id=standard_answer_model.id,
keyword=standard_answer_model.keyword,
answer=standard_answer_model.answer,
match_regex=standard_answer_model.match_regex,
match_any_keywords=standard_answer_model.match_any_keywords,
categories=[
StandardAnswerCategory.from_model(standard_answer_category_model)
for standard_answer_category_model in standard_answer_model.categories
],
)
class StandardAnswerCreationRequest(BaseModel):
keyword: str
answer: str
categories: list[int]
match_regex: bool
match_any_keywords: bool
@field_validator("categories", mode="before")
@classmethod
def validate_categories(cls, value: list[int]) -> list[int]:
if len(value) < 1:
raise ValueError(
"At least one category must be attached to a standard answer"
)
return value
@model_validator(mode="after")
def validate_only_match_any_if_not_regex(self) -> Any:
if self.match_regex and self.match_any_keywords:
raise ValueError(
"Can only match any keywords in keyword mode, not regex mode"
)
return self
@model_validator(mode="after")
def validate_keyword_if_regex(self) -> Any:
if not self.match_regex:
# no validation for keywords
return self
try:
re.compile(self.keyword)
return self
except re.error as err:
if isinstance(err.pattern, bytes):
raise ValueError(
f'invalid regex pattern r"{err.pattern.decode()}" in `keyword`: {err.msg}'
)
else:
pattern = f'r"{err.pattern}"' if err.pattern is not None else ""
raise ValueError(
" ".join(
["invalid regex pattern", pattern, f"in `keyword`: {err.msg}"]
)
)

View File

@@ -12,13 +12,13 @@ from ee.onyx.db.standard_answer import insert_standard_answer_category
from ee.onyx.db.standard_answer import remove_standard_answer
from ee.onyx.db.standard_answer import update_standard_answer
from ee.onyx.db.standard_answer import update_standard_answer_category
from ee.onyx.server.manage.models import StandardAnswer
from ee.onyx.server.manage.models import StandardAnswerCategory
from ee.onyx.server.manage.models import StandardAnswerCategoryCreationRequest
from ee.onyx.server.manage.models import StandardAnswerCreationRequest
from onyx.auth.users import current_admin_user
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.server.manage.models import StandardAnswer
from onyx.server.manage.models import StandardAnswerCategory
from onyx.server.manage.models import StandardAnswerCategoryCreationRequest
from onyx.server.manage.models import StandardAnswerCreationRequest
router = APIRouter(prefix="/manage")

View File

@@ -8,8 +8,8 @@ from fastapi import Request
from fastapi import Response
from ee.onyx.auth.users import decode_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from onyx.auth.api_key import extract_tenant_from_api_key_header
from onyx.configs.constants import ANONYMOUS_USER_COOKIE_NAME
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.engine import is_valid_schema_name
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis

View File

@@ -14,11 +14,11 @@ from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
from ee.onyx.server.oauth.api_router import router
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
from onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.confluence.utils import CONFLUENCE_OAUTH_TOKEN_URL

View File

@@ -11,11 +11,11 @@ from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
from ee.onyx.server.oauth.api_router import router
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds

View File

@@ -9,11 +9,11 @@ from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
from ee.onyx.server.oauth.api_router import router
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
from onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.db.credentials import create_credential

View File

@@ -43,6 +43,7 @@ from onyx.db.chat import get_or_create_root_message
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.utils import get_max_input_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.secondary_llm_flows.query_expansion import thread_based_query_rephrase
from onyx.server.query_and_chat.models import ChatMessageDetail
@@ -338,7 +339,10 @@ def handle_send_message_simple_with_history(
provider_type=llm.config.model_provider,
)
max_history_tokens = int(llm.config.max_input_tokens * CHAT_TARGET_CHUNK_PERCENTAGE)
input_tokens = get_max_input_tokens(
model_name=llm.config.model_name, model_provider=llm.config.model_provider
)
max_history_tokens = int(input_tokens * CHAT_TARGET_CHUNK_PERCENTAGE)
# Every chat Session begins with an empty root message
root_message = get_or_create_root_message(

View File

@@ -6,6 +6,7 @@ from pydantic import BaseModel
from pydantic import Field
from pydantic import model_validator
from ee.onyx.server.manage.models import StandardAnswer
from onyx.chat.models import CitationInfo
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
@@ -18,7 +19,6 @@ from onyx.context.search.models import ChunkContext
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SavedSearchDoc
from onyx.server.manage.models import StandardAnswer
class StandardAnswerRequest(BaseModel):

View File

@@ -38,6 +38,7 @@ from onyx.db.persona import get_persona_by_id
from onyx.llm.factory import get_default_llms
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.factory import get_main_llm_from_tuple
from onyx.llm.utils import get_max_input_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.utils import get_json_line
from onyx.utils.logger import setup_logger
@@ -176,9 +177,10 @@ def get_answer_stream(
provider_type=llm.config.model_provider,
)
max_history_tokens = int(
llm.config.max_input_tokens * MAX_THREAD_CONTEXT_PERCENTAGE
input_tokens = get_max_input_tokens(
model_name=llm.config.model_name, model_provider=llm.config.model_provider
)
max_history_tokens = int(input_tokens * MAX_THREAD_CONTEXT_PERCENTAGE)
combined_message = combine_message_thread(
messages=query_request.messages,

View File

@@ -1,3 +1,5 @@
import csv
import io
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
@@ -11,37 +13,25 @@ from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from ee.onyx.db.query_history import fetch_chat_sessions_eagerly_by_time
from ee.onyx.db.query_history import get_all_query_history_export_tasks
from ee.onyx.db.query_history import get_page_of_chat_sessions
from ee.onyx.db.query_history import get_total_filtered_chat_sessions_count
from ee.onyx.server.query_history.models import ChatSessionMinimal
from ee.onyx.server.query_history.models import ChatSessionSnapshot
from ee.onyx.server.query_history.models import MessageSnapshot
from ee.onyx.server.query_history.models import QueryHistoryExport
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_display_email
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.background.task_utils import construct_query_history_report_name
from onyx.chat.chat_utils import create_chat_chain
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import FileType
from onyx.configs.constants import MessageType
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import QAFeedbackType
from onyx.configs.constants import QueryHistoryType
from onyx.configs.constants import SessionType
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.engine import get_session
from onyx.db.enums import TaskStatus
from onyx.db.models import ChatSession
from onyx.db.models import User
from onyx.db.pg_file_store import get_query_history_export_files
from onyx.db.tasks import get_task_with_id
from onyx.file_store.file_store import get_default_file_store
from onyx.server.documents.models import PaginatedReturn
from onyx.server.query_and_chat.models import ChatSessionDetails
from onyx.server.query_and_chat.models import ChatSessionsResponse
@@ -51,16 +41,6 @@ router = APIRouter()
ONYX_ANONYMIZED_EMAIL = "anonymous@anonymous.invalid"
def ensure_query_history_is_enabled(
disallowed: list[QueryHistoryType],
) -> None:
if ONYX_QUERY_HISTORY_TYPE in disallowed:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Query history has been disabled by the administrator.",
)
def fetch_and_process_chat_session_history(
db_session: Session,
start: datetime,
@@ -139,12 +119,14 @@ def get_user_chat_sessions(
) -> ChatSessionsResponse:
# we specifically don't allow this endpoint if "anonymized" since
# this is a direct query on the user id
ensure_query_history_is_enabled(
[
QueryHistoryType.DISABLED,
QueryHistoryType.ANONYMIZED,
]
)
if ONYX_QUERY_HISTORY_TYPE in [
QueryHistoryType.DISABLED,
QueryHistoryType.ANONYMIZED,
]:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Per user query history has been disabled by the administrator.",
)
try:
chat_sessions = get_chat_sessions_by_user(
@@ -181,7 +163,11 @@ def get_chat_session_history(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> PaginatedReturn[ChatSessionMinimal]:
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Query history has been disabled by the administrator.",
)
page_of_chat_sessions = get_page_of_chat_sessions(
page_num=page_num,
@@ -219,7 +205,11 @@ def get_chat_session_admin(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> ChatSessionSnapshot:
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Query history has been disabled by the administrator.",
)
try:
chat_session = get_chat_session_by_id(
@@ -230,8 +220,7 @@ def get_chat_session_admin(
)
except ValueError:
raise HTTPException(
HTTPStatus.BAD_REQUEST,
f"Chat session with id '{chat_session_id}' does not exist.",
400, f"Chat session with id '{chat_session_id}' does not exist."
)
snapshot = snapshot_from_chat_session(
chat_session=chat_session, db_session=db_session
@@ -239,7 +228,7 @@ def get_chat_session_admin(
if snapshot is None:
raise HTTPException(
HTTPStatus.BAD_REQUEST,
400,
f"Could not create snapshot for chat session with id '{chat_session_id}'",
)
@@ -249,151 +238,52 @@ def get_chat_session_admin(
return snapshot
@router.get("/admin/query-history/list")
def list_all_query_history_exports(
@router.get("/admin/query-history-csv")
def get_query_history_as_csv(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[QueryHistoryExport]:
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
try:
pending_tasks = [
QueryHistoryExport.from_task(task)
for task in get_all_query_history_export_tasks(db_session=db_session)
]
generated_files = [
QueryHistoryExport.from_file(file)
for file in get_query_history_export_files(db_session=db_session)
]
merged = pending_tasks + generated_files
# We sort based off of the start-time of the task.
# We also return it in reverse order since viewing generated reports in most-recent to least-recent is most common.
merged.sort(key=lambda task: task.start_time, reverse=True)
return merged
except Exception as e:
raise HTTPException(
HTTPStatus.INTERNAL_SERVER_ERROR, f"Failed to get all tasks: {e}"
)
@router.post("/admin/query-history/start-export")
def start_query_history_export(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
start: datetime | None = None,
end: datetime | None = None,
) -> dict[str, str]:
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
start = start or datetime.fromtimestamp(0, tz=timezone.utc)
end = end or datetime.now(tz=timezone.utc)
if start >= end:
raise HTTPException(
HTTPStatus.BAD_REQUEST,
f"Start time must come before end time, but instead got the start time coming after; {start=} {end=}",
)
task = client_app.send_task(
OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK,
priority=OnyxCeleryPriority.MEDIUM,
queue=OnyxCeleryQueues.CSV_GENERATION,
kwargs={
"start": start,
"end": end,
},
)
return {"request_id": task.id}
@router.get("/admin/query-history/export-status")
def get_query_history_export_status(
request_id: str,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
task = get_task_with_id(db_session=db_session, task_id=request_id)
if task:
return {"status": task.status}
# If task is None, then it's possible that the task has already finished processing.
# Therefore, we should then check if the export file has already been stored inside of the file-store.
# If that *also* doesn't exist, then we can return a 404.
file_store = get_default_file_store(db_session)
report_name = construct_query_history_report_name(request_id)
has_file = file_store.has_file(
file_name=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,
file_type=FileType.CSV,
)
if not has_file:
raise HTTPException(
HTTPStatus.NOT_FOUND,
f"No task with {request_id=} was found",
)
return {"status": TaskStatus.SUCCESS}
@router.get("/admin/query-history/download")
def download_query_history_csv(
request_id: str,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse:
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Query history has been disabled by the administrator.",
)
report_name = construct_query_history_report_name(request_id)
file_store = get_default_file_store(db_session)
has_file = file_store.has_file(
file_name=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,
file_type=FileType.CSV,
# this call is very expensive and is timing out via endpoint
# TODO: optimize call and/or generate via background task
complete_chat_session_history = fetch_and_process_chat_session_history(
db_session=db_session,
start=start or datetime.fromtimestamp(0, tz=timezone.utc),
end=end or datetime.now(tz=timezone.utc),
feedback_type=None,
limit=None,
)
if has_file:
try:
csv_stream = file_store.read_file(report_name)
except Exception as e:
raise HTTPException(
HTTPStatus.INTERNAL_SERVER_ERROR,
f"Failed to read query history file: {str(e)}",
)
csv_stream.seek(0)
return StreamingResponse(
iter(csv_stream),
media_type=FileType.CSV,
headers={"Content-Disposition": f"attachment;filename={report_name}"},
question_answer_pairs: list[QuestionAnswerPairSnapshot] = []
for chat_session_snapshot in complete_chat_session_history:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
chat_session_snapshot.user_email = ONYX_ANONYMIZED_EMAIL
question_answer_pairs.extend(
QuestionAnswerPairSnapshot.from_chat_session_snapshot(chat_session_snapshot)
)
# If the file doesn't exist yet, it may still be processing.
# Therefore, we check the task queue to determine its status, if there is any.
task = get_task_with_id(db_session=db_session, task_id=request_id)
if not task:
raise HTTPException(
HTTPStatus.NOT_FOUND,
f"No task with {request_id=} was found",
)
# Create an in-memory text stream
stream = io.StringIO()
writer = csv.DictWriter(
stream, fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys())
)
writer.writeheader()
for row in question_answer_pairs:
writer.writerow(row.to_json())
if task.status in [TaskStatus.STARTED, TaskStatus.PENDING]:
raise HTTPException(
HTTPStatus.ACCEPTED, f"Task with {request_id=} is still being worked on"
)
# Reset the stream's position to the start
stream.seek(0)
elif task.status == TaskStatus.FAILURE:
raise HTTPException(
HTTPStatus.INTERNAL_SERVER_ERROR,
f"Task with {request_id=} failed to be processed",
)
else:
# This is the final case in which `task.status == SUCCESS`
raise RuntimeError(
"The task was marked as success, the file was not found in the file store; this is an internal error..."
)
return StreamingResponse(
iter([stream.getvalue()]),
media_type="text/csv",
headers={"Content-Disposition": "attachment;filename=onyx_query_history.csv"},
)

View File

@@ -3,17 +3,12 @@ from uuid import UUID
from pydantic import BaseModel
from ee.onyx.background.task_name_builders import QUERY_HISTORY_TASK_NAME_PREFIX
from onyx.auth.users import get_display_email
from onyx.background.task_utils import extract_task_id_from_query_history_report_name
from onyx.configs.constants import MessageType
from onyx.configs.constants import QAFeedbackType
from onyx.configs.constants import SessionType
from onyx.db.enums import TaskStatus
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import PGFileStore
from onyx.db.models import TaskQueueState
class AbridgedSearchDoc(BaseModel):
@@ -221,59 +216,3 @@ class QuestionAnswerPairSnapshot(BaseModel):
"time_created": str(self.time_created),
"flow_type": self.flow_type,
}
class QueryHistoryExport(BaseModel):
task_id: str
status: TaskStatus
start: datetime
end: datetime
start_time: datetime
@classmethod
def from_task(
cls,
task_queue_state: TaskQueueState,
) -> "QueryHistoryExport":
start_end = task_queue_state.task_name.removeprefix(
f"{QUERY_HISTORY_TASK_NAME_PREFIX}_"
)
start, end = start_end.split("_")
if not task_queue_state.start_time:
raise RuntimeError("The start time of the task must always be present")
return cls(
task_id=task_queue_state.task_id,
status=task_queue_state.status,
start=datetime.fromisoformat(start),
end=datetime.fromisoformat(end),
start_time=task_queue_state.start_time,
)
@classmethod
def from_file(
cls,
file: PGFileStore,
) -> "QueryHistoryExport":
if not file.file_metadata or not isinstance(file.file_metadata, dict):
raise RuntimeError(
"The file metadata must be non-null, and must be of type `dict[str, str]`"
)
metadata = QueryHistoryFileMetadata.model_validate(dict(file.file_metadata))
task_id = extract_task_id_from_query_history_report_name(file.file_name)
return cls(
task_id=task_id,
status=TaskStatus.SUCCESS,
start=metadata.start,
end=metadata.end,
start_time=metadata.start_time,
)
class QueryHistoryFileMetadata(BaseModel):
start: datetime
end: datetime
start_time: datetime

View File

@@ -1,6 +1,5 @@
import contextlib
import secrets
import string
from typing import Any
from fastapi import APIRouter
@@ -10,6 +9,7 @@ from fastapi import Request
from fastapi import Response
from fastapi import status
from fastapi_users import exceptions
from fastapi_users.password import PasswordHelper
from onelogin.saml2.auth import OneLogin_Saml2_Auth # type: ignore
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
@@ -28,7 +28,6 @@ from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from onyx.db.auth import get_user_count
from onyx.db.auth import get_user_db
from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_context_manager
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.utils.logger import setup_logger
@@ -39,21 +38,14 @@ router = APIRouter(prefix="/auth/saml")
async def upsert_saml_user(email: str) -> User:
"""
Creates or updates a user account for SAML authentication.
For new users or users with non-web-login roles:
1. Generates a secure random password that meets validation criteria
2. Creates the user with appropriate role and verified status
SAML users never use this password directly as they authenticate via their
Identity Provider, but we need a valid password to satisfy system requirements.
"""
logger.debug(f"Attempting to upsert SAML user with email: {email}")
get_async_session_context = contextlib.asynccontextmanager(
get_async_session
) # type:ignore
get_user_db_context = contextlib.asynccontextmanager(get_user_db)
get_user_manager_context = contextlib.asynccontextmanager(get_user_manager)
async with get_async_session_context_manager() as session:
async with get_async_session_context() as session:
async with get_user_db_context(session) as user_db:
async with get_user_manager_context(user_db) as user_manager:
try:
@@ -68,41 +60,15 @@ async def upsert_saml_user(email: str) -> User:
user_count = await get_user_count()
role = UserRole.ADMIN if user_count == 0 else UserRole.BASIC
# Generate a secure random password meeting validation requirements
# We use a secure random password since we never need to know what it is
# (SAML users authenticate via their IdP)
secure_random_password = "".join(
[
# Ensure minimum requirements are met
secrets.choice(
string.ascii_uppercase
), # at least one uppercase
secrets.choice(
string.ascii_lowercase
), # at least one lowercase
secrets.choice(string.digits), # at least one digit
secrets.choice(
"!@#$%^&*()-_=+[]{}|;:,.<>?"
), # at least one special
# Fill remaining length with random chars (mix of all types)
"".join(
secrets.choice(
string.ascii_letters
+ string.digits
+ "!@#$%^&*()-_=+[]{}|;:,.<>?"
)
for _ in range(12)
),
]
)
fastapi_users_pw_helper = PasswordHelper()
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
# Create the user with SAML-appropriate settings
user = await user_manager.create(
UserCreate(
email=email,
password=secure_random_password, # Pass raw password, not hash
password=hashed_pass,
role=role,
is_verified=True, # SAML users are pre-verified by their IdP
)
)

View File

@@ -5,6 +5,7 @@ from fastapi import Response
from sqlalchemy.exc import IntegrityError
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
@@ -16,7 +17,6 @@ from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import current_admin_user
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.constants import ANONYMOUS_USER_COOKIE_NAME
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.engine import get_session_with_shared_schema
from onyx.utils.logger import setup_logger

View File

@@ -7,7 +7,7 @@ from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.settings.models import ApplicationStatus
from onyx.server.settings.store import load_settings
from onyx.server.settings.store import store_settings
from onyx.utils.logger import setup_logger
from onyx.setup import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()

View File

@@ -39,13 +39,10 @@ from onyx.db.models import SearchSettings
from onyx.db.models import UserTenantMapping
from onyx.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
from onyx.llm.llm_provider_options import ANTHROPIC_VISIBLE_MODEL_NAMES
from onyx.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
from onyx.llm.llm_provider_options import OPEN_AI_VISIBLE_MODEL_NAMES
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.setup import setup_onyx
from onyx.utils.telemetry import create_milestone_and_report
from shared_configs.configs import MULTI_TENANT
@@ -272,14 +269,8 @@ def configure_default_api_keys(db_session: Session) -> None:
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-7-sonnet-20250219",
fast_default_model_name="claude-3-5-sonnet-20241022",
model_configurations=[
ModelConfigurationUpsertRequest(
name=name,
is_visible=name in ANTHROPIC_VISIBLE_MODEL_NAMES,
max_input_tokens=None,
)
for name in ANTHROPIC_MODEL_NAMES
],
model_names=ANTHROPIC_MODEL_NAMES,
display_model_names=["claude-3-5-sonnet-20241022"],
api_key_changed=True,
)
try:
@@ -299,14 +290,8 @@ def configure_default_api_keys(db_session: Session) -> None:
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4o",
fast_default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name=model_name,
is_visible=model_name in OPEN_AI_VISIBLE_MODEL_NAMES,
max_input_tokens=None,
)
for model_name in OPEN_AI_MODEL_NAMES
],
model_names=OPEN_AI_MODEL_NAMES,
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
api_key_changed=True,
)
try:
@@ -421,6 +406,7 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
headers=headers,
json=payload.model_dump(),
) as response:
print(response)
if response.status != 200:
error_text = await response.text()
logger.error(f"Control plane tenant creation failed: {error_text}")

View File

@@ -9,7 +9,7 @@ from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import UserTenantMapping
from onyx.server.manage.models import TenantSnapshot
from onyx.utils.logger import setup_logger
from onyx.setup import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR

View File

@@ -1,5 +1,3 @@
from typing import cast
import numpy as np
import torch
import torch.nn.functional as F
@@ -41,10 +39,10 @@ logger = setup_logger()
router = APIRouter(prefix="/custom")
_CONNECTOR_CLASSIFIER_TOKENIZER: PreTrainedTokenizer | None = None
_CONNECTOR_CLASSIFIER_TOKENIZER: AutoTokenizer | None = None
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
_INTENT_TOKENIZER: PreTrainedTokenizer | None = None
_INTENT_TOKENIZER: AutoTokenizer | None = None
_INTENT_MODEL: HybridClassifier | None = None
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
@@ -52,14 +50,13 @@ _INFORMATION_CONTENT_MODEL: SetFitModel | None = None
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
def get_connector_classifier_tokenizer() -> PreTrainedTokenizer:
def get_connector_classifier_tokenizer() -> AutoTokenizer:
global _CONNECTOR_CLASSIFIER_TOKENIZER
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
# The tokenizer details are not uploaded to the HF hub since it's just the
# unmodified distilbert tokenizer.
_CONNECTOR_CLASSIFIER_TOKENIZER = cast(
PreTrainedTokenizer,
AutoTokenizer.from_pretrained("distilbert-base-uncased"),
_CONNECTOR_CLASSIFIER_TOKENIZER = AutoTokenizer.from_pretrained(
"distilbert-base-uncased"
)
return _CONNECTOR_CLASSIFIER_TOKENIZER
@@ -95,15 +92,12 @@ def get_local_connector_classifier(
return _CONNECTOR_CLASSIFIER_MODEL
def get_intent_model_tokenizer() -> PreTrainedTokenizer:
def get_intent_model_tokenizer() -> AutoTokenizer:
global _INTENT_TOKENIZER
if _INTENT_TOKENIZER is None:
# The tokenizer details are not uploaded to the HF hub since it's just the
# unmodified distilbert tokenizer.
_INTENT_TOKENIZER = cast(
PreTrainedTokenizer,
AutoTokenizer.from_pretrained("distilbert-base-uncased"),
)
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained("distilbert-base-uncased")
return _INTENT_TOKENIZER
@@ -401,9 +395,9 @@ def run_content_classification_inference(
def map_keywords(
input_ids: torch.Tensor, tokenizer: PreTrainedTokenizer, is_keyword: list[bool]
input_ids: torch.Tensor, tokenizer: AutoTokenizer, is_keyword: list[bool]
) -> list[str]:
tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore
tokens = tokenizer.convert_ids_to_tokens(input_ids)
if not len(tokens) == len(is_keyword):
raise ValueError("Length of tokens and keyword predictions must match")

View File

@@ -1,6 +1,5 @@
import json
import os
from typing import cast
import torch
import torch.nn as nn
@@ -14,14 +13,15 @@ class HybridClassifier(nn.Module):
super().__init__()
config = DistilBertConfig()
self.distilbert = DistilBertModel(config)
config = self.distilbert.config # type: ignore
# Keyword tokenwise binary classification layer
self.keyword_classifier = nn.Linear(config.dim, 2)
self.keyword_classifier = nn.Linear(self.distilbert.config.dim, 2)
# Intent Classifier layers
self.pre_classifier = nn.Linear(config.dim, config.dim)
self.intent_classifier = nn.Linear(config.dim, 2)
self.pre_classifier = nn.Linear(
self.distilbert.config.dim, self.distilbert.config.dim
)
self.intent_classifier = nn.Linear(self.distilbert.config.dim, 2)
self.device = torch.device("cpu")
@@ -30,7 +30,7 @@ class HybridClassifier(nn.Module):
query_ids: torch.Tensor,
query_mask: torch.Tensor,
) -> dict[str, torch.Tensor]:
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask) # type: ignore
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask)
sequence_output = outputs.last_hidden_state
# Intent classification on the CLS token
@@ -79,9 +79,8 @@ class ConnectorClassifier(nn.Module):
self.config = config
self.distilbert = DistilBertModel(config)
config = self.distilbert.config # type: ignore
self.connector_global_classifier = nn.Linear(config.dim, 1)
self.connector_match_classifier = nn.Linear(config.dim, 1)
self.connector_global_classifier = nn.Linear(self.distilbert.config.dim, 1)
self.connector_match_classifier = nn.Linear(self.distilbert.config.dim, 1)
self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
# Token indicating end of connector name, and on which classifier is used
@@ -96,7 +95,7 @@ class ConnectorClassifier(nn.Module):
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states = self.distilbert( # type: ignore
hidden_states = self.distilbert(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state
@@ -115,10 +114,7 @@ class ConnectorClassifier(nn.Module):
@classmethod
def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
config = cast(
DistilBertConfig,
DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")),
)
config = DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json"))
device = (
torch.device("cuda")
if torch.cuda.is_available()

View File

@@ -8,11 +8,6 @@ from onyx.configs.constants import PUBLIC_DOC_PAT
@dataclass(frozen=True)
class ExternalAccess:
# arbitrary limit to prevent excessively large permissions sets
# not internally enforced ... the caller can check this before using the instance
MAX_NUM_ENTRIES = 1000
# Emails of external users with access to the doc externally
external_user_emails: set[str]
# Names or external IDs of groups with access to the doc
@@ -36,10 +31,6 @@ class ExternalAccess:
f"is_public={self.is_public})"
)
@property
def num_entries(self) -> int:
return len(self.external_user_emails) + len(self.external_user_group_ids)
@dataclass(frozen=True)
class DocExternalAccess:

View File

@@ -40,7 +40,6 @@ def process_llm_stream(
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information.
for message in messages:
answer_piece = message.content
if not isinstance(answer_piece, str):
# this is only used for logging, so fine to

View File

@@ -26,7 +26,7 @@ def decide_refinement_need(
graph_config = cast(GraphConfig, config["metadata"]["config"])
decision = graph_config.behavior.allow_refinement
decision = True # TODO: just for current testing purposes
if state.answer_error:
return RequireRefinemenEvalUpdate(

View File

@@ -74,9 +74,9 @@ def extract_entities_terms(
# Calculation here is only approximate
doc_context = trim_prompt_piece(
config=graph_config.tooling.fast_llm.config,
prompt_piece=doc_context,
reserved_str=ENTITY_TERM_EXTRACTION_PROMPT
graph_config.tooling.fast_llm.config,
doc_context,
ENTITY_TERM_EXTRACTION_PROMPT
+ question
+ ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE,
)

View File

@@ -267,9 +267,9 @@ def generate_validate_refined_answer(
relevant_docs_str = format_docs(answer_generation_documents.context_documents)
relevant_docs_str = trim_prompt_piece(
config=model.config,
prompt_piece=relevant_docs_str,
reserved_str=base_prompt
model.config,
relevant_docs_str,
base_prompt
+ question
+ sub_question_answer_str
+ initial_answer

View File

@@ -10,7 +10,6 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.configs.agent_configs import AGENT_MAX_VERIFICATION_HITS
def kickoff_verification(
@@ -23,7 +22,7 @@ def kickoff_verification(
are done here, so this could be replaced with an edge. But we may choose to make state
updates later.)
"""
retrieved_documents = state.retrieved_documents[:AGENT_MAX_VERIFICATION_HITS]
retrieved_documents = state.retrieved_documents
verification_question = state.question
sub_question_id = state.sub_question_id

View File

@@ -71,9 +71,7 @@ def verify_documents(
fast_llm = graph_config.tooling.fast_llm
document_content = trim_prompt_piece(
config=fast_llm.config,
prompt_piece=document_content,
reserved_str=DOCUMENT_VERIFICATION_PROMPT + question,
fast_llm.config, document_content, DOCUMENT_VERIFICATION_PROMPT + question
)
msg = [

View File

@@ -1,8 +1,6 @@
from typing import cast
from uuid import uuid4
from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import ToolCall
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
@@ -12,21 +10,13 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.orchestration.states import ToolChoice
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import (
get_tool_call_for_non_tool_calling_llm_impl,
)
from onyx.configs.chat_configs import USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
from onyx.context.search.preprocessing.preprocessing import query_analysis
from onyx.context.search.retrieval.search_runner import get_query_embedding
from onyx.llm.factory import get_default_llms
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
from onyx.prompts.chat_prompts import QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
from onyx.prompts.chat_prompts import QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT
from onyx.tools.models import QueryExpansions
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
@@ -40,49 +30,6 @@ from shared_configs.model_server_models import Embedding
logger = setup_logger()
def _create_history_str(prompt_builder: AnswerPromptBuilder) -> str:
# TODO: Add trimming logic
history_segments = []
for msg in prompt_builder.message_history:
if isinstance(msg, HumanMessage):
role = "User"
elif isinstance(msg, AIMessage):
role = "Assistant"
else:
continue
history_segments.append(f"{role}:\n {msg.content}\n\n")
return "\n".join(history_segments)
def _expand_query(
query: str,
expansion_type: QueryExpansionType,
prompt_builder: AnswerPromptBuilder,
) -> str:
history_str = _create_history_str(prompt_builder)
if history_str:
if expansion_type == QueryExpansionType.KEYWORD:
base_prompt = QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
else:
base_prompt = QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
expansion_prompt = base_prompt.format(question=query, history=history_str)
else:
if expansion_type == QueryExpansionType.KEYWORD:
base_prompt = QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
else:
base_prompt = QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT
expansion_prompt = base_prompt.format(question=query)
msg = HumanMessage(content=expansion_prompt)
primary_llm, _ = get_default_llms()
response = primary_llm.invoke([msg])
rephrased_query: str = cast(str, response.content)
return rephrased_query
# TODO: break this out into an implementation function
# and a function that handles extracting the necessary fields
# from the state and config
@@ -105,16 +52,7 @@ def choose_tool(
embedding_thread: TimeoutThread[Embedding] | None = None
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
expanded_keyword_thread: TimeoutThread[str] | None = None
expanded_semantic_thread: TimeoutThread[str] | None = None
override_kwargs: SearchToolOverrideKwargs | None = None
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
llm = agent_config.tooling.primary_llm
skip_gen_ai_answer_generation = agent_config.behavior.skip_gen_ai_answer_generation
if (
not agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
@@ -134,20 +72,11 @@ def choose_tool(
agent_config.inputs.search_request.query,
)
if USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH:
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
expanded_keyword_thread = run_in_background(
_expand_query,
agent_config.inputs.search_request.query,
QueryExpansionType.KEYWORD,
prompt_builder,
)
expanded_semantic_thread = run_in_background(
_expand_query,
agent_config.inputs.search_request.query,
QueryExpansionType.SEMANTIC,
prompt_builder,
)
llm = agent_config.tooling.primary_llm
skip_gen_ai_answer_generation = agent_config.behavior.skip_gen_ai_answer_generation
structured_response_format = agent_config.inputs.structured_response_format
tools = [
@@ -280,23 +209,6 @@ def choose_tool(
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
if (
selected_tool.name == SearchTool._NAME
and expanded_keyword_thread
and expanded_semantic_thread
):
keyword_expansion = wait_on_background(expanded_keyword_thread)
semantic_expansion = wait_on_background(expanded_semantic_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.expanded_queries = QueryExpansions(
keywords_expansions=[keyword_expansion],
semantic_expansions=[semantic_expansion],
)
logger.info(f"Original query: {agent_config.inputs.search_request.query}")
logger.info(f"Expanded keyword queries: {keyword_expansion}")
logger.info(f"Expanded semantic queries: {semantic_expansion}")
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=selected_tool,

View File

@@ -9,7 +9,6 @@ from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import GraphConfig
from onyx.chat.models import LlmDoc
from onyx.context.search.utils import dedupe_documents
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
@@ -51,16 +50,16 @@ def basic_use_tool_response(
final_search_results = []
initial_search_results = []
initial_search_document_ids: set[str] = set()
for yield_item in tool_call_responses:
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_search_results = cast(list[LlmDoc], yield_item.response)
elif yield_item.id == SEARCH_RESPONSE_SUMMARY_ID:
search_response_summary = cast(SearchResponseSummary, yield_item.response)
# use same function from _handle_search_tool_response_summary
initial_search_results = [
section_to_llm_doc(section)
for section in dedupe_documents(search_response_summary.top_sections)[0]
]
for section in search_response_summary.top_sections:
if section.center_chunk.document_id not in initial_search_document_ids:
initial_search_document_ids.add(section.center_chunk.document_id)
initial_search_results.append(section_to_llm_doc(section))
new_tool_call_chunk = AIMessageChunk(content="")
if not agent_config.behavior.skip_gen_ai_answer_generation:

View File

@@ -30,7 +30,7 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionPiece
from onyx.chat.models import ToolResponse
from onyx.configs.agent_configs import AGENT_ALLOW_REFINEMENT
from onyx.configs.agent_configs import ALLOW_REFINEMENT
from onyx.configs.agent_configs import INITIAL_SEARCH_DECOMPOSITION_ENABLED
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
@@ -51,6 +51,7 @@ def _parse_agent_event(
Parse the event into a typed object.
Return None if we are not interested in the event.
"""
event_type = event["event"]
# We always just yield the event data, but this piece is useful for two development reasons:
@@ -104,7 +105,7 @@ def run_graph(
config.behavior.perform_initial_search_decomposition = (
INITIAL_SEARCH_DECOMPOSITION_ENABLED
)
config.behavior.allow_refinement = AGENT_ALLOW_REFINEMENT
config.behavior.allow_refinement = ALLOW_REFINEMENT
for event in manage_sync_streaming(
compiled_graph=compiled_graph, config=config, graph_input=input

View File

@@ -17,6 +17,7 @@ from onyx.configs.agent_configs import AGENT_MAX_STATIC_HISTORY_WORD_LENGTH
from onyx.configs.constants import MessageType
from onyx.context.search.models import InferenceSection
from onyx.llm.interfaces import LLMConfig
from onyx.llm.utils import get_max_input_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.natural_language_processing.utils import tokenizer_trim_content
from onyx.prompts.agent_search import HISTORY_FRAMING_PROMPT
@@ -43,9 +44,9 @@ def build_sub_question_answer_prompt(
docs_str = format_docs(docs)
docs_str = trim_prompt_piece(
config=config,
prompt_piece=docs_str,
reserved_str=SUB_QUESTION_RAG_PROMPT + question + original_question + date_str,
config,
docs_str,
SUB_QUESTION_RAG_PROMPT + question + original_question + date_str,
)
human_message = HumanMessage(
content=SUB_QUESTION_RAG_PROMPT.format(
@@ -60,9 +61,15 @@ def build_sub_question_answer_prompt(
def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -> str:
# TODO: save the max input tokens in LLMConfig
max_tokens = get_max_input_tokens(
model_provider=config.model_provider,
model_name=config.model_name,
)
# no need to trim if a conservative estimate of one token
# per character is already less than the max tokens
if len(prompt_piece) + len(reserved_str) < config.max_input_tokens:
if len(prompt_piece) + len(reserved_str) < max_tokens:
return prompt_piece
llm_tokenizer = get_tokenizer(
@@ -73,8 +80,7 @@ def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -
# slightly conservative trimming
return tokenizer_trim_content(
content=prompt_piece,
desired_length=config.max_input_tokens
- len(llm_tokenizer.encode(reserved_str)),
desired_length=max_tokens - len(llm_tokenizer.encode(reserved_str)),
tokenizer=llm_tokenizer,
)
@@ -174,3 +180,35 @@ def binary_string_test_after_answer_separator(
relevant_text = text.split(f"{separator}")[-1]
return binary_string_test(relevant_text, positive_value)
def build_dc_search_prompt(
question: str,
original_question: str,
docs: list[InferenceSection],
persona_specification: str,
config: LLMConfig,
) -> list[SystemMessage | HumanMessage | AIMessage | ToolMessage]:
system_message = SystemMessage(
content=persona_specification,
)
date_str = build_date_time_string()
docs_str = format_docs(docs)
docs_str = trim_prompt_piece(
config,
docs_str,
SUB_QUESTION_RAG_PROMPT + question + original_question + date_str,
)
human_message = HumanMessage(
content=SUB_QUESTION_RAG_PROMPT.format(
question=question,
original_question=original_question,
context=docs_str,
date_prompt=date_str,
)
)
return [system_message, human_message]

View File

@@ -1,4 +1,3 @@
from enum import Enum
from typing import Any
from pydantic import BaseModel
@@ -154,8 +153,3 @@ class AnswerGenerationDocuments(BaseModel):
BaseMessage_Content = str | list[str | dict[str, Any]]
class QueryExpansionType(Enum):
KEYWORD = "keyword"
SEMANTIC = "semantic"

View File

@@ -1,4 +1,3 @@
import base64
import smtplib
from datetime import datetime
from email.mime.image import MIMEImage
@@ -7,21 +6,8 @@ from email.mime.text import MIMEText
from email.utils import formatdate
from email.utils import make_msgid
import sendgrid # type: ignore
from sendgrid.helpers.mail import Attachment # type: ignore
from sendgrid.helpers.mail import Content
from sendgrid.helpers.mail import ContentId
from sendgrid.helpers.mail import Disposition
from sendgrid.helpers.mail import Email
from sendgrid.helpers.mail import FileContent
from sendgrid.helpers.mail import FileName
from sendgrid.helpers.mail import FileType
from sendgrid.helpers.mail import Mail
from sendgrid.helpers.mail import To
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import EMAIL_FROM
from onyx.configs.app_configs import SENDGRID_API_KEY
from onyx.configs.app_configs import SMTP_PASS
from onyx.configs.app_configs import SMTP_PORT
from onyx.configs.app_configs import SMTP_SERVER
@@ -32,12 +18,11 @@ from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
from onyx.configs.constants import ONYX_SLACK_URL
from onyx.db.models import User
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.logger import setup_logger
from onyx.utils.file import FileWithMimeType
from onyx.utils.url import add_url_params
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
HTML_EMAIL_TEMPLATE = """\
<!DOCTYPE html>
@@ -191,70 +176,6 @@ def send_email(
if not EMAIL_CONFIGURED:
raise ValueError("Email is not configured.")
if SENDGRID_API_KEY:
send_email_with_sendgrid(
user_email, subject, html_body, text_body, mail_from, inline_png
)
return
send_email_with_smtplib(
user_email, subject, html_body, text_body, mail_from, inline_png
)
def send_email_with_sendgrid(
user_email: str,
subject: str,
html_body: str,
text_body: str,
mail_from: str = EMAIL_FROM,
inline_png: tuple[str, bytes] | None = None,
) -> None:
from_email = Email(mail_from) if mail_from else Email("noreply@onyx.app")
to_email = To(user_email)
mail = Mail(
from_email=from_email,
to_emails=to_email,
subject=subject,
plain_text_content=Content("text/plain", text_body),
)
# Add HTML content
mail.add_content(Content("text/html", html_body))
if inline_png:
image_name, image_data = inline_png
# Create attachment
encoded_image = base64.b64encode(image_data).decode()
attachment = Attachment()
attachment.file_content = FileContent(encoded_image)
attachment.file_name = FileName(image_name)
attachment.file_type = FileType("image/png")
attachment.disposition = Disposition("inline")
attachment.content_id = ContentId(image_name)
mail.add_attachment(attachment)
# Get a JSON-ready representation of the Mail object
mail_json = mail.get()
sg = sendgrid.SendGridAPIClient(api_key=SENDGRID_API_KEY)
response = sg.client.mail.send.post(request_body=mail_json) # can raise
if response.status_code != 202:
logger.warning(f"Unexpected status code {response.status_code}")
def send_email_with_smtplib(
user_email: str,
subject: str,
html_body: str,
text_body: str,
mail_from: str = EMAIL_FROM,
inline_png: tuple[str, bytes] | None = None,
) -> None:
# Create a multipart/alternative message - this indicates these are alternative versions of the same content
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
@@ -289,10 +210,13 @@ def send_email_with_smtplib(
html_part = MIMEText(html_body, "html")
msg.attach(html_part)
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
s.starttls()
s.login(SMTP_USER, SMTP_PASS)
s.send_message(msg)
try:
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
s.starttls()
s.login(SMTP_USER, SMTP_PASS)
s.send_message(msg)
except Exception as e:
raise e
def send_subscription_cancellation_email(user_email: str) -> None:
@@ -340,13 +264,27 @@ def send_subscription_cancellation_email(user_email: str) -> None:
)
def build_user_email_invite(
from_email: str, to_email: str, application_name: str, auth_type: AuthType
) -> tuple[str, str]:
def send_user_email_invite(
user_email: str, current_user: User, auth_type: AuthType
) -> None:
onyx_file: FileWithMimeType | None = None
try:
load_runtime_settings_fn = fetch_versioned_implementation(
"onyx.server.enterprise_settings.store", "load_runtime_settings"
)
settings = load_runtime_settings_fn()
application_name = settings.application_name
except ModuleNotFoundError:
application_name = ONYX_DEFAULT_APPLICATION_NAME
onyx_file = OnyxRuntime.get_emailable_logo()
subject = f"Invitation to Join {application_name} Organization"
heading = "You've Been Invited!"
# the exact action taken by the user, and thus the message, depends on the auth type
message = f"<p>You have been invited by {from_email} to join an organization on {application_name}.</p>"
message = f"<p>You have been invited by {current_user.email} to join an organization on {application_name}.</p>"
if auth_type == AuthType.CLOUD:
message += (
"<p>To join the organization, please click the button below to set a password "
@@ -371,7 +309,7 @@ def build_user_email_invite(
raise ValueError(f"Invalid auth type: {auth_type}")
cta_text = "Join Organization"
cta_link = f"{WEB_DOMAIN}/auth/signup?email={to_email}"
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
html_content = build_html_email(
application_name,
@@ -384,36 +322,13 @@ def build_user_email_invite(
# text content is the fallback for clients that don't support HTML
# not as critical, so not having special cases for each auth type
text_content = (
f"You have been invited by {from_email} to join an organization on {application_name}.\n"
f"You have been invited by {current_user.email} to join an organization on {application_name}.\n"
"To join the organization, please visit the following link:\n"
f"{WEB_DOMAIN}/auth/signup?email={to_email}\n"
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
)
if auth_type == AuthType.CLOUD:
text_content += "You'll be asked to set a password or login with Google to complete your registration."
return text_content, html_content
def send_user_email_invite(
user_email: str, current_user: User, auth_type: AuthType
) -> None:
try:
load_runtime_settings_fn = fetch_versioned_implementation(
"onyx.server.enterprise_settings.store", "load_runtime_settings"
)
settings = load_runtime_settings_fn()
application_name = settings.application_name
except ModuleNotFoundError:
application_name = ONYX_DEFAULT_APPLICATION_NAME
onyx_file = OnyxRuntime.get_emailable_logo()
subject = f"Invitation to Join {application_name} Organization"
text_content, html_content = build_user_email_invite(
current_user.email, user_email, application_name, auth_type
)
send_email(
user_email,
subject,

View File

@@ -56,6 +56,7 @@ from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from onyx.auth.api_key import get_hashed_api_key_from_request
from onyx.auth.email_utils import send_forgot_password_email
from onyx.auth.email_utils import send_user_verification_email
@@ -76,7 +77,6 @@ from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import ANONYMOUS_USER_COOKIE_NAME
from onyx.configs.constants import AuthType
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
@@ -92,7 +92,7 @@ from onyx.db.auth import get_user_count
from onyx.db.auth import get_user_db
from onyx.db.auth import SQLAlchemyUserAdminDB
from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_context_manager
from onyx.db.engine import get_async_session_with_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
@@ -253,7 +253,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
)(user_email)
async with get_async_session_context_manager(tenant_id) as db_session:
async with get_async_session_with_tenant(tenant_id) as db_session:
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
db_session, User, OAuthAccount
@@ -296,7 +296,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
)
user: User
async with get_async_session_context_manager(tenant_id) as db_session:
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
@@ -402,7 +402,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Proceed with the tenant context
token = None
async with get_async_session_context_manager(tenant_id) as db_session:
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
verify_email_in_whitelist(account_email, tenant_id)
@@ -642,7 +642,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
return None
# Create a tenant-specific session
async with get_async_session_context_manager(tenant_id) as tenant_session:
async with get_async_session_with_tenant(tenant_id) as tenant_session:
tenant_user_db: SQLAlchemyUserDatabase = SQLAlchemyUserDatabase(
tenant_session, User
)

View File

@@ -152,10 +152,7 @@ class DynamicTenantScheduler(PersistentScheduler):
current_schedule = self.schedule.items()
# get potential new state
try:
beat_multiplier = OnyxRuntime.get_beat_multiplier()
except Exception:
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
beat_multiplier = OnyxRuntime.get_beat_multiplier()
new_schedule = self._generate_schedule(tenant_ids, beat_multiplier)

View File

@@ -94,5 +94,7 @@ def on_setup_logging(
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.external_group_syncing",
]
)

View File

@@ -113,5 +113,6 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.user_file_folder_sync",
"onyx.background.celery.tasks.indexing",
"onyx.background.celery.tasks.tenant_provisioning",
]
)

View File

@@ -92,5 +92,6 @@ def on_setup_logging(
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.tenant_provisioning",
]
)

View File

@@ -108,19 +108,14 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
replication_info: dict[str, Any] = cast(dict, r.info("replication"))
role: str = cast(str, replication_info.get("role", ""))
connected_slaves: int = replication_info.get("connected_slaves", 0)
info: dict[str, Any] = cast(dict, r.info("replication"))
role: str = cast(str, info.get("role"))
connected_slaves: int = info.get("connected_slaves", 0)
logger.info(
f"Redis INFO REPLICATION: role={role} connected_slaves={connected_slaves}"
)
memory_info: dict[str, Any] = cast(dict, r.info("memory"))
maxmemory_policy: str = cast(str, memory_info.get("maxmemory_policy", ""))
logger.info(f"Redis INFO MEMORY: maxmemory_policy={maxmemory_policy}")
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
@@ -289,6 +284,8 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.indexing",
"onyx.background.celery.tasks.periodic",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.external_group_syncing",
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",

View File

@@ -4,6 +4,7 @@ from typing import Any
from typing import cast
import httpx
from sqlalchemy.orm import Session
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
@@ -15,14 +16,72 @@ from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import Document
from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import TaskStatus
from onyx.db.models import TaskQueueState
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_connector import RedisConnector
from onyx.server.documents.models import DeletionAttemptSnapshot
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _get_deletion_status(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str,
) -> TaskQueueState | None:
"""We no longer store TaskQueueState in the DB for a deletion attempt.
This function populates TaskQueueState by just checking redis.
"""
cc_pair = get_connector_credential_pair(
connector_id=connector_id, credential_id=credential_id, db_session=db_session
)
if not cc_pair:
return None
redis_connector = RedisConnector(tenant_id, cc_pair.id)
if redis_connector.delete.fenced:
return TaskQueueState(
task_id="",
task_name=redis_connector.delete.fence_key,
status=TaskStatus.STARTED,
)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return TaskQueueState(
task_id="",
task_name=redis_connector.delete.fence_key,
status=TaskStatus.PENDING,
)
return None
def get_deletion_attempt_snapshot(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str,
) -> DeletionAttemptSnapshot | None:
deletion_task = _get_deletion_status(
connector_id, credential_id, db_session, tenant_id
)
if not deletion_task:
return None
return DeletionAttemptSnapshot(
connector_id=connector_id,
credential_id=credential_id,
status=deletion_task.status,
)
def document_batch_to_ids(
doc_batch: list[Document],
) -> set[str]:

View File

@@ -226,16 +226,6 @@ if not MULTI_TENANT:
"queue": OnyxCeleryQueues.MONITORING,
},
},
{
"name": "celery-beat-heartbeat",
"task": OnyxCeleryTask.CELERY_BEAT_HEARTBEAT,
"schedule": timedelta(minutes=1),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.PRIMARY,
},
},
]
)

View File

@@ -28,9 +28,6 @@ from onyx.db.connector_credential_pair import add_deletion_failure_message
from onyx.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
)
from onyx.db.connector_credential_pair import (
delete_userfiles_for_cc_pair__no_commit,
)
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import (
@@ -47,7 +44,6 @@ from onyx.db.search_settings import get_all_search_settings
from onyx.db.sync_record import cleanup_sync_records
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.db.tag import delete_orphan_tags__no_commit
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
@@ -449,9 +445,6 @@ def monitor_connector_deletion_taskset(
db_session=db_session,
)
# delete orphan tags
delete_orphan_tags__no_commit(db_session)
# Store IDs before potentially expiring cc_pair
connector_id_to_delete = cc_pair.connector_id
credential_id_to_delete = cc_pair.credential_id
@@ -471,12 +464,6 @@ def monitor_connector_deletion_taskset(
# related to the deleted DocumentByConnectorCredentialPair during commit
db_session.expire(cc_pair)
# delete all userfiles for the cc_pair
delete_userfiles_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,

View File

@@ -16,10 +16,6 @@ from redis import Redis
from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from tenacity import retry
from tenacity import retry_if_exception
from tenacity import stop_after_delay
from tenacity import wait_random_exponential
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
@@ -35,6 +31,7 @@ from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
@@ -50,10 +47,8 @@ from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.factory import validate_ccpair_for_user
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
@@ -62,7 +57,6 @@ from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
from onyx.db.utils import is_retryable_sqlalchemy_error
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
@@ -79,12 +73,11 @@ from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
logger = setup_logger()
DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES = 3
DOCUMENT_PERMISSIONS_UPDATE_STOP_AFTER = 10 * 60
DOCUMENT_PERMISSIONS_UPDATE_MAX_WAIT = 60
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
@@ -456,21 +449,7 @@ def connector_permission_sync_generator_task(
redis_connector.permissions.set_fence(new_payload)
callback = PermissionSyncCallback(redis_connector, lock, r)
# pass in the capability to fetch all existing docs for the cc_pair
# this is can be used to determine documents that are "missing" and thus
# should no longer be accessible. The decision as to whether we should find
# every document during the doc sync process is connector-specific.
def fetch_all_existing_docs_fn() -> list[str]:
return get_document_ids_for_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
document_external_accesses = doc_sync_func(
cc_pair, fetch_all_existing_docs_fn, callback
)
document_external_accesses = doc_sync_func(cc_pair, callback)
task_logger.info(
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
@@ -478,13 +457,13 @@ def connector_permission_sync_generator_task(
tasks_generated = 0
for doc_external_access in document_external_accesses:
redis_connector.permissions.update_db(
redis_connector.permissions.generate_tasks(
celery_app=self.app,
lock=lock,
new_permissions=[doc_external_access],
source_string=source_type,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
task_logger=task_logger,
)
tasks_generated += 1
@@ -497,7 +476,6 @@ def connector_permission_sync_generator_task(
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id} {error_msg}"
)
@@ -518,28 +496,33 @@ def connector_permission_sync_generator_task(
)
# NOTE(rkuo): this should probably move to the db layer
@retry(
retry=retry_if_exception(is_retryable_sqlalchemy_error),
wait=wait_random_exponential(
multiplier=1, max=DOCUMENT_PERMISSIONS_UPDATE_MAX_WAIT
),
stop=stop_after_delay(DOCUMENT_PERMISSIONS_UPDATE_STOP_AFTER),
@shared_task(
name=OnyxCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
bind=True,
)
def document_update_permissions(
def update_external_document_permissions_task(
self: Task,
tenant_id: str,
permissions: DocExternalAccess,
source_type_str: str,
serialized_doc_external_access: dict,
source_string: str,
connector_id: int,
credential_id: int,
) -> bool:
start = time.monotonic()
doc_id = permissions.doc_id
external_access = permissions.external_access
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
document_external_access = DocExternalAccess.from_dict(
serialized_doc_external_access
)
doc_id = document_external_access.doc_id
external_access = document_external_access.external_access
try:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
# Add the users to the DB if they don't exist
batch_add_ext_perm_user_if_not_exists(
db_session=db_session,
@@ -551,7 +534,7 @@ def document_update_permissions(
db_session=db_session,
doc_id=doc_id,
external_access=external_access,
source_type=DocumentSource(source_type_str),
source_type=DocumentSource(source_string),
)
if created_new_doc:
@@ -570,105 +553,32 @@ def document_update_permissions(
f"action=update_permissions "
f"elapsed={elapsed:.2f}"
)
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Exception in update_external_document_permissions_task: connector_id={connector_id} doc_id={doc_id} {error_msg}"
)
task_logger.exception(
f"document_update_permissions exceptioned: "
f"update_external_document_permissions_task exceptioned: "
f"connector_id={connector_id} doc_id={doc_id}"
)
raise e
completion_status = OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
finally:
task_logger.info(
f"document_update_permissions completed: connector_id={connector_id} doc={doc_id}"
f"update_external_document_permissions_task completed: status={completion_status.value} doc={doc_id}"
)
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
return False
task_logger.info(
f"update_external_document_permissions_task finished: connector_id={connector_id} doc_id={doc_id}"
)
return True
# NOTE(rkuo): Deprecating this due to degenerate behavior in Redis from sending
# large permissions through celery (over 1MB in size)
# @shared_task(
# name=OnyxCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
# soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
# time_limit=LIGHT_TIME_LIMIT,
# max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
# bind=True,
# )
# def update_external_document_permissions_task(
# self: Task,
# tenant_id: str,
# serialized_doc_external_access: dict,
# source_string: str,
# connector_id: int,
# credential_id: int,
# ) -> bool:
# start = time.monotonic()
# completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
# document_external_access = DocExternalAccess.from_dict(
# serialized_doc_external_access
# )
# doc_id = document_external_access.doc_id
# external_access = document_external_access.external_access
# try:
# with get_session_with_current_tenant() as db_session:
# # Add the users to the DB if they don't exist
# batch_add_ext_perm_user_if_not_exists(
# db_session=db_session,
# emails=list(external_access.external_user_emails),
# continue_on_error=True,
# )
# # Then upsert the document's external permissions
# created_new_doc = upsert_document_external_perms(
# db_session=db_session,
# doc_id=doc_id,
# external_access=external_access,
# source_type=DocumentSource(source_string),
# )
# if created_new_doc:
# # If a new document was created, we associate it with the cc_pair
# upsert_document_by_connector_credential_pair(
# db_session=db_session,
# connector_id=connector_id,
# credential_id=credential_id,
# document_ids=[doc_id],
# )
# elapsed = time.monotonic() - start
# task_logger.info(
# f"connector_id={connector_id} "
# f"doc={doc_id} "
# f"action=update_permissions "
# f"elapsed={elapsed:.2f}"
# )
# completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
# except Exception as e:
# error_msg = format_error_for_logging(e)
# task_logger.warning(
# f"Exception in update_external_document_permissions_task: connector_id={connector_id} doc_id={doc_id} {error_msg}"
# )
# task_logger.exception(
# f"update_external_document_permissions_task exceptioned: "
# f"connector_id={connector_id} doc_id={doc_id}"
# )
# completion_status = OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
# finally:
# task_logger.info(
# f"update_external_document_permissions_task completed: status={completion_status.value} doc={doc_id}"
# )
# if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
# return False
# task_logger.info(
# f"update_external_document_permissions_task finished: connector_id={connector_id} doc_id={doc_id}"
# )
# return True
def validate_permission_sync_fences(
tenant_id: str,
r: Redis,

View File

@@ -14,9 +14,6 @@ from pydantic import ValidationError
from redis import Redis
from redis.lock import Lock as RedisLock
from ee.onyx.background.celery.tasks.external_group_syncing.group_sync_utils import (
mark_all_relevant_cc_pairs_as_external_group_synced,
)
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source
from ee.onyx.db.external_perm import ExternalUserGroup
@@ -41,6 +38,8 @@ from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.factory import validate_ccpair_for_user
from onyx.db.connector import mark_cc_pair_as_external_group_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
@@ -387,6 +386,24 @@ def connector_external_group_sync_generator_task(
f"No connector credential pair found for id: {cc_pair_id}"
)
try:
created = validate_ccpair_for_user(
cc_pair.connector.id,
cc_pair.credential.id,
db_session,
enforce_creation=False,
)
if not created:
task_logger.warning(
f"Unable to create connector credential pair for id: {cc_pair_id}"
)
except Exception:
task_logger.exception(
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
)
# TODO: add some notification to the admins here
raise
source_type = cc_pair.connector.source
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
@@ -423,7 +440,7 @@ def connector_external_group_sync_generator_task(
f"Synced {len(external_user_groups)} external user groups for {source_type}"
)
mark_all_relevant_cc_pairs_as_external_group_synced(db_session, cc_pair)
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
update_sync_record_status(
db_session=db_session,

View File

@@ -26,7 +26,6 @@ from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.background.celery.tasks.indexing.utils import is_in_repeated_error_state
from onyx.background.celery.tasks.indexing.utils import should_index
from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_task
from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences
@@ -55,12 +54,11 @@ from onyx.connectors.exceptions import ConnectorValidationError
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import set_cc_pair_repeated_error_state
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingMode
from onyx.db.enums import IndexingStatus
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.search_settings import get_active_search_settings_list
@@ -243,16 +241,6 @@ def monitor_ccpair_indexing_taskset(
if not payload:
return
# if the CC Pair is `SCHEDULED`, moved it to `INITIAL_INDEXING`. A CC Pair
# should only ever be `SCHEDULED` if it's a new connector.
cc_pair = get_connector_credential_pair_from_id(db_session, cc_pair_id)
if cc_pair is None:
raise RuntimeError(f"CC Pair {cc_pair_id} not found")
if cc_pair.status == ConnectorCredentialPairStatus.SCHEDULED:
cc_pair.status = ConnectorCredentialPairStatus.INITIAL_INDEXING
db_session.commit()
elapsed_started_str = None
if payload.started:
elapsed_started = datetime.now(timezone.utc) - payload.started
@@ -367,24 +355,6 @@ def monitor_ccpair_indexing_taskset(
redis_connector_index.reset()
# mark the CC Pair as `ACTIVE` if the attempt was a success and the
# CC Pair is not active not already
# This should never technically be in this state, but we'll handle it anyway
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
index_attempt_is_successful = index_attempt and index_attempt.status.is_successful()
if (
index_attempt_is_successful
and cc_pair.status == ConnectorCredentialPairStatus.SCHEDULED
or cc_pair.status == ConnectorCredentialPairStatus.INITIAL_INDEXING
):
cc_pair.status = ConnectorCredentialPairStatus.ACTIVE
db_session.commit()
# if the index attempt is successful, clear the repeated error state
if cc_pair.in_repeated_error_state and index_attempt_is_successful:
cc_pair.in_repeated_error_state = False
db_session.commit()
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_INDEXING,
@@ -471,21 +441,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
# mark CC Pairs that are repeatedly failing as in repeated error state
with get_session_with_current_tenant() as db_session:
current_search_settings = get_current_search_settings(db_session)
for cc_pair_id in cc_pair_ids:
if is_in_repeated_error_state(
cc_pair_id=cc_pair_id,
search_settings_id=current_search_settings.id,
db_session=db_session,
):
set_cc_pair_repeated_error_state(
db_session=db_session,
cc_pair_id=cc_pair_id,
in_repeated_error_state=True,
)
# kick off index attempts
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
@@ -509,7 +464,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
search_settings_instance.id
)
if redis_connector_index.fenced:
task_logger.debug(
task_logger.info(
f"check_for_indexing - Skipping fenced connector: "
f"cc_pair={cc_pair_id} search_settings={search_settings_instance.id}"
)
@@ -525,22 +480,29 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
)
continue
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, search_settings_instance.id, db_session
)
if not should_index(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
secondary_index_building=len(search_settings_list) > 1,
db_session=db_session,
):
task_logger.debug(
task_logger.info(
f"check_for_indexing - Not indexing cc_pair_id: {cc_pair_id} "
f"search_settings={search_settings_instance.id}, "
f"last_attempt={last_attempt.id if last_attempt else None}, "
f"secondary_index_building={len(search_settings_list) > 1}"
)
continue
else:
task_logger.debug(
task_logger.info(
f"check_for_indexing - Will index cc_pair_id: {cc_pair_id} "
f"search_settings={search_settings_instance.id}, "
f"last_attempt={last_attempt.id if last_attempt else None}, "
f"secondary_index_building={len(search_settings_list) > 1}"
)
@@ -898,16 +860,7 @@ def connector_indexing_task(
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
# special bulletproofing ... truncate long exception messages
# for exception types that require more args, this will fail
# thus the try/except
try:
sanitized_e = type(e)(str(e)[:1024])
sanitized_e.__traceback__ = e.__traceback__
raise sanitized_e
except Exception:
raise e
raise e
finally:
if lock.owned():
@@ -1072,10 +1025,6 @@ def connector_indexing_proxy_task(
# Track the last time memory info was emitted
last_memory_emit_time = 0.0
# track the last ttl and the time it was observed
last_activity_ttl_observed: float = time.monotonic()
last_activity_ttl: int = 0
try:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
@@ -1089,15 +1038,11 @@ def connector_indexing_proxy_task(
)
redis_connector_index.set_active() # renew active signal
# prime the connector active signal (renewed inside the connector)
redis_connector_index.set_connector_active()
redis_connector_index.set_connector_active() # prime the connective active signal
while True:
sleep(5)
now = time.monotonic()
# renew watchdog signal (this has a shorter timeout than set_active)
redis_connector_index.set_watchdog(True)
@@ -1147,37 +1092,18 @@ def connector_indexing_proxy_task(
break
# if activity timeout is detected, break (exit point will clean up)
ttl = redis_connector_index.connector_active_ttl()
if ttl < 0:
# verify expectations around ttl
last_observed = last_activity_ttl_observed - now
if now > last_activity_ttl_observed + last_activity_ttl:
task_logger.warning(
log_builder.build(
"Indexing watchdog - activity timeout exceeded",
last_observed=f"{last_observed:.2f}s",
last_ttl=f"{last_activity_ttl}",
timeout=f"{CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
)
if not redis_connector_index.connector_active():
task_logger.warning(
log_builder.build(
"Indexing watchdog - activity timeout exceeded",
timeout=f"{CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
)
)
result.status = (
IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT
)
break
else:
task_logger.warning(
log_builder.build(
"Indexing watchdog - activity timeout expired unexpectedly, "
"waiting for last observed TTL before exiting",
last_observed=f"{last_observed:.2f}s",
last_ttl=f"{last_activity_ttl}",
timeout=f"{CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
)
)
else:
last_activity_ttl_observed = now
last_activity_ttl = ttl
result.status = (
IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT
)
break
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status

View File

@@ -22,7 +22,6 @@ from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
@@ -32,8 +31,6 @@ from onyx.db.index_attempt import create_index_attempt
from onyx.db.index_attempt import delete_index_attempt
from onyx.db.index_attempt import get_all_index_attempts_by_status
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
@@ -47,8 +44,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE = 5
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
@@ -129,8 +124,6 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
return False
def progress(self, tag: str, amount: int) -> None:
"""Amount isn't used yet."""
# rkuo: this shouldn't be necessary yet because we spawn the process this runs inside
# with daemon=True. It seems likely some indexing tasks will need to spawn other processes
# eventually, which daemon=True prevents, so leave this code in until we're ready to test it.
@@ -353,42 +346,9 @@ def validate_indexing_fences(
return
def is_in_repeated_error_state(
cc_pair_id: int, search_settings_id: int, db_session: Session
) -> bool:
"""Checks if the cc pair / search setting combination is in a repeated error state."""
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
raise RuntimeError(
f"is_in_repeated_error_state - could not find cc_pair with id={cc_pair_id}"
)
# if the connector doesn't have a refresh_freq, a single failed attempt is enough
number_of_failed_attempts_in_a_row_needed = (
NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE
if cc_pair.connector.refresh_freq is not None
else 1
)
most_recent_index_attempts = get_recent_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
search_settings_id=search_settings_id,
limit=number_of_failed_attempts_in_a_row_needed,
db_session=db_session,
)
return len(
most_recent_index_attempts
) >= number_of_failed_attempts_in_a_row_needed and all(
attempt.status == IndexingStatus.FAILED
for attempt in most_recent_index_attempts
)
def should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
secondary_index_building: bool,
db_session: Session,
@@ -402,16 +362,6 @@ def should_index(
Return True if we should try to index, False if not.
"""
connector = cc_pair.connector
last_index_attempt = get_last_attempt_for_cc_pair(
cc_pair_id=cc_pair.id,
search_settings_id=search_settings_instance.id,
db_session=db_session,
)
all_recent_errored = is_in_repeated_error_state(
cc_pair_id=cc_pair.id,
search_settings_id=search_settings_instance.id,
db_session=db_session,
)
# uncomment for debugging
# task_logger.info(f"_should_index: "
@@ -438,24 +388,24 @@ def should_index(
# When switching over models, always index at least once
if search_settings_instance.status == IndexModelStatus.FUTURE:
if last_index_attempt:
if last_index:
# No new index if the last index attempt succeeded
# Once is enough. The model will never be able to swap otherwise.
if last_index_attempt.status == IndexingStatus.SUCCESS:
if last_index.status == IndexingStatus.SUCCESS:
# print(
# f"Not indexing cc_pair={cc_pair.id}: FUTURE model with successful last index attempt={last_index.id}"
# )
return False
# No new index if the last index attempt is waiting to start
if last_index_attempt.status == IndexingStatus.NOT_STARTED:
if last_index.status == IndexingStatus.NOT_STARTED:
# print(
# f"Not indexing cc_pair={cc_pair.id}: FUTURE model with NOT_STARTED last index attempt={last_index.id}"
# )
return False
# No new index if the last index attempt is running
if last_index_attempt.status == IndexingStatus.IN_PROGRESS:
if last_index.status == IndexingStatus.IN_PROGRESS:
# print(
# f"Not indexing cc_pair={cc_pair.id}: FUTURE model with IN_PROGRESS last index attempt={last_index.id}"
# )
@@ -489,27 +439,18 @@ def should_index(
return True
# if no attempt has ever occurred, we should index regardless of refresh_freq
if not last_index_attempt:
if not last_index:
return True
if connector.refresh_freq is None:
# print(f"Not indexing cc_pair={cc_pair.id}: refresh_freq is None")
return False
# if in the "initial" phase, we should always try and kick-off indexing
# as soon as possible if there is no ongoing attempt. In other words,
# no delay UNLESS we're repeatedly failing to index.
if (
cc_pair.status == ConnectorCredentialPairStatus.INITIAL_INDEXING
and not all_recent_errored
):
return True
current_db_time = get_db_current_time(db_session)
time_since_index = current_db_time - last_index_attempt.time_updated
time_since_index = current_db_time - last_index.time_updated
if time_since_index.total_seconds() < connector.refresh_freq:
# print(
# f"Not indexing cc_pair={cc_pair.id}: Last index attempt={last_index_attempt.id} "
# f"Not indexing cc_pair={cc_pair.id}: Last index attempt={last_index.id} "
# f"too recent ({time_since_index.total_seconds()}s < {connector.refresh_freq}s)"
# )
return False

View File

@@ -10,7 +10,6 @@ from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import LLMProvider
from onyx.db.models import ModelConfiguration
def _process_model_list_response(model_list_json: Any) -> list[str]:
@@ -89,10 +88,7 @@ def check_for_llm_model_update(self: Task, *, tenant_id: str) -> bool | None:
return None
# log change if any
old_models = set(
model_configuration.name
for model_configuration in default_provider.model_configurations
)
old_models = set(default_provider.model_names or [])
new_models = set(available_models)
added_models = new_models - old_models
removed_models = old_models - new_models
@@ -103,23 +99,7 @@ def check_for_llm_model_update(self: Task, *, tenant_id: str) -> bool | None:
task_logger.info(f"Removing models: {sorted(removed_models)}")
# Update the provider's model list
# Remove models that are no longer available
db_session.query(ModelConfiguration).filter(
ModelConfiguration.llm_provider_id == default_provider.id,
ModelConfiguration.name.notin_(available_models),
).delete(synchronize_session=False)
# Add new models
for available_model_name in available_models:
db_session.merge(
ModelConfiguration(
llm_provider_id=default_provider.id,
name=available_model_name,
is_visible=False,
max_input_tokens=None,
)
)
default_provider.model_names = available_models
# if the default model is no longer available, set it to the first model in the list
if default_provider.default_model_name not in available_models:
task_logger.info(

View File

@@ -49,7 +49,6 @@ from onyx.db.models import ConnectorCredentialPair
from onyx.db.search_settings import get_current_search_settings
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.db.tag import delete_orphan_tags__no_commit
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
@@ -562,8 +561,6 @@ def monitor_ccpair_pruning_taskset(
num_docs_synced=initial,
)
delete_orphan_tags__no_commit(db_session)
redis_connector.prune.taskset_clear()
redis_connector.prune.generator_clear()
redis_connector.prune.set_fence(None)

View File

@@ -6,14 +6,19 @@ import httpx
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from tenacity import RetryError
from ee.onyx.server.tenants.product_gating import get_gated_tenants
from onyx.access.access import get_access_for_document
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.configs.constants import ONYX_CELERY_BEAT_HEARTBEAT_KEY
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.document import delete_document_by_connector_credential_pair__no_commit
from onyx.db.document import delete_documents_complete__no_commit
from onyx.db.document import fetch_chunk_count_for_document
@@ -22,13 +27,16 @@ from onyx.db.document import get_document_connector_count
from onyx.db.document import mark_document_as_modified
from onyx.db.document import mark_document_as_synced
from onyx.db.document_set import fetch_document_sets_for_document
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.search_settings import get_active_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3
@@ -262,14 +270,86 @@ def document_by_cc_pair_cleanup_task(
return True
@shared_task(name=OnyxCeleryTask.CELERY_BEAT_HEARTBEAT, ignore_result=True, bind=True)
def celery_beat_heartbeat(self: Task, *, tenant_id: str) -> None:
"""When this task runs, it writes a key to Redis with a TTL.
An external observer can check this key to figure out if the celery beat is still running.
"""
@shared_task(
name=OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
ignore_result=True,
trail=False,
bind=True,
)
def cloud_beat_task_generator(
self: Task,
task_name: str,
queue: str = OnyxCeleryTask.DEFAULT,
priority: int = OnyxCeleryPriority.MEDIUM,
expires: int = BEAT_EXPIRES_DEFAULT,
) -> bool | None:
"""a lightweight task used to kick off individual beat tasks per tenant."""
time_start = time.monotonic()
r: Redis = get_redis_client()
r.set(ONYX_CELERY_BEAT_HEARTBEAT_KEY, 1, ex=600)
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
lock_beat: RedisLock = redis_client.lock(
f"{OnyxRedisLocks.CLOUD_BEAT_TASK_GENERATOR_LOCK}:{task_name}",
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
last_lock_time = time.monotonic()
tenant_ids: list[str] = []
num_processed_tenants = 0
try:
tenant_ids = get_all_tenant_ids()
gated_tenants = get_gated_tenants()
for tenant_id in tenant_ids:
if tenant_id in gated_tenants:
continue
current_time = time.monotonic()
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
lock_beat.reacquire()
last_lock_time = current_time
# needed in the cloud
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
continue
self.app.send_task(
task_name,
kwargs=dict(
tenant_id=tenant_id,
),
queue=queue,
priority=priority,
expires=expires,
ignore_result=True,
)
num_processed_tenants += 1
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception during cloud_beat_task_generator")
finally:
if not lock_beat.owned():
task_logger.error(
"cloud_beat_task_generator - Lock not owned on completion"
)
redis_lock_dump(lock_beat, redis_client)
else:
lock_beat.release()
time_elapsed = time.monotonic() - time_start
task_logger.info(f"celery_beat_heartbeat finished: " f"elapsed={time_elapsed:.2f}")
task_logger.info(
f"cloud_beat_task_generator finished: "
f"task={task_name} "
f"num_processed_tenants={num_processed_tenants} "
f"num_tenants={len(tenant_ids)} "
f"elapsed={time_elapsed:.2f}"
)
return True

Some files were not shown because too many files have changed in this diff Show More