mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-19 16:55:46 +00:00
Compare commits
23 Commits
cloud_buil
...
gating
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d66bdd061 | ||
|
|
8f67f1715c | ||
|
|
3b365509e2 | ||
|
|
022cbdfccf | ||
|
|
ebec6f6b10 | ||
|
|
1cad9c7b3d | ||
|
|
b4e975013c | ||
|
|
dd26f92206 | ||
|
|
4d00ec45ad | ||
|
|
1a81c67a67 | ||
|
|
04f965e656 | ||
|
|
277d37e0ee | ||
|
|
3cd260131b | ||
|
|
ad21ee0e9a | ||
|
|
c7dc0e9af0 | ||
|
|
75c5de802b | ||
|
|
c39f590d0d | ||
|
|
82a9fda846 | ||
|
|
842d4ab2a8 | ||
|
|
cddcec4ea4 | ||
|
|
09dd7b424c | ||
|
|
a2fd8d5e0a | ||
|
|
802dc00f78 |
30
.github/pull_request_template.md
vendored
30
.github/pull_request_template.md
vendored
@@ -6,24 +6,20 @@
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
|
||||
## Accepted Risk (provide if relevant)
|
||||
N/A
|
||||
## Accepted Risk
|
||||
[Any know risks or failure modes to point out to reviewers]
|
||||
|
||||
|
||||
## Related Issue(s) (provide if relevant)
|
||||
N/A
|
||||
## Related Issue(s)
|
||||
[If applicable, link to the issue(s) this PR addresses]
|
||||
|
||||
|
||||
## Mental Checklist:
|
||||
- All of the automated tests pass
|
||||
- All PR comments are addressed and marked resolved
|
||||
- If there are migrations, they have been rebased to latest main
|
||||
- If there are new dependencies, they are added to the requirements
|
||||
- If there are new environment variables, they are added to all of the deployment methods
|
||||
- If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
|
||||
- Docker images build and basic functionalities work
|
||||
- Author has done a final read through of the PR right before merge
|
||||
|
||||
## Backporting (check the box to trigger backport action)
|
||||
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
|
||||
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
|
||||
## Checklist:
|
||||
- [ ] All of the automated tests pass
|
||||
- [ ] All PR comments are addressed and marked resolved
|
||||
- [ ] If there are migrations, they have been rebased to latest main
|
||||
- [ ] If there are new dependencies, they are added to the requirements
|
||||
- [ ] If there are new environment variables, they are added to all of the deployment methods
|
||||
- [ ] If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
|
||||
- [ ] Docker images build and basic functionalities work
|
||||
- [ ] Author has done a final read through of the PR right before merge
|
||||
|
||||
23
.github/workflows/nightly-close-stale-issues.yml
vendored
23
.github/workflows/nightly-close-stale-issues.yml
vendored
@@ -1,23 +0,0 @@
|
||||
name: 'Nightly - Close stale issues and PRs'
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 11 * * *' # Runs every day at 3 AM PST / 4 AM PDT / 11 AM UTC
|
||||
|
||||
permissions:
|
||||
# contents: write # only for delete-branch option
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
close-issue-message: 'This issue was closed because it has been stalled for 90 days with no activity.'
|
||||
close-pr-message: 'This PR was closed because it has been stalled for 90 days with no activity.'
|
||||
days-before-stale: 75
|
||||
# days-before-close: 90 # uncomment after we test stale behavior
|
||||
|
||||
65
.github/workflows/pr-Integration-tests.yml
vendored
65
.github/workflows/pr-Integration-tests.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
@@ -85,58 +85,7 @@ jobs:
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
# Start containers for multi-tenant tests
|
||||
- name: Start Docker containers for multi-tenant tests
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
MULTI_TENANT=true \
|
||||
AUTH_TYPE=basic \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
id: start_docker_multi_tenant
|
||||
|
||||
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
|
||||
- name: Run Multi-Tenant Integration Tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AUTH_TYPE=cloud \
|
||||
-e MULTI_TENANT=true \
|
||||
danswer/danswer-integration:test \
|
||||
/app/tests/integration/multitenant_tests
|
||||
continue-on-error: true
|
||||
id: run_multitenant_tests
|
||||
|
||||
- name: Check multi-tenant test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
- name: Stop multi-tenant Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
|
||||
- name: Start Docker containers
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
@@ -181,7 +130,7 @@ jobs:
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run Standard Integration Tests
|
||||
- name: Run integration tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
@@ -196,8 +145,7 @@ jobs:
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
danswer/danswer-integration:test \
|
||||
/app/tests/integration/tests
|
||||
danswer/danswer-integration:test
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
@@ -210,11 +158,6 @@ jobs:
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
- name: Stop Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
run: |
|
||||
|
||||
98
.github/workflows/pr-backport-autotrigger.yml
vendored
98
.github/workflows/pr-backport-autotrigger.yml
vendored
@@ -1,98 +0,0 @@
|
||||
name: Backport on Merge
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [closed] # Later we check for merge so only PRs that go in can get backported
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
actions: write # Allows the workflow to trigger other workflows on push
|
||||
|
||||
jobs:
|
||||
backport:
|
||||
if: github.event.pull_request.merged == true
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0 # Fetch all history for all branches and tags
|
||||
|
||||
- name: Set up Git
|
||||
run: |
|
||||
git config --global user.name "yuhongsun96"
|
||||
git config --global user.email "yuhongsun96@gmail.com"
|
||||
# Configure Git to use the PAT for authentication
|
||||
git remote set-url origin https://yuhongsun96:${{ secrets.YUHONG_GH_ACTIONS }}@github.com/${{ github.repository }}.git
|
||||
|
||||
- name: Check for Backport Checkbox
|
||||
id: checkbox-check
|
||||
run: |
|
||||
PR_BODY="${{ github.event.pull_request.body }}"
|
||||
if [[ "$PR_BODY" == *"[x] This PR should be backported"* ]]; then
|
||||
echo "backport=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "backport=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: List and sort release branches
|
||||
id: list-branches
|
||||
run: |
|
||||
git fetch --all --tags
|
||||
BRANCHES=$(git for-each-ref --format='%(refname:short)' refs/remotes/origin/release/* | sed 's|origin/release/||' | sort -Vr)
|
||||
BETA=$(echo "$BRANCHES" | head -n 1)
|
||||
STABLE=$(echo "$BRANCHES" | head -n 2 | tail -n 1)
|
||||
echo "beta=release/$BETA" >> $GITHUB_OUTPUT
|
||||
echo "stable=release/$STABLE" >> $GITHUB_OUTPUT
|
||||
# Fetch latest tags for beta and stable
|
||||
LATEST_BETA_TAG=$(git tag -l "v[0-9]*.[0-9]*.[0-9]*-beta.[0-9]*" | grep -E "^v[0-9]+\.[0-9]+\.[0-9]+-beta\.[0-9]+$" | grep -v -- "-cloud" | sort -Vr | head -n 1)
|
||||
LATEST_STABLE_TAG=$(git tag -l "v[0-9]*.[0-9]*.[0-9]*" | grep -E "^v[0-9]+\.[0-9]+\.[0-9]+$" | sort -Vr | head -n 1)
|
||||
# Increment latest beta tag
|
||||
NEW_BETA_TAG=$(echo $LATEST_BETA_TAG | awk -F '[.-]' '{print $1 "." $2 "." $3 "-beta." ($NF+1)}')
|
||||
# Increment latest stable tag
|
||||
NEW_STABLE_TAG=$(echo $LATEST_STABLE_TAG | awk -F '.' '{print $1 "." $2 "." ($3+1)}')
|
||||
echo "latest_beta_tag=$LATEST_BETA_TAG" >> $GITHUB_OUTPUT
|
||||
echo "latest_stable_tag=$LATEST_STABLE_TAG" >> $GITHUB_OUTPUT
|
||||
echo "new_beta_tag=$NEW_BETA_TAG" >> $GITHUB_OUTPUT
|
||||
echo "new_stable_tag=$NEW_STABLE_TAG" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Echo branch and tag information
|
||||
run: |
|
||||
echo "Beta branch: ${{ steps.list-branches.outputs.beta }}"
|
||||
echo "Stable branch: ${{ steps.list-branches.outputs.stable }}"
|
||||
echo "Latest beta tag: ${{ steps.list-branches.outputs.latest_beta_tag }}"
|
||||
echo "Latest stable tag: ${{ steps.list-branches.outputs.latest_stable_tag }}"
|
||||
echo "New beta tag: ${{ steps.list-branches.outputs.new_beta_tag }}"
|
||||
echo "New stable tag: ${{ steps.list-branches.outputs.new_stable_tag }}"
|
||||
|
||||
- name: Trigger Backport
|
||||
if: steps.checkbox-check.outputs.backport == 'true'
|
||||
run: |
|
||||
set -e
|
||||
echo "Backporting to beta ${{ steps.list-branches.outputs.beta }} and stable ${{ steps.list-branches.outputs.stable }}"
|
||||
# Fetch all history for all branches and tags
|
||||
git fetch --prune
|
||||
# Checkout the beta branch
|
||||
git checkout ${{ steps.list-branches.outputs.beta }}
|
||||
# Cherry-pick the merge commit from the merged PR
|
||||
git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
|
||||
echo "Cherry-pick to beta failed due to conflicts."
|
||||
exit 1
|
||||
}
|
||||
# Create new beta tag
|
||||
git tag ${{ steps.list-branches.outputs.new_beta_tag }}
|
||||
# Push the changes and tag to the beta branch
|
||||
git push origin ${{ steps.list-branches.outputs.beta }} --force
|
||||
git push origin ${{ steps.list-branches.outputs.new_beta_tag }}
|
||||
# Checkout the stable branch
|
||||
git checkout ${{ steps.list-branches.outputs.stable }}
|
||||
# Cherry-pick the merge commit from the merged PR
|
||||
git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
|
||||
echo "Cherry-pick to stable failed due to conflicts."
|
||||
exit 1
|
||||
}
|
||||
# Create new stable tag
|
||||
git tag ${{ steps.list-branches.outputs.new_stable_tag }}
|
||||
# Push the changes and tag to the stable branch
|
||||
git push origin ${{ steps.list-branches.outputs.stable }} --force
|
||||
git push origin ${{ steps.list-branches.outputs.new_stable_tag }}
|
||||
300
.vscode/launch.template.jsonc
vendored
300
.vscode/launch.template.jsonc
vendored
@@ -6,69 +6,19 @@
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"compounds": [
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Compound ---",
|
||||
"configurations": [
|
||||
"--- Individual ---"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Run All Danswer Services",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
"Slack Bot",
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery beat",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web / Model / API",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery (all)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery beat"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
}
|
||||
"Indexing",
|
||||
"Background Jobs",
|
||||
"Slack Bot"
|
||||
]
|
||||
}
|
||||
],
|
||||
"configurations": [
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Individual ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web Server",
|
||||
"type": "node",
|
||||
@@ -79,11 +29,7 @@
|
||||
"runtimeArgs": [
|
||||
"run", "dev"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"console": "integratedTerminal",
|
||||
"consoleTitle": "Web Server Console"
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "Model Server",
|
||||
@@ -102,11 +48,7 @@
|
||||
"--reload",
|
||||
"--port",
|
||||
"9000"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Model Server Console"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "API Server",
|
||||
@@ -126,13 +68,43 @@
|
||||
"--reload",
|
||||
"--port",
|
||||
"8080"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "API Server Console"
|
||||
]
|
||||
},
|
||||
// For the listener to access the Slack API,
|
||||
{
|
||||
"name": "Indexing",
|
||||
"consoleName": "Indexing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "danswer/background/update.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
}
|
||||
},
|
||||
// Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev
|
||||
{
|
||||
"name": "Background Jobs",
|
||||
"consoleName": "Background Jobs",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"--no-indexing"
|
||||
]
|
||||
},
|
||||
// For the listner to access the Slack API,
|
||||
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
@@ -146,151 +118,7 @@
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery primary",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"danswer.background.celery.versioned_apps.primary",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=primary@%n",
|
||||
"-Q",
|
||||
"celery",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery primary Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery light",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"danswer.background.celery.versioned_apps.light",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=64",
|
||||
"--prefetch-multiplier=8",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery light Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery heavy",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"danswer.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery indexing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"danswer.background.celery.versioned_apps.indexing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=indexing@%n",
|
||||
"-Q",
|
||||
"connector_indexing",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery indexing Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery beat",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"danswer.background.celery.versioned_apps.beat",
|
||||
"beat",
|
||||
"--loglevel=INFO",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery beat Console"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
@@ -309,22 +137,8 @@
|
||||
"-v"
|
||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Pytest Console"
|
||||
]
|
||||
},
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Tasks ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "3",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart External Volumes and Containers",
|
||||
"type": "node",
|
||||
@@ -333,27 +147,7 @@
|
||||
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"stopOnEntry": true,
|
||||
"presentation": {
|
||||
"group": "3",
|
||||
},
|
||||
},
|
||||
{
|
||||
// Celery jobs launched through a single background script (legacy)
|
||||
// Recommend using the "Celery (all)" compound launch instead.
|
||||
"name": "Background Jobs",
|
||||
"consoleName": "Background Jobs",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
},
|
||||
"stopOnEntry": true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ We also have built-in support for deployment on Kubernetes. Files for that can b
|
||||
|
||||
## 🚧 Roadmap
|
||||
* Chat/Prompt sharing with specific teammates and user groups.
|
||||
* Multimodal model support, chat with images, video etc.
|
||||
* Multi-Model model support, chat with images, video etc.
|
||||
* Choosing between LLMs and parameters during chat session.
|
||||
* Tool calling and agent configurations options.
|
||||
* Organizational understanding and ability to locate and suggest experts from your team.
|
||||
|
||||
@@ -13,8 +13,7 @@ from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.models import Base
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from danswer.background.celery.celery_app import get_all_tenant_ids
|
||||
|
||||
# Alembic Config object
|
||||
config = context.config
|
||||
@@ -58,15 +57,11 @@ def get_schema_options() -> tuple[str, bool, bool]:
|
||||
if "=" in pair:
|
||||
key, value = pair.split("=", 1)
|
||||
x_args[key.strip()] = value.strip()
|
||||
schema_name = x_args.get("schema", POSTGRES_DEFAULT_SCHEMA)
|
||||
schema_name = x_args.get("schema", "public")
|
||||
create_schema = x_args.get("create_schema", "true").lower() == "true"
|
||||
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
|
||||
|
||||
if (
|
||||
MULTI_TENANT
|
||||
and schema_name == POSTGRES_DEFAULT_SCHEMA
|
||||
and not upgrade_all_tenants
|
||||
):
|
||||
if MULTI_TENANT and schema_name == "public":
|
||||
raise ValueError(
|
||||
"Cannot run default migrations in public schema when multi-tenancy is enabled. "
|
||||
"Please specify a tenant-specific schema."
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
"""Migrate chat_session and chat_message tables to use UUID primary keys
|
||||
|
||||
"""
|
||||
Revision ID: 6756efa39ada
|
||||
Revises: 5d12a446f5c0
|
||||
Create Date: 2024-10-15 17:47:44.108537
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
@@ -14,6 +12,8 @@ branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
"""
|
||||
Migrate chat_session and chat_message tables to use UUID primary keys.
|
||||
|
||||
This script:
|
||||
1. Adds UUID columns to chat_session and chat_message
|
||||
2. Populates new columns with UUIDs
|
||||
|
||||
@@ -31,12 +31,6 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# First, update any null values to a default value
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET last_attempt_status = 'NOT_STARTED' WHERE last_attempt_status IS NULL"
|
||||
)
|
||||
|
||||
# Then, make the column non-nullable
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"last_attempt_status",
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
import os
|
||||
|
||||
__version__ = os.environ.get("DANSWER_VERSION", "") or "Development"
|
||||
__version__ = os.environ.get("DANSWER_VERSION", "") or "0.3-dev"
|
||||
|
||||
@@ -70,12 +70,3 @@ class DocumentAccess(ExternalAccess):
|
||||
user_groups=set(user_groups),
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
|
||||
default_public_access = DocumentAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
user_emails=set(),
|
||||
user_groups=set(),
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
@@ -9,7 +9,6 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
def get_invited_users() -> list[str]:
|
||||
try:
|
||||
store = get_kv_store()
|
||||
|
||||
return cast(list, store.load(KV_USER_STORE_KEY))
|
||||
except KvKeyNotFoundError:
|
||||
return list()
|
||||
|
||||
@@ -58,7 +58,6 @@ from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import DISABLE_VERIFICATION
|
||||
from danswer.configs.app_configs import EMAIL_FROM
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
@@ -94,9 +93,7 @@ from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
from shared_configs.configs import current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -135,9 +132,7 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
|
||||
def user_needs_to_be_verified() -> bool:
|
||||
# all other auth types besides basic should require users to be
|
||||
# verified
|
||||
return not DISABLE_VERIFICATION and (
|
||||
AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
)
|
||||
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
@@ -192,7 +187,7 @@ def verify_email_domain(email: str) -> None:
|
||||
|
||||
def get_tenant_id_for_email(email: str) -> str:
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
return "public"
|
||||
# Implement logic to get tenant_id from the mapping table
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
result = db_session.execute(
|
||||
@@ -240,9 +235,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
) -> User:
|
||||
try:
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(user_create.email)
|
||||
if MULTI_TENANT
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
get_tenant_id_for_email(user_create.email) if MULTI_TENANT else "public"
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
@@ -253,7 +246,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
)
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
@@ -292,7 +285,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
else:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
current_tenant_id.reset(token)
|
||||
return user
|
||||
|
||||
async def on_after_login(
|
||||
@@ -334,9 +327,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
# Get tenant_id from mapping table
|
||||
try:
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(account_email)
|
||||
if MULTI_TENANT
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
get_tenant_id_for_email(account_email) if MULTI_TENANT else "public"
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
@@ -346,11 +337,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
token = None
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
|
||||
verify_email_in_whitelist(account_email, tenant_id)
|
||||
verify_email_domain(account_email)
|
||||
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
|
||||
self.user_db = tenant_user_db
|
||||
@@ -436,7 +426,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user.oidc_expiry = None # type: ignore
|
||||
|
||||
if token:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@@ -1,310 +0,0 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import sentry_sdk
|
||||
from celery import Task
|
||||
from celery.app import trace
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.states import READY_STATES
|
||||
from celery.utils.log import get_task_logger
|
||||
from celery.worker import strategy # type: ignore
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
|
||||
from danswer.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
|
||||
from danswer.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import ColoredFormatter
|
||||
from danswer.utils.logger import PlainFormatter
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
if SENTRY_DSN:
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[CeleryIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
|
||||
|
||||
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
"""We handle this signal in order to remove completed tasks
|
||||
from their respective tasksets. This allows us to track the progress of document set
|
||||
and user group syncs.
|
||||
|
||||
This function runs after any task completes (both success and failure)
|
||||
Note that this signal does not fire on a task that failed to complete and is going
|
||||
to be retried.
|
||||
|
||||
This also does not fire if a worker with acks_late=False crashes (which all of our
|
||||
long running workers are)
|
||||
"""
|
||||
if not task:
|
||||
return
|
||||
|
||||
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}")
|
||||
|
||||
if state not in READY_STATES:
|
||||
return
|
||||
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
|
||||
if not kwargs:
|
||||
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
|
||||
tenant_id = None
|
||||
else:
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
|
||||
task_logger.debug(
|
||||
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
|
||||
f"{f'for tenant_id={tenant_id}' if tenant_id else ''}"
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
|
||||
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisDocumentSet.PREFIX):
|
||||
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
|
||||
if document_set_id is not None:
|
||||
rds = RedisDocumentSet(int(document_set_id))
|
||||
r.srem(rds.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisUserGroup.PREFIX):
|
||||
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
|
||||
if usergroup_id is not None:
|
||||
rug = RedisUserGroup(int(usergroup_id))
|
||||
r.srem(rug.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorDeletion.PREFIX):
|
||||
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
rcd = RedisConnectorDeletion(int(cc_pair_id))
|
||||
r.srem(rcd.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
rcp = RedisConnectorPruning(int(cc_pair_id))
|
||||
r.srem(rcp.taskset_key, task_id)
|
||||
return
|
||||
|
||||
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
"""The first signal sent on celery worker startup"""
|
||||
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
|
||||
|
||||
|
||||
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
time_start = time.monotonic()
|
||||
logger.info("Redis: Readiness check starting.")
|
||||
while True:
|
||||
try:
|
||||
if r.ping():
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
logger.info(
|
||||
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
f"Redis: Readiness check did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
logger.info("Redis: Readiness check succeeded. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
logger.info("Waiting for all tenant primary workers to be ready...")
|
||||
time_start = time.monotonic()
|
||||
|
||||
while True:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
# Check if we have a primary worker lock for each tenant
|
||||
all_tenants_ready = all(
|
||||
get_redis_client(tenant_id=tenant_id).exists(
|
||||
DanswerRedisLocks.PRIMARY_WORKER
|
||||
)
|
||||
for tenant_id in tenant_ids
|
||||
)
|
||||
|
||||
if all_tenants_ready:
|
||||
break
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
ready_tenants = sum(
|
||||
1
|
||||
for tenant_id in tenant_ids
|
||||
if get_redis_client(tenant_id=tenant_id).exists(
|
||||
DanswerRedisLocks.PRIMARY_WORKER
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Not all tenant primary workers are ready yet. "
|
||||
f"Ready tenants: {ready_tenants}/{len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
f"Not all tenant primary workers were ready within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
logger.info("All tenant primary workers are ready. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
task_logger.info("worker_ready signal received.")
|
||||
|
||||
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
if not celery_is_worker_primary(sender):
|
||||
return
|
||||
|
||||
if not hasattr(sender, "primary_worker_locks"):
|
||||
return
|
||||
|
||||
for tenant_id, lock in sender.primary_worker_locks.items():
|
||||
try:
|
||||
if lock and lock.owned():
|
||||
logger.debug(f"Attempting to release lock for tenant {tenant_id}")
|
||||
try:
|
||||
lock.release()
|
||||
logger.debug(f"Successfully released lock for tenant {tenant_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to release lock for tenant {tenant_id}. Error: {str(e)}"
|
||||
)
|
||||
finally:
|
||||
sender.primary_worker_locks[tenant_id] = None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error checking lock status for tenant {tenant_id}. Error: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
# TODO: could unhardcode format and colorize and accept these as options from
|
||||
# celery's config
|
||||
|
||||
# reformats the root logger
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
root_handler = logging.StreamHandler() # Set up a handler for the root logger
|
||||
root_formatter = ColoredFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
root_handler.setFormatter(root_formatter)
|
||||
root_logger.addHandler(root_handler) # Apply the handler to the root logger
|
||||
|
||||
if logfile:
|
||||
root_file_handler = logging.FileHandler(logfile)
|
||||
root_file_formatter = PlainFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
root_file_handler.setFormatter(root_file_formatter)
|
||||
root_logger.addHandler(root_file_handler)
|
||||
|
||||
root_logger.setLevel(loglevel)
|
||||
|
||||
# reformats celery's task logger
|
||||
task_formatter = CeleryTaskColoredFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
task_handler = logging.StreamHandler() # Set up a handler for the task logger
|
||||
task_handler.setFormatter(task_formatter)
|
||||
task_logger.addHandler(task_handler) # Apply the handler to the task logger
|
||||
|
||||
if logfile:
|
||||
task_file_handler = logging.FileHandler(logfile)
|
||||
task_file_formatter = CeleryTaskPlainFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
task_file_handler.setFormatter(task_file_formatter)
|
||||
task_logger.addHandler(task_file_handler)
|
||||
|
||||
task_logger.setLevel(loglevel)
|
||||
task_logger.propagate = False
|
||||
|
||||
# hide celery task received spam
|
||||
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] received"
|
||||
strategy.logger.setLevel(logging.WARNING)
|
||||
|
||||
# hide celery task succeeded/failed spam
|
||||
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] succeeded in 0.03137450001668185s: None"
|
||||
trace.logger.setLevel(logging.WARNING)
|
||||
@@ -1,99 +0,0 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery.signals import beat_init
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("danswer.background.celery.configs.beat")
|
||||
|
||||
|
||||
@beat_init.connect
|
||||
def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("beat_init signal received.")
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=2, max_overflow=0)
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
#####
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": "check_for_connector_deletion_task",
|
||||
"schedule": timedelta(seconds=60),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": "check_for_indexing",
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": "check_for_pruning",
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": "kombu_message_cleanup_task",
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": "monitor_vespa_sync",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
]
|
||||
|
||||
# Build the celery beat schedule dynamically
|
||||
beat_schedule = {}
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
for task in tasks_to_schedule:
|
||||
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
|
||||
beat_schedule[task_name] = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"options": task["options"],
|
||||
"kwargs": {"tenant_id": tenant_id}, # Must pass tenant_id as an argument
|
||||
}
|
||||
|
||||
# Include any existing beat schedules
|
||||
existing_beat_schedule = celery_app.conf.beat_schedule or {}
|
||||
beat_schedule.update(existing_beat_schedule)
|
||||
|
||||
# Update the Celery app configuration once
|
||||
celery_app.conf.beat_schedule = beat_schedule
|
||||
@@ -1,88 +0,0 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("danswer.background.celery.configs.heavy")
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"danswer.background.celery.tasks.pruning",
|
||||
]
|
||||
)
|
||||
@@ -1,88 +0,0 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("danswer.background.celery.configs.indexing")
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"danswer.background.celery.tasks.indexing",
|
||||
]
|
||||
)
|
||||
@@ -1,89 +0,0 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("danswer.background.celery.configs.light")
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"danswer.background.celery.tasks.shared",
|
||||
"danswer.background.celery.tasks.vespa",
|
||||
]
|
||||
)
|
||||
@@ -1,300 +0,0 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import bootsteps # type: ignore
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("danswer.background.celery.configs.primary")
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
sender.primary_worker_locks = {}
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
for tenant_id in tenant_ids:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# For the moment, we're assuming that we are the only primary worker
|
||||
# that should be running.
|
||||
# TODO: maybe check for or clean up another zombie primary worker if we detect it
|
||||
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
|
||||
|
||||
# this process wide lock is taken to help other workers start up in order.
|
||||
# it is planned to use this lock to enforce singleton behavior on the primary
|
||||
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
||||
# implemented yet.
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
logger.info("Primary worker lock: Acquire starting.")
|
||||
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
|
||||
if acquired:
|
||||
logger.info("Primary worker lock: Acquire succeeded.")
|
||||
else:
|
||||
logger.error("Primary worker lock: Acquire failed!")
|
||||
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
||||
|
||||
# tacking on our own user data to the sender
|
||||
sender.primary_worker_locks[tenant_id] = lock
|
||||
|
||||
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
||||
# to a clean state (for our purposes, anyway)
|
||||
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
||||
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
|
||||
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
# @worker_process_init.connect
|
||||
# def on_worker_process_init(sender: Any, **kwargs: Any) -> None:
|
||||
# """This only runs inside child processes when the worker is in pool=prefork mode.
|
||||
# This may be technically unnecessary since we're finding prefork pools to be
|
||||
# unstable and currently aren't planning on using them."""
|
||||
# logger.info("worker_process_init signal received.")
|
||||
# SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
|
||||
# SqlEngine.init_engine(pool_size=5, max_overflow=0)
|
||||
|
||||
# # https://stackoverflow.com/questions/43944787/sqlalchemy-celery-with-scoped-session-error
|
||||
# SqlEngine.get_engine().dispose(close=False)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
"""Regularly reacquires the primary worker lock outside of the task queue.
|
||||
Use the task_logger in this class to avoid double logging.
|
||||
|
||||
This cannot be done inside a regular beat task because it must run on schedule and
|
||||
a queue of existing work would starve the task from running.
|
||||
"""
|
||||
|
||||
# it's unclear to me whether using the hub's timer or the bootstep timer is better
|
||||
requires = {"celery.worker.components:Hub"}
|
||||
|
||||
def __init__(self, worker: Any, **kwargs: Any) -> None:
|
||||
self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds
|
||||
self.task_tref = None
|
||||
|
||||
def start(self, worker: Any) -> None:
|
||||
if not celery_is_worker_primary(worker):
|
||||
return
|
||||
|
||||
# Access the worker's event loop (hub)
|
||||
hub = worker.consumer.controller.hub
|
||||
|
||||
# Schedule the periodic task
|
||||
self.task_tref = hub.call_repeatedly(
|
||||
self.interval, self.run_periodic_task, worker
|
||||
)
|
||||
task_logger.info("Scheduled periodic task with hub.")
|
||||
|
||||
def run_periodic_task(self, worker: Any) -> None:
|
||||
try:
|
||||
if not celery_is_worker_primary(worker):
|
||||
return
|
||||
|
||||
if not hasattr(worker, "primary_worker_locks"):
|
||||
return
|
||||
|
||||
# Retrieve all tenant IDs
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
lock = worker.primary_worker_locks.get(tenant_id)
|
||||
if not lock:
|
||||
continue # Skip if no lock for this tenant
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if lock.owned():
|
||||
task_logger.debug(
|
||||
f"Reacquiring primary worker lock for tenant {tenant_id}."
|
||||
)
|
||||
lock.reacquire()
|
||||
else:
|
||||
task_logger.warning(
|
||||
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
|
||||
"Reasons could be worker restart or lock expiration."
|
||||
)
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
|
||||
)
|
||||
acquired = lock.acquire(
|
||||
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
|
||||
)
|
||||
if acquired:
|
||||
task_logger.info(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
|
||||
)
|
||||
worker.primary_worker_locks[tenant_id] = lock
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
|
||||
)
|
||||
raise TimeoutError(
|
||||
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Periodic task failed.")
|
||||
|
||||
def stop(self, worker: Any) -> None:
|
||||
# Cancel the scheduled task when the worker stops
|
||||
if self.task_tref:
|
||||
self.task_tref.cancel()
|
||||
task_logger.info("Canceled periodic task with hub.")
|
||||
|
||||
|
||||
celery_app.steps["worker"].add(HubPeriodicTask)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"danswer.background.celery.tasks.connector_deletion",
|
||||
"danswer.background.celery.tasks.indexing",
|
||||
"danswer.background.celery.tasks.periodic",
|
||||
"danswer.background.celery.tasks.pruning",
|
||||
"danswer.background.celery.tasks.shared",
|
||||
"danswer.background.celery.tasks.vespa",
|
||||
]
|
||||
)
|
||||
@@ -1,26 +0,0 @@
|
||||
import logging
|
||||
|
||||
from celery import current_task
|
||||
|
||||
from danswer.utils.logger import ColoredFormatter
|
||||
from danswer.utils.logger import PlainFormatter
|
||||
|
||||
|
||||
class CeleryTaskPlainFormatter(PlainFormatter):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
task = current_task
|
||||
if task and task.request:
|
||||
record.__dict__.update(task_id=task.request.id, task_name=task.name)
|
||||
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
|
||||
|
||||
return super().format(record)
|
||||
|
||||
|
||||
class CeleryTaskColoredFormatter(ColoredFormatter):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
task = current_task
|
||||
if task and task.request:
|
||||
record.__dict__.update(task_id=task.request.id, task_name=task.name)
|
||||
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
|
||||
|
||||
return super().format(record)
|
||||
619
backend/danswer/background/celery/celery_app.py
Normal file
619
backend/danswer/background/celery/celery_app.py
Normal file
@@ -0,0 +1,619 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
import sentry_sdk
|
||||
from celery import bootsteps # type: ignore
|
||||
from celery import Celery
|
||||
from celery import current_task
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.signals import beat_init
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
from celery.states import READY_STATES
|
||||
from celery.utils.log import get_task_logger
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
from danswer.background.celery.celery_utils import get_all_tenant_ids
|
||||
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import ColoredFormatter
|
||||
from danswer.utils.logger import PlainFormatter
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# use this within celery tasks to get celery task specific logging
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
if SENTRY_DSN:
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[CeleryIntegration()],
|
||||
traces_sample_rate=0.5,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
|
||||
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object(
|
||||
"danswer.background.celery.celeryconfig"
|
||||
) # Load configuration from 'celeryconfig.py'
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
tenant_id: str | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
"""We handle this signal in order to remove completed tasks
|
||||
from their respective tasksets. This allows us to track the progress of document set
|
||||
and user group syncs.
|
||||
|
||||
This function runs after any task completes (both success and failure)
|
||||
Note that this signal does not fire on a task that failed to complete and is going
|
||||
to be retried.
|
||||
|
||||
This also does not fire if a worker with acks_late=False crashes (which all of our
|
||||
long running workers are)
|
||||
"""
|
||||
if not task:
|
||||
return
|
||||
|
||||
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
|
||||
if not kwargs:
|
||||
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
|
||||
tenant_id = None
|
||||
else:
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
|
||||
task_logger.debug(
|
||||
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
|
||||
f"{f'for tenant_id={tenant_id}' if tenant_id else ''}"
|
||||
)
|
||||
|
||||
if state not in READY_STATES:
|
||||
return
|
||||
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
|
||||
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisDocumentSet.PREFIX):
|
||||
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
|
||||
if document_set_id is not None:
|
||||
rds = RedisDocumentSet(int(document_set_id))
|
||||
r.srem(rds.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisUserGroup.PREFIX):
|
||||
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
|
||||
if usergroup_id is not None:
|
||||
rug = RedisUserGroup(int(usergroup_id))
|
||||
r.srem(rug.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorDeletion.PREFIX):
|
||||
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
rcd = RedisConnectorDeletion(int(cc_pair_id))
|
||||
r.srem(rcd.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
rcp = RedisConnectorPruning(int(cc_pair_id))
|
||||
r.srem(rcp.taskset_key, task_id)
|
||||
return
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
"""The first signal sent on celery worker startup"""
|
||||
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
|
||||
|
||||
|
||||
@beat_init.connect
|
||||
def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=2, max_overflow=0)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
# decide some initial startup settings based on the celery worker's hostname
|
||||
# (set at the command line)'
|
||||
|
||||
hostname = sender.hostname
|
||||
if hostname.startswith("light"):
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
|
||||
elif hostname.startswith("heavy"):
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
elif hostname.startswith("indexing"):
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
# TODO: why is this necessary for the indexer to do?
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
|
||||
if search_settings.provider_type is None:
|
||||
logger.notice(
|
||||
"Running a first inference to warm up embedding model"
|
||||
)
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
logger.notice("First inference complete.")
|
||||
else:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
if not hasattr(sender, "primary_worker_locks"):
|
||||
sender.primary_worker_locks = {}
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
if not celery_is_worker_primary(sender):
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
for tenant_id in tenant_ids:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
time_start = time.monotonic()
|
||||
logger.notice("Redis: Readiness check starting.")
|
||||
while True:
|
||||
# Log all the locks in Redis
|
||||
all_locks = r.keys("*")
|
||||
logger.notice(f"Current Redis locks: {all_locks}")
|
||||
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
|
||||
break
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
logger.info(
|
||||
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
"Redis: Readiness check did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
logger.info("Wait for primary worker completed successfully. Continuing...")
|
||||
return # Exit the function for secondary workers
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
time_start = time.monotonic()
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# For the moment, we're assuming that we are the only primary worker
|
||||
# that should be running.
|
||||
# TODO: maybe check for or clean up another zombie primary worker if we detect it
|
||||
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
|
||||
|
||||
# this process wide lock is taken to help other workers start up in order.
|
||||
# it is planned to use this lock to enforce singleton behavior on the primary
|
||||
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
||||
# implemented yet.
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
logger.info("Primary worker lock: Acquire starting.")
|
||||
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
|
||||
if acquired:
|
||||
logger.info("Primary worker lock: Acquire succeeded.")
|
||||
else:
|
||||
logger.error("Primary worker lock: Acquire failed!")
|
||||
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
||||
|
||||
sender.primary_worker_locks[tenant_id] = lock
|
||||
|
||||
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
||||
# to a clean state (for our purposes, anyway)
|
||||
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
||||
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
|
||||
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
# @worker_process_init.connect
|
||||
# def on_worker_process_init(sender: Any, **kwargs: Any) -> None:
|
||||
# """This only runs inside child processes when the worker is in pool=prefork mode.
|
||||
# This may be technically unnecessary since we're finding prefork pools to be
|
||||
# unstable and currently aren't planning on using them."""
|
||||
# logger.info("worker_process_init signal received.")
|
||||
# SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
|
||||
# SqlEngine.init_engine(pool_size=5, max_overflow=0)
|
||||
|
||||
# # https://stackoverflow.com/questions/43944787/sqlalchemy-celery-with-scoped-session-error
|
||||
# SqlEngine.get_engine().dispose(close=False)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
task_logger.info("worker_ready signal received.")
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
if not celery_is_worker_primary(sender):
|
||||
return
|
||||
|
||||
if not hasattr(sender, "primary_worker_locks"):
|
||||
return
|
||||
|
||||
logger.info("Releasing primary worker lock.")
|
||||
for tenant_id, lock in sender.primary_worker_locks.items():
|
||||
logger.info(f"Releasing primary worker lock for tenant {tenant_id}.")
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
sender.primary_worker_locks = {}
|
||||
|
||||
|
||||
class CeleryTaskPlainFormatter(PlainFormatter):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
task = current_task
|
||||
if task and task.request:
|
||||
record.__dict__.update(task_id=task.request.id, task_name=task.name)
|
||||
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
|
||||
|
||||
return super().format(record)
|
||||
|
||||
|
||||
class CeleryTaskColoredFormatter(ColoredFormatter):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
task = current_task
|
||||
if task and task.request:
|
||||
record.__dict__.update(task_id=task.request.id, task_name=task.name)
|
||||
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
|
||||
|
||||
return super().format(record)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
# TODO: could unhardcode format and colorize and accept these as options from
|
||||
# celery's config
|
||||
|
||||
# reformats the root logger
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
root_handler = logging.StreamHandler() # Set up a handler for the root logger
|
||||
root_formatter = ColoredFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
root_handler.setFormatter(root_formatter)
|
||||
root_logger.addHandler(root_handler) # Apply the handler to the root logger
|
||||
|
||||
if logfile:
|
||||
root_file_handler = logging.FileHandler(logfile)
|
||||
root_file_formatter = PlainFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
root_file_handler.setFormatter(root_file_formatter)
|
||||
root_logger.addHandler(root_file_handler)
|
||||
|
||||
root_logger.setLevel(loglevel)
|
||||
|
||||
# reformats celery's task logger
|
||||
task_formatter = CeleryTaskColoredFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
task_handler = logging.StreamHandler() # Set up a handler for the task logger
|
||||
task_handler.setFormatter(task_formatter)
|
||||
task_logger.addHandler(task_handler) # Apply the handler to the task logger
|
||||
|
||||
if logfile:
|
||||
task_file_handler = logging.FileHandler(logfile)
|
||||
task_file_formatter = CeleryTaskPlainFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
task_file_handler.setFormatter(task_file_formatter)
|
||||
task_logger.addHandler(task_file_handler)
|
||||
|
||||
task_logger.setLevel(loglevel)
|
||||
task_logger.propagate = False
|
||||
|
||||
|
||||
class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
"""Regularly reacquires the primary worker locks for all tenants outside of the task queue.
|
||||
Use the task_logger in this class to avoid double logging.
|
||||
|
||||
This cannot be done inside a regular beat task because it must run on schedule and
|
||||
a queue of existing work would starve the task from running.
|
||||
"""
|
||||
|
||||
# Requires the Hub component
|
||||
requires = {"celery.worker.components:Hub"}
|
||||
|
||||
def __init__(self, worker: Any, **kwargs: Any) -> None:
|
||||
super().__init__(worker, **kwargs)
|
||||
self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds
|
||||
self.task_tref = None
|
||||
|
||||
def start(self, worker: Any) -> None:
|
||||
if not celery_is_worker_primary(worker):
|
||||
return
|
||||
|
||||
# Access the worker's event loop (hub)
|
||||
hub = worker.consumer.controller.hub
|
||||
|
||||
# Schedule the periodic task
|
||||
self.task_tref = hub.call_repeatedly(
|
||||
self.interval, self.run_periodic_task, worker
|
||||
)
|
||||
task_logger.info("Scheduled periodic task with hub.")
|
||||
|
||||
def run_periodic_task(self, worker: Any) -> None:
|
||||
try:
|
||||
if not celery_is_worker_primary(worker):
|
||||
return
|
||||
|
||||
if not hasattr(worker, "primary_worker_locks"):
|
||||
return
|
||||
|
||||
# Retrieve all tenant IDs
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
lock = worker.primary_worker_locks.get(tenant_id)
|
||||
if not lock:
|
||||
continue # Skip if no lock for this tenant
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if lock.owned():
|
||||
task_logger.debug(
|
||||
f"Reacquiring primary worker lock for tenant {tenant_id}."
|
||||
)
|
||||
lock.reacquire()
|
||||
else:
|
||||
task_logger.warning(
|
||||
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
|
||||
"Reasons could be worker restart or lock expiration."
|
||||
)
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
|
||||
)
|
||||
acquired = lock.acquire(
|
||||
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
|
||||
)
|
||||
if acquired:
|
||||
task_logger.info(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
|
||||
)
|
||||
worker.primary_worker_locks[tenant_id] = lock
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
|
||||
)
|
||||
raise TimeoutError(
|
||||
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
task_logger.error(f"Error in periodic task: {e}")
|
||||
|
||||
def stop(self, worker: Any) -> None:
|
||||
# Cancel the scheduled task when the worker stops
|
||||
if self.task_tref:
|
||||
self.task_tref.cancel()
|
||||
task_logger.info("Canceled periodic task with hub.")
|
||||
|
||||
|
||||
celery_app.steps["worker"].add(HubPeriodicTask)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"danswer.background.celery.tasks.connector_deletion",
|
||||
"danswer.background.celery.tasks.indexing",
|
||||
"danswer.background.celery.tasks.periodic",
|
||||
"danswer.background.celery.tasks.pruning",
|
||||
"danswer.background.celery.tasks.shared",
|
||||
"danswer.background.celery.tasks.vespa",
|
||||
]
|
||||
)
|
||||
|
||||
#####
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": "check_for_connector_deletion_task",
|
||||
"schedule": timedelta(seconds=60),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": "check_for_indexing",
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": "check_for_pruning",
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": "kombu_message_cleanup_task",
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": "monitor_vespa_sync",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
]
|
||||
|
||||
# Build the celery beat schedule dynamically
|
||||
beat_schedule = {}
|
||||
|
||||
for id in tenant_ids:
|
||||
for task in tasks_to_schedule:
|
||||
task_name = f"{task['name']}-{id}" # Unique name for each scheduled task
|
||||
beat_schedule[task_name] = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"options": task["options"],
|
||||
"kwargs": {"tenant_id": id}, # Must pass tenant_id as an argument
|
||||
}
|
||||
|
||||
# Include any existing beat schedules
|
||||
existing_beat_schedule = celery_app.conf.beat_schedule or {}
|
||||
beat_schedule.update(existing_beat_schedule)
|
||||
|
||||
# Update the Celery app configuration once
|
||||
celery_app.conf.beat_schedule = beat_schedule
|
||||
@@ -10,7 +10,7 @@ from celery import Celery
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.configs.base import CELERY_SEPARATOR
|
||||
from danswer.background.celery.celeryconfig import CELERY_SEPARATOR
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
@@ -313,8 +313,6 @@ class RedisConnectorDeletion(RedisObjectHelper):
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns None if the cc_pair doesn't exist.
|
||||
Otherwise, returns an int with the number of generated tasks."""
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
@@ -467,8 +465,14 @@ class RedisConnectorPruning(RedisObjectHelper):
|
||||
|
||||
return len(async_results)
|
||||
|
||||
def is_pruning(self, redis_client: Redis) -> bool:
|
||||
def is_pruning(self, db_session: Session, redis_client: Redis) -> bool:
|
||||
"""A single example of a helper method being refactored into the redis helper"""
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=int(self._id), db_session=db_session
|
||||
)
|
||||
if not cc_pair:
|
||||
raise ValueError(f"cc_pair_id {self._id} does not exist.")
|
||||
|
||||
if redis_client.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
@@ -534,36 +538,6 @@ class RedisConnectorIndexing(RedisObjectHelper):
|
||||
) -> int | None:
|
||||
return None
|
||||
|
||||
def is_indexing(self, redis_client: Redis) -> bool:
|
||||
"""A single example of a helper method being refactored into the redis helper"""
|
||||
if redis_client.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class RedisConnectorStop(RedisObjectHelper):
|
||||
"""Used to signal any running tasks for a connector to stop. We should refactor
|
||||
connector related redis helpers into a single class.
|
||||
"""
|
||||
|
||||
PREFIX = "connectorstop"
|
||||
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
|
||||
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock | None,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
return None
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
"""This is a redis specific way to get the length of a celery queue.
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Factory stub for running celery worker / celery beat."""
|
||||
"""Entry point for running celery worker / celery beat."""
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app = fetch_versioned_implementation(
|
||||
"danswer.background.celery.apps.primary", "celery_app"
|
||||
celery_app = fetch_versioned_implementation(
|
||||
"danswer.background.celery.celery_app", "celery_app"
|
||||
)
|
||||
@@ -1,21 +1,25 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.configs.constants import TENANT_ID_PREFIX
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.interfaces import BaseConnector
|
||||
from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import TaskStatus
|
||||
from danswer.db.models import TaskQueueState
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
@@ -71,15 +75,13 @@ def get_deletion_attempt_snapshot(
|
||||
)
|
||||
|
||||
|
||||
def document_batch_to_ids(
|
||||
doc_batch: list[Document],
|
||||
) -> set[str]:
|
||||
def document_batch_to_ids(doc_batch: list[Document]) -> set[str]:
|
||||
return {doc.id for doc in doc_batch}
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(
|
||||
runnable_connector: BaseConnector,
|
||||
callback: RunIndexingCallbackInterface | None = None,
|
||||
progress_callback: Callable[[int], None] | None = None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
If the PruneConnector hasnt been implemented for the given connector, just pull
|
||||
@@ -89,13 +91,10 @@ def extract_ids_from_runnable_connector(
|
||||
"""
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
|
||||
if isinstance(runnable_connector, SlimConnector):
|
||||
for metadata_batch in runnable_connector.retrieve_all_slim_documents():
|
||||
all_connector_doc_ids.update({doc.id for doc in metadata_batch})
|
||||
|
||||
doc_batch_generator = None
|
||||
|
||||
if isinstance(runnable_connector, LoadConnector):
|
||||
if isinstance(runnable_connector, IdConnector):
|
||||
all_connector_doc_ids = runnable_connector.retrieve_all_source_ids()
|
||||
elif isinstance(runnable_connector, LoadConnector):
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
elif isinstance(runnable_connector, PollConnector):
|
||||
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
@@ -104,17 +103,16 @@ def extract_ids_from_runnable_connector(
|
||||
else:
|
||||
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
|
||||
|
||||
doc_batch_processing_func = document_batch_to_ids
|
||||
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
|
||||
doc_batch_processing_func = rate_limit_builder(
|
||||
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
|
||||
)(document_batch_to_ids)
|
||||
for doc_batch in doc_batch_generator:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("Stop signal received")
|
||||
callback.progress(len(doc_batch))
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||
if doc_batch_generator:
|
||||
doc_batch_processing_func = document_batch_to_ids
|
||||
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
|
||||
doc_batch_processing_func = rate_limit_builder(
|
||||
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
|
||||
)(document_batch_to_ids)
|
||||
for doc_batch in doc_batch_generator:
|
||||
if progress_callback:
|
||||
progress_callback(len(doc_batch))
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||
|
||||
return all_connector_doc_ids
|
||||
|
||||
@@ -135,10 +133,33 @@ def celery_is_listening_to_queue(worker: Any, name: str) -> bool:
|
||||
def celery_is_worker_primary(worker: Any) -> bool:
|
||||
"""There are multiple approaches that could be taken to determine if a celery worker
|
||||
is 'primary', as defined by us. But the way we do it is to check the hostname set
|
||||
for the celery worker, which can be done on the
|
||||
for the celery worker, which can be done either in celeryconfig.py or on the
|
||||
command line with '--hostname'."""
|
||||
hostname = worker.hostname
|
||||
if hostname.startswith("primary"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
if not MULTI_TENANT:
|
||||
return [None]
|
||||
with get_session_with_tenant(tenant_id="public") as session:
|
||||
result = session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
|
||||
)
|
||||
)
|
||||
tenant_ids = [row[0] for row in result]
|
||||
|
||||
valid_tenants = [
|
||||
tenant
|
||||
for tenant in tenant_ids
|
||||
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
|
||||
]
|
||||
|
||||
return valid_tenants
|
||||
|
||||
@@ -31,10 +31,21 @@ if REDIS_SSL:
|
||||
if REDIS_SSL_CA_CERTS:
|
||||
SSL_QUERY_PARAMS += f"&ssl_ca_certs={REDIS_SSL_CA_CERTS}"
|
||||
|
||||
# region Broker settings
|
||||
# example celery_broker_url: "redis://:password@localhost:6379/15"
|
||||
broker_url = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
|
||||
|
||||
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY_RESULT_BACKEND}{SSL_QUERY_PARAMS}"
|
||||
|
||||
# NOTE: prefetch 4 is significantly faster than prefetch 1 for small tasks
|
||||
# however, prefetching is bad when tasks are lengthy as those tasks
|
||||
# can stall other tasks.
|
||||
worker_prefetch_multiplier = 4
|
||||
|
||||
# Leaving this to the default of True may cause double logging since both our own app
|
||||
# and celery think they are controlling the logger.
|
||||
# TODO: Configure celery's logger entirely manually and set this to False
|
||||
# worker_hijack_root_logger = False
|
||||
|
||||
broker_connection_retry_on_startup = True
|
||||
broker_pool_limit = CELERY_BROKER_POOL_LIMIT
|
||||
|
||||
@@ -49,7 +60,6 @@ broker_transport_options = {
|
||||
"socket_keepalive": True,
|
||||
"socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS,
|
||||
}
|
||||
# endregion
|
||||
|
||||
# redis backend settings
|
||||
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
|
||||
@@ -63,19 +73,10 @@ redis_backend_health_check_interval = REDIS_HEALTH_CHECK_INTERVAL
|
||||
task_default_priority = DanswerCeleryPriority.MEDIUM
|
||||
task_acks_late = True
|
||||
|
||||
# region Task result backend settings
|
||||
# It's possible we don't even need celery's result backend, in which case all of the optimization below
|
||||
# might be irrelevant
|
||||
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY_RESULT_BACKEND}{SSL_QUERY_PARAMS}"
|
||||
result_expires = CELERY_RESULT_EXPIRES # 86400 seconds is the default
|
||||
# endregion
|
||||
|
||||
# Leaving this to the default of True may cause double logging since both our own app
|
||||
# and celery think they are controlling the logger.
|
||||
# TODO: Configure celery's logger entirely manually and set this to False
|
||||
# worker_hijack_root_logger = False
|
||||
|
||||
# region Notes on serialization performance
|
||||
# Option 0: Defaults (json serializer, no compression)
|
||||
# about 1.5 KB per queued task. 1KB in queue, 400B for result, 100 as a child entry in generator result
|
||||
|
||||
@@ -101,4 +102,3 @@ result_expires = CELERY_RESULT_EXPIRES # 86400 seconds is the default
|
||||
# task_serializer = "pickle-bzip2"
|
||||
# result_serializer = "pickle-bzip2"
|
||||
# accept_content=["pickle", "pickle-bzip2"]
|
||||
# endregion
|
||||
@@ -1,14 +0,0 @@
|
||||
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
|
||||
import danswer.background.celery.configs.base as shared_config
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
@@ -1,20 +0,0 @@
|
||||
import danswer.background.celery.configs.base as shared_config
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = 4
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
@@ -1,21 +0,0 @@
|
||||
import danswer.background.celery.configs.base as shared_config
|
||||
from danswer.configs.app_configs import CELERY_WORKER_INDEXING_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
@@ -1,22 +0,0 @@
|
||||
import danswer.background.celery.configs.base as shared_config
|
||||
from danswer.configs.app_configs import CELERY_WORKER_LIGHT_CONCURRENCY
|
||||
from danswer.configs.app_configs import CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = CELERY_WORKER_LIGHT_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER
|
||||
@@ -1,20 +0,0 @@
|
||||
import danswer.background.celery.configs.base as shared_config
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = 4
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
@@ -1,45 +1,29 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
|
||||
RedisConnectorDeletionFenceData,
|
||||
)
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.search_settings import get_all_search_settings
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
class TaskDependencyError(RuntimeError):
|
||||
"""Raised to the caller to indicate dependent tasks are running that would interfere
|
||||
with connector deletion."""
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_connector_deletion_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_connector_deletion_task(*, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
@@ -52,30 +36,12 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
# collect cc_pair_ids
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair.id)
|
||||
|
||||
# try running cleanup on the cc_pair_ids
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
rcs = RedisConnectorStop(cc_pair_id)
|
||||
try:
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
except TaskDependencyError as e:
|
||||
# this means we wanted to start deleting but dependent tasks were running
|
||||
# Leave a stop signal to clear indexing and pruning tasks more quickly
|
||||
task_logger.info(str(e))
|
||||
r.set(rcs.fence_key, cc_pair_id)
|
||||
else:
|
||||
# clear the stop signal if it exists ... no longer needed
|
||||
r.delete(rcs.fence_key)
|
||||
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
cc_pair, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -88,8 +54,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
|
||||
|
||||
def try_generate_document_cc_pair_cleanup_tasks(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
@@ -98,87 +63,51 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Note that syncing can still be required even if the number of sync tasks generated is zero.
|
||||
Returns None if no syncing is required.
|
||||
|
||||
Will raise TaskDependencyError if dependent tasks such as indexing and pruning are
|
||||
still running. In our case, the caller reacts by setting a stop signal in Redis to
|
||||
exit those tasks as quickly as possible.
|
||||
"""
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
# we need to load the state of the object inside the fence
|
||||
# we need to refresh the state of the object inside the fence
|
||||
# to avoid a race condition with db.commit/fence deletion
|
||||
# at the end of this taskset
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if not cc_pair:
|
||||
try:
|
||||
db_session.refresh(cc_pair)
|
||||
except ObjectDeletedError:
|
||||
return None
|
||||
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# set a basic fence to start
|
||||
fence_value = RedisConnectorDeletionFenceData(
|
||||
num_tasks=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rcd.taskset_key)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
||||
)
|
||||
r.set(rcd.fence_key, fence_value.model_dump_json())
|
||||
|
||||
try:
|
||||
# do not proceed if connector indexing or connector pruning are running
|
||||
search_settings_list = get_all_search_settings(db_session)
|
||||
for search_settings in search_settings_list:
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings.id)
|
||||
if r.get(rci.fence_key):
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (indexing in progress): "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
if r.get(rcp.fence_key):
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (pruning in progress): "
|
||||
f"cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rcd.taskset_key)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id)
|
||||
if tasks_generated is None:
|
||||
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
|
||||
except TaskDependencyError:
|
||||
r.delete(rcd.fence_key)
|
||||
raise
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
r.delete(rcd.fence_key)
|
||||
tasks_generated = rcd.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
else:
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
# set this only after all tasks have been added
|
||||
fence_value.num_tasks = tasks_generated
|
||||
r.set(rcd.fence_key, fence_value.model_dump_json())
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks finished. "
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rcd.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
@@ -5,26 +5,18 @@ from time import sleep
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
|
||||
RedisConnectorIndexingFenceData,
|
||||
)
|
||||
from danswer.background.celery.tasks.shared.tasks import RedisConnectorIndexingFenceData
|
||||
from danswer.background.indexing.job_client import SimpleJobClient
|
||||
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
@@ -48,48 +40,18 @@ from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class RunIndexingCallback(RunIndexingCallbackInterface):
|
||||
def __init__(
|
||||
self,
|
||||
stop_key: str,
|
||||
generator_progress_key: str,
|
||||
redis_lock: redis.lock.Lock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.redis_lock: redis.lock.Lock = redis_lock
|
||||
self.stop_key: str = stop_key
|
||||
self.generator_progress_key: str = generator_progress_key
|
||||
self.redis_client = redis_client
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_client.exists(self.stop_key):
|
||||
return True
|
||||
return False
|
||||
|
||||
def progress(self, amount: int) -> None:
|
||||
self.redis_lock.reacquire()
|
||||
self.redis_client.incrby(self.generator_progress_key, amount)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_indexing",
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
def check_for_indexing(*, tenant_id: str | None) -> int | None:
|
||||
tasks_created = 0
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
@@ -102,54 +64,31 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
task_logger.info(f"Lock acquired for tenant (Y): {tenant_id}")
|
||||
return None
|
||||
else:
|
||||
task_logger.info(f"Lock acquired for tenant (N): {tenant_id}")
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
if current_search_settings.provider_type is None and not MULTI_TENANT:
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=current_search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
)
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings = [primary_search_settings]
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings.append(secondary_search_settings)
|
||||
|
||||
cc_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings = [primary_search_settings]
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings.append(secondary_search_settings)
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
for search_settings_instance in search_settings:
|
||||
rci = RedisConnectorIndexing(
|
||||
cc_pair_id, search_settings_instance.id
|
||||
cc_pair.id, search_settings_instance.id
|
||||
)
|
||||
if r.exists(rci.fence_key):
|
||||
continue
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id, db_session
|
||||
)
|
||||
if not cc_pair:
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
@@ -165,7 +104,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
# using a task queue and only allowing one task per cc_pair/search_setting
|
||||
# prevents us from starving out certain attempts
|
||||
attempt_id = try_creating_indexing_task(
|
||||
self.app,
|
||||
cc_pair,
|
||||
search_settings_instance,
|
||||
False,
|
||||
@@ -275,7 +213,6 @@ def _should_index(
|
||||
|
||||
|
||||
def try_creating_indexing_task(
|
||||
celery_app: Celery,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
search_settings: SearchSettings,
|
||||
reindex: bool,
|
||||
@@ -311,10 +248,6 @@ def try_creating_indexing_task(
|
||||
return None
|
||||
|
||||
# skip indexing if the cc_pair is deleting
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
@@ -325,19 +258,7 @@ def try_creating_indexing_task(
|
||||
|
||||
custom_task_id = f"{rci.generator_task_id_prefix}_{uuid4()}"
|
||||
|
||||
# set a basic fence to start
|
||||
fence_value = RedisConnectorIndexingFenceData(
|
||||
index_attempt_id=None,
|
||||
started=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
celery_task_id=None,
|
||||
)
|
||||
r.set(rci.fence_key, fence_value.model_dump_json())
|
||||
|
||||
# create the index attempt for tracking purposes
|
||||
# code elsewhere checks for index attempts without an associated redis key
|
||||
# and cleans them up
|
||||
# therefore we must create the attempt and the task after the fence goes up
|
||||
# create the index attempt ... just for tracking purposes
|
||||
index_attempt_id = create_index_attempt(
|
||||
cc_pair.id,
|
||||
search_settings.id,
|
||||
@@ -358,14 +279,17 @@ def try_creating_indexing_task(
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
|
||||
return None
|
||||
|
||||
# now fill out the fence with the rest of the data
|
||||
fence_value.index_attempt_id = index_attempt_id
|
||||
fence_value.celery_task_id = result.id
|
||||
# set this only after all tasks have been added
|
||||
fence_value = RedisConnectorIndexingFenceData(
|
||||
index_attempt_id=index_attempt_id,
|
||||
started=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
celery_task_id=result.id,
|
||||
)
|
||||
r.set(rci.fence_key, fence_value.model_dump_json())
|
||||
except Exception:
|
||||
r.delete(rci.fence_key)
|
||||
task_logger.exception("Unexpected exception")
|
||||
return None
|
||||
finally:
|
||||
@@ -448,56 +372,8 @@ def connector_indexing_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
if r.exists(rcd.fence_key):
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because connector deletion is in progress: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={rcd.fence_key}"
|
||||
)
|
||||
|
||||
rcs = RedisConnectorStop(cc_pair_id)
|
||||
if r.exists(rcs.fence_key):
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because a connector stop signal was detected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={rcs.fence_key}"
|
||||
)
|
||||
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
|
||||
|
||||
while True:
|
||||
# read related data and evaluate/print task progress
|
||||
fence_value = cast(bytes, r.get(rci.fence_key))
|
||||
if fence_value is None:
|
||||
task_logger.info(
|
||||
f"connector_indexing_task: fence_value not found: fence={rci.fence_key}"
|
||||
)
|
||||
raise RuntimeError(f"Fence not found: fence={rci.fence_key}")
|
||||
|
||||
try:
|
||||
fence_json = fence_value.decode("utf-8")
|
||||
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
|
||||
cast(str, fence_json)
|
||||
)
|
||||
except ValueError:
|
||||
task_logger.exception(
|
||||
f"connector_indexing_task: fence_data not decodeable: fence={rci.fence_key}"
|
||||
)
|
||||
raise
|
||||
|
||||
if fence_data.index_attempt_id is None or fence_data.celery_task_id is None:
|
||||
task_logger.info(
|
||||
f"connector_indexing_task - Waiting for fence: fence={rci.fence_key}"
|
||||
)
|
||||
sleep(1)
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"connector_indexing_task - Fence found, continuing...: fence={rci.fence_key}"
|
||||
)
|
||||
break
|
||||
|
||||
lock = r.lock(
|
||||
rci.generator_lock_key,
|
||||
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
|
||||
@@ -507,20 +383,17 @@ def connector_indexing_task(
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"Indexing task already running, exiting...: "
|
||||
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
|
||||
f"cc_pair_id={cc_pair_id} search_settings_id={search_settings_id}"
|
||||
)
|
||||
# r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value)
|
||||
return None
|
||||
|
||||
fence_data.started = datetime.now(timezone.utc)
|
||||
r.set(rci.fence_key, fence_data.model_dump_json())
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if not attempt:
|
||||
raise ValueError(
|
||||
f"Index attempt not found: index_attempt={index_attempt_id}"
|
||||
f"Index attempt not found: index_attempt_id={index_attempt_id}"
|
||||
)
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
@@ -529,31 +402,31 @@ def connector_indexing_task(
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
raise ValueError(f"cc_pair not found: cc_pair={cc_pair_id}")
|
||||
raise ValueError(f"cc_pair not found: cc_pair_id={cc_pair_id}")
|
||||
|
||||
if not cc_pair.connector:
|
||||
raise ValueError(
|
||||
f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}"
|
||||
f"Connector not found: connector_id={cc_pair.connector_id}"
|
||||
)
|
||||
|
||||
if not cc_pair.credential:
|
||||
raise ValueError(
|
||||
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}"
|
||||
f"Credential not found: credential_id={cc_pair.credential_id}"
|
||||
)
|
||||
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
|
||||
|
||||
# define a callback class
|
||||
callback = RunIndexingCallback(
|
||||
rcs.fence_key, rci.generator_progress_key, lock, r
|
||||
)
|
||||
# Define the callback function
|
||||
def redis_increment_callback(amount: int) -> None:
|
||||
lock.reacquire()
|
||||
r.incrby(rci.generator_progress_key, amount)
|
||||
|
||||
run_indexing_entrypoint(
|
||||
index_attempt_id,
|
||||
tenant_id,
|
||||
cc_pair_id,
|
||||
is_ee,
|
||||
callback=callback,
|
||||
progress_callback=redis_increment_callback,
|
||||
)
|
||||
|
||||
# get back the total number of indexed docs and return it
|
||||
@@ -566,10 +439,9 @@ def connector_indexing_task(
|
||||
|
||||
r.set(rci.generator_complete_key, HTTPStatus.OK.value)
|
||||
except Exception as e:
|
||||
task_logger.exception(f"Indexing failed: cc_pair={cc_pair_id}")
|
||||
task_logger.exception(f"Failed to run indexing for cc_pair_id={cc_pair_id}.")
|
||||
if attempt:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
|
||||
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
|
||||
|
||||
r.delete(rci.generator_lock_key)
|
||||
r.delete(rci.generator_progress_key)
|
||||
|
||||
@@ -11,7 +11,7 @@ from sqlalchemy import inspect
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import PostgresAdvisoryLocks
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
|
||||
@@ -3,19 +3,15 @@ from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
|
||||
@@ -27,7 +23,6 @@ from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
@@ -42,9 +37,8 @@ logger = setup_logger()
|
||||
@shared_task(
|
||||
name="check_for_pruning",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_pruning(*, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
@@ -57,24 +51,15 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if not cc_pair:
|
||||
continue
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
lock_beat.reacquire()
|
||||
if not is_pruning_due(cc_pair, db_session, r):
|
||||
continue
|
||||
|
||||
tasks_created = try_creating_prune_generator_task(
|
||||
self.app, cc_pair, db_session, r, tenant_id
|
||||
cc_pair, db_session, r, tenant_id
|
||||
)
|
||||
if not tasks_created:
|
||||
continue
|
||||
@@ -133,7 +118,6 @@ def is_pruning_due(
|
||||
|
||||
|
||||
def try_creating_prune_generator_task(
|
||||
celery_app: Celery,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
@@ -171,10 +155,6 @@ def try_creating_prune_generator_task(
|
||||
return None
|
||||
|
||||
# skip pruning if the cc_pair is deleting
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
@@ -216,14 +196,9 @@ def try_creating_prune_generator_task(
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def connector_pruning_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
tenant_id: str | None,
|
||||
cc_pair_id: int, connector_id: int, credential_id: int, tenant_id: str | None
|
||||
) -> None:
|
||||
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
@@ -241,7 +216,7 @@ def connector_pruning_generator_task(
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"Pruning task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
f"Pruning task already running, exiting...: cc_pair_id={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -259,22 +234,22 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
return
|
||||
|
||||
# Define the callback function
|
||||
def redis_increment_callback(amount: int) -> None:
|
||||
lock.reacquire()
|
||||
r.incrby(rcp.generator_progress_key, amount)
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session,
|
||||
cc_pair.connector.source,
|
||||
InputType.SLIM_RETRIEVAL,
|
||||
InputType.PRUNE,
|
||||
cc_pair.connector.connector_specific_config,
|
||||
cc_pair.credential,
|
||||
)
|
||||
|
||||
rcs = RedisConnectorStop(cc_pair_id)
|
||||
|
||||
callback = RunIndexingCallback(
|
||||
rcs.fence_key, rcp.generator_progress_key, lock, r
|
||||
)
|
||||
# a list of docs in the source
|
||||
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
|
||||
runnable_connector, callback
|
||||
runnable_connector, redis_increment_callback
|
||||
)
|
||||
|
||||
# a list of docs in our local index
|
||||
@@ -292,7 +267,7 @@ def connector_pruning_generator_task(
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning set collected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"cc_pair_id={cc_pair.id} "
|
||||
f"docs_to_remove={len(doc_ids_to_remove)} "
|
||||
f"doc_source={cc_pair.connector.source}"
|
||||
)
|
||||
@@ -300,24 +275,22 @@ def connector_pruning_generator_task(
|
||||
rcp.documents_to_prune = set(doc_ids_to_remove)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair_id}"
|
||||
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
||||
)
|
||||
tasks_generated = rcp.generate_tasks(
|
||||
self.app, db_session, r, None, tenant_id
|
||||
celery_app, db_session, r, None, tenant_id
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorPruning.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
r.set(rcp.generator_complete_key, tasks_generated)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
|
||||
)
|
||||
task_logger.exception(f"Failed to run pruning for connector id {connector_id}.")
|
||||
|
||||
r.delete(rcp.generator_progress_key)
|
||||
r.delete(rcp.taskset_key)
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RedisConnectorDeletionFenceData(BaseModel):
|
||||
num_tasks: int | None
|
||||
submitted: datetime
|
||||
@@ -1,10 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RedisConnectorIndexingFenceData(BaseModel):
|
||||
index_attempt_id: int | None
|
||||
started: datetime | None
|
||||
submitted: datetime
|
||||
celery_task_id: str | None
|
||||
@@ -1,14 +1,16 @@
|
||||
from datetime import datetime
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document
|
||||
from danswer.db.document import get_document_connector_count
|
||||
from danswer.db.document import mark_document_as_modified
|
||||
from danswer.db.document import mark_document_as_synced
|
||||
from danswer.db.document_set import fetch_document_sets_for_document
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
@@ -17,7 +19,12 @@ from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import VespaDocumentFields
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3
|
||||
|
||||
class RedisConnectorIndexingFenceData(BaseModel):
|
||||
index_attempt_id: int
|
||||
started: datetime | None
|
||||
submitted: datetime
|
||||
celery_task_id: str
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -25,7 +32,7 @@ DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3
|
||||
bind=True,
|
||||
soft_time_limit=45,
|
||||
time_limit=60,
|
||||
max_retries=DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES,
|
||||
max_retries=3,
|
||||
)
|
||||
def document_by_cc_pair_cleanup_task(
|
||||
self: Task,
|
||||
@@ -49,7 +56,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
connector / credential pair from the access list
|
||||
(6) delete all relevant entries from postgres
|
||||
"""
|
||||
task_logger.info(f"tenant_id={tenant_id} document_id={document_id}")
|
||||
task_logger.info(f"document_id={document_id}")
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -115,8 +122,6 @@ def document_by_cc_pair_cleanup_task(
|
||||
else:
|
||||
pass
|
||||
|
||||
db_session.commit()
|
||||
|
||||
task_logger.info(
|
||||
f"tenant_id={tenant_id} "
|
||||
f"document_id={document_id} "
|
||||
@@ -124,27 +129,16 @@ def document_by_cc_pair_cleanup_task(
|
||||
f"refcount={count} "
|
||||
f"chunks={chunks_affected}"
|
||||
)
|
||||
db_session.commit()
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
f"SoftTimeLimitExceeded exception. tenant_id={tenant_id} doc_id={document_id}"
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
task_logger.exception("Unexpected exception")
|
||||
|
||||
if self.request.retries < DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES:
|
||||
# Still retrying. Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
else:
|
||||
# This is the last attempt! mark the document as dirty in the db so that it
|
||||
# eventually gets fixed out of band via stale document reconciliation
|
||||
task_logger.info(
|
||||
f"Max retries reached. Marking doc as dirty for reconciliation: "
|
||||
f"tenant_id={tenant_id} document_id={document_id}"
|
||||
)
|
||||
with get_session_with_tenant(tenant_id):
|
||||
mark_document_as_modified(document_id, db_session)
|
||||
return False
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
|
||||
return True
|
||||
|
||||
@@ -5,7 +5,6 @@ from http import HTTPStatus
|
||||
from typing import cast
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
@@ -15,7 +14,8 @@ from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import celery_get_queue_length
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
@@ -23,12 +23,7 @@ from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
|
||||
RedisConnectorDeletionFenceData,
|
||||
)
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
|
||||
RedisConnectorIndexingFenceData,
|
||||
)
|
||||
from danswer.background.celery.tasks.shared.tasks import RedisConnectorIndexingFenceData
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
@@ -59,6 +54,7 @@ from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import UserGroup
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import VespaDocumentFields
|
||||
@@ -80,9 +76,8 @@ logger = setup_logger()
|
||||
name="check_for_vespa_sync_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_vespa_sync_task(*, tenant_id: str | None) -> None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
|
||||
@@ -99,53 +94,35 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_stale_document_sync_tasks(
|
||||
self.app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
try_generate_stale_document_sync_tasks(db_session, r, lock_beat, tenant_id)
|
||||
|
||||
# region document set scan
|
||||
document_set_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# check if any document sets are not synced
|
||||
document_set_info = fetch_document_sets(
|
||||
user_id=None, db_session=db_session, include_outdated=True
|
||||
)
|
||||
|
||||
for document_set, _ in document_set_info:
|
||||
document_set_ids.append(document_set.id)
|
||||
|
||||
for document_set_id in document_set_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_document_set_sync_tasks(
|
||||
self.app, document_set_id, db_session, r, lock_beat, tenant_id
|
||||
document_set, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
# endregion
|
||||
|
||||
# check if any user groups are not synced
|
||||
if global_version.is_ee_version():
|
||||
try:
|
||||
fetch_user_groups = fetch_versioned_implementation(
|
||||
"danswer.db.user_group", "fetch_user_groups"
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
# Always exceptions on the MIT version, which is expected
|
||||
# We shouldn't actually get here if the ee version check works
|
||||
pass
|
||||
else:
|
||||
usergroup_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# check if any user groups are not synced
|
||||
if global_version.is_ee_version():
|
||||
try:
|
||||
fetch_user_groups = fetch_versioned_implementation(
|
||||
"danswer.db.user_group", "fetch_user_groups"
|
||||
)
|
||||
|
||||
user_groups = fetch_user_groups(
|
||||
db_session=db_session, only_up_to_date=False
|
||||
)
|
||||
|
||||
for usergroup in user_groups:
|
||||
usergroup_ids.append(usergroup.id)
|
||||
|
||||
for usergroup_id in usergroup_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_user_group_sync_tasks(
|
||||
self.app, usergroup_id, db_session, r, lock_beat, tenant_id
|
||||
usergroup, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
# Always exceptions on the MIT version, which is expected
|
||||
# We shouldn't actually get here if the ee version check works
|
||||
pass
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
@@ -159,11 +136,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
|
||||
|
||||
def try_generate_stale_document_sync_tasks(
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
db_session: Session, r: Redis, lock_beat: redis.lock.Lock, tenant_id: str | None
|
||||
) -> int | None:
|
||||
# the fence is up, do nothing
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
@@ -214,8 +187,7 @@ def try_generate_stale_document_sync_tasks(
|
||||
|
||||
|
||||
def try_generate_document_set_sync_tasks(
|
||||
celery_app: Celery,
|
||||
document_set_id: int,
|
||||
document_set: DocumentSet,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
@@ -223,7 +195,7 @@ def try_generate_document_set_sync_tasks(
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rds = RedisDocumentSet(document_set_id)
|
||||
rds = RedisDocumentSet(document_set.id)
|
||||
|
||||
# don't generate document set sync tasks if tasks are still pending
|
||||
if r.exists(rds.fence_key):
|
||||
@@ -231,10 +203,7 @@ def try_generate_document_set_sync_tasks(
|
||||
|
||||
# don't generate sync tasks if we're up to date
|
||||
# race condition with the monitor/cleanup function if we use a cached result!
|
||||
document_set = get_document_set_by_id(db_session, document_set_id)
|
||||
if not document_set:
|
||||
return None
|
||||
|
||||
db_session.refresh(document_set)
|
||||
if document_set.is_up_to_date:
|
||||
return None
|
||||
|
||||
@@ -269,8 +238,7 @@ def try_generate_document_set_sync_tasks(
|
||||
|
||||
|
||||
def try_generate_user_group_sync_tasks(
|
||||
celery_app: Celery,
|
||||
usergroup_id: int,
|
||||
usergroup: UserGroup,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
@@ -278,21 +246,14 @@ def try_generate_user_group_sync_tasks(
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rug = RedisUserGroup(usergroup_id)
|
||||
rug = RedisUserGroup(usergroup.id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if r.exists(rug.fence_key):
|
||||
return None
|
||||
|
||||
# race condition with the monitor/cleanup function if we use a cached result!
|
||||
fetch_user_group = fetch_versioned_implementation(
|
||||
"danswer.db.user_group", "fetch_user_group"
|
||||
)
|
||||
|
||||
usergroup = fetch_user_group(db_session, usergroup_id)
|
||||
if not usergroup:
|
||||
return None
|
||||
|
||||
db_session.refresh(usergroup)
|
||||
if usergroup.is_up_to_date:
|
||||
return None
|
||||
|
||||
@@ -371,7 +332,7 @@ def monitor_document_set_taskset(
|
||||
|
||||
count = cast(int, r.scard(rds.taskset_key))
|
||||
task_logger.info(
|
||||
f"Document set sync progress: document_set={document_set_id} "
|
||||
f"Document set sync progress: document_set_id={document_set_id} "
|
||||
f"remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
@@ -386,12 +347,12 @@ def monitor_document_set_taskset(
|
||||
# if there are no connectors, then delete the document set.
|
||||
delete_document_set(document_set_row=document_set, db_session=db_session)
|
||||
task_logger.info(
|
||||
f"Successfully deleted document set: document_set={document_set_id}"
|
||||
f"Successfully deleted document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
else:
|
||||
mark_document_set_as_synced(document_set_id, db_session)
|
||||
task_logger.info(
|
||||
f"Successfully synced document set: document_set={document_set_id}"
|
||||
f"Successfully synced document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
|
||||
r.delete(rds.taskset_key)
|
||||
@@ -411,29 +372,19 @@ def monitor_connector_deletion_taskset(
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
|
||||
# read related data and evaluate/print task progress
|
||||
fence_value = cast(bytes, r.get(rcd.fence_key))
|
||||
fence_value = r.get(rcd.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
fence_json = fence_value.decode("utf-8")
|
||||
fence_data = RedisConnectorDeletionFenceData.model_validate_json(
|
||||
cast(str, fence_json)
|
||||
)
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.exception(
|
||||
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
|
||||
)
|
||||
raise
|
||||
|
||||
# the fence is setting up but isn't ready yet
|
||||
if fence_data.num_tasks is None:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rcd.taskset_key))
|
||||
task_logger.info(
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={fence_data.num_tasks}"
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
@@ -496,7 +447,7 @@ def monitor_connector_deletion_taskset(
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
task_logger.info(
|
||||
"Connector deletion - Found no credentials left for connector, deleting connector"
|
||||
"Found no credentials left for connector, deleting connector"
|
||||
)
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
@@ -506,17 +457,17 @@ def monitor_connector_deletion_taskset(
|
||||
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
|
||||
add_deletion_failure_message(db_session, cc_pair_id, error_message)
|
||||
task_logger.exception(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"Failed to run connector_deletion. "
|
||||
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
f"Connector deletion succeeded: "
|
||||
f"Successfully deleted cc_pair: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector={cc_pair.connector_id} "
|
||||
f"credential={cc_pair.credential_id} "
|
||||
f"docs_deleted={fence_data.num_tasks}"
|
||||
f"docs_deleted={initial_count}"
|
||||
)
|
||||
|
||||
r.delete(rcd.taskset_key)
|
||||
@@ -626,12 +577,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
"monitor_ccpair_indexing_taskset: generator_progress_value is not an integer."
|
||||
)
|
||||
|
||||
if fence_data.index_attempt_id is None or fence_data.celery_task_id is None:
|
||||
# the task is still setting up
|
||||
return
|
||||
|
||||
# Read result state BEFORE generator_complete_key to avoid a race condition
|
||||
# never use any blocking methods on the result from inside a task!
|
||||
result: AsyncResult = AsyncResult(fence_data.celery_task_id)
|
||||
result_state = result.state
|
||||
|
||||
@@ -733,9 +679,36 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
f"pruning={n_pruning}"
|
||||
)
|
||||
|
||||
# do some cleanup before clearing fences
|
||||
# check the db for any outstanding index attempts
|
||||
lock_beat.reacquire()
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
monitor_document_set_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
monitor_usergroup_taskset = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"danswer.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
)
|
||||
monitor_usergroup_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
|
||||
|
||||
# do some cleanup before clearing fences
|
||||
# check the db for any outstanding index attempts
|
||||
attempts: list[IndexAttempt] = []
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
|
||||
@@ -753,42 +726,8 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
if not r.exists(rci.fence_key):
|
||||
mark_attempt_failed(a, db_session, failure_reason=failure_reason)
|
||||
|
||||
lock_beat.reacquire()
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
|
||||
monitor_ccpair_indexing_taskset(key_bytes, r, db_session)
|
||||
|
||||
# uncomment for debugging if needed
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
"""Factory stub for running celery worker / celery beat."""
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app = fetch_versioned_implementation(
|
||||
"danswer.background.celery.apps.beat", "celery_app"
|
||||
)
|
||||
@@ -1,17 +0,0 @@
|
||||
"""Factory stub for running celery worker / celery beat.
|
||||
This code is different from the primary/beat stubs because there is no EE version to
|
||||
fetch. Port over the code in those files if we add an EE version of this worker."""
|
||||
from celery import Celery
|
||||
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from danswer.background.celery.apps.heavy import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
|
||||
app = get_app()
|
||||
@@ -1,17 +0,0 @@
|
||||
"""Factory stub for running celery worker / celery beat.
|
||||
This code is different from the primary/beat stubs because there is no EE version to
|
||||
fetch. Port over the code in those files if we add an EE version of this worker."""
|
||||
from celery import Celery
|
||||
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from danswer.background.celery.apps.indexing import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
|
||||
app = get_app()
|
||||
@@ -1,17 +0,0 @@
|
||||
"""Factory stub for running celery worker / celery beat.
|
||||
This code is different from the primary/beat stubs because there is no EE version to
|
||||
fetch. Port over the code in those files if we add an EE version of this worker."""
|
||||
from celery import Celery
|
||||
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from danswer.background.celery.apps.light import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
|
||||
app = get_app()
|
||||
@@ -1,7 +1,6 @@
|
||||
import time
|
||||
import traceback
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -42,19 +41,6 @@ logger = setup_logger()
|
||||
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
|
||||
|
||||
|
||||
class RunIndexingCallbackInterface(ABC):
|
||||
"""Defines a callback interface to be passed to
|
||||
to run_indexing_entrypoint."""
|
||||
|
||||
@abstractmethod
|
||||
def should_stop(self) -> bool:
|
||||
"""Signal to stop the looping function in flight."""
|
||||
|
||||
@abstractmethod
|
||||
def progress(self, amount: int) -> None:
|
||||
"""Send progress updates to the caller."""
|
||||
|
||||
|
||||
def _get_connector_runner(
|
||||
db_session: Session,
|
||||
attempt: IndexAttempt,
|
||||
@@ -106,7 +92,7 @@ def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt: IndexAttempt,
|
||||
tenant_id: str | None,
|
||||
callback: RunIndexingCallbackInterface | None = None,
|
||||
progress_callback: Callable[[int], None] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
1. Get documents which are either new or updated from specified application
|
||||
@@ -220,11 +206,6 @@ def _run_indexing(
|
||||
# index being built. We want to populate it even for paused connectors
|
||||
# Often paused connectors are sources that aren't updated frequently but the
|
||||
# contents still need to be initially pulled.
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("Connector stop signal detected")
|
||||
|
||||
# TODO: should we move this into the above callback instead?
|
||||
db_session.refresh(db_cc_pair)
|
||||
if (
|
||||
(
|
||||
@@ -282,8 +263,8 @@ def _run_indexing(
|
||||
# be inaccurate
|
||||
db_session.commit()
|
||||
|
||||
if callback:
|
||||
callback.progress(len(doc_batch))
|
||||
if progress_callback:
|
||||
progress_callback(len(doc_batch))
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
update_docs_indexed(
|
||||
@@ -413,7 +394,7 @@ def run_indexing_entrypoint(
|
||||
tenant_id: str | None,
|
||||
connector_credential_pair_id: int,
|
||||
is_ee: bool = False,
|
||||
callback: RunIndexingCallbackInterface | None = None,
|
||||
progress_callback: Callable[[int], None] | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
if is_ee:
|
||||
@@ -436,7 +417,7 @@ def run_indexing_entrypoint(
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
|
||||
_run_indexing(db_session, attempt, tenant_id, callback)
|
||||
_run_indexing(db_session, attempt, tenant_id, progress_callback)
|
||||
|
||||
logger.info(
|
||||
f"Indexing finished for tenant {tenant_id}: "
|
||||
|
||||
494
backend/danswer/background/update.py
Executable file
494
backend/danswer/background/update.py
Executable file
@@ -0,0 +1,494 @@
|
||||
# TODO(rkuo): delete after background indexing via celery is fully vetted
|
||||
# import logging
|
||||
# import time
|
||||
# from datetime import datetime
|
||||
# import dask
|
||||
# from dask.distributed import Client
|
||||
# from dask.distributed import Future
|
||||
# from distributed import LocalCluster
|
||||
# from sqlalchemy import text
|
||||
# from sqlalchemy.exc import ProgrammingError
|
||||
# from sqlalchemy.orm import Session
|
||||
# from danswer.background.indexing.dask_utils import ResourceLogger
|
||||
# from danswer.background.indexing.job_client import SimpleJob
|
||||
# from danswer.background.indexing.job_client import SimpleJobClient
|
||||
# from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
# from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
||||
# from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||
# from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
# from danswer.configs.app_configs import MULTI_TENANT
|
||||
# from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||
# from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
|
||||
# from danswer.configs.constants import DocumentSource
|
||||
# from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
|
||||
# from danswer.configs.constants import TENANT_ID_PREFIX
|
||||
# from danswer.db.connector import fetch_connectors
|
||||
# from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
# from danswer.db.engine import get_db_current_time
|
||||
# from danswer.db.engine import get_session_with_tenant
|
||||
# from danswer.db.engine import get_sqlalchemy_engine
|
||||
# from danswer.db.engine import SqlEngine
|
||||
# from danswer.db.index_attempt import create_index_attempt
|
||||
# from danswer.db.index_attempt import get_index_attempt
|
||||
# from danswer.db.index_attempt import get_inprogress_index_attempts
|
||||
# from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
# from danswer.db.index_attempt import get_not_started_index_attempts
|
||||
# from danswer.db.index_attempt import mark_attempt_failed
|
||||
# from danswer.db.models import ConnectorCredentialPair
|
||||
# from danswer.db.models import IndexAttempt
|
||||
# from danswer.db.models import IndexingStatus
|
||||
# from danswer.db.models import IndexModelStatus
|
||||
# from danswer.db.models import SearchSettings
|
||||
# from danswer.db.search_settings import get_current_search_settings
|
||||
# from danswer.db.search_settings import get_secondary_search_settings
|
||||
# from danswer.db.swap_index import check_index_swap
|
||||
# from danswer.document_index.vespa.index import VespaIndex
|
||||
# from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
# from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
# from danswer.utils.logger import setup_logger
|
||||
# from danswer.utils.variable_functionality import global_version
|
||||
# from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
# from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
# from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
# from shared_configs.configs import LOG_LEVEL
|
||||
# logger = setup_logger()
|
||||
# # If the indexing dies, it's most likely due to resource constraints,
|
||||
# # restarting just delays the eventual failure, not useful to the user
|
||||
# dask.config.set({"distributed.scheduler.allowed-failures": 0})
|
||||
# _UNEXPECTED_STATE_FAILURE_REASON = (
|
||||
# "Stopped mid run, likely due to the background process being killed"
|
||||
# )
|
||||
# def _should_create_new_indexing(
|
||||
# cc_pair: ConnectorCredentialPair,
|
||||
# last_index: IndexAttempt | None,
|
||||
# search_settings_instance: SearchSettings,
|
||||
# secondary_index_building: bool,
|
||||
# db_session: Session,
|
||||
# ) -> bool:
|
||||
# connector = cc_pair.connector
|
||||
# # don't kick off indexing for `NOT_APPLICABLE` sources
|
||||
# if connector.source == DocumentSource.NOT_APPLICABLE:
|
||||
# return False
|
||||
# # User can still manually create single indexing attempts via the UI for the
|
||||
# # currently in use index
|
||||
# if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
# if (
|
||||
# search_settings_instance.status == IndexModelStatus.PRESENT
|
||||
# and secondary_index_building
|
||||
# ):
|
||||
# return False
|
||||
# # When switching over models, always index at least once
|
||||
# if search_settings_instance.status == IndexModelStatus.FUTURE:
|
||||
# if last_index:
|
||||
# # No new index if the last index attempt succeeded
|
||||
# # Once is enough. The model will never be able to swap otherwise.
|
||||
# if last_index.status == IndexingStatus.SUCCESS:
|
||||
# return False
|
||||
# # No new index if the last index attempt is waiting to start
|
||||
# if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
# return False
|
||||
# # No new index if the last index attempt is running
|
||||
# if last_index.status == IndexingStatus.IN_PROGRESS:
|
||||
# return False
|
||||
# else:
|
||||
# if (
|
||||
# connector.id == 0 or connector.source == DocumentSource.INGESTION_API
|
||||
# ): # Ingestion API
|
||||
# return False
|
||||
# return True
|
||||
# # If the connector is paused or is the ingestion API, don't index
|
||||
# # NOTE: during an embedding model switch over, the following logic
|
||||
# # is bypassed by the above check for a future model
|
||||
# if (
|
||||
# not cc_pair.status.is_active()
|
||||
# or connector.id == 0
|
||||
# or connector.source == DocumentSource.INGESTION_API
|
||||
# ):
|
||||
# return False
|
||||
# if not last_index:
|
||||
# return True
|
||||
# if connector.refresh_freq is None:
|
||||
# return False
|
||||
# # Only one scheduled/ongoing job per connector at a time
|
||||
# # this prevents cases where
|
||||
# # (1) the "latest" index_attempt is scheduled so we show
|
||||
# # that in the UI despite another index_attempt being in-progress
|
||||
# # (2) multiple scheduled index_attempts at a time
|
||||
# if (
|
||||
# last_index.status == IndexingStatus.NOT_STARTED
|
||||
# or last_index.status == IndexingStatus.IN_PROGRESS
|
||||
# ):
|
||||
# return False
|
||||
# current_db_time = get_db_current_time(db_session)
|
||||
# time_since_index = current_db_time - last_index.time_updated
|
||||
# return time_since_index.total_seconds() >= connector.refresh_freq
|
||||
# def _mark_run_failed(
|
||||
# db_session: Session, index_attempt: IndexAttempt, failure_reason: str
|
||||
# ) -> None:
|
||||
# """Marks the `index_attempt` row as failed + updates the `
|
||||
# connector_credential_pair` to reflect that the run failed"""
|
||||
# logger.warning(
|
||||
# f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
|
||||
# f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
|
||||
# )
|
||||
# mark_attempt_failed(
|
||||
# index_attempt=index_attempt,
|
||||
# db_session=db_session,
|
||||
# failure_reason=failure_reason,
|
||||
# )
|
||||
# """Main funcs"""
|
||||
# def create_indexing_jobs(
|
||||
# existing_jobs: dict[int, Future | SimpleJob], tenant_id: str | None
|
||||
# ) -> None:
|
||||
# """Creates new indexing jobs for each connector / credential pair which is:
|
||||
# 1. Enabled
|
||||
# 2. `refresh_frequency` time has passed since the last indexing run for this pair
|
||||
# 3. There is not already an ongoing indexing attempt for this pair
|
||||
# """
|
||||
# with get_session_with_tenant(tenant_id) as db_session:
|
||||
# ongoing: set[tuple[int | None, int]] = set()
|
||||
# for attempt_id in existing_jobs:
|
||||
# attempt = get_index_attempt(
|
||||
# db_session=db_session, index_attempt_id=attempt_id
|
||||
# )
|
||||
# if attempt is None:
|
||||
# logger.error(
|
||||
# f"Unable to find IndexAttempt for ID '{attempt_id}' when creating "
|
||||
# "indexing jobs"
|
||||
# )
|
||||
# continue
|
||||
# ongoing.add(
|
||||
# (
|
||||
# attempt.connector_credential_pair_id,
|
||||
# attempt.search_settings_id,
|
||||
# )
|
||||
# )
|
||||
# # Get the primary search settings
|
||||
# primary_search_settings = get_current_search_settings(db_session)
|
||||
# search_settings = [primary_search_settings]
|
||||
# # Check for secondary search settings
|
||||
# secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
# if secondary_search_settings is not None:
|
||||
# # If secondary settings exist, add them to the list
|
||||
# search_settings.append(secondary_search_settings)
|
||||
# all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
|
||||
# for cc_pair in all_connector_credential_pairs:
|
||||
# for search_settings_instance in search_settings:
|
||||
# # Check if there is an ongoing indexing attempt for this connector credential pair
|
||||
# if (cc_pair.id, search_settings_instance.id) in ongoing:
|
||||
# continue
|
||||
# last_attempt = get_last_attempt_for_cc_pair(
|
||||
# cc_pair.id, search_settings_instance.id, db_session
|
||||
# )
|
||||
# if not _should_create_new_indexing(
|
||||
# cc_pair=cc_pair,
|
||||
# last_index=last_attempt,
|
||||
# search_settings_instance=search_settings_instance,
|
||||
# secondary_index_building=len(search_settings) > 1,
|
||||
# db_session=db_session,
|
||||
# ):
|
||||
# continue
|
||||
# create_index_attempt(
|
||||
# cc_pair.id, search_settings_instance.id, db_session
|
||||
# )
|
||||
# def cleanup_indexing_jobs(
|
||||
# existing_jobs: dict[int, Future | SimpleJob],
|
||||
# tenant_id: str | None,
|
||||
# timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
|
||||
# ) -> dict[int, Future | SimpleJob]:
|
||||
# existing_jobs_copy = existing_jobs.copy()
|
||||
# # clean up completed jobs
|
||||
# with get_session_with_tenant(tenant_id) as db_session:
|
||||
# for attempt_id, job in existing_jobs.items():
|
||||
# index_attempt = get_index_attempt(
|
||||
# db_session=db_session, index_attempt_id=attempt_id
|
||||
# )
|
||||
# # do nothing for ongoing jobs that haven't been stopped
|
||||
# if not job.done():
|
||||
# if not index_attempt:
|
||||
# continue
|
||||
# if not index_attempt.is_finished():
|
||||
# continue
|
||||
# if job.status == "error":
|
||||
# logger.error(job.exception())
|
||||
# job.release()
|
||||
# del existing_jobs_copy[attempt_id]
|
||||
# if not index_attempt:
|
||||
# logger.error(
|
||||
# f"Unable to find IndexAttempt for ID '{attempt_id}' when cleaning "
|
||||
# "up indexing jobs"
|
||||
# )
|
||||
# continue
|
||||
# if (
|
||||
# index_attempt.status == IndexingStatus.IN_PROGRESS
|
||||
# or job.status == "error"
|
||||
# ):
|
||||
# _mark_run_failed(
|
||||
# db_session=db_session,
|
||||
# index_attempt=index_attempt,
|
||||
# failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||
# )
|
||||
# # clean up in-progress jobs that were never completed
|
||||
# try:
|
||||
# connectors = fetch_connectors(db_session)
|
||||
# for connector in connectors:
|
||||
# in_progress_indexing_attempts = get_inprogress_index_attempts(
|
||||
# connector.id, db_session
|
||||
# )
|
||||
# for index_attempt in in_progress_indexing_attempts:
|
||||
# if index_attempt.id in existing_jobs:
|
||||
# # If index attempt is canceled, stop the run
|
||||
# if index_attempt.status == IndexingStatus.FAILED:
|
||||
# existing_jobs[index_attempt.id].cancel()
|
||||
# # check to see if the job has been updated in last `timeout_hours` hours, if not
|
||||
# # assume it to frozen in some bad state and just mark it as failed. Note: this relies
|
||||
# # on the fact that the `time_updated` field is constantly updated every
|
||||
# # batch of documents indexed
|
||||
# current_db_time = get_db_current_time(db_session=db_session)
|
||||
# time_since_update = current_db_time - index_attempt.time_updated
|
||||
# if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
|
||||
# existing_jobs[index_attempt.id].cancel()
|
||||
# _mark_run_failed(
|
||||
# db_session=db_session,
|
||||
# index_attempt=index_attempt,
|
||||
# failure_reason="Indexing run frozen - no updates in the last three hours. "
|
||||
# "The run will be re-attempted at next scheduled indexing time.",
|
||||
# )
|
||||
# else:
|
||||
# # If job isn't known, simply mark it as failed
|
||||
# _mark_run_failed(
|
||||
# db_session=db_session,
|
||||
# index_attempt=index_attempt,
|
||||
# failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||
# )
|
||||
# except ProgrammingError:
|
||||
# logger.debug(f"No Connector Table exists for: {tenant_id}")
|
||||
# return existing_jobs_copy
|
||||
# def kickoff_indexing_jobs(
|
||||
# existing_jobs: dict[int, Future | SimpleJob],
|
||||
# client: Client | SimpleJobClient,
|
||||
# secondary_client: Client | SimpleJobClient,
|
||||
# tenant_id: str | None,
|
||||
# ) -> dict[int, Future | SimpleJob]:
|
||||
# existing_jobs_copy = existing_jobs.copy()
|
||||
# current_session = get_session_with_tenant(tenant_id)
|
||||
# # Don't include jobs waiting in the Dask queue that just haven't started running
|
||||
# # Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
|
||||
# with current_session as db_session:
|
||||
# # get_not_started_index_attempts orders its returned results from oldest to newest
|
||||
# # we must process attempts in a FIFO manner to prevent connector starvation
|
||||
# new_indexing_attempts = [
|
||||
# (attempt, attempt.search_settings)
|
||||
# for attempt in get_not_started_index_attempts(db_session)
|
||||
# if attempt.id not in existing_jobs
|
||||
# ]
|
||||
# logger.debug(f"Found {len(new_indexing_attempts)} new indexing task(s).")
|
||||
# if not new_indexing_attempts:
|
||||
# return existing_jobs
|
||||
# indexing_attempt_count = 0
|
||||
# primary_client_full = False
|
||||
# secondary_client_full = False
|
||||
# for attempt, search_settings in new_indexing_attempts:
|
||||
# if primary_client_full and secondary_client_full:
|
||||
# break
|
||||
# use_secondary_index = (
|
||||
# search_settings.status == IndexModelStatus.FUTURE
|
||||
# if search_settings is not None
|
||||
# else False
|
||||
# )
|
||||
# if attempt.connector_credential_pair.connector is None:
|
||||
# logger.warning(
|
||||
# f"Skipping index attempt as Connector has been deleted: {attempt}"
|
||||
# )
|
||||
# with current_session as db_session:
|
||||
# mark_attempt_failed(
|
||||
# attempt, db_session, failure_reason="Connector is null"
|
||||
# )
|
||||
# continue
|
||||
# if attempt.connector_credential_pair.credential is None:
|
||||
# logger.warning(
|
||||
# f"Skipping index attempt as Credential has been deleted: {attempt}"
|
||||
# )
|
||||
# with current_session as db_session:
|
||||
# mark_attempt_failed(
|
||||
# attempt, db_session, failure_reason="Credential is null"
|
||||
# )
|
||||
# continue
|
||||
# if not use_secondary_index:
|
||||
# if not primary_client_full:
|
||||
# run = client.submit(
|
||||
# run_indexing_entrypoint,
|
||||
# attempt.id,
|
||||
# tenant_id,
|
||||
# attempt.connector_credential_pair_id,
|
||||
# global_version.is_ee_version(),
|
||||
# pure=False,
|
||||
# )
|
||||
# if not run:
|
||||
# primary_client_full = True
|
||||
# else:
|
||||
# if not secondary_client_full:
|
||||
# run = secondary_client.submit(
|
||||
# run_indexing_entrypoint,
|
||||
# attempt.id,
|
||||
# tenant_id,
|
||||
# attempt.connector_credential_pair_id,
|
||||
# global_version.is_ee_version(),
|
||||
# pure=False,
|
||||
# )
|
||||
# if not run:
|
||||
# secondary_client_full = True
|
||||
# if run:
|
||||
# if indexing_attempt_count == 0:
|
||||
# logger.info(
|
||||
# f"Indexing dispatch starts: pending={len(new_indexing_attempts)}"
|
||||
# )
|
||||
# indexing_attempt_count += 1
|
||||
# secondary_str = " (secondary index)" if use_secondary_index else ""
|
||||
# logger.info(
|
||||
# f"Indexing dispatched{secondary_str}: "
|
||||
# f"attempt_id={attempt.id} "
|
||||
# f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
# f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
# f"credentials='{attempt.connector_credential_pair.credential_id}'"
|
||||
# )
|
||||
# existing_jobs_copy[attempt.id] = run
|
||||
# if indexing_attempt_count > 0:
|
||||
# logger.info(
|
||||
# f"Indexing dispatch results: "
|
||||
# f"initial_pending={len(new_indexing_attempts)} "
|
||||
# f"started={indexing_attempt_count} "
|
||||
# f"remaining={len(new_indexing_attempts) - indexing_attempt_count}"
|
||||
# )
|
||||
# return existing_jobs_copy
|
||||
# def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
# if not MULTI_TENANT:
|
||||
# return [None]
|
||||
# with get_session_with_tenant(tenant_id="public") as session:
|
||||
# result = session.execute(
|
||||
# text(
|
||||
# """
|
||||
# SELECT schema_name
|
||||
# FROM information_schema.schemata
|
||||
# WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
|
||||
# )
|
||||
# )
|
||||
# tenant_ids = [row[0] for row in result]
|
||||
# valid_tenants = [
|
||||
# tenant
|
||||
# for tenant in tenant_ids
|
||||
# if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
|
||||
# ]
|
||||
# return valid_tenants
|
||||
# def update_loop(
|
||||
# delay: int = 10,
|
||||
# num_workers: int = NUM_INDEXING_WORKERS,
|
||||
# num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
|
||||
# ) -> None:
|
||||
# if not MULTI_TENANT:
|
||||
# # We can use this function as we are certain only the public schema exists
|
||||
# # (explicitly for the non-`MULTI_TENANT` case)
|
||||
# engine = get_sqlalchemy_engine()
|
||||
# with Session(engine) as db_session:
|
||||
# check_index_swap(db_session=db_session)
|
||||
# search_settings = get_current_search_settings(db_session)
|
||||
# # So that the first time users aren't surprised by really slow speed of first
|
||||
# # batch of documents indexed
|
||||
# if search_settings.provider_type is None:
|
||||
# logger.notice("Running a first inference to warm up embedding model")
|
||||
# embedding_model = EmbeddingModel.from_db_model(
|
||||
# search_settings=search_settings,
|
||||
# server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
# server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
# )
|
||||
# warm_up_bi_encoder(
|
||||
# embedding_model=embedding_model,
|
||||
# )
|
||||
# logger.notice("First inference complete.")
|
||||
# client_primary: Client | SimpleJobClient
|
||||
# client_secondary: Client | SimpleJobClient
|
||||
# if DASK_JOB_CLIENT_ENABLED:
|
||||
# cluster_primary = LocalCluster(
|
||||
# n_workers=num_workers,
|
||||
# threads_per_worker=1,
|
||||
# silence_logs=logging.ERROR,
|
||||
# )
|
||||
# cluster_secondary = LocalCluster(
|
||||
# n_workers=num_secondary_workers,
|
||||
# threads_per_worker=1,
|
||||
# silence_logs=logging.ERROR,
|
||||
# )
|
||||
# client_primary = Client(cluster_primary)
|
||||
# client_secondary = Client(cluster_secondary)
|
||||
# if LOG_LEVEL.lower() == "debug":
|
||||
# client_primary.register_worker_plugin(ResourceLogger())
|
||||
# else:
|
||||
# client_primary = SimpleJobClient(n_workers=num_workers)
|
||||
# client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
|
||||
# existing_jobs: dict[str | None, dict[int, Future | SimpleJob]] = {}
|
||||
# logger.notice("Startup complete. Waiting for indexing jobs...")
|
||||
# while True:
|
||||
# start = time.time()
|
||||
# start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
|
||||
# logger.debug(f"Running update, current UTC time: {start_time_utc}")
|
||||
# if existing_jobs:
|
||||
# logger.debug(
|
||||
# "Found existing indexing jobs: "
|
||||
# f"{[(tenant_id, list(jobs.keys())) for tenant_id, jobs in existing_jobs.items()]}"
|
||||
# )
|
||||
# try:
|
||||
# tenants = get_all_tenant_ids()
|
||||
# for tenant_id in tenants:
|
||||
# try:
|
||||
# logger.debug(
|
||||
# f"Processing {'index attempts' if tenant_id is None else f'tenant {tenant_id}'}"
|
||||
# )
|
||||
# with get_session_with_tenant(tenant_id) as db_session:
|
||||
# index_to_expire = check_index_swap(db_session=db_session)
|
||||
# if index_to_expire and tenant_id and MULTI_TENANT:
|
||||
# VespaIndex.delete_entries_by_tenant_id(
|
||||
# tenant_id=tenant_id,
|
||||
# index_name=index_to_expire.index_name,
|
||||
# )
|
||||
# if not MULTI_TENANT:
|
||||
# search_settings = get_current_search_settings(db_session)
|
||||
# if search_settings.provider_type is None:
|
||||
# logger.notice(
|
||||
# "Running a first inference to warm up embedding model"
|
||||
# )
|
||||
# embedding_model = EmbeddingModel.from_db_model(
|
||||
# search_settings=search_settings,
|
||||
# server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
# server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
# )
|
||||
# warm_up_bi_encoder(embedding_model=embedding_model)
|
||||
# logger.notice("First inference complete.")
|
||||
# tenant_jobs = existing_jobs.get(tenant_id, {})
|
||||
# tenant_jobs = cleanup_indexing_jobs(
|
||||
# existing_jobs=tenant_jobs, tenant_id=tenant_id
|
||||
# )
|
||||
# create_indexing_jobs(existing_jobs=tenant_jobs, tenant_id=tenant_id)
|
||||
# tenant_jobs = kickoff_indexing_jobs(
|
||||
# existing_jobs=tenant_jobs,
|
||||
# client=client_primary,
|
||||
# secondary_client=client_secondary,
|
||||
# tenant_id=tenant_id,
|
||||
# )
|
||||
# existing_jobs[tenant_id] = tenant_jobs
|
||||
# except Exception as e:
|
||||
# logger.exception(
|
||||
# f"Failed to process tenant {tenant_id or 'default'}: {e}"
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.exception(f"Failed to run update due to {e}")
|
||||
# sleep_time = delay - (time.time() - start)
|
||||
# if sleep_time > 0:
|
||||
# time.sleep(sleep_time)
|
||||
# def update__main() -> None:
|
||||
# set_is_ee_based_on_env_variable()
|
||||
# # initialize the Postgres connection pool
|
||||
# SqlEngine.set_app_name(POSTGRES_INDEXER_APP_NAME)
|
||||
# logger.notice("Starting indexing service")
|
||||
# update_loop()
|
||||
# if __name__ == "__main__":
|
||||
# update__main()
|
||||
@@ -672,7 +672,6 @@ def stream_chat_message_objects(
|
||||
all_docs_useful=selected_db_search_docs is not None
|
||||
),
|
||||
document_pruning_config=document_pruning_config,
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
),
|
||||
prompt_config=prompt_config,
|
||||
llm=(
|
||||
|
||||
@@ -43,9 +43,6 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
||||
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
|
||||
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
|
||||
|
||||
# Necessary for cloud integration tests
|
||||
DISABLE_VERIFICATION = os.environ.get("DISABLE_VERIFICATION", "").lower() == "true"
|
||||
|
||||
# Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive
|
||||
# information. This provides an extra layer of security on top of Postgres access controls
|
||||
# and is available in Danswer EE
|
||||
@@ -134,6 +131,7 @@ try:
|
||||
except ValueError:
|
||||
INDEX_BATCH_SIZE = 16
|
||||
|
||||
|
||||
# Below are intended to match the env variables names used by the official postgres docker image
|
||||
# https://hub.docker.com/_/postgres
|
||||
POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
|
||||
@@ -142,7 +140,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
||||
os.environ.get("POSTGRES_PASSWORD") or "password"
|
||||
)
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
|
||||
POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||
@@ -200,41 +198,6 @@ try:
|
||||
except ValueError:
|
||||
CELERY_BROKER_POOL_LIMIT = CELERY_BROKER_POOL_LIMIT_DEFAULT
|
||||
|
||||
CELERY_WORKER_LIGHT_CONCURRENCY_DEFAULT = 24
|
||||
try:
|
||||
CELERY_WORKER_LIGHT_CONCURRENCY = int(
|
||||
os.environ.get(
|
||||
"CELERY_WORKER_LIGHT_CONCURRENCY", CELERY_WORKER_LIGHT_CONCURRENCY_DEFAULT
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
CELERY_WORKER_LIGHT_CONCURRENCY = CELERY_WORKER_LIGHT_CONCURRENCY_DEFAULT
|
||||
|
||||
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT = 8
|
||||
try:
|
||||
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER = int(
|
||||
os.environ.get(
|
||||
"CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER",
|
||||
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT,
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER = (
|
||||
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT
|
||||
)
|
||||
|
||||
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 1
|
||||
try:
|
||||
env_value = os.environ.get("CELERY_WORKER_INDEXING_CONCURRENCY")
|
||||
if not env_value:
|
||||
env_value = os.environ.get("NUM_INDEXING_WORKERS")
|
||||
|
||||
if not env_value:
|
||||
env_value = str(CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT)
|
||||
CELERY_WORKER_INDEXING_CONCURRENCY = int(env_value)
|
||||
except ValueError:
|
||||
CELERY_WORKER_INDEXING_CONCURRENCY = CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT
|
||||
|
||||
#####
|
||||
# Connector Configs
|
||||
#####
|
||||
@@ -290,6 +253,12 @@ CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES = (
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Save pages labels as Danswer metadata tags
|
||||
# The reason to skip this would be to reduce the number of calls to Confluence due to rate limit concerns
|
||||
CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING = (
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Attachments exceeding this size will not be retrieved (in bytes)
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD = int(
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||
|
||||
@@ -46,6 +46,7 @@ POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||
POSTGRES_DEFAULT_SCHEMA = "public"
|
||||
|
||||
# API Keys
|
||||
DANSWER_API_KEY_PREFIX = "API_KEY__"
|
||||
@@ -70,7 +71,6 @@ KV_CUSTOMER_UUID_KEY = "customer_uuid"
|
||||
KV_INSTANCE_DOMAIN_KEY = "instance_domain"
|
||||
KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings"
|
||||
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
|
||||
KV_DOCUMENTS_SEEDED_KEY = "documents_seeded"
|
||||
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
|
||||
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
|
||||
@@ -13,8 +13,8 @@ Connectors come in 3 different flows:
|
||||
documents via a connector's API or loads the documents from some sort of a dump file.
|
||||
- Poll connector:
|
||||
- Incrementally updates documents based on a provided time range. It is used by the background job to pull the latest
|
||||
changes and additions since the last round of polling. This connector helps keep the document index up to date
|
||||
without needing to fetch/embed/index every document which would be too slow to do frequently on large sets of
|
||||
changes additions and changes since the last round of polling. This connector helps keep the document index up to date
|
||||
without needing to fetch/embed/index every document which generally be too slow to do frequently on large sets of
|
||||
documents.
|
||||
- Event Based connectors:
|
||||
- Connectors that listen to events and update documents accordingly.
|
||||
|
||||
@@ -15,6 +15,7 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_st
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
@@ -23,7 +24,6 @@ from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.html_utils import parse_html_page_basic
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -10,6 +10,7 @@ from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
@@ -18,7 +19,6 @@ from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
CLICKUP_API_BASE_URL = "https://api.clickup.com/api/v2"
|
||||
|
||||
32
backend/danswer/connectors/confluence/confluence_utils.py
Normal file
32
backend/danswer/connectors/confluence/confluence_utils.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import bs4
|
||||
|
||||
|
||||
def build_confluence_document_id(base_url: str, content_url: str) -> str:
|
||||
"""For confluence, the document id is the page url for a page based document
|
||||
or the attachment download url for an attachment based document
|
||||
|
||||
Args:
|
||||
base_url (str): The base url of the Confluence instance
|
||||
content_url (str): The url of the page or attachment download url
|
||||
|
||||
Returns:
|
||||
str: The document id
|
||||
"""
|
||||
return f"{base_url}{content_url}"
|
||||
|
||||
|
||||
def get_used_attachments(text: str) -> list[str]:
|
||||
"""Parse a Confluence html page to generate a list of current
|
||||
attachment in used
|
||||
|
||||
Args:
|
||||
text (str): The page content
|
||||
|
||||
Returns:
|
||||
list[str]: List of filenames currently in use by the page text
|
||||
"""
|
||||
files_in_used = []
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
for attachment in soup.findAll("ri:attachment"):
|
||||
files_in_used.append(attachment.attrs["ri:filename"])
|
||||
return files_in_used
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,226 +0,0 @@
|
||||
import math
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
from urllib.parse import quote
|
||||
|
||||
from atlassian import Confluence # type:ignore
|
||||
from requests import HTTPError
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
|
||||
class ConfluenceRateLimitError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _handle_http_error(e: HTTPError, attempt: int) -> int:
|
||||
MIN_DELAY = 2
|
||||
MAX_DELAY = 60
|
||||
STARTING_DELAY = 5
|
||||
BACKOFF = 2
|
||||
|
||||
# Check if the response or headers are None to avoid potential AttributeError
|
||||
if e.response is None or e.response.headers is None:
|
||||
logger.warning("HTTPError with `None` as response or as headers")
|
||||
raise e
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
):
|
||||
raise e
|
||||
|
||||
retry_after = None
|
||||
|
||||
retry_after_header = e.response.headers.get("Retry-After")
|
||||
if retry_after_header is not None:
|
||||
try:
|
||||
retry_after = int(retry_after_header)
|
||||
if retry_after > MAX_DELAY:
|
||||
logger.warning(
|
||||
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
|
||||
)
|
||||
retry_after = MAX_DELAY
|
||||
if retry_after < MIN_DELAY:
|
||||
retry_after = MIN_DELAY
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if retry_after is not None:
|
||||
logger.warning(
|
||||
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
delay = retry_after
|
||||
else:
|
||||
logger.warning(
|
||||
"Rate limiting without retry header. Retrying with exponential backoff..."
|
||||
)
|
||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||
|
||||
delay_until = math.ceil(time.monotonic() + delay)
|
||||
return delay_until
|
||||
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
# this uses the native rate limiting option provided by the
|
||||
# confluence client and otherwise applies a simpler set of error handling
|
||||
def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
TIMEOUT = 3600
|
||||
timeout_at = time.monotonic() + TIMEOUT
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
if time.monotonic() > timeout_at:
|
||||
raise TimeoutError(
|
||||
f"Confluence call attempts took longer than {TIMEOUT} seconds."
|
||||
)
|
||||
|
||||
try:
|
||||
# we're relying more on the client to rate limit itself
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
while time.monotonic() < delay_until:
|
||||
# in the future, check a signal here to exit
|
||||
time.sleep(1)
|
||||
except AttributeError as e:
|
||||
# Some error within the Confluence library, unclear why it fails.
|
||||
# Users reported it to be intermittent, so just retry
|
||||
if attempt == MAX_RETRIES - 1:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Confluence Client raised an AttributeError. Retrying..."
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
_DEFAULT_PAGINATION_LIMIT = 100
|
||||
|
||||
|
||||
class OnyxConfluence(Confluence):
|
||||
"""
|
||||
This is a custom Confluence class that overrides the default Confluence class to add a custom CQL method.
|
||||
This is necessary because the default Confluence class does not properly support cql expansions.
|
||||
All methods are automatically wrapped with handle_confluence_rate_limit.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, *args: Any, **kwargs: Any) -> None:
|
||||
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
|
||||
self._wrap_methods()
|
||||
|
||||
def _wrap_methods(self) -> None:
|
||||
"""
|
||||
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
|
||||
wrap it with handle_confluence_rate_limit.
|
||||
"""
|
||||
for attr_name in dir(self):
|
||||
if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
|
||||
setattr(
|
||||
self,
|
||||
attr_name,
|
||||
handle_confluence_rate_limit(getattr(self, attr_name)),
|
||||
)
|
||||
|
||||
def _paginate_url(
|
||||
self, url_suffix: str, limit: int | None = None
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
"""
|
||||
This will paginate through the top level query.
|
||||
"""
|
||||
if not limit:
|
||||
limit = _DEFAULT_PAGINATION_LIMIT
|
||||
|
||||
connection_char = "&" if "?" in url_suffix else "?"
|
||||
url_suffix += f"{connection_char}limit={limit}"
|
||||
|
||||
while url_suffix:
|
||||
try:
|
||||
next_response = self.get(url_suffix)
|
||||
except Exception as e:
|
||||
logger.exception("Error in danswer_cql: \n")
|
||||
raise e
|
||||
yield next_response.get("results", [])
|
||||
url_suffix = next_response.get("_links", {}).get("next")
|
||||
|
||||
def paginated_groups_retrieval(
|
||||
self,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
return self._paginate_url("rest/api/group", limit)
|
||||
|
||||
def paginated_group_members_retrieval(
|
||||
self,
|
||||
group_name: str,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
group_name = quote(group_name)
|
||||
return self._paginate_url(f"rest/api/group/{group_name}/member", limit)
|
||||
|
||||
def paginated_cql_user_retrieval(
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
return self._paginate_url(
|
||||
f"rest/api/search/user?cql={cql}{expand_string}", limit
|
||||
)
|
||||
|
||||
def paginated_cql_page_retrieval(
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
return self._paginate_url(
|
||||
f"rest/api/content/search?cql={cql}{expand_string}", limit
|
||||
)
|
||||
|
||||
def cql_paginate_all_expansions(
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
"""
|
||||
This function will paginate through the top level query first, then
|
||||
paginate through all of the expansions.
|
||||
The limit only applies to the top level query.
|
||||
All expansion paginations use default pagination limit (defined by Atlassian).
|
||||
"""
|
||||
|
||||
def _traverse_and_update(data: dict | list) -> None:
|
||||
if isinstance(data, dict):
|
||||
next_url = data.get("_links", {}).get("next")
|
||||
if next_url and "results" in data:
|
||||
data["results"].extend(self._paginate_url(next_url))
|
||||
|
||||
for value in data.values():
|
||||
_traverse_and_update(value)
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
_traverse_and_update(item)
|
||||
|
||||
for results in self.paginated_cql_page_retrieval(cql, expand, limit):
|
||||
_traverse_and_update(results)
|
||||
yield results
|
||||
219
backend/danswer/connectors/confluence/rate_limit_handler.py
Normal file
219
backend/danswer/connectors/confluence/rate_limit_handler.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import math
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
from requests import HTTPError
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
|
||||
class ConfluenceRateLimitError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# commenting out while we try using confluence's rate limiter instead
|
||||
# # https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
# def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
|
||||
# def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
# max_retries = 5
|
||||
# starting_delay = 5
|
||||
# backoff = 2
|
||||
|
||||
# # max_delay is used when the server doesn't hand back "Retry-After"
|
||||
# # and we have to decide the retry delay ourselves
|
||||
# max_delay = 30 # Atlassian uses max_delay = 30 in their examples
|
||||
|
||||
# # max_retry_after is used when we do get a "Retry-After" header
|
||||
# max_retry_after = 300 # should we really cap the maximum retry delay?
|
||||
|
||||
# NEXT_RETRY_KEY = BaseConnector.REDIS_KEY_PREFIX + "confluence_next_retry"
|
||||
|
||||
# # for testing purposes, rate limiting is written to fall back to a simpler
|
||||
# # rate limiting approach when redis is not available
|
||||
# r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# for attempt in range(max_retries):
|
||||
# try:
|
||||
# # if multiple connectors are waiting for the next attempt, there could be an issue
|
||||
# # where many connectors are "released" onto the server at the same time.
|
||||
# # That's not ideal ... but coming up with a mechanism for queueing
|
||||
# # all of these connectors is a bigger problem that we want to take on
|
||||
# # right now
|
||||
# try:
|
||||
# next_attempt = r.get(NEXT_RETRY_KEY)
|
||||
# if next_attempt is None:
|
||||
# next_attempt = 0
|
||||
# else:
|
||||
# next_attempt = int(cast(int, next_attempt))
|
||||
|
||||
# # TODO: all connectors need to be interruptible moving forward
|
||||
# while time.monotonic() < next_attempt:
|
||||
# time.sleep(1)
|
||||
# except ConnectionError:
|
||||
# pass
|
||||
|
||||
# return confluence_call(*args, **kwargs)
|
||||
# except HTTPError as e:
|
||||
# # Check if the response or headers are None to avoid potential AttributeError
|
||||
# if e.response is None or e.response.headers is None:
|
||||
# logger.warning("HTTPError with `None` as response or as headers")
|
||||
# raise e
|
||||
|
||||
# retry_after_header = e.response.headers.get("Retry-After")
|
||||
# if (
|
||||
# e.response.status_code == 429
|
||||
# or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
|
||||
# ):
|
||||
# retry_after = None
|
||||
# if retry_after_header is not None:
|
||||
# try:
|
||||
# retry_after = int(retry_after_header)
|
||||
# except ValueError:
|
||||
# pass
|
||||
|
||||
# if retry_after is not None:
|
||||
# if retry_after > max_retry_after:
|
||||
# logger.warning(
|
||||
# f"Clamping retry_after from {retry_after} to {max_delay} seconds..."
|
||||
# )
|
||||
# retry_after = max_delay
|
||||
|
||||
# logger.warning(
|
||||
# f"Rate limit hit. Retrying after {retry_after} seconds..."
|
||||
# )
|
||||
# try:
|
||||
# r.set(
|
||||
# NEXT_RETRY_KEY,
|
||||
# math.ceil(time.monotonic() + retry_after),
|
||||
# )
|
||||
# except ConnectionError:
|
||||
# pass
|
||||
# else:
|
||||
# logger.warning(
|
||||
# "Rate limit hit. Retrying with exponential backoff..."
|
||||
# )
|
||||
# delay = min(starting_delay * (backoff**attempt), max_delay)
|
||||
# delay_until = math.ceil(time.monotonic() + delay)
|
||||
|
||||
# try:
|
||||
# r.set(NEXT_RETRY_KEY, delay_until)
|
||||
# except ConnectionError:
|
||||
# while time.monotonic() < delay_until:
|
||||
# time.sleep(1)
|
||||
# else:
|
||||
# # re-raise, let caller handle
|
||||
# raise
|
||||
# except AttributeError as e:
|
||||
# # Some error within the Confluence library, unclear why it fails.
|
||||
# # Users reported it to be intermittent, so just retry
|
||||
# logger.warning(f"Confluence Internal Error, retrying... {e}")
|
||||
# delay = min(starting_delay * (backoff**attempt), max_delay)
|
||||
# delay_until = math.ceil(time.monotonic() + delay)
|
||||
# try:
|
||||
# r.set(NEXT_RETRY_KEY, delay_until)
|
||||
# except ConnectionError:
|
||||
# while time.monotonic() < delay_until:
|
||||
# time.sleep(1)
|
||||
|
||||
# if attempt == max_retries - 1:
|
||||
# raise e
|
||||
|
||||
# return cast(F, wrapped_call)
|
||||
|
||||
|
||||
def _handle_http_error(e: HTTPError, attempt: int) -> int:
|
||||
MIN_DELAY = 2
|
||||
MAX_DELAY = 60
|
||||
STARTING_DELAY = 5
|
||||
BACKOFF = 2
|
||||
|
||||
# Check if the response or headers are None to avoid potential AttributeError
|
||||
if e.response is None or e.response.headers is None:
|
||||
logger.warning("HTTPError with `None` as response or as headers")
|
||||
raise e
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
):
|
||||
raise e
|
||||
|
||||
retry_after = None
|
||||
|
||||
retry_after_header = e.response.headers.get("Retry-After")
|
||||
if retry_after_header is not None:
|
||||
try:
|
||||
retry_after = int(retry_after_header)
|
||||
if retry_after > MAX_DELAY:
|
||||
logger.warning(
|
||||
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
|
||||
)
|
||||
retry_after = MAX_DELAY
|
||||
if retry_after < MIN_DELAY:
|
||||
retry_after = MIN_DELAY
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if retry_after is not None:
|
||||
logger.warning(
|
||||
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
delay = retry_after
|
||||
else:
|
||||
logger.warning(
|
||||
"Rate limiting without retry header. Retrying with exponential backoff..."
|
||||
)
|
||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||
|
||||
delay_until = math.ceil(time.monotonic() + delay)
|
||||
return delay_until
|
||||
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
# this uses the native rate limiting option provided by the
|
||||
# confluence client and otherwise applies a simpler set of error handling
|
||||
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
TIMEOUT = 3600
|
||||
timeout_at = time.monotonic() + TIMEOUT
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
if time.monotonic() > timeout_at:
|
||||
raise TimeoutError(
|
||||
f"Confluence call attempts took longer than {TIMEOUT} seconds."
|
||||
)
|
||||
|
||||
try:
|
||||
# we're relying more on the client to rate limit itself
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
while time.monotonic() < delay_until:
|
||||
# in the future, check a signal here to exit
|
||||
time.sleep(1)
|
||||
except AttributeError as e:
|
||||
# Some error within the Confluence library, unclear why it fails.
|
||||
# Users reported it to be intermittent, so just retry
|
||||
if attempt == MAX_RETRIES - 1:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Confluence Client raised an AttributeError. Retrying..."
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
@@ -1,214 +0,0 @@
|
||||
import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
import bs4
|
||||
|
||||
from danswer.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from danswer.connectors.confluence.onyx_confluence import (
|
||||
OnyxConfluence,
|
||||
)
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.file_processing.html_utils import format_document_soup
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
|
||||
|
||||
def get_user_email_from_username__server(
|
||||
confluence_client: OnyxConfluence, user_name: str
|
||||
) -> str | None:
|
||||
global _USER_EMAIL_CACHE
|
||||
if _USER_EMAIL_CACHE.get(user_name) is None:
|
||||
try:
|
||||
response = confluence_client.get_mobile_parameters(user_name)
|
||||
email = response.get("email")
|
||||
except Exception:
|
||||
email = None
|
||||
_USER_EMAIL_CACHE[user_name] = email
|
||||
return _USER_EMAIL_CACHE[user_name]
|
||||
|
||||
|
||||
_USER_NOT_FOUND = "Unknown Confluence User"
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
|
||||
|
||||
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
||||
"""Get Confluence Display Name based on the account-id or userkey value
|
||||
|
||||
Args:
|
||||
user_id (str): The user id (i.e: the account-id or userkey)
|
||||
confluence_client (Confluence): The Confluence Client
|
||||
|
||||
Returns:
|
||||
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
|
||||
"""
|
||||
global _USER_ID_TO_DISPLAY_NAME_CACHE
|
||||
if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None:
|
||||
try:
|
||||
result = confluence_client.get_user_details_by_userkey(user_id)
|
||||
found_display_name = result.get("displayName")
|
||||
except Exception:
|
||||
found_display_name = None
|
||||
|
||||
if not found_display_name:
|
||||
try:
|
||||
result = confluence_client.get_user_details_by_accountid(user_id)
|
||||
found_display_name = result.get("displayName")
|
||||
except Exception:
|
||||
found_display_name = None
|
||||
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name
|
||||
|
||||
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
|
||||
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: OnyxConfluence, confluence_object: dict[str, Any]
|
||||
) -> str:
|
||||
"""Parse a Confluence html page and replace the 'user Id' by the real
|
||||
User Display Name
|
||||
|
||||
Args:
|
||||
confluence_object (dict): The confluence object as a dict
|
||||
confluence_client (Confluence): Confluence client
|
||||
|
||||
Returns:
|
||||
str: loaded and formated Confluence page
|
||||
"""
|
||||
body = confluence_object["body"]
|
||||
object_html = body.get("storage", body.get("view", {})).get("value")
|
||||
|
||||
soup = bs4.BeautifulSoup(object_html, "html.parser")
|
||||
for user in soup.findAll("ri:user"):
|
||||
user_id = (
|
||||
user.attrs["ri:account-id"]
|
||||
if "ri:account-id" in user.attrs
|
||||
else user.get("ri:userkey")
|
||||
)
|
||||
if not user_id:
|
||||
logger.warning(
|
||||
"ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}"
|
||||
)
|
||||
continue
|
||||
# Include @ sign for tagging, more clear for LLM
|
||||
user.replaceWith("@" + _get_user(confluence_client, user_id))
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if attachment["metadata"]["mediaType"] in [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"video/mp4",
|
||||
"video/quicktime",
|
||||
]:
|
||||
return None
|
||||
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
|
||||
response = confluence_client._session.get(download_link)
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
|
||||
)
|
||||
return None
|
||||
|
||||
extracted_text = extract_file_text(
|
||||
io.BytesIO(response.content),
|
||||
file_name=attachment["title"],
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to char count. "
|
||||
f"char count={len(extracted_text)} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
return extracted_text
|
||||
|
||||
|
||||
def build_confluence_document_id(base_url: str, content_url: str) -> str:
|
||||
"""For confluence, the document id is the page url for a page based document
|
||||
or the attachment download url for an attachment based document
|
||||
|
||||
Args:
|
||||
base_url (str): The base url of the Confluence instance
|
||||
content_url (str): The url of the page or attachment download url
|
||||
|
||||
Returns:
|
||||
str: The document id
|
||||
"""
|
||||
return f"{base_url}{content_url}"
|
||||
|
||||
|
||||
def extract_referenced_attachment_names(page_text: str) -> list[str]:
|
||||
"""Parse a Confluence html page to generate a list of current
|
||||
attachments in use
|
||||
|
||||
Args:
|
||||
text (str): The page content
|
||||
|
||||
Returns:
|
||||
list[str]: List of filenames currently in use by the page text
|
||||
"""
|
||||
referenced_attachment_filenames = []
|
||||
soup = bs4.BeautifulSoup(page_text, "html.parser")
|
||||
for attachment in soup.findAll("ri:attachment"):
|
||||
referenced_attachment_filenames.append(attachment.attrs["ri:filename"])
|
||||
return referenced_attachment_filenames
|
||||
|
||||
|
||||
def datetime_from_string(datetime_string: str) -> datetime:
|
||||
datetime_object = datetime.fromisoformat(datetime_string)
|
||||
|
||||
if datetime_object.tzinfo is None:
|
||||
# If no timezone info, assume it is UTC
|
||||
datetime_object = datetime_object.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# If not in UTC, translate it
|
||||
datetime_object = datetime_object.astimezone(timezone.utc)
|
||||
|
||||
return datetime_object
|
||||
|
||||
|
||||
def build_confluence_client(
|
||||
credentials_json: dict[str, Any], is_cloud: bool, wiki_base: str
|
||||
) -> OnyxConfluence:
|
||||
return OnyxConfluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
# Remove trailing slash from wiki_base if present
|
||||
url=wiki_base.rstrip("/"),
|
||||
# passing in username causes issues for Confluence data center
|
||||
username=credentials_json["confluence_username"] if is_cloud else None,
|
||||
password=credentials_json["confluence_access_token"] if is_cloud else None,
|
||||
token=credentials_json["confluence_access_token"] if not is_cloud else None,
|
||||
backoff_and_retry=True,
|
||||
max_backoff_retries=60,
|
||||
max_backoff_seconds=60,
|
||||
)
|
||||
@@ -22,18 +22,18 @@ def retry_builder(
|
||||
jitter: tuple[float, float] | float = 1,
|
||||
) -> Callable[[F], F]:
|
||||
"""Builds a generic wrapper/decorator for calls to external APIs that
|
||||
may fail due to rate limiting, flakes, or other reasons. Applies exponential
|
||||
may fail due to rate limiting, flakes, or other reasons. Applies expontential
|
||||
backoff with jitter to retry the call."""
|
||||
|
||||
@retry(
|
||||
tries=tries,
|
||||
delay=delay,
|
||||
max_delay=max_delay,
|
||||
backoff=backoff,
|
||||
jitter=jitter,
|
||||
logger=cast(Logger, logger),
|
||||
)
|
||||
def retry_with_default(func: F) -> F:
|
||||
@retry(
|
||||
tries=tries,
|
||||
delay=delay,
|
||||
max_delay=max_delay,
|
||||
backoff=backoff,
|
||||
jitter=jitter,
|
||||
logger=cast(Logger, logger),
|
||||
)
|
||||
def wrapped_func(*args: list, **kwargs: dict[str, Any]) -> Any:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@@ -14,6 +14,7 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_st
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
@@ -23,7 +24,6 @@ from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.html_utils import parse_html_page_basic
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.document360.utils import flatten_child_categories
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
@@ -21,7 +22,6 @@ from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.html_utils import parse_html_page_basic
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
|
||||
# Limitations and Potential Improvements
|
||||
# 1. The "Categories themselves contain potentially relevant information" but they're not pulled in
|
||||
|
||||
@@ -64,7 +64,7 @@ def identify_connector_class(
|
||||
DocumentSource.SLACK: {
|
||||
InputType.LOAD_STATE: SlackLoadConnector,
|
||||
InputType.POLL: SlackPollConnector,
|
||||
InputType.SLIM_RETRIEVAL: SlackPollConnector,
|
||||
InputType.PRUNE: SlackPollConnector,
|
||||
},
|
||||
DocumentSource.GITHUB: GithubConnector,
|
||||
DocumentSource.GMAIL: GmailConnector,
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
@@ -27,8 +28,7 @@ from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.file_processing.extract_file_text import read_text_file
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -175,7 +175,7 @@ class LocalFileConnector(LoadConnector):
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
documents: list[Document] = []
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self.tenant_id)
|
||||
token = current_tenant_id.set(self.tenant_id)
|
||||
|
||||
with get_session_with_tenant(self.tenant_id) as db_session:
|
||||
for file_path in self.file_locations:
|
||||
@@ -199,7 +199,7 @@ class LocalFileConnector(LoadConnector):
|
||||
if documents:
|
||||
yield documents
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -19,6 +19,7 @@ from danswer.configs.app_configs import GOOGLE_DRIVE_ONLY_ORG_PUBLIC
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.google_drive.connector_auth import get_google_drive_creds
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
@@ -39,7 +40,6 @@ from danswer.file_processing.unstructured import get_unstructured_api_key
|
||||
from danswer.file_processing.unstructured import unstructured_to_text
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -3,13 +3,11 @@ from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import SlimDocument
|
||||
|
||||
|
||||
SecondsSinceUnixEpoch = float
|
||||
|
||||
GenerateDocumentsOutput = Iterator[list[Document]]
|
||||
GenerateSlimDocumentOutput = Iterator[list[SlimDocument]]
|
||||
|
||||
|
||||
class BaseConnector(abc.ABC):
|
||||
@@ -54,9 +52,9 @@ class PollConnector(BaseConnector):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SlimConnector(BaseConnector):
|
||||
class IdConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
|
||||
def retrieve_all_source_ids(self) -> set[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -161,7 +161,7 @@ class LoopioConnector(LoadConnector, PollConnector):
|
||||
]
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=str(entry["id"]),
|
||||
id=entry["id"],
|
||||
sections=[Section(link=link, text=content_text)],
|
||||
source=DocumentSource.LOOPIO,
|
||||
semantic_identifier=questions[0],
|
||||
|
||||
@@ -14,7 +14,7 @@ class InputType(str, Enum):
|
||||
LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file
|
||||
POLL = "poll" # e.g. calling an API to get all documents in the last hour
|
||||
EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events
|
||||
SLIM_RETRIEVAL = "slim_retrieval"
|
||||
PRUNE = "prune"
|
||||
|
||||
|
||||
class ConnectorMissingCredentialError(PermissionError):
|
||||
@@ -169,11 +169,6 @@ class Document(DocumentBase):
|
||||
)
|
||||
|
||||
|
||||
class SlimDocument(BaseModel):
|
||||
id: str
|
||||
perm_sync_data: Any | None = None
|
||||
|
||||
|
||||
class DocumentErrorSummary(BaseModel):
|
||||
id: str
|
||||
semantic_id: str
|
||||
|
||||
@@ -241,29 +241,24 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
|
||||
# TODO there may be more types to handle here
|
||||
if isinstance(inner_dict, str):
|
||||
# For some objects the innermost value could just be a string, not sure what causes this
|
||||
return inner_dict
|
||||
if "name" in inner_dict:
|
||||
return inner_dict["name"]
|
||||
if "content" in inner_dict:
|
||||
return inner_dict["content"]
|
||||
start = inner_dict.get("start")
|
||||
end = inner_dict.get("end")
|
||||
if start is not None:
|
||||
if end is not None:
|
||||
return f"{start} - {end}"
|
||||
return start
|
||||
elif end is not None:
|
||||
return f"Until {end}"
|
||||
|
||||
elif isinstance(inner_dict, dict):
|
||||
if "name" in inner_dict:
|
||||
return inner_dict["name"]
|
||||
if "content" in inner_dict:
|
||||
return inner_dict["content"]
|
||||
start = inner_dict.get("start")
|
||||
end = inner_dict.get("end")
|
||||
if start is not None:
|
||||
if end is not None:
|
||||
return f"{start} - {end}"
|
||||
return start
|
||||
elif end is not None:
|
||||
return f"Until {end}"
|
||||
|
||||
if "id" in inner_dict:
|
||||
# This is not useful to index, it's a reference to another Notion object
|
||||
# and this ID value in plaintext is useless outside of the Notion context
|
||||
logger.debug("Skipping Notion object id field property")
|
||||
return None
|
||||
if "id" in inner_dict:
|
||||
# This is not useful to index, it's a reference to another Notion object
|
||||
# and this ID value in plaintext is useless outside of the Notion context
|
||||
logger.debug("Skipping Notion object id field property")
|
||||
return None
|
||||
|
||||
logger.debug(f"Unreadable property from innermost prop: {inner_dict}")
|
||||
return None
|
||||
@@ -273,13 +268,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
if not prop:
|
||||
continue
|
||||
|
||||
try:
|
||||
inner_value = _recurse_properties(prop)
|
||||
except Exception as e:
|
||||
# This is not a critical failure, these properties are not the actual contents of the page
|
||||
# more similar to metadata
|
||||
logger.warning(f"Error recursing properties for {prop_name}: {e}")
|
||||
continue
|
||||
inner_value = _recurse_properties(prop)
|
||||
# Not a perfect way to format Notion database tables but there's no perfect representation
|
||||
# since this must be represented as plaintext
|
||||
if inner_value:
|
||||
|
||||
@@ -11,25 +11,17 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.models import SlimDocument
|
||||
from danswer.connectors.salesforce.utils import extract_dict_text
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
# TODO: this connector does not work well at large scales
|
||||
# the large query against a large Salesforce instance has been reported to take 1.5 hours.
|
||||
# Additionally it seems to eat up more memory over time if the connection is long running (again a scale issue).
|
||||
|
||||
|
||||
DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
||||
MAX_QUERY_LENGTH = 10000 # max query length is 20,000 characters
|
||||
ID_PREFIX = "SALESFORCE_"
|
||||
@@ -37,7 +29,7 @@ ID_PREFIX = "SALESFORCE_"
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class SalesforceConnector(LoadConnector, PollConnector, IdConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
@@ -251,22 +243,19 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
end_datetime = datetime.utcfromtimestamp(end)
|
||||
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
|
||||
|
||||
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
|
||||
def retrieve_all_source_ids(self) -> set[str]:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
all_retrieved_ids: set[str] = set()
|
||||
for parent_object_type in self.parent_object_list:
|
||||
query = f"SELECT Id FROM {parent_object_type}"
|
||||
query_result = self.sf_client.query_all(query)
|
||||
doc_metadata_list.extend(
|
||||
SlimDocument(
|
||||
id=f"{ID_PREFIX}{instance_dict.get('Id', '')}",
|
||||
perm_sync_data={},
|
||||
)
|
||||
all_retrieved_ids.update(
|
||||
f"{ID_PREFIX}{instance_dict.get('Id', '')}"
|
||||
for instance_dict in query_result["records"]
|
||||
)
|
||||
|
||||
yield doc_metadata_list
|
||||
return all_retrieved_ids
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -13,15 +13,13 @@ from danswer.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.models import SlimDocument
|
||||
from danswer.connectors.slack.utils import expert_info_from_slack_id
|
||||
from danswer.connectors.slack.utils import get_message_link
|
||||
from danswer.connectors.slack.utils import make_paginated_slack_api_call_w_retries
|
||||
@@ -328,7 +326,7 @@ def _get_all_doc_ids(
|
||||
channels: list[str] | None = None,
|
||||
channel_name_regex_enabled: bool = False,
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
) -> set[str]:
|
||||
"""
|
||||
Get all document ids in the workspace, channel by channel
|
||||
This is pretty identical to get_all_docs, but it returns a set of ids instead of documents
|
||||
@@ -340,14 +338,13 @@ def _get_all_doc_ids(
|
||||
all_channels, channels, channel_name_regex_enabled
|
||||
)
|
||||
|
||||
all_doc_ids = set()
|
||||
for channel in filtered_channels:
|
||||
channel_id = channel["id"]
|
||||
channel_message_batches = get_channel_messages(
|
||||
client=client,
|
||||
channel=channel,
|
||||
)
|
||||
|
||||
message_ts_set: set[str] = set()
|
||||
for message_batch in channel_message_batches:
|
||||
for message in message_batch:
|
||||
if msg_filter_func(message):
|
||||
@@ -356,21 +353,12 @@ def _get_all_doc_ids(
|
||||
# The document id is the channel id and the ts of the first message in the thread
|
||||
# Since we already have the first message of the thread, we dont have to
|
||||
# fetch the thread for id retrieval, saving time and API calls
|
||||
message_ts_set.add(message["ts"])
|
||||
all_doc_ids.add(f"{channel['id']}__{message['ts']}")
|
||||
|
||||
channel_metadata_list: list[SlimDocument] = []
|
||||
for message_ts in message_ts_set:
|
||||
channel_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=f"{channel_id}__{message_ts}",
|
||||
perm_sync_data={"channel_id": channel_id},
|
||||
)
|
||||
)
|
||||
|
||||
yield channel_metadata_list
|
||||
return all_doc_ids
|
||||
|
||||
|
||||
class SlackPollConnector(PollConnector, SlimConnector):
|
||||
class SlackPollConnector(PollConnector, IdConnector):
|
||||
def __init__(
|
||||
self,
|
||||
workspace: str,
|
||||
@@ -391,7 +379,7 @@ class SlackPollConnector(PollConnector, SlimConnector):
|
||||
self.client = WebClient(token=bot_token)
|
||||
return None
|
||||
|
||||
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
|
||||
def retrieve_all_source_ids(self) -> set[str]:
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
|
||||
@@ -10,9 +10,9 @@ from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.web import SlackResponse
|
||||
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -373,7 +373,7 @@ class WebConnector(LoadConnector):
|
||||
page.close()
|
||||
except Exception as e:
|
||||
last_error = f"Failed to fetch '{current_url}': {e}"
|
||||
logger.exception(last_error)
|
||||
logger.error(last_error)
|
||||
playwright.stop()
|
||||
restart_playwright = True
|
||||
continue
|
||||
|
||||
@@ -211,7 +211,6 @@ def handle_regular_answer(
|
||||
use_citations=use_citations,
|
||||
danswerbot_flow=True,
|
||||
)
|
||||
|
||||
if not answer.error_msg:
|
||||
return answer
|
||||
else:
|
||||
|
||||
@@ -7,6 +7,7 @@ from slack_sdk import WebClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
|
||||
from danswer.background.celery.celery_app import get_all_tenant_ids
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
|
||||
@@ -46,7 +47,6 @@ from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
|
||||
from danswer.danswerbot.slack.utils import rephrase_slack_message
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import TenantSocketModeClient
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
@@ -57,7 +57,6 @@ from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import SLACK_CHANNEL_ID
|
||||
@@ -346,9 +345,7 @@ def process_message(
|
||||
respond_every_channel: bool = DANSWER_BOT_RESPOND_EVERY_CHANNEL,
|
||||
notify_no_answer: bool = NOTIFY_SLACKBOT_NO_ANSWER,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"Received Slack request of type: '{req.type}' for tenant, {client.tenant_id}"
|
||||
)
|
||||
logger.debug(f"Received Slack request of type: '{req.type}'")
|
||||
|
||||
# Throw out requests that can't or shouldn't be handled
|
||||
if not prefilter_requests(req, client):
|
||||
@@ -360,59 +357,51 @@ def process_message(
|
||||
client=client.web_client, channel_id=channel
|
||||
)
|
||||
|
||||
# Set the current tenant ID at the beginning for all DB calls within this thread
|
||||
if client.tenant_id:
|
||||
logger.info(f"Setting tenant ID to {client.tenant_id}")
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(client.tenant_id)
|
||||
try:
|
||||
with get_session_with_tenant(client.tenant_id) as db_session:
|
||||
slack_bot_config = get_slack_bot_config_for_channel(
|
||||
channel_name=channel_name, db_session=db_session
|
||||
)
|
||||
with get_session_with_tenant(client.tenant_id) as db_session:
|
||||
slack_bot_config = get_slack_bot_config_for_channel(
|
||||
channel_name=channel_name, db_session=db_session
|
||||
)
|
||||
|
||||
# Be careful about this default, don't want to accidentally spam every channel
|
||||
# Users should be able to DM slack bot in their private channels though
|
||||
if (
|
||||
slack_bot_config is None
|
||||
and not respond_every_channel
|
||||
# Can't have configs for DMs so don't toss them out
|
||||
and not is_dm
|
||||
# If /DanswerBot (is_bot_msg) or @DanswerBot (bypass_filters)
|
||||
# always respond with the default configs
|
||||
and not (details.is_bot_msg or details.bypass_filters)
|
||||
):
|
||||
return
|
||||
# Be careful about this default, don't want to accidentally spam every channel
|
||||
# Users should be able to DM slack bot in their private channels though
|
||||
if (
|
||||
slack_bot_config is None
|
||||
and not respond_every_channel
|
||||
# Can't have configs for DMs so don't toss them out
|
||||
and not is_dm
|
||||
# If /DanswerBot (is_bot_msg) or @DanswerBot (bypass_filters)
|
||||
# always respond with the default configs
|
||||
and not (details.is_bot_msg or details.bypass_filters)
|
||||
):
|
||||
return
|
||||
|
||||
follow_up = bool(
|
||||
slack_bot_config
|
||||
and slack_bot_config.channel_config
|
||||
and slack_bot_config.channel_config.get("follow_up_tags") is not None
|
||||
)
|
||||
feedback_reminder_id = schedule_feedback_reminder(
|
||||
details=details, client=client.web_client, include_followup=follow_up
|
||||
)
|
||||
follow_up = bool(
|
||||
slack_bot_config
|
||||
and slack_bot_config.channel_config
|
||||
and slack_bot_config.channel_config.get("follow_up_tags") is not None
|
||||
)
|
||||
feedback_reminder_id = schedule_feedback_reminder(
|
||||
details=details, client=client.web_client, include_followup=follow_up
|
||||
)
|
||||
|
||||
failed = handle_message(
|
||||
message_info=details,
|
||||
slack_bot_config=slack_bot_config,
|
||||
client=client.web_client,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
tenant_id=client.tenant_id,
|
||||
)
|
||||
failed = handle_message(
|
||||
message_info=details,
|
||||
slack_bot_config=slack_bot_config,
|
||||
client=client.web_client,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
tenant_id=client.tenant_id,
|
||||
)
|
||||
|
||||
if failed:
|
||||
if feedback_reminder_id:
|
||||
remove_scheduled_feedback_reminder(
|
||||
client=client.web_client,
|
||||
channel=details.sender,
|
||||
msg_id=feedback_reminder_id,
|
||||
)
|
||||
# Skipping answering due to pre-filtering is not considered a failure
|
||||
if notify_no_answer:
|
||||
apologize_for_fail(details, client)
|
||||
finally:
|
||||
if client.tenant_id:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
if failed:
|
||||
if feedback_reminder_id:
|
||||
remove_scheduled_feedback_reminder(
|
||||
client=client.web_client,
|
||||
channel=details.sender,
|
||||
msg_id=feedback_reminder_id,
|
||||
)
|
||||
# Skipping answering due to pre-filtering is not considered a failure
|
||||
if notify_no_answer:
|
||||
apologize_for_fail(details, client)
|
||||
|
||||
|
||||
def acknowledge_message(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
|
||||
@@ -510,9 +499,7 @@ if __name__ == "__main__":
|
||||
for tenant_id in tenant_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id or "public")
|
||||
latest_slack_bot_tokens = fetch_tokens()
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
if (
|
||||
tenant_id not in slack_bot_tokens
|
||||
@@ -546,11 +533,6 @@ if __name__ == "__main__":
|
||||
socket_client = _get_socket_client(
|
||||
latest_slack_bot_tokens, tenant_id
|
||||
)
|
||||
|
||||
# Initialize socket client for this tenant. Each tenant has its own
|
||||
# socket client, allowing for multiple concurrent connections (one
|
||||
# per tenant) with the tenant ID wrapped in the socket model client.
|
||||
# Each `connect` stores websocket connection in a separate thread.
|
||||
_initialize_socket_client(socket_client)
|
||||
|
||||
socket_clients[tenant_id] = socket_client
|
||||
|
||||
@@ -341,8 +341,6 @@ def add_credential_to_connector(
|
||||
access_type: AccessType,
|
||||
groups: list[int] | None,
|
||||
auto_sync_options: dict | None = None,
|
||||
initial_status: ConnectorCredentialPairStatus = ConnectorCredentialPairStatus.ACTIVE,
|
||||
last_successful_index_time: datetime | None = None,
|
||||
) -> StatusResponse:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
@@ -386,10 +384,9 @@ def add_credential_to_connector(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
name=cc_pair_name,
|
||||
status=initial_status,
|
||||
status=ConnectorCredentialPairStatus.ACTIVE,
|
||||
access_type=access_type,
|
||||
auto_sync_options=auto_sync_options,
|
||||
last_successful_index_time=last_successful_index_time,
|
||||
)
|
||||
db_session.add(association)
|
||||
db_session.flush() # make sure the association has an id
|
||||
|
||||
@@ -40,8 +40,6 @@ CREDENTIAL_PERMISSIONS_TO_IGNORE = {
|
||||
DocumentSource.MEDIAWIKI,
|
||||
}
|
||||
|
||||
PUBLIC_CREDENTIAL_ID = 0
|
||||
|
||||
|
||||
def _add_user_filters(
|
||||
stmt: Select,
|
||||
@@ -243,7 +241,7 @@ def create_credential(
|
||||
curator_public=credential_data.curator_public,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush() # This ensures the credential gets an ID
|
||||
db_session.flush() # This ensures the credential gets an IDcredentials
|
||||
_relate_credential_to_user_groups__no_commit(
|
||||
db_session=db_session,
|
||||
credential_id=credential.id,
|
||||
@@ -386,11 +384,12 @@ def delete_credential(
|
||||
|
||||
|
||||
def create_initial_public_credential(db_session: Session) -> None:
|
||||
public_cred_id = 0
|
||||
error_msg = (
|
||||
"DB is not in a valid initial state."
|
||||
"There must exist an empty public credential for data connectors that do not require additional Auth."
|
||||
)
|
||||
first_credential = fetch_credential_by_id(PUBLIC_CREDENTIAL_ID, None, db_session)
|
||||
first_credential = fetch_credential_by_id(public_cred_id, None, db_session)
|
||||
|
||||
if first_credential is not None:
|
||||
if first_credential.credential_json != {} or first_credential.user is not None:
|
||||
@@ -398,7 +397,7 @@ def create_initial_public_credential(db_session: Session) -> None:
|
||||
return
|
||||
|
||||
credential = Credential(
|
||||
id=PUBLIC_CREDENTIAL_ID,
|
||||
id=public_cred_id,
|
||||
credential_json={},
|
||||
user_id=None,
|
||||
)
|
||||
@@ -406,24 +405,6 @@ def create_initial_public_credential(db_session: Session) -> None:
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def cleanup_gmail_credentials(db_session: Session) -> None:
|
||||
gmail_credentials = fetch_credentials_by_source(
|
||||
db_session=db_session, user=None, document_source=DocumentSource.GMAIL
|
||||
)
|
||||
for credential in gmail_credentials:
|
||||
db_session.delete(credential)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def cleanup_google_drive_credentials(db_session: Session) -> None:
|
||||
google_drive_credentials = fetch_credentials_by_source(
|
||||
db_session=db_session, user=None, document_source=DocumentSource.GOOGLE_DRIVE
|
||||
)
|
||||
for credential in google_drive_credentials:
|
||||
db_session.delete(credential)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_gmail_service_account_credentials(
|
||||
user: User | None, db_session: Session
|
||||
) -> None:
|
||||
|
||||
@@ -375,20 +375,6 @@ def update_docs_last_modified__no_commit(
|
||||
doc.last_modified = now
|
||||
|
||||
|
||||
def mark_document_as_modified(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
stmt = select(DbDocument).where(DbDocument.id == document_id)
|
||||
doc = db_session.scalar(stmt)
|
||||
if doc is None:
|
||||
raise ValueError(f"No document with ID: {document_id}")
|
||||
|
||||
# update last_synced
|
||||
doc.last_modified = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_document_as_synced(document_id: str, db_session: Session) -> None:
|
||||
stmt = select(DbDocument).where(DbDocument.id == document_id)
|
||||
doc = db_session.scalar(stmt)
|
||||
|
||||
@@ -36,11 +36,10 @@ from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
|
||||
from danswer.configs.app_configs import POSTGRES_PORT
|
||||
from danswer.configs.app_configs import POSTGRES_USER
|
||||
from danswer.configs.app_configs import SECRET_JWT_KEY
|
||||
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
|
||||
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from danswer.configs.constants import TENANT_ID_PREFIX
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -189,29 +188,6 @@ class SqlEngine:
|
||||
return cls._app_name
|
||||
|
||||
|
||||
def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
if not MULTI_TENANT:
|
||||
return [None]
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as session:
|
||||
result = session.execute(
|
||||
text(
|
||||
f"""
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', '{POSTGRES_DEFAULT_SCHEMA}')"""
|
||||
)
|
||||
)
|
||||
tenant_ids = [row[0] for row in result]
|
||||
|
||||
valid_tenants = [
|
||||
tenant
|
||||
for tenant in tenant_ids
|
||||
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
|
||||
]
|
||||
|
||||
return valid_tenants
|
||||
|
||||
|
||||
def build_connection_string(
|
||||
*,
|
||||
db_api: str = ASYNC_DB_API,
|
||||
@@ -260,12 +236,12 @@ def get_current_tenant_id(request: Request) -> str:
|
||||
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
|
||||
if not MULTI_TENANT:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
current_tenant_id.set(tenant_id)
|
||||
return tenant_id
|
||||
|
||||
token = request.cookies.get("tenant_details")
|
||||
if not token:
|
||||
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
current_value = current_tenant_id.get()
|
||||
# If no token is present, use the default schema or handle accordingly
|
||||
return current_value
|
||||
|
||||
@@ -273,14 +249,14 @@ def get_current_tenant_id(request: Request) -> str:
|
||||
payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"])
|
||||
tenant_id = payload.get("tenant_id")
|
||||
if not tenant_id:
|
||||
return CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
return current_tenant_id.get()
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
current_tenant_id.set(tenant_id)
|
||||
|
||||
return tenant_id
|
||||
except jwt.InvalidTokenError:
|
||||
return CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
return current_tenant_id.get()
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in get_current_tenant_id: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
@@ -291,7 +267,7 @@ async def get_async_session_with_tenant(
|
||||
tenant_id: str | None = None,
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
if tenant_id is None:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
tenant_id = current_tenant_id.get()
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
logger.error(f"Invalid tenant ID: {tenant_id}")
|
||||
@@ -323,9 +299,9 @@ def get_session_with_tenant(
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
if tenant_id is None:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
tenant_id = current_tenant_id.get()
|
||||
else:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
current_tenant_id.set(tenant_id)
|
||||
|
||||
event.listen(engine, "checkout", set_search_path_on_checkout)
|
||||
|
||||
@@ -361,22 +337,26 @@ def get_session_with_tenant(
|
||||
def set_search_path_on_checkout(
|
||||
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
|
||||
) -> None:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
tenant_id = current_tenant_id.get()
|
||||
if tenant_id and is_valid_schema_name(tenant_id):
|
||||
with dbapi_conn.cursor() as cursor:
|
||||
cursor.execute(f'SET search_path TO "{tenant_id}"')
|
||||
logger.debug(
|
||||
f"Set search_path to {tenant_id} for connection {connection_record}"
|
||||
)
|
||||
|
||||
|
||||
def get_session_generator_with_tenant() -> Generator[Session, None, None]:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
def get_session_generator_with_tenant(
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[Session, None, None]:
|
||||
with get_session_with_tenant(tenant_id) as session:
|
||||
yield session
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
"""Generate a database session with the appropriate tenant schema set."""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
|
||||
tenant_id = current_tenant_id.get()
|
||||
if tenant_id == "public" and MULTI_TENANT:
|
||||
raise HTTPException(status_code=401, detail="User must authenticate")
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
@@ -392,7 +372,7 @@ def get_session() -> Generator[Session, None, None]:
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Generate an async database session with the appropriate tenant schema set."""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
tenant_id = current_tenant_id.get()
|
||||
engine = get_sqlalchemy_async_engine()
|
||||
async with AsyncSession(engine, expire_on_commit=False) as async_session:
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import and_
|
||||
@@ -67,32 +66,6 @@ def create_index_attempt(
|
||||
return new_attempt.id
|
||||
|
||||
|
||||
def mock_successful_index_attempt(
|
||||
connector_credential_pair_id: int,
|
||||
search_settings_id: int,
|
||||
docs_indexed: int,
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
"""Should not be used in any user triggered flows"""
|
||||
db_time = func.now()
|
||||
new_attempt = IndexAttempt(
|
||||
connector_credential_pair_id=connector_credential_pair_id,
|
||||
search_settings_id=search_settings_id,
|
||||
from_beginning=True,
|
||||
status=IndexingStatus.SUCCESS,
|
||||
total_docs_indexed=docs_indexed,
|
||||
new_docs_indexed=docs_indexed,
|
||||
# Need this to be some convincing random looking value and it can't be 0
|
||||
# or the indexing rate would calculate out to infinity
|
||||
time_started=db_time - timedelta(seconds=1.92),
|
||||
time_updated=db_time,
|
||||
)
|
||||
db_session.add(new_attempt)
|
||||
db_session.commit()
|
||||
|
||||
return new_attempt.id
|
||||
|
||||
|
||||
def get_in_progress_index_attempts(
|
||||
connector_id: int | None,
|
||||
db_session: Session,
|
||||
|
||||
@@ -95,11 +95,10 @@ def upsert_llm_provider(
|
||||
group_ids=llm_provider.groups,
|
||||
db_session=db_session,
|
||||
)
|
||||
full_llm_provider = FullLLMProvider.from_model(existing_llm_provider)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return full_llm_provider
|
||||
return FullLLMProvider.from_model(existing_llm_provider)
|
||||
|
||||
|
||||
def fetch_existing_embedding_providers(
|
||||
|
||||
@@ -56,7 +56,7 @@ def get_notification_by_id(
|
||||
if not notif:
|
||||
raise ValueError(f"No notification found with id {notification_id}")
|
||||
if notif.user_id != user_id and not (
|
||||
notif.user_id is None and user is not None and user.role == UserRole.ADMIN
|
||||
notif.user_id is None and user.role == UserRole.ADMIN
|
||||
):
|
||||
raise PermissionError(
|
||||
f"User {user_id} is not authorized to access notification {notification_id}"
|
||||
|
||||
@@ -328,6 +328,7 @@ def update_all_personas_display_priority(
|
||||
|
||||
for persona in personas:
|
||||
persona.display_priority = display_priority_map[persona.id]
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
from retry import retry
|
||||
|
||||
from danswer.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
|
||||
@@ -193,21 +194,20 @@ def _get_chunks_via_visit_api(
|
||||
|
||||
document_chunks: list[dict] = []
|
||||
while True:
|
||||
response = requests.get(url, params=params)
|
||||
try:
|
||||
filtered_params = {k: v for k, v in params.items() if v is not None}
|
||||
with get_vespa_http_client() as http_client:
|
||||
response = http_client.get(url, params=filtered_params)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as e:
|
||||
error_base = "Failed to query Vespa"
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
request_info = f"Headers: {response.request.headers}\nPayload: {params}"
|
||||
response_info = f"Status Code: {response.status_code}\nResponse Content: {response.text}"
|
||||
error_base = f"Error occurred getting chunk by Document ID {chunk_request.document_id}"
|
||||
logger.error(
|
||||
f"{error_base}:\n"
|
||||
f"Request URL: {e.request.url}\n"
|
||||
f"Request Headers: {e.request.headers}\n"
|
||||
f"Request Payload: {params}\n"
|
||||
f"Exception: {str(e)}"
|
||||
f"{request_info}\n"
|
||||
f"{response_info}\n"
|
||||
f"Exception: {e}"
|
||||
)
|
||||
raise httpx.HTTPError(error_base) from e
|
||||
raise requests.HTTPError(error_base) from e
|
||||
|
||||
# Check if the response contains any documents
|
||||
response_data = response.json()
|
||||
@@ -301,13 +301,17 @@ def query_vespa(
|
||||
response = http_client.post(SEARCH_ENDPOINT, json=params)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as e:
|
||||
request_info = f"Headers: {response.request.headers}\nPayload: {params}"
|
||||
response_info = (
|
||||
f"Status Code: {response.status_code}\n"
|
||||
f"Response Content: {response.text}"
|
||||
)
|
||||
error_base = "Failed to query Vespa"
|
||||
logger.error(
|
||||
f"{error_base}:\n"
|
||||
f"Request URL: {e.request.url}\n"
|
||||
f"Request Headers: {e.request.headers}\n"
|
||||
f"Request Payload: {params}\n"
|
||||
f"Exception: {str(e)}"
|
||||
f"{request_info}\n"
|
||||
f"{response_info}\n"
|
||||
f"Exception: {e}"
|
||||
)
|
||||
raise httpx.HTTPError(error_base) from e
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ def get_existing_documents_from_chunks(
|
||||
return document_ids
|
||||
|
||||
|
||||
@retry(tries=5, delay=1, backoff=2)
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _index_vespa_chunk(
|
||||
chunk: DocMetadataAwareIndexChunk,
|
||||
index_name: str,
|
||||
|
||||
@@ -103,9 +103,6 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
self,
|
||||
chunks: list[DocAwareChunk],
|
||||
) -> list[IndexChunk]:
|
||||
"""Adds embeddings to the chunks, the title and metadata suffixes are added to the chunk as well
|
||||
if they exist. If there is no space for it, it would have been thrown out at the chunking step.
|
||||
"""
|
||||
# All chunks at this point must have some non-empty content
|
||||
flat_chunk_texts: list[str] = []
|
||||
large_chunks_present = False
|
||||
@@ -124,11 +121,6 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
flat_chunk_texts.append(chunk_text)
|
||||
|
||||
if chunk.mini_chunk_texts:
|
||||
if chunk.large_chunk_reference_ids:
|
||||
# A large chunk does not contain mini chunks, if it matches the large chunk
|
||||
# with a high score, then mini chunks would not be used anyway
|
||||
# otherwise it should match the normal chunk
|
||||
raise RuntimeError("Large chunk contains mini chunks")
|
||||
flat_chunk_texts.extend(chunk.mini_chunk_texts)
|
||||
|
||||
embeddings = self.embedding_model.encode(
|
||||
|
||||
@@ -195,8 +195,6 @@ def index_doc_batch_prepare(
|
||||
db_session: Session,
|
||||
ignore_time_skip: bool = False,
|
||||
) -> DocumentBatchPrepareContext | None:
|
||||
"""This sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
|
||||
This preceeds indexing it into the actual document index."""
|
||||
documents = []
|
||||
for document in document_batch:
|
||||
empty_contents = not any(section.text.strip() for section in document.sections)
|
||||
|
||||
@@ -3,6 +3,5 @@ from danswer.key_value_store.store import PgRedisKVStore
|
||||
|
||||
|
||||
def get_kv_store() -> KeyValueStore:
|
||||
# In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in
|
||||
# It's read from the global thread level variable
|
||||
# this is the only one supported currently
|
||||
return PgRedisKVStore()
|
||||
|
||||
@@ -14,8 +14,6 @@ class KvKeyNotFoundError(Exception):
|
||||
|
||||
|
||||
class KeyValueStore:
|
||||
# In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in
|
||||
# It's read from the global thread level variable
|
||||
@abc.abstractmethod
|
||||
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -4,7 +4,6 @@ from contextlib import contextmanager
|
||||
from typing import cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
from redis.client import Redis
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -17,8 +16,8 @@ from danswer.key_value_store.interface import KeyValueStore
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -28,23 +27,17 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
|
||||
|
||||
|
||||
class PgRedisKVStore(KeyValueStore):
|
||||
def __init__(
|
||||
self, redis_client: Redis | None = None, tenant_id: str | None = None
|
||||
) -> None:
|
||||
# If no redis_client is provided, fall back to the context var
|
||||
if redis_client is not None:
|
||||
self.redis_client = redis_client
|
||||
else:
|
||||
tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
self.redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
def __init__(self) -> None:
|
||||
tenant_id = current_tenant_id.get()
|
||||
self.redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
@contextmanager
|
||||
def get_session(self) -> Iterator[Session]:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine, expire_on_commit=False) as session:
|
||||
if MULTI_TENANT:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA:
|
||||
tenant_id = current_tenant_id.get()
|
||||
if tenant_id == "public":
|
||||
raise HTTPException(
|
||||
status_code=401, detail="User must authenticate"
|
||||
)
|
||||
|
||||
@@ -79,6 +79,8 @@ def _get_answer_stream_processor(
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
answer_style_configs: AnswerStyleConfig,
|
||||
) -> StreamProcessor:
|
||||
print("ANSWERR STYES")
|
||||
print(answer_style_configs.__dict__)
|
||||
if answer_style_configs.citation_config:
|
||||
return build_citation_processor(
|
||||
context_docs=context_docs, doc_id_to_rank_map=doc_id_to_rank_map
|
||||
@@ -237,7 +239,6 @@ class Answer:
|
||||
prompt=prompt,
|
||||
tools=final_tool_definitions if final_tool_definitions else None,
|
||||
tool_choice="required" if self.force_use_tool.force_use else None,
|
||||
structured_response_format=self.answer_style_config.structured_response_format,
|
||||
):
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
@@ -332,10 +333,7 @@ class Answer:
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
) -> Iterator[str | StreamStopInfo]:
|
||||
for message in self.llm.stream(
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
structured_response_format=self.answer_style_config.structured_response_format,
|
||||
prompt=prompt, tools=tools, tool_choice=tool_choice
|
||||
):
|
||||
if isinstance(message, AIMessageChunk):
|
||||
if message.content:
|
||||
|
||||
@@ -116,10 +116,6 @@ class AnswerStyleConfig(BaseModel):
|
||||
document_pruning_config: DocumentPruningConfig = Field(
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
# forces the LLM to return a structured response, see
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
# right now, only used by the simple chat API
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
|
||||
|
||||
@@ -226,6 +226,7 @@ def process_model_tokens(
|
||||
hold_quote = ""
|
||||
|
||||
for token in tokens:
|
||||
print(f"Token: {token}")
|
||||
model_previous = model_output
|
||||
model_output += token
|
||||
|
||||
|
||||
@@ -280,7 +280,6 @@ class DefaultMultiLLM(LLM):
|
||||
tools: list[dict] | None,
|
||||
tool_choice: ToolChoiceOptions | None,
|
||||
stream: bool,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
|
||||
if isinstance(prompt, list):
|
||||
prompt = [
|
||||
@@ -314,11 +313,6 @@ class DefaultMultiLLM(LLM):
|
||||
# NOTE: we can't pass this in if tools are not specified
|
||||
# or else OpenAI throws an error
|
||||
**({"parallel_tool_calls": False} if tools else {}),
|
||||
**(
|
||||
{"response_format": structured_response_format}
|
||||
if structured_response_format
|
||||
else {}
|
||||
),
|
||||
**self._model_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -342,16 +336,12 @@ class DefaultMultiLLM(LLM):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> BaseMessage:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
response = cast(
|
||||
litellm.ModelResponse,
|
||||
self._completion(
|
||||
prompt, tools, tool_choice, False, structured_response_format
|
||||
),
|
||||
litellm.ModelResponse, self._completion(prompt, tools, tool_choice, False)
|
||||
)
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "message"):
|
||||
@@ -364,21 +354,18 @@ class DefaultMultiLLM(LLM):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
if DISABLE_LITELLM_STREAMING:
|
||||
yield self.invoke(prompt, tools, tool_choice, structured_response_format)
|
||||
yield self.invoke(prompt)
|
||||
return
|
||||
|
||||
output = None
|
||||
response = cast(
|
||||
litellm.CustomStreamWrapper,
|
||||
self._completion(
|
||||
prompt, tools, tool_choice, True, structured_response_format
|
||||
),
|
||||
self._completion(prompt, tools, tool_choice, True),
|
||||
)
|
||||
try:
|
||||
for part in response:
|
||||
|
||||
@@ -80,7 +80,6 @@ class CustomModelServer(LLM):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> BaseMessage:
|
||||
return self._execute(prompt)
|
||||
|
||||
@@ -89,6 +88,5 @@ class CustomModelServer(LLM):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
yield self._execute(prompt)
|
||||
|
||||
@@ -51,7 +51,6 @@ def get_llms_for_persona(
|
||||
return get_llm(
|
||||
provider=llm_provider.provider,
|
||||
model=model,
|
||||
deployment_name=llm_provider.deployment_name,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
@@ -105,7 +104,7 @@ def get_default_llms(
|
||||
def get_llm(
|
||||
provider: str,
|
||||
model: str,
|
||||
deployment_name: str | None,
|
||||
deployment_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
@@ -117,7 +116,6 @@ def get_llm(
|
||||
return DefaultMultiLLM(
|
||||
model_provider=provider,
|
||||
model_name=model,
|
||||
deployment_name=deployment_name,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
|
||||
@@ -88,14 +88,11 @@ class LLM(abc.ABC):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> BaseMessage:
|
||||
self._precall(prompt)
|
||||
# TODO add a postcall to log model outputs independent of concrete class
|
||||
# implementation
|
||||
return self._invoke_implementation(
|
||||
prompt, tools, tool_choice, structured_response_format
|
||||
)
|
||||
return self._invoke_implementation(prompt, tools, tool_choice)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _invoke_implementation(
|
||||
@@ -103,7 +100,6 @@ class LLM(abc.ABC):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> BaseMessage:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -112,14 +108,11 @@ class LLM(abc.ABC):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
self._precall(prompt)
|
||||
# TODO add a postcall to log model outputs independent of concrete class
|
||||
# implementation
|
||||
return self._stream_implementation(
|
||||
prompt, tools, tool_choice, structured_response_format
|
||||
)
|
||||
return self._stream_implementation(prompt, tools, tool_choice)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _stream_implementation(
|
||||
@@ -127,6 +120,5 @@ class LLM(abc.ABC):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -61,7 +61,6 @@ BEDROCK_MODEL_NAMES = [
|
||||
IGNORABLE_ANTHROPIC_MODELS = [
|
||||
"claude-2",
|
||||
"claude-instant-1",
|
||||
"anthropic/claude-3-5-sonnet-20241022",
|
||||
]
|
||||
ANTHROPIC_PROVIDER_NAME = "anthropic"
|
||||
ANTHROPIC_MODEL_NAMES = [
|
||||
@@ -101,8 +100,8 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
api_version_required=False,
|
||||
custom_config_keys=[],
|
||||
llm_names=fetch_models_for_provider(ANTHROPIC_PROVIDER_NAME),
|
||||
default_model="claude-3-5-sonnet-20241022",
|
||||
default_fast_model="claude-3-5-sonnet-20241022",
|
||||
default_model="claude-3-5-sonnet-20240620",
|
||||
default_fast_model="claude-3-5-sonnet-20240620",
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=AZURE_PROVIDER_NAME,
|
||||
@@ -136,8 +135,8 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
),
|
||||
],
|
||||
llm_names=fetch_models_for_provider(BEDROCK_PROVIDER_NAME),
|
||||
default_model="anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
default_fast_model="anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
default_model="anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
default_fast_model="anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -342,26 +342,12 @@ def get_llm_max_tokens(
|
||||
|
||||
try:
|
||||
model_obj = model_map.get(f"{model_provider}/{model_name}")
|
||||
if model_obj:
|
||||
if not model_obj:
|
||||
model_obj = model_map[model_name]
|
||||
logger.debug(f"Using model object for {model_name}")
|
||||
else:
|
||||
logger.debug(f"Using model object for {model_provider}/{model_name}")
|
||||
|
||||
if not model_obj:
|
||||
model_obj = model_map.get(model_name)
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {model_name}")
|
||||
|
||||
if not model_obj:
|
||||
model_name_split = model_name.split("/")
|
||||
if len(model_name_split) > 1:
|
||||
model_obj = model_map.get(model_name_split[1])
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {model_name_split[1]}")
|
||||
|
||||
if not model_obj:
|
||||
raise RuntimeError(
|
||||
f"No litellm entry found for {model_provider}/{model_name}"
|
||||
)
|
||||
|
||||
if "max_input_tokens" in model_obj:
|
||||
max_tokens = model_obj["max_input_tokens"]
|
||||
logger.info(
|
||||
|
||||
@@ -96,7 +96,6 @@ from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import CORS_ALLOWED_ORIGIN
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -214,7 +213,7 @@ def get_application() -> FastAPI:
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[StarletteIntegration(), FastApiIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
traces_sample_rate=0.5,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
|
||||
@@ -61,30 +61,6 @@ class TenantRedis(redis.Redis):
|
||||
|
||||
return wrapper
|
||||
|
||||
def _prefix_scan_iter(self, method: Callable) -> Callable:
|
||||
@functools.wraps(method)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
# Prefix the match pattern if provided
|
||||
if "match" in kwargs:
|
||||
kwargs["match"] = self._prefixed(kwargs["match"])
|
||||
elif len(args) > 0:
|
||||
args = (self._prefixed(args[0]),) + args[1:]
|
||||
|
||||
# Get the iterator
|
||||
iterator = method(*args, **kwargs)
|
||||
|
||||
# Remove prefix from returned keys
|
||||
prefix = f"{self.tenant_id}:".encode()
|
||||
prefix_len = len(prefix)
|
||||
|
||||
for key in iterator:
|
||||
if isinstance(key, bytes) and key.startswith(prefix):
|
||||
yield key[prefix_len:]
|
||||
else:
|
||||
yield key
|
||||
|
||||
return wrapper
|
||||
|
||||
def __getattribute__(self, item: str) -> Any:
|
||||
original_attr = super().__getattribute__(item)
|
||||
methods_to_wrap = [
|
||||
@@ -98,15 +74,13 @@ class TenantRedis(redis.Redis):
|
||||
"hset",
|
||||
"hget",
|
||||
"getset",
|
||||
"scan_iter",
|
||||
"owned",
|
||||
"reacquire",
|
||||
"create_lock",
|
||||
"startswith",
|
||||
] # Regular methods that need simple prefixing
|
||||
|
||||
if item == "scan_iter":
|
||||
return self._prefix_scan_iter(original_attr)
|
||||
elif item in methods_to_wrap and callable(original_attr):
|
||||
] # Add all methods that need prefixing
|
||||
if item in methods_to_wrap and callable(original_attr):
|
||||
return self._prefix_method(original_attr)
|
||||
return original_attr
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from danswer.configs.chat_configs import HYBRID_ALPHA
|
||||
from danswer.configs.chat_configs import HYBRID_ALPHA_KEYWORD
|
||||
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
|
||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from danswer.db.engine import current_tenant_id
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.llm.interfaces import LLM
|
||||
@@ -162,7 +162,7 @@ def retrieval_preprocessing(
|
||||
time_cutoff=time_filter or predicted_time_cutoff,
|
||||
tags=preset_filters.tags, # Tags are never auto-extracted
|
||||
access_control_list=user_acl_filters,
|
||||
tenant_id=CURRENT_TENANT_ID_CONTEXTVAR.get() if MULTI_TENANT else None,
|
||||
tenant_id=current_tenant_id.get() if MULTI_TENANT else None,
|
||||
)
|
||||
|
||||
llm_evaluation_type = LLMEvaluationType.BASIC
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user