Compare commits

..

2 Commits

Author SHA1 Message Date
pablodanswer
0d52160a6f v2 2025-01-07 14:14:56 -08:00
pablodanswer
bbe9e9db74 add chrome extension
minor clean up

additional handling

post rebase fixes

nit

quick bump

finalize

minor cleanup

organizational

Revert changes in backend directory

Revert changes in deployment directory

push misc changes

improve shortcut display + general nrf page layout

minor clean up

quick nit

update chrome

k

build fix

k

update

k
2025-01-07 14:14:45 -08:00
1249 changed files with 32617 additions and 105449 deletions

1
.github/CODEOWNERS vendored
View File

@@ -1 +0,0 @@
* @onyx-dot-app/onyx-core-team

View File

@@ -1,14 +1,11 @@
## Description
[Provide a brief description of the changes in this PR]
## How Has This Been Tested?
## How Has This Been Tested?
[Describe the tests you ran to verify your changes]
## Backporting (check the box to trigger backport action)
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
- [ ] [Optional] Override Linear Check

View File

@@ -65,11 +65,8 @@ jobs:
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -12,43 +12,7 @@ env:
BUILDKIT_PROGRESS: plain
jobs:
# Bypassing this for now as the idea of not building is glitching
# releases and builds that depends on everything being tagged in docker
# 1) Preliminary job to check if the changed files are relevant
# check_model_server_changes:
# runs-on: ubuntu-latest
# outputs:
# changed: ${{ steps.check.outputs.changed }}
# steps:
# - name: Checkout code
# uses: actions/checkout@v4
#
# - name: Check if relevant files changed
# id: check
# run: |
# # Default to "false"
# echo "changed=false" >> $GITHUB_OUTPUT
#
# # Compare the previous commit (github.event.before) to the current one (github.sha)
# # If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
# # set changed=true
# if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
# | grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
# echo "changed=true" >> $GITHUB_OUTPUT
# fi
check_model_server_changes:
runs-on: ubuntu-latest
outputs:
changed: "true"
steps:
- name: Bypass check and set output
run: echo "changed=true" >> $GITHUB_OUTPUT
build-amd64:
needs: [check_model_server_changes]
if: needs.check_model_server_changes.outputs.changed == 'true'
runs-on:
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-amd64"]
steps:
@@ -88,8 +52,6 @@ jobs:
provenance: false
build-arm64:
needs: [check_model_server_changes]
if: needs.check_model_server_changes.outputs.changed == 'true'
runs-on:
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-arm64"]
steps:
@@ -129,8 +91,7 @@ jobs:
provenance: false
merge-and-scan:
needs: [build-amd64, build-arm64, check_model_server_changes]
if: needs.check_model_server_changes.outputs.changed == 'true'
needs: [build-amd64, build-arm64]
runs-on: ubuntu-latest
steps:
- name: Login to Docker Hub
@@ -157,6 +118,6 @@ jobs:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
image-ref: docker.io/onyxdotapp/onyx-model-server:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
timeout: "10m"

View File

@@ -60,8 +60,6 @@ jobs:
push: true
build-args: |
ONYX_VERSION=${{ github.ref_name }}
NODE_OPTIONS=--max-old-space-size=8192
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -53,90 +53,24 @@ jobs:
exclude: '(?i)^(pylint|aio[-_]*).*'
- name: Print report
if: always()
if: ${{ always() }}
run: echo "${{ steps.license_check_report.outputs.report }}"
- name: Install npm dependencies
working-directory: ./web
run: npm ci
- name: Run Trivy vulnerability scanner in repo mode
uses: aquasecurity/trivy-action@0.28.0
with:
scan-type: fs
scanners: license
format: table
# format: sarif
# output: trivy-results.sarif
severity: HIGH,CRITICAL
# be careful enabling the sarif and upload as it may spam the security tab
# with a huge amount of items. Work out the issues before enabling upload.
# - name: Run Trivy vulnerability scanner in repo mode
# if: always()
# uses: aquasecurity/trivy-action@0.29.0
# - name: Upload Trivy scan results to GitHub Security tab
# uses: github/codeql-action/upload-sarif@v3
# with:
# scan-type: fs
# scan-ref: .
# scanners: license
# format: table
# severity: HIGH,CRITICAL
# # format: sarif
# # output: trivy-results.sarif
#
# # - name: Upload Trivy scan results to GitHub Security tab
# # uses: github/codeql-action/upload-sarif@v3
# # with:
# # sarif_file: trivy-results.sarif
scan-trivy:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# Backend
- name: Pull backend docker image
run: docker pull onyxdotapp/onyx-backend:latest
- name: Run Trivy vulnerability scanner on backend
uses: aquasecurity/trivy-action@0.29.0
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-backend:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
# Web server
- name: Pull web server docker image
run: docker pull onyxdotapp/onyx-web-server:latest
- name: Run Trivy vulnerability scanner on web server
uses: aquasecurity/trivy-action@0.29.0
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-web-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0
# Model server
- name: Pull model server docker image
run: docker pull onyxdotapp/onyx-model-server:latest
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@0.29.0
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-model-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0
# sarif_file: trivy-results.sarif

View File

@@ -1,6 +1,6 @@
name: Run Playwright Tests
name: Run Chromatic Tests
concurrency:
group: Run-Playwright-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
group: Run-Chromatic-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on: push
@@ -8,8 +8,6 @@ on: push
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
MOCK_LLM_RESPONSE: true
jobs:
playwright-tests:
@@ -198,47 +196,43 @@ jobs:
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
# NOTE: Chromatic UI diff testing is currently disabled.
# We are using Playwright for local and CI testing without visual regression checks.
# Chromatic may be reintroduced in the future for UI diff testing if needed.
chromatic-tests:
name: Chromatic Tests
# chromatic-tests:
# name: Chromatic Tests
needs: playwright-tests
runs-on:
[
runs-on,
runner=32cpu-linux-x64,
disk=large,
"run-id=${{ github.run_id }}",
]
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
# needs: playwright-tests
# runs-on:
# [
# runs-on,
# runner=32cpu-linux-x64,
# disk=large,
# "run-id=${{ github.run_id }}",
# ]
# steps:
# - name: Checkout code
# uses: actions/checkout@v4
# with:
# fetch-depth: 0
- name: Setup node
uses: actions/setup-node@v4
with:
node-version: 22
# - name: Setup node
# uses: actions/setup-node@v4
# with:
# node-version: 22
- name: Install node dependencies
working-directory: ./web
run: npm ci
# - name: Install node dependencies
# working-directory: ./web
# run: npm ci
- name: Download Playwright test results
uses: actions/download-artifact@v4
with:
name: test-results
path: ./web/test-results
# - name: Download Playwright test results
# uses: actions/download-artifact@v4
# with:
# name: test-results
# path: ./web/test-results
# - name: Run Chromatic
# uses: chromaui/action@latest
# with:
# playwright: true
# projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
# workingDir: ./web
# env:
# CHROMATIC_ARCHIVE_LOCATION: ./test-results
- name: Run Chromatic
uses: chromaui/action@latest
with:
playwright: true
projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
workingDir: ./web
env:
CHROMATIC_ARCHIVE_LOCATION: ./test-results

View File

@@ -21,10 +21,10 @@ jobs:
- name: Set up Helm
uses: azure/setup-helm@v4.2.0
with:
version: v3.17.0
version: v3.14.4
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.7.0
uses: helm/chart-testing-action@v2.6.1
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
- name: Run chart-testing (list-changed)
@@ -37,6 +37,22 @@ jobs:
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
# rkuo: I don't think we need python?
# - name: Set up Python
# uses: actions/setup-python@v5
# with:
# python-version: '3.11'
# cache: 'pip'
# cache-dependency-path: |
# backend/requirements/default.txt
# backend/requirements/dev.txt
# backend/requirements/model_server.txt
# - run: |
# python -m pip install --upgrade pip
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
# lint all charts if any changes were detected
- name: Run chart-testing (lint)
if: steps.list-changed.outputs.changed == 'true'
@@ -46,7 +62,7 @@ jobs:
- name: Create kind cluster
if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@v1.12.0
uses: helm/kind-action@v1.10.0
- name: Run chart-testing (install)
if: steps.list-changed.outputs.changed == 'true'

View File

@@ -94,27 +94,23 @@ jobs:
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
MULTI_TENANT=true \
AUTH_TYPE=cloud \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
DEV_MODE=true \
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack up -d
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
id: start_docker_multi_tenant
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
- name: Run Multi-Tenant Integration Tests
run: |
echo "Waiting for 3 minutes to ensure API server is ready..."
sleep 180
echo "Running integration tests..."
docker run --rm --network onyx-stack_default \
docker run --rm --network danswer-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
@@ -123,10 +119,6 @@ jobs:
-e TEST_WEB_HOSTNAME=test-runner \
-e AUTH_TYPE=cloud \
-e MULTI_TENANT=true \
-e REQUIRE_EMAIL_VERIFICATION=false \
-e DISABLE_TELEMETRY=true \
-e IMAGE_TAG=test \
-e DEV_MODE=true \
onyxdotapp/onyx-integration:test \
/app/tests/integration/multitenant_tests
continue-on-error: true
@@ -134,38 +126,34 @@ jobs:
- name: Check multi-tenant test results
run: |
if [ ${{ steps.run_multitenant_tests.outcome }} == 'failure' ]; then
echo "Multi-tenant integration tests failed. Exiting with error."
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All multi-tenant integration tests passed successfully."
echo "All integration tests passed successfully."
fi
- name: Stop multi-tenant Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
POSTGRES_POOL_PRE_PING=true \
POSTGRES_USE_NULL_POOL=true \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
INTEGRATION_TESTS_MODE=true \
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
docker logs -f onyx-stack-api_server-1 &
docker logs -f danswer-stack-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
@@ -195,24 +183,15 @@ jobs:
done
echo "Finished waiting for service."
- name: Start Mock Services
run: |
cd backend/tests/integration/mock_services
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Standard Integration Tests
run: |
echo "Running integration tests..."
docker run --rm --network onyx-stack_default \
docker run --rm --network danswer-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
@@ -222,8 +201,6 @@ jobs:
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
onyxdotapp/onyx-integration:test \
/app/tests/integration/tests \
/app/tests/integration/connector_job_tests
@@ -239,30 +216,27 @@ jobs:
echo "All integration tests passed successfully."
fi
# ------------------------------------------------------------
# Always gather logs BEFORE "down":
- name: Dump API server logs
if: always()
# save before stopping the containers so the logs can be captured
- name: Save Docker logs
if: success() || failure()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@v4
with:
name: docker-all-logs
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@v4
with:
name: docker-logs
path: ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v

View File

@@ -1,29 +0,0 @@
name: Ensure PR references Linear
on:
pull_request:
types: [opened, edited, reopened, synchronize]
jobs:
linear-check:
runs-on: ubuntu-latest
steps:
- name: Check PR body for Linear link or override
env:
PR_BODY: ${{ github.event.pull_request.body }}
run: |
# Looking for "https://linear.app" in the body
if echo "$PR_BODY" | grep -qE "https://linear\.app"; then
echo "Found a Linear link. Check passed."
exit 0
fi
# Looking for a checked override: "[x] Override Linear Check"
if echo "$PR_BODY" | grep -q "\[x\].*Override Linear Check"; then
echo "Override box is checked. Check passed."
exit 0
fi
# Otherwise, fail the run
echo "No Linear link or override found in the PR description."
exit 1

View File

@@ -1,209 +0,0 @@
name: Run MIT Integration Tests v2
concurrency:
group: Run-MIT-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
integration-tests-mit:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# We don't need to build the Web Docker image since it's not yet used
# in the integration tests. We have a separate action to verify that it builds
# successfully.
- name: Pull Web Docker image
run: |
docker pull onyxdotapp/onyx-web-server:latest
docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test
# we use the runs-on cache for docker builds
# in conjunction with runs-on runners, it has better speed and unlimited caching
# https://runs-on.com/caching/s3-cache-for-github-actions/
# https://runs-on.com/caching/docker/
# https://github.com/moby/buildkit#s3-cache-experimental
# images are built and run locally for testing purposes. Not pushed.
- name: Build Backend Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64
tags: onyxdotapp/onyx-backend:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Model Server Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64
tags: onyxdotapp/onyx-model-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/amd64
tags: onyxdotapp/onyx-integration:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
- name: Start Docker containers
run: |
cd deployment/docker_compose
AUTH_TYPE=basic \
POSTGRES_POOL_PRE_PING=true \
POSTGRES_USE_NULL_POOL=true \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
INTEGRATION_TESTS_MODE=true \
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
docker logs -f onyx-stack-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Start Mock Services
run: |
cd backend/tests/integration/mock_services
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Standard Integration Tests
run: |
echo "Running integration tests..."
docker run --rm --network onyx-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
onyxdotapp/onyx-integration:test \
/app/tests/integration/tests \
/app/tests/integration/connector_job_tests
continue-on-error: true
id: run_tests
- name: Check test results
run: |
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All integration tests passed successfully."
fi
# ------------------------------------------------------------
# Always gather logs BEFORE "down":
- name: Dump API server logs
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@v4
with:
name: docker-all-logs
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack down -v

View File

@@ -1,7 +1,6 @@
name: Connector Tests
on:
merge_group:
pull_request:
branches: [main]
schedule:
@@ -9,10 +8,6 @@ on:
- cron: "0 16 * * *"
env:
# AWS
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
@@ -44,26 +39,10 @@ env:
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
# Sharepoint
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
# Github
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
# Gitbook
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
# Notion
NOTION_INTEGRATION_TOKEN: ${{ secrets.NOTION_INTEGRATION_TOKEN }}
# Highspot
HIGHSPOT_KEY: ${{ secrets.HIGHSPOT_KEY }}
HIGHSPOT_SECRET: ${{ secrets.HIGHSPOT_SECRET }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
env:
PYTHONPATH: ./backend
@@ -86,8 +65,6 @@ jobs:
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
playwright install chromium
playwright install-deps chromium
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"

View File

@@ -1,29 +1,18 @@
name: Model Server Tests
name: Connector Tests
on:
schedule:
# This cron expression runs the job daily at 16:00 UTC (9am PT)
- cron: "0 16 * * *"
workflow_dispatch:
inputs:
branch:
description: 'Branch to run the workflow on'
required: false
default: 'main'
env:
# Bedrock
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
# API keys for testing
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
LITELLM_API_KEY: ${{ secrets.LITELLM_API_KEY }}
LITELLM_API_URL: ${{ secrets.LITELLM_API_URL }}
# OpenAI
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
AZURE_API_URL: ${{ secrets.AZURE_API_URL }}
jobs:
model-check:
@@ -37,23 +26,6 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# We don't need to build the Web Docker image since it's not yet used
# in the integration tests. We have a separate action to verify that it builds
# successfully.
- name: Pull Model Server Docker image
run: |
docker pull onyxdotapp/onyx-model-server:latest
docker tag onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:test
- name: Set up Python
uses: actions/setup-python@v5
with:
@@ -69,49 +41,6 @@ jobs:
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.model-server-test.yml -p onyx-stack up -d indexing_model_server
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:9000/api/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
@@ -127,23 +56,3 @@ jobs:
-H 'Content-type: application/json' \
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
$SLACK_WEBHOOK
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.model-server-test.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@v4
with:
name: docker-all-logs
path: ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.model-server-test.yml -p onyx-stack down -v

4
.gitignore vendored
View File

@@ -7,6 +7,4 @@
.vscode/
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml
/web/test-results/
backend/onyx/agent_search/main/test_data.json
backend/tests/regression/answer_quality/test_data.json
/web/test-results/

View File

@@ -5,8 +5,6 @@
# For local dev, often user Authentication is not needed
AUTH_TYPE=disabled
# Skip warm up for dev
SKIP_WARM_UP=True
# Always keep these on for Dev
# Logs all model prompts to stdout
@@ -29,7 +27,6 @@ REQUIRE_EMAIL_VERIFICATION=False
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
GEN_AI_API_KEY=<REPLACE THIS>
OPENAI_API_KEY=<REPLACE THIS>
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper
GEN_AI_MODEL_VERSION=gpt-4o
FAST_GEN_AI_MODEL_VERSION=gpt-4o
@@ -52,9 +49,3 @@ BING_API_KEY=<REPLACE THIS>
# Enable the full set of Danswer Enterprise Edition features
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
# Agent Search configs # TODO: Remove give proper namings
AGENT_RETRIEVAL_STATS=False # Note: This setting will incur substantial re-ranking effort
AGENT_RERANKING_STATS=True
AGENT_MAX_QUERY_RETRIEVAL_RESULTS=20
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS=20

View File

@@ -6,419 +6,354 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"compounds": [
{
// Dummy entry used to label the group
"name": "--- Compound ---",
"configurations": ["--- Individual ---"],
"presentation": {
"group": "1"
}
},
{
"name": "Run All Onyx Services",
"configurations": [
"Web Server",
"Model Server",
"API Server",
"Slack Bot",
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery user files indexing",
"Celery beat",
"Celery monitoring"
],
"presentation": {
"group": "1"
}
},
{
"name": "Web / Model / API",
"configurations": ["Web Server", "Model Server", "API Server"],
"presentation": {
"group": "1"
}
},
{
"name": "Celery (all)",
"configurations": [
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery user files indexing",
"Celery beat",
"Celery monitoring"
],
"presentation": {
"group": "1"
}
}
{
// Dummy entry used to label the group
"name": "--- Compound ---",
"configurations": [
"--- Individual ---"
],
"presentation": {
"group": "1",
}
},
{
"name": "Run All Onyx Services",
"configurations": [
"Web Server",
"Model Server",
"API Server",
"Slack Bot",
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery beat",
],
"presentation": {
"group": "1",
}
},
{
"name": "Web / Model / API",
"configurations": [
"Web Server",
"Model Server",
"API Server",
],
"presentation": {
"group": "1",
}
},
{
"name": "Celery (all)",
"configurations": [
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery beat"
],
"presentation": {
"group": "1",
}
}
],
"configurations": [
{
// Dummy entry used to label the group
"name": "--- Individual ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "2",
"order": 0
}
},
{
"name": "Web Server",
"type": "node",
"request": "launch",
"cwd": "${workspaceRoot}/web",
"runtimeExecutable": "npm",
"envFile": "${workspaceFolder}/.vscode/.env",
"runtimeArgs": ["run", "dev"],
"presentation": {
"group": "2"
{
// Dummy entry used to label the group
"name": "--- Individual ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "2",
"order": 0
}
},
{
"name": "Web Server",
"type": "node",
"request": "launch",
"cwd": "${workspaceRoot}/web",
"runtimeExecutable": "npm",
"envFile": "${workspaceFolder}/.vscode/.env",
"runtimeArgs": [
"run", "dev"
],
"presentation": {
"group": "2",
},
"console": "integratedTerminal",
"consoleTitle": "Web Server Console"
},
"console": "integratedTerminal",
"consoleTitle": "Web Server Console"
},
{
"name": "Model Server",
"consoleName": "Model Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
{
"name": "Model Server",
"consoleName": "Model Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
"args": [
"model_server.main:app",
"--reload",
"--port",
"9000"
],
"presentation": {
"group": "2",
},
"consoleTitle": "Model Server Console"
},
"args": ["model_server.main:app", "--reload", "--port", "9000"],
"presentation": {
"group": "2"
{
"name": "API Server",
"consoleName": "API Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
"args": [
"onyx.main:app",
"--reload",
"--port",
"8080"
],
"presentation": {
"group": "2",
},
"consoleTitle": "API Server Console"
},
"consoleTitle": "Model Server Console"
},
{
"name": "API Server",
"consoleName": "API Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
// For the listener to access the Slack API,
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
{
"name": "Slack Bot",
"consoleName": "Slack Bot",
"type": "debugpy",
"request": "launch",
"program": "onyx/onyxbot/slack/listener.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"presentation": {
"group": "2",
},
"consoleTitle": "Slack Bot Console"
},
"args": ["onyx.main:app", "--reload", "--port", "8080"],
"presentation": {
"group": "2"
{
"name": "Celery primary",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.primary",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=primary@%n",
"-Q",
"celery",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery primary Console"
},
"consoleTitle": "API Server Console"
},
// For the listener to access the Slack API,
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
{
"name": "Slack Bot",
"consoleName": "Slack Bot",
"type": "debugpy",
"request": "launch",
"program": "onyx/onyxbot/slack/listener.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
{
"name": "Celery light",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.light",
"worker",
"--pool=threads",
"--concurrency=64",
"--prefetch-multiplier=8",
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery light Console"
},
"presentation": {
"group": "2"
{
"name": "Celery heavy",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.heavy",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery heavy Console"
},
"consoleTitle": "Slack Bot Console"
},
{
"name": "Celery primary",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
{
"name": "Celery indexing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=indexing@%n",
"-Q",
"connector_indexing",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery indexing Console"
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.primary",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=primary@%n",
"-Q",
"celery"
],
"presentation": {
"group": "2"
{
"name": "Celery beat",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.beat",
"beat",
"--loglevel=INFO",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery beat Console"
},
"consoleTitle": "Celery primary Console"
},
{
"name": "Celery light",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
{
"name": "Pytest",
"consoleName": "Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-v"
// Specify a sepcific module/test to run or provide nothing to run all tests
//"tests/unit/onyx/llm/answering/test_prune_and_merge.py"
],
"presentation": {
"group": "2",
},
"consoleTitle": "Pytest Console"
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.light",
"worker",
"--pool=threads",
"--concurrency=64",
"--prefetch-multiplier=8",
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert"
],
"presentation": {
"group": "2"
{
// Dummy entry used to label the group
"name": "--- Tasks ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "3",
"order": 0
}
},
{
"name": "Clear and Restart External Volumes and Containers",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"stopOnEntry": true,
"presentation": {
"group": "3",
},
},
"consoleTitle": "Celery light Console"
},
{
"name": "Celery heavy",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
{
// Celery jobs launched through a single background script (legacy)
// Recommend using the "Celery (all)" compound launch instead.
"name": "Background Jobs",
"consoleName": "Background Jobs",
"type": "debugpy",
"request": "launch",
"program": "scripts/dev_run_background_jobs.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.heavy",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery heavy Console"
},
{
"name": "Celery indexing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=indexing@%n",
"-Q",
"connector_indexing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery indexing Console"
},
{
"name": "Celery monitoring",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {},
"args": [
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=solo",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery monitoring Console"
},
{
"name": "Celery beat",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.beat",
"beat",
"--loglevel=INFO"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery beat Console"
},
{
"name": "Celery user files indexing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=user_files_indexing@%n",
"-Q",
"user_files_indexing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery user files indexing Console"
},
{
"name": "Pytest",
"consoleName": "Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-v"
// Specify a sepcific module/test to run or provide nothing to run all tests
//"tests/unit/onyx/llm/answering/test_prune_and_merge.py"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Pytest Console"
},
{
// Dummy entry used to label the group
"name": "--- Tasks ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "3",
"order": 0
}
},
{
"name": "Clear and Restart External Volumes and Containers",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"${workspaceFolder}/backend/scripts/restart_containers.sh"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"stopOnEntry": true,
"presentation": {
"group": "3"
}
},
{
// Celery jobs launched through a single background script (legacy)
// Recommend using the "Celery (all)" compound launch instead.
"name": "Background Jobs",
"consoleName": "Background Jobs",
"type": "debugpy",
"request": "launch",
"program": "scripts/dev_run_background_jobs.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
}
},
{
"name": "Install Python Requirements",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"-c",
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "3"
}
},
{
"name": "Debug React Web App in Chrome",
"type": "chrome",
"request": "launch",
"url": "http://localhost:3000",
"webRoot": "${workspaceFolder}/web"
}
]
}
}

View File

@@ -12,10 +12,6 @@ As an open source project in a rapidly changing space, we welcome all contributi
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to start for contribution ideas.
To ensure that your contribution is aligned with the project's direction, please reach out to Hagen (or any other maintainer) on the Onyx team
via [Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA) /
[Discord](https://discord.gg/TDJ59cGV2X) or [email](mailto:founders@onyx.app).
Issues that have been explicitly approved by the maintainers (aligned with the direction of the project)
will be marked with the `approved by maintainers` label.
Issues marked `good first issue` are an especially great place to start.
@@ -27,8 +23,8 @@ If you have a new/different contribution in mind, we'd love to hear about it!
Your input is vital to making sure that Onyx moves in the right direction.
Before starting on implementation, please raise a GitHub issue.
Also, always feel free to message the founders (Chris Weaver / Yuhong Sun) on
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA) /
And always feel free to message us (Chris Weaver / Yuhong Sun) on
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ) /
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.
### Contributing Code
@@ -46,7 +42,7 @@ Our goal is to make contributing as easy as possible. If you run into any issues
That way we can help future contributors and users can avoid the same issue.
We also have support channels and generally interesting discussions on our
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA)
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ)
and
[Discord](https://discord.gg/TDJ59cGV2X).
@@ -127,47 +123,7 @@ Once the above is done, navigate to `onyx/web` run:
npm i
```
## Formatting and Linting
### Backend
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
First, install pre-commit (if you don't have it already) following the instructions
[here](https://pre-commit.com/#installation).
With the virtual environment active, install the pre-commit library with:
```bash
pip install pre-commit
```
Then, from the `onyx/backend` directory, run:
```bash
pre-commit install
```
Additionally, we use `mypy` for static type checking.
Onyx is fully type-annotated, and we want to keep it that way!
To run the mypy checks manually, run `python -m mypy .` from the `onyx/backend` directory.
### Web
We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `onyx/web` directory.
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
Please double check that prettier passes before creating a pull request.
# Running the application for development
## Developing using VSCode Debugger (recommended)
We highly recommend using VSCode debugger for development.
See [CONTRIBUTING_VSCODE.md](./CONTRIBUTING_VSCODE.md) for more details.
Otherwise, you can follow the instructions below to run the application for development.
## Manually running the application for development
### Docker containers for external software
#### Docker containers for external software
You will need Docker installed to run these containers.
@@ -179,7 +135,7 @@ docker compose -f docker-compose.dev.yml -p onyx-stack up -d index relational_db
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
### Running Onyx locally
#### Running Onyx locally
To start the frontend, navigate to `onyx/web` and run:
@@ -267,6 +223,35 @@ If you want to make changes to Onyx and run those changes in Docker, you can als
docker compose -f docker-compose.dev.yml -p onyx-stack up -d --build
```
### Formatting and Linting
#### Backend
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
First, install pre-commit (if you don't have it already) following the instructions
[here](https://pre-commit.com/#installation).
With the virtual environment active, install the pre-commit library with:
```bash
pip install pre-commit
```
Then, from the `onyx/backend` directory, run:
```bash
pre-commit install
```
Additionally, we use `mypy` for static type checking.
Onyx is fully type-annotated, and we want to keep it that way!
To run the mypy checks manually, run `python -m mypy .` from the `onyx/backend` directory.
#### Web
We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `onyx/web` directory.
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
Please double check that prettier passes before creating a pull request.
### Release Process

View File

@@ -1,30 +0,0 @@
# VSCode Debugging Setup
This guide explains how to set up and use VSCode's debugging capabilities with this project.
## Initial Setup
1. **Environment Setup**:
- Copy `.vscode/.env.template` to `.vscode/.env`
- Fill in the necessary environment variables in `.vscode/.env`
2. **launch.json**:
- Copy `.vscode/launch.template.jsonc` to `.vscode/launch.json`
## Using the Debugger
Before starting, make sure the Docker Daemon is running.
1. Open the Debug view in VSCode (Cmd+Shift+D on macOS)
2. From the dropdown at the top, select "Clear and Restart External Volumes and Containers" and press the green play button
3. From the dropdown at the top, select "Run All Onyx Services" and press the green play button
4. CD into web, run "npm i" followed by npm run dev.
5. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
6. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
7. Use the debug toolbar to step through code, inspect variables, etc.
## Features
- Hot reload is enabled for the web server and API servers
- Python debugging is configured with debugpy
- Environment variables are loaded from `.vscode/.env`
- Console output is organized in the integrated terminal with labeled tabs

122
README.md
View File

@@ -24,94 +24,112 @@
</a>
</p>
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI platform connected to your company's docs, apps, and people.
Onyx provides a feature rich Chat interface and plugs into any LLM of your choice.
Keep knowledge and access controls sync-ed across over 40 connectors like Google Drive, Slack, Confluence, Salesforce, etc.
Create custom AI agents with unique prompts, knowledge, and actions that the agents can take.
Onyx can be deployed securely anywhere and for any scale - on a laptop, on-premise, or to cloud.
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for
configuring AI Assistants.
Onyx also serves as a Enterprise Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
By combining LLMs and team specific knowledge, Onyx becomes a subject matter expert for the team. Imagine ChatGPT if
it had access to your team's unique knowledge! It enables questions such as "A customer wants feature X, is this already
supported?" or "Where's the pull request for feature Y?"
<h3>Feature Highlights</h3>
<h3>Usage</h3>
**Deep research over your team's knowledge:**
Onyx Web App:
https://private-user-images.githubusercontent.com/32520769/414509312-48392e83-95d0-4fb5-8650-a396e05e0a32.mp4?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk5Mjg2MzYsIm5iZiI6MTczOTkyODMzNiwicGF0aCI6Ii8zMjUyMDc2OS80MTQ1MDkzMTItNDgzOTJlODMtOTVkMC00ZmI1LTg2NTAtYTM5NmUwNWUwYTMyLm1wND9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjE5VDAxMjUzNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWFhMzk5Njg2Y2Y5YjFmNDNiYTQ2YzM5ZTg5YWJiYTU2NWMyY2YwNmUyODE2NWUxMDRiMWQxZWJmODI4YTA0MTUmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.a9D8A0sgKE9AoaoE-mfFbJ6_OKYeqaf7TZ4Han2JfW8
https://github.com/onyx-dot-app/onyx/assets/32520769/563be14c-9304-47b5-bf0a-9049c2b6f410
Or, plug Onyx into your existing Slack workflows (more integrations to come 😁):
**Use Onyx as a secure AI Chat with any LLM:**
![Onyx Chat Silent Demo](https://github.com/onyx-dot-app/onyx/releases/download/v0.21.1/OnyxChatSilentDemo.gif)
**Easily set up connectors to your apps:**
![Onyx Connector Silent Demo](https://github.com/onyx-dot-app/onyx/releases/download/v0.21.1/OnyxConnectorSilentDemo.gif)
**Access Onyx where your team already works:**
![Onyx Bot Demo](https://github.com/onyx-dot-app/onyx/releases/download/v0.21.1/OnyxBot.png)
https://github.com/onyx-dot-app/onyx/assets/25087905/3e19739b-d178-4371-9a38-011430bdec1b
For more details on the Admin UI to manage connectors and users, check out our
<strong><a href="https://www.youtube.com/watch?v=geNzY1nbCnU">Full Video Demo</a></strong>!
## Deployment
**To try it out for free and get started in seconds, check out [Onyx Cloud](https://cloud.onyx.app/signup)**.
Onyx can also be run locally (even on a laptop) or deployed on a virtual machine with a single
Onyx can easily be run locally (even on a laptop) or deployed on a virtual machine with a single
`docker compose` command. Checkout our [docs](https://docs.onyx.app/quickstart) to learn more.
We also have built-in support for high-availability/scalable deployment on Kubernetes.
References [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment).
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment/kubernetes).
## 💃 Main Features
## 🔍 Other Notable Benefits of Onyx
- Custom deep learning models for indexing and inference time, only through Onyx + learning from user feedback.
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
- Knowledge curation features like document-sets, query history, usage analytics, etc.
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
- Chat UI with the ability to select documents to chat with.
- Create custom AI Assistants with different prompts and backing knowledge sets.
- Connect Onyx with LLM of your choice (self-host for a fully airgapped solution).
- Document Search + AI Answers for natural language queries.
- Connectors to all common workplace tools like Google Drive, Confluence, Slack, etc.
- Slack integration to get answers and search results directly in Slack.
## 🚧 Roadmap
- New methods in information retrieval (StructRAG, LightGraphRAG, etc.)
- Personalized Search
- Organizational understanding and ability to locate and suggest experts from your team.
- Code Search
- SQL and Structured Query Language
- Chat/Prompt sharing with specific teammates and user groups.
- Multimodal model support, chat with images, video etc.
- Choosing between LLMs and parameters during chat session.
- Tool calling and agent configurations options.
- Organizational understanding and ability to locate and suggest experts from your team.
## Other Notable Benefits of Onyx
- User Authentication with document level access management.
- Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
- Admin Dashboard to configure connectors, document-sets, access, etc.
- Custom deep learning models + learn from user feedback.
- Easy deployment and ability to host Onyx anywhere of your choosing.
## 🔌 Connectors
Keep knowledge and access up to sync across 40+ connectors:
Efficiently pulls the latest changes from:
- Slack
- GitHub
- Google Drive
- Confluence
- Slack
- Gmail
- Salesforce
- Microsoft Sharepoint
- Github
- Jira
- Zendesk
- Gmail
- Notion
- Gong
- Microsoft Teams
- Dropbox
- Slab
- Linear
- Productboard
- Guru
- Bookstack
- Document360
- Sharepoint
- Hubspot
- Local Files
- Websites
- And more ...
See the full list [here](https://docs.onyx.app/connectors).
## 📚 Editions
## 📚 Licensing
There are two editions of Onyx:
- Onyx Community Edition (CE) is available freely under the MIT Expat license. Simply follow the Deployment guide above.
- Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations.
For feature details, check out [our website](https://www.onyx.app/pricing).
- Onyx Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Onyx you will get if you follow the Deployment guide above.
- Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
- Single Sign-On (SSO), with support for both SAML and OIDC
- Role-based access control
- Document permission inheritance from connected sources
- Usage analytics and query history accessible to admins
- Whitelabeling
- API key authentication
- Encryption of secrets
- Any many more! Checkout [our website](https://www.onyx.app/) for the latest.
To try the Onyx Enterprise Edition:
1. Checkout [Onyx Cloud](https://cloud.onyx.app/signup).
2. For self-hosting the Enterprise Edition, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/onyx/founders).
1. Checkout our [Cloud product](https://cloud.onyx.app/signup).
2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
## ⭐Star History
[![Star History Chart](https://api.star-history.com/svg?repos=onyx-dot-app/onyx&type=Date)](https://star-history.com/#onyx-dot-app/onyx&Date)

View File

@@ -8,11 +8,9 @@ Edition features outside of personal development or testing purposes. Please rea
founders@onyx.app for more information. Please visit https://github.com/onyx-dot-app/onyx"
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
ARG ONYX_VERSION=0.0.0-dev
# DO_NOT_TRACK is used to disable telemetry for Unstructured
ARG ONYX_VERSION=0.8-dev
ENV ONYX_VERSION=${ONYX_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true" \
DO_NOT_TRACK="true"
DANSWER_RUNNING_IN_DOCKER="true"
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
@@ -28,16 +26,14 @@ RUN apt-get update && \
curl \
zip \
ca-certificates \
libgnutls30 \
libblkid1 \
libmount1 \
libsmartcols1 \
libuuid1 \
libgnutls30=3.7.9-2+deb12u3 \
libblkid1=2.38.1-5+deb12u1 \
libmount1=2.38.1-5+deb12u1 \
libsmartcols1=2.38.1-5+deb12u1 \
libuuid1=2.38.1-5+deb12u1 \
libxmlsec1-dev \
pkg-config \
gcc \
nano \
vim && \
gcc && \
rm -rf /var/lib/apt/lists/* && \
apt-get clean
@@ -102,10 +98,8 @@ COPY ./alembic /app/alembic
COPY ./alembic_tenants /app/alembic_tenants
COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf
COPY ./static /app/static
# Escape hatch scripts
COPY ./scripts/debugging /app/scripts/debugging
# Escape hatch
COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
# Put logo in assets

View File

@@ -7,7 +7,7 @@ You can find it at https://hub.docker.com/r/onyx/onyx-model-server. For more det
visit https://github.com/onyx-dot-app/onyx."
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
ARG ONYX_VERSION=0.0.0-dev
ARG ONYX_VERSION=0.8-dev
ENV ONYX_VERSION=${ONYX_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"
@@ -31,8 +31,7 @@ RUN python -c "from transformers import AutoTokenizer; \
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
from huggingface_hub import snapshot_download; \
snapshot_download(repo_id='onyx-dot-app/hybrid-intent-token-classifier'); \
snapshot_download(repo_id='onyx-dot-app/information-content-model'); \
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
from sentence_transformers import SentenceTransformer; \
@@ -46,7 +45,6 @@ WORKDIR /app
# Utils used by model server
COPY ./onyx/utils/logger.py /app/onyx/utils/logger.py
COPY ./onyx/utils/middleware.py /app/onyx/utils/middleware.py
# Place to fetch version information
COPY ./onyx/__init__.py /app/onyx/__init__.py

View File

@@ -84,7 +84,7 @@ keys = console
keys = generic
[logger_root]
level = INFO
level = WARN
handlers = console
qualname =

View File

@@ -25,9 +25,6 @@ from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
from onyx.db.models import Base
from celery.backends.database.session import ResultModelBase # type: ignore
# Make sure in alembic.ini [logger_root] level=INFO is set or most logging will be
# hidden! (defaults to level=WARN)
# Alembic Config object
config = context.config
@@ -39,7 +36,6 @@ if config.config_file_name is not None and config.attributes.get(
target_metadata = [Base.metadata, ResultModelBase.metadata]
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
logger = logging.getLogger(__name__)
ssl_context: ssl.SSLContext | None = None
@@ -68,7 +64,7 @@ def include_object(
return True
def get_schema_options() -> tuple[str, bool, bool, bool]:
def get_schema_options() -> tuple[str, bool, bool]:
x_args_raw = context.get_x_argument()
x_args = {}
for arg in x_args_raw:
@@ -80,10 +76,6 @@ def get_schema_options() -> tuple[str, bool, bool, bool]:
create_schema = x_args.get("create_schema", "true").lower() == "true"
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
# continue on error with individual tenant
# only applies to online migrations
continue_on_error = x_args.get("continue", "false").lower() == "true"
if (
MULTI_TENANT
and schema_name == POSTGRES_DEFAULT_SCHEMA
@@ -94,12 +86,14 @@ def get_schema_options() -> tuple[str, bool, bool, bool]:
"Please specify a tenant-specific schema."
)
return schema_name, create_schema, upgrade_all_tenants, continue_on_error
return schema_name, create_schema, upgrade_all_tenants
def do_run_migrations(
connection: Connection, schema_name: str, create_schema: bool
) -> None:
logger.info(f"About to migrate schema: {schema_name}")
if create_schema:
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
connection.execute(text("COMMIT"))
@@ -140,12 +134,7 @@ def provide_iam_token_for_alembic(
async def run_async_migrations() -> None:
(
schema_name,
create_schema,
upgrade_all_tenants,
continue_on_error,
) = get_schema_options()
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
engine = create_async_engine(
build_connection_string(),
@@ -162,15 +151,9 @@ async def run_async_migrations() -> None:
if upgrade_all_tenants:
tenant_schemas = get_all_tenant_ids()
i_tenant = 0
num_tenants = len(tenant_schemas)
for schema in tenant_schemas:
i_tenant += 1
logger.info(
f"Migrating schema: index={i_tenant} num_tenants={num_tenants} schema={schema}"
)
try:
logger.info(f"Migrating schema: {schema}")
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
@@ -179,12 +162,7 @@ async def run_async_migrations() -> None:
)
except Exception as e:
logger.error(f"Error migrating schema {schema}: {e}")
if not continue_on_error:
logger.error("--continue is not set, raising exception!")
raise
logger.warning("--continue is set, continuing to next schema.")
raise
else:
try:
logger.info(f"Migrating schema: {schema_name}")
@@ -202,11 +180,7 @@ async def run_async_migrations() -> None:
def run_migrations_offline() -> None:
"""This doesn't really get used when we migrate in the cloud."""
logger.info("run_migrations_offline starting.")
schema_name, _, upgrade_all_tenants, continue_on_error = get_schema_options()
schema_name, _, upgrade_all_tenants = get_schema_options()
url = build_connection_string()
if upgrade_all_tenants:
@@ -256,7 +230,6 @@ def run_migrations_offline() -> None:
def run_migrations_online() -> None:
logger.info("run_migrations_online starting.")
asyncio.run(run_async_migrations())

View File

@@ -1,29 +0,0 @@
"""add shortcut option for users
Revision ID: 027381bce97c
Revises: 6fc7886d665d
Create Date: 2025-01-14 12:14:00.814390
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "027381bce97c"
down_revision = "6fc7886d665d"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column(
"shortcut_enabled", sa.Boolean(), nullable=False, server_default="false"
),
)
def downgrade() -> None:
op.drop_column("user", "shortcut_enabled")

View File

@@ -1,36 +0,0 @@
"""add index to index_attempt.time_created
Revision ID: 0f7ff6d75b57
Revises: 369644546676
Create Date: 2025-01-10 14:01:14.067144
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "0f7ff6d75b57"
down_revision = "fec3db967bf7"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_index(
op.f("ix_index_attempt_status"),
"index_attempt",
["status"],
unique=False,
)
op.create_index(
op.f("ix_index_attempt_time_created"),
"index_attempt",
["time_created"],
unique=False,
)
def downgrade() -> None:
op.drop_index(op.f("ix_index_attempt_time_created"), table_name="index_attempt")
op.drop_index(op.f("ix_index_attempt_status"), table_name="index_attempt")

View File

@@ -1,27 +0,0 @@
"""Add indexes to document__tag
Revision ID: 1a03d2c2856b
Revises: 9c00a2bccb83
Create Date: 2025-02-18 10:45:13.957807
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "1a03d2c2856b"
down_revision = "9c00a2bccb83"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_index(
op.f("ix_document__tag_tag_id"),
"document__tag",
["tag_id"],
unique=False,
)
def downgrade() -> None:
op.drop_index(op.f("ix_document__tag_tag_id"), table_name="document__tag")

View File

@@ -1,32 +0,0 @@
"""set built in to default
Revision ID: 2cdeff6d8c93
Revises: f5437cc136c5
Create Date: 2025-02-11 14:57:51.308775
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "2cdeff6d8c93"
down_revision = "f5437cc136c5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Prior to this migration / point in the codebase history,
# built in personas were implicitly treated as default personas (with no option to change this)
# This migration makes that explicit
op.execute(
"""
UPDATE persona
SET is_default_persona = TRUE
WHERE builtin_persona = TRUE
"""
)
def downgrade() -> None:
pass

View File

@@ -1,36 +0,0 @@
"""add chat session specific temperature override
Revision ID: 2f80c6a2550f
Revises: 33ea50e88f24
Create Date: 2025-01-31 10:30:27.289646
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2f80c6a2550f"
down_revision = "33ea50e88f24"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_session", sa.Column("temperature_override", sa.Float(), nullable=True)
)
op.add_column(
"user",
sa.Column(
"temperature_override_enabled",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
def downgrade() -> None:
op.drop_column("chat_session", "temperature_override")
op.drop_column("user", "temperature_override_enabled")

View File

@@ -1,80 +0,0 @@
"""foreign key input prompts
Revision ID: 33ea50e88f24
Revises: a6df6b88ef81
Create Date: 2025-01-29 10:54:22.141765
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "33ea50e88f24"
down_revision = "a6df6b88ef81"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Safely drop constraints if exists
op.execute(
"""
ALTER TABLE inputprompt__user
DROP CONSTRAINT IF EXISTS inputprompt__user_input_prompt_id_fkey
"""
)
op.execute(
"""
ALTER TABLE inputprompt__user
DROP CONSTRAINT IF EXISTS inputprompt__user_user_id_fkey
"""
)
# Recreate with ON DELETE CASCADE
op.create_foreign_key(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
"inputprompt",
["input_prompt_id"],
["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
# Drop the new FKs with ondelete
op.drop_constraint(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
type_="foreignkey",
)
op.drop_constraint(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
type_="foreignkey",
)
# Recreate them without cascading
op.create_foreign_key(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
"inputprompt",
["input_prompt_id"],
["id"],
)
op.create_foreign_key(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
"user",
["user_id"],
["id"],
)

View File

@@ -1,35 +0,0 @@
"""add composite index for index attempt time updated
Revision ID: 369644546676
Revises: 2955778aa44c
Create Date: 2025-01-08 15:38:17.224380
"""
from alembic import op
from sqlalchemy import text
# revision identifiers, used by Alembic.
revision = "369644546676"
down_revision = "2955778aa44c"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_index(
"ix_index_attempt_ccpair_search_settings_time_updated",
"index_attempt",
[
"connector_credential_pair_id",
"search_settings_id",
text("time_updated DESC"),
],
unique=False,
)
def downgrade() -> None:
op.drop_index(
"ix_index_attempt_ccpair_search_settings_time_updated",
table_name="index_attempt",
)

View File

@@ -1,51 +0,0 @@
"""add chunk stats table
Revision ID: 3781a5eb12cb
Revises: df46c75b714e
Create Date: 2025-03-10 10:02:30.586666
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "3781a5eb12cb"
down_revision = "df46c75b714e"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"chunk_stats",
sa.Column("id", sa.String(), primary_key=True, index=True),
sa.Column(
"document_id",
sa.String(),
sa.ForeignKey("document.id"),
nullable=False,
index=True,
),
sa.Column("chunk_in_doc_id", sa.Integer(), nullable=False),
sa.Column("information_content_boost", sa.Float(), nullable=True),
sa.Column(
"last_modified",
sa.DateTime(timezone=True),
nullable=False,
index=True,
server_default=sa.func.now(),
),
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True, index=True),
sa.UniqueConstraint(
"document_id", "chunk_in_doc_id", name="uq_chunk_stats_doc_chunk"
),
)
op.create_index(
"ix_chunk_sync_status", "chunk_stats", ["last_modified", "last_synced"]
)
def downgrade() -> None:
op.drop_index("ix_chunk_sync_status", table_name="chunk_stats")
op.drop_table("chunk_stats")

View File

@@ -1,125 +0,0 @@
"""Update GitHub connector repo_name to repositories
Revision ID: 3934b1bc7b62
Revises: b7c2b63c4a03
Create Date: 2025-03-05 10:50:30.516962
"""
from alembic import op
import sqlalchemy as sa
import json
import logging
# revision identifiers, used by Alembic.
revision = "3934b1bc7b62"
down_revision = "b7c2b63c4a03"
branch_labels = None
depends_on = None
logger = logging.getLogger("alembic.runtime.migration")
def upgrade() -> None:
# Get all GitHub connectors
conn = op.get_bind()
# First get all GitHub connectors
github_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = 'GITHUB'
"""
)
).fetchall()
# Update each connector's config
updated_count = 0
for connector_id, config in github_connectors:
try:
if not config:
logger.warning(f"Connector {connector_id} has no config, skipping")
continue
# Parse the config if it's a string
if isinstance(config, str):
config = json.loads(config)
if "repo_name" not in config:
continue
# Create new config with repositories instead of repo_name
new_config = dict(config)
repo_name_value = new_config.pop("repo_name")
new_config["repositories"] = repo_name_value
# Update the connector with the new config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :connector_id
"""
),
{"connector_id": connector_id, "new_config": json.dumps(new_config)},
)
updated_count += 1
except Exception as e:
logger.error(f"Error updating connector {connector_id}: {str(e)}")
def downgrade() -> None:
# Get all GitHub connectors
conn = op.get_bind()
logger.debug(
"Starting rollback of GitHub connectors from repositories to repo_name"
)
github_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = 'GITHUB'
"""
)
).fetchall()
logger.debug(f"Found {len(github_connectors)} GitHub connectors to rollback")
# Revert each GitHub connector to use repo_name instead of repositories
reverted_count = 0
for connector_id, config in github_connectors:
try:
if not config:
continue
# Parse the config if it's a string
if isinstance(config, str):
config = json.loads(config)
if "repositories" not in config:
continue
# Create new config with repo_name instead of repositories
new_config = dict(config)
repositories_value = new_config.pop("repositories")
new_config["repo_name"] = repositories_value
# Update the connector with the new config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :connector_id
"""
),
{"new_config": json.dumps(new_config), "connector_id": connector_id},
)
reverted_count += 1
except Exception as e:
logger.error(f"Error reverting connector {connector_id}: {str(e)}")

View File

@@ -1,98 +0,0 @@
"""improved index
Revision ID: 3bd4c84fe72f
Revises: 8f43500ee275
Create Date: 2025-02-26 13:07:56.217791
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "3bd4c84fe72f"
down_revision = "8f43500ee275"
branch_labels = None
depends_on = None
# NOTE:
# This migration addresses issues with the previous migration (8f43500ee275) which caused
# an outage by creating an index without using CONCURRENTLY. This migration:
#
# 1. Creates more efficient full-text search capabilities using tsvector columns and GIN indexes
# 2. Uses CONCURRENTLY for all index creation to prevent table locking
# 3. Explicitly manages transactions with COMMIT statements to allow CONCURRENTLY to work
# (see: https://www.postgresql.org/docs/9.4/sql-createindex.html#SQL-CREATEINDEX-CONCURRENTLY)
# (see: https://github.com/sqlalchemy/alembic/issues/277)
# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search
def upgrade() -> None:
# First, drop any existing indexes to avoid conflicts
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
# Drop existing columns if they exist
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
# Create a GIN index for full-text search on chat_message.message
op.execute(
"""
ALTER TABLE chat_message
ADD COLUMN message_tsv tsvector
GENERATED ALWAYS AS (to_tsvector('english', message)) STORED;
"""
)
# Commit the current transaction before creating concurrent indexes
op.execute("COMMIT")
op.execute(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
ON chat_message
USING GIN (message_tsv)
"""
)
# Also add a stored tsvector column for chat_session.description
op.execute(
"""
ALTER TABLE chat_session
ADD COLUMN description_tsv tsvector
GENERATED ALWAYS AS (to_tsvector('english', coalesce(description, ''))) STORED;
"""
)
# Commit again before creating the second concurrent index
op.execute("COMMIT")
op.execute(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
ON chat_session
USING GIN (description_tsv)
"""
)
def downgrade() -> None:
# Drop the indexes first (use CONCURRENTLY for dropping too)
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
# Then drop the columns
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")

View File

@@ -1,59 +0,0 @@
"""add back input prompts
Revision ID: 3c6531f32351
Revises: aeda5f2df4f6
Create Date: 2025-01-13 12:49:51.705235
"""
from alembic import op
import sqlalchemy as sa
import fastapi_users_db_sqlalchemy
# revision identifiers, used by Alembic.
revision = "3c6531f32351"
down_revision = "aeda5f2df4f6"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"inputprompt",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("prompt", sa.String(), nullable=False),
sa.Column("content", sa.String(), nullable=False),
sa.Column("active", sa.Boolean(), nullable=False),
sa.Column("is_public", sa.Boolean(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"inputprompt__user",
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
sa.Column(
"user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False
),
sa.Column("disabled", sa.Boolean(), nullable=False, default=False),
sa.ForeignKeyConstraint(
["input_prompt_id"],
["inputprompt.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
)
def downgrade() -> None:
op.drop_table("inputprompt__user")
op.drop_table("inputprompt")

View File

@@ -1,50 +0,0 @@
"""update prompt length
Revision ID: 4794bc13e484
Revises: f7505c5b0284
Create Date: 2025-04-02 11:26:36.180328
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4794bc13e484"
down_revision = "f7505c5b0284"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column(
"prompt",
"system_prompt",
existing_type=sa.TEXT(),
type_=sa.String(length=5000000),
existing_nullable=False,
)
op.alter_column(
"prompt",
"task_prompt",
existing_type=sa.TEXT(),
type_=sa.String(length=5000000),
existing_nullable=False,
)
def downgrade() -> None:
op.alter_column(
"prompt",
"system_prompt",
existing_type=sa.String(length=5000000),
type_=sa.TEXT(),
existing_nullable=False,
)
op.alter_column(
"prompt",
"task_prompt",
existing_type=sa.String(length=5000000),
type_=sa.TEXT(),
existing_nullable=False,
)

View File

@@ -40,6 +40,6 @@ def upgrade() -> None:
def downgrade() -> None:
op.drop_constraint("persona_category_id_fkey", "persona", type_="foreignkey")
op.drop_constraint("fk_persona_category", "persona", type_="foreignkey")
op.drop_column("persona", "category_id")
op.drop_table("persona_category")

View File

@@ -1,64 +0,0 @@
"""lowercase_user_emails
Revision ID: 4d58345da04a
Revises: f1ca58b2f2ec
Create Date: 2025-01-29 07:48:46.784041
"""
import logging
from typing import cast
from alembic import op
from sqlalchemy.exc import IntegrityError
from sqlalchemy.sql import text
# revision identifiers, used by Alembic.
revision = "4d58345da04a"
down_revision = "f1ca58b2f2ec"
branch_labels = None
depends_on = None
logger = logging.getLogger("alembic.runtime.migration")
def upgrade() -> None:
"""Conflicts on lowercasing will result in the uppercased email getting a
unique integer suffix when converted to lowercase."""
connection = op.get_bind()
# Fetch all user emails that are not already lowercase
user_emails = connection.execute(
text('SELECT id, email FROM "user" WHERE email != LOWER(email)')
).fetchall()
for user_id, email in user_emails:
email = cast(str, email)
username, domain = email.rsplit("@", 1)
new_email = f"{username.lower()}@{domain.lower()}"
attempt = 1
while True:
try:
# Try updating the email
connection.execute(
text('UPDATE "user" SET email = :new_email WHERE id = :user_id'),
{"new_email": new_email, "user_id": user_id},
)
break # Success, exit loop
except IntegrityError:
next_email = f"{username.lower()}_{attempt}@{domain.lower()}"
# Email conflict occurred, append `_1`, `_2`, etc., to the username
logger.warning(
f"Conflict while lowercasing email: "
f"old_email={email} "
f"conflicting_email={new_email} "
f"next_email={next_email}"
)
new_email = next_email
attempt += 1
def downgrade() -> None:
# Cannot restore original case of emails
pass

View File

@@ -5,6 +5,7 @@ Revises: 47e5bef3a1d7
Create Date: 2024-11-06 13:15:53.302644
"""
import logging
from typing import cast
from alembic import op
import sqlalchemy as sa
@@ -19,8 +20,13 @@ down_revision = "47e5bef3a1d7"
branch_labels: None = None
depends_on: None = None
# Configure logging
logger = logging.getLogger("alembic.runtime.migration")
logger.setLevel(logging.INFO)
def upgrade() -> None:
logger.info(f"{revision}: create_table: slack_bot")
# Create new slack_bot table
op.create_table(
"slack_bot",
@@ -57,6 +63,7 @@ def upgrade() -> None:
)
# Handle existing Slack bot tokens first
logger.info(f"{revision}: Checking for existing Slack bot.")
bot_token = None
app_token = None
first_row_id = None
@@ -64,12 +71,15 @@ def upgrade() -> None:
try:
tokens = cast(dict, get_kv_store().load("slack_bot_tokens_config_key"))
except Exception:
logger.warning("No existing Slack bot tokens found.")
tokens = {}
bot_token = tokens.get("bot_token")
app_token = tokens.get("app_token")
if bot_token and app_token:
logger.info(f"{revision}: Found bot and app tokens.")
session = Session(bind=op.get_bind())
new_slack_bot = SlackBot(
name="Slack Bot (Migrated)",
@@ -160,9 +170,10 @@ def upgrade() -> None:
# Clean up old tokens if they existed
try:
if bot_token and app_token:
logger.info(f"{revision}: Removing old bot and app tokens.")
get_kv_store().delete("slack_bot_tokens_config_key")
except Exception:
pass
logger.warning("tried to delete tokens in dynamic config but failed")
# Rename the table
op.rename_table(
"slack_bot_config__standard_answer_category",
@@ -179,6 +190,8 @@ def upgrade() -> None:
# Drop the table with CASCADE to handle dependent objects
op.execute("DROP TABLE slack_bot_config CASCADE")
logger.info(f"{revision}: Migration complete.")
def downgrade() -> None:
# Recreate the old slack_bot_config table
@@ -260,7 +273,7 @@ def downgrade() -> None:
}
get_kv_store().store("slack_bot_tokens_config_key", tokens)
except Exception:
pass
logger.warning("Failed to save tokens back to KV store")
# Drop the new tables in reverse order
op.drop_table("slack_channel_config")

View File

@@ -1,117 +0,0 @@
"""duplicated no-harm user file migration
Revision ID: 6a804aeb4830
Revises: 8e1ac4f39a9f
Create Date: 2025-04-01 07:26:10.539362
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy import inspect
import datetime
# revision identifiers, used by Alembic.
revision = "6a804aeb4830"
down_revision = "8e1ac4f39a9f"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Check if user_file table already exists
conn = op.get_bind()
inspector = inspect(conn)
if not inspector.has_table("user_file"):
# Create user_folder table without parent_id
op.create_table(
"user_folder",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
sa.Column("name", sa.String(length=255), nullable=True),
sa.Column("description", sa.String(length=255), nullable=True),
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
sa.Column(
"created_at", sa.DateTime(timezone=True), server_default=sa.func.now()
),
)
# Create user_file table with folder_id instead of parent_folder_id
op.create_table(
"user_file",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
sa.Column(
"folder_id",
sa.Integer(),
sa.ForeignKey("user_folder.id"),
nullable=True,
),
sa.Column("link_url", sa.String(), nullable=True),
sa.Column("token_count", sa.Integer(), nullable=True),
sa.Column("file_type", sa.String(), nullable=True),
sa.Column("file_id", sa.String(length=255), nullable=False),
sa.Column("document_id", sa.String(length=255), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column(
"created_at",
sa.DateTime(),
default=datetime.datetime.utcnow,
),
sa.Column(
"cc_pair_id",
sa.Integer(),
sa.ForeignKey("connector_credential_pair.id"),
nullable=True,
unique=True,
),
)
# Create persona__user_file table
op.create_table(
"persona__user_file",
sa.Column(
"persona_id",
sa.Integer(),
sa.ForeignKey("persona.id"),
primary_key=True,
),
sa.Column(
"user_file_id",
sa.Integer(),
sa.ForeignKey("user_file.id"),
primary_key=True,
),
)
# Create persona__user_folder table
op.create_table(
"persona__user_folder",
sa.Column(
"persona_id",
sa.Integer(),
sa.ForeignKey("persona.id"),
primary_key=True,
),
sa.Column(
"user_folder_id",
sa.Integer(),
sa.ForeignKey("user_folder.id"),
primary_key=True,
),
)
op.add_column(
"connector_credential_pair",
sa.Column("is_user_file", sa.Boolean(), nullable=True, default=False),
)
# Update existing records to have is_user_file=False instead of NULL
op.execute(
"UPDATE connector_credential_pair SET is_user_file = FALSE WHERE is_user_file IS NULL"
)
def downgrade() -> None:
pass

View File

@@ -1,80 +0,0 @@
"""make categories labels and many to many
Revision ID: 6fc7886d665d
Revises: 3c6531f32351
Create Date: 2025-01-13 18:12:18.029112
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "6fc7886d665d"
down_revision = "3c6531f32351"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Rename persona_category table to persona_label
op.rename_table("persona_category", "persona_label")
# Create the new association table
op.create_table(
"persona__persona_label",
sa.Column("persona_id", sa.Integer(), nullable=False),
sa.Column("persona_label_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.ForeignKeyConstraint(
["persona_label_id"],
["persona_label.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("persona_id", "persona_label_id"),
)
# Copy existing relationships to the new table
op.execute(
"""
INSERT INTO persona__persona_label (persona_id, persona_label_id)
SELECT id, category_id FROM persona WHERE category_id IS NOT NULL
"""
)
# Remove the old category_id column from persona table
op.drop_column("persona", "category_id")
def downgrade() -> None:
# Rename persona_label table back to persona_category
op.rename_table("persona_label", "persona_category")
# Add back the category_id column to persona table
op.add_column("persona", sa.Column("category_id", sa.Integer(), nullable=True))
op.create_foreign_key(
"persona_category_id_fkey",
"persona",
"persona_category",
["category_id"],
["id"],
)
# Copy the first label relationship back to the persona table
op.execute(
"""
UPDATE persona
SET category_id = (
SELECT persona_label_id
FROM persona__persona_label
WHERE persona__persona_label.persona_id = persona.id
LIMIT 1
)
"""
)
# Drop the association table
op.drop_table("persona__persona_label")

View File

@@ -1,50 +0,0 @@
"""enable contextual retrieval
Revision ID: 8e1ac4f39a9f
Revises: 9aadf32dfeb4
Create Date: 2024-12-20 13:29:09.918661
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8e1ac4f39a9f"
down_revision = "9aadf32dfeb4"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"search_settings",
sa.Column(
"enable_contextual_rag",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
op.add_column(
"search_settings",
sa.Column(
"contextual_rag_llm_name",
sa.String(),
nullable=True,
),
)
op.add_column(
"search_settings",
sa.Column(
"contextual_rag_llm_provider",
sa.String(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("search_settings", "enable_contextual_rag")
op.drop_column("search_settings", "contextual_rag_llm_name")
op.drop_column("search_settings", "contextual_rag_llm_provider")

View File

@@ -1,32 +0,0 @@
"""add index
Revision ID: 8f43500ee275
Revises: da42808081e3
Create Date: 2025-02-24 17:35:33.072714
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "8f43500ee275"
down_revision = "da42808081e3"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create a basic index on the lowercase message column for direct text matching
# Limit to 1500 characters to stay well under the 2856 byte limit of btree version 4
# op.execute(
# """
# CREATE INDEX idx_chat_message_message_lower
# ON chat_message (LOWER(substring(message, 1, 1500)))
# """
# )
pass
def downgrade() -> None:
# Drop the index
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")

View File

@@ -1,72 +0,0 @@
"""Add SyncRecord
Revision ID: 97dbb53fa8c8
Revises: 369644546676
Create Date: 2025-01-11 19:39:50.426302
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "97dbb53fa8c8"
down_revision = "be2ab2aa50ee"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"sync_record",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("entity_id", sa.Integer(), nullable=False),
sa.Column(
"sync_type",
sa.Enum(
"DOCUMENT_SET",
"USER_GROUP",
"CONNECTOR_DELETION",
name="synctype",
native_enum=False,
length=40,
),
nullable=False,
),
sa.Column(
"sync_status",
sa.Enum(
"IN_PROGRESS",
"SUCCESS",
"FAILED",
"CANCELED",
name="syncstatus",
native_enum=False,
length=40,
),
nullable=False,
),
sa.Column("num_docs_synced", sa.Integer(), nullable=False),
sa.Column("sync_start_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("sync_end_time", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
# Add index for fetch_latest_sync_record query
op.create_index(
"ix_sync_record_entity_id_sync_type_sync_start_time",
"sync_record",
["entity_id", "sync_type", "sync_start_time"],
)
# Add index for cleanup_sync_records query
op.create_index(
"ix_sync_record_entity_id_sync_type_sync_status",
"sync_record",
["entity_id", "sync_type", "sync_status"],
)
def downgrade() -> None:
op.drop_index("ix_sync_record_entity_id_sync_type_sync_status")
op.drop_index("ix_sync_record_entity_id_sync_type_sync_start_time")
op.drop_table("sync_record")

View File

@@ -1,107 +0,0 @@
"""agent_tracking
Revision ID: 98a5008d8711
Revises: 2f80c6a2550f
Create Date: 2025-01-29 17:00:00.000001
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision = "98a5008d8711"
down_revision = "2f80c6a2550f"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"agent__search_metrics",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.Column("agent_type", sa.String(), nullable=False),
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("base_duration_s", sa.Float(), nullable=False),
sa.Column("full_duration_s", sa.Float(), nullable=False),
sa.Column("base_metrics", postgresql.JSONB(), nullable=True),
sa.Column("refined_metrics", postgresql.JSONB(), nullable=True),
sa.Column("all_metrics", postgresql.JSONB(), nullable=True),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
# Create sub_question table
op.create_table(
"agent__sub_question",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("primary_question_id", sa.Integer, sa.ForeignKey("chat_message.id")),
sa.Column(
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
),
sa.Column("sub_question", sa.Text),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
),
sa.Column("sub_answer", sa.Text),
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=True),
sa.Column("level", sa.Integer(), nullable=False),
sa.Column("level_question_num", sa.Integer(), nullable=False),
)
# Create sub_query table
op.create_table(
"agent__sub_query",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column(
"parent_question_id", sa.Integer, sa.ForeignKey("agent__sub_question.id")
),
sa.Column(
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
),
sa.Column("sub_query", sa.Text),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
),
)
# Create sub_query__search_doc association table
op.create_table(
"agent__sub_query__search_doc",
sa.Column(
"sub_query_id",
sa.Integer,
sa.ForeignKey("agent__sub_query.id"),
primary_key=True,
),
sa.Column(
"search_doc_id",
sa.Integer,
sa.ForeignKey("search_doc.id"),
primary_key=True,
),
)
op.add_column(
"chat_message",
sa.Column(
"refined_answer_improvement",
sa.Boolean(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("chat_message", "refined_answer_improvement")
op.drop_table("agent__sub_query__search_doc")
op.drop_table("agent__sub_query")
op.drop_table("agent__sub_question")
op.drop_table("agent__search_metrics")

View File

@@ -1,113 +0,0 @@
"""add user files
Revision ID: 9aadf32dfeb4
Revises: 3781a5eb12cb
Create Date: 2025-01-26 16:08:21.551022
"""
import sqlalchemy as sa
import datetime
from alembic import op
# revision identifiers, used by Alembic.
revision = "9aadf32dfeb4"
down_revision = "3781a5eb12cb"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create user_folder table without parent_id
op.create_table(
"user_folder",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
sa.Column("name", sa.String(length=255), nullable=True),
sa.Column("description", sa.String(length=255), nullable=True),
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
sa.Column(
"created_at", sa.DateTime(timezone=True), server_default=sa.func.now()
),
)
# Create user_file table with folder_id instead of parent_folder_id
op.create_table(
"user_file",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
sa.Column(
"folder_id",
sa.Integer(),
sa.ForeignKey("user_folder.id"),
nullable=True,
),
sa.Column("link_url", sa.String(), nullable=True),
sa.Column("token_count", sa.Integer(), nullable=True),
sa.Column("file_type", sa.String(), nullable=True),
sa.Column("file_id", sa.String(length=255), nullable=False),
sa.Column("document_id", sa.String(length=255), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column(
"created_at",
sa.DateTime(),
default=datetime.datetime.utcnow,
),
sa.Column(
"cc_pair_id",
sa.Integer(),
sa.ForeignKey("connector_credential_pair.id"),
nullable=True,
unique=True,
),
)
# Create persona__user_file table
op.create_table(
"persona__user_file",
sa.Column(
"persona_id", sa.Integer(), sa.ForeignKey("persona.id"), primary_key=True
),
sa.Column(
"user_file_id",
sa.Integer(),
sa.ForeignKey("user_file.id"),
primary_key=True,
),
)
# Create persona__user_folder table
op.create_table(
"persona__user_folder",
sa.Column(
"persona_id", sa.Integer(), sa.ForeignKey("persona.id"), primary_key=True
),
sa.Column(
"user_folder_id",
sa.Integer(),
sa.ForeignKey("user_folder.id"),
primary_key=True,
),
)
op.add_column(
"connector_credential_pair",
sa.Column("is_user_file", sa.Boolean(), nullable=True, default=False),
)
# Update existing records to have is_user_file=False instead of NULL
op.execute(
"UPDATE connector_credential_pair SET is_user_file = FALSE WHERE is_user_file IS NULL"
)
def downgrade() -> None:
# Drop the persona__user_folder table
op.drop_table("persona__user_folder")
# Drop the persona__user_file table
op.drop_table("persona__user_file")
# Drop the user_file table
op.drop_table("user_file")
# Drop the user_folder table
op.drop_table("user_folder")
op.drop_column("connector_credential_pair", "is_user_file")

View File

@@ -1,43 +0,0 @@
"""chat_message_agentic
Revision ID: 9c00a2bccb83
Revises: b7a7eee5aa15
Create Date: 2025-02-17 11:15:43.081150
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9c00a2bccb83"
down_revision = "b7a7eee5aa15"
branch_labels = None
depends_on = None
def upgrade() -> None:
# First add the column as nullable
op.add_column("chat_message", sa.Column("is_agentic", sa.Boolean(), nullable=True))
# Update existing rows based on presence of SubQuestions
op.execute(
"""
UPDATE chat_message
SET is_agentic = EXISTS (
SELECT 1
FROM agent__sub_question
WHERE agent__sub_question.primary_question_id = chat_message.id
)
WHERE is_agentic IS NULL
"""
)
# Make the column non-nullable with a default value of False
op.alter_column(
"chat_message", "is_agentic", nullable=False, server_default=sa.text("false")
)
def downgrade() -> None:
op.drop_column("chat_message", "is_agentic")

View File

@@ -1,29 +0,0 @@
"""remove recent assistants
Revision ID: a6df6b88ef81
Revises: 4d58345da04a
Create Date: 2025-01-29 10:25:52.790407
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "a6df6b88ef81"
down_revision = "4d58345da04a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_column("user", "recent_assistants")
def downgrade() -> None:
op.add_column(
"user",
sa.Column(
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
),
)

View File

@@ -1,29 +0,0 @@
"""remove inactive ccpair status on downgrade
Revision ID: acaab4ef4507
Revises: b388730a2899
Create Date: 2025-02-16 18:21:41.330212
"""
from alembic import op
from onyx.db.models import ConnectorCredentialPair
from onyx.db.enums import ConnectorCredentialPairStatus
from sqlalchemy import update
# revision identifiers, used by Alembic.
revision = "acaab4ef4507"
down_revision = "b388730a2899"
branch_labels = None
depends_on = None
def upgrade() -> None:
pass
def downgrade() -> None:
op.execute(
update(ConnectorCredentialPair)
.where(ConnectorCredentialPair.status == ConnectorCredentialPairStatus.INVALID)
.values(status=ConnectorCredentialPairStatus.ACTIVE)
)

View File

@@ -1,27 +0,0 @@
"""add pinned assistants
Revision ID: aeda5f2df4f6
Revises: c5eae4a75a1b
Create Date: 2025-01-09 16:04:10.770636
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "aeda5f2df4f6"
down_revision = "c5eae4a75a1b"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user", sa.Column("pinned_assistants", postgresql.JSONB(), nullable=True)
)
op.execute('UPDATE "user" SET pinned_assistants = chosen_assistants')
def downgrade() -> None:
op.drop_column("user", "pinned_assistants")

View File

@@ -1,31 +0,0 @@
"""nullable preferences
Revision ID: b388730a2899
Revises: 1a03d2c2856b
Create Date: 2025-02-17 18:49:22.643902
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "b388730a2899"
down_revision = "1a03d2c2856b"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column("user", "temperature_override_enabled", nullable=True)
op.alter_column("user", "auto_scroll", nullable=True)
def downgrade() -> None:
# Ensure no null values before making columns non-nullable
op.execute(
'UPDATE "user" SET temperature_override_enabled = false WHERE temperature_override_enabled IS NULL'
)
op.execute('UPDATE "user" SET auto_scroll = false WHERE auto_scroll IS NULL')
op.alter_column("user", "temperature_override_enabled", nullable=False)
op.alter_column("user", "auto_scroll", nullable=False)

View File

@@ -1,124 +0,0 @@
"""Add checkpointing/failure handling
Revision ID: b7a7eee5aa15
Revises: f39c5794c10a
Create Date: 2025-01-24 15:17:36.763172
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "b7a7eee5aa15"
down_revision = "f39c5794c10a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"index_attempt",
sa.Column("checkpoint_pointer", sa.String(), nullable=True),
)
op.add_column(
"index_attempt",
sa.Column("poll_range_start", sa.DateTime(timezone=True), nullable=True),
)
op.add_column(
"index_attempt",
sa.Column("poll_range_end", sa.DateTime(timezone=True), nullable=True),
)
op.create_index(
"ix_index_attempt_cc_pair_settings_poll",
"index_attempt",
[
"connector_credential_pair_id",
"search_settings_id",
"status",
sa.text("time_updated DESC"),
],
)
# Drop the old IndexAttemptError table
op.drop_index("index_attempt_id", table_name="index_attempt_errors")
op.drop_table("index_attempt_errors")
# Create the new version of the table
op.create_table(
"index_attempt_errors",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("index_attempt_id", sa.Integer(), nullable=False),
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False),
sa.Column("document_id", sa.String(), nullable=True),
sa.Column("document_link", sa.String(), nullable=True),
sa.Column("entity_id", sa.String(), nullable=True),
sa.Column("failed_time_range_start", sa.DateTime(timezone=True), nullable=True),
sa.Column("failed_time_range_end", sa.DateTime(timezone=True), nullable=True),
sa.Column("failure_message", sa.Text(), nullable=False),
sa.Column("is_resolved", sa.Boolean(), nullable=False, default=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["index_attempt_id"],
["index_attempt.id"],
),
sa.ForeignKeyConstraint(
["connector_credential_pair_id"],
["connector_credential_pair.id"],
),
)
def downgrade() -> None:
op.execute("SET lock_timeout = '5s'")
# try a few times to drop the table, this has been observed to fail due to other locks
# blocking the drop
NUM_TRIES = 10
for i in range(NUM_TRIES):
try:
op.drop_table("index_attempt_errors")
break
except Exception as e:
if i == NUM_TRIES - 1:
raise e
print(f"Error dropping table: {e}. Retrying...")
op.execute("SET lock_timeout = DEFAULT")
# Recreate the old IndexAttemptError table
op.create_table(
"index_attempt_errors",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("index_attempt_id", sa.Integer(), nullable=True),
sa.Column("batch", sa.Integer(), nullable=True),
sa.Column("doc_summaries", postgresql.JSONB(), nullable=False),
sa.Column("error_msg", sa.Text(), nullable=True),
sa.Column("traceback", sa.Text(), nullable=True),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
),
sa.ForeignKeyConstraint(
["index_attempt_id"],
["index_attempt.id"],
),
)
op.create_index(
"index_attempt_id",
"index_attempt_errors",
["time_created"],
)
op.drop_index("ix_index_attempt_cc_pair_settings_poll")
op.drop_column("index_attempt", "checkpoint_pointer")
op.drop_column("index_attempt", "poll_range_start")
op.drop_column("index_attempt", "poll_range_end")

View File

@@ -1,55 +0,0 @@
"""add background_reindex_enabled field
Revision ID: b7c2b63c4a03
Revises: f11b408e39d3
Create Date: 2024-03-26 12:34:56.789012
"""
from alembic import op
import sqlalchemy as sa
from onyx.db.enums import EmbeddingPrecision
# revision identifiers, used by Alembic.
revision = "b7c2b63c4a03"
down_revision = "f11b408e39d3"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add background_reindex_enabled column with default value of True
op.add_column(
"search_settings",
sa.Column(
"background_reindex_enabled",
sa.Boolean(),
nullable=False,
server_default="true",
),
)
# Add embedding_precision column with default value of FLOAT
op.add_column(
"search_settings",
sa.Column(
"embedding_precision",
sa.Enum(EmbeddingPrecision, native_enum=False),
nullable=False,
server_default=EmbeddingPrecision.FLOAT.name,
),
)
# Add reduced_dimension column with default value of None
op.add_column(
"search_settings",
sa.Column("reduced_dimension", sa.Integer(), nullable=True),
)
def downgrade() -> None:
# Remove the background_reindex_enabled column
op.drop_column("search_settings", "background_reindex_enabled")
op.drop_column("search_settings", "embedding_precision")
op.drop_column("search_settings", "reduced_dimension")

View File

@@ -1,38 +0,0 @@
"""fix_capitalization
Revision ID: be2ab2aa50ee
Revises: 369644546676
Create Date: 2025-01-10 13:13:26.228960
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "be2ab2aa50ee"
down_revision = "369644546676"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
"""
UPDATE document
SET
external_user_group_ids = ARRAY(
SELECT LOWER(unnest(external_user_group_ids))
),
last_modified = NOW()
WHERE
external_user_group_ids IS NOT NULL
AND external_user_group_ids::text[] <> ARRAY(
SELECT LOWER(unnest(external_user_group_ids))
)::text[]
"""
)
def downgrade() -> None:
# No way to cleanly persist the bad state through an upgrade/downgrade
# cycle, so we just pass
pass

View File

@@ -1,36 +0,0 @@
"""Add chat_message__standard_answer table
Revision ID: c5eae4a75a1b
Revises: 0f7ff6d75b57
Create Date: 2025-01-15 14:08:49.688998
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c5eae4a75a1b"
down_revision = "0f7ff6d75b57"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"chat_message__standard_answer",
sa.Column("chat_message_id", sa.Integer(), nullable=False),
sa.Column("standard_answer_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["chat_message_id"],
["chat_message.id"],
),
sa.ForeignKeyConstraint(
["standard_answer_id"],
["standard_answer.id"],
),
sa.PrimaryKeyConstraint("chat_message_id", "standard_answer_id"),
)
def downgrade() -> None:
op.drop_table("chat_message__standard_answer")

View File

@@ -1,48 +0,0 @@
"""Add has_been_indexed to DocumentByConnectorCredentialPair
Revision ID: c7bf5721733e
Revises: fec3db967bf7
Create Date: 2025-01-13 12:39:05.831693
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c7bf5721733e"
down_revision = "027381bce97c"
branch_labels = None
depends_on = None
def upgrade() -> None:
# assume all existing rows have been indexed, no better approach
op.add_column(
"document_by_connector_credential_pair",
sa.Column("has_been_indexed", sa.Boolean(), nullable=True),
)
op.execute(
"UPDATE document_by_connector_credential_pair SET has_been_indexed = TRUE"
)
op.alter_column(
"document_by_connector_credential_pair",
"has_been_indexed",
nullable=False,
)
# Add index to optimize get_document_counts_for_cc_pairs query pattern
op.create_index(
"idx_document_cc_pair_counts",
"document_by_connector_credential_pair",
["connector_id", "credential_id", "has_been_indexed"],
unique=False,
)
def downgrade() -> None:
# Remove the index first before removing the column
op.drop_index(
"idx_document_cc_pair_counts",
table_name="document_by_connector_credential_pair",
)
op.drop_column("document_by_connector_credential_pair", "has_been_indexed")

View File

@@ -1,120 +0,0 @@
"""migrate jira connectors to new format
Revision ID: da42808081e3
Revises: f13db29f3101
Create Date: 2025-02-24 11:24:54.396040
"""
from alembic import op
import sqlalchemy as sa
import json
from onyx.configs.constants import DocumentSource
from onyx.connectors.onyx_jira.utils import extract_jira_project
# revision identifiers, used by Alembic.
revision = "da42808081e3"
down_revision = "f13db29f3101"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Get all Jira connectors
conn = op.get_bind()
# First get all Jira connectors
jira_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = :source
"""
),
{"source": DocumentSource.JIRA.value.upper()},
).fetchall()
# Update each connector's config
for connector_id, old_config in jira_connectors:
if not old_config:
continue
# Extract project key from URL if it exists
new_config: dict[str, str | None] = {}
if project_url := old_config.get("jira_project_url"):
# Parse the URL to get base and project
try:
jira_base, project_key = extract_jira_project(project_url)
new_config = {"jira_base_url": jira_base, "project_key": project_key}
except ValueError:
# If URL parsing fails, just use the URL as the base
new_config = {
"jira_base_url": project_url.split("/projects/")[0],
"project_key": None,
}
else:
# For connectors without a project URL, we need admin intervention
# Mark these for review
print(
f"WARNING: Jira connector {connector_id} has no project URL configured"
)
continue
# Update the connector config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :id
"""
),
{"id": connector_id, "new_config": json.dumps(new_config)},
)
def downgrade() -> None:
# Get all Jira connectors
conn = op.get_bind()
# First get all Jira connectors
jira_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = :source
"""
),
{"source": DocumentSource.JIRA.value.upper()},
).fetchall()
# Update each connector's config back to the old format
for connector_id, new_config in jira_connectors:
if not new_config:
continue
old_config = {}
base_url = new_config.get("jira_base_url")
project_key = new_config.get("project_key")
if base_url and project_key:
old_config = {"jira_project_url": f"{base_url}/projects/{project_key}"}
elif base_url:
old_config = {"jira_project_url": base_url}
else:
continue
# Update the connector config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :old_config
WHERE id = :id
"""
),
{"id": connector_id, "old_config": old_config},
)

View File

@@ -1,36 +0,0 @@
"""add_default_vision_provider_to_llm_provider
Revision ID: df46c75b714e
Revises: 3934b1bc7b62
Create Date: 2025-03-11 16:20:19.038945
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "df46c75b714e"
down_revision = "3934b1bc7b62"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"llm_provider",
sa.Column(
"is_default_vision_provider",
sa.Boolean(),
nullable=True,
server_default=sa.false(),
),
)
op.add_column(
"llm_provider", sa.Column("default_vision_model", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("llm_provider", "default_vision_model")
op.drop_column("llm_provider", "is_default_vision_provider")

View File

@@ -1,80 +0,0 @@
"""add default slack channel config
Revision ID: eaa3b5593925
Revises: 98a5008d8711
Create Date: 2025-02-03 18:07:56.552526
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "eaa3b5593925"
down_revision = "98a5008d8711"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add is_default column
op.add_column(
"slack_channel_config",
sa.Column("is_default", sa.Boolean(), nullable=False, server_default="false"),
)
op.create_index(
"ix_slack_channel_config_slack_bot_id_default",
"slack_channel_config",
["slack_bot_id", "is_default"],
unique=True,
postgresql_where=sa.text("is_default IS TRUE"),
)
# Create default channel configs for existing slack bots without one
conn = op.get_bind()
slack_bots = conn.execute(sa.text("SELECT id FROM slack_bot")).fetchall()
for slack_bot in slack_bots:
slack_bot_id = slack_bot[0]
existing_default = conn.execute(
sa.text(
"SELECT id FROM slack_channel_config WHERE slack_bot_id = :bot_id AND is_default = TRUE"
),
{"bot_id": slack_bot_id},
).fetchone()
if not existing_default:
conn.execute(
sa.text(
"""
INSERT INTO slack_channel_config (
slack_bot_id, persona_id, channel_config, enable_auto_filters, is_default
) VALUES (
:bot_id, NULL,
'{"channel_name": null, '
'"respond_member_group_list": [], '
'"answer_filters": [], '
'"follow_up_tags": [], '
'"respond_tag_only": true}',
FALSE, TRUE
)
"""
),
{"bot_id": slack_bot_id},
)
def downgrade() -> None:
# Delete default slack channel configs
conn = op.get_bind()
conn.execute(sa.text("DELETE FROM slack_channel_config WHERE is_default = TRUE"))
# Remove index
op.drop_index(
"ix_slack_channel_config_slack_bot_id_default",
table_name="slack_channel_config",
)
# Remove is_default column
op.drop_column("slack_channel_config", "is_default")

View File

@@ -1,36 +0,0 @@
"""force lowercase all users
Revision ID: f11b408e39d3
Revises: 3bd4c84fe72f
Create Date: 2025-02-26 17:04:55.683500
"""
# revision identifiers, used by Alembic.
revision = "f11b408e39d3"
down_revision = "3bd4c84fe72f"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 1) Convert all existing user emails to lowercase
from alembic import op
op.execute(
"""
UPDATE "user"
SET email = LOWER(email)
"""
)
# 2) Add a check constraint to ensure emails are always lowercase
op.create_check_constraint("ensure_lowercase_email", "user", "email = LOWER(email)")
def downgrade() -> None:
# Drop the check constraint
from alembic import op
op.drop_constraint("ensure_lowercase_email", "user", type_="check")

View File

@@ -1,27 +0,0 @@
"""Add composite index for last_modified and last_synced to document
Revision ID: f13db29f3101
Revises: b388730a2899
Create Date: 2025-02-18 22:48:11.511389
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "f13db29f3101"
down_revision = "acaab4ef4507"
branch_labels: str | None = None
depends_on: str | None = None
def upgrade() -> None:
op.create_index(
"ix_document_sync_status",
"document",
["last_modified", "last_synced"],
unique=False,
)
def downgrade() -> None:
op.drop_index("ix_document_sync_status", table_name="document")

View File

@@ -1,33 +0,0 @@
"""add passthrough auth to tool
Revision ID: f1ca58b2f2ec
Revises: c7bf5721733e
Create Date: 2024-03-19
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "f1ca58b2f2ec"
down_revision: Union[str, None] = "c7bf5721733e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add passthrough_auth column to tool table with default value of False
op.add_column(
"tool",
sa.Column(
"passthrough_auth", sa.Boolean(), nullable=False, server_default=sa.false()
),
)
def downgrade() -> None:
# Remove passthrough_auth column from tool table
op.drop_column("tool", "passthrough_auth")

View File

@@ -1,40 +0,0 @@
"""Add background errors table
Revision ID: f39c5794c10a
Revises: 2cdeff6d8c93
Create Date: 2025-02-12 17:11:14.527876
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f39c5794c10a"
down_revision = "2cdeff6d8c93"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"background_error",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("message", sa.String(), nullable=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("cc_pair_id", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["cc_pair_id"],
["connector_credential_pair.id"],
ondelete="CASCADE",
),
)
def downgrade() -> None:
op.drop_table("background_error")

View File

@@ -1,53 +0,0 @@
"""delete non-search assistants
Revision ID: f5437cc136c5
Revises: eaa3b5593925
Create Date: 2025-02-04 16:17:15.677256
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "f5437cc136c5"
down_revision = "eaa3b5593925"
branch_labels = None
depends_on = None
def upgrade() -> None:
pass
def downgrade() -> None:
# Fix: split the statements into multiple op.execute() calls
op.execute(
"""
WITH personas_without_search AS (
SELECT p.id
FROM persona p
LEFT JOIN persona__tool pt ON p.id = pt.persona_id
LEFT JOIN tool t ON pt.tool_id = t.id
GROUP BY p.id
HAVING COUNT(CASE WHEN t.in_code_tool_id = 'run_search' THEN 1 END) = 0
)
UPDATE slack_channel_config
SET persona_id = NULL
WHERE is_default = TRUE AND persona_id IN (SELECT id FROM personas_without_search)
"""
)
op.execute(
"""
WITH personas_without_search AS (
SELECT p.id
FROM persona p
LEFT JOIN persona__tool pt ON p.id = pt.persona_id
LEFT JOIN tool t ON pt.tool_id = t.id
GROUP BY p.id
HAVING COUNT(CASE WHEN t.in_code_tool_id = 'run_search' THEN 1 END) = 0
)
DELETE FROM slack_channel_config
WHERE is_default = FALSE AND persona_id IN (SELECT id FROM personas_without_search)
"""
)

View File

@@ -1,50 +0,0 @@
"""add prompt length limit
Revision ID: f71470ba9274
Revises: 6a804aeb4830
Create Date: 2025-04-01 15:07:14.977435
"""
# revision identifiers, used by Alembic.
revision = "f71470ba9274"
down_revision = "6a804aeb4830"
branch_labels = None
depends_on = None
def upgrade() -> None:
# op.alter_column(
# "prompt",
# "system_prompt",
# existing_type=sa.TEXT(),
# type_=sa.String(length=8000),
# existing_nullable=False,
# )
# op.alter_column(
# "prompt",
# "task_prompt",
# existing_type=sa.TEXT(),
# type_=sa.String(length=8000),
# existing_nullable=False,
# )
pass
def downgrade() -> None:
# op.alter_column(
# "prompt",
# "system_prompt",
# existing_type=sa.String(length=8000),
# type_=sa.TEXT(),
# existing_nullable=False,
# )
# op.alter_column(
# "prompt",
# "task_prompt",
# existing_type=sa.String(length=8000),
# type_=sa.TEXT(),
# existing_nullable=False,
# )
pass

View File

@@ -1,77 +0,0 @@
"""updated constraints for ccpairs
Revision ID: f7505c5b0284
Revises: f71470ba9274
Create Date: 2025-04-01 17:50:42.504818
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "f7505c5b0284"
down_revision = "f71470ba9274"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 1) Drop the old foreign-key constraints
op.drop_constraint(
"document_by_connector_credential_pair_connector_id_fkey",
"document_by_connector_credential_pair",
type_="foreignkey",
)
op.drop_constraint(
"document_by_connector_credential_pair_credential_id_fkey",
"document_by_connector_credential_pair",
type_="foreignkey",
)
# 2) Re-add them with ondelete='CASCADE'
op.create_foreign_key(
"document_by_connector_credential_pair_connector_id_fkey",
source_table="document_by_connector_credential_pair",
referent_table="connector",
local_cols=["connector_id"],
remote_cols=["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
"document_by_connector_credential_pair_credential_id_fkey",
source_table="document_by_connector_credential_pair",
referent_table="credential",
local_cols=["credential_id"],
remote_cols=["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
# Reverse the changes for rollback
op.drop_constraint(
"document_by_connector_credential_pair_connector_id_fkey",
"document_by_connector_credential_pair",
type_="foreignkey",
)
op.drop_constraint(
"document_by_connector_credential_pair_credential_id_fkey",
"document_by_connector_credential_pair",
type_="foreignkey",
)
# Recreate without CASCADE
op.create_foreign_key(
"document_by_connector_credential_pair_connector_id_fkey",
"document_by_connector_credential_pair",
"connector",
["connector_id"],
["id"],
)
op.create_foreign_key(
"document_by_connector_credential_pair_credential_id_fkey",
"document_by_connector_credential_pair",
"credential",
["credential_id"],
["id"],
)

View File

@@ -1,41 +0,0 @@
"""Add time_updated to UserGroup and DocumentSet
Revision ID: fec3db967bf7
Revises: 97dbb53fa8c8
Create Date: 2025-01-12 15:49:02.289100
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "fec3db967bf7"
down_revision = "97dbb53fa8c8"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"document_set",
sa.Column(
"time_last_modified_by_user",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
op.add_column(
"user_group",
sa.Column(
"time_last_modified_by_user",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
def downgrade() -> None:
op.drop_column("user_group", "time_last_modified_by_user")
op.drop_column("document_set", "time_last_modified_by_user")

View File

@@ -1,42 +0,0 @@
"""lowercase multi-tenant user auth
Revision ID: 34e3630c7f32
Revises: a4f6ee863c47
Create Date: 2025-02-26 15:03:01.211894
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "34e3630c7f32"
down_revision = "a4f6ee863c47"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 1) Convert all existing rows to lowercase
op.execute(
"""
UPDATE user_tenant_mapping
SET email = LOWER(email)
"""
)
# 2) Add a check constraint so that emails cannot be written in uppercase
op.create_check_constraint(
"ensure_lowercase_email",
"user_tenant_mapping",
"email = LOWER(email)",
schema="public",
)
def downgrade() -> None:
# Drop the check constraint
op.drop_constraint(
"ensure_lowercase_email",
"user_tenant_mapping",
schema="public",
type_="check",
)

View File

@@ -1,33 +0,0 @@
"""add new available tenant table
Revision ID: 3b45e0018bf1
Revises: ac842f85f932
Create Date: 2025-03-06 09:55:18.229910
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "3b45e0018bf1"
down_revision = "ac842f85f932"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create new_available_tenant table
op.create_table(
"available_tenant",
sa.Column("tenant_id", sa.String(), nullable=False),
sa.Column("alembic_version", sa.String(), nullable=False),
sa.Column("date_created", sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint("tenant_id"),
)
def downgrade() -> None:
# Drop new_available_tenant table
op.drop_table("available_tenant")

View File

@@ -1,51 +0,0 @@
"""new column user tenant mapping
Revision ID: ac842f85f932
Revises: 34e3630c7f32
Create Date: 2025-03-03 13:30:14.802874
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "ac842f85f932"
down_revision = "34e3630c7f32"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add active column with default value of True
op.add_column(
"user_tenant_mapping",
sa.Column(
"active",
sa.Boolean(),
nullable=False,
server_default="true",
),
schema="public",
)
op.drop_constraint("uq_email", "user_tenant_mapping", schema="public")
# Create a unique index for active=true records
# This ensures a user can only be active in one tenant at a time
op.execute(
"CREATE UNIQUE INDEX uq_user_active_email_idx ON public.user_tenant_mapping (email) WHERE active = true"
)
def downgrade() -> None:
# Drop the unique index for active=true records
op.execute("DROP INDEX IF EXISTS uq_user_active_email_idx")
op.create_unique_constraint(
"uq_email", "user_tenant_mapping", ["email"], schema="public"
)
# Remove the active column
op.drop_column("user_tenant_mapping", "active", schema="public")

View File

@@ -93,12 +93,12 @@ def _get_access_for_documents(
)
# To avoid collisions of group namings between connectors, they need to be prefixed
access_map[document_id] = DocumentAccess.build(
user_emails=list(non_ee_access.user_emails),
user_groups=user_group_info.get(document_id, []),
access_map[document_id] = DocumentAccess(
user_emails=non_ee_access.user_emails,
user_groups=set(user_group_info.get(document_id, [])),
is_public=is_public_anywhere,
external_user_emails=list(ext_u_emails),
external_user_group_ids=list(ext_u_groups),
external_user_emails=ext_u_emails,
external_user_group_ids=ext_u_groups,
)
return access_map

View File

@@ -4,11 +4,12 @@ from ee.onyx.server.reporting.usage_export_generation import create_new_usage_re
from onyx.background.celery.apps.primary import celery_app
from onyx.background.task_utils import build_celery_task_wrapper
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.db.chat import delete_chat_session
from onyx.db.chat import get_chat_sessions_older_than
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.chat import delete_chat_sessions_older_than
from onyx.db.engine import get_session_with_tenant
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -17,28 +18,11 @@ logger = setup_logger()
@build_celery_task_wrapper(name_chat_ttl_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None:
with get_session_with_current_tenant() as db_session:
old_chat_sessions = get_chat_sessions_older_than(
retention_limit_days, db_session
)
for user_id, session_id in old_chat_sessions:
# one session per delete so that we don't blow up if a deletion fails.
with get_session_with_current_tenant() as db_session:
try:
delete_chat_session(
user_id,
session_id,
db_session,
include_deleted=True,
hard_delete=True,
)
except Exception:
logger.exception(
"delete_chat_session exceptioned. "
f"user_id={user_id} session_id={session_id}"
)
def perform_ttl_management_task(
retention_limit_days: int, *, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id) as db_session:
delete_chat_sessions_older_than(retention_limit_days, db_session)
#####
@@ -48,32 +32,35 @@ def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) ->
@celery_app.task(
name="check_ttl_management_task",
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def check_ttl_management_task(*, tenant_id: str) -> None:
def check_ttl_management_task(*, tenant_id: str | None) -> None:
"""Runs periodically to check if any ttl tasks should be run and adds them
to the queue"""
token = None
if MULTI_TENANT and tenant_id is not None:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
settings = load_settings()
retention_limit_days = settings.maximum_chat_retention_days
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
if should_perform_chat_ttl_check(retention_limit_days, db_session):
perform_ttl_management_task.apply_async(
kwargs=dict(
retention_limit_days=retention_limit_days, tenant_id=tenant_id
),
)
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@celery_app.task(
name="autogenerate_usage_report_task",
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def autogenerate_usage_report_task(*, tenant_id: str) -> None:
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
"""This generates usage report under the /admin/generate-usage/report endpoint"""
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
create_new_usage_report(
db_session=db_session,
user_id=None,

View File

@@ -2,79 +2,23 @@ from datetime import timedelta
from typing import Any
from onyx.background.celery.tasks.beat_schedule import (
beat_cloud_tasks as base_beat_system_tasks,
tasks_to_schedule as base_tasks_to_schedule,
)
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
from onyx.background.celery.tasks.beat_schedule import (
beat_task_templates as base_beat_task_templates,
)
from onyx.background.celery.tasks.beat_schedule import generate_cloud_tasks
from onyx.background.celery.tasks.beat_schedule import (
get_tasks_to_schedule as base_get_tasks_to_schedule,
)
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from shared_configs.configs import MULTI_TENANT
ee_beat_system_tasks: list[dict] = []
ee_beat_task_templates: list[dict] = []
ee_beat_task_templates.extend(
[
{
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
)
ee_tasks_to_schedule: list[dict] = []
if not MULTI_TENANT:
ee_tasks_to_schedule = [
{
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30), # TODO: change this to config flag
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
def get_cloud_tasks_to_schedule(beat_multiplier: float) -> list[dict[str, Any]]:
beat_system_tasks = ee_beat_system_tasks + base_beat_system_tasks
beat_task_templates = ee_beat_task_templates + base_beat_task_templates
cloud_tasks = generate_cloud_tasks(
beat_system_tasks, beat_task_templates, beat_multiplier
)
return cloud_tasks
ee_tasks_to_schedule = [
{
"name": "autogenerate_usage_report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30), # TODO: change this to config flag
},
{
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=1),
},
]
def get_tasks_to_schedule() -> list[dict[str, Any]]:
return ee_tasks_to_schedule + base_get_tasks_to_schedule()
return ee_tasks_to_schedule + base_tasks_to_schedule

View File

@@ -8,9 +8,6 @@ from ee.onyx.db.user_group import fetch_user_group
from ee.onyx.db.user_group import mark_user_group_as_synced
from ee.onyx.db.user_group import prepare_user_group_for_deletion
from onyx.background.celery.apps.app_base import task_logger
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.sync_record import update_sync_record_status
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import setup_logger
@@ -18,7 +15,7 @@ logger = setup_logger()
def monitor_usergroup_taskset(
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
"""This function is likely to move in the worker refactor happening next."""
fence_key = key_bytes.decode("utf-8")
@@ -46,59 +43,24 @@ def monitor_usergroup_taskset(
f"User group sync progress: usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
)
if count > 0:
update_sync_record_status(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
sync_status=SyncStatus.IN_PROGRESS,
num_docs_synced=count,
)
return
user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id)
if user_group:
usergroup_name = user_group.name
try:
if user_group.is_up_for_deletion:
# this prepare should have been run when the deletion was scheduled,
# but run it again to be sure we're ready to go
mark_user_group_as_synced(db_session, user_group)
prepare_user_group_for_deletion(db_session, usergroup_id)
delete_user_group(db_session=db_session, user_group=user_group)
update_sync_record_status(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=initial_count,
)
task_logger.info(
f"Deleted usergroup: name={usergroup_name} id={usergroup_id}"
)
else:
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
update_sync_record_status(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=initial_count,
)
task_logger.info(
f"Synced usergroup. name={usergroup_name} id={usergroup_id}"
)
except Exception as e:
update_sync_record_status(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
sync_status=SyncStatus.FAILED,
num_docs_synced=initial_count,
if user_group.is_up_for_deletion:
# this prepare should have been run when the deletion was scheduled,
# but run it again to be sure we're ready to go
mark_user_group_as_synced(db_session, user_group)
prepare_user_group_for_deletion(db_session, usergroup_id)
delete_user_group(db_session=db_session, user_group=user_group)
task_logger.info(
f"Deleted usergroup: name={usergroup_name} id={usergroup_id}"
)
else:
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
task_logger.info(
f"Synced usergroup. name={usergroup_name} id={usergroup_id}"
)
raise e
rug.reset()

View File

@@ -2,6 +2,7 @@ from ee.onyx.server.query_and_chat.models import OneShotQAResponse
from onyx.chat.models import AllCitations
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import OnyxContexts
from onyx.chat.models import QADocsResponse
from onyx.chat.models import StreamingError
from onyx.chat.process_message import ChatPacketStream
@@ -31,6 +32,8 @@ def gather_stream_for_answer_api(
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
elif isinstance(packet, AllCitations):
response.citations = packet.citations
elif isinstance(packet, OnyxContexts):
response.contexts = packet
if answer:
response.answer = answer

View File

@@ -4,20 +4,6 @@ import os
# Applicable for OIDC Auth
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
# Applicable for OIDC Auth, allows you to override the scopes that
# are requested from the OIDC provider. Currently used when passing
# over access tokens to tool calls and the tool needs more scopes
OIDC_SCOPE_OVERRIDE: list[str] | None = None
_OIDC_SCOPE_OVERRIDE = os.environ.get("OIDC_SCOPE_OVERRIDE")
if _OIDC_SCOPE_OVERRIDE:
try:
OIDC_SCOPE_OVERRIDE = [
scope.strip() for scope in _OIDC_SCOPE_OVERRIDE.split(",")
]
except Exception:
pass
# Applicable for SAML Auth
SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_config"
@@ -25,10 +11,6 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_co
#####
# Auto Permission Sync
#####
DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
# In seconds, default is 5 minutes
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
@@ -43,7 +25,6 @@ CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
@@ -64,26 +45,15 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
"OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", ""
)
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get(
"OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", ""
)
OAUTH_JIRA_CLOUD_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_ID", "")
OAUTH_JIRA_CLOUD_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_SECRET", "")
OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "")
OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "")
OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "")
OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "")
OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "")
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
)
GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
)
SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("SLACK_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
# The posthog client does not accept empty API keys or hosts however it fails silently
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
@@ -93,5 +63,3 @@ POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
ANONYMOUS_USER_COOKIE_NAME = "onyx_anonymous_user"
GATED_TENANTS_KEY = "gated_tenants"

View File

@@ -345,8 +345,7 @@ def fetch_assistant_unique_users_total(
def user_can_view_assistant_stats(
db_session: Session, user: User | None, assistant_id: int
) -> bool:
# If user is None and auth is disabled, assume the user is an admin
# If user is None, assume the user is an admin or auth is disabled
if user is None or user.role == UserRole.ADMIN:
return True

View File

@@ -4,7 +4,6 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import UserGroup__ConnectorCredentialPair
@@ -36,11 +35,10 @@ def _delete_connector_credential_pair_user_groups_relationship__no_commit(
def get_cc_pairs_by_source(
db_session: Session,
source_type: DocumentSource,
access_type: AccessType | None = None,
status: ConnectorCredentialPairStatus | None = None,
only_sync: bool,
) -> list[ConnectorCredentialPair]:
"""
Get all cc_pairs for a given source type with optional filtering by access_type and status
Get all cc_pairs for a given source type (and optionally only sync)
result is sorted by cc_pair id
"""
query = (
@@ -50,11 +48,8 @@ def get_cc_pairs_by_source(
.order_by(ConnectorCredentialPair.id)
)
if access_type is not None:
query = query.filter(ConnectorCredentialPair.access_type == access_type)
if status is not None:
query = query.filter(ConnectorCredentialPair.status == status)
if only_sync:
query = query.filter(ConnectorCredentialPair.access_type == AccessType.SYNC)
cc_pairs = query.all()
return cc_pairs

View File

@@ -5,7 +5,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.access.models import ExternalAccess
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.access.utils import prefix_group_w_source
from onyx.configs.constants import DocumentSource
from onyx.db.models import Document as DbDocument
@@ -25,7 +25,7 @@ def upsert_document_external_perms__no_commit(
).first()
prefixed_external_groups = [
build_ext_group_name_for_onyx(
prefix_group_w_source(
ext_group_name=group_id,
source=source_type,
)
@@ -66,7 +66,7 @@ def upsert_document_external_perms(
).first()
prefixed_external_groups: set[str] = {
build_ext_group_name_for_onyx(
prefix_group_w_source(
ext_group_name=group_id,
source=source_type,
)

View File

@@ -6,9 +6,8 @@ from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.access.utils import prefix_group_w_source
from onyx.configs.constants import DocumentSource
from onyx.db.models import User
from onyx.db.models import User__ExternalUserGroupId
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
from onyx.db.users import get_user_by_email
@@ -62,10 +61,8 @@ def replace_user__ext_group_for_cc_pair(
all_group_member_emails.add(user_email)
# batch add users if they don't exist and get their ids
all_group_members: list[User] = batch_add_ext_perm_user_if_not_exists(
db_session=db_session,
# NOTE: this function handles case sensitivity for emails
emails=list(all_group_member_emails),
all_group_members = batch_add_ext_perm_user_if_not_exists(
db_session=db_session, emails=list(all_group_member_emails)
)
delete_user__ext_group_for_cc_pair__no_commit(
@@ -87,14 +84,12 @@ def replace_user__ext_group_for_cc_pair(
f" with email {user_email} not found"
)
continue
external_group_id = build_ext_group_name_for_onyx(
ext_group_name=external_group.id,
source=source,
)
new_external_permissions.append(
User__ExternalUserGroupId(
user_id=user_id,
external_user_group_id=external_group_id,
external_user_group_id=prefix_group_w_source(
external_group.id, source
),
cc_pair_id=cc_pair_id,
)
)

View File

@@ -2,11 +2,8 @@ from uuid import UUID
from sqlalchemy.orm import Session
from onyx.configs.constants import NotificationType
from onyx.db.models import Persona__User
from onyx.db.models import Persona__UserGroup
from onyx.db.notification import create_notification
from onyx.server.features.persona.models import PersonaSharedNotificationData
def make_persona_private(
@@ -15,9 +12,6 @@ def make_persona_private(
group_ids: list[int] | None,
db_session: Session,
) -> None:
"""NOTE(rkuo): This function batches all updates into a single commit. If we don't
dedupe the inputs, the commit will exception."""
db_session.query(Persona__User).filter(
Persona__User.persona_id == persona_id
).delete(synchronize_session="fetch")
@@ -26,22 +20,11 @@ def make_persona_private(
).delete(synchronize_session="fetch")
if user_ids:
user_ids_set = set(user_ids)
for user_id in user_ids_set:
db_session.add(Persona__User(persona_id=persona_id, user_id=user_id))
create_notification(
user_id=user_id,
notif_type=NotificationType.PERSONA_SHARED,
db_session=db_session,
additional_data=PersonaSharedNotificationData(
persona_id=persona_id,
).model_dump(),
)
for user_uuid in user_ids:
db_session.add(Persona__User(persona_id=persona_id, user_id=user_uuid))
if group_ids:
group_ids_set = set(group_ids)
for group_id in group_ids_set:
for group_id in group_ids:
db_session.add(
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
)

View File

@@ -1,142 +1,29 @@
from collections.abc import Sequence
from datetime import datetime
import datetime
from typing import Literal
from sqlalchemy import asc
from sqlalchemy import BinaryExpression
from sqlalchemy import ColumnElement
from sqlalchemy import desc
from sqlalchemy import distinct
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql import case
from sqlalchemy.sql import func
from sqlalchemy.sql import select
from sqlalchemy.sql.expression import literal
from sqlalchemy.sql.expression import UnaryExpression
from onyx.configs.constants import QAFeedbackType
from onyx.db.models import ChatMessage
from onyx.db.models import ChatMessageFeedback
from onyx.db.models import ChatSession
def _build_filter_conditions(
start_time: datetime | None,
end_time: datetime | None,
feedback_filter: QAFeedbackType | None,
) -> list[ColumnElement]:
"""
Helper function to build all filter conditions for chat sessions.
Filters by start and end time, feedback type, and any sessions without messages.
start_time: Date from which to filter
end_time: Date to which to filter
feedback_filter: Feedback type to filter by
Returns: List of filter conditions
"""
conditions = []
if start_time is not None:
conditions.append(ChatSession.time_created >= start_time)
if end_time is not None:
conditions.append(ChatSession.time_created <= end_time)
if feedback_filter is not None:
feedback_subq = (
select(ChatMessage.chat_session_id)
.join(ChatMessageFeedback)
.group_by(ChatMessage.chat_session_id)
.having(
case(
(
case(
{literal(feedback_filter == QAFeedbackType.LIKE): True},
else_=False,
),
func.bool_and(ChatMessageFeedback.is_positive),
),
(
case(
{literal(feedback_filter == QAFeedbackType.DISLIKE): True},
else_=False,
),
func.bool_and(func.not_(ChatMessageFeedback.is_positive)),
),
else_=func.bool_or(ChatMessageFeedback.is_positive)
& func.bool_or(func.not_(ChatMessageFeedback.is_positive)),
)
)
)
conditions.append(ChatSession.id.in_(feedback_subq))
return conditions
def get_total_filtered_chat_sessions_count(
db_session: Session,
start_time: datetime | None,
end_time: datetime | None,
feedback_filter: QAFeedbackType | None,
) -> int:
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
stmt = (
select(func.count(distinct(ChatSession.id)))
.select_from(ChatSession)
.filter(*conditions)
)
return db_session.scalar(stmt) or 0
def get_page_of_chat_sessions(
start_time: datetime | None,
end_time: datetime | None,
db_session: Session,
page_num: int,
page_size: int,
feedback_filter: QAFeedbackType | None = None,
) -> Sequence[ChatSession]:
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
subquery = (
select(ChatSession.id)
.filter(*conditions)
.order_by(desc(ChatSession.time_created), ChatSession.id)
.limit(page_size)
.offset(page_num * page_size)
.subquery()
)
stmt = (
select(ChatSession)
.join(subquery, ChatSession.id == subquery.c.id)
.outerjoin(ChatMessage, ChatSession.id == ChatMessage.chat_session_id)
.options(
joinedload(ChatSession.user),
joinedload(ChatSession.persona),
contains_eager(ChatSession.messages).joinedload(
ChatMessage.chat_message_feedbacks
),
)
.order_by(
desc(ChatSession.time_created),
ChatSession.id,
asc(ChatMessage.id), # Ensure chronological message order
)
)
return db_session.scalars(stmt).unique().all()
SortByOptions = Literal["time_sent"]
def fetch_chat_sessions_eagerly_by_time(
start: datetime,
end: datetime,
start: datetime.datetime,
end: datetime.datetime,
db_session: Session,
limit: int | None = 500,
initial_time: datetime | None = None,
initial_time: datetime.datetime | None = None,
) -> list[ChatSession]:
"""Sorted by oldest to newest, then by message id"""
asc_time_order: UnaryExpression = asc(ChatSession.time_created)
time_order: UnaryExpression = desc(ChatSession.time_created)
message_order: UnaryExpression = asc(ChatMessage.id)
filters: list[ColumnElement | BinaryExpression] = [
@@ -149,7 +36,8 @@ def fetch_chat_sessions_eagerly_by_time(
subquery = (
db_session.query(ChatSession.id, ChatSession.time_created)
.filter(*filters)
.order_by(asc_time_order)
.order_by(ChatSession.id, time_order)
.distinct(ChatSession.id)
.limit(limit)
.subquery()
)
@@ -165,7 +53,7 @@ def fetch_chat_sessions_eagerly_by_time(
ChatMessage.chat_message_feedbacks
),
)
.order_by(asc_time_order, message_order)
.order_by(time_order, message_order)
)
chat_sessions = query.all()

View File

@@ -7,7 +7,6 @@ from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.constants import TokenRateLimitScope
from onyx.db.models import TokenRateLimit
from onyx.db.models import TokenRateLimit__UserGroup
@@ -21,11 +20,10 @@ from onyx.server.token_rate_limits.models import TokenRateLimitArgs
def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
) -> Select:
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
# If user is None, assume the user is an admin or auth is disabled
if user is None or user.role == UserRole.ADMIN:
return stmt
stmt = stmt.distinct()
TRLimit_UG = aliased(TokenRateLimit__UserGroup)
User__UG = aliased(User__UserGroup)
@@ -48,12 +46,6 @@ def _add_user_filters(
that the user isn't a curator for
- if we are not editing, we show all token_rate_limits in the groups the user curates
"""
# If user is None, this is an anonymous user and we should only show public token_rate_limits
if user is None:
where_clause = TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
return stmt.where(where_clause)
where_clause = User__UG.user_id == user.id
if user.role == UserRole.CURATOR and get_editable:
where_clause &= User__UG.is_curator == True # noqa: E712
@@ -111,10 +103,10 @@ def insert_user_group_token_rate_limit(
return token_limit
def fetch_user_group_token_rate_limits_for_user(
def fetch_user_group_token_rate_limits(
db_session: Session,
group_id: int,
user: User | None,
user: User | None = None,
enabled_only: bool = False,
ordered: bool = True,
get_editable: bool = True,

View File

@@ -16,20 +16,13 @@ from onyx.db.models import UsageReport
from onyx.file_store.file_store import get_default_file_store
# Gets skeletons of all messages in the given range
# Gets skeletons of all message
def get_empty_chat_messages_entries__paginated(
db_session: Session,
period: tuple[datetime, datetime],
limit: int | None = 500,
initial_time: datetime | None = None,
) -> tuple[Optional[datetime], list[ChatMessageSkeleton]]:
"""Returns a tuple where:
first element is the most recent timestamp out of the sessions iterated
- this timestamp can be used to paginate forward in time
second element is a list of messages belonging to all the sessions iterated
Only messages of type USER are returned
"""
chat_sessions = fetch_chat_sessions_eagerly_by_time(
start=period[0],
end=period[1],
@@ -59,17 +52,18 @@ def get_empty_chat_messages_entries__paginated(
if len(chat_sessions) == 0:
return None, []
return chat_sessions[-1].time_created, message_skeletons
return chat_sessions[0].time_created, message_skeletons
def get_all_empty_chat_message_entries(
db_session: Session,
period: tuple[datetime, datetime],
) -> Generator[list[ChatMessageSkeleton], None, None]:
"""period is the range of time over which to fetch messages."""
initial_time: Optional[datetime] = period[0]
ind = 0
while True:
# iterate from oldest to newest
ind += 1
time_created, message_skeletons = get_empty_chat_messages_entries__paginated(
db_session,
period,

View File

@@ -218,14 +218,14 @@ def fetch_user_groups_for_user(
return db_session.scalars(stmt).all()
def construct_document_id_select_by_usergroup(
def construct_document_select_by_usergroup(
user_group_id: int,
) -> Select:
"""This returns a statement that should be executed using
.yield_per() to minimize overhead. The primary consumers of this function
are background processing task generators."""
stmt = (
select(Document.id)
select(Document)
.join(
DocumentByConnectorCredentialPair,
Document.id == DocumentByConnectorCredentialPair.id,
@@ -374,9 +374,7 @@ def _add_user_group__cc_pair_relationships__no_commit(
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
db_user_group = UserGroup(
name=user_group.name, time_last_modified_by_user=func.now()
)
db_user_group = UserGroup(name=user_group.name)
db_session.add(db_user_group)
db_session.flush() # give the group an ID
@@ -424,7 +422,7 @@ def _validate_curator_status__no_commit(
)
# if the user is a curator in any of their groups, set their role to CURATOR
# otherwise, set their role to BASIC only if they were previously a CURATOR
# otherwise, set their role to BASIC
if curator_relationships:
user.role = UserRole.CURATOR
elif user.role == UserRole.CURATOR:
@@ -631,20 +629,7 @@ def update_user_group(
removed_users = db_session.scalars(
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
).unique()
# Filter out admin and global curator users before validating curator status
users_to_validate = [
user
for user in removed_users
if user.role not in [UserRole.ADMIN, UserRole.GLOBAL_CURATOR]
]
if users_to_validate:
_validate_curator_status__no_commit(db_session, users_to_validate)
# update "time_updated" to now
db_user_group.time_last_modified_by_user = func.now()
_validate_curator_status__no_commit(db_session, list(removed_users))
db_session.commit()
return db_user_group
@@ -714,10 +699,7 @@ def delete_user_group_cc_pair_relationship__no_commit(
connector_credential_pair_id matches the given cc_pair_id.
Should be used very carefully (only for connectors that are being deleted)."""
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair:
raise ValueError(f"Connector Credential Pair '{cc_pair_id}' does not exist")

View File

@@ -2,7 +2,6 @@
Rules defined here:
https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.html
"""
from collections.abc import Generator
from typing import Any
from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC
@@ -10,16 +9,11 @@ from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GR
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.confluence.onyx_confluence import (
get_user_email_from_username__server,
)
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.confluence.utils import get_user_email_from_username__server
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -30,9 +24,7 @@ _REQUEST_PAGINATION_LIMIT = 5000
def _get_server_space_permissions(
confluence_client: OnyxConfluence, space_key: str
) -> ExternalAccess:
space_permissions = confluence_client.get_all_space_permissions_server(
space_key=space_key
)
space_permissions = confluence_client.get_space_permissions(space_key=space_key)
viewspace_permissions = []
for permission_category in space_permissions:
@@ -75,13 +67,6 @@ def _get_server_space_permissions(
else:
logger.warning(f"Email for user {user_name} not found in Confluence")
if not user_emails and not group_names:
logger.warning(
"No user emails or group names found in Confluence space permissions"
f"\nSpace key: {space_key}"
f"\nSpace permissions: {space_permissions}"
)
return ExternalAccess(
external_user_emails=user_emails,
external_user_group_ids=group_names,
@@ -159,9 +144,6 @@ def _get_space_permissions(
# Stores the permissions for each space
space_permissions_by_space_key[space_key] = space_permissions
logger.info(
f"Found space permissions for space '{space_key}': {space_permissions}"
)
return space_permissions_by_space_key
@@ -266,19 +248,14 @@ def _fetch_all_page_restrictions(
slim_docs: list[SlimDocument],
space_permissions_by_space_key: dict[str, ExternalAccess],
is_cloud: bool,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
) -> list[DocExternalAccess]:
"""
For all pages, if a page has restrictions, then use those restrictions.
Otherwise, use the space's restrictions.
"""
document_restrictions: list[DocExternalAccess] = []
for slim_doc in slim_docs:
if callback:
if callback.should_stop():
raise RuntimeError("confluence_doc_sync: Stop signal detected")
callback.progress("confluence_doc_sync:fetch_all_page_restrictions", 1)
if slim_doc.perm_sync_data is None:
raise ValueError(
f"No permission sync data found for document {slim_doc.id}"
@@ -288,9 +265,11 @@ def _fetch_all_page_restrictions(
confluence_client=confluence_client,
perm_sync_data=slim_doc.perm_sync_data,
):
yield DocExternalAccess(
doc_id=slim_doc.id,
external_access=restrictions,
document_restrictions.append(
DocExternalAccess(
doc_id=slim_doc.id,
external_access=restrictions,
)
)
# If there are restrictions, then we don't need to use the space's restrictions
continue
@@ -324,9 +303,11 @@ def _fetch_all_page_restrictions(
continue
# If there are no restrictions, then use the space's restrictions
yield DocExternalAccess(
doc_id=slim_doc.id,
external_access=space_permissions,
document_restrictions.append(
DocExternalAccess(
doc_id=slim_doc.id,
external_access=space_permissions,
)
)
if (
not space_permissions.is_public
@@ -340,12 +321,12 @@ def _fetch_all_page_restrictions(
)
logger.debug("Finished fetching all page restrictions for space")
return document_restrictions
def confluence_doc_sync(
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
@@ -356,11 +337,7 @@ def confluence_doc_sync(
confluence_connector = ConfluenceConnector(
**cc_pair.connector.connector_specific_config
)
provider = OnyxDBCredentialsProvider(
get_current_tenant_id(), "confluence", cc_pair.credential_id
)
confluence_connector.set_credentials_provider(provider)
confluence_connector.load_credentials(cc_pair.credential.credential_json)
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
@@ -371,23 +348,14 @@ def confluence_doc_sync(
slim_docs = []
logger.debug("Fetching all slim documents from confluence")
for doc_batch in confluence_connector.retrieve_all_slim_documents(
callback=callback
):
for doc_batch in confluence_connector.retrieve_all_slim_documents():
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
if callback:
if callback.should_stop():
raise RuntimeError("confluence_doc_sync: Stop signal detected")
callback.progress("confluence_doc_sync", 1)
slim_docs.extend(doc_batch)
logger.debug("Fetching all page restrictions for space")
yield from _fetch_all_page_restrictions(
return _fetch_all_page_restrictions(
confluence_client=confluence_connector.confluence_client,
slim_docs=slim_docs,
space_permissions_by_space_key=space_permissions_by_space_key,
is_cloud=is_cloud,
callback=callback,
)

View File

@@ -1,11 +1,8 @@
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
from onyx.background.error_logging import emit_background_error
from onyx.connectors.confluence.onyx_confluence import (
get_user_email_from_username__server,
)
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.confluence.utils import get_user_email_from_username__server
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
@@ -13,81 +10,47 @@ logger = setup_logger()
def _build_group_member_email_map(
confluence_client: OnyxConfluence, cc_pair_id: int
confluence_client: OnyxConfluence,
) -> dict[str, set[str]]:
group_member_emails: dict[str, set[str]] = {}
for user in confluence_client.paginated_cql_user_retrieval():
logger.debug(f"Processing groups for user: {user}")
email = user.email
for user_result in confluence_client.paginated_cql_user_retrieval():
user = user_result.get("user", {})
if not user:
logger.warning(f"user result missing user field: {user_result}")
continue
email = user.get("email")
if not email:
# This field is only present in Confluence Server
user_name = user.username
user_name = user.get("username")
# If it is present, try to get the email using a Server-specific method
if user_name:
email = get_user_email_from_username__server(
confluence_client=confluence_client,
user_name=user_name,
)
if not email:
# If we still don't have an email, skip this user
msg = f"user result missing email field: {user}"
if user.type == "app":
logger.warning(msg)
else:
emit_background_error(msg, cc_pair_id=cc_pair_id)
logger.error(msg)
continue
all_users_groups: set[str] = set()
for group in confluence_client.paginated_groups_by_user_retrieval(user.user_id):
for group in confluence_client.paginated_groups_by_user_retrieval(user):
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
group_id = group["name"]
group_member_emails.setdefault(group_id, set()).add(email)
all_users_groups.add(group_id)
if not all_users_groups:
msg = f"No groups found for user with email: {email}"
emit_background_error(msg, cc_pair_id=cc_pair_id)
logger.error(msg)
else:
logger.debug(f"Found groups {all_users_groups} for user with email {email}")
if not group_member_emails:
msg = "No groups found for any users."
emit_background_error(msg, cc_pair_id=cc_pair_id)
logger.error(msg)
return group_member_emails
def confluence_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> list[ExternalUserGroup]:
provider = OnyxDBCredentialsProvider(tenant_id, "confluence", cc_pair.credential_id)
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
wiki_base: str = cc_pair.connector.connector_specific_config["wiki_base"]
url = wiki_base.rstrip("/")
probe_kwargs = {
"max_backoff_retries": 6,
"max_backoff_seconds": 10,
}
final_kwargs = {
"max_backoff_retries": 10,
"max_backoff_seconds": 60,
}
confluence_client = OnyxConfluence(is_cloud, url, provider)
confluence_client._probe_connection(**probe_kwargs)
confluence_client._initialize_connection(**final_kwargs)
confluence_client = build_confluence_client(
credentials=cc_pair.credential.credential_json,
is_cloud=cc_pair.connector.connector_specific_config.get("is_cloud", False),
wiki_base=cc_pair.connector.connector_specific_config["wiki_base"],
)
group_member_email_map = _build_group_member_email_map(
confluence_client=confluence_client,
cc_pair_id=cc_pair.id,
)
onyx_groups: list[ExternalUserGroup] = []
all_found_emails = set()

View File

@@ -1,4 +1,3 @@
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
@@ -7,7 +6,6 @@ from onyx.access.models import ExternalAccess
from onyx.connectors.gmail.connector import GmailConnector
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -16,7 +14,6 @@ logger = setup_logger()
def _get_slim_doc_generator(
cc_pair: ConnectorCredentialPair,
gmail_connector: GmailConnector,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
current_time = datetime.now(timezone.utc)
start_time = (
@@ -26,16 +23,13 @@ def _get_slim_doc_generator(
)
return gmail_connector.retrieve_all_slim_documents(
start=start_time,
end=current_time.timestamp(),
callback=callback,
start=start_time, end=current_time.timestamp()
)
def gmail_doc_sync(
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
@@ -45,29 +39,25 @@ def gmail_doc_sync(
gmail_connector = GmailConnector(**cc_pair.connector.connector_specific_config)
gmail_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = _get_slim_doc_generator(
cc_pair, gmail_connector, callback=callback
)
slim_doc_generator = _get_slim_doc_generator(cc_pair, gmail_connector)
document_external_access: list[DocExternalAccess] = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if callback:
if callback.should_stop():
raise RuntimeError("gmail_doc_sync: Stop signal detected")
callback.progress("gmail_doc_sync", 1)
if slim_doc.perm_sync_data is None:
logger.warning(f"No permissions found for document {slim_doc.id}")
continue
if user_email := slim_doc.perm_sync_data.get("user_email"):
ext_access = ExternalAccess(
external_user_emails=set([user_email]),
external_user_group_ids=set(),
is_public=False,
)
yield DocExternalAccess(
doc_id=slim_doc.id,
external_access=ext_access,
document_external_access.append(
DocExternalAccess(
doc_id=slim_doc.id,
external_access=ext_access,
)
)
return document_external_access

View File

@@ -1,4 +1,3 @@
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from typing import Any
@@ -11,7 +10,6 @@ from onyx.connectors.google_utils.resources import get_drive_service
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -22,7 +20,6 @@ _PERMISSION_ID_PERMISSION_MAP: dict[str, dict[str, Any]] = {}
def _get_slim_doc_generator(
cc_pair: ConnectorCredentialPair,
google_drive_connector: GoogleDriveConnector,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
current_time = datetime.now(timezone.utc)
start_time = (
@@ -32,9 +29,7 @@ def _get_slim_doc_generator(
)
return google_drive_connector.retrieve_all_slim_documents(
start=start_time,
end=current_time.timestamp(),
callback=callback,
start=start_time, end=current_time.timestamp()
)
@@ -47,33 +42,34 @@ def _fetch_permissions_for_permission_ids(
if not permission_info or not doc_id:
return []
# Check cache first for all permission IDs
permissions = [
_PERMISSION_ID_PERMISSION_MAP[pid]
for pid in permission_ids
if pid in _PERMISSION_ID_PERMISSION_MAP
]
# If we found all permissions in cache, return them
if len(permissions) == len(permission_ids):
return permissions
owner_email = permission_info.get("owner_email")
drive_service = get_drive_service(
creds=google_drive_connector.creds,
user_email=(owner_email or google_drive_connector.primary_admin_email),
)
# We continue on 404 or 403 because the document may not exist or the user may not have access to it
# Otherwise, fetch all permissions and update cache
fetched_permissions = execute_paginated_retrieval(
retrieval_function=drive_service.permissions().list,
list_key="permissions",
fileId=doc_id,
fields="permissions(id, emailAddress, type, domain)",
supportsAllDrives=True,
continue_on_404_or_403=True,
)
permissions_for_doc_id = []
# Update cache and return all permissions
for permission in fetched_permissions:
permissions_for_doc_id.append(permission)
_PERMISSION_ID_PERMISSION_MAP[permission["id"]] = permission
@@ -107,13 +103,7 @@ def _get_permissions_from_slim_doc(
user_emails: set[str] = set()
group_emails: set[str] = set()
public = False
skipped_permissions = 0
for permission in permissions_list:
if not permission:
skipped_permissions += 1
continue
permission_type = permission["type"]
if permission_type == "user":
user_emails.add(permission["emailAddress"])
@@ -130,25 +120,16 @@ def _get_permissions_from_slim_doc(
elif permission_type == "anyone":
public = True
if skipped_permissions > 0:
logger.warning(
f"Skipped {skipped_permissions} permissions of {len(permissions_list)} for document {slim_doc.id}"
)
drive_id = permission_info.get("drive_id")
group_ids = group_emails | ({drive_id} if drive_id is not None else set())
return ExternalAccess(
external_user_emails=user_emails,
external_user_group_ids=group_ids,
external_user_group_ids=group_emails,
is_public=public,
)
def gdrive_doc_sync(
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
@@ -162,19 +143,17 @@ def gdrive_doc_sync(
slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector)
document_external_accesses = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if callback:
if callback.should_stop():
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
callback.progress("gdrive_doc_sync", 1)
ext_access = _get_permissions_from_slim_doc(
google_drive_connector=google_drive_connector,
slim_doc=slim_doc,
)
yield DocExternalAccess(
external_access=ext_access,
doc_id=slim_doc.id,
document_external_accesses.append(
DocExternalAccess(
external_access=ext_access,
doc_id=slim_doc.id,
)
)
return document_external_accesses

View File

@@ -1,128 +1,16 @@
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
from onyx.connectors.google_utils.resources import AdminService
from onyx.connectors.google_utils.resources import get_admin_service
from onyx.connectors.google_utils.resources import get_drive_service
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _get_drive_members(
google_drive_connector: GoogleDriveConnector,
) -> dict[str, tuple[set[str], set[str]]]:
"""
This builds a map of drive ids to their members (group and user emails).
E.g. {
"drive_id_1": ({"group_email_1"}, {"user_email_1", "user_email_2"}),
"drive_id_2": ({"group_email_3"}, {"user_email_3"}),
}
"""
drive_ids = google_drive_connector.get_all_drive_ids()
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]] = {}
drive_service = get_drive_service(
google_drive_connector.creds,
google_drive_connector.primary_admin_email,
)
for drive_id in drive_ids:
group_emails: set[str] = set()
user_emails: set[str] = set()
for permission in execute_paginated_retrieval(
drive_service.permissions().list,
list_key="permissions",
fileId=drive_id,
fields="permissions(emailAddress, type)",
supportsAllDrives=True,
):
if permission["type"] == "group":
group_emails.add(permission["emailAddress"])
elif permission["type"] == "user":
user_emails.add(permission["emailAddress"])
drive_id_to_members_map[drive_id] = (group_emails, user_emails)
return drive_id_to_members_map
def _get_all_groups(
admin_service: AdminService,
google_domain: str,
) -> set[str]:
"""
This gets all the group emails.
"""
group_emails: set[str] = set()
for group in execute_paginated_retrieval(
admin_service.groups().list,
list_key="groups",
domain=google_domain,
fields="groups(email)",
):
group_emails.add(group["email"])
return group_emails
def _map_group_email_to_member_emails(
admin_service: AdminService,
group_emails: set[str],
) -> dict[str, set[str]]:
"""
This maps group emails to their member emails.
"""
group_to_member_map: dict[str, set[str]] = {}
for group_email in group_emails:
group_member_emails: set[str] = set()
for member in execute_paginated_retrieval(
admin_service.members().list,
list_key="members",
groupKey=group_email,
fields="members(email)",
):
group_member_emails.add(member["email"])
group_to_member_map[group_email] = group_member_emails
return group_to_member_map
def _build_onyx_groups(
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]],
group_email_to_member_emails_map: dict[str, set[str]],
) -> list[ExternalUserGroup]:
onyx_groups: list[ExternalUserGroup] = []
# Convert all drive member definitions to onyx groups
# This is because having drive level access means you have
# irrevocable access to all the files in the drive.
for drive_id, (group_emails, user_emails) in drive_id_to_members_map.items():
all_member_emails: set[str] = user_emails
for group_email in group_emails:
all_member_emails.update(group_email_to_member_emails_map[group_email])
onyx_groups.append(
ExternalUserGroup(
id=drive_id,
user_emails=list(all_member_emails),
)
)
# Convert all group member definitions to onyx groups
for group_email, member_emails in group_email_to_member_emails_map.items():
onyx_groups.append(
ExternalUserGroup(
id=group_email,
user_emails=list(member_emails),
)
)
return onyx_groups
def gdrive_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> list[ExternalUserGroup]:
# Initialize connector and build credential/service objects
google_drive_connector = GoogleDriveConnector(
**cc_pair.connector.connector_specific_config
)
@@ -131,23 +19,34 @@ def gdrive_group_sync(
google_drive_connector.creds, google_drive_connector.primary_admin_email
)
# Get all drive members
drive_id_to_members_map = _get_drive_members(google_drive_connector)
onyx_groups: list[ExternalUserGroup] = []
for group in execute_paginated_retrieval(
admin_service.groups().list,
list_key="groups",
domain=google_drive_connector.google_domain,
fields="groups(email)",
):
# The id is the group email
group_email = group["email"]
# Get all group emails
all_group_emails = _get_all_groups(
admin_service, google_drive_connector.google_domain
)
# Gather group member emails
group_member_emails: list[str] = []
for member in execute_paginated_retrieval(
admin_service.members().list,
list_key="members",
groupKey=group_email,
fields="members(email)",
):
group_member_emails.append(member["email"])
# Map group emails to their members
group_email_to_member_emails_map = _map_group_email_to_member_emails(
admin_service, all_group_emails
)
if not group_member_emails:
continue
# Convert the maps to onyx groups
onyx_groups = _build_onyx_groups(
drive_id_to_members_map=drive_id_to_members_map,
group_email_to_member_emails_map=group_email_to_member_emails_map,
)
onyx_groups.append(
ExternalUserGroup(
id=group_email,
user_emails=list(group_member_emails),
)
)
return onyx_groups

View File

@@ -55,7 +55,7 @@ def _post_query_chunk_censoring(
# if user is None, permissions are not enforced
return chunks
final_chunk_dict: dict[str, InferenceChunk] = {}
chunks_to_keep = []
chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {}
sources_to_censor = _get_all_censoring_enabled_sources()
@@ -64,7 +64,7 @@ def _post_query_chunk_censoring(
if chunk.source_type in sources_to_censor:
chunks_to_process.setdefault(chunk.source_type, []).append(chunk)
else:
final_chunk_dict[chunk.unique_id] = chunk
chunks_to_keep.append(chunk)
# For each source, filter out the chunks using the permission
# check function for that source
@@ -79,16 +79,6 @@ def _post_query_chunk_censoring(
f" chunks for this source and continuing: {e}"
)
continue
chunks_to_keep.extend(censored_chunks)
for censored_chunk in censored_chunks:
final_chunk_dict[censored_chunk.unique_id] = censored_chunk
# IMPORTANT: make sure to retain the same ordering as the original `chunks` passed in
final_chunk_list: list[InferenceChunk] = []
for chunk in chunks:
# only if the chunk is in the final censored chunks, add it to the final list
# if it is missing, that means it was intentionally left out
if chunk.unique_id in final_chunk_dict:
final_chunk_list.append(final_chunk_dict[chunk.unique_id])
return final_chunk_list
return chunks_to_keep

View File

@@ -58,7 +58,6 @@ def _get_objects_access_for_user_email_from_salesforce(
f"Time taken to get Salesforce user ID: {end_time - start_time} seconds"
)
if user_id is None:
logger.warning(f"User '{user_email}' not found in Salesforce")
return None
# This is the only query that is not cached in the function
@@ -66,7 +65,6 @@ def _get_objects_access_for_user_email_from_salesforce(
object_id_to_access = get_objects_access_for_user_id(
salesforce_client, user_id, list(object_ids)
)
logger.debug(f"Object ID to access: {object_id_to_access}")
return object_id_to_access

View File

@@ -42,18 +42,11 @@ def get_any_salesforce_client_for_doc_id(
def _query_salesforce_user_id(sf_client: Salesforce, user_email: str) -> str | None:
query = f"SELECT Id FROM User WHERE Username = '{user_email}' AND IsActive = true"
query = f"SELECT Id FROM User WHERE Email = '{user_email}'"
result = sf_client.query(query)
if len(result["records"]) > 0:
return result["records"][0]["Id"]
# try emails
query = f"SELECT Id FROM User WHERE Email = '{user_email}' AND IsActive = true"
result = sf_client.query(query)
if len(result["records"]) > 0:
return result["records"][0]["Id"]
return None
if len(result["records"]) == 0:
return None
return result["records"][0]["Id"]
# This contains only the user_ids that we have found in Salesforce.
@@ -168,10 +161,7 @@ def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Sales
cc_pair_id = _DOC_ID_TO_CC_PAIR_ID_MAP[doc_id]
if cc_pair_id not in _CC_PAIR_ID_SALESFORCE_CLIENT_MAP:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(f"CC pair {cc_pair_id} not found")
credential_json = cc_pair.credential.credential_json

View File

@@ -1,5 +1,3 @@
from collections.abc import Generator
from slack_sdk import WebClient
from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
@@ -7,15 +5,35 @@ from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.slack.connector import get_channels
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.connector import SlackConnector
from onyx.connectors.slack.connector import SlackPollConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _get_slack_document_ids_and_channels(
cc_pair: ConnectorCredentialPair,
) -> dict[str, list[str]]:
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = slack_connector.retrieve_all_slim_documents()
channel_doc_map: dict[str, list[str]] = {}
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:
if doc_metadata.perm_sync_data is None:
continue
channel_id = doc_metadata.perm_sync_data["channel_id"]
if channel_id not in channel_doc_map:
channel_doc_map[channel_id] = []
channel_doc_map[channel_id].append(doc_metadata.id)
return channel_doc_map
def _fetch_workspace_permissions(
user_id_to_email_map: dict[str, str],
) -> ExternalAccess:
@@ -95,37 +113,9 @@ def _fetch_channel_permissions(
return channel_permissions
def _get_slack_document_access(
cc_pair: ConnectorCredentialPair,
channel_permissions: dict[str, ExternalAccess],
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:
if doc_metadata.perm_sync_data is None:
continue
channel_id = doc_metadata.perm_sync_data["channel_id"]
yield DocExternalAccess(
external_access=channel_permissions[channel_id],
doc_id=doc_metadata.id,
)
if callback:
if callback.should_stop():
raise RuntimeError("_get_slack_document_access: Stop signal detected")
callback.progress("_get_slack_document_access", 1)
def slack_doc_sync(
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
@@ -136,12 +126,9 @@ def slack_doc_sync(
token=cc_pair.credential.credential_json["slack_bot_token"]
)
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
if not user_id_to_email_map:
raise ValueError(
"No user id to email map found. Please check to make sure that "
"your Slack bot token has the `users:read.email` scope"
)
channel_doc_map = _get_slack_document_ids_and_channels(
cc_pair=cc_pair,
)
workspace_permissions = _fetch_workspace_permissions(
user_id_to_email_map=user_id_to_email_map,
)
@@ -151,8 +138,18 @@ def slack_doc_sync(
user_id_to_email_map=user_id_to_email_map,
)
yield from _get_slack_document_access(
cc_pair=cc_pair,
channel_permissions=channel_permissions,
callback=callback,
)
document_external_accesses = []
for channel_id, ext_access in channel_permissions.items():
doc_ids = channel_doc_map.get(channel_id)
if not doc_ids:
# No documents found for channel the channel_id
continue
for doc_id in doc_ids:
document_external_accesses.append(
DocExternalAccess(
external_access=ext_access,
doc_id=doc_id,
)
)
return document_external_accesses

View File

@@ -1,10 +1,7 @@
from collections.abc import Callable
from collections.abc import Generator
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync
@@ -18,20 +15,17 @@ from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
# Defining the input/output types for the sync functions
DocSyncFuncType = Callable[
[
ConnectorCredentialPair,
IndexingHeartbeatInterface | None,
],
Generator[DocExternalAccess, None, None],
list[DocExternalAccess],
]
GroupSyncFuncType = Callable[
[
str,
ConnectorCredentialPair,
],
list[ExternalUserGroup],
@@ -68,13 +62,13 @@ GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC: set[DocumentSource] = {
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all doc permissions every 5 minutes
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY,
DocumentSource.SLACK: SLACK_PERMISSION_DOC_SYNC_FREQUENCY,
DocumentSource.SLACK: 5 * 60,
}
# If nothing is specified here, we run the doc_sync every time the celery beat runs
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all group permissions every 30 minutes
DocumentSource.GOOGLE_DRIVE: GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY,
DocumentSource.GOOGLE_DRIVE: 5 * 60,
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY,
}

View File

@@ -1,9 +1,7 @@
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from httpx_oauth.clients.openid import BASE_SCOPES
from httpx_oauth.clients.openid import OpenID
from ee.onyx.configs.app_configs import OIDC_SCOPE_OVERRIDE
from ee.onyx.configs.app_configs import OPENID_CONFIG_URL
from ee.onyx.server.analytics.api import router as analytics_router
from ee.onyx.server.auth_check import check_ee_router_auth
@@ -15,7 +13,7 @@ from ee.onyx.server.enterprise_settings.api import (
)
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
from ee.onyx.server.oauth.api import router as ee_oauth_router
from ee.onyx.server.oauth import router as oauth_router
from ee.onyx.server.query_and_chat.chat_backend import (
router as chat_router,
)
@@ -64,15 +62,7 @@ def get_application() -> FastAPI:
add_tenant_id_middleware(application, logger)
if AUTH_TYPE == AuthType.CLOUD:
# For Google OAuth, refresh tokens are requested by:
# 1. Adding the right scopes
# 2. Properly configuring OAuth in Google Cloud Console to allow offline access
oauth_client = GoogleOAuth2(
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
# Use standard scopes that include profile and email
scopes=["openid", "email", "profile"],
)
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
@@ -95,26 +85,10 @@ def get_application() -> FastAPI:
)
if AUTH_TYPE == AuthType.OIDC:
# Ensure we request offline_access for refresh tokens
try:
oidc_scopes = list(OIDC_SCOPE_OVERRIDE or BASE_SCOPES)
if "offline_access" not in oidc_scopes:
oidc_scopes.append("offline_access")
except Exception as e:
logger.warning(f"Error configuring OIDC scopes: {e}")
# Fall back to default scopes if there's an error
oidc_scopes = BASE_SCOPES
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
OpenID(
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
OPENID_CONFIG_URL,
# Use the configured scopes
base_scopes=oidc_scopes,
),
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
auth_backend,
USER_AUTH_SECRET,
associate_by_email=True,
@@ -146,7 +120,7 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, query_router)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, standard_answer_router)
include_router_with_global_prefix_prepended(application, ee_oauth_router)
include_router_with_global_prefix_prepended(application, oauth_router)
# Enterprise-only global settings
include_router_with_global_prefix_prepended(
@@ -170,8 +144,4 @@ def get_application() -> FastAPI:
# environment variable. Used to automate deployment for multiple environments.
seed_db()
# for debugging discovered routes
# for route in application.router.routes:
# print(f"Path: {route.path}, Methods: {route.methods}")
return application

View File

@@ -22,7 +22,7 @@ from onyx.onyxbot.slack.blocks import get_restate_blocks
from onyx.onyxbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
from onyx.onyxbot.slack.models import SlackMessageInfo
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
from onyx.onyxbot.slack.utils import respond_in_thread
from onyx.onyxbot.slack.utils import update_emote_react
from onyx.utils.logger import OnyxLoggingAdapter
from onyx.utils.logger import setup_logger
@@ -80,7 +80,7 @@ def oneoff_standard_answers(
def _handle_standard_answers(
message_info: SlackMessageInfo,
receiver_ids: list[str] | None,
slack_channel_config: SlackChannelConfig,
slack_channel_config: SlackChannelConfig | None,
prompt: Prompt | None,
logger: OnyxLoggingAdapter,
client: WebClient,
@@ -94,10 +94,13 @@ def _handle_standard_answers(
Returns True if standard answers are found to match the user's message and therefore,
we still need to respond to the users.
"""
# if no channel config, then no standard answers are configured
if not slack_channel_config:
return False
slack_thread_id = message_info.thread_to_respond
configured_standard_answer_categories = (
slack_channel_config.standard_answer_categories
slack_channel_config.standard_answer_categories if slack_channel_config else []
)
configured_standard_answers = set(
[
@@ -147,9 +150,9 @@ def _handle_standard_answers(
db_session=db_session,
description="",
user_id=None,
persona_id=(
slack_channel_config.persona.id if slack_channel_config.persona else 0
),
persona_id=slack_channel_config.persona.id
if slack_channel_config.persona
else 0,
onyxbot_flow=True,
slack_thread_id=slack_thread_id,
)
@@ -179,7 +182,7 @@ def _handle_standard_answers(
formatted_answers.append(formatted_answer)
answer_message = "\n\n".join(formatted_answers)
chat_message = create_new_chat_message(
_ = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=new_user_message,
prompt_id=prompt.id if prompt else None,
@@ -188,13 +191,8 @@ def _handle_standard_answers(
message_type=MessageType.ASSISTANT,
error=None,
db_session=db_session,
commit=False,
commit=True,
)
# attach the standard answers to the chat message
chat_message.standard_answers = [
standard_answer for standard_answer, _ in matching_standard_answers
]
db_session.commit()
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
@@ -216,7 +214,7 @@ def _handle_standard_answers(
all_blocks = restate_question_blocks + answer_blocks
try:
respond_in_thread_or_channel(
respond_in_thread(
client=client,
channel=message_info.channel_to_respond,
receiver_ids=receiver_ids,
@@ -231,7 +229,6 @@ def _handle_standard_answers(
client=client,
channel=message_info.channel_to_respond,
thread_ts=slack_thread_id,
receiver_ids=receiver_ids,
)
return True

View File

@@ -15,8 +15,8 @@ from sqlalchemy.orm import Session
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
from ee.onyx.server.enterprise_settings.store import get_logo_filename
from ee.onyx.server.enterprise_settings.store import get_logotype_filename
from ee.onyx.server.enterprise_settings.store import _LOGO_FILENAME
from ee.onyx.server.enterprise_settings.store import _LOGOTYPE_FILENAME
from ee.onyx.server.enterprise_settings.store import load_analytics_script
from ee.onyx.server.enterprise_settings.store import load_settings
from ee.onyx.server.enterprise_settings.store import store_analytics_script
@@ -28,7 +28,7 @@ from onyx.auth.users import get_user_manager
from onyx.auth.users import UserManager
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.file_store.file_store import PostgresBackedFileStore
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
admin_router = APIRouter(prefix="/admin/enterprise-settings")
@@ -131,49 +131,31 @@ def put_logo(
upload_logo(file=file, db_session=db_session, is_logotype=is_logotype)
def fetch_logo_helper(db_session: Session) -> Response:
def fetch_logo_or_logotype(is_logotype: bool, db_session: Session) -> Response:
try:
file_store = PostgresBackedFileStore(db_session)
onyx_file = file_store.get_file_with_mime_type(get_logo_filename())
if not onyx_file:
raise ValueError("get_onyx_file returned None!")
file_store = get_default_file_store(db_session)
filename = _LOGOTYPE_FILENAME if is_logotype else _LOGO_FILENAME
file_io = file_store.read_file(filename, mode="b")
# NOTE: specifying "image/jpeg" here, but it still works for pngs
# TODO: do this properly
return Response(content=file_io.read(), media_type="image/jpeg")
except Exception:
raise HTTPException(
status_code=404,
detail="No logo file found",
detail=f"No {'logotype' if is_logotype else 'logo'} file found",
)
else:
return Response(content=onyx_file.data, media_type=onyx_file.mime_type)
def fetch_logotype_helper(db_session: Session) -> Response:
try:
file_store = PostgresBackedFileStore(db_session)
onyx_file = file_store.get_file_with_mime_type(get_logotype_filename())
if not onyx_file:
raise ValueError("get_onyx_file returned None!")
except Exception:
raise HTTPException(
status_code=404,
detail="No logotype file found",
)
else:
return Response(content=onyx_file.data, media_type=onyx_file.mime_type)
@basic_router.get("/logotype")
def fetch_logotype(db_session: Session = Depends(get_session)) -> Response:
return fetch_logotype_helper(db_session)
return fetch_logo_or_logotype(is_logotype=True, db_session=db_session)
@basic_router.get("/logo")
def fetch_logo(
is_logotype: bool = False, db_session: Session = Depends(get_session)
) -> Response:
if is_logotype:
return fetch_logotype_helper(db_session)
return fetch_logo_helper(db_session)
return fetch_logo_or_logotype(is_logotype=is_logotype, db_session=db_session)
@admin_router.put("/custom-analytics-script")

View File

@@ -13,7 +13,6 @@ from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import KV_CUSTOM_ANALYTICS_SCRIPT_KEY
from onyx.configs.constants import KV_ENTERPRISE_SETTINGS_KEY
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
from onyx.file_store.file_store import get_default_file_store
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
@@ -22,18 +21,8 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_LOGO_FILENAME = "__logo__"
_LOGOTYPE_FILENAME = "__logotype__"
def load_settings() -> EnterpriseSettings:
"""Loads settings data directly from DB. This should be used primarily
for checking what is actually in the DB, aka for editing and saving back settings.
Runtime settings actually used by the application should be checked with
load_runtime_settings as defaults may be applied at runtime.
"""
dynamic_config_store = get_kv_store()
try:
settings = EnterpriseSettings(
@@ -47,24 +36,9 @@ def load_settings() -> EnterpriseSettings:
def store_settings(settings: EnterpriseSettings) -> None:
"""Stores settings directly to the kv store / db."""
get_kv_store().store(KV_ENTERPRISE_SETTINGS_KEY, settings.model_dump())
def load_runtime_settings() -> EnterpriseSettings:
"""Loads settings from DB and applies any defaults or transformations for use
at runtime.
Should not be stored back to the DB.
"""
enterprise_settings = load_settings()
if not enterprise_settings.application_name:
enterprise_settings.application_name = ONYX_DEFAULT_APPLICATION_NAME
return enterprise_settings
_CUSTOM_ANALYTICS_SECRET_KEY = os.environ.get("CUSTOM_ANALYTICS_SECRET_KEY")
@@ -86,6 +60,10 @@ def store_analytics_script(analytics_script_upload: AnalyticsScriptUpload) -> No
get_kv_store().store(KV_CUSTOM_ANALYTICS_SCRIPT_KEY, analytics_script_upload.script)
_LOGO_FILENAME = "__logo__"
_LOGOTYPE_FILENAME = "__logotype__"
def is_valid_file_type(filename: str) -> bool:
valid_extensions = (".png", ".jpg", ".jpeg")
return filename.endswith(valid_extensions)
@@ -138,11 +116,3 @@ def upload_logo(
file_type=file_type,
)
return True
def get_logo_filename() -> str:
return _LOGO_FILENAME
def get_logotype_filename() -> str:
return _LOGOTYPE_FILENAME

View File

@@ -10,7 +10,6 @@ from fastapi import Response
from ee.onyx.auth.users import decode_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from onyx.auth.api_key import extract_tenant_from_api_key_header
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.engine import is_valid_schema_name
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
from shared_configs.configs import MULTI_TENANT
@@ -33,7 +32,7 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
return await call_next(request)
except Exception as e:
logger.exception(f"Error in tenant ID middleware: {str(e)}")
logger.error(f"Error in tenant ID middleware: {str(e)}")
raise
@@ -44,76 +43,50 @@ async def _get_tenant_id_from_request(
Attempt to extract tenant_id from:
1) The API key header
2) The Redis-based token (stored in Cookie: fastapiusersauth)
3) The anonymous user cookie
Fallback: POSTGRES_DEFAULT_SCHEMA
"""
# Check for API key
tenant_id = extract_tenant_from_api_key_header(request)
if tenant_id is not None:
if tenant_id:
return tenant_id
# Check for anonymous user cookie
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
if anonymous_user_cookie:
try:
anonymous_user_data = decode_anonymous_user_jwt_token(anonymous_user_cookie)
return anonymous_user_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
except Exception as e:
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
# Continue and attempt to authenticate
try:
# Look up token data in Redis
token_data = await retrieve_auth_token_data_from_redis(request)
if token_data:
tenant_id_from_payload = token_data.get(
"tenant_id", POSTGRES_DEFAULT_SCHEMA
if not token_data:
logger.debug(
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
)
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
# so we maintain consistency by returning it here when no valid tenant is found.
return POSTGRES_DEFAULT_SCHEMA
tenant_id = (
str(tenant_id_from_payload)
if tenant_id_from_payload is not None
else None
)
tenant_id_from_payload = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
if tenant_id and not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
# Check for anonymous user cookie
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
if anonymous_user_cookie:
try:
anonymous_user_data = decode_anonymous_user_jwt_token(
anonymous_user_cookie
)
tenant_id = anonymous_user_data.get(
"tenant_id", POSTGRES_DEFAULT_SCHEMA
)
if not tenant_id or not is_valid_schema_name(tenant_id):
raise HTTPException(
status_code=400, detail="Invalid tenant ID format"
)
return tenant_id
except Exception as e:
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
# Continue and attempt to authenticate
logger.debug(
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
# Since token_data.get() can return None, ensure we have a string
tenant_id = (
str(tenant_id_from_payload)
if tenant_id_from_payload is not None
else POSTGRES_DEFAULT_SCHEMA
)
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
# so we maintain consistency by returning it here when no valid tenant is found.
return POSTGRES_DEFAULT_SCHEMA
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
return tenant_id
except Exception as e:
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
finally:
if tenant_id:
return tenant_id
# As a final step, check for explicit tenant_id cookie
tenant_id_cookie = request.cookies.get(TENANT_ID_COOKIE_NAME)
if tenant_id_cookie and is_valid_schema_name(tenant_id_cookie):
return tenant_id_cookie
# If we've reached this point, return the default schema
return POSTGRES_DEFAULT_SCHEMA

View File

@@ -0,0 +1,629 @@
import base64
import json
import uuid
from typing import Any
from typing import cast
import requests
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
from onyx.auth.users import current_user
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/oauth")
class SlackOAuth:
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
# SCOPE is per https://docs.onyx.app/connectors/slack
BOT_SCOPE = (
"channels:history,"
"channels:read,"
"groups:history,"
"groups:read,"
"channels:join,"
"im:history,"
"users:read,"
"users:read.email,"
"usergroups:read"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
f"https://slack.com/oauth/v2/authorize"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.BOT_SCOPE}"
f"&state={state}"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = SlackOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
return session
class ConfluenceCloudOAuth:
"""work in progress"""
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
CONFLUENCE_OAUTH_SCOPE = (
"read:confluence-props%20"
"read:confluence-content.all%20"
"read:confluence-content.summary%20"
"read:confluence-content.permission%20"
"read:confluence-user%20"
"read:confluence-groups%20"
"readonly:content.attachment:confluence"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
# eventually for Confluence Data Center
# oauth_url = (
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
# f"&redirect_uri={redirectme_uri}"
# )
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
"https://auth.atlassian.com/authorize"
f"?audience=api.atlassian.com"
f"&client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
f"&state={state}"
"&response_type=code"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = ConfluenceCloudOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
return session
class GoogleDriveOAuth:
# https://developers.google.com/identity/protocols/oauth2
# https://developers.google.com/identity/protocols/oauth2/web-server
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
TOKEN_URL = "https://oauth2.googleapis.com/token"
# SCOPE is per https://docs.onyx.app/connectors/google-drive
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
SCOPE = (
"https://www.googleapis.com/auth/drive.readonly%20"
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
"https://www.googleapis.com/auth/admin.directory.group.readonly"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# without prompt=consent, a refresh token is only issued the first time the user approves
url = (
f"https://accounts.google.com/o/oauth2/v2/auth"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
"&response_type=code"
f"&scope={cls.SCOPE}"
"&access_type=offline"
f"&state={state}"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = GoogleDriveOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
return session
@router.post("/prepare-authorization-request")
def prepare_authorization_request(
connector: DocumentSource,
redirect_on_success: str | None,
user: User = Depends(current_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Used by the frontend to generate the url for the user's browser during auth request.
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
"""
# create random oauth state param for security and to retrieve user data later
oauth_uuid = uuid.uuid4()
oauth_uuid_str = str(oauth_uuid)
# urlsafe b64 encode the uuid for the oauth url
oauth_state = (
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
)
if connector == DocumentSource.SLACK:
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
session = SlackOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
elif connector == DocumentSource.GOOGLE_DRIVE:
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
session = GoogleDriveOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
# elif connector == DocumentSource.CONFLUENCE:
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
# session = ConfluenceCloudOAuth.session_dump_json(
# email=user.email, redirect_on_success=redirect_on_success
# )
# elif connector == DocumentSource.JIRA:
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
else:
oauth_url = None
if not oauth_url:
raise HTTPException(
status_code=404,
detail=f"The document source type {connector} does not have OAuth implemented",
)
r = get_redis_client(tenant_id=tenant_id)
# store important session state to retrieve when the user is redirected back
# 10 min is the max we want an oauth flow to be valid
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
return JSONResponse(content={"url": oauth_url})
@router.post("/connector/slack/callback")
def handle_slack_oauth_callback(
code: str,
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Slack client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = SlackOAuth.parse_session(session_json)
# Exchange the authorization code for an access token
response = requests.post(
SlackOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": SlackOAuth.CLIENT_ID,
"client_secret": SlackOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": SlackOAuth.REDIRECT_URI,
},
)
response_data = response.json()
if not response_data.get("ok"):
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed: {response_data.get('error')}",
)
# Extract token and team information
access_token: str = response_data.get("access_token")
team_id: str = response_data.get("team", {}).get("id")
authed_user_id: str = response_data.get("authed_user", {}).get("id")
credential_info = CredentialBase(
credential_json={"slack_bot_token": access_token},
admin_public=True,
source=DocumentSource.SLACK,
name="Slack OAuth",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Slack OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Slack OAuth completed successfully.",
"team_id": team_id,
"authed_user_id": authed_user_id,
"redirect_on_success": session.redirect_on_success,
}
)
# Work in progress
# @router.post("/connector/confluence/callback")
# def handle_confluence_oauth_callback(
# code: str,
# state: str,
# user: User = Depends(current_user),
# db_session: Session = Depends(get_session),
# tenant_id: str | None = Depends(get_current_tenant_id),
# ) -> JSONResponse:
# if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
# raise HTTPException(
# status_code=500,
# detail="Confluence client ID or client secret is not configured."
# )
# r = get_redis_client(tenant_id=tenant_id)
# # recover the state
# padded_state = state + '=' * (-len(state) % 4) # Add padding back (Base64 decoding requires padding)
# uuid_bytes = base64.urlsafe_b64decode(padded_state) # Decode the Base64 string back to bytes
# # Convert bytes back to a UUID
# oauth_uuid = uuid.UUID(bytes=uuid_bytes)
# oauth_uuid_str = str(oauth_uuid)
# r_key = f"da_oauth:{oauth_uuid_str}"
# result = r.get(r_key)
# if not result:
# raise HTTPException(
# status_code=400,
# detail=f"Confluence OAuth failed - OAuth state key not found: key={r_key}"
# )
# try:
# session = ConfluenceCloudOAuth.parse_session(result)
# # Exchange the authorization code for an access token
# response = requests.post(
# ConfluenceCloudOAuth.TOKEN_URL,
# headers={"Content-Type": "application/x-www-form-urlencoded"},
# data={
# "client_id": ConfluenceCloudOAuth.CLIENT_ID,
# "client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
# "code": code,
# "redirect_uri": ConfluenceCloudOAuth.DEV_REDIRECT_URI,
# },
# )
# response_data = response.json()
# if not response_data.get("ok"):
# raise HTTPException(
# status_code=400,
# detail=f"ConfluenceCloudOAuth OAuth failed: {response_data.get('error')}"
# )
# # Extract token and team information
# access_token: str = response_data.get("access_token")
# team_id: str = response_data.get("team", {}).get("id")
# authed_user_id: str = response_data.get("authed_user", {}).get("id")
# credential_info = CredentialBase(
# credential_json={"slack_bot_token": access_token},
# admin_public=True,
# source=DocumentSource.CONFLUENCE,
# name="Confluence OAuth",
# )
# logger.info(f"Slack access token: {access_token}")
# credential = create_credential(credential_info, user, db_session)
# logger.info(f"new_credential_id={credential.id}")
# except Exception as e:
# return JSONResponse(
# status_code=500,
# content={
# "success": False,
# "message": f"An error occurred during Slack OAuth: {str(e)}",
# },
# )
# finally:
# r.delete(r_key)
# # return the result
# return JSONResponse(
# content={
# "success": True,
# "message": "Slack OAuth completed successfully.",
# "team_id": team_id,
# "authed_user_id": authed_user_id,
# "redirect_on_success": session.redirect_on_success,
# }
# )
@router.post("/connector/google-drive/callback")
def handle_google_drive_oauth_callback(
code: str,
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Google Drive client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = GoogleDriveOAuth.parse_session(session_json)
# Exchange the authorization code for an access token
response = requests.post(
GoogleDriveOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": GoogleDriveOAuth.CLIENT_ID,
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": GoogleDriveOAuth.REDIRECT_URI,
"grant_type": "authorization_code",
},
)
response.raise_for_status()
authorization_response: dict[str, Any] = response.json()
# the connector wants us to store the json in its authorized_user_info format
# returned from OAuthCredentials.get_authorized_user_info().
# So refresh immediately via get_google_oauth_creds with the params filled in
# from fields in authorization_response to get the json we need
authorized_user_info = {}
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
token_json_str = json.dumps(authorized_user_info)
oauth_creds = get_google_oauth_creds(
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
)
if not oauth_creds:
raise RuntimeError("get_google_oauth_creds returned None.")
# save off the credentials
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
credential_dict: dict[str, str] = {}
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
credential_dict[
DB_CREDENTIALS_AUTHENTICATION_METHOD
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
credential_info = CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=DocumentSource.GOOGLE_DRIVE,
name="OAuth (interactive)",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Google Drive OAuth completed successfully.",
"redirect_on_success": session.redirect_on_success,
}
)

Some files were not shown because too many files have changed in this diff Show More