mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
161 Commits
github_lis
...
v0.26.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
efab12a962 | ||
|
|
a5234a398b | ||
|
|
d50a17db21 | ||
|
|
dc5a1e8fd0 | ||
|
|
c0b3681650 | ||
|
|
7ec04484d4 | ||
|
|
1cf966ecc1 | ||
|
|
8a8526dbbb | ||
|
|
be20586ba1 | ||
|
|
a314462d1e | ||
|
|
155f53c3d7 | ||
|
|
7c027df186 | ||
|
|
0a5db96026 | ||
|
|
daef985b02 | ||
|
|
b7ece296e0 | ||
|
|
d7063e0a1d | ||
|
|
ee073f6d30 | ||
|
|
2e524816a0 | ||
|
|
47ef0c8658 | ||
|
|
806de92feb | ||
|
|
da39f32fea | ||
|
|
2a87837ce1 | ||
|
|
7491cdd0f0 | ||
|
|
aabd698295 | ||
|
|
4b725e4d1a | ||
|
|
34d2d92fa8 | ||
|
|
3a3b2a2f8d | ||
|
|
ccd372cc4a | ||
|
|
ea30f1de1e | ||
|
|
a7130681d9 | ||
|
|
04911db715 | ||
|
|
feae7d0cc4 | ||
|
|
ac19c64b3c | ||
|
|
03d5c30fd2 | ||
|
|
e988c13e1d | ||
|
|
dc18d53133 | ||
|
|
a1cef389aa | ||
|
|
db8d6ce538 | ||
|
|
e8370dcb24 | ||
|
|
9951fe13ba | ||
|
|
56f8ab927b | ||
|
|
cb5bbd3812 | ||
|
|
742d29e504 | ||
|
|
ecc155d082 | ||
|
|
0857e4809d | ||
|
|
22e00a1f5c | ||
|
|
0d0588a0c1 | ||
|
|
aab777f844 | ||
|
|
babbe7689a | ||
|
|
a123661c92 | ||
|
|
c554889baf | ||
|
|
f08fa878a6 | ||
|
|
d307534781 | ||
|
|
6f54791910 | ||
|
|
0d5497bb6b | ||
|
|
7648627503 | ||
|
|
927554d5ca | ||
|
|
7dcec6caf5 | ||
|
|
036648146d | ||
|
|
2aa4697ac8 | ||
|
|
bc9b4e4f45 | ||
|
|
178a64f298 | ||
|
|
c79f1edf1d | ||
|
|
7c8e23aa54 | ||
|
|
d37b427d52 | ||
|
|
a65fefd226 | ||
|
|
bb09bde519 | ||
|
|
0f6cf0fc58 | ||
|
|
fed06b592d | ||
|
|
8d92a1524e | ||
|
|
ecfea9f5ed | ||
|
|
b269f1ba06 | ||
|
|
30c878efa5 | ||
|
|
2024776c19 | ||
|
|
431316929c | ||
|
|
c5b9c6e308 | ||
|
|
73dd188b3f | ||
|
|
79b061abbc | ||
|
|
552f1ead4f | ||
|
|
17925b49e8 | ||
|
|
55fb5c3ca5 | ||
|
|
99546e4a4d | ||
|
|
c25d56f4a5 | ||
|
|
35f3f4f120 | ||
|
|
25b69a8aca | ||
|
|
1b7d710b2a | ||
|
|
ae3d3db3f4 | ||
|
|
fb79a9e700 | ||
|
|
587ba11bbc | ||
|
|
fce81ebb60 | ||
|
|
61facfb0a8 | ||
|
|
52b96854a2 | ||
|
|
d123713c00 | ||
|
|
775c847f82 | ||
|
|
6d330131fd | ||
|
|
0292ca2445 | ||
|
|
15dd1e72ca | ||
|
|
91c9be37c0 | ||
|
|
2a01c854a0 | ||
|
|
85ebadc8eb | ||
|
|
5dda53eec3 | ||
|
|
72bf427cc2 | ||
|
|
f421c6010b | ||
|
|
0b87549f35 | ||
|
|
06624a988d | ||
|
|
ae774105e3 | ||
|
|
4dafc3aa6d | ||
|
|
5d7d471823 | ||
|
|
61366df34c | ||
|
|
1a444245f6 | ||
|
|
c32d234491 | ||
|
|
07b68436cf | ||
|
|
293d1a4476 | ||
|
|
ba514aaaa2 | ||
|
|
f45798b5dd | ||
|
|
64ff5df083 | ||
|
|
cf1b7e7a93 | ||
|
|
63692a6bd3 | ||
|
|
934700b928 | ||
|
|
b1a7cff9e0 | ||
|
|
463340b8a1 | ||
|
|
ba82888e1e | ||
|
|
39465d3104 | ||
|
|
b4ecc870b9 | ||
|
|
a2ac9f02fb | ||
|
|
f87e559cc4 | ||
|
|
5883336d5e | ||
|
|
0153ff6b51 | ||
|
|
2f8f0f01be | ||
|
|
a9e5ae2f11 | ||
|
|
997f40500d | ||
|
|
a918a84e7b | ||
|
|
090f3fe817 | ||
|
|
4e70f99214 | ||
|
|
ecbd4eb1ad | ||
|
|
f94d335d12 | ||
|
|
59a388ce0a | ||
|
|
9cd3cbb978 | ||
|
|
ab1b6b487e | ||
|
|
6ead9510a4 | ||
|
|
965f9e98bf | ||
|
|
426883bbf5 | ||
|
|
6ca400ced9 | ||
|
|
104c4b9f4d | ||
|
|
8b5e8bd5b9 | ||
|
|
7f7621d7c0 | ||
|
|
06dcc28d05 | ||
|
|
18df63dfd9 | ||
|
|
0d3c72acbf | ||
|
|
9217243e3e | ||
|
|
61ccba82a9 | ||
|
|
9e8eba23c3 | ||
|
|
0c29743538 | ||
|
|
08b2421947 | ||
|
|
ed518563db | ||
|
|
a32f7dc936 | ||
|
|
798e10c52f | ||
|
|
bf4983e35a | ||
|
|
b7da91e3ae | ||
|
|
29382656fc | ||
|
|
7d6db8d500 |
@@ -12,29 +12,40 @@ env:
|
||||
BUILDKIT_PROGRESS: plain
|
||||
|
||||
jobs:
|
||||
# 1) Preliminary job to check if the changed files are relevant
|
||||
|
||||
# Bypassing this for now as the idea of not building is glitching
|
||||
# releases and builds that depends on everything being tagged in docker
|
||||
# 1) Preliminary job to check if the changed files are relevant
|
||||
# check_model_server_changes:
|
||||
# runs-on: ubuntu-latest
|
||||
# outputs:
|
||||
# changed: ${{ steps.check.outputs.changed }}
|
||||
# steps:
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@v4
|
||||
#
|
||||
# - name: Check if relevant files changed
|
||||
# id: check
|
||||
# run: |
|
||||
# # Default to "false"
|
||||
# echo "changed=false" >> $GITHUB_OUTPUT
|
||||
#
|
||||
# # Compare the previous commit (github.event.before) to the current one (github.sha)
|
||||
# # If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
|
||||
# # set changed=true
|
||||
# if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
|
||||
# | grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
|
||||
# echo "changed=true" >> $GITHUB_OUTPUT
|
||||
# fi
|
||||
|
||||
check_model_server_changes:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
changed: ${{ steps.check.outputs.changed }}
|
||||
changed: "true"
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Check if relevant files changed
|
||||
id: check
|
||||
run: |
|
||||
# Default to "false"
|
||||
echo "changed=false" >> $GITHUB_OUTPUT
|
||||
|
||||
# Compare the previous commit (github.event.before) to the current one (github.sha)
|
||||
# If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
|
||||
# set changed=true
|
||||
if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
|
||||
| grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
|
||||
echo "changed=true" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Bypass check and set output
|
||||
run: echo "changed=true" >> $GITHUB_OUTPUT
|
||||
|
||||
build-amd64:
|
||||
needs: [check_model_server_changes]
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
|
||||
209
.github/workflows/pr-mit-integration-tests.yml
vendored
Normal file
209
.github/workflows/pr-mit-integration-tests.yml
vendored
Normal file
@@ -0,0 +1,209 @@
|
||||
name: Run MIT Integration Tests v2
|
||||
concurrency:
|
||||
group: Run-MIT-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
jobs:
|
||||
integration-tests-mit:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Web Docker image
|
||||
run: |
|
||||
docker pull onyxdotapp/onyx-web-server:latest
|
||||
docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test
|
||||
|
||||
# we use the runs-on cache for docker builds
|
||||
# in conjunction with runs-on runners, it has better speed and unlimited caching
|
||||
# https://runs-on.com/caching/s3-cache-for-github-actions/
|
||||
# https://runs-on.com/caching/docker/
|
||||
# https://github.com/moby/buildkit#s3-cache-experimental
|
||||
|
||||
# images are built and run locally for testing purposes. Not pushed.
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-integration:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
AUTH_TYPE=basic \
|
||||
POSTGRES_POOL_PRE_PING=true \
|
||||
POSTGRES_USE_NULL_POOL=true \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
docker logs -f onyx-stack-api_server-1 &
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Use curl with error handling to ignore specific exit code 56
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Start Mock Services
|
||||
run: |
|
||||
cd backend/tests/integration/mock_services
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Standard Integration Tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/tests \
|
||||
/app/tests/integration/connector_job_tests
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
- name: Check test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Always gather logs BEFORE "down":
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
- name: Stop Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
|
||||
16
.github/workflows/pr-python-connector-tests.yml
vendored
16
.github/workflows/pr-python-connector-tests.yml
vendored
@@ -1,6 +1,7 @@
|
||||
name: Connector Tests
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
schedule:
|
||||
@@ -8,6 +9,10 @@ on:
|
||||
- cron: "0 16 * * *"
|
||||
|
||||
env:
|
||||
# AWS
|
||||
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
|
||||
AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
|
||||
|
||||
# Confluence
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
|
||||
@@ -44,14 +49,21 @@ env:
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
|
||||
# Github
|
||||
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
|
||||
# Gitbook
|
||||
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
|
||||
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
|
||||
# Notion
|
||||
NOTION_INTEGRATION_TOKEN: ${{ secrets.NOTION_INTEGRATION_TOKEN }}
|
||||
# Highspot
|
||||
HIGHSPOT_KEY: ${{ secrets.HIGHSPOT_KEY }}
|
||||
HIGHSPOT_SECRET: ${{ secrets.HIGHSPOT_SECRET }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
@@ -76,7 +88,7 @@ jobs:
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
playwright install chromium
|
||||
playwright install-deps chromium
|
||||
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors
|
||||
|
||||
775
.vscode/launch.template.jsonc
vendored
775
.vscode/launch.template.jsonc
vendored
@@ -6,396 +6,419 @@
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"compounds": [
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Compound ---",
|
||||
"configurations": [
|
||||
"--- Individual ---"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Run All Onyx Services",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
"Slack Bot",
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web / Model / API",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery (all)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
}
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Compound ---",
|
||||
"configurations": ["--- Individual ---"],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Run All Onyx Services",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
"Slack Bot",
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery user files indexing",
|
||||
"Celery beat",
|
||||
"Celery monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web / Model / API",
|
||||
"configurations": ["Web Server", "Model Server", "API Server"],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery (all)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery user files indexing",
|
||||
"Celery beat",
|
||||
"Celery monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
}
|
||||
],
|
||||
"configurations": [
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Individual ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web Server",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceRoot}/web",
|
||||
"runtimeExecutable": "npm",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"runtimeArgs": [
|
||||
"run", "dev"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"console": "integratedTerminal",
|
||||
"consoleTitle": "Web Server Console"
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Individual ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web Server",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceRoot}/web",
|
||||
"runtimeExecutable": "npm",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"runtimeArgs": ["run", "dev"],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
{
|
||||
"name": "Model Server",
|
||||
"consoleName": "Model Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
"args": [
|
||||
"model_server.main:app",
|
||||
"--reload",
|
||||
"--port",
|
||||
"9000"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Model Server Console"
|
||||
"console": "integratedTerminal",
|
||||
"consoleTitle": "Web Server Console"
|
||||
},
|
||||
{
|
||||
"name": "Model Server",
|
||||
"consoleName": "Model Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
{
|
||||
"name": "API Server",
|
||||
"consoleName": "API Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
"args": [
|
||||
"onyx.main:app",
|
||||
"--reload",
|
||||
"--port",
|
||||
"8080"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "API Server Console"
|
||||
"args": ["model_server.main:app", "--reload", "--port", "9000"],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
// For the listener to access the Slack API,
|
||||
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
"consoleName": "Slack Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "onyx/onyxbot/slack/listener.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
"consoleTitle": "Model Server Console"
|
||||
},
|
||||
{
|
||||
"name": "API Server",
|
||||
"consoleName": "API Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
{
|
||||
"name": "Celery primary",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.primary",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=primary@%n",
|
||||
"-Q",
|
||||
"celery",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery primary Console"
|
||||
"args": ["onyx.main:app", "--reload", "--port", "8080"],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
{
|
||||
"name": "Celery light",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.light",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=64",
|
||||
"--prefetch-multiplier=8",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery light Console"
|
||||
"consoleTitle": "API Server Console"
|
||||
},
|
||||
// For the listener to access the Slack API,
|
||||
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
"consoleName": "Slack Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "onyx/onyxbot/slack/listener.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
{
|
||||
"name": "Celery heavy",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
{
|
||||
"name": "Celery indexing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.indexing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=indexing@%n",
|
||||
"-Q",
|
||||
"connector_indexing",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery indexing Console"
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery primary",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=solo",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery monitoring Console"
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.primary",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=primary@%n",
|
||||
"-Q",
|
||||
"celery"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
{
|
||||
"name": "Celery beat",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.beat",
|
||||
"beat",
|
||||
"--loglevel=INFO",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery beat Console"
|
||||
"consoleTitle": "Celery primary Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery light",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"consoleName": "Pytest",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-v"
|
||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||
//"tests/unit/onyx/llm/answering/test_prune_and_merge.py"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Pytest Console"
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.light",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=64",
|
||||
"--prefetch-multiplier=8",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Tasks ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "3",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart External Volumes and Containers",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"stopOnEntry": true,
|
||||
"presentation": {
|
||||
"group": "3",
|
||||
},
|
||||
"consoleTitle": "Celery light Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery heavy",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
{
|
||||
// Celery jobs launched through a single background script (legacy)
|
||||
// Recommend using the "Celery (all)" compound launch instead.
|
||||
"name": "Background Jobs",
|
||||
"consoleName": "Background Jobs",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
{
|
||||
"name": "Install Python Requirements",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"-c",
|
||||
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery indexing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.indexing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=indexing@%n",
|
||||
"-Q",
|
||||
"connector_indexing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery indexing Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=solo",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery monitoring Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery beat",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.beat",
|
||||
"beat",
|
||||
"--loglevel=INFO"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery beat Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery user files indexing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.indexing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_files_indexing@%n",
|
||||
"-Q",
|
||||
"user_files_indexing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery user files indexing Console"
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"consoleName": "Pytest",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-v"
|
||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||
//"tests/unit/onyx/llm/answering/test_prune_and_merge.py"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Pytest Console"
|
||||
},
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Tasks ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "3",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart External Volumes and Containers",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"${workspaceFolder}/backend/scripts/restart_containers.sh"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"stopOnEntry": true,
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
// Celery jobs launched through a single background script (legacy)
|
||||
// Recommend using the "Celery (all)" compound launch instead.
|
||||
"name": "Background Jobs",
|
||||
"consoleName": "Background Jobs",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Install Python Requirements",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"-c",
|
||||
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Debug React Web App in Chrome",
|
||||
"type": "chrome",
|
||||
"request": "launch",
|
||||
"url": "http://localhost:3000",
|
||||
"webRoot": "${workspaceFolder}/web"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,3 +114,4 @@ To try the Onyx Enterprise Edition:
|
||||
|
||||
## 💡 Contributing
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ Edition features outside of personal development or testing purposes. Please rea
|
||||
founders@onyx.app for more information. Please visit https://github.com/onyx-dot-app/onyx"
|
||||
|
||||
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG ONYX_VERSION=0.8-dev
|
||||
ARG ONYX_VERSION=0.0.0-dev
|
||||
# DO_NOT_TRACK is used to disable telemetry for Unstructured
|
||||
ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
@@ -102,6 +102,7 @@ COPY ./alembic /app/alembic
|
||||
COPY ./alembic_tenants /app/alembic_tenants
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
COPY ./static /app/static
|
||||
|
||||
# Escape hatch scripts
|
||||
COPY ./scripts/debugging /app/scripts/debugging
|
||||
|
||||
@@ -7,7 +7,7 @@ You can find it at https://hub.docker.com/r/onyx/onyx-model-server. For more det
|
||||
visit https://github.com/onyx-dot-app/onyx."
|
||||
|
||||
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG ONYX_VERSION=0.8-dev
|
||||
ARG ONYX_VERSION=0.0.0-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
@@ -31,7 +31,8 @@ RUN python -c "from transformers import AutoTokenizer; \
|
||||
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
|
||||
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
from huggingface_hub import snapshot_download; \
|
||||
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \
|
||||
snapshot_download(repo_id='onyx-dot-app/hybrid-intent-token-classifier'); \
|
||||
snapshot_download(repo_id='onyx-dot-app/information-content-model'); \
|
||||
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
from sentence_transformers import SentenceTransformer; \
|
||||
|
||||
@@ -84,7 +84,7 @@ keys = console
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
level = INFO
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
|
||||
@@ -25,6 +25,9 @@ from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
|
||||
from onyx.db.models import Base
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
|
||||
# Make sure in alembic.ini [logger_root] level=INFO is set or most logging will be
|
||||
# hidden! (defaults to level=WARN)
|
||||
|
||||
# Alembic Config object
|
||||
config = context.config
|
||||
|
||||
@@ -36,6 +39,7 @@ if config.config_file_name is not None and config.attributes.get(
|
||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ssl_context: ssl.SSLContext | None = None
|
||||
@@ -64,7 +68,7 @@ def include_object(
|
||||
return True
|
||||
|
||||
|
||||
def get_schema_options() -> tuple[str, bool, bool]:
|
||||
def get_schema_options() -> tuple[str, bool, bool, bool]:
|
||||
x_args_raw = context.get_x_argument()
|
||||
x_args = {}
|
||||
for arg in x_args_raw:
|
||||
@@ -76,6 +80,10 @@ def get_schema_options() -> tuple[str, bool, bool]:
|
||||
create_schema = x_args.get("create_schema", "true").lower() == "true"
|
||||
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
|
||||
|
||||
# continue on error with individual tenant
|
||||
# only applies to online migrations
|
||||
continue_on_error = x_args.get("continue", "false").lower() == "true"
|
||||
|
||||
if (
|
||||
MULTI_TENANT
|
||||
and schema_name == POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -86,14 +94,12 @@ def get_schema_options() -> tuple[str, bool, bool]:
|
||||
"Please specify a tenant-specific schema."
|
||||
)
|
||||
|
||||
return schema_name, create_schema, upgrade_all_tenants
|
||||
return schema_name, create_schema, upgrade_all_tenants, continue_on_error
|
||||
|
||||
|
||||
def do_run_migrations(
|
||||
connection: Connection, schema_name: str, create_schema: bool
|
||||
) -> None:
|
||||
logger.info(f"About to migrate schema: {schema_name}")
|
||||
|
||||
if create_schema:
|
||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
|
||||
connection.execute(text("COMMIT"))
|
||||
@@ -134,7 +140,12 @@ def provide_iam_token_for_alembic(
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
|
||||
(
|
||||
schema_name,
|
||||
create_schema,
|
||||
upgrade_all_tenants,
|
||||
continue_on_error,
|
||||
) = get_schema_options()
|
||||
|
||||
engine = create_async_engine(
|
||||
build_connection_string(),
|
||||
@@ -151,9 +162,15 @@ async def run_async_migrations() -> None:
|
||||
|
||||
if upgrade_all_tenants:
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
|
||||
i_tenant = 0
|
||||
num_tenants = len(tenant_schemas)
|
||||
for schema in tenant_schemas:
|
||||
i_tenant += 1
|
||||
logger.info(
|
||||
f"Migrating schema: index={i_tenant} num_tenants={num_tenants} schema={schema}"
|
||||
)
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
async with engine.connect() as connection:
|
||||
await connection.run_sync(
|
||||
do_run_migrations,
|
||||
@@ -162,7 +179,12 @@ async def run_async_migrations() -> None:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
raise
|
||||
if not continue_on_error:
|
||||
logger.error("--continue is not set, raising exception!")
|
||||
raise
|
||||
|
||||
logger.warning("--continue is set, continuing to next schema.")
|
||||
|
||||
else:
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema_name}")
|
||||
@@ -180,7 +202,11 @@ async def run_async_migrations() -> None:
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
schema_name, _, upgrade_all_tenants = get_schema_options()
|
||||
"""This doesn't really get used when we migrate in the cloud."""
|
||||
|
||||
logger.info("run_migrations_offline starting.")
|
||||
|
||||
schema_name, _, upgrade_all_tenants, continue_on_error = get_schema_options()
|
||||
url = build_connection_string()
|
||||
|
||||
if upgrade_all_tenants:
|
||||
@@ -230,6 +256,7 @@ def run_migrations_offline() -> None:
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
logger.info("run_migrations_online starting.")
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
"""add chunk stats table
|
||||
|
||||
Revision ID: 3781a5eb12cb
|
||||
Revises: df46c75b714e
|
||||
Create Date: 2025-03-10 10:02:30.586666
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3781a5eb12cb"
|
||||
down_revision = "df46c75b714e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"chunk_stats",
|
||||
sa.Column("id", sa.String(), primary_key=True, index=True),
|
||||
sa.Column(
|
||||
"document_id",
|
||||
sa.String(),
|
||||
sa.ForeignKey("document.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
),
|
||||
sa.Column("chunk_in_doc_id", sa.Integer(), nullable=False),
|
||||
sa.Column("information_content_boost", sa.Float(), nullable=True),
|
||||
sa.Column(
|
||||
"last_modified",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
index=True,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True, index=True),
|
||||
sa.UniqueConstraint(
|
||||
"document_id", "chunk_in_doc_id", name="uq_chunk_stats_doc_chunk"
|
||||
),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_chunk_sync_status", "chunk_stats", ["last_modified", "last_synced"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_chunk_sync_status", table_name="chunk_stats")
|
||||
op.drop_table("chunk_stats")
|
||||
@@ -0,0 +1,125 @@
|
||||
"""Update GitHub connector repo_name to repositories
|
||||
|
||||
Revision ID: 3934b1bc7b62
|
||||
Revises: b7c2b63c4a03
|
||||
Create Date: 2025-03-05 10:50:30.516962
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import json
|
||||
import logging
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3934b1bc7b62"
|
||||
down_revision = "b7c2b63c4a03"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Get all GitHub connectors
|
||||
conn = op.get_bind()
|
||||
|
||||
# First get all GitHub connectors
|
||||
github_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = 'GITHUB'
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
# Update each connector's config
|
||||
updated_count = 0
|
||||
for connector_id, config in github_connectors:
|
||||
try:
|
||||
if not config:
|
||||
logger.warning(f"Connector {connector_id} has no config, skipping")
|
||||
continue
|
||||
|
||||
# Parse the config if it's a string
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
|
||||
if "repo_name" not in config:
|
||||
continue
|
||||
|
||||
# Create new config with repositories instead of repo_name
|
||||
new_config = dict(config)
|
||||
repo_name_value = new_config.pop("repo_name")
|
||||
new_config["repositories"] = repo_name_value
|
||||
|
||||
# Update the connector with the new config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :new_config
|
||||
WHERE id = :connector_id
|
||||
"""
|
||||
),
|
||||
{"connector_id": connector_id, "new_config": json.dumps(new_config)},
|
||||
)
|
||||
updated_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating connector {connector_id}: {str(e)}")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Get all GitHub connectors
|
||||
conn = op.get_bind()
|
||||
|
||||
logger.debug(
|
||||
"Starting rollback of GitHub connectors from repositories to repo_name"
|
||||
)
|
||||
|
||||
github_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = 'GITHUB'
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
logger.debug(f"Found {len(github_connectors)} GitHub connectors to rollback")
|
||||
|
||||
# Revert each GitHub connector to use repo_name instead of repositories
|
||||
reverted_count = 0
|
||||
for connector_id, config in github_connectors:
|
||||
try:
|
||||
if not config:
|
||||
continue
|
||||
|
||||
# Parse the config if it's a string
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
|
||||
if "repositories" not in config:
|
||||
continue
|
||||
|
||||
# Create new config with repo_name instead of repositories
|
||||
new_config = dict(config)
|
||||
repositories_value = new_config.pop("repositories")
|
||||
new_config["repo_name"] = repositories_value
|
||||
|
||||
# Update the connector with the new config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :new_config
|
||||
WHERE id = :connector_id
|
||||
"""
|
||||
),
|
||||
{"new_config": json.dumps(new_config), "connector_id": connector_id},
|
||||
)
|
||||
reverted_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error reverting connector {connector_id}: {str(e)}")
|
||||
@@ -28,6 +28,20 @@ depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# First, drop any existing indexes to avoid conflicts
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
|
||||
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
|
||||
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
|
||||
|
||||
# Drop existing columns if they exist
|
||||
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
|
||||
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
|
||||
|
||||
# Create a GIN index for full-text search on chat_message.message
|
||||
op.execute(
|
||||
"""
|
||||
|
||||
@@ -5,7 +5,10 @@ Revises: f1ca58b2f2ec
|
||||
Create Date: 2025-01-29 07:48:46.784041
|
||||
|
||||
"""
|
||||
import logging
|
||||
from typing import cast
|
||||
from alembic import op
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
|
||||
@@ -15,21 +18,45 @@ down_revision = "f1ca58b2f2ec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Get database connection
|
||||
"""Conflicts on lowercasing will result in the uppercased email getting a
|
||||
unique integer suffix when converted to lowercase."""
|
||||
|
||||
connection = op.get_bind()
|
||||
|
||||
# Update all user emails to lowercase
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET email = LOWER(email)
|
||||
WHERE email != LOWER(email)
|
||||
"""
|
||||
)
|
||||
)
|
||||
# Fetch all user emails that are not already lowercase
|
||||
user_emails = connection.execute(
|
||||
text('SELECT id, email FROM "user" WHERE email != LOWER(email)')
|
||||
).fetchall()
|
||||
|
||||
for user_id, email in user_emails:
|
||||
email = cast(str, email)
|
||||
username, domain = email.rsplit("@", 1)
|
||||
new_email = f"{username.lower()}@{domain.lower()}"
|
||||
attempt = 1
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Try updating the email
|
||||
connection.execute(
|
||||
text('UPDATE "user" SET email = :new_email WHERE id = :user_id'),
|
||||
{"new_email": new_email, "user_id": user_id},
|
||||
)
|
||||
break # Success, exit loop
|
||||
except IntegrityError:
|
||||
next_email = f"{username.lower()}_{attempt}@{domain.lower()}"
|
||||
# Email conflict occurred, append `_1`, `_2`, etc., to the username
|
||||
logger.warning(
|
||||
f"Conflict while lowercasing email: "
|
||||
f"old_email={email} "
|
||||
f"conflicting_email={new_email} "
|
||||
f"next_email={next_email}"
|
||||
)
|
||||
new_email = next_email
|
||||
attempt += 1
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
"""duplicated no-harm user file migration
|
||||
|
||||
Revision ID: 6a804aeb4830
|
||||
Revises: 8e1ac4f39a9f
|
||||
Create Date: 2025-04-01 07:26:10.539362
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import inspect
|
||||
import datetime
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6a804aeb4830"
|
||||
down_revision = "8e1ac4f39a9f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Check if user_file table already exists
|
||||
conn = op.get_bind()
|
||||
inspector = inspect(conn)
|
||||
|
||||
if not inspector.has_table("user_file"):
|
||||
# Create user_folder table without parent_id
|
||||
op.create_table(
|
||||
"user_folder",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
|
||||
sa.Column("name", sa.String(length=255), nullable=True),
|
||||
sa.Column("description", sa.String(length=255), nullable=True),
|
||||
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(timezone=True), server_default=sa.func.now()
|
||||
),
|
||||
)
|
||||
|
||||
# Create user_file table with folder_id instead of parent_folder_id
|
||||
op.create_table(
|
||||
"user_file",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
|
||||
sa.Column(
|
||||
"folder_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_folder.id"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("link_url", sa.String(), nullable=True),
|
||||
sa.Column("token_count", sa.Integer(), nullable=True),
|
||||
sa.Column("file_type", sa.String(), nullable=True),
|
||||
sa.Column("file_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("document_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(),
|
||||
default=datetime.datetime.utcnow,
|
||||
),
|
||||
sa.Column(
|
||||
"cc_pair_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("connector_credential_pair.id"),
|
||||
nullable=True,
|
||||
unique=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create persona__user_file table
|
||||
op.create_table(
|
||||
"persona__user_file",
|
||||
sa.Column(
|
||||
"persona_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("persona.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column(
|
||||
"user_file_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_file.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create persona__user_folder table
|
||||
op.create_table(
|
||||
"persona__user_folder",
|
||||
sa.Column(
|
||||
"persona_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("persona.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column(
|
||||
"user_folder_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_folder.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("is_user_file", sa.Boolean(), nullable=True, default=False),
|
||||
)
|
||||
|
||||
# Update existing records to have is_user_file=False instead of NULL
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET is_user_file = FALSE WHERE is_user_file IS NULL"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -0,0 +1,50 @@
|
||||
"""enable contextual retrieval
|
||||
|
||||
Revision ID: 8e1ac4f39a9f
|
||||
Revises: 9aadf32dfeb4
|
||||
Create Date: 2024-12-20 13:29:09.918661
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8e1ac4f39a9f"
|
||||
down_revision = "9aadf32dfeb4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"enable_contextual_rag",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"contextual_rag_llm_name",
|
||||
sa.String(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"contextual_rag_llm_provider",
|
||||
sa.String(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("search_settings", "enable_contextual_rag")
|
||||
op.drop_column("search_settings", "contextual_rag_llm_name")
|
||||
op.drop_column("search_settings", "contextual_rag_llm_provider")
|
||||
113
backend/alembic/versions/9aadf32dfeb4_add_user_files.py
Normal file
113
backend/alembic/versions/9aadf32dfeb4_add_user_files.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""add user files
|
||||
|
||||
Revision ID: 9aadf32dfeb4
|
||||
Revises: 3781a5eb12cb
|
||||
Create Date: 2025-01-26 16:08:21.551022
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import datetime
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9aadf32dfeb4"
|
||||
down_revision = "3781a5eb12cb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create user_folder table without parent_id
|
||||
op.create_table(
|
||||
"user_folder",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
|
||||
sa.Column("name", sa.String(length=255), nullable=True),
|
||||
sa.Column("description", sa.String(length=255), nullable=True),
|
||||
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(timezone=True), server_default=sa.func.now()
|
||||
),
|
||||
)
|
||||
|
||||
# Create user_file table with folder_id instead of parent_folder_id
|
||||
op.create_table(
|
||||
"user_file",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
|
||||
sa.Column(
|
||||
"folder_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_folder.id"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("link_url", sa.String(), nullable=True),
|
||||
sa.Column("token_count", sa.Integer(), nullable=True),
|
||||
sa.Column("file_type", sa.String(), nullable=True),
|
||||
sa.Column("file_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("document_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(),
|
||||
default=datetime.datetime.utcnow,
|
||||
),
|
||||
sa.Column(
|
||||
"cc_pair_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("connector_credential_pair.id"),
|
||||
nullable=True,
|
||||
unique=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create persona__user_file table
|
||||
op.create_table(
|
||||
"persona__user_file",
|
||||
sa.Column(
|
||||
"persona_id", sa.Integer(), sa.ForeignKey("persona.id"), primary_key=True
|
||||
),
|
||||
sa.Column(
|
||||
"user_file_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_file.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create persona__user_folder table
|
||||
op.create_table(
|
||||
"persona__user_folder",
|
||||
sa.Column(
|
||||
"persona_id", sa.Integer(), sa.ForeignKey("persona.id"), primary_key=True
|
||||
),
|
||||
sa.Column(
|
||||
"user_folder_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_folder.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("is_user_file", sa.Boolean(), nullable=True, default=False),
|
||||
)
|
||||
|
||||
# Update existing records to have is_user_file=False instead of NULL
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET is_user_file = FALSE WHERE is_user_file IS NULL"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the persona__user_folder table
|
||||
op.drop_table("persona__user_folder")
|
||||
# Drop the persona__user_file table
|
||||
op.drop_table("persona__user_file")
|
||||
# Drop the user_file table
|
||||
op.drop_table("user_file")
|
||||
# Drop the user_folder table
|
||||
op.drop_table("user_folder")
|
||||
op.drop_column("connector_credential_pair", "is_user_file")
|
||||
@@ -0,0 +1,36 @@
|
||||
"""add_default_vision_provider_to_llm_provider
|
||||
|
||||
Revision ID: df46c75b714e
|
||||
Revises: 3934b1bc7b62
|
||||
Create Date: 2025-03-11 16:20:19.038945
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "df46c75b714e"
|
||||
down_revision = "3934b1bc7b62"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column(
|
||||
"is_default_vision_provider",
|
||||
sa.Boolean(),
|
||||
nullable=True,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"llm_provider", sa.Column("default_vision_model", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("llm_provider", "default_vision_model")
|
||||
op.drop_column("llm_provider", "is_default_vision_provider")
|
||||
@@ -0,0 +1,50 @@
|
||||
"""add prompt length limit
|
||||
|
||||
Revision ID: f71470ba9274
|
||||
Revises: 6a804aeb4830
|
||||
Create Date: 2025-04-01 15:07:14.977435
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f71470ba9274"
|
||||
down_revision = "6a804aeb4830"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# op.alter_column(
|
||||
# "prompt",
|
||||
# "system_prompt",
|
||||
# existing_type=sa.TEXT(),
|
||||
# type_=sa.String(length=8000),
|
||||
# existing_nullable=False,
|
||||
# )
|
||||
# op.alter_column(
|
||||
# "prompt",
|
||||
# "task_prompt",
|
||||
# existing_type=sa.TEXT(),
|
||||
# type_=sa.String(length=8000),
|
||||
# existing_nullable=False,
|
||||
# )
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# op.alter_column(
|
||||
# "prompt",
|
||||
# "system_prompt",
|
||||
# existing_type=sa.String(length=8000),
|
||||
# type_=sa.TEXT(),
|
||||
# existing_nullable=False,
|
||||
# )
|
||||
# op.alter_column(
|
||||
# "prompt",
|
||||
# "task_prompt",
|
||||
# existing_type=sa.String(length=8000),
|
||||
# type_=sa.TEXT(),
|
||||
# existing_nullable=False,
|
||||
# )
|
||||
pass
|
||||
@@ -0,0 +1,77 @@
|
||||
"""updated constraints for ccpairs
|
||||
|
||||
Revision ID: f7505c5b0284
|
||||
Revises: f71470ba9274
|
||||
Create Date: 2025-04-01 17:50:42.504818
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f7505c5b0284"
|
||||
down_revision = "f71470ba9274"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1) Drop the old foreign-key constraints
|
||||
op.drop_constraint(
|
||||
"document_by_connector_credential_pair_connector_id_fkey",
|
||||
"document_by_connector_credential_pair",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"document_by_connector_credential_pair_credential_id_fkey",
|
||||
"document_by_connector_credential_pair",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# 2) Re-add them with ondelete='CASCADE'
|
||||
op.create_foreign_key(
|
||||
"document_by_connector_credential_pair_connector_id_fkey",
|
||||
source_table="document_by_connector_credential_pair",
|
||||
referent_table="connector",
|
||||
local_cols=["connector_id"],
|
||||
remote_cols=["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"document_by_connector_credential_pair_credential_id_fkey",
|
||||
source_table="document_by_connector_credential_pair",
|
||||
referent_table="credential",
|
||||
local_cols=["credential_id"],
|
||||
remote_cols=["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Reverse the changes for rollback
|
||||
op.drop_constraint(
|
||||
"document_by_connector_credential_pair_connector_id_fkey",
|
||||
"document_by_connector_credential_pair",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"document_by_connector_credential_pair_credential_id_fkey",
|
||||
"document_by_connector_credential_pair",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Recreate without CASCADE
|
||||
op.create_foreign_key(
|
||||
"document_by_connector_credential_pair_connector_id_fkey",
|
||||
"document_by_connector_credential_pair",
|
||||
"connector",
|
||||
["connector_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"document_by_connector_credential_pair_credential_id_fkey",
|
||||
"document_by_connector_credential_pair",
|
||||
"credential",
|
||||
["credential_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -0,0 +1,33 @@
|
||||
"""add new available tenant table
|
||||
|
||||
Revision ID: 3b45e0018bf1
|
||||
Revises: ac842f85f932
|
||||
Create Date: 2025-03-06 09:55:18.229910
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3b45e0018bf1"
|
||||
down_revision = "ac842f85f932"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create new_available_tenant table
|
||||
op.create_table(
|
||||
"available_tenant",
|
||||
sa.Column("tenant_id", sa.String(), nullable=False),
|
||||
sa.Column("alembic_version", sa.String(), nullable=False),
|
||||
sa.Column("date_created", sa.DateTime(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("tenant_id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop new_available_tenant table
|
||||
op.drop_table("available_tenant")
|
||||
@@ -0,0 +1,51 @@
|
||||
"""new column user tenant mapping
|
||||
|
||||
Revision ID: ac842f85f932
|
||||
Revises: 34e3630c7f32
|
||||
Create Date: 2025-03-03 13:30:14.802874
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ac842f85f932"
|
||||
down_revision = "34e3630c7f32"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add active column with default value of True
|
||||
op.add_column(
|
||||
"user_tenant_mapping",
|
||||
sa.Column(
|
||||
"active",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="true",
|
||||
),
|
||||
schema="public",
|
||||
)
|
||||
|
||||
op.drop_constraint("uq_email", "user_tenant_mapping", schema="public")
|
||||
|
||||
# Create a unique index for active=true records
|
||||
# This ensures a user can only be active in one tenant at a time
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX uq_user_active_email_idx ON public.user_tenant_mapping (email) WHERE active = true"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the unique index for active=true records
|
||||
op.execute("DROP INDEX IF EXISTS uq_user_active_email_idx")
|
||||
|
||||
op.create_unique_constraint(
|
||||
"uq_email", "user_tenant_mapping", ["email"], schema="public"
|
||||
)
|
||||
|
||||
# Remove the active column
|
||||
op.drop_column("user_tenant_mapping", "active", schema="public")
|
||||
@@ -93,12 +93,12 @@ def _get_access_for_documents(
|
||||
)
|
||||
|
||||
# To avoid collisions of group namings between connectors, they need to be prefixed
|
||||
access_map[document_id] = DocumentAccess(
|
||||
user_emails=non_ee_access.user_emails,
|
||||
user_groups=set(user_group_info.get(document_id, [])),
|
||||
access_map[document_id] = DocumentAccess.build(
|
||||
user_emails=list(non_ee_access.user_emails),
|
||||
user_groups=user_group_info.get(document_id, []),
|
||||
is_public=is_public_anywhere,
|
||||
external_user_emails=ext_u_emails,
|
||||
external_user_group_ids=ext_u_groups,
|
||||
external_user_emails=list(ext_u_emails),
|
||||
external_user_group_ids=list(ext_u_groups),
|
||||
)
|
||||
return access_map
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from ee.onyx.server.query_and_chat.models import OneShotQAResponse
|
||||
from onyx.chat.models import AllCitations
|
||||
from onyx.chat.models import LLMRelevanceFilterResponse
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.process_message import ChatPacketStream
|
||||
@@ -32,8 +31,6 @@ def gather_stream_for_answer_api(
|
||||
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||
elif isinstance(packet, AllCitations):
|
||||
response.citations = packet.citations
|
||||
elif isinstance(packet, OnyxContexts):
|
||||
response.contexts = packet
|
||||
|
||||
if answer:
|
||||
response.answer = answer
|
||||
|
||||
@@ -25,6 +25,10 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_co
|
||||
#####
|
||||
# Auto Permission Sync
|
||||
#####
|
||||
DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
# In seconds, default is 5 minutes
|
||||
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
@@ -39,6 +43,7 @@ CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
|
||||
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
|
||||
|
||||
|
||||
@@ -72,6 +77,13 @@ OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
|
||||
)
|
||||
|
||||
GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("SLACK_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
# The posthog client does not accept empty API keys or hosts however it fails silently
|
||||
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
|
||||
|
||||
@@ -27,6 +27,8 @@ def get_empty_chat_messages_entries__paginated(
|
||||
first element is the most recent timestamp out of the sessions iterated
|
||||
- this timestamp can be used to paginate forward in time
|
||||
second element is a list of messages belonging to all the sessions iterated
|
||||
|
||||
Only messages of type USER are returned
|
||||
"""
|
||||
chat_sessions = fetch_chat_sessions_eagerly_by_time(
|
||||
start=period[0],
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
Rules defined here:
|
||||
https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.html
|
||||
"""
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC
|
||||
@@ -263,13 +264,11 @@ def _fetch_all_page_restrictions(
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess],
|
||||
is_cloud: bool,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
For all pages, if a page has restrictions, then use those restrictions.
|
||||
Otherwise, use the space's restrictions.
|
||||
"""
|
||||
document_restrictions: list[DocExternalAccess] = []
|
||||
|
||||
for slim_doc in slim_docs:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
@@ -286,11 +285,9 @@ def _fetch_all_page_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
perm_sync_data=slim_doc.perm_sync_data,
|
||||
):
|
||||
document_restrictions.append(
|
||||
DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=restrictions,
|
||||
)
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=restrictions,
|
||||
)
|
||||
# If there are restrictions, then we don't need to use the space's restrictions
|
||||
continue
|
||||
@@ -324,11 +321,9 @@ def _fetch_all_page_restrictions(
|
||||
continue
|
||||
|
||||
# If there are no restrictions, then use the space's restrictions
|
||||
document_restrictions.append(
|
||||
DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=space_permissions,
|
||||
)
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=space_permissions,
|
||||
)
|
||||
if (
|
||||
not space_permissions.is_public
|
||||
@@ -342,13 +337,12 @@ def _fetch_all_page_restrictions(
|
||||
)
|
||||
|
||||
logger.debug("Finished fetching all page restrictions for space")
|
||||
return document_restrictions
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
@@ -387,7 +381,7 @@ def confluence_doc_sync(
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
logger.debug("Fetching all page restrictions for space")
|
||||
return _fetch_all_page_restrictions(
|
||||
yield from _fetch_all_page_restrictions(
|
||||
confluence_client=confluence_connector.confluence_client,
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
@@ -34,7 +35,7 @@ def _get_slim_doc_generator(
|
||||
def gmail_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
@@ -48,7 +49,6 @@ def gmail_doc_sync(
|
||||
cc_pair, gmail_connector, callback=callback
|
||||
)
|
||||
|
||||
document_external_access: list[DocExternalAccess] = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
@@ -60,17 +60,14 @@ def gmail_doc_sync(
|
||||
if slim_doc.perm_sync_data is None:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
continue
|
||||
|
||||
if user_email := slim_doc.perm_sync_data.get("user_email"):
|
||||
ext_access = ExternalAccess(
|
||||
external_user_emails=set([user_email]),
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
document_external_access.append(
|
||||
DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=ext_access,
|
||||
)
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=ext_access,
|
||||
)
|
||||
|
||||
return document_external_access
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
@@ -147,7 +148,7 @@ def _get_permissions_from_slim_doc(
|
||||
def gdrive_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
@@ -161,7 +162,6 @@ def gdrive_doc_sync(
|
||||
|
||||
slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector)
|
||||
|
||||
document_external_accesses = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
@@ -174,10 +174,7 @@ def gdrive_doc_sync(
|
||||
google_drive_connector=google_drive_connector,
|
||||
slim_doc=slim_doc,
|
||||
)
|
||||
document_external_accesses.append(
|
||||
DocExternalAccess(
|
||||
external_access=ext_access,
|
||||
doc_id=slim_doc.id,
|
||||
)
|
||||
yield DocExternalAccess(
|
||||
external_access=ext_access,
|
||||
doc_id=slim_doc.id,
|
||||
)
|
||||
return document_external_accesses
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
@@ -12,6 +14,7 @@ logger = setup_logger()
|
||||
|
||||
def _get_drive_members(
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
admin_service: AdminService,
|
||||
) -> dict[str, tuple[set[str], set[str]]]:
|
||||
"""
|
||||
This builds a map of drive ids to their members (group and user emails).
|
||||
@@ -20,6 +23,8 @@ def _get_drive_members(
|
||||
"drive_id_2": ({"group_email_3"}, {"user_email_3"}),
|
||||
}
|
||||
"""
|
||||
|
||||
# fetches shared drives only
|
||||
drive_ids = google_drive_connector.get_all_drive_ids()
|
||||
|
||||
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]] = {}
|
||||
@@ -28,20 +33,44 @@ def _get_drive_members(
|
||||
google_drive_connector.primary_admin_email,
|
||||
)
|
||||
|
||||
admin_user_info = (
|
||||
admin_service.users()
|
||||
.get(userKey=google_drive_connector.primary_admin_email)
|
||||
.execute()
|
||||
)
|
||||
is_admin = admin_user_info.get("isAdmin", False) or admin_user_info.get(
|
||||
"isDelegatedAdmin", False
|
||||
)
|
||||
|
||||
for drive_id in drive_ids:
|
||||
group_emails: set[str] = set()
|
||||
user_emails: set[str] = set()
|
||||
for permission in execute_paginated_retrieval(
|
||||
drive_service.permissions().list,
|
||||
list_key="permissions",
|
||||
fileId=drive_id,
|
||||
fields="permissions(emailAddress, type)",
|
||||
supportsAllDrives=True,
|
||||
):
|
||||
if permission["type"] == "group":
|
||||
group_emails.add(permission["emailAddress"])
|
||||
elif permission["type"] == "user":
|
||||
user_emails.add(permission["emailAddress"])
|
||||
|
||||
try:
|
||||
for permission in execute_paginated_retrieval(
|
||||
drive_service.permissions().list,
|
||||
list_key="permissions",
|
||||
fileId=drive_id,
|
||||
fields="permissions(emailAddress, type)",
|
||||
supportsAllDrives=True,
|
||||
# can only set `useDomainAdminAccess` to true if the user
|
||||
# is an admin
|
||||
useDomainAdminAccess=is_admin,
|
||||
):
|
||||
if permission["type"] == "group":
|
||||
group_emails.add(permission["emailAddress"])
|
||||
elif permission["type"] == "user":
|
||||
user_emails.add(permission["emailAddress"])
|
||||
except HttpError as e:
|
||||
if e.status_code == 404:
|
||||
logger.warning(
|
||||
f"Error getting permissions for drive id {drive_id}. "
|
||||
f"User '{google_drive_connector.primary_admin_email}' likely "
|
||||
f"does not have access to this drive. Exception: {e}"
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
|
||||
drive_id_to_members_map[drive_id] = (group_emails, user_emails)
|
||||
return drive_id_to_members_map
|
||||
|
||||
@@ -132,7 +161,7 @@ def gdrive_group_sync(
|
||||
)
|
||||
|
||||
# Get all drive members
|
||||
drive_id_to_members_map = _get_drive_members(google_drive_connector)
|
||||
drive_id_to_members_map = _get_drive_members(google_drive_connector, admin_service)
|
||||
|
||||
# Get all group emails
|
||||
all_group_emails = _get_all_groups(
|
||||
|
||||
@@ -55,7 +55,7 @@ def _post_query_chunk_censoring(
|
||||
# if user is None, permissions are not enforced
|
||||
return chunks
|
||||
|
||||
chunks_to_keep = []
|
||||
final_chunk_dict: dict[str, InferenceChunk] = {}
|
||||
chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {}
|
||||
|
||||
sources_to_censor = _get_all_censoring_enabled_sources()
|
||||
@@ -64,7 +64,7 @@ def _post_query_chunk_censoring(
|
||||
if chunk.source_type in sources_to_censor:
|
||||
chunks_to_process.setdefault(chunk.source_type, []).append(chunk)
|
||||
else:
|
||||
chunks_to_keep.append(chunk)
|
||||
final_chunk_dict[chunk.unique_id] = chunk
|
||||
|
||||
# For each source, filter out the chunks using the permission
|
||||
# check function for that source
|
||||
@@ -79,6 +79,16 @@ def _post_query_chunk_censoring(
|
||||
f" chunks for this source and continuing: {e}"
|
||||
)
|
||||
continue
|
||||
chunks_to_keep.extend(censored_chunks)
|
||||
|
||||
return chunks_to_keep
|
||||
for censored_chunk in censored_chunks:
|
||||
final_chunk_dict[censored_chunk.unique_id] = censored_chunk
|
||||
|
||||
# IMPORTANT: make sure to retain the same ordering as the original `chunks` passed in
|
||||
final_chunk_list: list[InferenceChunk] = []
|
||||
for chunk in chunks:
|
||||
# only if the chunk is in the final censored chunks, add it to the final list
|
||||
# if it is missing, that means it was intentionally left out
|
||||
if chunk.unique_id in final_chunk_dict:
|
||||
final_chunk_list.append(final_chunk_dict[chunk.unique_id])
|
||||
|
||||
return final_chunk_list
|
||||
|
||||
@@ -58,6 +58,7 @@ def _get_objects_access_for_user_email_from_salesforce(
|
||||
f"Time taken to get Salesforce user ID: {end_time - start_time} seconds"
|
||||
)
|
||||
if user_id is None:
|
||||
logger.warning(f"User '{user_email}' not found in Salesforce")
|
||||
return None
|
||||
|
||||
# This is the only query that is not cached in the function
|
||||
@@ -65,6 +66,7 @@ def _get_objects_access_for_user_email_from_salesforce(
|
||||
object_id_to_access = get_objects_access_for_user_id(
|
||||
salesforce_client, user_id, list(object_ids)
|
||||
)
|
||||
logger.debug(f"Object ID to access: {object_id_to_access}")
|
||||
return object_id_to_access
|
||||
|
||||
|
||||
|
||||
@@ -42,11 +42,18 @@ def get_any_salesforce_client_for_doc_id(
|
||||
|
||||
|
||||
def _query_salesforce_user_id(sf_client: Salesforce, user_email: str) -> str | None:
|
||||
query = f"SELECT Id FROM User WHERE Email = '{user_email}'"
|
||||
query = f"SELECT Id FROM User WHERE Username = '{user_email}' AND IsActive = true"
|
||||
result = sf_client.query(query)
|
||||
if len(result["records"]) == 0:
|
||||
return None
|
||||
return result["records"][0]["Id"]
|
||||
if len(result["records"]) > 0:
|
||||
return result["records"][0]["Id"]
|
||||
|
||||
# try emails
|
||||
query = f"SELECT Id FROM User WHERE Email = '{user_email}' AND IsActive = true"
|
||||
result = sf_client.query(query)
|
||||
if len(result["records"]) > 0:
|
||||
return result["records"][0]["Id"]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# This contains only the user_ids that we have found in Salesforce.
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
@@ -14,35 +16,6 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_slack_document_ids_and_channels(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> dict[str, list[str]]:
|
||||
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
|
||||
|
||||
channel_doc_map: dict[str, list[str]] = {}
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if doc_metadata.perm_sync_data is None:
|
||||
continue
|
||||
channel_id = doc_metadata.perm_sync_data["channel_id"]
|
||||
if channel_id not in channel_doc_map:
|
||||
channel_doc_map[channel_id] = []
|
||||
channel_doc_map[channel_id].append(doc_metadata.id)
|
||||
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_get_slack_document_ids_and_channels: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_get_slack_document_ids_and_channels", 1)
|
||||
|
||||
return channel_doc_map
|
||||
|
||||
|
||||
def _fetch_workspace_permissions(
|
||||
user_id_to_email_map: dict[str, str],
|
||||
) -> ExternalAccess:
|
||||
@@ -122,10 +95,37 @@ def _fetch_channel_permissions(
|
||||
return channel_permissions
|
||||
|
||||
|
||||
def _get_slack_document_access(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
channel_permissions: dict[str, ExternalAccess],
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
|
||||
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if doc_metadata.perm_sync_data is None:
|
||||
continue
|
||||
channel_id = doc_metadata.perm_sync_data["channel_id"]
|
||||
yield DocExternalAccess(
|
||||
external_access=channel_permissions[channel_id],
|
||||
doc_id=doc_metadata.id,
|
||||
)
|
||||
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("_get_slack_document_access: Stop signal detected")
|
||||
|
||||
callback.progress("_get_slack_document_access", 1)
|
||||
|
||||
|
||||
def slack_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
@@ -136,9 +136,12 @@ def slack_doc_sync(
|
||||
token=cc_pair.credential.credential_json["slack_bot_token"]
|
||||
)
|
||||
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||
channel_doc_map = _get_slack_document_ids_and_channels(
|
||||
cc_pair=cc_pair, callback=callback
|
||||
)
|
||||
if not user_id_to_email_map:
|
||||
raise ValueError(
|
||||
"No user id to email map found. Please check to make sure that "
|
||||
"your Slack bot token has the `users:read.email` scope"
|
||||
)
|
||||
|
||||
workspace_permissions = _fetch_workspace_permissions(
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
)
|
||||
@@ -148,18 +151,8 @@ def slack_doc_sync(
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
)
|
||||
|
||||
document_external_accesses = []
|
||||
for channel_id, ext_access in channel_permissions.items():
|
||||
doc_ids = channel_doc_map.get(channel_id)
|
||||
if not doc_ids:
|
||||
# No documents found for channel the channel_id
|
||||
continue
|
||||
|
||||
for doc_id in doc_ids:
|
||||
document_external_accesses.append(
|
||||
DocExternalAccess(
|
||||
external_access=ext_access,
|
||||
doc_id=doc_id,
|
||||
)
|
||||
)
|
||||
return document_external_accesses
|
||||
yield from _get_slack_document_access(
|
||||
cc_pair=cc_pair,
|
||||
channel_permissions=channel_permissions,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
|
||||
from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync
|
||||
@@ -23,7 +26,7 @@ DocSyncFuncType = Callable[
|
||||
ConnectorCredentialPair,
|
||||
IndexingHeartbeatInterface | None,
|
||||
],
|
||||
list[DocExternalAccess],
|
||||
Generator[DocExternalAccess, None, None],
|
||||
]
|
||||
|
||||
GroupSyncFuncType = Callable[
|
||||
@@ -65,13 +68,13 @@ GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC: set[DocumentSource] = {
|
||||
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all doc permissions every 5 minutes
|
||||
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
DocumentSource.SLACK: 5 * 60,
|
||||
DocumentSource.SLACK: SLACK_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
}
|
||||
|
||||
# If nothing is specified here, we run the doc_sync every time the celery beat runs
|
||||
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all group permissions every 30 minutes
|
||||
DocumentSource.GOOGLE_DRIVE: 5 * 60,
|
||||
DocumentSource.GOOGLE_DRIVE: GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY,
|
||||
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY,
|
||||
}
|
||||
|
||||
|
||||
@@ -64,7 +64,15 @@ def get_application() -> FastAPI:
|
||||
add_tenant_id_middleware(application, logger)
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
# For Google OAuth, refresh tokens are requested by:
|
||||
# 1. Adding the right scopes
|
||||
# 2. Properly configuring OAuth in Google Cloud Console to allow offline access
|
||||
oauth_client = GoogleOAuth2(
|
||||
OAUTH_CLIENT_ID,
|
||||
OAUTH_CLIENT_SECRET,
|
||||
# Use standard scopes that include profile and email
|
||||
scopes=["openid", "email", "profile"],
|
||||
)
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
@@ -87,6 +95,16 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.OIDC:
|
||||
# Ensure we request offline_access for refresh tokens
|
||||
try:
|
||||
oidc_scopes = list(OIDC_SCOPE_OVERRIDE or BASE_SCOPES)
|
||||
if "offline_access" not in oidc_scopes:
|
||||
oidc_scopes.append("offline_access")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error configuring OIDC scopes: {e}")
|
||||
# Fall back to default scopes if there's an error
|
||||
oidc_scopes = BASE_SCOPES
|
||||
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
@@ -94,8 +112,8 @@ def get_application() -> FastAPI:
|
||||
OAUTH_CLIENT_ID,
|
||||
OAUTH_CLIENT_SECRET,
|
||||
OPENID_CONFIG_URL,
|
||||
# BASE_SCOPES is the same as not setting this
|
||||
base_scopes=OIDC_SCOPE_OVERRIDE or BASE_SCOPES,
|
||||
# Use the configured scopes
|
||||
base_scopes=oidc_scopes,
|
||||
),
|
||||
auth_backend,
|
||||
USER_AUTH_SECRET,
|
||||
|
||||
@@ -15,8 +15,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
|
||||
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
|
||||
from ee.onyx.server.enterprise_settings.store import _LOGO_FILENAME
|
||||
from ee.onyx.server.enterprise_settings.store import _LOGOTYPE_FILENAME
|
||||
from ee.onyx.server.enterprise_settings.store import get_logo_filename
|
||||
from ee.onyx.server.enterprise_settings.store import get_logotype_filename
|
||||
from ee.onyx.server.enterprise_settings.store import load_analytics_script
|
||||
from ee.onyx.server.enterprise_settings.store import load_settings
|
||||
from ee.onyx.server.enterprise_settings.store import store_analytics_script
|
||||
@@ -28,7 +28,7 @@ from onyx.auth.users import get_user_manager
|
||||
from onyx.auth.users import UserManager
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.file_store import PostgresBackedFileStore
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/enterprise-settings")
|
||||
@@ -131,31 +131,49 @@ def put_logo(
|
||||
upload_logo(file=file, db_session=db_session, is_logotype=is_logotype)
|
||||
|
||||
|
||||
def fetch_logo_or_logotype(is_logotype: bool, db_session: Session) -> Response:
|
||||
def fetch_logo_helper(db_session: Session) -> Response:
|
||||
try:
|
||||
file_store = get_default_file_store(db_session)
|
||||
filename = _LOGOTYPE_FILENAME if is_logotype else _LOGO_FILENAME
|
||||
file_io = file_store.read_file(filename, mode="b")
|
||||
# NOTE: specifying "image/jpeg" here, but it still works for pngs
|
||||
# TODO: do this properly
|
||||
return Response(content=file_io.read(), media_type="image/jpeg")
|
||||
file_store = PostgresBackedFileStore(db_session)
|
||||
onyx_file = file_store.get_file_with_mime_type(get_logo_filename())
|
||||
if not onyx_file:
|
||||
raise ValueError("get_onyx_file returned None!")
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No {'logotype' if is_logotype else 'logo'} file found",
|
||||
detail="No logo file found",
|
||||
)
|
||||
else:
|
||||
return Response(content=onyx_file.data, media_type=onyx_file.mime_type)
|
||||
|
||||
|
||||
def fetch_logotype_helper(db_session: Session) -> Response:
|
||||
try:
|
||||
file_store = PostgresBackedFileStore(db_session)
|
||||
onyx_file = file_store.get_file_with_mime_type(get_logotype_filename())
|
||||
if not onyx_file:
|
||||
raise ValueError("get_onyx_file returned None!")
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No logotype file found",
|
||||
)
|
||||
else:
|
||||
return Response(content=onyx_file.data, media_type=onyx_file.mime_type)
|
||||
|
||||
|
||||
@basic_router.get("/logotype")
|
||||
def fetch_logotype(db_session: Session = Depends(get_session)) -> Response:
|
||||
return fetch_logo_or_logotype(is_logotype=True, db_session=db_session)
|
||||
return fetch_logotype_helper(db_session)
|
||||
|
||||
|
||||
@basic_router.get("/logo")
|
||||
def fetch_logo(
|
||||
is_logotype: bool = False, db_session: Session = Depends(get_session)
|
||||
) -> Response:
|
||||
return fetch_logo_or_logotype(is_logotype=is_logotype, db_session=db_session)
|
||||
if is_logotype:
|
||||
return fetch_logotype_helper(db_session)
|
||||
|
||||
return fetch_logo_helper(db_session)
|
||||
|
||||
|
||||
@admin_router.put("/custom-analytics-script")
|
||||
|
||||
@@ -13,6 +13,7 @@ from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import KV_CUSTOM_ANALYTICS_SCRIPT_KEY
|
||||
from onyx.configs.constants import KV_ENTERPRISE_SETTINGS_KEY
|
||||
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
@@ -21,8 +22,18 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_LOGO_FILENAME = "__logo__"
|
||||
_LOGOTYPE_FILENAME = "__logotype__"
|
||||
|
||||
|
||||
def load_settings() -> EnterpriseSettings:
|
||||
"""Loads settings data directly from DB. This should be used primarily
|
||||
for checking what is actually in the DB, aka for editing and saving back settings.
|
||||
|
||||
Runtime settings actually used by the application should be checked with
|
||||
load_runtime_settings as defaults may be applied at runtime.
|
||||
"""
|
||||
|
||||
dynamic_config_store = get_kv_store()
|
||||
try:
|
||||
settings = EnterpriseSettings(
|
||||
@@ -36,9 +47,24 @@ def load_settings() -> EnterpriseSettings:
|
||||
|
||||
|
||||
def store_settings(settings: EnterpriseSettings) -> None:
|
||||
"""Stores settings directly to the kv store / db."""
|
||||
|
||||
get_kv_store().store(KV_ENTERPRISE_SETTINGS_KEY, settings.model_dump())
|
||||
|
||||
|
||||
def load_runtime_settings() -> EnterpriseSettings:
|
||||
"""Loads settings from DB and applies any defaults or transformations for use
|
||||
at runtime.
|
||||
|
||||
Should not be stored back to the DB.
|
||||
"""
|
||||
enterprise_settings = load_settings()
|
||||
if not enterprise_settings.application_name:
|
||||
enterprise_settings.application_name = ONYX_DEFAULT_APPLICATION_NAME
|
||||
|
||||
return enterprise_settings
|
||||
|
||||
|
||||
_CUSTOM_ANALYTICS_SECRET_KEY = os.environ.get("CUSTOM_ANALYTICS_SECRET_KEY")
|
||||
|
||||
|
||||
@@ -60,10 +86,6 @@ def store_analytics_script(analytics_script_upload: AnalyticsScriptUpload) -> No
|
||||
get_kv_store().store(KV_CUSTOM_ANALYTICS_SCRIPT_KEY, analytics_script_upload.script)
|
||||
|
||||
|
||||
_LOGO_FILENAME = "__logo__"
|
||||
_LOGOTYPE_FILENAME = "__logotype__"
|
||||
|
||||
|
||||
def is_valid_file_type(filename: str) -> bool:
|
||||
valid_extensions = (".png", ".jpg", ".jpeg")
|
||||
return filename.endswith(valid_extensions)
|
||||
@@ -116,3 +138,11 @@ def upload_logo(
|
||||
file_type=file_type,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def get_logo_filename() -> str:
|
||||
return _LOGO_FILENAME
|
||||
|
||||
|
||||
def get_logotype_filename() -> str:
|
||||
return _LOGOTYPE_FILENAME
|
||||
|
||||
@@ -44,7 +44,7 @@ async def _get_tenant_id_from_request(
|
||||
Attempt to extract tenant_id from:
|
||||
1) The API key header
|
||||
2) The Redis-based token (stored in Cookie: fastapiusersauth)
|
||||
3) Reset token cookie
|
||||
3) The anonymous user cookie
|
||||
Fallback: POSTGRES_DEFAULT_SCHEMA
|
||||
"""
|
||||
# Check for API key
|
||||
@@ -52,41 +52,55 @@ async def _get_tenant_id_from_request(
|
||||
if tenant_id is not None:
|
||||
return tenant_id
|
||||
|
||||
# Check for anonymous user cookie
|
||||
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
|
||||
if anonymous_user_cookie:
|
||||
try:
|
||||
anonymous_user_data = decode_anonymous_user_jwt_token(anonymous_user_cookie)
|
||||
return anonymous_user_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
|
||||
# Continue and attempt to authenticate
|
||||
|
||||
try:
|
||||
# Look up token data in Redis
|
||||
|
||||
token_data = await retrieve_auth_token_data_from_redis(request)
|
||||
|
||||
if not token_data:
|
||||
logger.debug(
|
||||
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
|
||||
if token_data:
|
||||
tenant_id_from_payload = token_data.get(
|
||||
"tenant_id", POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
|
||||
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
|
||||
# so we maintain consistency by returning it here when no valid tenant is found.
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
tenant_id_from_payload = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
tenant_id = (
|
||||
str(tenant_id_from_payload)
|
||||
if tenant_id_from_payload is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Since token_data.get() can return None, ensure we have a string
|
||||
tenant_id = (
|
||||
str(tenant_id_from_payload)
|
||||
if tenant_id_from_payload is not None
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
if tenant_id and not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||
|
||||
# Check for anonymous user cookie
|
||||
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
|
||||
if anonymous_user_cookie:
|
||||
try:
|
||||
anonymous_user_data = decode_anonymous_user_jwt_token(
|
||||
anonymous_user_cookie
|
||||
)
|
||||
tenant_id = anonymous_user_data.get(
|
||||
"tenant_id", POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
|
||||
if not tenant_id or not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid tenant ID format"
|
||||
)
|
||||
|
||||
return tenant_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
|
||||
# Continue and attempt to authenticate
|
||||
|
||||
logger.debug(
|
||||
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
|
||||
)
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
|
||||
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
|
||||
# so we maintain consistency by returning it here when no valid tenant is found.
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")
|
||||
|
||||
@@ -1,26 +1,35 @@
|
||||
import re
|
||||
from typing import cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.query_and_chat.models import AgentAnswer
|
||||
from ee.onyx.server.query_and_chat.models import AgentSubQuery
|
||||
from ee.onyx.server.query_and_chat.models import AgentSubQuestion
|
||||
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
|
||||
from ee.onyx.server.query_and_chat.models import (
|
||||
BasicCreateChatMessageWithHistoryRequest,
|
||||
)
|
||||
from ee.onyx.server.query_and_chat.models import ChatBasicResponse
|
||||
from ee.onyx.server.query_and_chat.models import SimpleDoc
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import combine_message_thread
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AllCitations
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import FinalUsedContextDocsResponse
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import LLMRelevanceFilterResponse
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.chat.models import SubQuestionIdentifier
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.chat.process_message import ChatPacketStream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
@@ -46,25 +55,6 @@ logger = setup_logger()
|
||||
router = APIRouter(prefix="/chat")
|
||||
|
||||
|
||||
def _translate_doc_response_to_simple_doc(
|
||||
doc_response: QADocsResponse,
|
||||
) -> list[SimpleDoc]:
|
||||
return [
|
||||
SimpleDoc(
|
||||
id=doc.document_id,
|
||||
semantic_identifier=doc.semantic_identifier,
|
||||
link=doc.link,
|
||||
blurb=doc.blurb,
|
||||
match_highlights=[
|
||||
highlight for highlight in doc.match_highlights if highlight
|
||||
],
|
||||
source_type=doc.source_type,
|
||||
metadata=doc.metadata,
|
||||
)
|
||||
for doc in doc_response.top_documents
|
||||
]
|
||||
|
||||
|
||||
def _get_final_context_doc_indices(
|
||||
final_context_docs: list[LlmDoc] | None,
|
||||
top_docs: list[SavedSearchDoc] | None,
|
||||
@@ -89,14 +79,26 @@ def _convert_packet_stream_to_response(
|
||||
final_context_docs: list[LlmDoc] = []
|
||||
|
||||
answer = ""
|
||||
|
||||
# accumulate stream data with these dicts
|
||||
agent_sub_questions: dict[tuple[int, int], AgentSubQuestion] = {}
|
||||
agent_answers: dict[tuple[int, int], AgentAnswer] = {}
|
||||
agent_sub_queries: dict[tuple[int, int, int], AgentSubQuery] = {}
|
||||
|
||||
for packet in packets:
|
||||
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
response.top_documents = packet.top_documents
|
||||
|
||||
# TODO: deprecate `simple_search_docs`
|
||||
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
|
||||
# This is a no-op if agent_sub_questions hasn't already been filled
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
id = (packet.level, packet.level_question_num)
|
||||
if id in agent_sub_questions:
|
||||
agent_sub_questions[id].document_ids = [
|
||||
saved_search_doc.document_id
|
||||
for saved_search_doc in packet.top_documents
|
||||
]
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
@@ -113,11 +115,104 @@ def _convert_packet_stream_to_response(
|
||||
citation.citation_num: citation.document_id
|
||||
for citation in packet.citations
|
||||
}
|
||||
# agentic packets
|
||||
elif isinstance(packet, SubQuestionPiece):
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
id = (packet.level, packet.level_question_num)
|
||||
if agent_sub_questions.get(id) is None:
|
||||
agent_sub_questions[id] = AgentSubQuestion(
|
||||
level=packet.level,
|
||||
level_question_num=packet.level_question_num,
|
||||
sub_question=packet.sub_question,
|
||||
document_ids=[],
|
||||
)
|
||||
else:
|
||||
agent_sub_questions[id].sub_question += packet.sub_question
|
||||
|
||||
elif isinstance(packet, AgentAnswerPiece):
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
id = (packet.level, packet.level_question_num)
|
||||
if agent_answers.get(id) is None:
|
||||
agent_answers[id] = AgentAnswer(
|
||||
level=packet.level,
|
||||
level_question_num=packet.level_question_num,
|
||||
answer=packet.answer_piece,
|
||||
answer_type=packet.answer_type,
|
||||
)
|
||||
else:
|
||||
agent_answers[id].answer += packet.answer_piece
|
||||
elif isinstance(packet, SubQueryPiece):
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
sub_query_id = (
|
||||
packet.level,
|
||||
packet.level_question_num,
|
||||
packet.query_id,
|
||||
)
|
||||
if agent_sub_queries.get(sub_query_id) is None:
|
||||
agent_sub_queries[sub_query_id] = AgentSubQuery(
|
||||
level=packet.level,
|
||||
level_question_num=packet.level_question_num,
|
||||
sub_query=packet.sub_query,
|
||||
query_id=packet.query_id,
|
||||
)
|
||||
else:
|
||||
agent_sub_queries[sub_query_id].sub_query += packet.sub_query
|
||||
elif isinstance(packet, ExtendedToolResponse):
|
||||
# we shouldn't get this ... it gets intercepted and translated to QADocsResponse
|
||||
logger.warning(
|
||||
"_convert_packet_stream_to_response: Unexpected chat packet type ExtendedToolResponse!"
|
||||
)
|
||||
elif isinstance(packet, RefinedAnswerImprovement):
|
||||
response.agent_refined_answer_improvement = (
|
||||
packet.refined_answer_improvement
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"_convert_packet_stream_to_response - Unrecognized chat packet: type={type(packet)}"
|
||||
)
|
||||
|
||||
response.final_context_doc_indices = _get_final_context_doc_indices(
|
||||
final_context_docs, response.top_documents
|
||||
)
|
||||
|
||||
# organize / sort agent metadata for output
|
||||
if len(agent_sub_questions) > 0:
|
||||
response.agent_sub_questions = cast(
|
||||
dict[int, list[AgentSubQuestion]],
|
||||
SubQuestionIdentifier.make_dict_by_level(agent_sub_questions),
|
||||
)
|
||||
|
||||
if len(agent_answers) > 0:
|
||||
# return the agent_level_answer from the first level or the last one depending
|
||||
# on agent_refined_answer_improvement
|
||||
response.agent_answers = cast(
|
||||
dict[int, list[AgentAnswer]],
|
||||
SubQuestionIdentifier.make_dict_by_level(agent_answers),
|
||||
)
|
||||
if response.agent_answers:
|
||||
selected_answer_level = (
|
||||
0
|
||||
if not response.agent_refined_answer_improvement
|
||||
else len(response.agent_answers) - 1
|
||||
)
|
||||
level_answers = response.agent_answers[selected_answer_level]
|
||||
for level_answer in level_answers:
|
||||
if level_answer.answer_type != "agent_level_answer":
|
||||
continue
|
||||
|
||||
answer = level_answer.answer
|
||||
break
|
||||
|
||||
if len(agent_sub_queries) > 0:
|
||||
# subqueries are often emitted with trailing whitespace ... clean it up here
|
||||
# perhaps fix at the source?
|
||||
for v in agent_sub_queries.values():
|
||||
v.sub_query = v.sub_query.strip()
|
||||
|
||||
response.agent_sub_queries = (
|
||||
AgentSubQuery.make_dict_by_level_and_question_index(agent_sub_queries)
|
||||
)
|
||||
|
||||
response.answer = answer
|
||||
if answer:
|
||||
response.answer_citationless = remove_answer_citations(answer)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -6,9 +8,9 @@ from pydantic import model_validator
|
||||
|
||||
from ee.onyx.server.manage.models import StandardAnswer
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import SubQuestionIdentifier
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
@@ -88,6 +90,64 @@ class SimpleDoc(BaseModel):
|
||||
metadata: dict | None
|
||||
|
||||
|
||||
class AgentSubQuestion(SubQuestionIdentifier):
|
||||
sub_question: str
|
||||
document_ids: list[str]
|
||||
|
||||
|
||||
class AgentAnswer(SubQuestionIdentifier):
|
||||
answer: str
|
||||
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
|
||||
|
||||
|
||||
class AgentSubQuery(SubQuestionIdentifier):
|
||||
sub_query: str
|
||||
query_id: int
|
||||
|
||||
@staticmethod
|
||||
def make_dict_by_level_and_question_index(
|
||||
original_dict: dict[tuple[int, int, int], "AgentSubQuery"]
|
||||
) -> dict[int, dict[int, list["AgentSubQuery"]]]:
|
||||
"""Takes a dict of tuple(level, question num, query_id) to sub queries.
|
||||
|
||||
returns a dict of level to dict[question num to list of query_id's]
|
||||
Ordering is asc for readability.
|
||||
"""
|
||||
# In this function, when we sort int | None, we deliberately push None to the end
|
||||
|
||||
# map entries to the level_question_dict
|
||||
level_question_dict: dict[int, dict[int, list["AgentSubQuery"]]] = {}
|
||||
for k1, obj in original_dict.items():
|
||||
level = k1[0]
|
||||
question = k1[1]
|
||||
|
||||
if level not in level_question_dict:
|
||||
level_question_dict[level] = {}
|
||||
|
||||
if question not in level_question_dict[level]:
|
||||
level_question_dict[level][question] = []
|
||||
|
||||
level_question_dict[level][question].append(obj)
|
||||
|
||||
# sort each query_id list and question_index
|
||||
for key1, obj1 in level_question_dict.items():
|
||||
for key2, value2 in obj1.items():
|
||||
# sort the query_id list of each question_index
|
||||
level_question_dict[key1][key2] = sorted(
|
||||
value2, key=lambda o: o.query_id
|
||||
)
|
||||
# sort the question_index dict of level
|
||||
level_question_dict[key1] = OrderedDict(
|
||||
sorted(level_question_dict[key1].items(), key=lambda x: (x is None, x))
|
||||
)
|
||||
|
||||
# sort the top dict of levels
|
||||
sorted_dict = OrderedDict(
|
||||
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
|
||||
)
|
||||
return sorted_dict
|
||||
|
||||
|
||||
class ChatBasicResponse(BaseModel):
|
||||
# This is built piece by piece, any of these can be None as the flow could break
|
||||
answer: str | None = None
|
||||
@@ -103,10 +163,14 @@ class ChatBasicResponse(BaseModel):
|
||||
cited_documents: dict[int, str] | None = None
|
||||
|
||||
# FOR BACKWARDS COMPATIBILITY
|
||||
# TODO: deprecate both of these
|
||||
simple_search_docs: list[SimpleDoc] | None = None
|
||||
llm_chunks_indices: list[int] | None = None
|
||||
|
||||
# agentic fields
|
||||
agent_sub_questions: dict[int, list[AgentSubQuestion]] | None = None
|
||||
agent_answers: dict[int, list[AgentAnswer]] | None = None
|
||||
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
|
||||
agent_refined_answer_improvement: bool | None = None
|
||||
|
||||
|
||||
class OneShotQARequest(ChunkContext):
|
||||
# Supports simplier APIs that don't deal with chat histories or message edits
|
||||
@@ -153,4 +217,3 @@ class OneShotQAResponse(BaseModel):
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
chat_message_id: int | None = None
|
||||
contexts: OnyxContexts | None = None
|
||||
|
||||
@@ -48,10 +48,15 @@ def fetch_and_process_chat_session_history(
|
||||
feedback_type: QAFeedbackType | None,
|
||||
limit: int | None = 500,
|
||||
) -> list[ChatSessionSnapshot]:
|
||||
# observed to be slow a scale of 8192 sessions and 4 messages per session
|
||||
|
||||
# this is a little slow (5 seconds)
|
||||
chat_sessions = fetch_chat_sessions_eagerly_by_time(
|
||||
start=start, end=end, db_session=db_session, limit=limit
|
||||
)
|
||||
|
||||
# this is VERY slow (80 seconds) due to create_chat_chain being called
|
||||
# for each session. Needs optimizing.
|
||||
chat_session_snapshots = [
|
||||
snapshot_from_chat_session(chat_session=chat_session, db_session=db_session)
|
||||
for chat_session in chat_sessions
|
||||
@@ -246,6 +251,8 @@ def get_query_history_as_csv(
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
# this call is very expensive and is timing out via endpoint
|
||||
# TODO: optimize call and/or generate via background task
|
||||
complete_chat_session_history = fetch_and_process_chat_session_history(
|
||||
db_session=db_session,
|
||||
start=start or datetime.fromtimestamp(0, tz=timezone.utc),
|
||||
|
||||
@@ -38,6 +38,7 @@ router = APIRouter(prefix="/auth/saml")
|
||||
|
||||
|
||||
async def upsert_saml_user(email: str) -> User:
|
||||
logger.debug(f"Attempting to upsert SAML user with email: {email}")
|
||||
get_async_session_context = contextlib.asynccontextmanager(
|
||||
get_async_session
|
||||
) # type:ignore
|
||||
@@ -48,9 +49,13 @@ async def upsert_saml_user(email: str) -> User:
|
||||
async with get_user_db_context(session) as user_db:
|
||||
async with get_user_manager_context(user_db) as user_manager:
|
||||
try:
|
||||
return await user_manager.get_by_email(email)
|
||||
user = await user_manager.get_by_email(email)
|
||||
# If user has a non-authenticated role, treat as non-existent
|
||||
if not user.role.is_web_login():
|
||||
raise exceptions.UserNotExists()
|
||||
return user
|
||||
except exceptions.UserNotExists:
|
||||
logger.notice("Creating user from SAML login")
|
||||
logger.info("Creating user from SAML login")
|
||||
|
||||
user_count = await get_user_count()
|
||||
role = UserRole.ADMIN if user_count == 0 else UserRole.BASIC
|
||||
@@ -59,11 +64,10 @@ async def upsert_saml_user(email: str) -> User:
|
||||
password = fastapi_users_pw_helper.generate()
|
||||
hashed_pass = fastapi_users_pw_helper.hash(password)
|
||||
|
||||
user: User = await user_manager.create(
|
||||
user = await user_manager.create(
|
||||
UserCreate(
|
||||
email=email,
|
||||
password=hashed_pass,
|
||||
is_verified=True,
|
||||
role=role,
|
||||
)
|
||||
)
|
||||
|
||||
45
backend/ee/onyx/server/tenants/admin_api.py
Normal file
45
backend/ee/onyx/server/tenants/admin_api.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
|
||||
from ee.onyx.auth.users import current_cloud_superuser
|
||||
from ee.onyx.server.tenants.models import ImpersonateRequest
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import get_redis_strategy
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/impersonate")
|
||||
async def impersonate_user(
|
||||
impersonate_request: ImpersonateRequest,
|
||||
_: User = Depends(current_cloud_superuser),
|
||||
) -> Response:
|
||||
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
|
||||
tenant_id = get_tenant_id_for_email(impersonate_request.email)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
|
||||
user_to_impersonate = get_user_by_email(
|
||||
impersonate_request.email, tenant_session
|
||||
)
|
||||
if user_to_impersonate is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
token = await get_redis_strategy().write_token(user_to_impersonate)
|
||||
|
||||
response = await auth_backend.transport.get_login_response(token)
|
||||
response.set_cookie(
|
||||
key="fastapiusersauth",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="lax",
|
||||
)
|
||||
return response
|
||||
98
backend/ee/onyx/server/tenants/anonymous_users_api.py
Normal file
98
backend/ee/onyx/server/tenants/anonymous_users_api.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
|
||||
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
|
||||
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import (
|
||||
get_tenant_id_for_anonymous_user_path,
|
||||
)
|
||||
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
|
||||
from ee.onyx.server.tenants.models import AnonymousUserPath
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.get("/anonymous-user-path")
|
||||
async def get_anonymous_user_path_api(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> AnonymousUserPath:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if tenant_id is None:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
current_path = get_anonymous_user_path(tenant_id, db_session)
|
||||
|
||||
return AnonymousUserPath(anonymous_user_path=current_path)
|
||||
|
||||
|
||||
@router.post("/anonymous-user-path")
|
||||
async def set_anonymous_user_path_api(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
validate_anonymous_user_path(anonymous_user_path)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
|
||||
except IntegrityError:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="The anonymous user path is already in use. Please choose a different path.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while modifying the anonymous user path",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/anonymous-user")
|
||||
async def login_as_anonymous_user(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(optional_user),
|
||||
) -> Response:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
tenant_id = get_tenant_id_for_anonymous_user_path(
|
||||
anonymous_user_path, db_session
|
||||
)
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
if not anonymous_user_enabled(tenant_id=tenant_id):
|
||||
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
|
||||
|
||||
token = generate_anonymous_user_jwt_token(tenant_id)
|
||||
|
||||
response = Response()
|
||||
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
response.set_cookie(
|
||||
key=ANONYMOUS_USER_COOKIE_NAME,
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="strict",
|
||||
)
|
||||
return response
|
||||
@@ -1,269 +1,24 @@
|
||||
import stripe
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_cloud_superuser
|
||||
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
|
||||
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import (
|
||||
get_tenant_id_for_anonymous_user_path,
|
||||
from ee.onyx.server.tenants.admin_api import router as admin_router
|
||||
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
|
||||
from ee.onyx.server.tenants.billing_api import router as billing_router
|
||||
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
|
||||
from ee.onyx.server.tenants.tenant_management_api import (
|
||||
router as tenant_management_router,
|
||||
)
|
||||
from ee.onyx.server.tenants.user_invitations_api import (
|
||||
router as user_invitations_router,
|
||||
)
|
||||
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import AnonymousUserPath
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import ImpersonateRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.product_gating import store_product_gating
|
||||
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_redis_strategy
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
router = APIRouter(prefix="/tenants")
|
||||
# Create a main router to include all sub-routers
|
||||
# Note: We don't add a prefix here as each router already has the /tenants prefix
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/anonymous-user-path")
|
||||
async def get_anonymous_user_path_api(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> AnonymousUserPath:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if tenant_id is None:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
current_path = get_anonymous_user_path(tenant_id, db_session)
|
||||
|
||||
return AnonymousUserPath(anonymous_user_path=current_path)
|
||||
|
||||
|
||||
@router.post("/anonymous-user-path")
|
||||
async def set_anonymous_user_path_api(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
validate_anonymous_user_path(anonymous_user_path)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
|
||||
except IntegrityError:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="The anonymous user path is already in use. Please choose a different path.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while modifying the anonymous user path",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/anonymous-user")
|
||||
async def login_as_anonymous_user(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(optional_user),
|
||||
) -> Response:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
tenant_id = get_tenant_id_for_anonymous_user_path(
|
||||
anonymous_user_path, db_session
|
||||
)
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
if not anonymous_user_enabled(tenant_id=tenant_id):
|
||||
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
|
||||
|
||||
token = generate_anonymous_user_jwt_token(tenant_id)
|
||||
|
||||
response = Response()
|
||||
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
response.set_cookie(
|
||||
key=ANONYMOUS_USER_COOKIE_NAME,
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="strict",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
|
||||
) -> ProductGatingResponse:
|
||||
"""
|
||||
Gating the product means that the product is not available to the tenant.
|
||||
They will be directed to the billing page.
|
||||
We gate the product when their subscription has ended.
|
||||
"""
|
||||
try:
|
||||
store_product_gating(
|
||||
product_gating_request.tenant_id, product_gating_request.application_status
|
||||
)
|
||||
return ProductGatingResponse(updated=True, error=None)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to gate product")
|
||||
return ProductGatingResponse(updated=False, error=str(e))
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
logger.info("Fetching billing information")
|
||||
tenant_id = get_current_tenant_id()
|
||||
return fetch_billing_information(tenant_id)
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
try:
|
||||
stripe_info = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_customer_id = stripe_info.get("stripe_customer_id")
|
||||
if not stripe_customer_id:
|
||||
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
|
||||
logger.info(stripe_customer_id)
|
||||
|
||||
portal_session = stripe.billing_portal.Session.create(
|
||||
customer=stripe_customer_id,
|
||||
return_url=f"{WEB_DOMAIN}/admin/billing",
|
||||
)
|
||||
logger.info(portal_session)
|
||||
return {"url": portal_session.url}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create customer portal session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create resubscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/impersonate")
|
||||
async def impersonate_user(
|
||||
impersonate_request: ImpersonateRequest,
|
||||
_: User = Depends(current_cloud_superuser),
|
||||
) -> Response:
|
||||
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
|
||||
tenant_id = get_tenant_id_for_email(impersonate_request.email)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
|
||||
user_to_impersonate = get_user_by_email(
|
||||
impersonate_request.email, tenant_session
|
||||
)
|
||||
if user_to_impersonate is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
token = await get_redis_strategy().write_token(user_to_impersonate)
|
||||
|
||||
response = await auth_backend.transport.get_login_response(token)
|
||||
response.set_cookie(
|
||||
key="fastapiusersauth",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="lax",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/leave-organization")
|
||||
async def leave_organization(
|
||||
user_email: UserByEmail,
|
||||
current_user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if current_user is None or current_user.email != user_email.user_email:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You can only leave the organization as yourself"
|
||||
)
|
||||
|
||||
user_to_delete = get_user_by_email(user_email.user_email, db_session)
|
||||
if user_to_delete is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
num_admin_users = await get_user_count(only_admin_users=True)
|
||||
|
||||
should_delete_tenant = num_admin_users == 1
|
||||
|
||||
if should_delete_tenant:
|
||||
logger.info(
|
||||
"Last admin user is leaving the organization. Deleting tenant from control plane."
|
||||
)
|
||||
try:
|
||||
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
|
||||
logger.debug("User deleted from control plane")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to remove user from control plane: {str(e)}",
|
||||
)
|
||||
|
||||
db_session.expunge(user_to_delete)
|
||||
delete_user_from_db(user_to_delete, db_session)
|
||||
|
||||
if should_delete_tenant:
|
||||
remove_all_users_from_tenant(tenant_id)
|
||||
else:
|
||||
remove_users_from_tenant([user_to_delete.email], tenant_id)
|
||||
# Include all the individual routers
|
||||
router.include_router(admin_router)
|
||||
router.include_router(anonymous_users_router)
|
||||
router.include_router(billing_router)
|
||||
router.include_router(team_membership_router)
|
||||
router.include_router(tenant_management_router)
|
||||
router.include_router(user_invitations_router)
|
||||
|
||||
96
backend/ee/onyx/server/tenants/billing_api.py
Normal file
96
backend/ee/onyx/server/tenants/billing_api.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import stripe
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.product_gating import store_product_gating
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
|
||||
) -> ProductGatingResponse:
|
||||
"""
|
||||
Gating the product means that the product is not available to the tenant.
|
||||
They will be directed to the billing page.
|
||||
We gate the product when their subscription has ended.
|
||||
"""
|
||||
try:
|
||||
store_product_gating(
|
||||
product_gating_request.tenant_id, product_gating_request.application_status
|
||||
)
|
||||
return ProductGatingResponse(updated=True, error=None)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to gate product")
|
||||
return ProductGatingResponse(updated=False, error=str(e))
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
logger.info("Fetching billing information")
|
||||
tenant_id = get_current_tenant_id()
|
||||
return fetch_billing_information(tenant_id)
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
try:
|
||||
stripe_info = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_customer_id = stripe_info.get("stripe_customer_id")
|
||||
if not stripe_customer_id:
|
||||
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
|
||||
logger.info(stripe_customer_id)
|
||||
|
||||
portal_session = stripe.billing_portal.Session.create(
|
||||
customer=stripe_customer_id,
|
||||
return_url=f"{WEB_DOMAIN}/admin/billing",
|
||||
)
|
||||
logger.info(portal_session)
|
||||
return {"url": portal_session.url}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create customer portal session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create resubscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -67,3 +67,30 @@ class ProductGatingResponse(BaseModel):
|
||||
|
||||
class SubscriptionSessionResponse(BaseModel):
|
||||
sessionId: str
|
||||
|
||||
|
||||
class TenantByDomainResponse(BaseModel):
|
||||
tenant_id: str
|
||||
number_of_users: int
|
||||
creator_email: str
|
||||
|
||||
|
||||
class TenantByDomainRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class RequestInviteRequest(BaseModel):
|
||||
tenant_id: str
|
||||
|
||||
|
||||
class RequestInviteResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class PendingUserSnapshot(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class ApproveUserRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
@@ -4,6 +4,7 @@ import uuid
|
||||
|
||||
import aiohttp # Async HTTP client
|
||||
import httpx
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy import select
|
||||
@@ -14,6 +15,7 @@ from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
|
||||
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import TenantByDomainResponse
|
||||
from ee.onyx.server.tenants.models import TenantCreationPayload
|
||||
from ee.onyx.server.tenants.models import TenantDeletionPayload
|
||||
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
|
||||
@@ -26,11 +28,12 @@ from onyx.auth.users import exceptions
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import AvailableTenant
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import UserTenantMapping
|
||||
@@ -60,42 +63,72 @@ async def get_or_provision_tenant(
|
||||
This function should only be called after we have verified we want this user's tenant to exist.
|
||||
It returns the tenant ID associated with the email, creating a new tenant if necessary.
|
||||
"""
|
||||
# Early return for non-multi-tenant mode
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
if referral_source and request:
|
||||
await submit_to_hubspot(email, referral_source, request)
|
||||
|
||||
# First, check if the user already has a tenant
|
||||
tenant_id: str | None = None
|
||||
try:
|
||||
tenant_id = get_tenant_id_for_email(email)
|
||||
return tenant_id
|
||||
except exceptions.UserNotExists:
|
||||
# If tenant does not exist and in Multi tenant mode, provision a new tenant
|
||||
try:
|
||||
# User doesn't exist, so we need to create a new tenant or assign an existing one
|
||||
pass
|
||||
|
||||
try:
|
||||
# Try to get a pre-provisioned tenant
|
||||
tenant_id = await get_available_tenant()
|
||||
|
||||
if tenant_id:
|
||||
# If we have a pre-provisioned tenant, assign it to the user
|
||||
await assign_tenant_to_user(tenant_id, email, referral_source)
|
||||
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
|
||||
else:
|
||||
# If no pre-provisioned tenant is available, create a new one on-demand
|
||||
tenant_id = await create_tenant(email, referral_source)
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
||||
|
||||
if not tenant_id:
|
||||
# Notify control plane if we have created / assigned a new tenant
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
|
||||
return tenant_id
|
||||
|
||||
except Exception as e:
|
||||
# If we've encountered an error, log and raise an exception
|
||||
error_msg = "Failed to provision tenant"
|
||||
logger.error(error_msg, exc_info=e)
|
||||
raise HTTPException(
|
||||
status_code=401, detail="User does not belong to an organization"
|
||||
status_code=500,
|
||||
detail="Failed to provision tenant. Please try again later.",
|
||||
)
|
||||
|
||||
return tenant_id
|
||||
|
||||
|
||||
async def create_tenant(email: str, referral_source: str | None = None) -> str:
|
||||
"""
|
||||
Create a new tenant on-demand when no pre-provisioned tenants are available.
|
||||
This is the fallback method when we can't use a pre-provisioned tenant.
|
||||
|
||||
"""
|
||||
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
|
||||
logger.info(f"Creating new tenant {tenant_id} for user {email}")
|
||||
|
||||
try:
|
||||
# Provision tenant on data plane
|
||||
await provision_tenant(tenant_id, email)
|
||||
# Notify control plane
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
await rollback_tenant_provisioning(tenant_id)
|
||||
logger.exception(f"Tenant provisioning failed: {str(e)}")
|
||||
# Attempt to rollback the tenant provisioning
|
||||
try:
|
||||
await rollback_tenant_provisioning(tenant_id)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to rollback tenant provisioning for {tenant_id}")
|
||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
||||
|
||||
return tenant_id
|
||||
|
||||
|
||||
@@ -109,54 +142,25 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
)
|
||||
|
||||
logger.debug(f"Provisioning tenant {tenant_id} for user {email}")
|
||||
token = None
|
||||
|
||||
try:
|
||||
# Create the schema for the tenant
|
||||
if not create_schema_if_not_exists(tenant_id):
|
||||
logger.debug(f"Created schema for tenant {tenant_id}")
|
||||
else:
|
||||
logger.debug(f"Schema already exists for tenant {tenant_id}")
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
# Set up the tenant with all necessary configurations
|
||||
await setup_tenant(tenant_id)
|
||||
|
||||
# Await the Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
current_search_settings = (
|
||||
db_session.query(SearchSettings)
|
||||
.filter_by(status=IndexModelStatus.FUTURE)
|
||||
.first()
|
||||
)
|
||||
cohere_enabled = (
|
||||
current_search_settings is not None
|
||||
and current_search_settings.provider_type == EmbeddingProvider.COHERE
|
||||
)
|
||||
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
|
||||
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||
properties={
|
||||
"email": email,
|
||||
},
|
||||
db_session=db_session,
|
||||
)
|
||||
# Assign the tenant to the user
|
||||
await assign_tenant_to_user(tenant_id, email)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create tenant {tenant_id}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create tenant: {str(e)}"
|
||||
)
|
||||
finally:
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
async def notify_control_plane(
|
||||
@@ -187,20 +191,74 @@ async def notify_control_plane(
|
||||
|
||||
|
||||
async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
# Logic to rollback tenant provisioning on data plane
|
||||
"""
|
||||
Logic to rollback tenant provisioning on data plane.
|
||||
Handles each step independently to ensure maximum cleanup even if some steps fail.
|
||||
"""
|
||||
logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}")
|
||||
try:
|
||||
# Drop the tenant's schema to rollback provisioning
|
||||
drop_schema(tenant_id)
|
||||
|
||||
# Remove tenant mapping
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
# Track if any part of the rollback fails
|
||||
rollback_errors = []
|
||||
|
||||
# 1. Try to drop the tenant's schema
|
||||
try:
|
||||
drop_schema(tenant_id)
|
||||
logger.info(f"Successfully dropped schema for tenant {tenant_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to rollback tenant provisioning: {e}")
|
||||
error_msg = f"Failed to drop schema for tenant {tenant_id}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
rollback_errors.append(error_msg)
|
||||
|
||||
# 2. Try to remove tenant mapping
|
||||
try:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
db_session.begin()
|
||||
try:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
logger.info(
|
||||
f"Successfully removed user mappings for tenant {tenant_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
raise e
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to remove user mappings for tenant {tenant_id}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
rollback_errors.append(error_msg)
|
||||
|
||||
# 3. If this tenant was in the available tenants table, remove it
|
||||
try:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
db_session.begin()
|
||||
try:
|
||||
available_tenant = (
|
||||
db_session.query(AvailableTenant)
|
||||
.filter(AvailableTenant.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if available_tenant:
|
||||
db_session.delete(available_tenant)
|
||||
db_session.commit()
|
||||
logger.info(
|
||||
f"Removed tenant {tenant_id} from available tenants table"
|
||||
)
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
raise e
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to remove tenant {tenant_id} from available tenants table: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
rollback_errors.append(error_msg)
|
||||
|
||||
# Log summary of rollback operation
|
||||
if rollback_errors:
|
||||
logger.error(f"Tenant rollback completed with {len(rollback_errors)} errors")
|
||||
else:
|
||||
logger.info(f"Tenant rollback completed successfully for tenant {tenant_id}")
|
||||
|
||||
|
||||
def configure_default_api_keys(db_session: Session) -> None:
|
||||
@@ -213,6 +271,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_names=ANTHROPIC_MODEL_NAMES,
|
||||
display_model_names=["claude-3-5-sonnet-20241022"],
|
||||
api_key_changed=True,
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
||||
@@ -225,7 +284,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
)
|
||||
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
open_provider = LLMProviderUpsertRequest(
|
||||
openai_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
@@ -233,9 +292,10 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
|
||||
api_key_changed=True,
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(open_provider, db_session)
|
||||
full_provider = upsert_llm_provider(openai_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure OpenAI provider: {e}")
|
||||
@@ -353,3 +413,154 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
|
||||
raise Exception(
|
||||
f"Failed to delete tenant on control plane: {error_text}"
|
||||
)
|
||||
|
||||
|
||||
def get_tenant_by_domain_from_control_plane(
|
||||
domain: str,
|
||||
tenant_id: str,
|
||||
) -> TenantByDomainResponse | None:
|
||||
"""
|
||||
Fetches tenant information from the control plane based on the email domain.
|
||||
|
||||
Args:
|
||||
domain: The email domain to search for (e.g., "example.com")
|
||||
|
||||
Returns:
|
||||
A dictionary containing tenant information if found, None otherwise
|
||||
"""
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{CONTROL_PLANE_API_BASE_URL}/tenant-by-domain",
|
||||
headers=headers,
|
||||
json={"domain": domain, "tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Control plane tenant lookup failed: {response.text}")
|
||||
return None
|
||||
|
||||
response_data = response.json()
|
||||
if not response_data:
|
||||
return None
|
||||
|
||||
return TenantByDomainResponse(
|
||||
tenant_id=response_data.get("tenant_id"),
|
||||
number_of_users=response_data.get("number_of_users"),
|
||||
creator_email=response_data.get("creator_email"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching tenant by domain: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_available_tenant() -> str | None:
|
||||
"""
|
||||
Get an available pre-provisioned tenant from the NewAvailableTenant table.
|
||||
Returns the tenant_id if one is available, None otherwise.
|
||||
Uses row-level locking to prevent race conditions when multiple processes
|
||||
try to get an available tenant simultaneously.
|
||||
"""
|
||||
if not MULTI_TENANT:
|
||||
return None
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
db_session.begin()
|
||||
|
||||
# Get the oldest available tenant with FOR UPDATE lock to prevent race conditions
|
||||
available_tenant = (
|
||||
db_session.query(AvailableTenant)
|
||||
.order_by(AvailableTenant.date_created)
|
||||
.with_for_update(skip_locked=True) # Skip locked rows to avoid blocking
|
||||
.first()
|
||||
)
|
||||
|
||||
if available_tenant:
|
||||
tenant_id = available_tenant.tenant_id
|
||||
# Remove the tenant from the available tenants table
|
||||
db_session.delete(available_tenant)
|
||||
db_session.commit()
|
||||
logger.info(f"Using pre-provisioned tenant {tenant_id}")
|
||||
return tenant_id
|
||||
else:
|
||||
db_session.rollback()
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Error getting available tenant")
|
||||
db_session.rollback()
|
||||
return None
|
||||
|
||||
|
||||
async def setup_tenant(tenant_id: str) -> None:
|
||||
"""
|
||||
Set up a tenant with all necessary configurations.
|
||||
This is a centralized function that handles all tenant setup logic.
|
||||
"""
|
||||
token = None
|
||||
try:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Run Alembic migrations in a way that isolates it from the current event loop
|
||||
# Create a new event loop for this synchronous operation
|
||||
loop = asyncio.get_event_loop()
|
||||
# Use run_in_executor which properly isolates the thread execution
|
||||
await loop.run_in_executor(None, lambda: run_alembic_migrations(tenant_id))
|
||||
|
||||
# Configure the tenant with default settings
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
# Configure default API keys
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
# Set up Onyx with appropriate settings
|
||||
current_search_settings = (
|
||||
db_session.query(SearchSettings)
|
||||
.filter_by(status=IndexModelStatus.FUTURE)
|
||||
.first()
|
||||
)
|
||||
cohere_enabled = (
|
||||
current_search_settings is not None
|
||||
and current_search_settings.provider_type == EmbeddingProvider.COHERE
|
||||
)
|
||||
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to set up tenant {tenant_id}")
|
||||
raise e
|
||||
finally:
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
async def assign_tenant_to_user(
|
||||
tenant_id: str, email: str, referral_source: str | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Assign a tenant to a user and perform necessary operations.
|
||||
Uses transaction handling to ensure atomicity and includes retry logic
|
||||
for control plane notifications.
|
||||
"""
|
||||
# First, add the user to the tenant in a transaction
|
||||
|
||||
try:
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
# Create milestone record in the same transaction context as the tenant assignment
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||
properties={
|
||||
"email": email,
|
||||
},
|
||||
db_session=db_session,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
|
||||
raise Exception("Failed to assign tenant to user")
|
||||
|
||||
@@ -74,3 +74,21 @@ def drop_schema(tenant_id: str) -> None:
|
||||
text("DROP SCHEMA IF EXISTS %(schema_name)s CASCADE"),
|
||||
{"schema_name": tenant_id},
|
||||
)
|
||||
|
||||
|
||||
def get_current_alembic_version(tenant_id: str) -> str:
|
||||
"""Get the current Alembic version for a tenant."""
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from sqlalchemy import text
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
# Set the search path to the tenant's schema
|
||||
with engine.connect() as connection:
|
||||
connection.execute(text(f'SET search_path TO "{tenant_id}"'))
|
||||
|
||||
# Get the current version from the alembic_version table
|
||||
context = MigrationContext.configure(connection)
|
||||
current_rev = context.get_current_revision()
|
||||
|
||||
return current_rev or "head"
|
||||
|
||||
67
backend/ee/onyx/server/tenants/team_membership_api.py
Normal file
67
backend/ee/onyx/server/tenants/team_membership_api.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
|
||||
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/leave-team")
|
||||
async def leave_organization(
|
||||
user_email: UserByEmail,
|
||||
current_user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if current_user is None or current_user.email != user_email.user_email:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You can only leave the organization as yourself"
|
||||
)
|
||||
|
||||
user_to_delete = get_user_by_email(user_email.user_email, db_session)
|
||||
if user_to_delete is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
num_admin_users = await get_user_count(only_admin_users=True)
|
||||
|
||||
should_delete_tenant = num_admin_users == 1
|
||||
|
||||
if should_delete_tenant:
|
||||
logger.info(
|
||||
"Last admin user is leaving the organization. Deleting tenant from control plane."
|
||||
)
|
||||
try:
|
||||
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
|
||||
logger.debug("User deleted from control plane")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to remove user from control plane: {str(e)}",
|
||||
)
|
||||
|
||||
db_session.expunge(user_to_delete)
|
||||
delete_user_from_db(user_to_delete, db_session)
|
||||
|
||||
if should_delete_tenant:
|
||||
remove_all_users_from_tenant(tenant_id)
|
||||
else:
|
||||
remove_users_from_tenant([user_to_delete.email], tenant_id)
|
||||
39
backend/ee/onyx/server/tenants/tenant_management_api.py
Normal file
39
backend/ee/onyx/server/tenants/tenant_management_api.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
|
||||
from ee.onyx.server.tenants.models import TenantByDomainResponse
|
||||
from ee.onyx.server.tenants.provisioning import get_tenant_by_domain_from_control_plane
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
FORBIDDEN_COMMON_EMAIL_SUBSTRINGS = [
|
||||
"gmail",
|
||||
"outlook",
|
||||
"yahoo",
|
||||
"hotmail",
|
||||
"icloud",
|
||||
"msn",
|
||||
"hotmail",
|
||||
"hotmail.co.uk",
|
||||
]
|
||||
|
||||
|
||||
@router.get("/existing-team-by-domain")
|
||||
def get_existing_tenant_by_domain(
|
||||
user: User | None = Depends(current_user),
|
||||
) -> TenantByDomainResponse | None:
|
||||
if not user:
|
||||
return None
|
||||
domain = user.email.split("@")[1]
|
||||
if any(substring in domain for substring in FORBIDDEN_COMMON_EMAIL_SUBSTRINGS):
|
||||
return None
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
return get_tenant_by_domain_from_control_plane(domain, tenant_id)
|
||||
90
backend/ee/onyx/server/tenants/user_invitations_api.py
Normal file
90
backend/ee/onyx/server/tenants/user_invitations_api.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.tenants.models import ApproveUserRequest
|
||||
from ee.onyx.server.tenants.models import PendingUserSnapshot
|
||||
from ee.onyx.server.tenants.models import RequestInviteRequest
|
||||
from ee.onyx.server.tenants.user_mapping import accept_user_invite
|
||||
from ee.onyx.server.tenants.user_mapping import approve_user_invite
|
||||
from ee.onyx.server.tenants.user_mapping import deny_user_invite
|
||||
from ee.onyx.server.tenants.user_mapping import invite_self_to_tenant
|
||||
from onyx.auth.invited_users import get_pending_users
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/users/invite/request")
|
||||
async def request_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="User not authenticated")
|
||||
try:
|
||||
invite_self_to_tenant(user.email, invite_request.tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to invite self to tenant {invite_request.tenant_id}: {e}"
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/users/pending")
|
||||
def list_pending_users(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[PendingUserSnapshot]:
|
||||
pending_emails = get_pending_users()
|
||||
return [PendingUserSnapshot(email=email) for email in pending_emails]
|
||||
|
||||
|
||||
@router.post("/users/invite/approve")
|
||||
async def approve_user(
|
||||
approve_user_request: ApproveUserRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
approve_user_invite(approve_user_request.email, tenant_id)
|
||||
|
||||
|
||||
@router.post("/users/invite/accept")
|
||||
async def accept_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
) -> None:
|
||||
"""
|
||||
Accept an invitation to join a tenant.
|
||||
"""
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
accept_user_invite(user.email, invite_request.tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to accept invite: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Failed to accept invitation")
|
||||
|
||||
|
||||
@router.post("/users/invite/deny")
|
||||
async def deny_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
) -> None:
|
||||
"""
|
||||
Deny an invitation to join a tenant.
|
||||
"""
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
deny_user_invite(user.email, invite_request.tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to deny invite: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Failed to deny invitation")
|
||||
@@ -1,27 +1,56 @@
|
||||
import logging
|
||||
|
||||
from fastapi_users import exceptions
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import get_pending_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
from onyx.auth.invited_users import write_pending_users
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.models import UserTenantMapping
|
||||
from onyx.server.manage.models import TenantSnapshot
|
||||
from onyx.setup import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_tenant_id_for_email(email: str) -> str:
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
# Implement logic to get tenant_id from the mapping table
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
result = db_session.execute(
|
||||
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
|
||||
)
|
||||
tenant_id = result.scalar_one_or_none()
|
||||
try:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# First try to get an active tenant
|
||||
result = db_session.execute(
|
||||
select(UserTenantMapping).where(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
mapping = result.scalar_one_or_none()
|
||||
tenant_id = mapping.tenant_id if mapping else None
|
||||
|
||||
# If no active tenant found, try to get the first inactive one
|
||||
if tenant_id is None:
|
||||
result = db_session.execute(
|
||||
select(UserTenantMapping).where(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
mapping = result.scalar_one_or_none()
|
||||
if mapping:
|
||||
# Mark this mapping as active
|
||||
mapping.active = True
|
||||
db_session.commit()
|
||||
tenant_id = mapping.tenant_id
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error getting tenant id for email {email}: {e}")
|
||||
raise exceptions.UserNotExists()
|
||||
if tenant_id is None:
|
||||
raise exceptions.UserNotExists()
|
||||
return tenant_id
|
||||
@@ -38,13 +67,56 @@ def user_owns_a_tenant(email: str) -> bool:
|
||||
|
||||
|
||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
"""
|
||||
Add users to a tenant with proper transaction handling.
|
||||
Checks if users already have a tenant mapping to avoid duplicates.
|
||||
If a user already has an active mapping to any tenant, the new mapping will be added as inactive.
|
||||
"""
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
# Start a transaction
|
||||
db_session.begin()
|
||||
|
||||
for email in emails:
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
# Check if the user already has a mapping to this tenant
|
||||
existing_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
|
||||
# If user already has an active mapping, add this one as inactive
|
||||
if not existing_mapping:
|
||||
# Check if the user already has an active mapping to any tenant
|
||||
has_active_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
db_session.add(
|
||||
UserTenantMapping(
|
||||
email=email,
|
||||
tenant_id=tenant_id,
|
||||
active=False if has_active_mapping else True,
|
||||
)
|
||||
)
|
||||
|
||||
# Commit the transaction
|
||||
db_session.commit()
|
||||
logger.info(f"Successfully added users {emails} to tenant {tenant_id}")
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Failed to add users to tenant {tenant_id}")
|
||||
db_session.commit()
|
||||
db_session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
@@ -76,3 +148,187 @@ def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def invite_self_to_tenant(email: str, tenant_id: str) -> None:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
pending_users = get_pending_users()
|
||||
if email in pending_users:
|
||||
return
|
||||
write_pending_users(pending_users + [email])
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def approve_user_invite(email: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Approve a user invite to a tenant.
|
||||
This will delete all existing records for this email and create a new mapping entry for the user in this tenant.
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# Delete all existing records for this email
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.email == email
|
||||
).delete()
|
||||
|
||||
# Create a new mapping entry for the user in this tenant
|
||||
new_mapping = UserTenantMapping(email=email, tenant_id=tenant_id, active=True)
|
||||
db_session.add(new_mapping)
|
||||
db_session.commit()
|
||||
|
||||
# Also remove the user from pending users list
|
||||
# Remove from pending users
|
||||
pending_users = get_pending_users()
|
||||
if email in pending_users:
|
||||
pending_users.remove(email)
|
||||
write_pending_users(pending_users)
|
||||
|
||||
# Add to invited users
|
||||
invited_users = get_invited_users()
|
||||
if email not in invited_users:
|
||||
invited_users.append(email)
|
||||
write_invited_users(invited_users)
|
||||
|
||||
|
||||
def accept_user_invite(email: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Accept an invitation to join a tenant.
|
||||
This activates the user's mapping to the tenant.
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
# First check if there's an active mapping for this user and tenant
|
||||
active_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# If an active mapping exists, delete it
|
||||
if active_mapping:
|
||||
db_session.delete(active_mapping)
|
||||
logger.info(
|
||||
f"Deleted existing active mapping for user {email} in tenant {tenant_id}"
|
||||
)
|
||||
|
||||
# Find the inactive mapping for this user and tenant
|
||||
mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
UserTenantMapping.active == False, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if mapping:
|
||||
# Set all other mappings for this user to inactive
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
).update({"active": False})
|
||||
|
||||
# Activate this mapping
|
||||
mapping.active = True
|
||||
db_session.commit()
|
||||
logger.info(f"User {email} accepted invitation to tenant {tenant_id}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"No invitation found for user {email} in tenant {tenant_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
logger.exception(
|
||||
f"Failed to accept invitation for user {email} to tenant {tenant_id}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def deny_user_invite(email: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Deny an invitation to join a tenant.
|
||||
This removes the user's mapping to the tenant.
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# Delete the mapping for this user and tenant
|
||||
result = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
UserTenantMapping.active == False, # noqa: E712
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
if result:
|
||||
logger.info(f"User {email} denied invitation to tenant {tenant_id}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"No invitation found for user {email} in tenant {tenant_id}"
|
||||
)
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
pending_users = get_invited_users()
|
||||
if email in pending_users:
|
||||
pending_users.remove(email)
|
||||
write_invited_users(pending_users)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def get_tenant_count(tenant_id: str) -> int:
|
||||
"""
|
||||
Get the number of active users for this tenant
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# Count the number of active users for this tenant
|
||||
user_count = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
return user_count
|
||||
|
||||
|
||||
def get_tenant_invitation(email: str) -> TenantSnapshot | None:
|
||||
"""
|
||||
Get the first tenant invitation for this user
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# Get the first tenant invitation for this user
|
||||
invitation = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == False, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if invitation:
|
||||
# Get the user count for this tenant
|
||||
user_count = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.tenant_id == invitation.tenant_id,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.count()
|
||||
)
|
||||
return TenantSnapshot(
|
||||
tenant_id=invitation.tenant_id, number_of_users=user_count
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
BIN
backend/hello-vmlinux.bin
Normal file
BIN
backend/hello-vmlinux.bin
Normal file
Binary file not shown.
@@ -3,6 +3,7 @@ from shared_configs.enums import EmbedTextType
|
||||
|
||||
|
||||
MODEL_WARM_UP_STRING = "hi " * 512
|
||||
INFORMATION_CONTENT_MODEL_WARM_UP_STRING = "hi " * 16
|
||||
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
|
||||
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
|
||||
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fastapi import APIRouter
|
||||
from huggingface_hub import snapshot_download # type: ignore
|
||||
from setfit import SetFitModel # type: ignore[import]
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
from transformers import BatchEncoding # type: ignore
|
||||
from transformers import PreTrainedTokenizer # type: ignore
|
||||
|
||||
from model_server.constants import INFORMATION_CONTENT_MODEL_WARM_UP_STRING
|
||||
from model_server.constants import MODEL_WARM_UP_STRING
|
||||
from model_server.onyx_torch_model import ConnectorClassifier
|
||||
from model_server.onyx_torch_model import HybridClassifier
|
||||
@@ -13,11 +16,22 @@ from model_server.utils import simple_log_function_time
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
|
||||
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
|
||||
from shared_configs.configs import (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
|
||||
)
|
||||
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
|
||||
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
from shared_configs.configs import (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE,
|
||||
)
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.configs import INFORMATION_CONTENT_MODEL_TAG
|
||||
from shared_configs.configs import INFORMATION_CONTENT_MODEL_VERSION
|
||||
from shared_configs.configs import INTENT_MODEL_TAG
|
||||
from shared_configs.configs import INTENT_MODEL_VERSION
|
||||
from shared_configs.model_server_models import ConnectorClassificationRequest
|
||||
from shared_configs.model_server_models import ConnectorClassificationResponse
|
||||
from shared_configs.model_server_models import ContentClassificationPrediction
|
||||
from shared_configs.model_server_models import IntentRequest
|
||||
from shared_configs.model_server_models import IntentResponse
|
||||
|
||||
@@ -31,6 +45,10 @@ _CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
|
||||
_INTENT_TOKENIZER: AutoTokenizer | None = None
|
||||
_INTENT_MODEL: HybridClassifier | None = None
|
||||
|
||||
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
|
||||
|
||||
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
|
||||
|
||||
|
||||
def get_connector_classifier_tokenizer() -> AutoTokenizer:
|
||||
global _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
@@ -85,7 +103,7 @@ def get_intent_model_tokenizer() -> AutoTokenizer:
|
||||
|
||||
def get_local_intent_model(
|
||||
model_name_or_path: str = INTENT_MODEL_VERSION,
|
||||
tag: str = INTENT_MODEL_TAG,
|
||||
tag: str | None = INTENT_MODEL_TAG,
|
||||
) -> HybridClassifier:
|
||||
global _INTENT_MODEL
|
||||
if _INTENT_MODEL is None:
|
||||
@@ -102,7 +120,9 @@ def get_local_intent_model(
|
||||
try:
|
||||
# Attempt to download the model snapshot
|
||||
logger.notice(f"Downloading model snapshot for {model_name_or_path}")
|
||||
local_path = snapshot_download(repo_id=model_name_or_path, revision=tag)
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=False
|
||||
)
|
||||
_INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -112,6 +132,44 @@ def get_local_intent_model(
|
||||
return _INTENT_MODEL
|
||||
|
||||
|
||||
def get_local_information_content_model(
|
||||
model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION,
|
||||
tag: str | None = INFORMATION_CONTENT_MODEL_TAG,
|
||||
) -> SetFitModel:
|
||||
global _INFORMATION_CONTENT_MODEL
|
||||
if _INFORMATION_CONTENT_MODEL is None:
|
||||
try:
|
||||
# Calculate where the cache should be, then load from local if available
|
||||
logger.notice(
|
||||
f"Loading content information model from local cache: {model_name_or_path}"
|
||||
)
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=True
|
||||
)
|
||||
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
|
||||
logger.notice(
|
||||
f"Loaded content information model from local cache: {local_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load content information model directly: {e}")
|
||||
try:
|
||||
# Attempt to download the model snapshot
|
||||
logger.notice(
|
||||
f"Downloading content information model snapshot for {model_name_or_path}"
|
||||
)
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=False
|
||||
)
|
||||
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to load content information model even after attempted snapshot download: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
return _INFORMATION_CONTENT_MODEL
|
||||
|
||||
|
||||
def tokenize_connector_classification_query(
|
||||
connectors: list[str],
|
||||
query: str,
|
||||
@@ -195,6 +253,13 @@ def warm_up_intent_model() -> None:
|
||||
)
|
||||
|
||||
|
||||
def warm_up_information_content_model() -> None:
|
||||
logger.notice("Warming up Content Model") # TODO: add version if needed
|
||||
|
||||
information_content_model = get_local_information_content_model()
|
||||
information_content_model(INFORMATION_CONTENT_MODEL_WARM_UP_STRING)
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
|
||||
intent_model = get_local_intent_model()
|
||||
@@ -218,6 +283,117 @@ def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
|
||||
return intent_probabilities.tolist(), token_positive_probs
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def run_content_classification_inference(
|
||||
text_inputs: list[str],
|
||||
) -> list[ContentClassificationPrediction]:
|
||||
"""
|
||||
Assign a score to the segments in question. The model stored in get_local_information_content_model()
|
||||
creates the 'model score' based on its training, and the scores are then converted to a 0.0-1.0 scale.
|
||||
In the code outside of the model/inference model servers that score will be converted into the actual
|
||||
boost factor.
|
||||
"""
|
||||
|
||||
def _prob_to_score(prob: float) -> float:
|
||||
"""
|
||||
Conversion of base score to 0.0 - 1.0 score. Note that the min/max values depend on the model!
|
||||
"""
|
||||
_MIN_BASE_SCORE = 0.25
|
||||
_MAX_BASE_SCORE = 0.75
|
||||
if prob < _MIN_BASE_SCORE:
|
||||
raw_score = 0.0
|
||||
elif prob < _MAX_BASE_SCORE:
|
||||
raw_score = (prob - _MIN_BASE_SCORE) / (_MAX_BASE_SCORE - _MIN_BASE_SCORE)
|
||||
else:
|
||||
raw_score = 1.0
|
||||
return (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
+ (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
|
||||
- INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
)
|
||||
* raw_score
|
||||
)
|
||||
|
||||
_BATCH_SIZE = 32
|
||||
content_model = get_local_information_content_model()
|
||||
|
||||
# Process inputs in batches
|
||||
all_output_classes: list[int] = []
|
||||
all_base_output_probabilities: list[float] = []
|
||||
|
||||
for i in range(0, len(text_inputs), _BATCH_SIZE):
|
||||
batch = text_inputs[i : i + _BATCH_SIZE]
|
||||
batch_with_prefix = []
|
||||
batch_indices = []
|
||||
|
||||
# Pre-allocate results for this batch
|
||||
batch_output_classes: list[np.ndarray] = [np.array(1)] * len(batch)
|
||||
batch_probabilities: list[np.ndarray] = [np.array(1.0)] * len(batch)
|
||||
|
||||
# Pre-process batch to handle long input exceptions
|
||||
for j, text in enumerate(batch):
|
||||
if len(text) == 0:
|
||||
# if no input, treat as non-informative from the model's perspective
|
||||
batch_output_classes[j] = np.array(0)
|
||||
batch_probabilities[j] = np.array(0.0)
|
||||
logger.warning("Input for Content Information Model is empty")
|
||||
|
||||
elif (
|
||||
len(text.split())
|
||||
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
|
||||
):
|
||||
# if input is short, use the model
|
||||
batch_with_prefix.append(
|
||||
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX + text
|
||||
)
|
||||
batch_indices.append(j)
|
||||
else:
|
||||
# if longer than cutoff, treat as informative (stay with default), but issue warning
|
||||
logger.warning("Input for Content Information Model too long")
|
||||
|
||||
if batch_with_prefix: # Only run model if we have valid inputs
|
||||
# Get predictions for the batch
|
||||
model_output_classes = content_model(batch_with_prefix)
|
||||
model_output_probabilities = content_model.predict_proba(batch_with_prefix)
|
||||
|
||||
# Place results in the correct positions
|
||||
for idx, batch_idx in enumerate(batch_indices):
|
||||
batch_output_classes[batch_idx] = model_output_classes[idx].numpy()
|
||||
batch_probabilities[batch_idx] = model_output_probabilities[idx][
|
||||
1
|
||||
].numpy() # x[1] is prob of the positive class
|
||||
|
||||
all_output_classes.extend([int(x) for x in batch_output_classes])
|
||||
all_base_output_probabilities.extend([float(x) for x in batch_probabilities])
|
||||
|
||||
logits = [
|
||||
np.log(p / (1 - p)) if p != 0.0 and p != 1.0 else (100 if p == 1.0 else -100)
|
||||
for p in all_base_output_probabilities
|
||||
]
|
||||
scaled_logits = [
|
||||
logit / INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE
|
||||
for logit in logits
|
||||
]
|
||||
output_probabilities_with_temp = [
|
||||
np.exp(scaled_logit) / (1 + np.exp(scaled_logit))
|
||||
for scaled_logit in scaled_logits
|
||||
]
|
||||
|
||||
prediction_scores = [
|
||||
_prob_to_score(p_temp) for p_temp in output_probabilities_with_temp
|
||||
]
|
||||
|
||||
content_classification_predictions = [
|
||||
ContentClassificationPrediction(
|
||||
predicted_label=predicted_label, content_boost_factor=output_score
|
||||
)
|
||||
for predicted_label, output_score in zip(all_output_classes, prediction_scores)
|
||||
]
|
||||
|
||||
return content_classification_predictions
|
||||
|
||||
|
||||
def map_keywords(
|
||||
input_ids: torch.Tensor, tokenizer: AutoTokenizer, is_keyword: list[bool]
|
||||
) -> list[str]:
|
||||
@@ -362,3 +538,10 @@ async def process_analysis_request(
|
||||
|
||||
is_keyword, keywords = run_analysis(intent_request)
|
||||
return IntentResponse(is_keyword=is_keyword, keywords=keywords)
|
||||
|
||||
|
||||
@router.post("/content-classification")
|
||||
async def process_content_classification_request(
|
||||
content_classification_requests: list[str],
|
||||
) -> list[ContentClassificationPrediction]:
|
||||
return run_content_classification_inference(content_classification_requests)
|
||||
|
||||
@@ -62,6 +62,60 @@ _OPENAI_MAX_INPUT_LEN = 2048
|
||||
# Cohere allows up to 96 embeddings in a single embedding calling
|
||||
_COHERE_MAX_INPUT_LEN = 96
|
||||
|
||||
# Authentication error string constants
|
||||
_AUTH_ERROR_401 = "401"
|
||||
_AUTH_ERROR_UNAUTHORIZED = "unauthorized"
|
||||
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
|
||||
_AUTH_ERROR_PERMISSION = "permission"
|
||||
|
||||
|
||||
def is_authentication_error(error: Exception) -> bool:
|
||||
"""Check if an exception is related to authentication issues.
|
||||
|
||||
Args:
|
||||
error: The exception to check
|
||||
|
||||
Returns:
|
||||
bool: True if the error appears to be authentication-related
|
||||
"""
|
||||
error_str = str(error).lower()
|
||||
return (
|
||||
_AUTH_ERROR_401 in error_str
|
||||
or _AUTH_ERROR_UNAUTHORIZED in error_str
|
||||
or _AUTH_ERROR_INVALID_API_KEY in error_str
|
||||
or _AUTH_ERROR_PERMISSION in error_str
|
||||
)
|
||||
|
||||
|
||||
def format_embedding_error(
|
||||
error: Exception,
|
||||
service_name: str,
|
||||
model: str | None,
|
||||
provider: EmbeddingProvider,
|
||||
status_code: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format a standardized error string for embedding errors.
|
||||
"""
|
||||
detail = f"Status {status_code}" if status_code else f"{type(error)}"
|
||||
|
||||
return (
|
||||
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
|
||||
f"Model: {model} "
|
||||
f"Provider: {provider} "
|
||||
f"Exception: {error}"
|
||||
)
|
||||
|
||||
|
||||
# Custom exception for authentication errors
|
||||
class AuthenticationError(Exception):
|
||||
"""Raised when authentication fails with a provider."""
|
||||
|
||||
def __init__(self, provider: str, message: str = "API key is invalid or expired"):
|
||||
self.provider = provider
|
||||
self.message = message
|
||||
super().__init__(f"{provider} authentication failed: {message}")
|
||||
|
||||
|
||||
class CloudEmbedding:
|
||||
def __init__(
|
||||
@@ -92,31 +146,17 @@ class CloudEmbedding:
|
||||
)
|
||||
|
||||
final_embeddings: list[Embedding] = []
|
||||
try:
|
||||
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
||||
response = await client.embeddings.create(
|
||||
input=text_batch,
|
||||
model=model,
|
||||
dimensions=reduced_dimension or openai.NOT_GIVEN,
|
||||
)
|
||||
final_embeddings.extend(
|
||||
[embedding.embedding for embedding in response.data]
|
||||
)
|
||||
return final_embeddings
|
||||
except Exception as e:
|
||||
error_string = (
|
||||
f"Exception embedding text with OpenAI - {type(e)}: "
|
||||
f"Model: {model} "
|
||||
f"Provider: {self.provider} "
|
||||
f"Exception: {e}"
|
||||
|
||||
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
||||
response = await client.embeddings.create(
|
||||
input=text_batch,
|
||||
model=model,
|
||||
dimensions=reduced_dimension or openai.NOT_GIVEN,
|
||||
)
|
||||
logger.error(error_string)
|
||||
|
||||
# only log text when it's not an authentication error.
|
||||
if not isinstance(e, openai.AuthenticationError):
|
||||
logger.debug(f"Exception texts: {texts}")
|
||||
|
||||
raise RuntimeError(error_string)
|
||||
final_embeddings.extend(
|
||||
[embedding.embedding for embedding in response.data]
|
||||
)
|
||||
return final_embeddings
|
||||
|
||||
async def _embed_cohere(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
@@ -155,7 +195,6 @@ class CloudEmbedding:
|
||||
input_type=embedding_type,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
return response.embeddings
|
||||
|
||||
async def _embed_azure(
|
||||
@@ -239,22 +278,51 @@ class CloudEmbedding:
|
||||
deployment_name: str | None = None,
|
||||
reduced_dimension: int | None = None,
|
||||
) -> list[Embedding]:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return await self._embed_openai(texts, model_name, reduced_dimension)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return await self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
return await self._embed_litellm_proxy(texts, model_name)
|
||||
try:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return await self._embed_openai(texts, model_name, reduced_dimension)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return await self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
return await self._embed_litellm_proxy(texts, model_name)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return await self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return await self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return await self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return await self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return await self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return await self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
except openai.AuthenticationError:
|
||||
raise AuthenticationError(provider="OpenAI")
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise AuthenticationError(provider=str(self.provider))
|
||||
|
||||
error_string = format_embedding_error(
|
||||
e,
|
||||
str(self.provider),
|
||||
model_name or deployment_name,
|
||||
self.provider,
|
||||
status_code=e.response.status_code,
|
||||
)
|
||||
logger.error(error_string)
|
||||
logger.debug(f"Exception texts: {texts}")
|
||||
|
||||
raise RuntimeError(error_string)
|
||||
except Exception as e:
|
||||
if is_authentication_error(e):
|
||||
raise AuthenticationError(provider=str(self.provider))
|
||||
|
||||
error_string = format_embedding_error(
|
||||
e, str(self.provider), model_name or deployment_name, self.provider
|
||||
)
|
||||
logger.error(error_string)
|
||||
logger.debug(f"Exception texts: {texts}")
|
||||
|
||||
raise RuntimeError(error_string)
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
@@ -569,6 +637,13 @@ async def process_embed_request(
|
||||
gpu_type=gpu_type,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except AuthenticationError as e:
|
||||
# Handle authentication errors consistently
|
||||
logger.error(f"Authentication error: {e.provider}")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Authentication failed: {e.message}",
|
||||
)
|
||||
except RateLimitError as e:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
|
||||
@@ -13,6 +13,7 @@ from sentry_sdk.integrations.starlette import StarletteIntegration
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
|
||||
from model_server.custom_models import router as custom_models_router
|
||||
from model_server.custom_models import warm_up_information_content_model
|
||||
from model_server.custom_models import warm_up_intent_model
|
||||
from model_server.encoders import router as encoders_router
|
||||
from model_server.management_endpoints import router as management_router
|
||||
@@ -64,19 +65,31 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
|
||||
app.state.gpu_type = gpu_type
|
||||
|
||||
if TEMP_HF_CACHE_PATH.is_dir():
|
||||
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
|
||||
_move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH)
|
||||
shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True)
|
||||
logger.notice("Moved contents of temp_huggingface to huggingface cache.")
|
||||
try:
|
||||
if TEMP_HF_CACHE_PATH.is_dir():
|
||||
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
|
||||
_move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH)
|
||||
shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True)
|
||||
logger.notice("Moved contents of temp_huggingface to huggingface cache.")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error moving contents of temp_huggingface to huggingface cache: {e}. "
|
||||
"This is not a critical error and the model server will continue to run."
|
||||
)
|
||||
|
||||
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
|
||||
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
if not INDEXING_ONLY:
|
||||
logger.notice(
|
||||
"The intent model should run on the model server. The information content model should not run here."
|
||||
)
|
||||
warm_up_intent_model()
|
||||
else:
|
||||
logger.notice("This model server should only run document indexing.")
|
||||
logger.notice(
|
||||
"The content information model should run on the indexing model server. The intent model should not run here."
|
||||
)
|
||||
warm_up_information_content_model()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ def _get_access_for_document(
|
||||
document_id=document_id,
|
||||
)
|
||||
|
||||
return DocumentAccess.build(
|
||||
doc_access = DocumentAccess.build(
|
||||
user_emails=info[1] if info and info[1] else [],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
@@ -26,6 +26,8 @@ def _get_access_for_document(
|
||||
is_public=info[2] if info else False,
|
||||
)
|
||||
|
||||
return doc_access
|
||||
|
||||
|
||||
def get_access_for_document(
|
||||
document_id: str,
|
||||
@@ -38,12 +40,12 @@ def get_access_for_document(
|
||||
|
||||
|
||||
def get_null_document_access() -> DocumentAccess:
|
||||
return DocumentAccess(
|
||||
user_emails=set(),
|
||||
user_groups=set(),
|
||||
return DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
is_public=False,
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
|
||||
|
||||
@@ -55,19 +57,18 @@ def _get_access_for_documents(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
doc_access = {
|
||||
document_id: DocumentAccess(
|
||||
user_emails=set([email for email in user_emails if email]),
|
||||
doc_access = {}
|
||||
for document_id, user_emails, is_public in document_access_info:
|
||||
doc_access[document_id] = DocumentAccess.build(
|
||||
user_emails=[email for email in user_emails if email],
|
||||
# MIT version will wipe all groups and external groups on update
|
||||
user_groups=set(),
|
||||
user_groups=[],
|
||||
is_public=is_public,
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
for document_id, user_emails, is_public in document_access_info
|
||||
}
|
||||
|
||||
# Sometimes the document has not be indexed by the indexing job yet, in those cases
|
||||
# Sometimes the document has not been indexed by the indexing job yet, in those cases
|
||||
# the document does not exist and so we use least permissive. Specifically the EE version
|
||||
# checks the MIT version permissions and creates a superset. This ensures that this flow
|
||||
# does not fail even if the Document has not yet been indexed.
|
||||
|
||||
@@ -20,7 +20,7 @@ class ExternalAccess:
|
||||
class DocExternalAccess:
|
||||
"""
|
||||
This is just a class to wrap the external access and the document ID
|
||||
together. It's used for syncing document permissions to Redis.
|
||||
together. It's used for syncing document permissions to Vespa.
|
||||
"""
|
||||
|
||||
external_access: ExternalAccess
|
||||
@@ -56,34 +56,46 @@ class DocExternalAccess:
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, init=False)
|
||||
class DocumentAccess(ExternalAccess):
|
||||
# User emails for Onyx users, None indicates admin
|
||||
user_emails: set[str | None]
|
||||
|
||||
# Names of user groups associated with this document
|
||||
user_groups: set[str]
|
||||
|
||||
def to_acl(self) -> set[str]:
|
||||
return set(
|
||||
[
|
||||
prefix_user_email(user_email)
|
||||
for user_email in self.user_emails
|
||||
if user_email
|
||||
]
|
||||
+ [prefix_user_group(group_name) for group_name in self.user_groups]
|
||||
+ [
|
||||
prefix_user_email(user_email)
|
||||
for user_email in self.external_user_emails
|
||||
]
|
||||
+ [
|
||||
# The group names are already prefixed by the source type
|
||||
# This adds an additional prefix of "external_group:"
|
||||
prefix_external_group(group_name)
|
||||
for group_name in self.external_user_group_ids
|
||||
]
|
||||
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
|
||||
external_user_emails: set[str]
|
||||
external_user_group_ids: set[str]
|
||||
is_public: bool
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise TypeError(
|
||||
"Use `DocumentAccess.build(...)` instead of creating an instance directly."
|
||||
)
|
||||
|
||||
def to_acl(self) -> set[str]:
|
||||
# the acl's emitted by this function are prefixed by type
|
||||
# to get the native objects, access the member variables directly
|
||||
|
||||
acl_set: set[str] = set()
|
||||
for user_email in self.user_emails:
|
||||
if user_email:
|
||||
acl_set.add(prefix_user_email(user_email))
|
||||
|
||||
for group_name in self.user_groups:
|
||||
acl_set.add(prefix_user_group(group_name))
|
||||
|
||||
for external_user_email in self.external_user_emails:
|
||||
acl_set.add(prefix_user_email(external_user_email))
|
||||
|
||||
for external_group_id in self.external_user_group_ids:
|
||||
acl_set.add(prefix_external_group(external_group_id))
|
||||
|
||||
if self.is_public:
|
||||
acl_set.add(PUBLIC_DOC_PAT)
|
||||
|
||||
return acl_set
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
@@ -93,29 +105,32 @@ class DocumentAccess(ExternalAccess):
|
||||
external_user_group_ids: list[str],
|
||||
is_public: bool,
|
||||
) -> "DocumentAccess":
|
||||
return cls(
|
||||
external_user_emails={
|
||||
prefix_user_email(external_email)
|
||||
for external_email in external_user_emails
|
||||
},
|
||||
external_user_group_ids={
|
||||
prefix_external_group(external_group_id)
|
||||
for external_group_id in external_user_group_ids
|
||||
},
|
||||
user_emails={
|
||||
prefix_user_email(user_email)
|
||||
for user_email in user_emails
|
||||
if user_email
|
||||
},
|
||||
user_groups=set(user_groups),
|
||||
is_public=is_public,
|
||||
"""Don't prefix incoming data wth acl type, prefix on read from to_acl!"""
|
||||
|
||||
obj = object.__new__(cls)
|
||||
object.__setattr__(
|
||||
obj, "user_emails", {user_email for user_email in user_emails if user_email}
|
||||
)
|
||||
object.__setattr__(obj, "user_groups", set(user_groups))
|
||||
object.__setattr__(
|
||||
obj,
|
||||
"external_user_emails",
|
||||
{external_email for external_email in external_user_emails},
|
||||
)
|
||||
object.__setattr__(
|
||||
obj,
|
||||
"external_user_group_ids",
|
||||
{external_group_id for external_group_id in external_user_group_ids},
|
||||
)
|
||||
object.__setattr__(obj, "is_public", is_public)
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
default_public_access = DocumentAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
user_emails=set(),
|
||||
user_groups=set(),
|
||||
default_public_access = DocumentAccess.build(
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
@@ -7,7 +7,6 @@ from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContext
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
@@ -24,7 +23,7 @@ def process_llm_stream(
|
||||
should_stream_answer: bool,
|
||||
writer: StreamWriter,
|
||||
final_search_results: list[LlmDoc] | None = None,
|
||||
displayed_search_results: list[OnyxContext] | list[LlmDoc] | None = None,
|
||||
displayed_search_results: list[LlmDoc] | None = None,
|
||||
) -> AIMessageChunk:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
@@ -92,6 +93,7 @@ def check_sub_answer(
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK,
|
||||
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
|
||||
)
|
||||
|
||||
quality_str: str = cast(str, response.content)
|
||||
|
||||
@@ -46,6 +46,7 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
@@ -119,6 +120,7 @@ def generate_sub_answer(
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
|
||||
@@ -43,6 +43,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_section_list,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
@@ -62,6 +63,7 @@ from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
@@ -153,8 +155,8 @@ def generate_initial_answer(
|
||||
)
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=answer_generation_documents.streaming_documents,
|
||||
final_context_sections=answer_generation_documents.context_documents,
|
||||
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
||||
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
@@ -278,6 +280,9 @@ def generate_initial_answer(
|
||||
for message in model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_ANSWER_GENERATION
|
||||
if _should_restrict_tokens(model.config)
|
||||
else None,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
|
||||
@@ -34,6 +34,7 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
|
||||
@@ -141,6 +142,7 @@ def decompose_orig_question(
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(0, writer),
|
||||
sep_callback=dispatch_subquestion_sep(0, writer),
|
||||
|
||||
@@ -33,6 +33,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_COMPARE_ANSWERS
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
@@ -112,6 +113,7 @@ def compare_answers(
|
||||
model.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS,
|
||||
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
|
||||
@@ -43,6 +43,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
)
|
||||
@@ -144,6 +145,7 @@ def create_refined_sub_questions(
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(1, writer),
|
||||
sep_callback=dispatch_subquestion_sep(1, writer),
|
||||
|
||||
@@ -50,13 +50,7 @@ def decide_refinement_need(
|
||||
)
|
||||
]
|
||||
|
||||
if graph_config.behavior.allow_refinement:
|
||||
return RequireRefinemenEvalUpdate(
|
||||
require_refined_answer_eval=decision,
|
||||
log_messages=log_messages,
|
||||
)
|
||||
else:
|
||||
return RequireRefinemenEvalUpdate(
|
||||
require_refined_answer_eval=False,
|
||||
log_messages=log_messages,
|
||||
)
|
||||
return RequireRefinemenEvalUpdate(
|
||||
require_refined_answer_eval=graph_config.behavior.allow_refinement and decision,
|
||||
log_messages=log_messages,
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
@@ -96,6 +97,7 @@ def extract_entities_terms(
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
max_tokens=AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
|
||||
@@ -46,6 +46,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_section_list,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
@@ -68,6 +69,8 @@ from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
|
||||
@@ -179,8 +182,8 @@ def generate_validate_refined_answer(
|
||||
)
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=answer_generation_documents.streaming_documents,
|
||||
final_context_sections=answer_generation_documents.context_documents,
|
||||
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
||||
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
@@ -302,7 +305,11 @@ def generate_validate_refined_answer(
|
||||
|
||||
def stream_refined_answer() -> list[str]:
|
||||
for message in model.stream(
|
||||
msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_ANSWER_GENERATION
|
||||
if _should_restrict_tokens(model.config)
|
||||
else None,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
@@ -409,6 +416,7 @@ def generate_validate_refined_answer(
|
||||
validation_model.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
|
||||
)
|
||||
refined_answer_quality = binary_string_test_after_answer_separator(
|
||||
text=cast(str, validation_response.content),
|
||||
|
||||
@@ -13,7 +13,6 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -144,8 +143,6 @@ def get_query_info(results: list[QueryRetrievalResult]) -> SearchQueryInfo:
|
||||
if result.query_info is not None:
|
||||
query_info = result.query_info
|
||||
break
|
||||
return query_info or SearchQueryInfo(
|
||||
predicted_search=None,
|
||||
final_filters=IndexFilters(access_control_list=None),
|
||||
recency_bias_multiplier=1.0,
|
||||
)
|
||||
|
||||
assert query_info is not None, "must have query info"
|
||||
return query_info
|
||||
|
||||
@@ -33,6 +33,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUERY_GENERATION
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
|
||||
)
|
||||
@@ -96,6 +97,7 @@ def expand_queries(
|
||||
model.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_SUBQUERY_GENERATION,
|
||||
),
|
||||
dispatch_subquery(level, question_num, writer),
|
||||
)
|
||||
|
||||
@@ -56,8 +56,8 @@ def format_results(
|
||||
relevance_list = relevance_from_docs(reranked_documents)
|
||||
for tool_response in yield_search_responses(
|
||||
query=state.question,
|
||||
reranked_sections=state.retrieved_documents,
|
||||
final_context_sections=reranked_documents,
|
||||
get_retrieved_sections=lambda: reranked_documents,
|
||||
get_final_context_sections=lambda: reranked_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
|
||||
@@ -91,7 +91,7 @@ def retrieve_documents(
|
||||
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
|
||||
|
||||
if AGENT_RETRIEVAL_STATS:
|
||||
pre_rerank_docs = callback_container[0]
|
||||
pre_rerank_docs = callback_container[0] if callback_container else []
|
||||
fit_scores = get_fit_scores(
|
||||
pre_rerank_docs,
|
||||
retrieved_docs,
|
||||
|
||||
@@ -25,6 +25,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
@@ -93,6 +94,7 @@ def verify_documents(
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
|
||||
)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
@@ -44,7 +44,9 @@ def call_tool(
|
||||
tool = tool_choice.tool
|
||||
tool_args = tool_choice.tool_args
|
||||
tool_id = tool_choice.id
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
tool_runner = ToolRunner(
|
||||
tool, tool_args, override_kwargs=tool_choice.search_tool_override_kwargs
|
||||
)
|
||||
tool_kickoff = tool_runner.kickoff()
|
||||
|
||||
emit_packet(tool_kickoff, writer)
|
||||
|
||||
@@ -15,8 +15,17 @@ from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.chat.tool_handling.tool_response_handler import (
|
||||
get_tool_call_for_non_tool_calling_llm_impl,
|
||||
)
|
||||
from onyx.context.search.preprocessing.preprocessing import query_analysis
|
||||
from onyx.context.search.retrieval.search_runner import get_query_embedding
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import TimeoutThread
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -25,6 +34,7 @@ logger = setup_logger()
|
||||
# and a function that handles extracting the necessary fields
|
||||
# from the state and config
|
||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||
@log_function_time(print_only=True)
|
||||
def choose_tool(
|
||||
state: ToolChoiceState,
|
||||
config: RunnableConfig,
|
||||
@@ -37,6 +47,31 @@ def choose_tool(
|
||||
should_stream_answer = state.should_stream_answer
|
||||
|
||||
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
force_use_tool = agent_config.tooling.force_use_tool
|
||||
|
||||
embedding_thread: TimeoutThread[Embedding] | None = None
|
||||
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
|
||||
override_kwargs: SearchToolOverrideKwargs | None = None
|
||||
if (
|
||||
not agent_config.behavior.use_agentic_search
|
||||
and agent_config.tooling.search_tool is not None
|
||||
and (
|
||||
not force_use_tool.force_use or force_use_tool.tool_name == SearchTool.name
|
||||
)
|
||||
):
|
||||
override_kwargs = SearchToolOverrideKwargs()
|
||||
# Run in a background thread to avoid blocking the main thread
|
||||
embedding_thread = run_in_background(
|
||||
get_query_embedding,
|
||||
agent_config.inputs.search_request.query,
|
||||
agent_config.persistence.db_session,
|
||||
)
|
||||
keyword_thread = run_in_background(
|
||||
query_analysis,
|
||||
agent_config.inputs.search_request.query,
|
||||
)
|
||||
|
||||
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
|
||||
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
|
||||
|
||||
@@ -47,7 +82,6 @@ def choose_tool(
|
||||
tools = [
|
||||
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
|
||||
]
|
||||
force_use_tool = agent_config.tooling.force_use_tool
|
||||
|
||||
tool, tool_args = None, None
|
||||
if force_use_tool.force_use and force_use_tool.args is not None:
|
||||
@@ -71,11 +105,22 @@ def choose_tool(
|
||||
# If we have a tool and tool args, we are ready to request a tool call.
|
||||
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
|
||||
if tool and tool_args:
|
||||
if embedding_thread and tool.name == SearchTool._NAME:
|
||||
# Wait for the embedding thread to finish
|
||||
embedding = wait_on_background(embedding_thread)
|
||||
assert override_kwargs is not None, "must have override kwargs"
|
||||
override_kwargs.precomputed_query_embedding = embedding
|
||||
if keyword_thread and tool.name == SearchTool._NAME:
|
||||
is_keyword, keywords = wait_on_background(keyword_thread)
|
||||
assert override_kwargs is not None, "must have override kwargs"
|
||||
override_kwargs.precomputed_is_keyword = is_keyword
|
||||
override_kwargs.precomputed_keywords = keywords
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=tool,
|
||||
tool_args=tool_args,
|
||||
id=str(uuid4()),
|
||||
search_tool_override_kwargs=override_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -153,10 +198,22 @@ def choose_tool(
|
||||
logger.debug(f"Selected tool: {selected_tool.name}")
|
||||
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
|
||||
|
||||
if embedding_thread and selected_tool.name == SearchTool._NAME:
|
||||
# Wait for the embedding thread to finish
|
||||
embedding = wait_on_background(embedding_thread)
|
||||
assert override_kwargs is not None, "must have override kwargs"
|
||||
override_kwargs.precomputed_query_embedding = embedding
|
||||
if keyword_thread and selected_tool.name == SearchTool._NAME:
|
||||
is_keyword, keywords = wait_on_background(keyword_thread)
|
||||
assert override_kwargs is not None, "must have override kwargs"
|
||||
override_kwargs.precomputed_is_keyword = is_keyword
|
||||
override_kwargs.precomputed_keywords = keywords
|
||||
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=selected_tool,
|
||||
tool_args=selected_tool_call_request["args"],
|
||||
id=selected_tool_call_request["id"],
|
||||
search_tool_override_kwargs=override_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -9,18 +9,21 @@ from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_DOC_CONTENT_ID,
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_utils import section_to_llm_doc
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def basic_use_tool_response(
|
||||
state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> BasicOutput:
|
||||
@@ -50,11 +53,11 @@ def basic_use_tool_response(
|
||||
for yield_item in tool_call_responses:
|
||||
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_search_results = cast(list[LlmDoc], yield_item.response)
|
||||
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
|
||||
search_contexts = cast(OnyxContexts, yield_item.response).contexts
|
||||
for doc in search_contexts:
|
||||
if doc.document_id not in initial_search_results:
|
||||
initial_search_results.append(doc)
|
||||
elif yield_item.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
search_response_summary = cast(SearchResponseSummary, yield_item.response)
|
||||
for section in search_response_summary.top_sections:
|
||||
if section.center_chunk.document_id not in initial_search_results:
|
||||
initial_search_results.append(section_to_llm_doc(section))
|
||||
|
||||
new_tool_call_chunk = AIMessageChunk(content="")
|
||||
if not agent_config.behavior.skip_gen_ai_answer_generation:
|
||||
|
||||
@@ -2,6 +2,7 @@ from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
@@ -35,6 +36,7 @@ class ToolChoice(BaseModel):
|
||||
tool: Tool
|
||||
tool_args: dict
|
||||
id: str | None
|
||||
search_tool_override_kwargs: SearchToolOverrideKwargs | None = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@@ -13,6 +13,11 @@ AGENT_NEGATIVE_VALUE_STR = "no"
|
||||
AGENT_ANSWER_SEPARATOR = "Answer:"
|
||||
|
||||
|
||||
EMBEDDING_KEY = "embedding"
|
||||
IS_KEYWORD_KEY = "is_keyword"
|
||||
KEYWORDS_KEY = "keywords"
|
||||
|
||||
|
||||
class AgentLLMErrorType(str, Enum):
|
||||
TIMEOUT = "timeout"
|
||||
RATE_LIMIT = "rate_limit"
|
||||
|
||||
@@ -42,6 +42,7 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_HISTORY_SUMMARY
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
)
|
||||
@@ -61,6 +62,7 @@ from onyx.db.persona import Persona
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.prompts.agent_search import (
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
|
||||
)
|
||||
@@ -319,8 +321,10 @@ def dispatch_separated(
|
||||
sep: str = DISPATCH_SEP_CHAR,
|
||||
) -> list[BaseMessage_Content]:
|
||||
num = 1
|
||||
accumulated_tokens = ""
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
for token in tokens:
|
||||
accumulated_tokens += cast(str, token.content)
|
||||
content = cast(str, token.content)
|
||||
if sep in content:
|
||||
sub_question_parts = content.split(sep)
|
||||
@@ -402,6 +406,7 @@ def summarize_history(
|
||||
llm.invoke,
|
||||
history_context_prompt,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_HISTORY_SUMMARY,
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
logger.error("LLM Timeout Error - summarize history")
|
||||
@@ -505,3 +510,9 @@ def get_deduplicated_structured_subquestion_documents(
|
||||
cited_documents=dedup_inference_section_list(cited_docs),
|
||||
context_documents=dedup_inference_section_list(context_docs),
|
||||
)
|
||||
|
||||
|
||||
def _should_restrict_tokens(llm_config: LLMConfig) -> bool:
|
||||
return not (
|
||||
llm_config.model_provider == "openai" and llm_config.model_name.startswith("o")
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import smtplib
|
||||
from datetime import datetime
|
||||
from email.mime.image import MIMEImage
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from email.utils import formatdate
|
||||
@@ -13,10 +14,16 @@ from onyx.configs.app_configs import SMTP_SERVER
|
||||
from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
|
||||
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
|
||||
from onyx.configs.constants import ONYX_SLACK_URL
|
||||
from onyx.db.models import User
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.file import FileWithMimeType
|
||||
from onyx.utils.url import add_url_params
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
HTML_EMAIL_TEMPLATE = """\
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
@@ -56,6 +63,11 @@ HTML_EMAIL_TEMPLATE = """\
|
||||
}}
|
||||
.header img {{
|
||||
max-width: 140px;
|
||||
width: 140px;
|
||||
height: auto;
|
||||
filter: brightness(1.1) contrast(1.2);
|
||||
border-radius: 8px;
|
||||
padding: 5px;
|
||||
}}
|
||||
.body-content {{
|
||||
padding: 20px 30px;
|
||||
@@ -72,12 +84,16 @@ HTML_EMAIL_TEMPLATE = """\
|
||||
}}
|
||||
.cta-button {{
|
||||
display: inline-block;
|
||||
padding: 12px 20px;
|
||||
background-color: #000000;
|
||||
padding: 14px 24px;
|
||||
background-color: #0055FF;
|
||||
color: #ffffff !important;
|
||||
text-decoration: none;
|
||||
border-radius: 4px;
|
||||
font-weight: 500;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
margin-top: 10px;
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
text-align: center;
|
||||
}}
|
||||
.footer {{
|
||||
font-size: 13px;
|
||||
@@ -97,8 +113,8 @@ HTML_EMAIL_TEMPLATE = """\
|
||||
<td class="header">
|
||||
<img
|
||||
style="background-color: #ffffff; border-radius: 8px;"
|
||||
src="https://www.onyx.app/logos/customer/onyx.png"
|
||||
alt="Onyx Logo"
|
||||
src="cid:logo.png"
|
||||
alt="{application_name} Logo"
|
||||
>
|
||||
</td>
|
||||
</tr>
|
||||
@@ -113,9 +129,8 @@ HTML_EMAIL_TEMPLATE = """\
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="footer">
|
||||
© {year} Onyx. All rights reserved.
|
||||
<br>
|
||||
Have questions? Join our Slack community <a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA">here</a>.
|
||||
© {year} {application_name}. All rights reserved.
|
||||
{slack_fragment}
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
@@ -125,17 +140,27 @@ HTML_EMAIL_TEMPLATE = """\
|
||||
|
||||
|
||||
def build_html_email(
|
||||
heading: str, message: str, cta_text: str | None = None, cta_link: str | None = None
|
||||
application_name: str | None,
|
||||
heading: str,
|
||||
message: str,
|
||||
cta_text: str | None = None,
|
||||
cta_link: str | None = None,
|
||||
) -> str:
|
||||
slack_fragment = ""
|
||||
if application_name == ONYX_DEFAULT_APPLICATION_NAME:
|
||||
slack_fragment = f'<br>Have questions? Join our Slack community <a href="{ONYX_SLACK_URL}">here</a>.'
|
||||
|
||||
if cta_text and cta_link:
|
||||
cta_block = f'<a class="cta-button" href="{cta_link}">{cta_text}</a>'
|
||||
else:
|
||||
cta_block = ""
|
||||
return HTML_EMAIL_TEMPLATE.format(
|
||||
application_name=application_name,
|
||||
title=heading,
|
||||
heading=heading,
|
||||
message=message,
|
||||
cta_block=cta_block,
|
||||
slack_fragment=slack_fragment,
|
||||
year=datetime.now().year,
|
||||
)
|
||||
|
||||
@@ -146,22 +171,44 @@ def send_email(
|
||||
html_body: str,
|
||||
text_body: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
inline_png: tuple[str, bytes] | None = None,
|
||||
) -> None:
|
||||
if not EMAIL_CONFIGURED:
|
||||
raise ValueError("Email is not configured.")
|
||||
|
||||
# Create a multipart/alternative message - this indicates these are alternative versions of the same content
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = subject
|
||||
msg["To"] = user_email
|
||||
msg["From"] = mail_from
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
msg["Date"] = formatdate(localtime=True)
|
||||
msg["Message-ID"] = make_msgid(domain="onyx.app")
|
||||
|
||||
part_text = MIMEText(text_body, "plain")
|
||||
part_html = MIMEText(html_body, "html")
|
||||
# Add text part first (lowest priority)
|
||||
text_part = MIMEText(text_body, "plain")
|
||||
msg.attach(text_part)
|
||||
|
||||
msg.attach(part_text)
|
||||
msg.attach(part_html)
|
||||
if inline_png:
|
||||
# For HTML with images, create a multipart/related container
|
||||
related = MIMEMultipart("related")
|
||||
|
||||
# Add the HTML part to the related container
|
||||
html_part = MIMEText(html_body, "html")
|
||||
related.attach(html_part)
|
||||
|
||||
# Add image with proper Content-ID to the related container
|
||||
img = MIMEImage(inline_png[1], _subtype="png")
|
||||
img.add_header("Content-ID", f"<{inline_png[0]}>")
|
||||
img.add_header("Content-Disposition", "inline", filename=inline_png[0])
|
||||
related.attach(img)
|
||||
|
||||
# Add the related part to the message (higher priority than text)
|
||||
msg.attach(related)
|
||||
else:
|
||||
# No images, just add HTML directly (higher priority than text)
|
||||
html_part = MIMEText(html_body, "html")
|
||||
msg.attach(html_part)
|
||||
|
||||
try:
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
|
||||
@@ -173,8 +220,21 @@ def send_email(
|
||||
|
||||
|
||||
def send_subscription_cancellation_email(user_email: str) -> None:
|
||||
"""This is templated but isn't meaningful for whitelabeling."""
|
||||
|
||||
# Example usage of the reusable HTML
|
||||
subject = "Your Onyx Subscription Has Been Canceled"
|
||||
try:
|
||||
load_runtime_settings_fn = fetch_versioned_implementation(
|
||||
"onyx.server.enterprise_settings.store", "load_runtime_settings"
|
||||
)
|
||||
settings = load_runtime_settings_fn()
|
||||
application_name = settings.application_name
|
||||
except ModuleNotFoundError:
|
||||
application_name = ONYX_DEFAULT_APPLICATION_NAME
|
||||
|
||||
onyx_file = OnyxRuntime.get_emailable_logo()
|
||||
|
||||
subject = f"Your {application_name} Subscription Has Been Canceled"
|
||||
heading = "Subscription Canceled"
|
||||
message = (
|
||||
"<p>We're sorry to see you go.</p>"
|
||||
@@ -183,23 +243,48 @@ def send_subscription_cancellation_email(user_email: str) -> None:
|
||||
)
|
||||
cta_text = "Renew Subscription"
|
||||
cta_link = "https://www.onyx.app/pricing"
|
||||
html_content = build_html_email(heading, message, cta_text, cta_link)
|
||||
html_content = build_html_email(
|
||||
application_name,
|
||||
heading,
|
||||
message,
|
||||
cta_text,
|
||||
cta_link,
|
||||
)
|
||||
text_content = (
|
||||
"We're sorry to see you go.\n"
|
||||
"Your subscription has been canceled and will end on your next billing date.\n"
|
||||
"If you change your mind, visit https://www.onyx.app/pricing"
|
||||
)
|
||||
send_email(user_email, subject, html_content, text_content)
|
||||
send_email(
|
||||
user_email,
|
||||
subject,
|
||||
html_content,
|
||||
text_content,
|
||||
inline_png=("logo.png", onyx_file.data),
|
||||
)
|
||||
|
||||
|
||||
def send_user_email_invite(
|
||||
user_email: str, current_user: User, auth_type: AuthType
|
||||
) -> None:
|
||||
subject = "Invitation to Join Onyx Organization"
|
||||
onyx_file: FileWithMimeType | None = None
|
||||
|
||||
try:
|
||||
load_runtime_settings_fn = fetch_versioned_implementation(
|
||||
"onyx.server.enterprise_settings.store", "load_runtime_settings"
|
||||
)
|
||||
settings = load_runtime_settings_fn()
|
||||
application_name = settings.application_name
|
||||
except ModuleNotFoundError:
|
||||
application_name = ONYX_DEFAULT_APPLICATION_NAME
|
||||
|
||||
onyx_file = OnyxRuntime.get_emailable_logo()
|
||||
|
||||
subject = f"Invitation to Join {application_name} Organization"
|
||||
heading = "You've Been Invited!"
|
||||
|
||||
# the exact action taken by the user, and thus the message, depends on the auth type
|
||||
message = f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
|
||||
message = f"<p>You have been invited by {current_user.email} to join an organization on {application_name}.</p>"
|
||||
if auth_type == AuthType.CLOUD:
|
||||
message += (
|
||||
"<p>To join the organization, please click the button below to set a password "
|
||||
@@ -225,19 +310,32 @@ def send_user_email_invite(
|
||||
|
||||
cta_text = "Join Organization"
|
||||
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
|
||||
html_content = build_html_email(heading, message, cta_text, cta_link)
|
||||
|
||||
html_content = build_html_email(
|
||||
application_name,
|
||||
heading,
|
||||
message,
|
||||
cta_text,
|
||||
cta_link,
|
||||
)
|
||||
|
||||
# text content is the fallback for clients that don't support HTML
|
||||
# not as critical, so not having special cases for each auth type
|
||||
text_content = (
|
||||
f"You have been invited by {current_user.email} to join an organization on Onyx.\n"
|
||||
f"You have been invited by {current_user.email} to join an organization on {application_name}.\n"
|
||||
"To join the organization, please visit the following link:\n"
|
||||
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
|
||||
)
|
||||
if auth_type == AuthType.CLOUD:
|
||||
text_content += "You'll be asked to set a password or login with Google to complete your registration."
|
||||
|
||||
send_email(user_email, subject, html_content, text_content)
|
||||
send_email(
|
||||
user_email,
|
||||
subject,
|
||||
html_content,
|
||||
text_content,
|
||||
inline_png=("logo.png", onyx_file.data),
|
||||
)
|
||||
|
||||
|
||||
def send_forgot_password_email(
|
||||
@@ -247,27 +345,80 @@ def send_forgot_password_email(
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
# Builds a forgot password email with or without fancy HTML
|
||||
subject = "Onyx Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
if MULTI_TENANT:
|
||||
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
|
||||
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
|
||||
html_content = build_html_email("Reset Your Password", message)
|
||||
text_content = f"Click the following link to reset your password: {link}"
|
||||
send_email(user_email, subject, html_content, text_content, mail_from)
|
||||
try:
|
||||
load_runtime_settings_fn = fetch_versioned_implementation(
|
||||
"onyx.server.enterprise_settings.store", "load_runtime_settings"
|
||||
)
|
||||
settings = load_runtime_settings_fn()
|
||||
application_name = settings.application_name
|
||||
except ModuleNotFoundError:
|
||||
application_name = ONYX_DEFAULT_APPLICATION_NAME
|
||||
|
||||
onyx_file = OnyxRuntime.get_emailable_logo()
|
||||
|
||||
subject = f"Reset Your {application_name} Password"
|
||||
heading = "Reset Your Password"
|
||||
tenant_param = f"&tenant={tenant_id}" if tenant_id and MULTI_TENANT else ""
|
||||
message = "<p>Please click the button below to reset your password. This link will expire in 24 hours.</p>"
|
||||
cta_text = "Reset Password"
|
||||
cta_link = f"{WEB_DOMAIN}/auth/reset-password?token={token}{tenant_param}"
|
||||
html_content = build_html_email(
|
||||
application_name,
|
||||
heading,
|
||||
message,
|
||||
cta_text,
|
||||
cta_link,
|
||||
)
|
||||
text_content = (
|
||||
f"Please click the following link to reset your password. This link will expire in 24 hours.\n"
|
||||
f"{WEB_DOMAIN}/auth/reset-password?token={token}{tenant_param}"
|
||||
)
|
||||
send_email(
|
||||
user_email,
|
||||
subject,
|
||||
html_content,
|
||||
text_content,
|
||||
mail_from,
|
||||
inline_png=("logo.png", onyx_file.data),
|
||||
)
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
new_organization: bool = False,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
# Builds a verification email
|
||||
subject = "Onyx Email Verification"
|
||||
try:
|
||||
load_runtime_settings_fn = fetch_versioned_implementation(
|
||||
"onyx.server.enterprise_settings.store", "load_runtime_settings"
|
||||
)
|
||||
settings = load_runtime_settings_fn()
|
||||
application_name = settings.application_name
|
||||
except ModuleNotFoundError:
|
||||
application_name = ONYX_DEFAULT_APPLICATION_NAME
|
||||
|
||||
onyx_file = OnyxRuntime.get_emailable_logo()
|
||||
|
||||
subject = f"{application_name} Email Verification"
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
if new_organization:
|
||||
link = add_url_params(link, {"first_user": "true"})
|
||||
message = (
|
||||
f"<p>Click the following link to verify your email address:</p><p>{link}</p>"
|
||||
)
|
||||
html_content = build_html_email("Verify Your Email", message)
|
||||
html_content = build_html_email(
|
||||
application_name,
|
||||
"Verify Your Email",
|
||||
message,
|
||||
)
|
||||
text_content = f"Click the following link to verify your email address: {link}"
|
||||
send_email(user_email, subject, html_content, text_content, mail_from)
|
||||
send_email(
|
||||
user_email,
|
||||
subject,
|
||||
html_content,
|
||||
text_content,
|
||||
mail_from,
|
||||
inline_png=("logo.png", onyx_file.data),
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import cast
|
||||
|
||||
from onyx.configs.constants import KV_PENDING_USERS_KEY
|
||||
from onyx.configs.constants import KV_USER_STORE_KEY
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
@@ -18,3 +19,17 @@ def write_invited_users(emails: list[str]) -> int:
|
||||
store = get_kv_store()
|
||||
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
|
||||
return len(emails)
|
||||
|
||||
|
||||
def get_pending_users() -> list[str]:
|
||||
try:
|
||||
store = get_kv_store()
|
||||
return cast(list, store.load(KV_PENDING_USERS_KEY))
|
||||
except KvKeyNotFoundError:
|
||||
return list()
|
||||
|
||||
|
||||
def write_pending_users(emails: list[str]) -> int:
|
||||
store = get_kv_store()
|
||||
store.store(KV_PENDING_USERS_KEY, cast(JSON_ro, emails))
|
||||
return len(emails)
|
||||
|
||||
211
backend/onyx/auth/oauth_refresher.py
Normal file
211
backend/onyx/auth/oauth_refresher.py
Normal file
@@ -0,0 +1,211 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from fastapi_users.manager import BaseUserManager
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Standard OAuth refresh token endpoints
|
||||
REFRESH_ENDPOINTS = {
|
||||
"google": "https://oauth2.googleapis.com/token",
|
||||
}
|
||||
|
||||
|
||||
# NOTE: Keeping this as a utility function for potential future debugging,
|
||||
# but not using it in production code
|
||||
async def _test_expire_oauth_token(
|
||||
user: User,
|
||||
oauth_account: OAuthAccount,
|
||||
db_session: AsyncSession,
|
||||
user_manager: BaseUserManager[User, Any],
|
||||
expire_in_seconds: int = 10,
|
||||
) -> bool:
|
||||
"""
|
||||
Utility function for testing - Sets an OAuth token to expire in a short time
|
||||
to facilitate testing of the refresh flow.
|
||||
Not used in production code.
|
||||
"""
|
||||
try:
|
||||
new_expires_at = int(
|
||||
(datetime.now(timezone.utc).timestamp() + expire_in_seconds)
|
||||
)
|
||||
|
||||
updated_data: Dict[str, Any] = {"expires_at": new_expires_at}
|
||||
|
||||
await user_manager.user_db.update_oauth_account(
|
||||
user, cast(Any, oauth_account), updated_data
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(f"Error setting artificial expiration: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def refresh_oauth_token(
|
||||
user: User,
|
||||
oauth_account: OAuthAccount,
|
||||
db_session: AsyncSession,
|
||||
user_manager: BaseUserManager[User, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Attempt to refresh an OAuth token that's about to expire or has expired.
|
||||
Returns True if successful, False otherwise.
|
||||
"""
|
||||
if not oauth_account.refresh_token:
|
||||
logger.warning(
|
||||
f"No refresh token available for {user.email}'s {oauth_account.oauth_name} account"
|
||||
)
|
||||
return False
|
||||
|
||||
provider = oauth_account.oauth_name
|
||||
if provider not in REFRESH_ENDPOINTS:
|
||||
logger.warning(f"Refresh endpoint not configured for provider: {provider}")
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.info(f"Refreshing OAuth token for {user.email}'s {provider} account")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
REFRESH_ENDPOINTS[provider],
|
||||
data={
|
||||
"client_id": OAUTH_CLIENT_ID,
|
||||
"client_secret": OAUTH_CLIENT_SECRET,
|
||||
"refresh_token": oauth_account.refresh_token,
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
f"Failed to refresh OAuth token: Status {response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
token_data = response.json()
|
||||
|
||||
new_access_token = token_data.get("access_token")
|
||||
new_refresh_token = token_data.get(
|
||||
"refresh_token", oauth_account.refresh_token
|
||||
)
|
||||
expires_in = token_data.get("expires_in")
|
||||
|
||||
# Calculate new expiry time if provided
|
||||
new_expires_at: Optional[int] = None
|
||||
if expires_in:
|
||||
new_expires_at = int(
|
||||
(datetime.now(timezone.utc).timestamp() + expires_in)
|
||||
)
|
||||
|
||||
# Update the OAuth account
|
||||
updated_data: Dict[str, Any] = {
|
||||
"access_token": new_access_token,
|
||||
"refresh_token": new_refresh_token,
|
||||
}
|
||||
|
||||
if new_expires_at:
|
||||
updated_data["expires_at"] = new_expires_at
|
||||
|
||||
# Update oidc_expiry in user model if we're tracking it
|
||||
if TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
oidc_expiry = datetime.fromtimestamp(
|
||||
new_expires_at, tz=timezone.utc
|
||||
)
|
||||
await user_manager.user_db.update(
|
||||
user, {"oidc_expiry": oidc_expiry}
|
||||
)
|
||||
|
||||
# Update the OAuth account
|
||||
await user_manager.user_db.update_oauth_account(
|
||||
user, cast(Any, oauth_account), updated_data
|
||||
)
|
||||
|
||||
logger.info(f"Successfully refreshed OAuth token for {user.email}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error refreshing OAuth token: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def check_and_refresh_oauth_tokens(
|
||||
user: User,
|
||||
db_session: AsyncSession,
|
||||
user_manager: BaseUserManager[User, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Check if any OAuth tokens are expired or about to expire and refresh them.
|
||||
"""
|
||||
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
|
||||
return
|
||||
|
||||
now_timestamp = datetime.now(timezone.utc).timestamp()
|
||||
|
||||
# Buffer time to refresh tokens before they expire (in seconds)
|
||||
buffer_seconds = 300 # 5 minutes
|
||||
|
||||
for oauth_account in user.oauth_accounts:
|
||||
# Skip accounts without refresh tokens
|
||||
if not oauth_account.refresh_token:
|
||||
continue
|
||||
|
||||
# If token is about to expire, refresh it
|
||||
if (
|
||||
oauth_account.expires_at
|
||||
and oauth_account.expires_at - now_timestamp < buffer_seconds
|
||||
):
|
||||
logger.info(f"OAuth token for {user.email} is about to expire - refreshing")
|
||||
success = await refresh_oauth_token(
|
||||
user, oauth_account, db_session, user_manager
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.warning(
|
||||
"Failed to refresh OAuth token. User may need to re-authenticate."
|
||||
)
|
||||
|
||||
|
||||
async def check_oauth_account_has_refresh_token(
|
||||
user: User,
|
||||
oauth_account: OAuthAccount,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an OAuth account has a refresh token.
|
||||
Returns True if a refresh token exists, False otherwise.
|
||||
"""
|
||||
return bool(oauth_account.refresh_token)
|
||||
|
||||
|
||||
async def get_oauth_accounts_requiring_refresh_token(user: User) -> List[OAuthAccount]:
|
||||
"""
|
||||
Returns a list of OAuth accounts for a user that are missing refresh tokens.
|
||||
These accounts will need re-authentication to get refresh tokens.
|
||||
"""
|
||||
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
|
||||
return []
|
||||
|
||||
accounts_needing_refresh = []
|
||||
for oauth_account in user.oauth_accounts:
|
||||
has_refresh_token = await check_oauth_account_has_refresh_token(
|
||||
user, oauth_account
|
||||
)
|
||||
if not has_refresh_token:
|
||||
accounts_needing_refresh.append(oauth_account)
|
||||
|
||||
return accounts_needing_refresh
|
||||
@@ -5,12 +5,16 @@ import string
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Protocol
|
||||
from typing import Tuple
|
||||
from typing import TypeVar
|
||||
|
||||
import jwt
|
||||
from email_validator import EmailNotValidError
|
||||
@@ -52,6 +56,7 @@ from httpx_oauth.oauth2 import OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.email_utils import send_forgot_password_email
|
||||
from onyx.auth.email_utils import send_user_verification_email
|
||||
@@ -100,10 +105,12 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.url import add_url_params
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import async_return_default_schema
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -354,7 +361,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reason="Password must contain at least one special character from the following set: "
|
||||
f"{PASSWORD_SPECIAL_CHARS}."
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
async def oauth_callback(
|
||||
@@ -508,6 +514,25 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
return user
|
||||
|
||||
async def on_after_login(
|
||||
self,
|
||||
user: User,
|
||||
request: Optional[Request] = None,
|
||||
response: Optional[Response] = None,
|
||||
) -> None:
|
||||
try:
|
||||
if response and request and ANONYMOUS_USER_COOKIE_NAME in request.cookies:
|
||||
response.delete_cookie(
|
||||
ANONYMOUS_USER_COOKIE_NAME,
|
||||
# Ensure cookie deletion doesn't override other cookies by setting the same path/domain
|
||||
path="/",
|
||||
domain=None,
|
||||
secure=WEB_DOMAIN.startswith("https"),
|
||||
)
|
||||
logger.debug(f"Deleted anonymous user cookie for user {user.email}")
|
||||
except Exception:
|
||||
logger.exception("Error deleting anonymous user cookie")
|
||||
|
||||
async def on_after_register(
|
||||
self, user: User, request: Optional[Request] = None
|
||||
) -> None:
|
||||
@@ -579,8 +604,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
logger.notice(
|
||||
f"Verification requested for user {user.id}. Verification token: {token}"
|
||||
)
|
||||
|
||||
send_user_verification_email(user.email, token)
|
||||
user_count = await get_user_count()
|
||||
send_user_verification_email(
|
||||
user.email, token, new_organization=user_count == 1
|
||||
)
|
||||
|
||||
async def authenticate(
|
||||
self, credentials: OAuth2PasswordRequestForm
|
||||
@@ -592,7 +619,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
tenant_id = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_tenant_id_for_email",
|
||||
None,
|
||||
POSTGRES_DEFAULT_SCHEMA,
|
||||
)(
|
||||
email=email,
|
||||
)
|
||||
@@ -686,16 +713,20 @@ cookie_transport = CookieTransport(
|
||||
)
|
||||
|
||||
|
||||
def get_redis_strategy() -> RedisStrategy:
|
||||
return TenantAwareRedisStrategy()
|
||||
T = TypeVar("T", covariant=True)
|
||||
ID = TypeVar("ID", contravariant=True)
|
||||
|
||||
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
return DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
|
||||
)
|
||||
# Protocol for strategies that support token refreshing without inheritance.
|
||||
class RefreshableStrategy(Protocol):
|
||||
"""Protocol for authentication strategies that support token refreshing."""
|
||||
|
||||
async def refresh_token(self, token: Optional[str], user: Any) -> str:
|
||||
"""
|
||||
Refresh an existing token by extending its lifetime.
|
||||
Returns either the same token with extended expiration or a new token.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
@@ -754,6 +785,75 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
redis = await get_async_redis_connection()
|
||||
await redis.delete(f"{self.key_prefix}{token}")
|
||||
|
||||
async def refresh_token(self, token: Optional[str], user: User) -> str:
|
||||
"""Refresh a token by extending its expiration time in Redis."""
|
||||
if token is None:
|
||||
# If no token provided, create a new one
|
||||
return await self.write_token(user)
|
||||
|
||||
redis = await get_async_redis_connection()
|
||||
token_key = f"{self.key_prefix}{token}"
|
||||
|
||||
# Check if token exists
|
||||
token_data_str = await redis.get(token_key)
|
||||
if not token_data_str:
|
||||
# Token not found, create new one
|
||||
return await self.write_token(user)
|
||||
|
||||
# Token exists, extend its lifetime
|
||||
token_data = json.loads(token_data_str)
|
||||
await redis.set(
|
||||
token_key,
|
||||
json.dumps(token_data),
|
||||
ex=self.lifetime_seconds,
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
|
||||
class RefreshableDatabaseStrategy(DatabaseStrategy[User, uuid.UUID, AccessToken]):
|
||||
"""Database strategy with token refreshing capabilities."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
access_token_db: AccessTokenDatabase[AccessToken],
|
||||
lifetime_seconds: Optional[int] = None,
|
||||
):
|
||||
super().__init__(access_token_db, lifetime_seconds)
|
||||
self._access_token_db = access_token_db
|
||||
|
||||
async def refresh_token(self, token: Optional[str], user: User) -> str:
|
||||
"""Refresh a token by updating its expiration time in the database."""
|
||||
if token is None:
|
||||
return await self.write_token(user)
|
||||
|
||||
# Find the token in database
|
||||
access_token = await self._access_token_db.get_by_token(token)
|
||||
|
||||
if access_token is None:
|
||||
# Token not found, create new one
|
||||
return await self.write_token(user)
|
||||
|
||||
# Update expiration time
|
||||
new_expires = datetime.now(timezone.utc) + timedelta(
|
||||
seconds=float(self.lifetime_seconds or SESSION_EXPIRE_TIME_SECONDS)
|
||||
)
|
||||
await self._access_token_db.update(access_token, {"expires": new_expires})
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def get_redis_strategy() -> TenantAwareRedisStrategy:
|
||||
return TenantAwareRedisStrategy()
|
||||
|
||||
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> RefreshableDatabaseStrategy:
|
||||
return RefreshableDatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
|
||||
)
|
||||
|
||||
|
||||
if AUTH_BACKEND == AuthBackend.REDIS:
|
||||
auth_backend = AuthenticationBackend(
|
||||
@@ -804,6 +904,88 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||
|
||||
return router
|
||||
|
||||
def get_refresh_router(
|
||||
self,
|
||||
backend: AuthenticationBackend,
|
||||
requires_verification: bool = REQUIRE_EMAIL_VERIFICATION,
|
||||
) -> APIRouter:
|
||||
"""
|
||||
Provide a router for session token refreshing.
|
||||
"""
|
||||
# Import the oauth_refresher here to avoid circular imports
|
||||
from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
get_current_user_token = self.authenticator.current_user_token(
|
||||
active=True, verified=requires_verification
|
||||
)
|
||||
|
||||
refresh_responses: OpenAPIResponseType = {
|
||||
**{
|
||||
status.HTTP_401_UNAUTHORIZED: {
|
||||
"description": "Missing token or inactive user."
|
||||
}
|
||||
},
|
||||
**backend.transport.get_openapi_login_responses_success(),
|
||||
}
|
||||
|
||||
@router.post(
|
||||
"/refresh", name=f"auth:{backend.name}.refresh", responses=refresh_responses
|
||||
)
|
||||
async def refresh(
|
||||
user_token: Tuple[models.UP, str] = Depends(get_current_user_token),
|
||||
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(
|
||||
get_user_manager
|
||||
),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
) -> Response:
|
||||
try:
|
||||
user, token = user_token
|
||||
logger.info(f"Processing token refresh request for user {user.email}")
|
||||
|
||||
# Check if user has OAuth accounts that need refreshing
|
||||
await check_and_refresh_oauth_tokens(
|
||||
user=cast(User, user),
|
||||
db_session=db_session,
|
||||
user_manager=cast(Any, user_manager),
|
||||
)
|
||||
|
||||
# Check if strategy supports refreshing
|
||||
supports_refresh = hasattr(strategy, "refresh_token") and callable(
|
||||
getattr(strategy, "refresh_token")
|
||||
)
|
||||
|
||||
if supports_refresh:
|
||||
try:
|
||||
refresh_method = getattr(strategy, "refresh_token")
|
||||
new_token = await refresh_method(token, user)
|
||||
logger.info(
|
||||
f"Successfully refreshed session token for user {user.email}"
|
||||
)
|
||||
return await backend.transport.get_login_response(new_token)
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing session token: {str(e)}")
|
||||
# Fallback to logout and login if refresh fails
|
||||
await backend.logout(strategy, user, token)
|
||||
return await backend.login(strategy, user)
|
||||
|
||||
# Fallback: logout and login again
|
||||
logger.info(
|
||||
"Strategy doesn't support refresh - using logout/login flow"
|
||||
)
|
||||
await backend.logout(strategy, user, token)
|
||||
return await backend.login(strategy, user)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in refresh endpoint: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Token refresh failed: {str(e)}",
|
||||
)
|
||||
|
||||
return router
|
||||
|
||||
|
||||
fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
|
||||
get_user_manager, [auth_backend]
|
||||
@@ -894,7 +1076,7 @@ async def current_limited_user(
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_chat_accesssible_user(
|
||||
async def current_chat_accessible_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
@@ -1037,12 +1219,20 @@ def get_oauth_router(
|
||||
"referral_source": referral_source or "default_referral",
|
||||
}
|
||||
state = generate_state_token(state_data, state_secret)
|
||||
|
||||
# Get the basic authorization URL
|
||||
authorization_url = await oauth_client.get_authorization_url(
|
||||
authorize_redirect_url,
|
||||
state,
|
||||
scopes,
|
||||
)
|
||||
|
||||
# For Google OAuth, add parameters to request refresh tokens
|
||||
if oauth_client.name == "google":
|
||||
authorization_url = add_url_params(
|
||||
authorization_url, {"access_type": "offline", "prompt": "consent"}
|
||||
)
|
||||
|
||||
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
|
||||
|
||||
@router.get(
|
||||
@@ -1095,6 +1285,12 @@ def get_oauth_router(
|
||||
|
||||
next_url = state_data.get("next_url", "/")
|
||||
referral_source = state_data.get("referral_source", None)
|
||||
try:
|
||||
tenant_id = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
|
||||
)(account_email)
|
||||
except exceptions.UserNotExists:
|
||||
tenant_id = None
|
||||
|
||||
request.state.referral_source = referral_source
|
||||
|
||||
@@ -1128,11 +1324,22 @@ def get_oauth_router(
|
||||
await user_manager.on_after_login(user, request, response)
|
||||
|
||||
# Prepare redirect response
|
||||
redirect_response = RedirectResponse(next_url, status_code=302)
|
||||
if tenant_id is None:
|
||||
# Use URL utility to add parameters
|
||||
redirect_url = add_url_params(next_url, {"new_team": "true"})
|
||||
redirect_response = RedirectResponse(redirect_url, status_code=302)
|
||||
else:
|
||||
# No parameters to add
|
||||
redirect_response = RedirectResponse(next_url, status_code=302)
|
||||
|
||||
# Copy headers and other attributes from 'response' to 'redirect_response'
|
||||
# Copy headers from auth response to redirect response, with special handling for Set-Cookie
|
||||
for header_name, header_value in response.headers.items():
|
||||
redirect_response.headers[header_name] = header_value
|
||||
# FastAPI can have multiple Set-Cookie headers as a list
|
||||
if header_name.lower() == "set-cookie" and isinstance(header_value, list):
|
||||
for cookie_value in header_value:
|
||||
redirect_response.headers.append(header_name, cookie_value)
|
||||
else:
|
||||
redirect_response.headers[header_name] = header_value
|
||||
|
||||
if hasattr(response, "body"):
|
||||
redirect_response.body = response.body
|
||||
@@ -1140,6 +1347,7 @@ def get_oauth_router(
|
||||
redirect_response.status_code = response.status_code
|
||||
if hasattr(response, "media_type"):
|
||||
redirect_response.media_type = response.media_type
|
||||
|
||||
return redirect_response
|
||||
|
||||
return router
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -34,7 +35,6 @@ from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGrou
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import ColoredFormatter
|
||||
from onyx.utils.logger import PlainFormatter
|
||||
@@ -225,7 +225,7 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout
|
||||
is reached."""
|
||||
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
@@ -306,12 +306,12 @@ def wait_for_db(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
|
||||
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
logger.info(f"Running as a secondary celery worker: pid={os.getpid()}")
|
||||
|
||||
# Set up variables for waiting on primary worker
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
|
||||
time_start = time.monotonic()
|
||||
|
||||
logger.info("Waiting for primary worker to be ready...")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
@@ -10,12 +9,10 @@ from celery.utils.log import get_task_logger
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -141,8 +138,6 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
"""Only updates the actual beat schedule on the celery app when it changes"""
|
||||
do_update = False
|
||||
|
||||
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
task_logger.debug("_try_updating_schedule starting")
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
@@ -152,16 +147,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
current_schedule = self.schedule.items()
|
||||
|
||||
# get potential new state
|
||||
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
|
||||
if beat_multiplier_raw is not None:
|
||||
try:
|
||||
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
|
||||
beat_multiplier = float(beat_multiplier_bytes.decode())
|
||||
except ValueError:
|
||||
task_logger.error(
|
||||
f"Invalid beat_multiplier value: {beat_multiplier_raw}"
|
||||
)
|
||||
beat_multiplier = OnyxRuntime.get_beat_multiplier()
|
||||
|
||||
new_schedule = self._generate_schedule(tenant_ids, beat_multiplier)
|
||||
|
||||
|
||||
7
backend/onyx/background/celery/apps/client.py
Normal file
7
backend/onyx/background/celery/apps/client.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from celery import Celery
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.client")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
@@ -111,5 +111,8 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.user_file_folder_sync",
|
||||
"onyx.background.celery.tasks.indexing",
|
||||
"onyx.background.celery.tasks.tenant_provisioning",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -92,5 +92,6 @@ def on_setup_logging(
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.tenant_provisioning",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -38,10 +39,11 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_stop import RedisConnectorStop
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -94,7 +96,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
logger.info(f"Running as the primary celery worker: pid={os.getpid()}")
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
@@ -102,7 +104,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker. This is unnecessary in the multi tenant scenario
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
|
||||
info: dict[str, Any] = cast(dict, r.info("replication"))
|
||||
@@ -173,6 +175,9 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
f"search_settings={attempt.search_settings_id}"
|
||||
)
|
||||
logger.warning(failure_reason)
|
||||
logger.exception(
|
||||
f"Marking attempt {attempt.id} as canceled due to validation error 2"
|
||||
)
|
||||
mark_attempt_canceled(attempt.id, db_session, failure_reason)
|
||||
|
||||
|
||||
@@ -235,7 +240,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
|
||||
lock: RedisLock = worker.primary_worker_lock
|
||||
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
if lock.owned():
|
||||
task_logger.debug("Reacquiring primary worker lock.")
|
||||
@@ -284,5 +289,6 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
"onyx.background.celery.tasks.user_file_folder_sync",
|
||||
]
|
||||
)
|
||||
|
||||
16
backend/onyx/background/celery/configs/client.py
Normal file
16
backend/onyx/background/celery/configs/client.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
73
backend/onyx/background/celery/memory_monitoring.py
Normal file
73
backend/onyx/background/celery/memory_monitoring.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# backend/onyx/background/celery/memory_monitoring.py
|
||||
import logging
|
||||
import os
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
import psutil
|
||||
|
||||
from onyx.utils.logger import is_running_in_container
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
# Regular application logger
|
||||
logger = setup_logger()
|
||||
|
||||
# Only set up memory monitoring in container environment
|
||||
if is_running_in_container():
|
||||
# Set up a dedicated memory monitoring logger
|
||||
MEMORY_LOG_DIR = "/var/log/memory"
|
||||
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
|
||||
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
|
||||
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files
|
||||
|
||||
# Ensure log directory exists
|
||||
os.makedirs(MEMORY_LOG_DIR, exist_ok=True)
|
||||
|
||||
# Create a dedicated logger for memory monitoring
|
||||
memory_logger = logging.getLogger("memory_monitoring")
|
||||
memory_logger.setLevel(logging.INFO)
|
||||
|
||||
# Create a rotating file handler
|
||||
memory_handler = RotatingFileHandler(
|
||||
MEMORY_LOG_FILE,
|
||||
maxBytes=MEMORY_LOG_MAX_BYTES,
|
||||
backupCount=MEMORY_LOG_BACKUP_COUNT,
|
||||
)
|
||||
|
||||
# Create a formatter that includes all relevant information
|
||||
memory_formatter = logging.Formatter(
|
||||
"%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
memory_handler.setFormatter(memory_formatter)
|
||||
memory_logger.addHandler(memory_handler)
|
||||
else:
|
||||
# Create a null logger when not in container
|
||||
memory_logger = logging.getLogger("memory_monitoring")
|
||||
memory_logger.addHandler(logging.NullHandler())
|
||||
|
||||
|
||||
def emit_process_memory(
|
||||
pid: int, process_name: str, additional_metadata: dict[str, str | int]
|
||||
) -> None:
|
||||
# Skip memory monitoring if not in container
|
||||
if not is_running_in_container():
|
||||
return
|
||||
|
||||
try:
|
||||
process = psutil.Process(pid)
|
||||
memory_info = process.memory_info()
|
||||
cpu_percent = process.cpu_percent(interval=0.1)
|
||||
|
||||
# Build metadata string from additional_metadata dictionary
|
||||
metadata_str = " ".join(
|
||||
[f"{key}={value}" for key, value in additional_metadata.items()]
|
||||
)
|
||||
metadata_str = f" {metadata_str}" if metadata_str else ""
|
||||
|
||||
memory_logger.info(
|
||||
f"PROCESS_MEMORY process_name={process_name} pid={pid} "
|
||||
f"rss_mb={memory_info.rss / (1024 * 1024):.2f} "
|
||||
f"vms_mb={memory_info.vms / (1024 * 1024):.2f} "
|
||||
f"cpu={cpu_percent:.2f}{metadata_str}"
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error monitoring process memory.")
|
||||
@@ -21,6 +21,7 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
|
||||
# we have a better implementation (backpressure, etc)
|
||||
# Note that DynamicTenantScheduler can adjust the runtime value for this via Redis
|
||||
CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0
|
||||
CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT = 1.0
|
||||
|
||||
# tasks that run in either self-hosted on cloud
|
||||
beat_task_templates: list[dict] = []
|
||||
@@ -63,6 +64,15 @@ beat_task_templates.extend(
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-user-file-folder-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_USER_FILE_FOLDER_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-pruning",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
@@ -167,6 +177,16 @@ beat_cloud_tasks: list[dict] = [
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-available-tenants",
|
||||
"task": OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
|
||||
"schedule": timedelta(minutes=10),
|
||||
"options": {
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# tasks that only run self hosted
|
||||
@@ -184,6 +204,16 @@ if not MULTI_TENANT:
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-process-memory",
|
||||
"task": OnyxCeleryTask.MONITOR_PROCESS_MEMORY,
|
||||
"schedule": timedelta(minutes=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -30,6 +30,9 @@ from onyx.db.connector_credential_pair import (
|
||||
)
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import (
|
||||
delete_all_documents_by_connector_credential_pair__no_commit,
|
||||
)
|
||||
from onyx.db.document import get_document_ids_for_connector_credential_pair
|
||||
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
@@ -386,6 +389,8 @@ def monitor_connector_deletion_taskset(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
credential_id_to_delete: int | None = None
|
||||
connector_id_to_delete: int | None = None
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
|
||||
@@ -440,16 +445,35 @@ def monitor_connector_deletion_taskset(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Store IDs before potentially expiring cc_pair
|
||||
connector_id_to_delete = cc_pair.connector_id
|
||||
credential_id_to_delete = cc_pair.credential_id
|
||||
|
||||
# Explicitly delete document by connector credential pair records before deleting the connector
|
||||
# This is needed because connector_id is a primary key in that table and cascading deletes won't work
|
||||
delete_all_documents_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id_to_delete,
|
||||
credential_id=credential_id_to_delete,
|
||||
)
|
||||
|
||||
# Flush to ensure document deletion happens before connector deletion
|
||||
db_session.flush()
|
||||
|
||||
# Expire the cc_pair to ensure SQLAlchemy doesn't try to manage its state
|
||||
# related to the deleted DocumentByConnectorCredentialPair during commit
|
||||
db_session.expire(cc_pair)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
connector_id=connector_id_to_delete,
|
||||
credential_id=credential_id_to_delete,
|
||||
)
|
||||
# if there are no credentials left, delete the connector
|
||||
connector = fetch_connector_by_id(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
connector_id=connector_id_to_delete,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
task_logger.info(
|
||||
@@ -482,15 +506,15 @@ def monitor_connector_deletion_taskset(
|
||||
|
||||
task_logger.exception(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
|
||||
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
|
||||
)
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
f"Connector deletion succeeded: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector={cc_pair.connector_id} "
|
||||
f"credential={cc_pair.credential_id} "
|
||||
f"connector={connector_id_to_delete} "
|
||||
f"credential={credential_id_to_delete} "
|
||||
f"docs_deleted={fence_data.num_tasks}"
|
||||
)
|
||||
|
||||
@@ -540,7 +564,7 @@ def validate_connector_deletion_fences(
|
||||
def validate_connector_deletion_fence(
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
queued_upsert_tasks: set[str],
|
||||
r: Redis,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
@@ -627,7 +651,7 @@ def validate_connector_deletion_fence(
|
||||
|
||||
member_bytes = cast(bytes, member)
|
||||
member_str = member_bytes.decode("utf-8")
|
||||
if member_str in queued_tasks:
|
||||
if member_str in queued_upsert_tasks:
|
||||
continue
|
||||
|
||||
tasks_not_in_celery += 1
|
||||
|
||||
@@ -17,6 +17,7 @@ from redis.exceptions import LockError
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.document import upsert_document_external_perms
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
|
||||
@@ -46,7 +47,6 @@ from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.connectors.factory import validate_ccpair_for_user
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
@@ -64,11 +64,14 @@ from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyn
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import doc_permission_sync_ctx
|
||||
from onyx.utils.logger import format_error_for_logging
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -105,9 +108,10 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
|
||||
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||
|
||||
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
|
||||
if not source_sync_period:
|
||||
return True
|
||||
source_sync_period = DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
|
||||
source_sync_period *= int(OnyxRuntime.get_doc_permission_sync_multiplier())
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
next_sync = last_perm_sync + timedelta(seconds=source_sync_period)
|
||||
@@ -285,7 +289,7 @@ def try_creating_permissions_sync_task(
|
||||
),
|
||||
queue=OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
# fill in the celery task id
|
||||
@@ -420,12 +424,7 @@ def connector_permission_sync_generator_task(
|
||||
task_logger.exception(
|
||||
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
|
||||
)
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
# TODO: add some notification to the admins here
|
||||
raise
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
@@ -453,23 +452,23 @@ def connector_permission_sync_generator_task(
|
||||
redis_connector.permissions.set_fence(new_payload)
|
||||
|
||||
callback = PermissionSyncCallback(redis_connector, lock, r)
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(
|
||||
cc_pair, callback
|
||||
)
|
||||
document_external_accesses = doc_sync_func(cc_pair, callback)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = redis_connector.permissions.generate_tasks(
|
||||
celery_app=self.app,
|
||||
lock=lock,
|
||||
new_permissions=document_external_accesses,
|
||||
source_string=source_type,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
tasks_generated = 0
|
||||
for doc_external_access in document_external_accesses:
|
||||
redis_connector.permissions.generate_tasks(
|
||||
celery_app=self.app,
|
||||
lock=lock,
|
||||
new_permissions=[doc_external_access],
|
||||
source_string=source_type,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
tasks_generated += 1
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.permissions.generate_tasks finished. "
|
||||
@@ -881,6 +880,18 @@ def monitor_ccpair_permissions_taskset(
|
||||
f"remaining={remaining} "
|
||||
f"initial={initial}"
|
||||
)
|
||||
|
||||
# Add telemetry for permission syncing progress
|
||||
optional_telemetry(
|
||||
record_type=RecordType.PERMISSION_SYNC_PROGRESS,
|
||||
data={
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"total_docs_synced": initial if initial is not None else 0,
|
||||
"remaining_docs_to_sync": remaining,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
@@ -892,6 +903,13 @@ def monitor_ccpair_permissions_taskset(
|
||||
f"num_synced={initial}"
|
||||
)
|
||||
|
||||
# Add telemetry for permission syncing complete
|
||||
optional_telemetry(
|
||||
record_type=RecordType.PERMISSION_SYNC_COMPLETE,
|
||||
data={"cc_pair_id": cc_pair_id},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
|
||||
@@ -41,7 +41,6 @@ from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.factory import validate_ccpair_for_user
|
||||
from onyx.db.connector import mark_cc_pair_as_external_group_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -272,7 +271,7 @@ def try_creating_external_group_sync_task(
|
||||
),
|
||||
queue=OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
payload.celery_task_id = result.id
|
||||
@@ -402,12 +401,7 @@ def connector_external_group_sync_generator_task(
|
||||
task_logger.exception(
|
||||
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
|
||||
)
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
# TODO: add some notification to the admins here
|
||||
raise
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
@@ -425,12 +419,9 @@ def connector_external_group_sync_generator_task(
|
||||
try:
|
||||
external_user_groups = ext_group_sync_func(tenant_id, cc_pair)
|
||||
except ConnectorValidationError as e:
|
||||
msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
# TODO: add some notification to the admins here
|
||||
logger.exception(
|
||||
f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
|
||||
from onyx.background.celery.tasks.indexing.utils import should_index
|
||||
@@ -71,6 +72,7 @@ from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_utils import is_fence
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
@@ -363,6 +365,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
Occcasionally does some validation of existing state to clear up error conditions"""
|
||||
|
||||
time_start = time.monotonic()
|
||||
task_logger.warning("check_for_indexing - Starting")
|
||||
|
||||
tasks_created = 0
|
||||
locked = False
|
||||
@@ -400,7 +403,11 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
logger.warning(f"Adding {key_bytes} to the lookup table.")
|
||||
redis_client.sadd(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
|
||||
redis_client.set(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE, 1, ex=300)
|
||||
redis_client.set(
|
||||
OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE,
|
||||
1,
|
||||
ex=OnyxRuntime.get_build_fence_lookup_table_interval(),
|
||||
)
|
||||
|
||||
# 1/3: KICKOFF
|
||||
|
||||
@@ -427,7 +434,9 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pairs = fetch_connector_credential_pairs(db_session)
|
||||
cc_pairs = fetch_connector_credential_pairs(
|
||||
db_session, include_user_files=True
|
||||
)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
@@ -446,12 +455,18 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
not search_settings_instance.status.is_current()
|
||||
and not search_settings_instance.background_reindex_enabled
|
||||
):
|
||||
task_logger.warning("SKIPPING DUE TO NON-LIVE SEARCH SETTINGS")
|
||||
|
||||
continue
|
||||
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings_instance.id
|
||||
)
|
||||
if redis_connector_index.fenced:
|
||||
task_logger.info(
|
||||
f"check_for_indexing - Skipping fenced connector: "
|
||||
f"cc_pair={cc_pair_id} search_settings={search_settings_instance.id}"
|
||||
)
|
||||
continue
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
@@ -459,6 +474,9 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"check_for_indexing - CC pair not found: cc_pair={cc_pair_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
@@ -472,7 +490,20 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
secondary_index_building=len(search_settings_list) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
task_logger.info(
|
||||
f"check_for_indexing - Not indexing cc_pair_id: {cc_pair_id} "
|
||||
f"search_settings={search_settings_instance.id}, "
|
||||
f"last_attempt={last_attempt.id if last_attempt else None}, "
|
||||
f"secondary_index_building={len(search_settings_list) > 1}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
task_logger.info(
|
||||
f"check_for_indexing - Will index cc_pair_id: {cc_pair_id} "
|
||||
f"search_settings={search_settings_instance.id}, "
|
||||
f"last_attempt={last_attempt.id if last_attempt else None}, "
|
||||
f"secondary_index_building={len(search_settings_list) > 1}"
|
||||
)
|
||||
|
||||
reindex = False
|
||||
if search_settings_instance.status.is_current():
|
||||
@@ -511,6 +542,12 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
f"search_settings={search_settings_instance.id}"
|
||||
)
|
||||
tasks_created += 1
|
||||
else:
|
||||
task_logger.info(
|
||||
f"Failed to create indexing task: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings_instance.id}"
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
@@ -984,6 +1021,9 @@ def connector_indexing_proxy_task(
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
# Track the last time memory info was emitted
|
||||
last_memory_emit_time = 0.0
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
@@ -1024,6 +1064,23 @@ def connector_indexing_proxy_task(
|
||||
job.release()
|
||||
break
|
||||
|
||||
# log the memory usage for tracking down memory leaks / connector-specific memory issues
|
||||
pid = job.process.pid
|
||||
if pid is not None:
|
||||
# Only emit memory info once per minute (60 seconds)
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_memory_emit_time >= 60.0:
|
||||
emit_process_memory(
|
||||
pid,
|
||||
"indexing_worker",
|
||||
{
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"search_settings_id": search_settings_id,
|
||||
"index_attempt_id": index_attempt_id,
|
||||
},
|
||||
)
|
||||
last_memory_emit_time = current_time
|
||||
|
||||
# if a termination signal is detected, break (exit point will clean up)
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
@@ -1123,6 +1180,9 @@ def connector_indexing_proxy_task(
|
||||
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
logger.exception(
|
||||
f"Marking attempt {index_attempt_id} as canceled due to termination signal"
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
@@ -1170,6 +1230,7 @@ def connector_indexing_proxy_task(
|
||||
return
|
||||
|
||||
|
||||
# primary
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
|
||||
soft_time_limit=300,
|
||||
@@ -1217,6 +1278,7 @@ def check_for_checkpoint_cleanup(*, tenant_id: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
# light worker
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLEANUP_CHECKPOINT,
|
||||
bind=True,
|
||||
|
||||
@@ -371,6 +371,7 @@ def should_index(
|
||||
|
||||
# don't kick off indexing for `NOT_APPLICABLE` sources
|
||||
if connector.source == DocumentSource.NOT_APPLICABLE:
|
||||
print(f"Not indexing cc_pair={cc_pair.id}: NOT_APPLICABLE source")
|
||||
return False
|
||||
|
||||
# User can still manually create single indexing attempts via the UI for the
|
||||
@@ -380,6 +381,9 @@ def should_index(
|
||||
search_settings_instance.status == IndexModelStatus.PRESENT
|
||||
and secondary_index_building
|
||||
):
|
||||
print(
|
||||
f"Not indexing cc_pair={cc_pair.id}: DISABLE_INDEX_UPDATE_ON_SWAP is True and secondary index building"
|
||||
)
|
||||
return False
|
||||
|
||||
# When switching over models, always index at least once
|
||||
@@ -388,19 +392,31 @@ def should_index(
|
||||
# No new index if the last index attempt succeeded
|
||||
# Once is enough. The model will never be able to swap otherwise.
|
||||
if last_index.status == IndexingStatus.SUCCESS:
|
||||
print(
|
||||
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with successful last index attempt={last_index.id}"
|
||||
)
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is waiting to start
|
||||
if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
print(
|
||||
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with NOT_STARTED last index attempt={last_index.id}"
|
||||
)
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is running
|
||||
if last_index.status == IndexingStatus.IN_PROGRESS:
|
||||
print(
|
||||
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with IN_PROGRESS last index attempt={last_index.id}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
if (
|
||||
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
|
||||
): # Ingestion API
|
||||
print(
|
||||
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with Ingestion API source"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -412,6 +428,9 @@ def should_index(
|
||||
or connector.id == 0
|
||||
or connector.source == DocumentSource.INGESTION_API
|
||||
):
|
||||
print(
|
||||
f"Not indexing cc_pair={cc_pair.id}: Connector is paused or is Ingestion API"
|
||||
)
|
||||
return False
|
||||
|
||||
if search_settings_instance.status.is_current():
|
||||
@@ -424,11 +443,16 @@ def should_index(
|
||||
return True
|
||||
|
||||
if connector.refresh_freq is None:
|
||||
print(f"Not indexing cc_pair={cc_pair.id}: refresh_freq is None")
|
||||
return False
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
time_since_index = current_db_time - last_index.time_updated
|
||||
if time_since_index.total_seconds() < connector.refresh_freq:
|
||||
print(
|
||||
f"Not indexing cc_pair={cc_pair.id}: Last index attempt={last_index.id} "
|
||||
f"too recent ({time_since_index.total_seconds()}s < {connector.refresh_freq}s)"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -508,6 +532,13 @@ def try_creating_indexing_task(
|
||||
|
||||
custom_task_id = redis_connector_index.generate_generator_task_id()
|
||||
|
||||
# Determine which queue to use based on whether this is a user file
|
||||
queue = (
|
||||
OnyxCeleryQueues.USER_FILES_INDEXING
|
||||
if cc_pair.is_user_file
|
||||
else OnyxCeleryQueues.CONNECTOR_INDEXING
|
||||
)
|
||||
|
||||
# when the task is sent, we have yet to finish setting up the fence
|
||||
# therefore, the task must contain code that blocks until the fence is ready
|
||||
result = celery_app.send_task(
|
||||
@@ -518,7 +549,7 @@ def try_creating_indexing_task(
|
||||
search_settings_id=search_settings.id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=OnyxCeleryQueues.CONNECTOR_INDEXING,
|
||||
queue=queue,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from itertools import islice
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
import psutil
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
@@ -19,6 +20,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -39,8 +41,10 @@ from onyx.db.models import UserGroup
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.utils.logger import is_running_in_container
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
_MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||
@@ -904,3 +908,93 @@ def monitor_celery_queues_helper(
|
||||
f"external_group_sync={n_external_group_sync} "
|
||||
f"permissions_upsert={n_permissions_upsert} "
|
||||
)
|
||||
|
||||
|
||||
"""Memory monitoring"""
|
||||
|
||||
|
||||
def _get_cmdline_for_process(process: psutil.Process) -> str | None:
|
||||
try:
|
||||
return " ".join(process.cmdline())
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||
return None
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.MONITOR_PROCESS_MEMORY,
|
||||
ignore_result=True,
|
||||
soft_time_limit=_MONITORING_SOFT_TIME_LIMIT,
|
||||
time_limit=_MONITORING_TIME_LIMIT,
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
bind=True,
|
||||
)
|
||||
def monitor_process_memory(self: Task, *, tenant_id: str) -> None:
|
||||
"""
|
||||
Task to monitor memory usage of supervisor-managed processes.
|
||||
This periodically checks the memory usage of processes and logs information
|
||||
in a standardized format.
|
||||
|
||||
The task looks for processes managed by supervisor and logs their
|
||||
memory usage statistics. This is useful for monitoring memory consumption
|
||||
over time and identifying potential memory leaks.
|
||||
"""
|
||||
# don't run this task in multi-tenant mode, have other, better means of monitoring
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
# Skip memory monitoring if not in container
|
||||
if not is_running_in_container():
|
||||
return
|
||||
|
||||
try:
|
||||
# Get all supervisor-managed processes
|
||||
supervisor_processes: dict[int, str] = {}
|
||||
|
||||
# Map cmd line elements to more readable process names
|
||||
process_type_mapping = {
|
||||
"--hostname=primary": "primary",
|
||||
"--hostname=light": "light",
|
||||
"--hostname=heavy": "heavy",
|
||||
"--hostname=indexing": "indexing",
|
||||
"--hostname=monitoring": "monitoring",
|
||||
"beat": "beat",
|
||||
"slack/listener.py": "slack",
|
||||
}
|
||||
|
||||
# Find all python processes that are likely celery workers
|
||||
for proc in psutil.process_iter():
|
||||
cmdline = _get_cmdline_for_process(proc)
|
||||
if not cmdline:
|
||||
continue
|
||||
|
||||
# Match supervisor-managed processes
|
||||
for process_name, process_type in process_type_mapping.items():
|
||||
if process_name in cmdline:
|
||||
if process_type in supervisor_processes.values():
|
||||
task_logger.error(
|
||||
f"Duplicate process type for type {process_type} "
|
||||
f"with cmd {cmdline} with pid={proc.pid}."
|
||||
)
|
||||
continue
|
||||
|
||||
supervisor_processes[proc.pid] = process_type
|
||||
break
|
||||
|
||||
if len(supervisor_processes) != len(process_type_mapping):
|
||||
task_logger.error(
|
||||
"Missing processes: "
|
||||
f"{set(process_type_mapping.keys()).symmetric_difference(supervisor_processes.values())}"
|
||||
)
|
||||
|
||||
# Log memory usage for each process
|
||||
for pid, process_type in supervisor_processes.items():
|
||||
try:
|
||||
emit_process_memory(pid, process_type, {})
|
||||
except psutil.NoSuchProcess:
|
||||
# Process may have terminated since we obtained the list
|
||||
continue
|
||||
except Exception as e:
|
||||
task_logger.exception(f"Error monitoring process {pid}: {str(e)}")
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Error in monitor_process_memory task")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user