mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
23 Commits
remove_emp
...
gating
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d66bdd061 | ||
|
|
8f67f1715c | ||
|
|
3b365509e2 | ||
|
|
022cbdfccf | ||
|
|
ebec6f6b10 | ||
|
|
1cad9c7b3d | ||
|
|
b4e975013c | ||
|
|
dd26f92206 | ||
|
|
4d00ec45ad | ||
|
|
1a81c67a67 | ||
|
|
04f965e656 | ||
|
|
277d37e0ee | ||
|
|
3cd260131b | ||
|
|
ad21ee0e9a | ||
|
|
c7dc0e9af0 | ||
|
|
75c5de802b | ||
|
|
c39f590d0d | ||
|
|
82a9fda846 | ||
|
|
842d4ab2a8 | ||
|
|
cddcec4ea4 | ||
|
|
09dd7b424c | ||
|
|
a2fd8d5e0a | ||
|
|
802dc00f78 |
30
.github/pull_request_template.md
vendored
30
.github/pull_request_template.md
vendored
@@ -6,24 +6,20 @@
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
|
||||
## Accepted Risk (provide if relevant)
|
||||
N/A
|
||||
## Accepted Risk
|
||||
[Any know risks or failure modes to point out to reviewers]
|
||||
|
||||
|
||||
## Related Issue(s) (provide if relevant)
|
||||
N/A
|
||||
## Related Issue(s)
|
||||
[If applicable, link to the issue(s) this PR addresses]
|
||||
|
||||
|
||||
## Mental Checklist:
|
||||
- All of the automated tests pass
|
||||
- All PR comments are addressed and marked resolved
|
||||
- If there are migrations, they have been rebased to latest main
|
||||
- If there are new dependencies, they are added to the requirements
|
||||
- If there are new environment variables, they are added to all of the deployment methods
|
||||
- If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
|
||||
- Docker images build and basic functionalities work
|
||||
- Author has done a final read through of the PR right before merge
|
||||
|
||||
## Backporting (check the box to trigger backport action)
|
||||
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
|
||||
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
|
||||
## Checklist:
|
||||
- [ ] All of the automated tests pass
|
||||
- [ ] All PR comments are addressed and marked resolved
|
||||
- [ ] If there are migrations, they have been rebased to latest main
|
||||
- [ ] If there are new dependencies, they are added to the requirements
|
||||
- [ ] If there are new environment variables, they are added to all of the deployment methods
|
||||
- [ ] If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
|
||||
- [ ] Docker images build and basic functionalities work
|
||||
- [ ] Author has done a final read through of the PR right before merge
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
name: Build and Push Cloud Web Image on Tag
|
||||
# Identical to the web container build, but with correct image tag and build args
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-cloud-web-server
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=${{ matrix.platform == 'linux/amd64' && '8cpu-linux-x64' || '8cpu-linux-arm64' }}
|
||||
- run-id=${{ github.run_id }}
|
||||
- tag=platform-${{ matrix.platform }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
platform:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
|
||||
steps:
|
||||
- name: Prepare
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: ${{ matrix.platform }}
|
||||
push: true
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
NEXT_PUBLIC_CLOUD_ENABLED=true
|
||||
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
mkdir -p /tmp/digests
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: digests-${{ env.PLATFORM_PAIR }}
|
||||
path: /tmp/digests/*
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
merge:
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- build
|
||||
steps:
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: /tmp/digests
|
||||
pattern: digests-*
|
||||
merge-multiple: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create manifest list and push
|
||||
working-directory: /tmp/digests
|
||||
run: |
|
||||
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
||||
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
|
||||
|
||||
- name: Inspect image
|
||||
run: |
|
||||
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
23
.github/workflows/nightly-close-stale-issues.yml
vendored
23
.github/workflows/nightly-close-stale-issues.yml
vendored
@@ -1,23 +0,0 @@
|
||||
name: 'Nightly - Close stale issues and PRs'
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 11 * * *' # Runs every day at 3 AM PST / 4 AM PDT / 11 AM UTC
|
||||
|
||||
permissions:
|
||||
# contents: write # only for delete-branch option
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
close-issue-message: 'This issue was closed because it has been stalled for 90 days with no activity.'
|
||||
close-pr-message: 'This PR was closed because it has been stalled for 90 days with no activity.'
|
||||
days-before-stale: 75
|
||||
# days-before-close: 90 # uncomment after we test stale behavior
|
||||
|
||||
65
.github/workflows/pr-Integration-tests.yml
vendored
65
.github/workflows/pr-Integration-tests.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
@@ -85,58 +85,7 @@ jobs:
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
# Start containers for multi-tenant tests
|
||||
- name: Start Docker containers for multi-tenant tests
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
MULTI_TENANT=true \
|
||||
AUTH_TYPE=basic \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
id: start_docker_multi_tenant
|
||||
|
||||
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
|
||||
- name: Run Multi-Tenant Integration Tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AUTH_TYPE=cloud \
|
||||
-e MULTI_TENANT=true \
|
||||
danswer/danswer-integration:test \
|
||||
/app/tests/integration/multitenant_tests
|
||||
continue-on-error: true
|
||||
id: run_multitenant_tests
|
||||
|
||||
- name: Check multi-tenant test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
- name: Stop multi-tenant Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
|
||||
- name: Start Docker containers
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
@@ -181,7 +130,7 @@ jobs:
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run Standard Integration Tests
|
||||
- name: Run integration tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
@@ -196,8 +145,7 @@ jobs:
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
danswer/danswer-integration:test \
|
||||
/app/tests/integration/tests
|
||||
danswer/danswer-integration:test
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
@@ -210,11 +158,6 @@ jobs:
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
- name: Stop Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
run: |
|
||||
|
||||
124
.github/workflows/pr-backport-autotrigger.yml
vendored
124
.github/workflows/pr-backport-autotrigger.yml
vendored
@@ -1,124 +0,0 @@
|
||||
name: Backport on Merge
|
||||
|
||||
# Note this workflow does not trigger the builds, be sure to manually tag the branches to trigger the builds
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [closed] # Later we check for merge so only PRs that go in can get backported
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
actions: write
|
||||
|
||||
jobs:
|
||||
backport:
|
||||
if: github.event.pull_request.merged == true
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.YUHONG_GH_ACTIONS }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ssh-key: "${{ secrets.RKUO_DEPLOY_KEY }}"
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Git user
|
||||
run: |
|
||||
git config user.name "Richard Kuo [bot]"
|
||||
git config user.email "rkuo[bot]@danswer.ai"
|
||||
git fetch --prune
|
||||
|
||||
- name: Check for Backport Checkbox
|
||||
id: checkbox-check
|
||||
run: |
|
||||
PR_BODY="${{ github.event.pull_request.body }}"
|
||||
if [[ "$PR_BODY" == *"[x] This PR should be backported"* ]]; then
|
||||
echo "backport=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "backport=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: List and sort release branches
|
||||
id: list-branches
|
||||
run: |
|
||||
git fetch --all --tags
|
||||
BRANCHES=$(git for-each-ref --format='%(refname:short)' refs/remotes/origin/release/* | sed 's|origin/release/||' | sort -Vr)
|
||||
BETA=$(echo "$BRANCHES" | head -n 1)
|
||||
STABLE=$(echo "$BRANCHES" | head -n 2 | tail -n 1)
|
||||
echo "beta=release/$BETA" >> $GITHUB_OUTPUT
|
||||
echo "stable=release/$STABLE" >> $GITHUB_OUTPUT
|
||||
# Fetch latest tags for beta and stable
|
||||
LATEST_BETA_TAG=$(git tag -l "v[0-9]*.[0-9]*.[0-9]*-beta.[0-9]*" | grep -E "^v[0-9]+\.[0-9]+\.[0-9]+-beta\.[0-9]+$" | grep -v -- "-cloud" | sort -Vr | head -n 1)
|
||||
LATEST_STABLE_TAG=$(git tag -l "v[0-9]*.[0-9]*.[0-9]*" | grep -E "^v[0-9]+\.[0-9]+\.[0-9]+$" | sort -Vr | head -n 1)
|
||||
|
||||
# Handle case where no beta tags exist
|
||||
if [[ -z "$LATEST_BETA_TAG" ]]; then
|
||||
NEW_BETA_TAG="v1.0.0-beta.1"
|
||||
else
|
||||
NEW_BETA_TAG=$(echo $LATEST_BETA_TAG | awk -F '[.-]' '{print $1 "." $2 "." $3 "-beta." ($NF+1)}')
|
||||
fi
|
||||
|
||||
# Increment latest stable tag
|
||||
NEW_STABLE_TAG=$(echo $LATEST_STABLE_TAG | awk -F '.' '{print $1 "." $2 "." ($3+1)}')
|
||||
echo "latest_beta_tag=$LATEST_BETA_TAG" >> $GITHUB_OUTPUT
|
||||
echo "latest_stable_tag=$LATEST_STABLE_TAG" >> $GITHUB_OUTPUT
|
||||
echo "new_beta_tag=$NEW_BETA_TAG" >> $GITHUB_OUTPUT
|
||||
echo "new_stable_tag=$NEW_STABLE_TAG" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Echo branch and tag information
|
||||
run: |
|
||||
echo "Beta branch: ${{ steps.list-branches.outputs.beta }}"
|
||||
echo "Stable branch: ${{ steps.list-branches.outputs.stable }}"
|
||||
echo "Latest beta tag: ${{ steps.list-branches.outputs.latest_beta_tag }}"
|
||||
echo "Latest stable tag: ${{ steps.list-branches.outputs.latest_stable_tag }}"
|
||||
echo "New beta tag: ${{ steps.list-branches.outputs.new_beta_tag }}"
|
||||
echo "New stable tag: ${{ steps.list-branches.outputs.new_stable_tag }}"
|
||||
|
||||
- name: Trigger Backport
|
||||
if: steps.checkbox-check.outputs.backport == 'true'
|
||||
run: |
|
||||
set -e
|
||||
echo "Backporting to beta ${{ steps.list-branches.outputs.beta }} and stable ${{ steps.list-branches.outputs.stable }}"
|
||||
|
||||
# Echo the merge commit SHA
|
||||
echo "Merge commit SHA: ${{ github.event.pull_request.merge_commit_sha }}"
|
||||
|
||||
# Fetch all history for all branches and tags
|
||||
git fetch --prune
|
||||
|
||||
# Reset and prepare the beta branch
|
||||
git checkout ${{ steps.list-branches.outputs.beta }}
|
||||
echo "Last 5 commits on beta branch:"
|
||||
git log -n 5 --pretty=format:"%H"
|
||||
echo "" # Newline for formatting
|
||||
|
||||
# Cherry-pick the merge commit from the merged PR
|
||||
git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
|
||||
echo "Cherry-pick to beta failed due to conflicts."
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Create new beta branch/tag
|
||||
git tag ${{ steps.list-branches.outputs.new_beta_tag }}
|
||||
# Push the changes and tag to the beta branch using PAT
|
||||
git push origin ${{ steps.list-branches.outputs.beta }}
|
||||
git push origin ${{ steps.list-branches.outputs.new_beta_tag }}
|
||||
|
||||
# Reset and prepare the stable branch
|
||||
git checkout ${{ steps.list-branches.outputs.stable }}
|
||||
echo "Last 5 commits on stable branch:"
|
||||
git log -n 5 --pretty=format:"%H"
|
||||
echo "" # Newline for formatting
|
||||
|
||||
# Cherry-pick the merge commit from the merged PR
|
||||
git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
|
||||
echo "Cherry-pick to stable failed due to conflicts."
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Create new stable branch/tag
|
||||
git tag ${{ steps.list-branches.outputs.new_stable_tag }}
|
||||
# Push the changes and tag to the stable branch using PAT
|
||||
git push origin ${{ steps.list-branches.outputs.stable }}
|
||||
git push origin ${{ steps.list-branches.outputs.new_stable_tag }}
|
||||
300
.vscode/launch.template.jsonc
vendored
300
.vscode/launch.template.jsonc
vendored
@@ -6,69 +6,19 @@
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"compounds": [
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Compound ---",
|
||||
"configurations": [
|
||||
"--- Individual ---"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Run All Danswer Services",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
"Slack Bot",
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery beat",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web / Model / API",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery (all)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery beat"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
}
|
||||
"Indexing",
|
||||
"Background Jobs",
|
||||
"Slack Bot"
|
||||
]
|
||||
}
|
||||
],
|
||||
"configurations": [
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Individual ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web Server",
|
||||
"type": "node",
|
||||
@@ -79,11 +29,7 @@
|
||||
"runtimeArgs": [
|
||||
"run", "dev"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"console": "integratedTerminal",
|
||||
"consoleTitle": "Web Server Console"
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "Model Server",
|
||||
@@ -102,11 +48,7 @@
|
||||
"--reload",
|
||||
"--port",
|
||||
"9000"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Model Server Console"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "API Server",
|
||||
@@ -126,13 +68,43 @@
|
||||
"--reload",
|
||||
"--port",
|
||||
"8080"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "API Server Console"
|
||||
]
|
||||
},
|
||||
// For the listener to access the Slack API,
|
||||
{
|
||||
"name": "Indexing",
|
||||
"consoleName": "Indexing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "danswer/background/update.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
}
|
||||
},
|
||||
// Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev
|
||||
{
|
||||
"name": "Background Jobs",
|
||||
"consoleName": "Background Jobs",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"--no-indexing"
|
||||
]
|
||||
},
|
||||
// For the listner to access the Slack API,
|
||||
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
@@ -146,151 +118,7 @@
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery primary",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"danswer.background.celery.versioned_apps.primary",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=primary@%n",
|
||||
"-Q",
|
||||
"celery",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery primary Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery light",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"danswer.background.celery.versioned_apps.light",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=64",
|
||||
"--prefetch-multiplier=8",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery light Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery heavy",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"danswer.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery indexing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"danswer.background.celery.versioned_apps.indexing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=indexing@%n",
|
||||
"-Q",
|
||||
"connector_indexing",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery indexing Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery beat",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"danswer.background.celery.versioned_apps.beat",
|
||||
"beat",
|
||||
"--loglevel=INFO",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery beat Console"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
@@ -309,22 +137,8 @@
|
||||
"-v"
|
||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Pytest Console"
|
||||
]
|
||||
},
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Tasks ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "3",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart External Volumes and Containers",
|
||||
"type": "node",
|
||||
@@ -333,27 +147,7 @@
|
||||
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"stopOnEntry": true,
|
||||
"presentation": {
|
||||
"group": "3",
|
||||
},
|
||||
},
|
||||
{
|
||||
// Celery jobs launched through a single background script (legacy)
|
||||
// Recommend using the "Celery (all)" compound launch instead.
|
||||
"name": "Background Jobs",
|
||||
"consoleName": "Background Jobs",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
},
|
||||
"stopOnEntry": true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ We also have built-in support for deployment on Kubernetes. Files for that can b
|
||||
|
||||
## 🚧 Roadmap
|
||||
* Chat/Prompt sharing with specific teammates and user groups.
|
||||
* Multimodal model support, chat with images, video etc.
|
||||
* Multi-Model model support, chat with images, video etc.
|
||||
* Choosing between LLMs and parameters during chat session.
|
||||
* Tool calling and agent configurations options.
|
||||
* Organizational understanding and ability to locate and suggest experts from your team.
|
||||
|
||||
@@ -9,12 +9,11 @@ from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.models import Base
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from danswer.background.celery.celery_app import get_all_tenant_ids
|
||||
|
||||
# Alembic Config object
|
||||
config = context.config
|
||||
@@ -58,15 +57,11 @@ def get_schema_options() -> tuple[str, bool, bool]:
|
||||
if "=" in pair:
|
||||
key, value = pair.split("=", 1)
|
||||
x_args[key.strip()] = value.strip()
|
||||
schema_name = x_args.get("schema", POSTGRES_DEFAULT_SCHEMA)
|
||||
schema_name = x_args.get("schema", "public")
|
||||
create_schema = x_args.get("create_schema", "true").lower() == "true"
|
||||
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
|
||||
|
||||
if (
|
||||
MULTI_TENANT
|
||||
and schema_name == POSTGRES_DEFAULT_SCHEMA
|
||||
and not upgrade_all_tenants
|
||||
):
|
||||
if MULTI_TENANT and schema_name == "public":
|
||||
raise ValueError(
|
||||
"Cannot run default migrations in public schema when multi-tenancy is enabled. "
|
||||
"Please specify a tenant-specific schema."
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
"""Migrate chat_session and chat_message tables to use UUID primary keys
|
||||
|
||||
"""
|
||||
Revision ID: 6756efa39ada
|
||||
Revises: 5d12a446f5c0
|
||||
Create Date: 2024-10-15 17:47:44.108537
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
@@ -14,6 +12,8 @@ branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
"""
|
||||
Migrate chat_session and chat_message tables to use UUID primary keys.
|
||||
|
||||
This script:
|
||||
1. Adds UUID columns to chat_session and chat_message
|
||||
2. Populates new columns with UUIDs
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
"""remove rt
|
||||
|
||||
Revision ID: 949b4a92a401
|
||||
Revises: 1b10e1fda030
|
||||
Create Date: 2024-10-26 13:06:06.937969
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Import your models and constants
|
||||
from danswer.db.models import (
|
||||
Connector,
|
||||
ConnectorCredentialPair,
|
||||
Credential,
|
||||
IndexAttempt,
|
||||
)
|
||||
from danswer.configs.constants import DocumentSource
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "949b4a92a401"
|
||||
down_revision = "1b10e1fda030"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Deletes all RequestTracker connectors and associated data
|
||||
bind = op.get_bind()
|
||||
session = Session(bind=bind)
|
||||
|
||||
connectors_to_delete = (
|
||||
session.query(Connector)
|
||||
.filter(Connector.source == DocumentSource.REQUESTTRACKER)
|
||||
.all()
|
||||
)
|
||||
|
||||
connector_ids = [connector.id for connector in connectors_to_delete]
|
||||
|
||||
if connector_ids:
|
||||
cc_pairs_to_delete = (
|
||||
session.query(ConnectorCredentialPair)
|
||||
.filter(ConnectorCredentialPair.connector_id.in_(connector_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
cc_pair_ids = [cc_pair.id for cc_pair in cc_pairs_to_delete]
|
||||
|
||||
if cc_pair_ids:
|
||||
session.query(IndexAttempt).filter(
|
||||
IndexAttempt.connector_credential_pair_id.in_(cc_pair_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
session.query(ConnectorCredentialPair).filter(
|
||||
ConnectorCredentialPair.id.in_(cc_pair_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
credential_ids = [cc_pair.credential_id for cc_pair in cc_pairs_to_delete]
|
||||
if credential_ids:
|
||||
session.query(Credential).filter(Credential.id.in_(credential_ids)).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
session.query(Connector).filter(Connector.id.in_(connector_ids)).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# No-op downgrade as we cannot restore deleted data
|
||||
pass
|
||||
@@ -31,12 +31,6 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# First, update any null values to a default value
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET last_attempt_status = 'NOT_STARTED' WHERE last_attempt_status IS NULL"
|
||||
)
|
||||
|
||||
# Then, make the column non-nullable
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"last_attempt_status",
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
import os
|
||||
|
||||
__version__ = os.environ.get("DANSWER_VERSION", "") or "Development"
|
||||
__version__ = os.environ.get("DANSWER_VERSION", "") or "0.3-dev"
|
||||
|
||||
@@ -70,12 +70,3 @@ class DocumentAccess(ExternalAccess):
|
||||
user_groups=set(user_groups),
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
|
||||
default_public_access = DocumentAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
user_emails=set(),
|
||||
user_groups=set(),
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
@@ -9,7 +9,6 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
def get_invited_users() -> list[str]:
|
||||
try:
|
||||
store = get_kv_store()
|
||||
|
||||
return cast(list, store.load(KV_USER_STORE_KEY))
|
||||
except KvKeyNotFoundError:
|
||||
return list()
|
||||
|
||||
@@ -49,7 +49,6 @@ from httpx_oauth.oauth2 import BaseOAuth2
|
||||
from httpx_oauth.oauth2 import OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import attributes
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -59,9 +58,10 @@ from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import DISABLE_VERIFICATION
|
||||
from danswer.configs.app_configs import EMAIL_FROM
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from danswer.configs.app_configs import SECRET_JWT_KEY
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.configs.app_configs import SMTP_PASS
|
||||
from danswer.configs.app_configs import SMTP_PORT
|
||||
@@ -93,10 +93,7 @@ from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
from shared_configs.configs import current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -135,9 +132,7 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
|
||||
def user_needs_to_be_verified() -> bool:
|
||||
# all other auth types besides basic should require users to be
|
||||
# verified
|
||||
return not DISABLE_VERIFICATION and (
|
||||
AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
)
|
||||
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
@@ -192,7 +187,7 @@ def verify_email_domain(email: str) -> None:
|
||||
|
||||
def get_tenant_id_for_email(email: str) -> str:
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
return "public"
|
||||
# Implement logic to get tenant_id from the mapping table
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
result = db_session.execute(
|
||||
@@ -240,9 +235,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
) -> User:
|
||||
try:
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(user_create.email)
|
||||
if MULTI_TENANT
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
get_tenant_id_for_email(user_create.email) if MULTI_TENANT else "public"
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
@@ -253,7 +246,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
)
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
@@ -292,9 +285,32 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
else:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
current_tenant_id.reset(token)
|
||||
return user
|
||||
|
||||
async def on_after_login(
|
||||
self,
|
||||
user: User,
|
||||
request: Request | None = None,
|
||||
response: Response | None = None,
|
||||
) -> None:
|
||||
if response is None or not MULTI_TENANT:
|
||||
return
|
||||
|
||||
tenant_id = get_tenant_id_for_email(user.email)
|
||||
|
||||
tenant_token = jwt.encode(
|
||||
{"tenant_id": tenant_id}, SECRET_JWT_KEY, algorithm="HS256"
|
||||
)
|
||||
|
||||
response.set_cookie(
|
||||
key="tenant_details",
|
||||
value=tenant_token,
|
||||
httponly=True,
|
||||
secure=WEB_DOMAIN.startswith("https"),
|
||||
samesite="lax",
|
||||
)
|
||||
|
||||
async def oauth_callback(
|
||||
self: "BaseUserManager[models.UOAP, models.ID]",
|
||||
oauth_name: str,
|
||||
@@ -311,9 +327,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
# Get tenant_id from mapping table
|
||||
try:
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(account_email)
|
||||
if MULTI_TENANT
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
get_tenant_id_for_email(account_email) if MULTI_TENANT else "public"
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
@@ -323,11 +337,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
token = None
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
|
||||
verify_email_in_whitelist(account_email, tenant_id)
|
||||
verify_email_domain(account_email)
|
||||
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
|
||||
self.user_db = tenant_user_db
|
||||
@@ -367,10 +380,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
}
|
||||
|
||||
user = await self.user_db.create(user_dict)
|
||||
|
||||
# Explicitly set the Postgres schema for this session to ensure
|
||||
# OAuth account creation happens in the correct tenant schema
|
||||
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
@@ -417,7 +426,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user.oidc_expiry = None # type: ignore
|
||||
|
||||
if token:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
return user
|
||||
|
||||
@@ -508,22 +517,8 @@ cookie_transport = CookieTransport(
|
||||
)
|
||||
|
||||
|
||||
# This strategy is used to add tenant_id to the JWT token
|
||||
class TenantAwareJWTStrategy(JWTStrategy):
|
||||
async def write_token(self, user: User) -> str:
|
||||
tenant_id = get_tenant_id_for_email(user.email)
|
||||
data = {
|
||||
"sub": str(user.id),
|
||||
"aud": self.token_audience,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
return generate_jwt(
|
||||
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
|
||||
)
|
||||
|
||||
|
||||
def get_jwt_strategy() -> JWTStrategy:
|
||||
return TenantAwareJWTStrategy(
|
||||
return JWTStrategy(
|
||||
secret=USER_AUTH_SECRET,
|
||||
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
|
||||
)
|
||||
|
||||
@@ -1,310 +0,0 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import sentry_sdk
|
||||
from celery import Task
|
||||
from celery.app import trace
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.states import READY_STATES
|
||||
from celery.utils.log import get_task_logger
|
||||
from celery.worker import strategy # type: ignore
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
|
||||
from danswer.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
|
||||
from danswer.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import ColoredFormatter
|
||||
from danswer.utils.logger import PlainFormatter
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
if SENTRY_DSN:
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[CeleryIntegration()],
|
||||
traces_sample_rate=0.1,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
|
||||
|
||||
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
"""We handle this signal in order to remove completed tasks
|
||||
from their respective tasksets. This allows us to track the progress of document set
|
||||
and user group syncs.
|
||||
|
||||
This function runs after any task completes (both success and failure)
|
||||
Note that this signal does not fire on a task that failed to complete and is going
|
||||
to be retried.
|
||||
|
||||
This also does not fire if a worker with acks_late=False crashes (which all of our
|
||||
long running workers are)
|
||||
"""
|
||||
if not task:
|
||||
return
|
||||
|
||||
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}")
|
||||
|
||||
if state not in READY_STATES:
|
||||
return
|
||||
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
|
||||
if not kwargs:
|
||||
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
|
||||
tenant_id = None
|
||||
else:
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
|
||||
task_logger.debug(
|
||||
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
|
||||
f"{f'for tenant_id={tenant_id}' if tenant_id else ''}"
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
|
||||
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisDocumentSet.PREFIX):
|
||||
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
|
||||
if document_set_id is not None:
|
||||
rds = RedisDocumentSet(int(document_set_id))
|
||||
r.srem(rds.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisUserGroup.PREFIX):
|
||||
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
|
||||
if usergroup_id is not None:
|
||||
rug = RedisUserGroup(int(usergroup_id))
|
||||
r.srem(rug.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorDeletion.PREFIX):
|
||||
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
rcd = RedisConnectorDeletion(int(cc_pair_id))
|
||||
r.srem(rcd.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
rcp = RedisConnectorPruning(int(cc_pair_id))
|
||||
r.srem(rcp.taskset_key, task_id)
|
||||
return
|
||||
|
||||
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
"""The first signal sent on celery worker startup"""
|
||||
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
|
||||
|
||||
|
||||
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
time_start = time.monotonic()
|
||||
logger.info("Redis: Readiness check starting.")
|
||||
while True:
|
||||
try:
|
||||
if r.ping():
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
logger.info(
|
||||
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
f"Redis: Readiness check did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
logger.info("Redis: Readiness check succeeded. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
logger.info("Waiting for all tenant primary workers to be ready...")
|
||||
time_start = time.monotonic()
|
||||
|
||||
while True:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
# Check if we have a primary worker lock for each tenant
|
||||
all_tenants_ready = all(
|
||||
get_redis_client(tenant_id=tenant_id).exists(
|
||||
DanswerRedisLocks.PRIMARY_WORKER
|
||||
)
|
||||
for tenant_id in tenant_ids
|
||||
)
|
||||
|
||||
if all_tenants_ready:
|
||||
break
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
ready_tenants = sum(
|
||||
1
|
||||
for tenant_id in tenant_ids
|
||||
if get_redis_client(tenant_id=tenant_id).exists(
|
||||
DanswerRedisLocks.PRIMARY_WORKER
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Not all tenant primary workers are ready yet. "
|
||||
f"Ready tenants: {ready_tenants}/{len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
f"Not all tenant primary workers were ready within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
logger.info("All tenant primary workers are ready. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
task_logger.info("worker_ready signal received.")
|
||||
|
||||
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
if not celery_is_worker_primary(sender):
|
||||
return
|
||||
|
||||
if not hasattr(sender, "primary_worker_locks"):
|
||||
return
|
||||
|
||||
for tenant_id, lock in sender.primary_worker_locks.items():
|
||||
try:
|
||||
if lock and lock.owned():
|
||||
logger.debug(f"Attempting to release lock for tenant {tenant_id}")
|
||||
try:
|
||||
lock.release()
|
||||
logger.debug(f"Successfully released lock for tenant {tenant_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to release lock for tenant {tenant_id}. Error: {str(e)}"
|
||||
)
|
||||
finally:
|
||||
sender.primary_worker_locks[tenant_id] = None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error checking lock status for tenant {tenant_id}. Error: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
# TODO: could unhardcode format and colorize and accept these as options from
|
||||
# celery's config
|
||||
|
||||
# reformats the root logger
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
root_handler = logging.StreamHandler() # Set up a handler for the root logger
|
||||
root_formatter = ColoredFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
root_handler.setFormatter(root_formatter)
|
||||
root_logger.addHandler(root_handler) # Apply the handler to the root logger
|
||||
|
||||
if logfile:
|
||||
root_file_handler = logging.FileHandler(logfile)
|
||||
root_file_formatter = PlainFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
root_file_handler.setFormatter(root_file_formatter)
|
||||
root_logger.addHandler(root_file_handler)
|
||||
|
||||
root_logger.setLevel(loglevel)
|
||||
|
||||
# reformats celery's task logger
|
||||
task_formatter = CeleryTaskColoredFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
task_handler = logging.StreamHandler() # Set up a handler for the task logger
|
||||
task_handler.setFormatter(task_formatter)
|
||||
task_logger.addHandler(task_handler) # Apply the handler to the task logger
|
||||
|
||||
if logfile:
|
||||
task_file_handler = logging.FileHandler(logfile)
|
||||
task_file_formatter = CeleryTaskPlainFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
task_file_handler.setFormatter(task_file_formatter)
|
||||
task_logger.addHandler(task_file_handler)
|
||||
|
||||
task_logger.setLevel(loglevel)
|
||||
task_logger.propagate = False
|
||||
|
||||
# hide celery task received spam
|
||||
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] received"
|
||||
strategy.logger.setLevel(logging.WARNING)
|
||||
|
||||
# hide celery task succeeded/failed spam
|
||||
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] succeeded in 0.03137450001668185s: None"
|
||||
trace.logger.setLevel(logging.WARNING)
|
||||
@@ -1,100 +0,0 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery.signals import beat_init
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("danswer.background.celery.configs.beat")
|
||||
|
||||
|
||||
@beat_init.connect
|
||||
def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("beat_init signal received.")
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=2, max_overflow=0)
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
#####
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": "check_for_connector_deletion_task",
|
||||
"schedule": timedelta(seconds=60),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": "check_for_indexing",
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": "check_for_pruning",
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": "kombu_message_cleanup_task",
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": "monitor_vespa_sync",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# Build the celery beat schedule dynamically
|
||||
beat_schedule = {}
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
for task in tasks_to_schedule:
|
||||
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
|
||||
beat_schedule[task_name] = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"options": task["options"],
|
||||
"kwargs": {"tenant_id": tenant_id}, # Must pass tenant_id as an argument
|
||||
}
|
||||
|
||||
# Include any existing beat schedules
|
||||
existing_beat_schedule = celery_app.conf.beat_schedule or {}
|
||||
beat_schedule.update(existing_beat_schedule)
|
||||
|
||||
# Update the Celery app configuration once
|
||||
celery_app.conf.beat_schedule = beat_schedule
|
||||
@@ -1,88 +0,0 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("danswer.background.celery.configs.heavy")
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"danswer.background.celery.tasks.pruning",
|
||||
]
|
||||
)
|
||||
@@ -1,88 +0,0 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("danswer.background.celery.configs.indexing")
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"danswer.background.celery.tasks.indexing",
|
||||
]
|
||||
)
|
||||
@@ -1,89 +0,0 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("danswer.background.celery.configs.light")
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"danswer.background.celery.tasks.shared",
|
||||
"danswer.background.celery.tasks.vespa",
|
||||
]
|
||||
)
|
||||
@@ -1,300 +0,0 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import bootsteps # type: ignore
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("danswer.background.celery.configs.primary")
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
sender.primary_worker_locks = {}
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
for tenant_id in tenant_ids:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# For the moment, we're assuming that we are the only primary worker
|
||||
# that should be running.
|
||||
# TODO: maybe check for or clean up another zombie primary worker if we detect it
|
||||
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
|
||||
|
||||
# this process wide lock is taken to help other workers start up in order.
|
||||
# it is planned to use this lock to enforce singleton behavior on the primary
|
||||
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
||||
# implemented yet.
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
logger.info("Primary worker lock: Acquire starting.")
|
||||
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
|
||||
if acquired:
|
||||
logger.info("Primary worker lock: Acquire succeeded.")
|
||||
else:
|
||||
logger.error("Primary worker lock: Acquire failed!")
|
||||
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
||||
|
||||
# tacking on our own user data to the sender
|
||||
sender.primary_worker_locks[tenant_id] = lock
|
||||
|
||||
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
||||
# to a clean state (for our purposes, anyway)
|
||||
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
||||
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
|
||||
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
# @worker_process_init.connect
|
||||
# def on_worker_process_init(sender: Any, **kwargs: Any) -> None:
|
||||
# """This only runs inside child processes when the worker is in pool=prefork mode.
|
||||
# This may be technically unnecessary since we're finding prefork pools to be
|
||||
# unstable and currently aren't planning on using them."""
|
||||
# logger.info("worker_process_init signal received.")
|
||||
# SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
|
||||
# SqlEngine.init_engine(pool_size=5, max_overflow=0)
|
||||
|
||||
# # https://stackoverflow.com/questions/43944787/sqlalchemy-celery-with-scoped-session-error
|
||||
# SqlEngine.get_engine().dispose(close=False)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
"""Regularly reacquires the primary worker lock outside of the task queue.
|
||||
Use the task_logger in this class to avoid double logging.
|
||||
|
||||
This cannot be done inside a regular beat task because it must run on schedule and
|
||||
a queue of existing work would starve the task from running.
|
||||
"""
|
||||
|
||||
# it's unclear to me whether using the hub's timer or the bootstep timer is better
|
||||
requires = {"celery.worker.components:Hub"}
|
||||
|
||||
def __init__(self, worker: Any, **kwargs: Any) -> None:
|
||||
self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds
|
||||
self.task_tref = None
|
||||
|
||||
def start(self, worker: Any) -> None:
|
||||
if not celery_is_worker_primary(worker):
|
||||
return
|
||||
|
||||
# Access the worker's event loop (hub)
|
||||
hub = worker.consumer.controller.hub
|
||||
|
||||
# Schedule the periodic task
|
||||
self.task_tref = hub.call_repeatedly(
|
||||
self.interval, self.run_periodic_task, worker
|
||||
)
|
||||
task_logger.info("Scheduled periodic task with hub.")
|
||||
|
||||
def run_periodic_task(self, worker: Any) -> None:
|
||||
try:
|
||||
if not celery_is_worker_primary(worker):
|
||||
return
|
||||
|
||||
if not hasattr(worker, "primary_worker_locks"):
|
||||
return
|
||||
|
||||
# Retrieve all tenant IDs
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
lock = worker.primary_worker_locks.get(tenant_id)
|
||||
if not lock:
|
||||
continue # Skip if no lock for this tenant
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if lock.owned():
|
||||
task_logger.debug(
|
||||
f"Reacquiring primary worker lock for tenant {tenant_id}."
|
||||
)
|
||||
lock.reacquire()
|
||||
else:
|
||||
task_logger.warning(
|
||||
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
|
||||
"Reasons could be worker restart or lock expiration."
|
||||
)
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
|
||||
)
|
||||
acquired = lock.acquire(
|
||||
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
|
||||
)
|
||||
if acquired:
|
||||
task_logger.info(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
|
||||
)
|
||||
worker.primary_worker_locks[tenant_id] = lock
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
|
||||
)
|
||||
raise TimeoutError(
|
||||
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Periodic task failed.")
|
||||
|
||||
def stop(self, worker: Any) -> None:
|
||||
# Cancel the scheduled task when the worker stops
|
||||
if self.task_tref:
|
||||
self.task_tref.cancel()
|
||||
task_logger.info("Canceled periodic task with hub.")
|
||||
|
||||
|
||||
celery_app.steps["worker"].add(HubPeriodicTask)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"danswer.background.celery.tasks.connector_deletion",
|
||||
"danswer.background.celery.tasks.indexing",
|
||||
"danswer.background.celery.tasks.periodic",
|
||||
"danswer.background.celery.tasks.pruning",
|
||||
"danswer.background.celery.tasks.shared",
|
||||
"danswer.background.celery.tasks.vespa",
|
||||
]
|
||||
)
|
||||
@@ -1,26 +0,0 @@
|
||||
import logging
|
||||
|
||||
from celery import current_task
|
||||
|
||||
from danswer.utils.logger import ColoredFormatter
|
||||
from danswer.utils.logger import PlainFormatter
|
||||
|
||||
|
||||
class CeleryTaskPlainFormatter(PlainFormatter):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
task = current_task
|
||||
if task and task.request:
|
||||
record.__dict__.update(task_id=task.request.id, task_name=task.name)
|
||||
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
|
||||
|
||||
return super().format(record)
|
||||
|
||||
|
||||
class CeleryTaskColoredFormatter(ColoredFormatter):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
task = current_task
|
||||
if task and task.request:
|
||||
record.__dict__.update(task_id=task.request.id, task_name=task.name)
|
||||
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
|
||||
|
||||
return super().format(record)
|
||||
619
backend/danswer/background/celery/celery_app.py
Normal file
619
backend/danswer/background/celery/celery_app.py
Normal file
@@ -0,0 +1,619 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
import sentry_sdk
|
||||
from celery import bootsteps # type: ignore
|
||||
from celery import Celery
|
||||
from celery import current_task
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.signals import beat_init
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
from celery.states import READY_STATES
|
||||
from celery.utils.log import get_task_logger
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
from danswer.background.celery.celery_utils import get_all_tenant_ids
|
||||
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import ColoredFormatter
|
||||
from danswer.utils.logger import PlainFormatter
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# use this within celery tasks to get celery task specific logging
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
if SENTRY_DSN:
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
integrations=[CeleryIntegration()],
|
||||
traces_sample_rate=0.5,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
|
||||
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object(
|
||||
"danswer.background.celery.celeryconfig"
|
||||
) # Load configuration from 'celeryconfig.py'
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
tenant_id: str | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
"""We handle this signal in order to remove completed tasks
|
||||
from their respective tasksets. This allows us to track the progress of document set
|
||||
and user group syncs.
|
||||
|
||||
This function runs after any task completes (both success and failure)
|
||||
Note that this signal does not fire on a task that failed to complete and is going
|
||||
to be retried.
|
||||
|
||||
This also does not fire if a worker with acks_late=False crashes (which all of our
|
||||
long running workers are)
|
||||
"""
|
||||
if not task:
|
||||
return
|
||||
|
||||
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
|
||||
if not kwargs:
|
||||
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
|
||||
tenant_id = None
|
||||
else:
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
|
||||
task_logger.debug(
|
||||
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
|
||||
f"{f'for tenant_id={tenant_id}' if tenant_id else ''}"
|
||||
)
|
||||
|
||||
if state not in READY_STATES:
|
||||
return
|
||||
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
|
||||
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisDocumentSet.PREFIX):
|
||||
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
|
||||
if document_set_id is not None:
|
||||
rds = RedisDocumentSet(int(document_set_id))
|
||||
r.srem(rds.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisUserGroup.PREFIX):
|
||||
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
|
||||
if usergroup_id is not None:
|
||||
rug = RedisUserGroup(int(usergroup_id))
|
||||
r.srem(rug.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorDeletion.PREFIX):
|
||||
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
rcd = RedisConnectorDeletion(int(cc_pair_id))
|
||||
r.srem(rcd.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
rcp = RedisConnectorPruning(int(cc_pair_id))
|
||||
r.srem(rcp.taskset_key, task_id)
|
||||
return
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
"""The first signal sent on celery worker startup"""
|
||||
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
|
||||
|
||||
|
||||
@beat_init.connect
|
||||
def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=2, max_overflow=0)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
# decide some initial startup settings based on the celery worker's hostname
|
||||
# (set at the command line)'
|
||||
|
||||
hostname = sender.hostname
|
||||
if hostname.startswith("light"):
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
|
||||
elif hostname.startswith("heavy"):
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
elif hostname.startswith("indexing"):
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
# TODO: why is this necessary for the indexer to do?
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
|
||||
if search_settings.provider_type is None:
|
||||
logger.notice(
|
||||
"Running a first inference to warm up embedding model"
|
||||
)
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
logger.notice("First inference complete.")
|
||||
else:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
if not hasattr(sender, "primary_worker_locks"):
|
||||
sender.primary_worker_locks = {}
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
if not celery_is_worker_primary(sender):
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
for tenant_id in tenant_ids:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
time_start = time.monotonic()
|
||||
logger.notice("Redis: Readiness check starting.")
|
||||
while True:
|
||||
# Log all the locks in Redis
|
||||
all_locks = r.keys("*")
|
||||
logger.notice(f"Current Redis locks: {all_locks}")
|
||||
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
|
||||
break
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
logger.info(
|
||||
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
"Redis: Readiness check did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
logger.info("Wait for primary worker completed successfully. Continuing...")
|
||||
return # Exit the function for secondary workers
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
time_start = time.monotonic()
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# For the moment, we're assuming that we are the only primary worker
|
||||
# that should be running.
|
||||
# TODO: maybe check for or clean up another zombie primary worker if we detect it
|
||||
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
|
||||
|
||||
# this process wide lock is taken to help other workers start up in order.
|
||||
# it is planned to use this lock to enforce singleton behavior on the primary
|
||||
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
||||
# implemented yet.
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
logger.info("Primary worker lock: Acquire starting.")
|
||||
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
|
||||
if acquired:
|
||||
logger.info("Primary worker lock: Acquire succeeded.")
|
||||
else:
|
||||
logger.error("Primary worker lock: Acquire failed!")
|
||||
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
||||
|
||||
sender.primary_worker_locks[tenant_id] = lock
|
||||
|
||||
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
||||
# to a clean state (for our purposes, anyway)
|
||||
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
||||
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
|
||||
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
# @worker_process_init.connect
|
||||
# def on_worker_process_init(sender: Any, **kwargs: Any) -> None:
|
||||
# """This only runs inside child processes when the worker is in pool=prefork mode.
|
||||
# This may be technically unnecessary since we're finding prefork pools to be
|
||||
# unstable and currently aren't planning on using them."""
|
||||
# logger.info("worker_process_init signal received.")
|
||||
# SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
|
||||
# SqlEngine.init_engine(pool_size=5, max_overflow=0)
|
||||
|
||||
# # https://stackoverflow.com/questions/43944787/sqlalchemy-celery-with-scoped-session-error
|
||||
# SqlEngine.get_engine().dispose(close=False)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
task_logger.info("worker_ready signal received.")
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
if not celery_is_worker_primary(sender):
|
||||
return
|
||||
|
||||
if not hasattr(sender, "primary_worker_locks"):
|
||||
return
|
||||
|
||||
logger.info("Releasing primary worker lock.")
|
||||
for tenant_id, lock in sender.primary_worker_locks.items():
|
||||
logger.info(f"Releasing primary worker lock for tenant {tenant_id}.")
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
sender.primary_worker_locks = {}
|
||||
|
||||
|
||||
class CeleryTaskPlainFormatter(PlainFormatter):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
task = current_task
|
||||
if task and task.request:
|
||||
record.__dict__.update(task_id=task.request.id, task_name=task.name)
|
||||
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
|
||||
|
||||
return super().format(record)
|
||||
|
||||
|
||||
class CeleryTaskColoredFormatter(ColoredFormatter):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
task = current_task
|
||||
if task and task.request:
|
||||
record.__dict__.update(task_id=task.request.id, task_name=task.name)
|
||||
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
|
||||
|
||||
return super().format(record)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
# TODO: could unhardcode format and colorize and accept these as options from
|
||||
# celery's config
|
||||
|
||||
# reformats the root logger
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
root_handler = logging.StreamHandler() # Set up a handler for the root logger
|
||||
root_formatter = ColoredFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
root_handler.setFormatter(root_formatter)
|
||||
root_logger.addHandler(root_handler) # Apply the handler to the root logger
|
||||
|
||||
if logfile:
|
||||
root_file_handler = logging.FileHandler(logfile)
|
||||
root_file_formatter = PlainFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
root_file_handler.setFormatter(root_file_formatter)
|
||||
root_logger.addHandler(root_file_handler)
|
||||
|
||||
root_logger.setLevel(loglevel)
|
||||
|
||||
# reformats celery's task logger
|
||||
task_formatter = CeleryTaskColoredFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
task_handler = logging.StreamHandler() # Set up a handler for the task logger
|
||||
task_handler.setFormatter(task_formatter)
|
||||
task_logger.addHandler(task_handler) # Apply the handler to the task logger
|
||||
|
||||
if logfile:
|
||||
task_file_handler = logging.FileHandler(logfile)
|
||||
task_file_formatter = CeleryTaskPlainFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
task_file_handler.setFormatter(task_file_formatter)
|
||||
task_logger.addHandler(task_file_handler)
|
||||
|
||||
task_logger.setLevel(loglevel)
|
||||
task_logger.propagate = False
|
||||
|
||||
|
||||
class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
"""Regularly reacquires the primary worker locks for all tenants outside of the task queue.
|
||||
Use the task_logger in this class to avoid double logging.
|
||||
|
||||
This cannot be done inside a regular beat task because it must run on schedule and
|
||||
a queue of existing work would starve the task from running.
|
||||
"""
|
||||
|
||||
# Requires the Hub component
|
||||
requires = {"celery.worker.components:Hub"}
|
||||
|
||||
def __init__(self, worker: Any, **kwargs: Any) -> None:
|
||||
super().__init__(worker, **kwargs)
|
||||
self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds
|
||||
self.task_tref = None
|
||||
|
||||
def start(self, worker: Any) -> None:
|
||||
if not celery_is_worker_primary(worker):
|
||||
return
|
||||
|
||||
# Access the worker's event loop (hub)
|
||||
hub = worker.consumer.controller.hub
|
||||
|
||||
# Schedule the periodic task
|
||||
self.task_tref = hub.call_repeatedly(
|
||||
self.interval, self.run_periodic_task, worker
|
||||
)
|
||||
task_logger.info("Scheduled periodic task with hub.")
|
||||
|
||||
def run_periodic_task(self, worker: Any) -> None:
|
||||
try:
|
||||
if not celery_is_worker_primary(worker):
|
||||
return
|
||||
|
||||
if not hasattr(worker, "primary_worker_locks"):
|
||||
return
|
||||
|
||||
# Retrieve all tenant IDs
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
lock = worker.primary_worker_locks.get(tenant_id)
|
||||
if not lock:
|
||||
continue # Skip if no lock for this tenant
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if lock.owned():
|
||||
task_logger.debug(
|
||||
f"Reacquiring primary worker lock for tenant {tenant_id}."
|
||||
)
|
||||
lock.reacquire()
|
||||
else:
|
||||
task_logger.warning(
|
||||
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
|
||||
"Reasons could be worker restart or lock expiration."
|
||||
)
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
|
||||
)
|
||||
acquired = lock.acquire(
|
||||
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
|
||||
)
|
||||
if acquired:
|
||||
task_logger.info(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
|
||||
)
|
||||
worker.primary_worker_locks[tenant_id] = lock
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
|
||||
)
|
||||
raise TimeoutError(
|
||||
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
task_logger.error(f"Error in periodic task: {e}")
|
||||
|
||||
def stop(self, worker: Any) -> None:
|
||||
# Cancel the scheduled task when the worker stops
|
||||
if self.task_tref:
|
||||
self.task_tref.cancel()
|
||||
task_logger.info("Canceled periodic task with hub.")
|
||||
|
||||
|
||||
celery_app.steps["worker"].add(HubPeriodicTask)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"danswer.background.celery.tasks.connector_deletion",
|
||||
"danswer.background.celery.tasks.indexing",
|
||||
"danswer.background.celery.tasks.periodic",
|
||||
"danswer.background.celery.tasks.pruning",
|
||||
"danswer.background.celery.tasks.shared",
|
||||
"danswer.background.celery.tasks.vespa",
|
||||
]
|
||||
)
|
||||
|
||||
#####
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": "check_for_connector_deletion_task",
|
||||
"schedule": timedelta(seconds=60),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": "check_for_indexing",
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": "check_for_pruning",
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": "kombu_message_cleanup_task",
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": "monitor_vespa_sync",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
]
|
||||
|
||||
# Build the celery beat schedule dynamically
|
||||
beat_schedule = {}
|
||||
|
||||
for id in tenant_ids:
|
||||
for task in tasks_to_schedule:
|
||||
task_name = f"{task['name']}-{id}" # Unique name for each scheduled task
|
||||
beat_schedule[task_name] = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"options": task["options"],
|
||||
"kwargs": {"tenant_id": id}, # Must pass tenant_id as an argument
|
||||
}
|
||||
|
||||
# Include any existing beat schedules
|
||||
existing_beat_schedule = celery_app.conf.beat_schedule or {}
|
||||
beat_schedule.update(existing_beat_schedule)
|
||||
|
||||
# Update the Celery app configuration once
|
||||
celery_app.conf.beat_schedule = beat_schedule
|
||||
@@ -10,7 +10,7 @@ from celery import Celery
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.configs.base import CELERY_SEPARATOR
|
||||
from danswer.background.celery.celeryconfig import CELERY_SEPARATOR
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
@@ -313,8 +313,6 @@ class RedisConnectorDeletion(RedisObjectHelper):
|
||||
lock: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns None if the cc_pair doesn't exist.
|
||||
Otherwise, returns an int with the number of generated tasks."""
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
@@ -467,8 +465,14 @@ class RedisConnectorPruning(RedisObjectHelper):
|
||||
|
||||
return len(async_results)
|
||||
|
||||
def is_pruning(self, redis_client: Redis) -> bool:
|
||||
def is_pruning(self, db_session: Session, redis_client: Redis) -> bool:
|
||||
"""A single example of a helper method being refactored into the redis helper"""
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=int(self._id), db_session=db_session
|
||||
)
|
||||
if not cc_pair:
|
||||
raise ValueError(f"cc_pair_id {self._id} does not exist.")
|
||||
|
||||
if redis_client.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
@@ -534,36 +538,6 @@ class RedisConnectorIndexing(RedisObjectHelper):
|
||||
) -> int | None:
|
||||
return None
|
||||
|
||||
def is_indexing(self, redis_client: Redis) -> bool:
|
||||
"""A single example of a helper method being refactored into the redis helper"""
|
||||
if redis_client.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class RedisConnectorStop(RedisObjectHelper):
|
||||
"""Used to signal any running tasks for a connector to stop. We should refactor
|
||||
connector related redis helpers into a single class.
|
||||
"""
|
||||
|
||||
PREFIX = "connectorstop"
|
||||
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
|
||||
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
super().__init__(str(id))
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock | None,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
return None
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
"""This is a redis specific way to get the length of a celery queue.
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Factory stub for running celery worker / celery beat."""
|
||||
"""Entry point for running celery worker / celery beat."""
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app = fetch_versioned_implementation(
|
||||
"danswer.background.celery.apps.primary", "celery_app"
|
||||
celery_app = fetch_versioned_implementation(
|
||||
"danswer.background.celery.celery_app", "celery_app"
|
||||
)
|
||||
@@ -1,21 +1,25 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.configs.constants import TENANT_ID_PREFIX
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.interfaces import BaseConnector
|
||||
from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import TaskStatus
|
||||
from danswer.db.models import TaskQueueState
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
@@ -71,15 +75,13 @@ def get_deletion_attempt_snapshot(
|
||||
)
|
||||
|
||||
|
||||
def document_batch_to_ids(
|
||||
doc_batch: list[Document],
|
||||
) -> set[str]:
|
||||
def document_batch_to_ids(doc_batch: list[Document]) -> set[str]:
|
||||
return {doc.id for doc in doc_batch}
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(
|
||||
runnable_connector: BaseConnector,
|
||||
callback: RunIndexingCallbackInterface | None = None,
|
||||
progress_callback: Callable[[int], None] | None = None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
If the PruneConnector hasnt been implemented for the given connector, just pull
|
||||
@@ -89,13 +91,10 @@ def extract_ids_from_runnable_connector(
|
||||
"""
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
|
||||
if isinstance(runnable_connector, SlimConnector):
|
||||
for metadata_batch in runnable_connector.retrieve_all_slim_documents():
|
||||
all_connector_doc_ids.update({doc.id for doc in metadata_batch})
|
||||
|
||||
doc_batch_generator = None
|
||||
|
||||
if isinstance(runnable_connector, LoadConnector):
|
||||
if isinstance(runnable_connector, IdConnector):
|
||||
all_connector_doc_ids = runnable_connector.retrieve_all_source_ids()
|
||||
elif isinstance(runnable_connector, LoadConnector):
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
elif isinstance(runnable_connector, PollConnector):
|
||||
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
@@ -104,17 +103,16 @@ def extract_ids_from_runnable_connector(
|
||||
else:
|
||||
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
|
||||
|
||||
doc_batch_processing_func = document_batch_to_ids
|
||||
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
|
||||
doc_batch_processing_func = rate_limit_builder(
|
||||
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
|
||||
)(document_batch_to_ids)
|
||||
for doc_batch in doc_batch_generator:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("Stop signal received")
|
||||
callback.progress(len(doc_batch))
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||
if doc_batch_generator:
|
||||
doc_batch_processing_func = document_batch_to_ids
|
||||
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
|
||||
doc_batch_processing_func = rate_limit_builder(
|
||||
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
|
||||
)(document_batch_to_ids)
|
||||
for doc_batch in doc_batch_generator:
|
||||
if progress_callback:
|
||||
progress_callback(len(doc_batch))
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||
|
||||
return all_connector_doc_ids
|
||||
|
||||
@@ -135,10 +133,33 @@ def celery_is_listening_to_queue(worker: Any, name: str) -> bool:
|
||||
def celery_is_worker_primary(worker: Any) -> bool:
|
||||
"""There are multiple approaches that could be taken to determine if a celery worker
|
||||
is 'primary', as defined by us. But the way we do it is to check the hostname set
|
||||
for the celery worker, which can be done on the
|
||||
for the celery worker, which can be done either in celeryconfig.py or on the
|
||||
command line with '--hostname'."""
|
||||
hostname = worker.hostname
|
||||
if hostname.startswith("primary"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
if not MULTI_TENANT:
|
||||
return [None]
|
||||
with get_session_with_tenant(tenant_id="public") as session:
|
||||
result = session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
|
||||
)
|
||||
)
|
||||
tenant_ids = [row[0] for row in result]
|
||||
|
||||
valid_tenants = [
|
||||
tenant
|
||||
for tenant in tenant_ids
|
||||
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
|
||||
]
|
||||
|
||||
return valid_tenants
|
||||
|
||||
@@ -31,10 +31,21 @@ if REDIS_SSL:
|
||||
if REDIS_SSL_CA_CERTS:
|
||||
SSL_QUERY_PARAMS += f"&ssl_ca_certs={REDIS_SSL_CA_CERTS}"
|
||||
|
||||
# region Broker settings
|
||||
# example celery_broker_url: "redis://:password@localhost:6379/15"
|
||||
broker_url = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
|
||||
|
||||
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY_RESULT_BACKEND}{SSL_QUERY_PARAMS}"
|
||||
|
||||
# NOTE: prefetch 4 is significantly faster than prefetch 1 for small tasks
|
||||
# however, prefetching is bad when tasks are lengthy as those tasks
|
||||
# can stall other tasks.
|
||||
worker_prefetch_multiplier = 4
|
||||
|
||||
# Leaving this to the default of True may cause double logging since both our own app
|
||||
# and celery think they are controlling the logger.
|
||||
# TODO: Configure celery's logger entirely manually and set this to False
|
||||
# worker_hijack_root_logger = False
|
||||
|
||||
broker_connection_retry_on_startup = True
|
||||
broker_pool_limit = CELERY_BROKER_POOL_LIMIT
|
||||
|
||||
@@ -49,7 +60,6 @@ broker_transport_options = {
|
||||
"socket_keepalive": True,
|
||||
"socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS,
|
||||
}
|
||||
# endregion
|
||||
|
||||
# redis backend settings
|
||||
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
|
||||
@@ -63,19 +73,10 @@ redis_backend_health_check_interval = REDIS_HEALTH_CHECK_INTERVAL
|
||||
task_default_priority = DanswerCeleryPriority.MEDIUM
|
||||
task_acks_late = True
|
||||
|
||||
# region Task result backend settings
|
||||
# It's possible we don't even need celery's result backend, in which case all of the optimization below
|
||||
# might be irrelevant
|
||||
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY_RESULT_BACKEND}{SSL_QUERY_PARAMS}"
|
||||
result_expires = CELERY_RESULT_EXPIRES # 86400 seconds is the default
|
||||
# endregion
|
||||
|
||||
# Leaving this to the default of True may cause double logging since both our own app
|
||||
# and celery think they are controlling the logger.
|
||||
# TODO: Configure celery's logger entirely manually and set this to False
|
||||
# worker_hijack_root_logger = False
|
||||
|
||||
# region Notes on serialization performance
|
||||
# Option 0: Defaults (json serializer, no compression)
|
||||
# about 1.5 KB per queued task. 1KB in queue, 400B for result, 100 as a child entry in generator result
|
||||
|
||||
@@ -101,4 +102,3 @@ result_expires = CELERY_RESULT_EXPIRES # 86400 seconds is the default
|
||||
# task_serializer = "pickle-bzip2"
|
||||
# result_serializer = "pickle-bzip2"
|
||||
# accept_content=["pickle", "pickle-bzip2"]
|
||||
# endregion
|
||||
@@ -1,14 +0,0 @@
|
||||
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
|
||||
import danswer.background.celery.configs.base as shared_config
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
@@ -1,20 +0,0 @@
|
||||
import danswer.background.celery.configs.base as shared_config
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = 4
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
@@ -1,21 +0,0 @@
|
||||
import danswer.background.celery.configs.base as shared_config
|
||||
from danswer.configs.app_configs import CELERY_WORKER_INDEXING_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
@@ -1,22 +0,0 @@
|
||||
import danswer.background.celery.configs.base as shared_config
|
||||
from danswer.configs.app_configs import CELERY_WORKER_LIGHT_CONCURRENCY
|
||||
from danswer.configs.app_configs import CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = CELERY_WORKER_LIGHT_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER
|
||||
@@ -1,20 +0,0 @@
|
||||
import danswer.background.celery.configs.base as shared_config
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = 4
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
@@ -1,45 +1,29 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
|
||||
RedisConnectorDeletionFenceData,
|
||||
)
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.search_settings import get_all_search_settings
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
class TaskDependencyError(RuntimeError):
|
||||
"""Raised to the caller to indicate dependent tasks are running that would interfere
|
||||
with connector deletion."""
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_connector_deletion_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_connector_deletion_task(*, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
@@ -52,44 +36,25 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
# collect cc_pair_ids
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair.id)
|
||||
|
||||
# try running cleanup on the cc_pair_ids
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
rcs = RedisConnectorStop(cc_pair_id)
|
||||
try:
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
except TaskDependencyError as e:
|
||||
# this means we wanted to start deleting but dependent tasks were running
|
||||
# Leave a stop signal to clear indexing and pruning tasks more quickly
|
||||
task_logger.info(str(e))
|
||||
r.set(rcs.fence_key, cc_pair_id)
|
||||
else:
|
||||
# clear the stop signal if it exists ... no longer needed
|
||||
r.delete(rcs.fence_key)
|
||||
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
cc_pair, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
task_logger.exception("Unexpected exception")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_generate_document_cc_pair_cleanup_tasks(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
@@ -98,87 +63,51 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Note that syncing can still be required even if the number of sync tasks generated is zero.
|
||||
Returns None if no syncing is required.
|
||||
|
||||
Will raise TaskDependencyError if dependent tasks such as indexing and pruning are
|
||||
still running. In our case, the caller reacts by setting a stop signal in Redis to
|
||||
exit those tasks as quickly as possible.
|
||||
"""
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
# we need to load the state of the object inside the fence
|
||||
# we need to refresh the state of the object inside the fence
|
||||
# to avoid a race condition with db.commit/fence deletion
|
||||
# at the end of this taskset
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if not cc_pair:
|
||||
try:
|
||||
db_session.refresh(cc_pair)
|
||||
except ObjectDeletedError:
|
||||
return None
|
||||
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# set a basic fence to start
|
||||
fence_value = RedisConnectorDeletionFenceData(
|
||||
num_tasks=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rcd.taskset_key)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
||||
)
|
||||
r.set(rcd.fence_key, fence_value.model_dump_json())
|
||||
|
||||
try:
|
||||
# do not proceed if connector indexing or connector pruning are running
|
||||
search_settings_list = get_all_search_settings(db_session)
|
||||
for search_settings in search_settings_list:
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings.id)
|
||||
if r.get(rci.fence_key):
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (indexing in progress): "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
if r.get(rcp.fence_key):
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (pruning in progress): "
|
||||
f"cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rcd.taskset_key)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id)
|
||||
if tasks_generated is None:
|
||||
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
|
||||
except TaskDependencyError:
|
||||
r.delete(rcd.fence_key)
|
||||
raise
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
r.delete(rcd.fence_key)
|
||||
tasks_generated = rcd.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
else:
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
# set this only after all tasks have been added
|
||||
fence_value.num_tasks = tasks_generated
|
||||
r.set(rcd.fence_key, fence_value.model_dump_json())
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks finished. "
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rcd.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
@@ -5,24 +5,17 @@ from time import sleep
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
|
||||
RedisConnectorIndexingFenceData,
|
||||
)
|
||||
from danswer.background.celery.tasks.shared.tasks import RedisConnectorIndexingFenceData
|
||||
from danswer.background.indexing.job_client import SimpleJobClient
|
||||
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
@@ -47,49 +40,18 @@ from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class RunIndexingCallback(RunIndexingCallbackInterface):
|
||||
def __init__(
|
||||
self,
|
||||
stop_key: str,
|
||||
generator_progress_key: str,
|
||||
redis_lock: redis.lock.Lock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.redis_lock: redis.lock.Lock = redis_lock
|
||||
self.stop_key: str = stop_key
|
||||
self.generator_progress_key: str = generator_progress_key
|
||||
self.redis_client = redis_client
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_client.exists(self.stop_key):
|
||||
return True
|
||||
return False
|
||||
|
||||
def progress(self, amount: int) -> None:
|
||||
self.redis_lock.reacquire()
|
||||
self.redis_client.incrby(self.generator_progress_key, amount)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_indexing",
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
def check_for_indexing(*, tenant_id: str | None) -> int | None:
|
||||
tasks_created = 0
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
@@ -102,54 +64,31 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
task_logger.info(f"Lock acquired for tenant (Y): {tenant_id}")
|
||||
return None
|
||||
else:
|
||||
task_logger.info(f"Lock acquired for tenant (N): {tenant_id}")
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
if current_search_settings.provider_type is None and not MULTI_TENANT:
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=current_search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
)
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings = [primary_search_settings]
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings.append(secondary_search_settings)
|
||||
|
||||
cc_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings = [primary_search_settings]
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings.append(secondary_search_settings)
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
for search_settings_instance in search_settings:
|
||||
rci = RedisConnectorIndexing(
|
||||
cc_pair_id, search_settings_instance.id
|
||||
cc_pair.id, search_settings_instance.id
|
||||
)
|
||||
if r.exists(rci.fence_key):
|
||||
continue
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id, db_session
|
||||
)
|
||||
if not cc_pair:
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
@@ -165,7 +104,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
# using a task queue and only allowing one task per cc_pair/search_setting
|
||||
# prevents us from starving out certain attempts
|
||||
attempt_id = try_creating_indexing_task(
|
||||
self.app,
|
||||
cc_pair,
|
||||
search_settings_instance,
|
||||
False,
|
||||
@@ -175,7 +113,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
)
|
||||
if attempt_id:
|
||||
task_logger.info(
|
||||
f"Indexing queued: cc_pair={cc_pair.id} index_attempt={attempt_id}"
|
||||
f"Indexing queued: cc_pair_id={cc_pair.id} index_attempt_id={attempt_id}"
|
||||
)
|
||||
tasks_created += 1
|
||||
except SoftTimeLimitExceeded:
|
||||
@@ -183,7 +121,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
task_logger.exception("Unexpected exception")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
@@ -275,7 +213,6 @@ def _should_index(
|
||||
|
||||
|
||||
def try_creating_indexing_task(
|
||||
celery_app: Celery,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
search_settings: SearchSettings,
|
||||
reindex: bool,
|
||||
@@ -311,10 +248,6 @@ def try_creating_indexing_task(
|
||||
return None
|
||||
|
||||
# skip indexing if the cc_pair is deleting
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
@@ -325,19 +258,7 @@ def try_creating_indexing_task(
|
||||
|
||||
custom_task_id = f"{rci.generator_task_id_prefix}_{uuid4()}"
|
||||
|
||||
# set a basic fence to start
|
||||
fence_value = RedisConnectorIndexingFenceData(
|
||||
index_attempt_id=None,
|
||||
started=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
celery_task_id=None,
|
||||
)
|
||||
r.set(rci.fence_key, fence_value.model_dump_json())
|
||||
|
||||
# create the index attempt for tracking purposes
|
||||
# code elsewhere checks for index attempts without an associated redis key
|
||||
# and cleans them up
|
||||
# therefore we must create the attempt and the task after the fence goes up
|
||||
# create the index attempt ... just for tracking purposes
|
||||
index_attempt_id = create_index_attempt(
|
||||
cc_pair.id,
|
||||
search_settings.id,
|
||||
@@ -358,20 +279,18 @@ def try_creating_indexing_task(
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
|
||||
return None
|
||||
|
||||
# now fill out the fence with the rest of the data
|
||||
fence_value.index_attempt_id = index_attempt_id
|
||||
fence_value.celery_task_id = result.id
|
||||
# set this only after all tasks have been added
|
||||
fence_value = RedisConnectorIndexingFenceData(
|
||||
index_attempt_id=index_attempt_id,
|
||||
started=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
celery_task_id=result.id,
|
||||
)
|
||||
r.set(rci.fence_key, fence_value.model_dump_json())
|
||||
except Exception:
|
||||
r.delete(rci.fence_key)
|
||||
task_logger.exception(
|
||||
f"Unexpected exception: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
task_logger.exception("Unexpected exception")
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
@@ -453,55 +372,8 @@ def connector_indexing_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
if r.exists(rcd.fence_key):
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because connector deletion is in progress: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={rcd.fence_key}"
|
||||
)
|
||||
|
||||
rcs = RedisConnectorStop(cc_pair_id)
|
||||
if r.exists(rcs.fence_key):
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because a connector stop signal was detected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={rcs.fence_key}"
|
||||
)
|
||||
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
|
||||
|
||||
while True:
|
||||
# read related data and evaluate/print task progress
|
||||
fence_value = cast(bytes, r.get(rci.fence_key))
|
||||
if fence_value is None:
|
||||
raise ValueError(
|
||||
f"connector_indexing_task: fence_value not found: fence={rci.fence_key}"
|
||||
)
|
||||
|
||||
try:
|
||||
fence_json = fence_value.decode("utf-8")
|
||||
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
|
||||
cast(str, fence_json)
|
||||
)
|
||||
except ValueError:
|
||||
task_logger.exception(
|
||||
f"connector_indexing_task: fence_data not decodeable: fence={rci.fence_key}"
|
||||
)
|
||||
raise
|
||||
|
||||
if fence_data.index_attempt_id is None or fence_data.celery_task_id is None:
|
||||
task_logger.info(
|
||||
f"connector_indexing_task - Waiting for fence: fence={rci.fence_key}"
|
||||
)
|
||||
sleep(1)
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"connector_indexing_task - Fence found, continuing...: fence={rci.fence_key}"
|
||||
)
|
||||
break
|
||||
|
||||
lock = r.lock(
|
||||
rci.generator_lock_key,
|
||||
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
|
||||
@@ -511,20 +383,17 @@ def connector_indexing_task(
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"Indexing task already running, exiting...: "
|
||||
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
|
||||
f"cc_pair_id={cc_pair_id} search_settings_id={search_settings_id}"
|
||||
)
|
||||
# r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value)
|
||||
return None
|
||||
|
||||
fence_data.started = datetime.now(timezone.utc)
|
||||
r.set(rci.fence_key, fence_data.model_dump_json())
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if not attempt:
|
||||
raise ValueError(
|
||||
f"Index attempt not found: index_attempt={index_attempt_id}"
|
||||
f"Index attempt not found: index_attempt_id={index_attempt_id}"
|
||||
)
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
@@ -533,31 +402,31 @@ def connector_indexing_task(
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
raise ValueError(f"cc_pair not found: cc_pair={cc_pair_id}")
|
||||
raise ValueError(f"cc_pair not found: cc_pair_id={cc_pair_id}")
|
||||
|
||||
if not cc_pair.connector:
|
||||
raise ValueError(
|
||||
f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}"
|
||||
f"Connector not found: connector_id={cc_pair.connector_id}"
|
||||
)
|
||||
|
||||
if not cc_pair.credential:
|
||||
raise ValueError(
|
||||
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}"
|
||||
f"Credential not found: credential_id={cc_pair.credential_id}"
|
||||
)
|
||||
|
||||
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
|
||||
|
||||
# define a callback class
|
||||
callback = RunIndexingCallback(
|
||||
rcs.fence_key, rci.generator_progress_key, lock, r
|
||||
)
|
||||
# Define the callback function
|
||||
def redis_increment_callback(amount: int) -> None:
|
||||
lock.reacquire()
|
||||
r.incrby(rci.generator_progress_key, amount)
|
||||
|
||||
run_indexing_entrypoint(
|
||||
index_attempt_id,
|
||||
tenant_id,
|
||||
cc_pair_id,
|
||||
is_ee,
|
||||
callback=callback,
|
||||
progress_callback=redis_increment_callback,
|
||||
)
|
||||
|
||||
# get back the total number of indexed docs and return it
|
||||
@@ -570,10 +439,9 @@ def connector_indexing_task(
|
||||
|
||||
r.set(rci.generator_complete_key, HTTPStatus.OK.value)
|
||||
except Exception as e:
|
||||
task_logger.exception(f"Indexing failed: cc_pair={cc_pair_id}")
|
||||
task_logger.exception(f"Failed to run indexing for cc_pair_id={cc_pair_id}.")
|
||||
if attempt:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
|
||||
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
|
||||
|
||||
r.delete(rci.generator_lock_key)
|
||||
r.delete(rci.generator_progress_key)
|
||||
|
||||
@@ -11,7 +11,7 @@ from sqlalchemy import inspect
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import PostgresAdvisoryLocks
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
|
||||
@@ -3,19 +3,15 @@ from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisConnectorStop
|
||||
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
|
||||
@@ -27,7 +23,6 @@ from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
@@ -42,9 +37,8 @@ logger = setup_logger()
|
||||
@shared_task(
|
||||
name="check_for_pruning",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_pruning(*, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
@@ -57,35 +51,26 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if not cc_pair:
|
||||
continue
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
lock_beat.reacquire()
|
||||
if not is_pruning_due(cc_pair, db_session, r):
|
||||
continue
|
||||
|
||||
tasks_created = try_creating_prune_generator_task(
|
||||
self.app, cc_pair, db_session, r, tenant_id
|
||||
cc_pair, db_session, r, tenant_id
|
||||
)
|
||||
if not tasks_created:
|
||||
continue
|
||||
|
||||
task_logger.info(f"Pruning queued: cc_pair={cc_pair.id}")
|
||||
task_logger.info(f"Pruning queued: cc_pair_id={cc_pair.id}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
task_logger.exception("Unexpected exception")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
@@ -133,7 +118,6 @@ def is_pruning_due(
|
||||
|
||||
|
||||
def try_creating_prune_generator_task(
|
||||
celery_app: Celery,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
@@ -171,10 +155,6 @@ def try_creating_prune_generator_task(
|
||||
return None
|
||||
|
||||
# skip pruning if the cc_pair is deleting
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
@@ -201,7 +181,7 @@ def try_creating_prune_generator_task(
|
||||
# set this only after all tasks have been added
|
||||
r.set(rcp.fence_key, 1)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}")
|
||||
task_logger.exception("Unexpected exception")
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
@@ -216,14 +196,9 @@ def try_creating_prune_generator_task(
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def connector_pruning_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
tenant_id: str | None,
|
||||
cc_pair_id: int, connector_id: int, credential_id: int, tenant_id: str | None
|
||||
) -> None:
|
||||
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
@@ -241,7 +216,7 @@ def connector_pruning_generator_task(
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"Pruning task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
f"Pruning task already running, exiting...: cc_pair_id={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -259,22 +234,22 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
return
|
||||
|
||||
# Define the callback function
|
||||
def redis_increment_callback(amount: int) -> None:
|
||||
lock.reacquire()
|
||||
r.incrby(rcp.generator_progress_key, amount)
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session,
|
||||
cc_pair.connector.source,
|
||||
InputType.SLIM_RETRIEVAL,
|
||||
InputType.PRUNE,
|
||||
cc_pair.connector.connector_specific_config,
|
||||
cc_pair.credential,
|
||||
)
|
||||
|
||||
rcs = RedisConnectorStop(cc_pair_id)
|
||||
|
||||
callback = RunIndexingCallback(
|
||||
rcs.fence_key, rcp.generator_progress_key, lock, r
|
||||
)
|
||||
# a list of docs in the source
|
||||
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
|
||||
runnable_connector, callback
|
||||
runnable_connector, redis_increment_callback
|
||||
)
|
||||
|
||||
# a list of docs in our local index
|
||||
@@ -292,7 +267,7 @@ def connector_pruning_generator_task(
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning set collected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"cc_pair_id={cc_pair.id} "
|
||||
f"docs_to_remove={len(doc_ids_to_remove)} "
|
||||
f"doc_source={cc_pair.connector.source}"
|
||||
)
|
||||
@@ -300,24 +275,22 @@ def connector_pruning_generator_task(
|
||||
rcp.documents_to_prune = set(doc_ids_to_remove)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorPruning.generate_tasks starting. cc_pair={cc_pair.id}"
|
||||
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
||||
)
|
||||
tasks_generated = rcp.generate_tasks(
|
||||
self.app, db_session, r, None, tenant_id
|
||||
celery_app, db_session, r, None, tenant_id
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorPruning.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
r.set(rcp.generator_complete_key, tasks_generated)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
|
||||
)
|
||||
task_logger.exception(f"Failed to run pruning for connector id {connector_id}.")
|
||||
|
||||
r.delete(rcp.generator_progress_key)
|
||||
r.delete(rcp.taskset_key)
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RedisConnectorDeletionFenceData(BaseModel):
|
||||
num_tasks: int | None
|
||||
submitted: datetime
|
||||
@@ -1,10 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RedisConnectorIndexingFenceData(BaseModel):
|
||||
index_attempt_id: int | None
|
||||
started: datetime | None
|
||||
submitted: datetime
|
||||
celery_task_id: str | None
|
||||
@@ -1,40 +0,0 @@
|
||||
import httpx
|
||||
from tenacity import retry
|
||||
from tenacity import retry_if_exception_type
|
||||
from tenacity import stop_after_delay
|
||||
from tenacity import wait_random_exponential
|
||||
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import VespaDocumentFields
|
||||
|
||||
|
||||
class RetryDocumentIndex:
|
||||
"""A wrapper class to help with specific retries against Vespa involving
|
||||
read timeouts.
|
||||
|
||||
wait_random_exponential implements full jitter as per this article:
|
||||
https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/"""
|
||||
|
||||
MAX_WAIT = 30
|
||||
|
||||
# STOP_AFTER + MAX_WAIT should be slightly less (5?) than the celery soft_time_limit
|
||||
STOP_AFTER = 70
|
||||
|
||||
def __init__(self, index: DocumentIndex):
|
||||
self.index: DocumentIndex = index
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(httpx.ReadTimeout),
|
||||
wait=wait_random_exponential(multiplier=1, max=MAX_WAIT),
|
||||
stop=stop_after_delay(STOP_AFTER),
|
||||
)
|
||||
def delete_single(self, doc_id: str) -> int:
|
||||
return self.index.delete_single(doc_id)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(httpx.ReadTimeout),
|
||||
wait=wait_random_exponential(multiplier=1, max=MAX_WAIT),
|
||||
stop=stop_after_delay(STOP_AFTER),
|
||||
)
|
||||
def update_single(self, doc_id: str, fields: VespaDocumentFields) -> int:
|
||||
return self.index.update_single(doc_id, fields)
|
||||
@@ -1,19 +1,16 @@
|
||||
from http import HTTPStatus
|
||||
from datetime import datetime
|
||||
|
||||
import httpx
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from tenacity import RetryError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document
|
||||
from danswer.db.document import get_document_connector_count
|
||||
from danswer.db.document import mark_document_as_modified
|
||||
from danswer.db.document import mark_document_as_synced
|
||||
from danswer.db.document_set import fetch_document_sets_for_document
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
@@ -22,20 +19,20 @@ from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import VespaDocumentFields
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3
|
||||
|
||||
|
||||
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
|
||||
LIGHT_SOFT_TIME_LIMIT = 105
|
||||
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
class RedisConnectorIndexingFenceData(BaseModel):
|
||||
index_attempt_id: int
|
||||
started: datetime | None
|
||||
submitted: datetime
|
||||
celery_task_id: str
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="document_by_cc_pair_cleanup_task",
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES,
|
||||
bind=True,
|
||||
soft_time_limit=45,
|
||||
time_limit=60,
|
||||
max_retries=3,
|
||||
)
|
||||
def document_by_cc_pair_cleanup_task(
|
||||
self: Task,
|
||||
@@ -59,7 +56,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
connector / credential pair from the access list
|
||||
(6) delete all relevant entries from postgres
|
||||
"""
|
||||
task_logger.info(f"tenant={tenant_id} doc={document_id}")
|
||||
task_logger.info(f"document_id={document_id}")
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -67,19 +64,17 @@ def document_by_cc_pair_cleanup_task(
|
||||
chunks_affected = 0
|
||||
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
|
||||
count = get_document_connector_count(db_session, document_id)
|
||||
if count == 1:
|
||||
# count == 1 means this is the only remaining cc_pair reference to the doc
|
||||
# delete it from vespa and the db
|
||||
action = "delete"
|
||||
|
||||
chunks_affected = retry_index.delete_single(document_id)
|
||||
chunks_affected = document_index.delete_single(document_id)
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=[document_id],
|
||||
@@ -109,7 +104,9 @@ def document_by_cc_pair_cleanup_task(
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
chunks_affected = retry_index.update_single(document_id, fields=fields)
|
||||
chunks_affected = document_index.update_single(
|
||||
document_id, fields=fields
|
||||
)
|
||||
|
||||
# there are still other cc_pair references to the doc, so just resync to Vespa
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
@@ -125,58 +122,23 @@ def document_by_cc_pair_cleanup_task(
|
||||
else:
|
||||
pass
|
||||
|
||||
db_session.commit()
|
||||
|
||||
task_logger.info(
|
||||
f"tenant={tenant_id} "
|
||||
f"doc={document_id} "
|
||||
f"tenant_id={tenant_id} "
|
||||
f"document_id={document_id} "
|
||||
f"action={action} "
|
||||
f"refcount={count} "
|
||||
f"chunks={chunks_affected}"
|
||||
)
|
||||
db_session.commit()
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
f"SoftTimeLimitExceeded exception. tenant={tenant_id} doc={document_id}"
|
||||
f"SoftTimeLimitExceeded exception. tenant_id={tenant_id} doc_id={document_id}"
|
||||
)
|
||||
return False
|
||||
except Exception as ex:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.info(f"Retry failed: {ex.last_attempt.attempt_number}")
|
||||
except Exception as e:
|
||||
task_logger.exception("Unexpected exception")
|
||||
|
||||
# only set the inner exception if it is of type Exception
|
||||
e_temp = ex.last_attempt.exception()
|
||||
if isinstance(e_temp, Exception):
|
||||
e = e_temp
|
||||
else:
|
||||
e = ex
|
||||
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
if e.response.status_code == HTTPStatus.BAD_REQUEST:
|
||||
task_logger.exception(
|
||||
f"Non-retryable HTTPStatusError: "
|
||||
f"tenant={tenant_id} "
|
||||
f"doc={document_id} "
|
||||
f"status={e.response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
task_logger.exception(
|
||||
f"Unexpected exception: tenant={tenant_id} doc={document_id}"
|
||||
)
|
||||
|
||||
if self.request.retries < DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES:
|
||||
# Still retrying. Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
else:
|
||||
# This is the last attempt! mark the document as dirty in the db so that it
|
||||
# eventually gets fixed out of band via stale document reconciliation
|
||||
task_logger.info(
|
||||
f"Max retries reached. Marking doc as dirty for reconciliation: "
|
||||
f"tenant={tenant_id} doc={document_id}"
|
||||
)
|
||||
with get_session_with_tenant(tenant_id):
|
||||
mark_document_as_modified(document_id, db_session)
|
||||
return False
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
|
||||
return True
|
||||
|
||||
@@ -4,9 +4,7 @@ from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
import redis
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
@@ -14,10 +12,10 @@ from celery.result import AsyncResult
|
||||
from celery.states import READY_STATES
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
from tenacity import RetryError
|
||||
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import celery_get_queue_length
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
@@ -25,15 +23,7 @@ from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
|
||||
RedisConnectorDeletionFenceData,
|
||||
)
|
||||
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
|
||||
RedisConnectorIndexingFenceData,
|
||||
)
|
||||
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from danswer.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from danswer.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from danswer.background.celery.tasks.shared.tasks import RedisConnectorIndexingFenceData
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
@@ -64,6 +54,7 @@ from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import UserGroup
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import VespaDocumentFields
|
||||
@@ -85,9 +76,8 @@ logger = setup_logger()
|
||||
name="check_for_vespa_sync_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_vespa_sync_task(*, tenant_id: str | None) -> None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
|
||||
@@ -104,71 +94,49 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_stale_document_sync_tasks(
|
||||
self.app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
try_generate_stale_document_sync_tasks(db_session, r, lock_beat, tenant_id)
|
||||
|
||||
# region document set scan
|
||||
document_set_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# check if any document sets are not synced
|
||||
document_set_info = fetch_document_sets(
|
||||
user_id=None, db_session=db_session, include_outdated=True
|
||||
)
|
||||
|
||||
for document_set, _ in document_set_info:
|
||||
document_set_ids.append(document_set.id)
|
||||
|
||||
for document_set_id in document_set_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_document_set_sync_tasks(
|
||||
self.app, document_set_id, db_session, r, lock_beat, tenant_id
|
||||
document_set, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
# endregion
|
||||
|
||||
# check if any user groups are not synced
|
||||
if global_version.is_ee_version():
|
||||
try:
|
||||
fetch_user_groups = fetch_versioned_implementation(
|
||||
"danswer.db.user_group", "fetch_user_groups"
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
# Always exceptions on the MIT version, which is expected
|
||||
# We shouldn't actually get here if the ee version check works
|
||||
pass
|
||||
else:
|
||||
usergroup_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# check if any user groups are not synced
|
||||
if global_version.is_ee_version():
|
||||
try:
|
||||
fetch_user_groups = fetch_versioned_implementation(
|
||||
"danswer.db.user_group", "fetch_user_groups"
|
||||
)
|
||||
|
||||
user_groups = fetch_user_groups(
|
||||
db_session=db_session, only_up_to_date=False
|
||||
)
|
||||
|
||||
for usergroup in user_groups:
|
||||
usergroup_ids.append(usergroup.id)
|
||||
|
||||
for usergroup_id in usergroup_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_user_group_sync_tasks(
|
||||
self.app, usergroup_id, db_session, r, lock_beat, tenant_id
|
||||
usergroup, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
# Always exceptions on the MIT version, which is expected
|
||||
# We shouldn't actually get here if the ee version check works
|
||||
pass
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
task_logger.exception("Unexpected exception")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_generate_stale_document_sync_tasks(
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
tenant_id: str | None,
|
||||
db_session: Session, r: Redis, lock_beat: redis.lock.Lock, tenant_id: str | None
|
||||
) -> int | None:
|
||||
# the fence is up, do nothing
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
@@ -219,8 +187,7 @@ def try_generate_stale_document_sync_tasks(
|
||||
|
||||
|
||||
def try_generate_document_set_sync_tasks(
|
||||
celery_app: Celery,
|
||||
document_set_id: int,
|
||||
document_set: DocumentSet,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
@@ -228,7 +195,7 @@ def try_generate_document_set_sync_tasks(
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rds = RedisDocumentSet(document_set_id)
|
||||
rds = RedisDocumentSet(document_set.id)
|
||||
|
||||
# don't generate document set sync tasks if tasks are still pending
|
||||
if r.exists(rds.fence_key):
|
||||
@@ -236,10 +203,7 @@ def try_generate_document_set_sync_tasks(
|
||||
|
||||
# don't generate sync tasks if we're up to date
|
||||
# race condition with the monitor/cleanup function if we use a cached result!
|
||||
document_set = get_document_set_by_id(db_session, document_set_id)
|
||||
if not document_set:
|
||||
return None
|
||||
|
||||
db_session.refresh(document_set)
|
||||
if document_set.is_up_to_date:
|
||||
return None
|
||||
|
||||
@@ -274,8 +238,7 @@ def try_generate_document_set_sync_tasks(
|
||||
|
||||
|
||||
def try_generate_user_group_sync_tasks(
|
||||
celery_app: Celery,
|
||||
usergroup_id: int,
|
||||
usergroup: UserGroup,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
@@ -283,21 +246,14 @@ def try_generate_user_group_sync_tasks(
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rug = RedisUserGroup(usergroup_id)
|
||||
rug = RedisUserGroup(usergroup.id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if r.exists(rug.fence_key):
|
||||
return None
|
||||
|
||||
# race condition with the monitor/cleanup function if we use a cached result!
|
||||
fetch_user_group = fetch_versioned_implementation(
|
||||
"danswer.db.user_group", "fetch_user_group"
|
||||
)
|
||||
|
||||
usergroup = fetch_user_group(db_session, usergroup_id)
|
||||
if not usergroup:
|
||||
return None
|
||||
|
||||
db_session.refresh(usergroup)
|
||||
if usergroup.is_up_to_date:
|
||||
return None
|
||||
|
||||
@@ -376,7 +332,7 @@ def monitor_document_set_taskset(
|
||||
|
||||
count = cast(int, r.scard(rds.taskset_key))
|
||||
task_logger.info(
|
||||
f"Document set sync progress: document_set={document_set_id} "
|
||||
f"Document set sync progress: document_set_id={document_set_id} "
|
||||
f"remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
@@ -391,12 +347,12 @@ def monitor_document_set_taskset(
|
||||
# if there are no connectors, then delete the document set.
|
||||
delete_document_set(document_set_row=document_set, db_session=db_session)
|
||||
task_logger.info(
|
||||
f"Successfully deleted document set: document_set={document_set_id}"
|
||||
f"Successfully deleted document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
else:
|
||||
mark_document_set_as_synced(document_set_id, db_session)
|
||||
task_logger.info(
|
||||
f"Successfully synced document set: document_set={document_set_id}"
|
||||
f"Successfully synced document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
|
||||
r.delete(rds.taskset_key)
|
||||
@@ -416,29 +372,19 @@ def monitor_connector_deletion_taskset(
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
|
||||
# read related data and evaluate/print task progress
|
||||
fence_value = cast(bytes, r.get(rcd.fence_key))
|
||||
fence_value = r.get(rcd.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
fence_json = fence_value.decode("utf-8")
|
||||
fence_data = RedisConnectorDeletionFenceData.model_validate_json(
|
||||
cast(str, fence_json)
|
||||
)
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.exception(
|
||||
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
|
||||
)
|
||||
raise
|
||||
|
||||
# the fence is setting up but isn't ready yet
|
||||
if fence_data.num_tasks is None:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rcd.taskset_key))
|
||||
task_logger.info(
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={fence_data.num_tasks}"
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
@@ -501,7 +447,7 @@ def monitor_connector_deletion_taskset(
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
task_logger.info(
|
||||
"Connector deletion - Found no credentials left for connector, deleting connector"
|
||||
"Found no credentials left for connector, deleting connector"
|
||||
)
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
@@ -511,17 +457,17 @@ def monitor_connector_deletion_taskset(
|
||||
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
|
||||
add_deletion_failure_message(db_session, cc_pair_id, error_message)
|
||||
task_logger.exception(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"Failed to run connector_deletion. "
|
||||
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
f"Connector deletion succeeded: "
|
||||
f"Successfully deleted cc_pair: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector={cc_pair.connector_id} "
|
||||
f"credential={cc_pair.credential_id} "
|
||||
f"docs_deleted={fence_data.num_tasks}"
|
||||
f"docs_deleted={initial_count}"
|
||||
)
|
||||
|
||||
r.delete(rcd.taskset_key)
|
||||
@@ -631,12 +577,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
"monitor_ccpair_indexing_taskset: generator_progress_value is not an integer."
|
||||
)
|
||||
|
||||
if fence_data.index_attempt_id is None or fence_data.celery_task_id is None:
|
||||
# the task is still setting up
|
||||
return
|
||||
|
||||
# Read result state BEFORE generator_complete_key to avoid a race condition
|
||||
# never use any blocking methods on the result from inside a task!
|
||||
result: AsyncResult = AsyncResult(fence_data.celery_task_id)
|
||||
result_state = result.state
|
||||
|
||||
@@ -738,9 +679,36 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
f"pruning={n_pruning}"
|
||||
)
|
||||
|
||||
# do some cleanup before clearing fences
|
||||
# check the db for any outstanding index attempts
|
||||
lock_beat.reacquire()
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
monitor_document_set_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
monitor_usergroup_taskset = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"danswer.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
)
|
||||
monitor_usergroup_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
|
||||
|
||||
# do some cleanup before clearing fences
|
||||
# check the db for any outstanding index attempts
|
||||
attempts: list[IndexAttempt] = []
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
|
||||
@@ -758,42 +726,8 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
if not r.exists(rci.fence_key):
|
||||
mark_attempt_failed(a, db_session, failure_reason=failure_reason)
|
||||
|
||||
lock_beat.reacquire()
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
|
||||
monitor_ccpair_indexing_taskset(key_bytes, r, db_session)
|
||||
|
||||
# uncomment for debugging if needed
|
||||
@@ -814,22 +748,22 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
@shared_task(
|
||||
name="vespa_metadata_sync_task",
|
||||
bind=True,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
soft_time_limit=45,
|
||||
time_limit=60,
|
||||
max_retries=3,
|
||||
)
|
||||
def vespa_metadata_sync_task(
|
||||
self: Task, document_id: str, tenant_id: str | None
|
||||
) -> bool:
|
||||
task_logger.info(f"document_id={document_id}")
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
return False
|
||||
@@ -851,43 +785,19 @@ def vespa_metadata_sync_task(
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
chunks_affected = retry_index.update_single(document_id, fields)
|
||||
chunks_affected = document_index.update_single(document_id, fields=fields)
|
||||
|
||||
# update db last. Worst case = we crash right before this and
|
||||
# the sync might repeat again later
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
|
||||
task_logger.info(
|
||||
f"tenant={tenant_id} doc={document_id} action=sync chunks={chunks_affected}"
|
||||
f"document_id={document_id} action=sync chunks={chunks_affected}"
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
f"SoftTimeLimitExceeded exception. tenant={tenant_id} doc={document_id}"
|
||||
)
|
||||
except Exception as ex:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(f"Retry failed: {ex.last_attempt.attempt_number}")
|
||||
|
||||
# only set the inner exception if it is of type Exception
|
||||
e_temp = ex.last_attempt.exception()
|
||||
if isinstance(e_temp, Exception):
|
||||
e = e_temp
|
||||
else:
|
||||
e = ex
|
||||
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
if e.response.status_code == HTTPStatus.BAD_REQUEST:
|
||||
task_logger.exception(
|
||||
f"Non-retryable HTTPStatusError: "
|
||||
f"tenant={tenant_id} "
|
||||
f"doc={document_id} "
|
||||
f"status={e.response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
task_logger.exception(
|
||||
f"Unexpected exception: tenant={tenant_id} doc={document_id}"
|
||||
)
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
|
||||
except Exception as e:
|
||||
task_logger.exception("Unexpected exception")
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
"""Factory stub for running celery worker / celery beat."""
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app = fetch_versioned_implementation(
|
||||
"danswer.background.celery.apps.beat", "celery_app"
|
||||
)
|
||||
@@ -1,17 +0,0 @@
|
||||
"""Factory stub for running celery worker / celery beat.
|
||||
This code is different from the primary/beat stubs because there is no EE version to
|
||||
fetch. Port over the code in those files if we add an EE version of this worker."""
|
||||
from celery import Celery
|
||||
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from danswer.background.celery.apps.heavy import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
|
||||
app = get_app()
|
||||
@@ -1,17 +0,0 @@
|
||||
"""Factory stub for running celery worker / celery beat.
|
||||
This code is different from the primary/beat stubs because there is no EE version to
|
||||
fetch. Port over the code in those files if we add an EE version of this worker."""
|
||||
from celery import Celery
|
||||
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from danswer.background.celery.apps.indexing import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
|
||||
app = get_app()
|
||||
@@ -1,17 +0,0 @@
|
||||
"""Factory stub for running celery worker / celery beat.
|
||||
This code is different from the primary/beat stubs because there is no EE version to
|
||||
fetch. Port over the code in those files if we add an EE version of this worker."""
|
||||
from celery import Celery
|
||||
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from danswer.background.celery.apps.light import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
|
||||
app = get_app()
|
||||
@@ -1,7 +1,6 @@
|
||||
import time
|
||||
import traceback
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -42,19 +41,6 @@ logger = setup_logger()
|
||||
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
|
||||
|
||||
|
||||
class RunIndexingCallbackInterface(ABC):
|
||||
"""Defines a callback interface to be passed to
|
||||
to run_indexing_entrypoint."""
|
||||
|
||||
@abstractmethod
|
||||
def should_stop(self) -> bool:
|
||||
"""Signal to stop the looping function in flight."""
|
||||
|
||||
@abstractmethod
|
||||
def progress(self, amount: int) -> None:
|
||||
"""Send progress updates to the caller."""
|
||||
|
||||
|
||||
def _get_connector_runner(
|
||||
db_session: Session,
|
||||
attempt: IndexAttempt,
|
||||
@@ -106,7 +92,7 @@ def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt: IndexAttempt,
|
||||
tenant_id: str | None,
|
||||
callback: RunIndexingCallbackInterface | None = None,
|
||||
progress_callback: Callable[[int], None] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
1. Get documents which are either new or updated from specified application
|
||||
@@ -220,11 +206,6 @@ def _run_indexing(
|
||||
# index being built. We want to populate it even for paused connectors
|
||||
# Often paused connectors are sources that aren't updated frequently but the
|
||||
# contents still need to be initially pulled.
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("Connector stop signal detected")
|
||||
|
||||
# TODO: should we move this into the above callback instead?
|
||||
db_session.refresh(db_cc_pair)
|
||||
if (
|
||||
(
|
||||
@@ -282,8 +263,8 @@ def _run_indexing(
|
||||
# be inaccurate
|
||||
db_session.commit()
|
||||
|
||||
if callback:
|
||||
callback.progress(len(doc_batch))
|
||||
if progress_callback:
|
||||
progress_callback(len(doc_batch))
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
update_docs_indexed(
|
||||
@@ -413,7 +394,7 @@ def run_indexing_entrypoint(
|
||||
tenant_id: str | None,
|
||||
connector_credential_pair_id: int,
|
||||
is_ee: bool = False,
|
||||
callback: RunIndexingCallbackInterface | None = None,
|
||||
progress_callback: Callable[[int], None] | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
if is_ee:
|
||||
@@ -436,7 +417,7 @@ def run_indexing_entrypoint(
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
|
||||
_run_indexing(db_session, attempt, tenant_id, callback)
|
||||
_run_indexing(db_session, attempt, tenant_id, progress_callback)
|
||||
|
||||
logger.info(
|
||||
f"Indexing finished for tenant {tenant_id}: "
|
||||
|
||||
494
backend/danswer/background/update.py
Executable file
494
backend/danswer/background/update.py
Executable file
@@ -0,0 +1,494 @@
|
||||
# TODO(rkuo): delete after background indexing via celery is fully vetted
|
||||
# import logging
|
||||
# import time
|
||||
# from datetime import datetime
|
||||
# import dask
|
||||
# from dask.distributed import Client
|
||||
# from dask.distributed import Future
|
||||
# from distributed import LocalCluster
|
||||
# from sqlalchemy import text
|
||||
# from sqlalchemy.exc import ProgrammingError
|
||||
# from sqlalchemy.orm import Session
|
||||
# from danswer.background.indexing.dask_utils import ResourceLogger
|
||||
# from danswer.background.indexing.job_client import SimpleJob
|
||||
# from danswer.background.indexing.job_client import SimpleJobClient
|
||||
# from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
# from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
||||
# from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||
# from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
# from danswer.configs.app_configs import MULTI_TENANT
|
||||
# from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||
# from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
|
||||
# from danswer.configs.constants import DocumentSource
|
||||
# from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
|
||||
# from danswer.configs.constants import TENANT_ID_PREFIX
|
||||
# from danswer.db.connector import fetch_connectors
|
||||
# from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
# from danswer.db.engine import get_db_current_time
|
||||
# from danswer.db.engine import get_session_with_tenant
|
||||
# from danswer.db.engine import get_sqlalchemy_engine
|
||||
# from danswer.db.engine import SqlEngine
|
||||
# from danswer.db.index_attempt import create_index_attempt
|
||||
# from danswer.db.index_attempt import get_index_attempt
|
||||
# from danswer.db.index_attempt import get_inprogress_index_attempts
|
||||
# from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
# from danswer.db.index_attempt import get_not_started_index_attempts
|
||||
# from danswer.db.index_attempt import mark_attempt_failed
|
||||
# from danswer.db.models import ConnectorCredentialPair
|
||||
# from danswer.db.models import IndexAttempt
|
||||
# from danswer.db.models import IndexingStatus
|
||||
# from danswer.db.models import IndexModelStatus
|
||||
# from danswer.db.models import SearchSettings
|
||||
# from danswer.db.search_settings import get_current_search_settings
|
||||
# from danswer.db.search_settings import get_secondary_search_settings
|
||||
# from danswer.db.swap_index import check_index_swap
|
||||
# from danswer.document_index.vespa.index import VespaIndex
|
||||
# from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
# from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
# from danswer.utils.logger import setup_logger
|
||||
# from danswer.utils.variable_functionality import global_version
|
||||
# from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
# from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
# from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
# from shared_configs.configs import LOG_LEVEL
|
||||
# logger = setup_logger()
|
||||
# # If the indexing dies, it's most likely due to resource constraints,
|
||||
# # restarting just delays the eventual failure, not useful to the user
|
||||
# dask.config.set({"distributed.scheduler.allowed-failures": 0})
|
||||
# _UNEXPECTED_STATE_FAILURE_REASON = (
|
||||
# "Stopped mid run, likely due to the background process being killed"
|
||||
# )
|
||||
# def _should_create_new_indexing(
|
||||
# cc_pair: ConnectorCredentialPair,
|
||||
# last_index: IndexAttempt | None,
|
||||
# search_settings_instance: SearchSettings,
|
||||
# secondary_index_building: bool,
|
||||
# db_session: Session,
|
||||
# ) -> bool:
|
||||
# connector = cc_pair.connector
|
||||
# # don't kick off indexing for `NOT_APPLICABLE` sources
|
||||
# if connector.source == DocumentSource.NOT_APPLICABLE:
|
||||
# return False
|
||||
# # User can still manually create single indexing attempts via the UI for the
|
||||
# # currently in use index
|
||||
# if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
# if (
|
||||
# search_settings_instance.status == IndexModelStatus.PRESENT
|
||||
# and secondary_index_building
|
||||
# ):
|
||||
# return False
|
||||
# # When switching over models, always index at least once
|
||||
# if search_settings_instance.status == IndexModelStatus.FUTURE:
|
||||
# if last_index:
|
||||
# # No new index if the last index attempt succeeded
|
||||
# # Once is enough. The model will never be able to swap otherwise.
|
||||
# if last_index.status == IndexingStatus.SUCCESS:
|
||||
# return False
|
||||
# # No new index if the last index attempt is waiting to start
|
||||
# if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
# return False
|
||||
# # No new index if the last index attempt is running
|
||||
# if last_index.status == IndexingStatus.IN_PROGRESS:
|
||||
# return False
|
||||
# else:
|
||||
# if (
|
||||
# connector.id == 0 or connector.source == DocumentSource.INGESTION_API
|
||||
# ): # Ingestion API
|
||||
# return False
|
||||
# return True
|
||||
# # If the connector is paused or is the ingestion API, don't index
|
||||
# # NOTE: during an embedding model switch over, the following logic
|
||||
# # is bypassed by the above check for a future model
|
||||
# if (
|
||||
# not cc_pair.status.is_active()
|
||||
# or connector.id == 0
|
||||
# or connector.source == DocumentSource.INGESTION_API
|
||||
# ):
|
||||
# return False
|
||||
# if not last_index:
|
||||
# return True
|
||||
# if connector.refresh_freq is None:
|
||||
# return False
|
||||
# # Only one scheduled/ongoing job per connector at a time
|
||||
# # this prevents cases where
|
||||
# # (1) the "latest" index_attempt is scheduled so we show
|
||||
# # that in the UI despite another index_attempt being in-progress
|
||||
# # (2) multiple scheduled index_attempts at a time
|
||||
# if (
|
||||
# last_index.status == IndexingStatus.NOT_STARTED
|
||||
# or last_index.status == IndexingStatus.IN_PROGRESS
|
||||
# ):
|
||||
# return False
|
||||
# current_db_time = get_db_current_time(db_session)
|
||||
# time_since_index = current_db_time - last_index.time_updated
|
||||
# return time_since_index.total_seconds() >= connector.refresh_freq
|
||||
# def _mark_run_failed(
|
||||
# db_session: Session, index_attempt: IndexAttempt, failure_reason: str
|
||||
# ) -> None:
|
||||
# """Marks the `index_attempt` row as failed + updates the `
|
||||
# connector_credential_pair` to reflect that the run failed"""
|
||||
# logger.warning(
|
||||
# f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
|
||||
# f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
|
||||
# )
|
||||
# mark_attempt_failed(
|
||||
# index_attempt=index_attempt,
|
||||
# db_session=db_session,
|
||||
# failure_reason=failure_reason,
|
||||
# )
|
||||
# """Main funcs"""
|
||||
# def create_indexing_jobs(
|
||||
# existing_jobs: dict[int, Future | SimpleJob], tenant_id: str | None
|
||||
# ) -> None:
|
||||
# """Creates new indexing jobs for each connector / credential pair which is:
|
||||
# 1. Enabled
|
||||
# 2. `refresh_frequency` time has passed since the last indexing run for this pair
|
||||
# 3. There is not already an ongoing indexing attempt for this pair
|
||||
# """
|
||||
# with get_session_with_tenant(tenant_id) as db_session:
|
||||
# ongoing: set[tuple[int | None, int]] = set()
|
||||
# for attempt_id in existing_jobs:
|
||||
# attempt = get_index_attempt(
|
||||
# db_session=db_session, index_attempt_id=attempt_id
|
||||
# )
|
||||
# if attempt is None:
|
||||
# logger.error(
|
||||
# f"Unable to find IndexAttempt for ID '{attempt_id}' when creating "
|
||||
# "indexing jobs"
|
||||
# )
|
||||
# continue
|
||||
# ongoing.add(
|
||||
# (
|
||||
# attempt.connector_credential_pair_id,
|
||||
# attempt.search_settings_id,
|
||||
# )
|
||||
# )
|
||||
# # Get the primary search settings
|
||||
# primary_search_settings = get_current_search_settings(db_session)
|
||||
# search_settings = [primary_search_settings]
|
||||
# # Check for secondary search settings
|
||||
# secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
# if secondary_search_settings is not None:
|
||||
# # If secondary settings exist, add them to the list
|
||||
# search_settings.append(secondary_search_settings)
|
||||
# all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
|
||||
# for cc_pair in all_connector_credential_pairs:
|
||||
# for search_settings_instance in search_settings:
|
||||
# # Check if there is an ongoing indexing attempt for this connector credential pair
|
||||
# if (cc_pair.id, search_settings_instance.id) in ongoing:
|
||||
# continue
|
||||
# last_attempt = get_last_attempt_for_cc_pair(
|
||||
# cc_pair.id, search_settings_instance.id, db_session
|
||||
# )
|
||||
# if not _should_create_new_indexing(
|
||||
# cc_pair=cc_pair,
|
||||
# last_index=last_attempt,
|
||||
# search_settings_instance=search_settings_instance,
|
||||
# secondary_index_building=len(search_settings) > 1,
|
||||
# db_session=db_session,
|
||||
# ):
|
||||
# continue
|
||||
# create_index_attempt(
|
||||
# cc_pair.id, search_settings_instance.id, db_session
|
||||
# )
|
||||
# def cleanup_indexing_jobs(
|
||||
# existing_jobs: dict[int, Future | SimpleJob],
|
||||
# tenant_id: str | None,
|
||||
# timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
|
||||
# ) -> dict[int, Future | SimpleJob]:
|
||||
# existing_jobs_copy = existing_jobs.copy()
|
||||
# # clean up completed jobs
|
||||
# with get_session_with_tenant(tenant_id) as db_session:
|
||||
# for attempt_id, job in existing_jobs.items():
|
||||
# index_attempt = get_index_attempt(
|
||||
# db_session=db_session, index_attempt_id=attempt_id
|
||||
# )
|
||||
# # do nothing for ongoing jobs that haven't been stopped
|
||||
# if not job.done():
|
||||
# if not index_attempt:
|
||||
# continue
|
||||
# if not index_attempt.is_finished():
|
||||
# continue
|
||||
# if job.status == "error":
|
||||
# logger.error(job.exception())
|
||||
# job.release()
|
||||
# del existing_jobs_copy[attempt_id]
|
||||
# if not index_attempt:
|
||||
# logger.error(
|
||||
# f"Unable to find IndexAttempt for ID '{attempt_id}' when cleaning "
|
||||
# "up indexing jobs"
|
||||
# )
|
||||
# continue
|
||||
# if (
|
||||
# index_attempt.status == IndexingStatus.IN_PROGRESS
|
||||
# or job.status == "error"
|
||||
# ):
|
||||
# _mark_run_failed(
|
||||
# db_session=db_session,
|
||||
# index_attempt=index_attempt,
|
||||
# failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||
# )
|
||||
# # clean up in-progress jobs that were never completed
|
||||
# try:
|
||||
# connectors = fetch_connectors(db_session)
|
||||
# for connector in connectors:
|
||||
# in_progress_indexing_attempts = get_inprogress_index_attempts(
|
||||
# connector.id, db_session
|
||||
# )
|
||||
# for index_attempt in in_progress_indexing_attempts:
|
||||
# if index_attempt.id in existing_jobs:
|
||||
# # If index attempt is canceled, stop the run
|
||||
# if index_attempt.status == IndexingStatus.FAILED:
|
||||
# existing_jobs[index_attempt.id].cancel()
|
||||
# # check to see if the job has been updated in last `timeout_hours` hours, if not
|
||||
# # assume it to frozen in some bad state and just mark it as failed. Note: this relies
|
||||
# # on the fact that the `time_updated` field is constantly updated every
|
||||
# # batch of documents indexed
|
||||
# current_db_time = get_db_current_time(db_session=db_session)
|
||||
# time_since_update = current_db_time - index_attempt.time_updated
|
||||
# if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
|
||||
# existing_jobs[index_attempt.id].cancel()
|
||||
# _mark_run_failed(
|
||||
# db_session=db_session,
|
||||
# index_attempt=index_attempt,
|
||||
# failure_reason="Indexing run frozen - no updates in the last three hours. "
|
||||
# "The run will be re-attempted at next scheduled indexing time.",
|
||||
# )
|
||||
# else:
|
||||
# # If job isn't known, simply mark it as failed
|
||||
# _mark_run_failed(
|
||||
# db_session=db_session,
|
||||
# index_attempt=index_attempt,
|
||||
# failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||
# )
|
||||
# except ProgrammingError:
|
||||
# logger.debug(f"No Connector Table exists for: {tenant_id}")
|
||||
# return existing_jobs_copy
|
||||
# def kickoff_indexing_jobs(
|
||||
# existing_jobs: dict[int, Future | SimpleJob],
|
||||
# client: Client | SimpleJobClient,
|
||||
# secondary_client: Client | SimpleJobClient,
|
||||
# tenant_id: str | None,
|
||||
# ) -> dict[int, Future | SimpleJob]:
|
||||
# existing_jobs_copy = existing_jobs.copy()
|
||||
# current_session = get_session_with_tenant(tenant_id)
|
||||
# # Don't include jobs waiting in the Dask queue that just haven't started running
|
||||
# # Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
|
||||
# with current_session as db_session:
|
||||
# # get_not_started_index_attempts orders its returned results from oldest to newest
|
||||
# # we must process attempts in a FIFO manner to prevent connector starvation
|
||||
# new_indexing_attempts = [
|
||||
# (attempt, attempt.search_settings)
|
||||
# for attempt in get_not_started_index_attempts(db_session)
|
||||
# if attempt.id not in existing_jobs
|
||||
# ]
|
||||
# logger.debug(f"Found {len(new_indexing_attempts)} new indexing task(s).")
|
||||
# if not new_indexing_attempts:
|
||||
# return existing_jobs
|
||||
# indexing_attempt_count = 0
|
||||
# primary_client_full = False
|
||||
# secondary_client_full = False
|
||||
# for attempt, search_settings in new_indexing_attempts:
|
||||
# if primary_client_full and secondary_client_full:
|
||||
# break
|
||||
# use_secondary_index = (
|
||||
# search_settings.status == IndexModelStatus.FUTURE
|
||||
# if search_settings is not None
|
||||
# else False
|
||||
# )
|
||||
# if attempt.connector_credential_pair.connector is None:
|
||||
# logger.warning(
|
||||
# f"Skipping index attempt as Connector has been deleted: {attempt}"
|
||||
# )
|
||||
# with current_session as db_session:
|
||||
# mark_attempt_failed(
|
||||
# attempt, db_session, failure_reason="Connector is null"
|
||||
# )
|
||||
# continue
|
||||
# if attempt.connector_credential_pair.credential is None:
|
||||
# logger.warning(
|
||||
# f"Skipping index attempt as Credential has been deleted: {attempt}"
|
||||
# )
|
||||
# with current_session as db_session:
|
||||
# mark_attempt_failed(
|
||||
# attempt, db_session, failure_reason="Credential is null"
|
||||
# )
|
||||
# continue
|
||||
# if not use_secondary_index:
|
||||
# if not primary_client_full:
|
||||
# run = client.submit(
|
||||
# run_indexing_entrypoint,
|
||||
# attempt.id,
|
||||
# tenant_id,
|
||||
# attempt.connector_credential_pair_id,
|
||||
# global_version.is_ee_version(),
|
||||
# pure=False,
|
||||
# )
|
||||
# if not run:
|
||||
# primary_client_full = True
|
||||
# else:
|
||||
# if not secondary_client_full:
|
||||
# run = secondary_client.submit(
|
||||
# run_indexing_entrypoint,
|
||||
# attempt.id,
|
||||
# tenant_id,
|
||||
# attempt.connector_credential_pair_id,
|
||||
# global_version.is_ee_version(),
|
||||
# pure=False,
|
||||
# )
|
||||
# if not run:
|
||||
# secondary_client_full = True
|
||||
# if run:
|
||||
# if indexing_attempt_count == 0:
|
||||
# logger.info(
|
||||
# f"Indexing dispatch starts: pending={len(new_indexing_attempts)}"
|
||||
# )
|
||||
# indexing_attempt_count += 1
|
||||
# secondary_str = " (secondary index)" if use_secondary_index else ""
|
||||
# logger.info(
|
||||
# f"Indexing dispatched{secondary_str}: "
|
||||
# f"attempt_id={attempt.id} "
|
||||
# f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
# f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
# f"credentials='{attempt.connector_credential_pair.credential_id}'"
|
||||
# )
|
||||
# existing_jobs_copy[attempt.id] = run
|
||||
# if indexing_attempt_count > 0:
|
||||
# logger.info(
|
||||
# f"Indexing dispatch results: "
|
||||
# f"initial_pending={len(new_indexing_attempts)} "
|
||||
# f"started={indexing_attempt_count} "
|
||||
# f"remaining={len(new_indexing_attempts) - indexing_attempt_count}"
|
||||
# )
|
||||
# return existing_jobs_copy
|
||||
# def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
# if not MULTI_TENANT:
|
||||
# return [None]
|
||||
# with get_session_with_tenant(tenant_id="public") as session:
|
||||
# result = session.execute(
|
||||
# text(
|
||||
# """
|
||||
# SELECT schema_name
|
||||
# FROM information_schema.schemata
|
||||
# WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
|
||||
# )
|
||||
# )
|
||||
# tenant_ids = [row[0] for row in result]
|
||||
# valid_tenants = [
|
||||
# tenant
|
||||
# for tenant in tenant_ids
|
||||
# if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
|
||||
# ]
|
||||
# return valid_tenants
|
||||
# def update_loop(
|
||||
# delay: int = 10,
|
||||
# num_workers: int = NUM_INDEXING_WORKERS,
|
||||
# num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
|
||||
# ) -> None:
|
||||
# if not MULTI_TENANT:
|
||||
# # We can use this function as we are certain only the public schema exists
|
||||
# # (explicitly for the non-`MULTI_TENANT` case)
|
||||
# engine = get_sqlalchemy_engine()
|
||||
# with Session(engine) as db_session:
|
||||
# check_index_swap(db_session=db_session)
|
||||
# search_settings = get_current_search_settings(db_session)
|
||||
# # So that the first time users aren't surprised by really slow speed of first
|
||||
# # batch of documents indexed
|
||||
# if search_settings.provider_type is None:
|
||||
# logger.notice("Running a first inference to warm up embedding model")
|
||||
# embedding_model = EmbeddingModel.from_db_model(
|
||||
# search_settings=search_settings,
|
||||
# server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
# server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
# )
|
||||
# warm_up_bi_encoder(
|
||||
# embedding_model=embedding_model,
|
||||
# )
|
||||
# logger.notice("First inference complete.")
|
||||
# client_primary: Client | SimpleJobClient
|
||||
# client_secondary: Client | SimpleJobClient
|
||||
# if DASK_JOB_CLIENT_ENABLED:
|
||||
# cluster_primary = LocalCluster(
|
||||
# n_workers=num_workers,
|
||||
# threads_per_worker=1,
|
||||
# silence_logs=logging.ERROR,
|
||||
# )
|
||||
# cluster_secondary = LocalCluster(
|
||||
# n_workers=num_secondary_workers,
|
||||
# threads_per_worker=1,
|
||||
# silence_logs=logging.ERROR,
|
||||
# )
|
||||
# client_primary = Client(cluster_primary)
|
||||
# client_secondary = Client(cluster_secondary)
|
||||
# if LOG_LEVEL.lower() == "debug":
|
||||
# client_primary.register_worker_plugin(ResourceLogger())
|
||||
# else:
|
||||
# client_primary = SimpleJobClient(n_workers=num_workers)
|
||||
# client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
|
||||
# existing_jobs: dict[str | None, dict[int, Future | SimpleJob]] = {}
|
||||
# logger.notice("Startup complete. Waiting for indexing jobs...")
|
||||
# while True:
|
||||
# start = time.time()
|
||||
# start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
|
||||
# logger.debug(f"Running update, current UTC time: {start_time_utc}")
|
||||
# if existing_jobs:
|
||||
# logger.debug(
|
||||
# "Found existing indexing jobs: "
|
||||
# f"{[(tenant_id, list(jobs.keys())) for tenant_id, jobs in existing_jobs.items()]}"
|
||||
# )
|
||||
# try:
|
||||
# tenants = get_all_tenant_ids()
|
||||
# for tenant_id in tenants:
|
||||
# try:
|
||||
# logger.debug(
|
||||
# f"Processing {'index attempts' if tenant_id is None else f'tenant {tenant_id}'}"
|
||||
# )
|
||||
# with get_session_with_tenant(tenant_id) as db_session:
|
||||
# index_to_expire = check_index_swap(db_session=db_session)
|
||||
# if index_to_expire and tenant_id and MULTI_TENANT:
|
||||
# VespaIndex.delete_entries_by_tenant_id(
|
||||
# tenant_id=tenant_id,
|
||||
# index_name=index_to_expire.index_name,
|
||||
# )
|
||||
# if not MULTI_TENANT:
|
||||
# search_settings = get_current_search_settings(db_session)
|
||||
# if search_settings.provider_type is None:
|
||||
# logger.notice(
|
||||
# "Running a first inference to warm up embedding model"
|
||||
# )
|
||||
# embedding_model = EmbeddingModel.from_db_model(
|
||||
# search_settings=search_settings,
|
||||
# server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
# server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
# )
|
||||
# warm_up_bi_encoder(embedding_model=embedding_model)
|
||||
# logger.notice("First inference complete.")
|
||||
# tenant_jobs = existing_jobs.get(tenant_id, {})
|
||||
# tenant_jobs = cleanup_indexing_jobs(
|
||||
# existing_jobs=tenant_jobs, tenant_id=tenant_id
|
||||
# )
|
||||
# create_indexing_jobs(existing_jobs=tenant_jobs, tenant_id=tenant_id)
|
||||
# tenant_jobs = kickoff_indexing_jobs(
|
||||
# existing_jobs=tenant_jobs,
|
||||
# client=client_primary,
|
||||
# secondary_client=client_secondary,
|
||||
# tenant_id=tenant_id,
|
||||
# )
|
||||
# existing_jobs[tenant_id] = tenant_jobs
|
||||
# except Exception as e:
|
||||
# logger.exception(
|
||||
# f"Failed to process tenant {tenant_id or 'default'}: {e}"
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.exception(f"Failed to run update due to {e}")
|
||||
# sleep_time = delay - (time.time() - start)
|
||||
# if sleep_time > 0:
|
||||
# time.sleep(sleep_time)
|
||||
# def update__main() -> None:
|
||||
# set_is_ee_based_on_env_variable()
|
||||
# # initialize the Postgres connection pool
|
||||
# SqlEngine.set_app_name(POSTGRES_INDEXER_APP_NAME)
|
||||
# logger.notice("Starting indexing service")
|
||||
# update_loop()
|
||||
# if __name__ == "__main__":
|
||||
# update__main()
|
||||
@@ -41,19 +41,6 @@ personas:
|
||||
icon_color: "#6FB1FF"
|
||||
display_priority: 1
|
||||
is_visible: true
|
||||
starter_messages:
|
||||
- name: "General Information"
|
||||
description: "Ask about available information"
|
||||
message: "Hello! I'm interested in learning more about the information available here. Could you give me an overview of the types of data or documents that might be accessible?"
|
||||
- name: "Specific Topic Search"
|
||||
description: "Search for specific information"
|
||||
message: "Hi! I'd like to learn more about a specific topic. Could you help me find relevant documents and information?"
|
||||
- name: "Recent Updates"
|
||||
description: "Inquire about latest additions"
|
||||
message: "Hello! I'm curious about any recent updates or additions to the knowledge base. Can you tell me what new information has been added lately?"
|
||||
- name: "Cross-referencing Information"
|
||||
description: "Connect information from different sources"
|
||||
message: "Hi! I'm working on a project that requires connecting information from multiple sources. How can I effectively cross-reference data across different documents or categories?"
|
||||
|
||||
- id: 1
|
||||
name: "General"
|
||||
@@ -70,19 +57,6 @@ personas:
|
||||
icon_color: "#FF6F6F"
|
||||
display_priority: 0
|
||||
is_visible: true
|
||||
starter_messages:
|
||||
- name: "Open Discussion"
|
||||
description: "Start an open-ended conversation"
|
||||
message: "Hi! Can you help me write a professional email?"
|
||||
- name: "Problem Solving"
|
||||
description: "Get help with a challenge"
|
||||
message: "Hello! I need help managing my daily tasks better. Do you have any simple tips?"
|
||||
- name: "Learn Something New"
|
||||
description: "Explore a new topic"
|
||||
message: "Hi! Could you explain what project management is in simple terms?"
|
||||
- name: "Creative Brainstorming"
|
||||
description: "Generate creative ideas"
|
||||
message: "Hello! I need to brainstorm some team building activities. Do you have any fun suggestions?"
|
||||
|
||||
- id: 2
|
||||
name: "Paraphrase"
|
||||
@@ -99,19 +73,7 @@ personas:
|
||||
icon_color: "#6FFF8D"
|
||||
display_priority: 2
|
||||
is_visible: false
|
||||
starter_messages:
|
||||
- name: "Document Search"
|
||||
description: "Find exact information"
|
||||
message: "Hi! Could you help me find information about our team structure and reporting lines from our internal documents?"
|
||||
- name: "Process Verification"
|
||||
description: "Find exact quotes"
|
||||
message: "Hello! I need to understand our project approval process. Could you find the exact steps from our documentation?"
|
||||
- name: "Technical Documentation"
|
||||
description: "Search technical details"
|
||||
message: "Hi there! I'm looking for information about our deployment procedures. Can you find the specific steps from our technical guides?"
|
||||
- name: "Policy Reference"
|
||||
description: "Check official policies"
|
||||
message: "Hello! Could you help me find our official guidelines about client communication? I need the exact wording from our documentation."
|
||||
|
||||
|
||||
- id: 3
|
||||
name: "Art"
|
||||
@@ -124,21 +86,8 @@ personas:
|
||||
llm_filter_extraction: false
|
||||
recency_bias: "no_decay"
|
||||
document_sets: []
|
||||
icon_shape: 234124
|
||||
icon_shape: 234124
|
||||
icon_color: "#9B59B6"
|
||||
image_generation: true
|
||||
image_generation: true
|
||||
display_priority: 3
|
||||
is_visible: true
|
||||
starter_messages:
|
||||
- name: "Landscape"
|
||||
description: "Generate a landscape image"
|
||||
message: "Create an image of a serene mountain lake at sunset, with snow-capped peaks reflected in the calm water and a small wooden cabin on the shore."
|
||||
- name: "Character"
|
||||
description: "Generate a character image"
|
||||
message: "Generate an image of a futuristic robot with glowing blue eyes, sleek metallic body, and intricate circuitry visible through transparent panels on its chest and arms."
|
||||
- name: "Abstract"
|
||||
description: "Create an abstract image"
|
||||
message: "Create an abstract image representing the concept of time, using swirling clock hands, fragmented hourglasses, and streaks of light to convey the passage of moments and eras."
|
||||
- name: "Urban Scene"
|
||||
description: "Generate an urban landscape"
|
||||
message: "Generate an image of a bustling futuristic cityscape at night, with towering skyscrapers, flying vehicles, holographic advertisements, and a mix of neon and bioluminescent lighting."
|
||||
|
||||
@@ -672,7 +672,6 @@ def stream_chat_message_objects(
|
||||
all_docs_useful=selected_db_search_docs is not None
|
||||
),
|
||||
document_pruning_config=document_pruning_config,
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
),
|
||||
prompt_config=prompt_config,
|
||||
llm=(
|
||||
|
||||
@@ -43,9 +43,6 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
||||
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
|
||||
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
|
||||
|
||||
# Necessary for cloud integration tests
|
||||
DISABLE_VERIFICATION = os.environ.get("DISABLE_VERIFICATION", "").lower() == "true"
|
||||
|
||||
# Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive
|
||||
# information. This provides an extra layer of security on top of Postgres access controls
|
||||
# and is available in Danswer EE
|
||||
@@ -134,6 +131,7 @@ try:
|
||||
except ValueError:
|
||||
INDEX_BATCH_SIZE = 16
|
||||
|
||||
|
||||
# Below are intended to match the env variables names used by the official postgres docker image
|
||||
# https://hub.docker.com/_/postgres
|
||||
POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
|
||||
@@ -142,7 +140,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
||||
os.environ.get("POSTGRES_PASSWORD") or "password"
|
||||
)
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
|
||||
POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||
@@ -200,41 +198,6 @@ try:
|
||||
except ValueError:
|
||||
CELERY_BROKER_POOL_LIMIT = CELERY_BROKER_POOL_LIMIT_DEFAULT
|
||||
|
||||
CELERY_WORKER_LIGHT_CONCURRENCY_DEFAULT = 24
|
||||
try:
|
||||
CELERY_WORKER_LIGHT_CONCURRENCY = int(
|
||||
os.environ.get(
|
||||
"CELERY_WORKER_LIGHT_CONCURRENCY", CELERY_WORKER_LIGHT_CONCURRENCY_DEFAULT
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
CELERY_WORKER_LIGHT_CONCURRENCY = CELERY_WORKER_LIGHT_CONCURRENCY_DEFAULT
|
||||
|
||||
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT = 8
|
||||
try:
|
||||
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER = int(
|
||||
os.environ.get(
|
||||
"CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER",
|
||||
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT,
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER = (
|
||||
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT
|
||||
)
|
||||
|
||||
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 1
|
||||
try:
|
||||
env_value = os.environ.get("CELERY_WORKER_INDEXING_CONCURRENCY")
|
||||
if not env_value:
|
||||
env_value = os.environ.get("NUM_INDEXING_WORKERS")
|
||||
|
||||
if not env_value:
|
||||
env_value = str(CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT)
|
||||
CELERY_WORKER_INDEXING_CONCURRENCY = int(env_value)
|
||||
except ValueError:
|
||||
CELERY_WORKER_INDEXING_CONCURRENCY = CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT
|
||||
|
||||
#####
|
||||
# Connector Configs
|
||||
#####
|
||||
@@ -290,6 +253,12 @@ CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES = (
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Save pages labels as Danswer metadata tags
|
||||
# The reason to skip this would be to reduce the number of calls to Confluence due to rate limit concerns
|
||||
CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING = (
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Attachments exceeding this size will not be retrieved (in bytes)
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD = int(
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||
@@ -437,7 +406,7 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
|
||||
os.environ.get("CUSTOM_ANSWER_VALIDITY_CONDITIONS", "[]")
|
||||
)
|
||||
|
||||
VESPA_REQUEST_TIMEOUT = int(os.environ.get("VESPA_REQUEST_TIMEOUT") or "15")
|
||||
VESPA_REQUEST_TIMEOUT = int(os.environ.get("VESPA_REQUEST_TIMEOUT") or "5")
|
||||
|
||||
SYSTEM_RECURSION_LIMIT = int(os.environ.get("SYSTEM_RECURSION_LIMIT") or "1000")
|
||||
|
||||
@@ -461,12 +430,20 @@ AZURE_DALLE_API_BASE = os.environ.get("AZURE_DALLE_API_BASE")
|
||||
AZURE_DALLE_DEPLOYMENT_NAME = os.environ.get("AZURE_DALLE_DEPLOYMENT_NAME")
|
||||
|
||||
|
||||
# Cloud configuration
|
||||
|
||||
# Multi-tenancy configuration
|
||||
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
|
||||
|
||||
# Use managed Vespa (Vespa Cloud). If set, must also set VESPA_CLOUD_URL, VESPA_CLOUD_CERT_PATH and VESPA_CLOUD_KEY_PATH
|
||||
MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
|
||||
|
||||
ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"
|
||||
|
||||
# Security and authentication
|
||||
SECRET_JWT_KEY = os.environ.get(
|
||||
"SECRET_JWT_KEY", ""
|
||||
) # Used for encryption of the JWT token for user's tenant context
|
||||
DATA_PLANE_SECRET = os.environ.get(
|
||||
"DATA_PLANE_SECRET", ""
|
||||
) # Used for secure communication between the control and data plane
|
||||
|
||||
@@ -31,6 +31,9 @@ DISABLED_GEN_AI_MSG = (
|
||||
"You can still use Danswer as a search engine."
|
||||
)
|
||||
|
||||
# Prefix used for all tenant ids
|
||||
TENANT_ID_PREFIX = "tenant_"
|
||||
|
||||
# Postgres connection constants for application_name
|
||||
POSTGRES_WEB_APP_NAME = "web"
|
||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||
@@ -43,6 +46,7 @@ POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||
POSTGRES_DEFAULT_SCHEMA = "public"
|
||||
|
||||
# API Keys
|
||||
DANSWER_API_KEY_PREFIX = "API_KEY__"
|
||||
@@ -67,7 +71,6 @@ KV_CUSTOMER_UUID_KEY = "customer_uuid"
|
||||
KV_INSTANCE_DOMAIN_KEY = "instance_domain"
|
||||
KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings"
|
||||
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
|
||||
KV_DOCUMENTS_SEEDED_KEY = "documents_seeded"
|
||||
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
|
||||
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
|
||||
@@ -13,8 +13,8 @@ Connectors come in 3 different flows:
|
||||
documents via a connector's API or loads the documents from some sort of a dump file.
|
||||
- Poll connector:
|
||||
- Incrementally updates documents based on a provided time range. It is used by the background job to pull the latest
|
||||
changes and additions since the last round of polling. This connector helps keep the document index up to date
|
||||
without needing to fetch/embed/index every document which would be too slow to do frequently on large sets of
|
||||
changes additions and changes since the last round of polling. This connector helps keep the document index up to date
|
||||
without needing to fetch/embed/index every document which generally be too slow to do frequently on large sets of
|
||||
documents.
|
||||
- Event Based connectors:
|
||||
- Connectors that listen to events and update documents accordingly.
|
||||
|
||||
@@ -15,6 +15,7 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_st
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
@@ -23,7 +24,6 @@ from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.html_utils import parse_html_page_basic
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -44,6 +44,8 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> tuple[list[Document], int]:
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
params = {
|
||||
"count": str(batch_size),
|
||||
"offset": str(start_ind),
|
||||
@@ -61,7 +63,8 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
|
||||
batch = bookstack_client.get(endpoint, params=params).get("data", [])
|
||||
doc_batch = [transformer(bookstack_client, item) for item in batch]
|
||||
for item in batch:
|
||||
doc_batch.append(transformer(bookstack_client, item))
|
||||
|
||||
return doc_batch, len(batch)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
@@ -18,7 +19,6 @@ from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
CLICKUP_API_BASE_URL = "https://api.clickup.com/api/v2"
|
||||
@@ -210,7 +210,6 @@ if __name__ == "__main__":
|
||||
"clickup_team_id": os.environ["clickup_team_id"],
|
||||
}
|
||||
)
|
||||
|
||||
latest_docs = clickup_connector.load_from_state()
|
||||
|
||||
for doc in latest_docs:
|
||||
|
||||
32
backend/danswer/connectors/confluence/confluence_utils.py
Normal file
32
backend/danswer/connectors/confluence/confluence_utils.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import bs4
|
||||
|
||||
|
||||
def build_confluence_document_id(base_url: str, content_url: str) -> str:
|
||||
"""For confluence, the document id is the page url for a page based document
|
||||
or the attachment download url for an attachment based document
|
||||
|
||||
Args:
|
||||
base_url (str): The base url of the Confluence instance
|
||||
content_url (str): The url of the page or attachment download url
|
||||
|
||||
Returns:
|
||||
str: The document id
|
||||
"""
|
||||
return f"{base_url}{content_url}"
|
||||
|
||||
|
||||
def get_used_attachments(text: str) -> list[str]:
|
||||
"""Parse a Confluence html page to generate a list of current
|
||||
attachment in used
|
||||
|
||||
Args:
|
||||
text (str): The page content
|
||||
|
||||
Returns:
|
||||
list[str]: List of filenames currently in use by the page text
|
||||
"""
|
||||
files_in_used = []
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
for attachment in soup.findAll("ri:attachment"):
|
||||
files_in_used.append(attachment.attrs["ri:filename"])
|
||||
return files_in_used
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,226 +0,0 @@
|
||||
import math
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
from urllib.parse import quote
|
||||
|
||||
from atlassian import Confluence # type:ignore
|
||||
from requests import HTTPError
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
|
||||
class ConfluenceRateLimitError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _handle_http_error(e: HTTPError, attempt: int) -> int:
|
||||
MIN_DELAY = 2
|
||||
MAX_DELAY = 60
|
||||
STARTING_DELAY = 5
|
||||
BACKOFF = 2
|
||||
|
||||
# Check if the response or headers are None to avoid potential AttributeError
|
||||
if e.response is None or e.response.headers is None:
|
||||
logger.warning("HTTPError with `None` as response or as headers")
|
||||
raise e
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
):
|
||||
raise e
|
||||
|
||||
retry_after = None
|
||||
|
||||
retry_after_header = e.response.headers.get("Retry-After")
|
||||
if retry_after_header is not None:
|
||||
try:
|
||||
retry_after = int(retry_after_header)
|
||||
if retry_after > MAX_DELAY:
|
||||
logger.warning(
|
||||
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
|
||||
)
|
||||
retry_after = MAX_DELAY
|
||||
if retry_after < MIN_DELAY:
|
||||
retry_after = MIN_DELAY
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if retry_after is not None:
|
||||
logger.warning(
|
||||
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
delay = retry_after
|
||||
else:
|
||||
logger.warning(
|
||||
"Rate limiting without retry header. Retrying with exponential backoff..."
|
||||
)
|
||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||
|
||||
delay_until = math.ceil(time.monotonic() + delay)
|
||||
return delay_until
|
||||
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
# this uses the native rate limiting option provided by the
|
||||
# confluence client and otherwise applies a simpler set of error handling
|
||||
def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
TIMEOUT = 3600
|
||||
timeout_at = time.monotonic() + TIMEOUT
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
if time.monotonic() > timeout_at:
|
||||
raise TimeoutError(
|
||||
f"Confluence call attempts took longer than {TIMEOUT} seconds."
|
||||
)
|
||||
|
||||
try:
|
||||
# we're relying more on the client to rate limit itself
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
while time.monotonic() < delay_until:
|
||||
# in the future, check a signal here to exit
|
||||
time.sleep(1)
|
||||
except AttributeError as e:
|
||||
# Some error within the Confluence library, unclear why it fails.
|
||||
# Users reported it to be intermittent, so just retry
|
||||
if attempt == MAX_RETRIES - 1:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Confluence Client raised an AttributeError. Retrying..."
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
_DEFAULT_PAGINATION_LIMIT = 100
|
||||
|
||||
|
||||
class OnyxConfluence(Confluence):
|
||||
"""
|
||||
This is a custom Confluence class that overrides the default Confluence class to add a custom CQL method.
|
||||
This is necessary because the default Confluence class does not properly support cql expansions.
|
||||
All methods are automatically wrapped with handle_confluence_rate_limit.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, *args: Any, **kwargs: Any) -> None:
|
||||
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
|
||||
self._wrap_methods()
|
||||
|
||||
def _wrap_methods(self) -> None:
|
||||
"""
|
||||
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
|
||||
wrap it with handle_confluence_rate_limit.
|
||||
"""
|
||||
for attr_name in dir(self):
|
||||
if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
|
||||
setattr(
|
||||
self,
|
||||
attr_name,
|
||||
handle_confluence_rate_limit(getattr(self, attr_name)),
|
||||
)
|
||||
|
||||
def _paginate_url(
|
||||
self, url_suffix: str, limit: int | None = None
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
"""
|
||||
This will paginate through the top level query.
|
||||
"""
|
||||
if not limit:
|
||||
limit = _DEFAULT_PAGINATION_LIMIT
|
||||
|
||||
connection_char = "&" if "?" in url_suffix else "?"
|
||||
url_suffix += f"{connection_char}limit={limit}"
|
||||
|
||||
while url_suffix:
|
||||
try:
|
||||
next_response = self.get(url_suffix)
|
||||
except Exception as e:
|
||||
logger.exception("Error in danswer_cql: \n")
|
||||
raise e
|
||||
yield next_response.get("results", [])
|
||||
url_suffix = next_response.get("_links", {}).get("next")
|
||||
|
||||
def paginated_groups_retrieval(
|
||||
self,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
return self._paginate_url("rest/api/group", limit)
|
||||
|
||||
def paginated_group_members_retrieval(
|
||||
self,
|
||||
group_name: str,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
group_name = quote(group_name)
|
||||
return self._paginate_url(f"rest/api/group/{group_name}/member", limit)
|
||||
|
||||
def paginated_cql_user_retrieval(
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
return self._paginate_url(
|
||||
f"rest/api/search/user?cql={cql}{expand_string}", limit
|
||||
)
|
||||
|
||||
def paginated_cql_page_retrieval(
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
return self._paginate_url(
|
||||
f"rest/api/content/search?cql={cql}{expand_string}", limit
|
||||
)
|
||||
|
||||
def cql_paginate_all_expansions(
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
"""
|
||||
This function will paginate through the top level query first, then
|
||||
paginate through all of the expansions.
|
||||
The limit only applies to the top level query.
|
||||
All expansion paginations use default pagination limit (defined by Atlassian).
|
||||
"""
|
||||
|
||||
def _traverse_and_update(data: dict | list) -> None:
|
||||
if isinstance(data, dict):
|
||||
next_url = data.get("_links", {}).get("next")
|
||||
if next_url and "results" in data:
|
||||
data["results"].extend(self._paginate_url(next_url))
|
||||
|
||||
for value in data.values():
|
||||
_traverse_and_update(value)
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
_traverse_and_update(item)
|
||||
|
||||
for results in self.paginated_cql_page_retrieval(cql, expand, limit):
|
||||
_traverse_and_update(results)
|
||||
yield results
|
||||
219
backend/danswer/connectors/confluence/rate_limit_handler.py
Normal file
219
backend/danswer/connectors/confluence/rate_limit_handler.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import math
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
from requests import HTTPError
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
|
||||
class ConfluenceRateLimitError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# commenting out while we try using confluence's rate limiter instead
|
||||
# # https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
# def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
|
||||
# def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
# max_retries = 5
|
||||
# starting_delay = 5
|
||||
# backoff = 2
|
||||
|
||||
# # max_delay is used when the server doesn't hand back "Retry-After"
|
||||
# # and we have to decide the retry delay ourselves
|
||||
# max_delay = 30 # Atlassian uses max_delay = 30 in their examples
|
||||
|
||||
# # max_retry_after is used when we do get a "Retry-After" header
|
||||
# max_retry_after = 300 # should we really cap the maximum retry delay?
|
||||
|
||||
# NEXT_RETRY_KEY = BaseConnector.REDIS_KEY_PREFIX + "confluence_next_retry"
|
||||
|
||||
# # for testing purposes, rate limiting is written to fall back to a simpler
|
||||
# # rate limiting approach when redis is not available
|
||||
# r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# for attempt in range(max_retries):
|
||||
# try:
|
||||
# # if multiple connectors are waiting for the next attempt, there could be an issue
|
||||
# # where many connectors are "released" onto the server at the same time.
|
||||
# # That's not ideal ... but coming up with a mechanism for queueing
|
||||
# # all of these connectors is a bigger problem that we want to take on
|
||||
# # right now
|
||||
# try:
|
||||
# next_attempt = r.get(NEXT_RETRY_KEY)
|
||||
# if next_attempt is None:
|
||||
# next_attempt = 0
|
||||
# else:
|
||||
# next_attempt = int(cast(int, next_attempt))
|
||||
|
||||
# # TODO: all connectors need to be interruptible moving forward
|
||||
# while time.monotonic() < next_attempt:
|
||||
# time.sleep(1)
|
||||
# except ConnectionError:
|
||||
# pass
|
||||
|
||||
# return confluence_call(*args, **kwargs)
|
||||
# except HTTPError as e:
|
||||
# # Check if the response or headers are None to avoid potential AttributeError
|
||||
# if e.response is None or e.response.headers is None:
|
||||
# logger.warning("HTTPError with `None` as response or as headers")
|
||||
# raise e
|
||||
|
||||
# retry_after_header = e.response.headers.get("Retry-After")
|
||||
# if (
|
||||
# e.response.status_code == 429
|
||||
# or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
|
||||
# ):
|
||||
# retry_after = None
|
||||
# if retry_after_header is not None:
|
||||
# try:
|
||||
# retry_after = int(retry_after_header)
|
||||
# except ValueError:
|
||||
# pass
|
||||
|
||||
# if retry_after is not None:
|
||||
# if retry_after > max_retry_after:
|
||||
# logger.warning(
|
||||
# f"Clamping retry_after from {retry_after} to {max_delay} seconds..."
|
||||
# )
|
||||
# retry_after = max_delay
|
||||
|
||||
# logger.warning(
|
||||
# f"Rate limit hit. Retrying after {retry_after} seconds..."
|
||||
# )
|
||||
# try:
|
||||
# r.set(
|
||||
# NEXT_RETRY_KEY,
|
||||
# math.ceil(time.monotonic() + retry_after),
|
||||
# )
|
||||
# except ConnectionError:
|
||||
# pass
|
||||
# else:
|
||||
# logger.warning(
|
||||
# "Rate limit hit. Retrying with exponential backoff..."
|
||||
# )
|
||||
# delay = min(starting_delay * (backoff**attempt), max_delay)
|
||||
# delay_until = math.ceil(time.monotonic() + delay)
|
||||
|
||||
# try:
|
||||
# r.set(NEXT_RETRY_KEY, delay_until)
|
||||
# except ConnectionError:
|
||||
# while time.monotonic() < delay_until:
|
||||
# time.sleep(1)
|
||||
# else:
|
||||
# # re-raise, let caller handle
|
||||
# raise
|
||||
# except AttributeError as e:
|
||||
# # Some error within the Confluence library, unclear why it fails.
|
||||
# # Users reported it to be intermittent, so just retry
|
||||
# logger.warning(f"Confluence Internal Error, retrying... {e}")
|
||||
# delay = min(starting_delay * (backoff**attempt), max_delay)
|
||||
# delay_until = math.ceil(time.monotonic() + delay)
|
||||
# try:
|
||||
# r.set(NEXT_RETRY_KEY, delay_until)
|
||||
# except ConnectionError:
|
||||
# while time.monotonic() < delay_until:
|
||||
# time.sleep(1)
|
||||
|
||||
# if attempt == max_retries - 1:
|
||||
# raise e
|
||||
|
||||
# return cast(F, wrapped_call)
|
||||
|
||||
|
||||
def _handle_http_error(e: HTTPError, attempt: int) -> int:
|
||||
MIN_DELAY = 2
|
||||
MAX_DELAY = 60
|
||||
STARTING_DELAY = 5
|
||||
BACKOFF = 2
|
||||
|
||||
# Check if the response or headers are None to avoid potential AttributeError
|
||||
if e.response is None or e.response.headers is None:
|
||||
logger.warning("HTTPError with `None` as response or as headers")
|
||||
raise e
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
):
|
||||
raise e
|
||||
|
||||
retry_after = None
|
||||
|
||||
retry_after_header = e.response.headers.get("Retry-After")
|
||||
if retry_after_header is not None:
|
||||
try:
|
||||
retry_after = int(retry_after_header)
|
||||
if retry_after > MAX_DELAY:
|
||||
logger.warning(
|
||||
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
|
||||
)
|
||||
retry_after = MAX_DELAY
|
||||
if retry_after < MIN_DELAY:
|
||||
retry_after = MIN_DELAY
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if retry_after is not None:
|
||||
logger.warning(
|
||||
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
delay = retry_after
|
||||
else:
|
||||
logger.warning(
|
||||
"Rate limiting without retry header. Retrying with exponential backoff..."
|
||||
)
|
||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||
|
||||
delay_until = math.ceil(time.monotonic() + delay)
|
||||
return delay_until
|
||||
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
# this uses the native rate limiting option provided by the
|
||||
# confluence client and otherwise applies a simpler set of error handling
|
||||
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
TIMEOUT = 3600
|
||||
timeout_at = time.monotonic() + TIMEOUT
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
if time.monotonic() > timeout_at:
|
||||
raise TimeoutError(
|
||||
f"Confluence call attempts took longer than {TIMEOUT} seconds."
|
||||
)
|
||||
|
||||
try:
|
||||
# we're relying more on the client to rate limit itself
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
while time.monotonic() < delay_until:
|
||||
# in the future, check a signal here to exit
|
||||
time.sleep(1)
|
||||
except AttributeError as e:
|
||||
# Some error within the Confluence library, unclear why it fails.
|
||||
# Users reported it to be intermittent, so just retry
|
||||
if attempt == MAX_RETRIES - 1:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Confluence Client raised an AttributeError. Retrying..."
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
@@ -1,214 +0,0 @@
|
||||
import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
import bs4
|
||||
|
||||
from danswer.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from danswer.connectors.confluence.onyx_confluence import (
|
||||
OnyxConfluence,
|
||||
)
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.file_processing.html_utils import format_document_soup
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
|
||||
|
||||
def get_user_email_from_username__server(
|
||||
confluence_client: OnyxConfluence, user_name: str
|
||||
) -> str | None:
|
||||
global _USER_EMAIL_CACHE
|
||||
if _USER_EMAIL_CACHE.get(user_name) is None:
|
||||
try:
|
||||
response = confluence_client.get_mobile_parameters(user_name)
|
||||
email = response.get("email")
|
||||
except Exception:
|
||||
email = None
|
||||
_USER_EMAIL_CACHE[user_name] = email
|
||||
return _USER_EMAIL_CACHE[user_name]
|
||||
|
||||
|
||||
_USER_NOT_FOUND = "Unknown Confluence User"
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
|
||||
|
||||
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
||||
"""Get Confluence Display Name based on the account-id or userkey value
|
||||
|
||||
Args:
|
||||
user_id (str): The user id (i.e: the account-id or userkey)
|
||||
confluence_client (Confluence): The Confluence Client
|
||||
|
||||
Returns:
|
||||
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
|
||||
"""
|
||||
global _USER_ID_TO_DISPLAY_NAME_CACHE
|
||||
if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None:
|
||||
try:
|
||||
result = confluence_client.get_user_details_by_userkey(user_id)
|
||||
found_display_name = result.get("displayName")
|
||||
except Exception:
|
||||
found_display_name = None
|
||||
|
||||
if not found_display_name:
|
||||
try:
|
||||
result = confluence_client.get_user_details_by_accountid(user_id)
|
||||
found_display_name = result.get("displayName")
|
||||
except Exception:
|
||||
found_display_name = None
|
||||
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name
|
||||
|
||||
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
|
||||
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: OnyxConfluence, confluence_object: dict[str, Any]
|
||||
) -> str:
|
||||
"""Parse a Confluence html page and replace the 'user Id' by the real
|
||||
User Display Name
|
||||
|
||||
Args:
|
||||
confluence_object (dict): The confluence object as a dict
|
||||
confluence_client (Confluence): Confluence client
|
||||
|
||||
Returns:
|
||||
str: loaded and formated Confluence page
|
||||
"""
|
||||
body = confluence_object["body"]
|
||||
object_html = body.get("storage", body.get("view", {})).get("value")
|
||||
|
||||
soup = bs4.BeautifulSoup(object_html, "html.parser")
|
||||
for user in soup.findAll("ri:user"):
|
||||
user_id = (
|
||||
user.attrs["ri:account-id"]
|
||||
if "ri:account-id" in user.attrs
|
||||
else user.get("ri:userkey")
|
||||
)
|
||||
if not user_id:
|
||||
logger.warning(
|
||||
"ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}"
|
||||
)
|
||||
continue
|
||||
# Include @ sign for tagging, more clear for LLM
|
||||
user.replaceWith("@" + _get_user(confluence_client, user_id))
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if attachment["metadata"]["mediaType"] in [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"video/mp4",
|
||||
"video/quicktime",
|
||||
]:
|
||||
return None
|
||||
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
|
||||
response = confluence_client._session.get(download_link)
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
|
||||
)
|
||||
return None
|
||||
|
||||
extracted_text = extract_file_text(
|
||||
io.BytesIO(response.content),
|
||||
file_name=attachment["title"],
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to char count. "
|
||||
f"char count={len(extracted_text)} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
return extracted_text
|
||||
|
||||
|
||||
def build_confluence_document_id(base_url: str, content_url: str) -> str:
|
||||
"""For confluence, the document id is the page url for a page based document
|
||||
or the attachment download url for an attachment based document
|
||||
|
||||
Args:
|
||||
base_url (str): The base url of the Confluence instance
|
||||
content_url (str): The url of the page or attachment download url
|
||||
|
||||
Returns:
|
||||
str: The document id
|
||||
"""
|
||||
return f"{base_url}{content_url}"
|
||||
|
||||
|
||||
def extract_referenced_attachment_names(page_text: str) -> list[str]:
|
||||
"""Parse a Confluence html page to generate a list of current
|
||||
attachments in use
|
||||
|
||||
Args:
|
||||
text (str): The page content
|
||||
|
||||
Returns:
|
||||
list[str]: List of filenames currently in use by the page text
|
||||
"""
|
||||
referenced_attachment_filenames = []
|
||||
soup = bs4.BeautifulSoup(page_text, "html.parser")
|
||||
for attachment in soup.findAll("ri:attachment"):
|
||||
referenced_attachment_filenames.append(attachment.attrs["ri:filename"])
|
||||
return referenced_attachment_filenames
|
||||
|
||||
|
||||
def datetime_from_string(datetime_string: str) -> datetime:
|
||||
datetime_object = datetime.fromisoformat(datetime_string)
|
||||
|
||||
if datetime_object.tzinfo is None:
|
||||
# If no timezone info, assume it is UTC
|
||||
datetime_object = datetime_object.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# If not in UTC, translate it
|
||||
datetime_object = datetime_object.astimezone(timezone.utc)
|
||||
|
||||
return datetime_object
|
||||
|
||||
|
||||
def build_confluence_client(
|
||||
credentials_json: dict[str, Any], is_cloud: bool, wiki_base: str
|
||||
) -> OnyxConfluence:
|
||||
return OnyxConfluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
# Remove trailing slash from wiki_base if present
|
||||
url=wiki_base.rstrip("/"),
|
||||
# passing in username causes issues for Confluence data center
|
||||
username=credentials_json["confluence_username"] if is_cloud else None,
|
||||
password=credentials_json["confluence_access_token"] if is_cloud else None,
|
||||
token=credentials_json["confluence_access_token"] if not is_cloud else None,
|
||||
backoff_and_retry=True,
|
||||
max_backoff_retries=60,
|
||||
max_backoff_seconds=60,
|
||||
)
|
||||
@@ -11,10 +11,6 @@ from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.utils.text_processing import is_valid_email
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
|
||||
|
||||
def datetime_to_utc(dt: datetime) -> datetime:
|
||||
if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
@@ -53,6 +49,10 @@ def get_experts_stores_representations(
|
||||
return [owner for owner in reps if owner is not None]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
|
||||
|
||||
def process_in_batches(
|
||||
objects: list[T], process_function: Callable[[T], U], batch_size: int
|
||||
) -> Iterator[list[U]]:
|
||||
|
||||
@@ -22,18 +22,18 @@ def retry_builder(
|
||||
jitter: tuple[float, float] | float = 1,
|
||||
) -> Callable[[F], F]:
|
||||
"""Builds a generic wrapper/decorator for calls to external APIs that
|
||||
may fail due to rate limiting, flakes, or other reasons. Applies exponential
|
||||
may fail due to rate limiting, flakes, or other reasons. Applies expontential
|
||||
backoff with jitter to retry the call."""
|
||||
|
||||
@retry(
|
||||
tries=tries,
|
||||
delay=delay,
|
||||
max_delay=max_delay,
|
||||
backoff=backoff,
|
||||
jitter=jitter,
|
||||
logger=cast(Logger, logger),
|
||||
)
|
||||
def retry_with_default(func: F) -> F:
|
||||
@retry(
|
||||
tries=tries,
|
||||
delay=delay,
|
||||
max_delay=max_delay,
|
||||
backoff=backoff,
|
||||
jitter=jitter,
|
||||
logger=cast(Logger, logger),
|
||||
)
|
||||
def wrapped_func(*args: list, **kwargs: dict[str, Any]) -> Any:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@@ -14,6 +14,7 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_st
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
@@ -23,7 +24,6 @@ from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.html_utils import parse_html_page_basic
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.document360.utils import flatten_child_categories
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
@@ -21,7 +22,6 @@ from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.html_utils import parse_html_page_basic
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
|
||||
# Limitations and Potential Improvements
|
||||
# 1. The "Categories themselves contain potentially relevant information" but they're not pulled in
|
||||
|
||||
@@ -34,6 +34,7 @@ from danswer.connectors.mediawiki.wiki import MediaWikiConnector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.connectors.notion.connector import NotionConnector
|
||||
from danswer.connectors.productboard.connector import ProductboardConnector
|
||||
from danswer.connectors.requesttracker.connector import RequestTrackerConnector
|
||||
from danswer.connectors.salesforce.connector import SalesforceConnector
|
||||
from danswer.connectors.sharepoint.connector import SharepointConnector
|
||||
from danswer.connectors.slab.connector import SlabConnector
|
||||
@@ -63,7 +64,7 @@ def identify_connector_class(
|
||||
DocumentSource.SLACK: {
|
||||
InputType.LOAD_STATE: SlackLoadConnector,
|
||||
InputType.POLL: SlackPollConnector,
|
||||
InputType.SLIM_RETRIEVAL: SlackPollConnector,
|
||||
InputType.PRUNE: SlackPollConnector,
|
||||
},
|
||||
DocumentSource.GITHUB: GithubConnector,
|
||||
DocumentSource.GMAIL: GmailConnector,
|
||||
@@ -76,6 +77,7 @@ def identify_connector_class(
|
||||
DocumentSource.SLAB: SlabConnector,
|
||||
DocumentSource.NOTION: NotionConnector,
|
||||
DocumentSource.ZULIP: ZulipConnector,
|
||||
DocumentSource.REQUESTTRACKER: RequestTrackerConnector,
|
||||
DocumentSource.GURU: GuruConnector,
|
||||
DocumentSource.LINEAR: LinearConnector,
|
||||
DocumentSource.HUBSPOT: HubSpotConnector,
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
@@ -27,8 +28,7 @@ from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.file_processing.extract_file_text import read_text_file
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -175,7 +175,7 @@ class LocalFileConnector(LoadConnector):
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
documents: list[Document] = []
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self.tenant_id)
|
||||
token = current_tenant_id.set(self.tenant_id)
|
||||
|
||||
with get_session_with_tenant(self.tenant_id) as db_session:
|
||||
for file_path in self.file_locations:
|
||||
@@ -199,7 +199,7 @@ class LocalFileConnector(LoadConnector):
|
||||
if documents:
|
||||
yield documents
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -24,9 +24,6 @@ from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# List of directories/Files to exclude
|
||||
exclude_patterns = [
|
||||
"logs",
|
||||
@@ -34,6 +31,7 @@ exclude_patterns = [
|
||||
".gitlab/",
|
||||
".pre-commit-config.yaml",
|
||||
]
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _batch_gitlab_objects(
|
||||
|
||||
@@ -19,6 +19,7 @@ from danswer.configs.app_configs import GOOGLE_DRIVE_ONLY_ORG_PUBLIC
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.google_drive.connector_auth import get_google_drive_creds
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
@@ -39,7 +40,6 @@ from danswer.file_processing.unstructured import get_unstructured_api_key
|
||||
from danswer.file_processing.unstructured import unstructured_to_text
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -19,14 +19,13 @@ from danswer.connectors.models import Section
|
||||
from danswer.file_processing.html_utils import parse_html_page_basic
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Potential Improvements
|
||||
# 1. Support fetching per collection via collection token (configured at connector creation)
|
||||
|
||||
GURU_API_BASE = "https://api.getguru.com/api/v1/"
|
||||
GURU_QUERY_ENDPOINT = GURU_API_BASE + "search/query"
|
||||
GURU_CARDS_URL = "https://app.getguru.com/card/"
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def unixtime_to_guru_time_str(unix_time: SecondsSinceUnixEpoch) -> str:
|
||||
|
||||
@@ -3,13 +3,11 @@ from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import SlimDocument
|
||||
|
||||
|
||||
SecondsSinceUnixEpoch = float
|
||||
|
||||
GenerateDocumentsOutput = Iterator[list[Document]]
|
||||
GenerateSlimDocumentOutput = Iterator[list[SlimDocument]]
|
||||
|
||||
|
||||
class BaseConnector(abc.ABC):
|
||||
@@ -54,9 +52,9 @@ class PollConnector(BaseConnector):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SlimConnector(BaseConnector):
|
||||
class IdConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
|
||||
def retrieve_all_source_ids(self) -> set[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_NUM_RETRIES = 5
|
||||
|
||||
@@ -161,7 +161,7 @@ class LoopioConnector(LoadConnector, PollConnector):
|
||||
]
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=str(entry["id"]),
|
||||
id=entry["id"],
|
||||
sections=[Section(link=link, text=content_text)],
|
||||
source=DocumentSource.LOOPIO,
|
||||
semantic_identifier=questions[0],
|
||||
|
||||
@@ -22,7 +22,6 @@ from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -231,7 +230,5 @@ if __name__ == "__main__":
|
||||
print("All docs", all_docs)
|
||||
current = datetime.datetime.now().timestamp()
|
||||
one_day_ago = current - 30 * 24 * 60 * 60 # 30 days
|
||||
|
||||
latest_docs = list(test_connector.poll_source(one_day_ago, current))
|
||||
|
||||
print("Latest docs", latest_docs)
|
||||
|
||||
@@ -14,7 +14,7 @@ class InputType(str, Enum):
|
||||
LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file
|
||||
POLL = "poll" # e.g. calling an API to get all documents in the last hour
|
||||
EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events
|
||||
SLIM_RETRIEVAL = "slim_retrieval"
|
||||
PRUNE = "prune"
|
||||
|
||||
|
||||
class ConnectorMissingCredentialError(PermissionError):
|
||||
@@ -169,11 +169,6 @@ class Document(DocumentBase):
|
||||
)
|
||||
|
||||
|
||||
class SlimDocument(BaseModel):
|
||||
id: str
|
||||
perm_sync_data: Any | None = None
|
||||
|
||||
|
||||
class DocumentErrorSummary(BaseModel):
|
||||
id: str
|
||||
semantic_id: str
|
||||
|
||||
@@ -134,14 +134,9 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
f"This is likely due to the block not being shared "
|
||||
f"with the Danswer integration. Exact exception:\n\n{e}"
|
||||
)
|
||||
else:
|
||||
logger.exception(
|
||||
f"Error fetching blocks with status code {res.status_code}: {res.json()}"
|
||||
)
|
||||
|
||||
# This can occasionally happen, the reason is unknown and cannot be reproduced on our internal Notion
|
||||
# Assuming this will not be a critical loss of data
|
||||
return None
|
||||
return None
|
||||
logger.exception(f"Error fetching blocks - {res.json()}")
|
||||
raise e
|
||||
return res.json()
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
@@ -246,29 +241,24 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
|
||||
# TODO there may be more types to handle here
|
||||
if isinstance(inner_dict, str):
|
||||
# For some objects the innermost value could just be a string, not sure what causes this
|
||||
return inner_dict
|
||||
if "name" in inner_dict:
|
||||
return inner_dict["name"]
|
||||
if "content" in inner_dict:
|
||||
return inner_dict["content"]
|
||||
start = inner_dict.get("start")
|
||||
end = inner_dict.get("end")
|
||||
if start is not None:
|
||||
if end is not None:
|
||||
return f"{start} - {end}"
|
||||
return start
|
||||
elif end is not None:
|
||||
return f"Until {end}"
|
||||
|
||||
elif isinstance(inner_dict, dict):
|
||||
if "name" in inner_dict:
|
||||
return inner_dict["name"]
|
||||
if "content" in inner_dict:
|
||||
return inner_dict["content"]
|
||||
start = inner_dict.get("start")
|
||||
end = inner_dict.get("end")
|
||||
if start is not None:
|
||||
if end is not None:
|
||||
return f"{start} - {end}"
|
||||
return start
|
||||
elif end is not None:
|
||||
return f"Until {end}"
|
||||
|
||||
if "id" in inner_dict:
|
||||
# This is not useful to index, it's a reference to another Notion object
|
||||
# and this ID value in plaintext is useless outside of the Notion context
|
||||
logger.debug("Skipping Notion object id field property")
|
||||
return None
|
||||
if "id" in inner_dict:
|
||||
# This is not useful to index, it's a reference to another Notion object
|
||||
# and this ID value in plaintext is useless outside of the Notion context
|
||||
logger.debug("Skipping Notion object id field property")
|
||||
return None
|
||||
|
||||
logger.debug(f"Unreadable property from innermost prop: {inner_dict}")
|
||||
return None
|
||||
@@ -278,13 +268,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
if not prop:
|
||||
continue
|
||||
|
||||
try:
|
||||
inner_value = _recurse_properties(prop)
|
||||
except Exception as e:
|
||||
# This is not a critical failure, these properties are not the actual contents of the page
|
||||
# more similar to metadata
|
||||
logger.warning(f"Error recursing properties for {prop_name}: {e}")
|
||||
continue
|
||||
inner_value = _recurse_properties(prop)
|
||||
# Not a perfect way to format Notion database tables but there's no perfect representation
|
||||
# since this must be represented as plaintext
|
||||
if inner_value:
|
||||
|
||||
@@ -1,124 +1,153 @@
|
||||
# from datetime import datetime
|
||||
# from datetime import timezone
|
||||
# from logging import DEBUG as LOG_LVL_DEBUG
|
||||
# from typing import Any
|
||||
# from typing import List
|
||||
# from typing import Optional
|
||||
# from rt.rest1 import ALL_QUEUES
|
||||
# from rt.rest1 import Rt
|
||||
# from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
# from danswer.configs.constants import DocumentSource
|
||||
# from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
# from danswer.connectors.interfaces import PollConnector
|
||||
# from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
# from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
# from danswer.connectors.models import Document
|
||||
# from danswer.connectors.models import Section
|
||||
# from danswer.utils.logger import setup_logger
|
||||
# logger = setup_logger()
|
||||
# class RequestTrackerError(Exception):
|
||||
# pass
|
||||
# class RequestTrackerConnector(PollConnector):
|
||||
# def __init__(
|
||||
# self,
|
||||
# batch_size: int = INDEX_BATCH_SIZE,
|
||||
# ) -> None:
|
||||
# self.batch_size = batch_size
|
||||
# def txn_link(self, tid: int, txn: int) -> str:
|
||||
# return f"{self.rt_base_url}/Ticket/Display.html?id={tid}&txn={txn}"
|
||||
# def build_doc_sections_from_txn(
|
||||
# self, connection: Rt, ticket_id: int
|
||||
# ) -> List[Section]:
|
||||
# Sections: List[Section] = []
|
||||
# get_history_resp = connection.get_history(ticket_id)
|
||||
# if get_history_resp is None:
|
||||
# raise RequestTrackerError(f"Ticket {ticket_id} cannot be found")
|
||||
# for tx in get_history_resp:
|
||||
# Sections.append(
|
||||
# Section(
|
||||
# link=self.txn_link(ticket_id, int(tx["id"])),
|
||||
# text="\n".join(
|
||||
# [
|
||||
# f"{k}:\n{v}\n" if k != "Attachments" else ""
|
||||
# for (k, v) in tx.items()
|
||||
# ]
|
||||
# ),
|
||||
# )
|
||||
# )
|
||||
# return Sections
|
||||
# def load_credentials(self, credentials: dict[str, Any]) -> Optional[dict[str, Any]]:
|
||||
# self.rt_username = credentials.get("requesttracker_username")
|
||||
# self.rt_password = credentials.get("requesttracker_password")
|
||||
# self.rt_base_url = credentials.get("requesttracker_base_url")
|
||||
# return None
|
||||
# # This does not include RT file attachments yet.
|
||||
# def _process_tickets(
|
||||
# self, start: datetime, end: datetime
|
||||
# ) -> GenerateDocumentsOutput:
|
||||
# if any([self.rt_username, self.rt_password, self.rt_base_url]) is None:
|
||||
# raise ConnectorMissingCredentialError("requesttracker")
|
||||
# Rt0 = Rt(
|
||||
# f"{self.rt_base_url}/REST/1.0/",
|
||||
# self.rt_username,
|
||||
# self.rt_password,
|
||||
# )
|
||||
# Rt0.login()
|
||||
# d0 = start.strftime("%Y-%m-%d %H:%M:%S")
|
||||
# d1 = end.strftime("%Y-%m-%d %H:%M:%S")
|
||||
# tickets = Rt0.search(
|
||||
# Queue=ALL_QUEUES,
|
||||
# raw_query=f"Updated > '{d0}' AND Updated < '{d1}'",
|
||||
# )
|
||||
# doc_batch: List[Document] = []
|
||||
# for ticket in tickets:
|
||||
# ticket_keys_to_omit = ["id", "Subject"]
|
||||
# tid: int = int(ticket["numerical_id"])
|
||||
# ticketLink: str = f"{self.rt_base_url}/Ticket/Display.html?id={tid}"
|
||||
# logger.info(f"Processing ticket {tid}")
|
||||
# doc = Document(
|
||||
# id=ticket["id"],
|
||||
# # Will add title to the first section later in processing
|
||||
# sections=[Section(link=ticketLink, text="")]
|
||||
# + self.build_doc_sections_from_txn(Rt0, tid),
|
||||
# source=DocumentSource.REQUESTTRACKER,
|
||||
# semantic_identifier=ticket["Subject"],
|
||||
# metadata={
|
||||
# key: value
|
||||
# for key, value in ticket.items()
|
||||
# if key not in ticket_keys_to_omit
|
||||
# },
|
||||
# )
|
||||
# doc_batch.append(doc)
|
||||
# if len(doc_batch) >= self.batch_size:
|
||||
# yield doc_batch
|
||||
# doc_batch = []
|
||||
# if doc_batch:
|
||||
# yield doc_batch
|
||||
# def poll_source(
|
||||
# self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
# ) -> GenerateDocumentsOutput:
|
||||
# # Keep query short, only look behind 1 day at maximum
|
||||
# one_day_ago: float = end - (24 * 60 * 60)
|
||||
# _start: float = start if start > one_day_ago else one_day_ago
|
||||
# start_datetime = datetime.fromtimestamp(_start, tz=timezone.utc)
|
||||
# end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
# yield from self._process_tickets(start_datetime, end_datetime)
|
||||
# if __name__ == "__main__":
|
||||
# import time
|
||||
# import os
|
||||
# from dotenv import load_dotenv
|
||||
# load_dotenv()
|
||||
# logger.setLevel(LOG_LVL_DEBUG)
|
||||
# rt_connector = RequestTrackerConnector()
|
||||
# rt_connector.load_credentials(
|
||||
# {
|
||||
# "requesttracker_username": os.getenv("RT_USERNAME"),
|
||||
# "requesttracker_password": os.getenv("RT_PASSWORD"),
|
||||
# "requesttracker_base_url": os.getenv("RT_BASE_URL"),
|
||||
# }
|
||||
# )
|
||||
# current = time.time()
|
||||
# one_day_ago = current - (24 * 60 * 60) # 1 days
|
||||
# latest_docs = rt_connector.poll_source(one_day_ago, current)
|
||||
# for doc in latest_docs:
|
||||
# print(doc)
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from logging import DEBUG as LOG_LVL_DEBUG
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from rt.rest1 import ALL_QUEUES
|
||||
from rt.rest1 import Rt
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class RequestTrackerError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RequestTrackerConnector(PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
|
||||
def txn_link(self, tid: int, txn: int) -> str:
|
||||
return f"{self.rt_base_url}/Ticket/Display.html?id={tid}&txn={txn}"
|
||||
|
||||
def build_doc_sections_from_txn(
|
||||
self, connection: Rt, ticket_id: int
|
||||
) -> List[Section]:
|
||||
Sections: List[Section] = []
|
||||
|
||||
get_history_resp = connection.get_history(ticket_id)
|
||||
|
||||
if get_history_resp is None:
|
||||
raise RequestTrackerError(f"Ticket {ticket_id} cannot be found")
|
||||
|
||||
for tx in get_history_resp:
|
||||
Sections.append(
|
||||
Section(
|
||||
link=self.txn_link(ticket_id, int(tx["id"])),
|
||||
text="\n".join(
|
||||
[
|
||||
f"{k}:\n{v}\n" if k != "Attachments" else ""
|
||||
for (k, v) in tx.items()
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
return Sections
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> Optional[dict[str, Any]]:
|
||||
self.rt_username = credentials.get("requesttracker_username")
|
||||
self.rt_password = credentials.get("requesttracker_password")
|
||||
self.rt_base_url = credentials.get("requesttracker_base_url")
|
||||
return None
|
||||
|
||||
# This does not include RT file attachments yet.
|
||||
def _process_tickets(
|
||||
self, start: datetime, end: datetime
|
||||
) -> GenerateDocumentsOutput:
|
||||
if any([self.rt_username, self.rt_password, self.rt_base_url]) is None:
|
||||
raise ConnectorMissingCredentialError("requesttracker")
|
||||
|
||||
Rt0 = Rt(
|
||||
f"{self.rt_base_url}/REST/1.0/",
|
||||
self.rt_username,
|
||||
self.rt_password,
|
||||
)
|
||||
|
||||
Rt0.login()
|
||||
|
||||
d0 = start.strftime("%Y-%m-%d %H:%M:%S")
|
||||
d1 = end.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
tickets = Rt0.search(
|
||||
Queue=ALL_QUEUES,
|
||||
raw_query=f"Updated > '{d0}' AND Updated < '{d1}'",
|
||||
)
|
||||
|
||||
doc_batch: List[Document] = []
|
||||
|
||||
for ticket in tickets:
|
||||
ticket_keys_to_omit = ["id", "Subject"]
|
||||
tid: int = int(ticket["numerical_id"])
|
||||
ticketLink: str = f"{self.rt_base_url}/Ticket/Display.html?id={tid}"
|
||||
logger.info(f"Processing ticket {tid}")
|
||||
doc = Document(
|
||||
id=ticket["id"],
|
||||
# Will add title to the first section later in processing
|
||||
sections=[Section(link=ticketLink, text="")]
|
||||
+ self.build_doc_sections_from_txn(Rt0, tid),
|
||||
source=DocumentSource.REQUESTTRACKER,
|
||||
semantic_identifier=ticket["Subject"],
|
||||
metadata={
|
||||
key: value
|
||||
for key, value in ticket.items()
|
||||
if key not in ticket_keys_to_omit
|
||||
},
|
||||
)
|
||||
|
||||
doc_batch.append(doc)
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
# Keep query short, only look behind 1 day at maximum
|
||||
one_day_ago: float = end - (24 * 60 * 60)
|
||||
_start: float = start if start > one_day_ago else one_day_ago
|
||||
start_datetime = datetime.fromtimestamp(_start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
yield from self._process_tickets(start_datetime, end_datetime)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
logger.setLevel(LOG_LVL_DEBUG)
|
||||
rt_connector = RequestTrackerConnector()
|
||||
rt_connector.load_credentials(
|
||||
{
|
||||
"requesttracker_username": os.getenv("RT_USERNAME"),
|
||||
"requesttracker_password": os.getenv("RT_PASSWORD"),
|
||||
"requesttracker_base_url": os.getenv("RT_BASE_URL"),
|
||||
}
|
||||
)
|
||||
|
||||
current = time.time()
|
||||
one_day_ago = current - (24 * 60 * 60) # 1 days
|
||||
latest_docs = rt_connector.poll_source(one_day_ago, current)
|
||||
|
||||
for doc in latest_docs:
|
||||
print(doc)
|
||||
|
||||
@@ -11,25 +11,17 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.models import SlimDocument
|
||||
from danswer.connectors.salesforce.utils import extract_dict_text
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
# TODO: this connector does not work well at large scales
|
||||
# the large query against a large Salesforce instance has been reported to take 1.5 hours.
|
||||
# Additionally it seems to eat up more memory over time if the connection is long running (again a scale issue).
|
||||
|
||||
|
||||
DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
||||
MAX_QUERY_LENGTH = 10000 # max query length is 20,000 characters
|
||||
ID_PREFIX = "SALESFORCE_"
|
||||
@@ -37,7 +29,7 @@ ID_PREFIX = "SALESFORCE_"
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class SalesforceConnector(LoadConnector, PollConnector, IdConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
@@ -251,22 +243,19 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
end_datetime = datetime.utcfromtimestamp(end)
|
||||
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
|
||||
|
||||
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
|
||||
def retrieve_all_source_ids(self) -> set[str]:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
all_retrieved_ids: set[str] = set()
|
||||
for parent_object_type in self.parent_object_list:
|
||||
query = f"SELECT Id FROM {parent_object_type}"
|
||||
query_result = self.sf_client.query_all(query)
|
||||
doc_metadata_list.extend(
|
||||
SlimDocument(
|
||||
id=f"{ID_PREFIX}{instance_dict.get('Id', '')}",
|
||||
perm_sync_data={},
|
||||
)
|
||||
all_retrieved_ids.update(
|
||||
f"{ID_PREFIX}{instance_dict.get('Id', '')}"
|
||||
for instance_dict in query_result["records"]
|
||||
)
|
||||
|
||||
yield doc_metadata_list
|
||||
return all_retrieved_ids
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -25,7 +25,6 @@ from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
|
||||
@@ -20,13 +20,10 @@ from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# Fairly generous retry because it's not understood why occasionally GraphQL requests fail even with timeout > 1 min
|
||||
SLAB_GRAPHQL_MAX_TRIES = 10
|
||||
SLAB_API_URL = "https://api.slab.com/v1/graphql"
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def run_graphql_request(
|
||||
|
||||
@@ -13,15 +13,13 @@ from danswer.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.models import SlimDocument
|
||||
from danswer.connectors.slack.utils import expert_info_from_slack_id
|
||||
from danswer.connectors.slack.utils import get_message_link
|
||||
from danswer.connectors.slack.utils import make_paginated_slack_api_call_w_retries
|
||||
@@ -328,7 +326,7 @@ def _get_all_doc_ids(
|
||||
channels: list[str] | None = None,
|
||||
channel_name_regex_enabled: bool = False,
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
) -> set[str]:
|
||||
"""
|
||||
Get all document ids in the workspace, channel by channel
|
||||
This is pretty identical to get_all_docs, but it returns a set of ids instead of documents
|
||||
@@ -340,14 +338,13 @@ def _get_all_doc_ids(
|
||||
all_channels, channels, channel_name_regex_enabled
|
||||
)
|
||||
|
||||
all_doc_ids = set()
|
||||
for channel in filtered_channels:
|
||||
channel_id = channel["id"]
|
||||
channel_message_batches = get_channel_messages(
|
||||
client=client,
|
||||
channel=channel,
|
||||
)
|
||||
|
||||
message_ts_set: set[str] = set()
|
||||
for message_batch in channel_message_batches:
|
||||
for message in message_batch:
|
||||
if msg_filter_func(message):
|
||||
@@ -356,21 +353,12 @@ def _get_all_doc_ids(
|
||||
# The document id is the channel id and the ts of the first message in the thread
|
||||
# Since we already have the first message of the thread, we dont have to
|
||||
# fetch the thread for id retrieval, saving time and API calls
|
||||
message_ts_set.add(message["ts"])
|
||||
all_doc_ids.add(f"{channel['id']}__{message['ts']}")
|
||||
|
||||
channel_metadata_list: list[SlimDocument] = []
|
||||
for message_ts in message_ts_set:
|
||||
channel_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=f"{channel_id}__{message_ts}",
|
||||
perm_sync_data={"channel_id": channel_id},
|
||||
)
|
||||
)
|
||||
|
||||
yield channel_metadata_list
|
||||
return all_doc_ids
|
||||
|
||||
|
||||
class SlackPollConnector(PollConnector, SlimConnector):
|
||||
class SlackPollConnector(PollConnector, IdConnector):
|
||||
def __init__(
|
||||
self,
|
||||
workspace: str,
|
||||
@@ -391,7 +379,7 @@ class SlackPollConnector(PollConnector, SlimConnector):
|
||||
self.client = WebClient(token=bot_token)
|
||||
return None
|
||||
|
||||
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
|
||||
def retrieve_all_source_ids(self) -> set[str]:
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
@@ -441,7 +429,6 @@ if __name__ == "__main__":
|
||||
|
||||
current = time.time()
|
||||
one_day_ago = current - 24 * 60 * 60 # 1 day
|
||||
|
||||
document_batches = connector.poll_source(one_day_ago, current)
|
||||
|
||||
print(next(document_batches))
|
||||
|
||||
@@ -16,7 +16,6 @@ from danswer.connectors.slack.connector import filter_channels
|
||||
from danswer.connectors.slack.utils import get_message_link
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
|
||||
@@ -10,9 +10,9 @@ from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.web import SlackResponse
|
||||
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -373,7 +373,7 @@ class WebConnector(LoadConnector):
|
||||
page.close()
|
||||
except Exception as e:
|
||||
last_error = f"Failed to fetch '{current_url}': {e}"
|
||||
logger.exception(last_error)
|
||||
logger.error(last_error)
|
||||
playwright.stop()
|
||||
restart_playwright = True
|
||||
continue
|
||||
|
||||
@@ -211,7 +211,6 @@ def handle_regular_answer(
|
||||
use_citations=use_citations,
|
||||
danswerbot_flow=True,
|
||||
)
|
||||
|
||||
if not answer.error_msg:
|
||||
return answer
|
||||
else:
|
||||
|
||||
@@ -7,6 +7,7 @@ from slack_sdk import WebClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
|
||||
from danswer.background.celery.celery_app import get_all_tenant_ids
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
|
||||
@@ -46,7 +47,6 @@ from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
|
||||
from danswer.danswerbot.slack.utils import rephrase_slack_message
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import TenantSocketModeClient
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
@@ -57,7 +57,6 @@ from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import SLACK_CHANNEL_ID
|
||||
@@ -346,9 +345,7 @@ def process_message(
|
||||
respond_every_channel: bool = DANSWER_BOT_RESPOND_EVERY_CHANNEL,
|
||||
notify_no_answer: bool = NOTIFY_SLACKBOT_NO_ANSWER,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"Received Slack request of type: '{req.type}' for tenant, {client.tenant_id}"
|
||||
)
|
||||
logger.debug(f"Received Slack request of type: '{req.type}'")
|
||||
|
||||
# Throw out requests that can't or shouldn't be handled
|
||||
if not prefilter_requests(req, client):
|
||||
@@ -360,59 +357,51 @@ def process_message(
|
||||
client=client.web_client, channel_id=channel
|
||||
)
|
||||
|
||||
# Set the current tenant ID at the beginning for all DB calls within this thread
|
||||
if client.tenant_id:
|
||||
logger.info(f"Setting tenant ID to {client.tenant_id}")
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(client.tenant_id)
|
||||
try:
|
||||
with get_session_with_tenant(client.tenant_id) as db_session:
|
||||
slack_bot_config = get_slack_bot_config_for_channel(
|
||||
channel_name=channel_name, db_session=db_session
|
||||
)
|
||||
with get_session_with_tenant(client.tenant_id) as db_session:
|
||||
slack_bot_config = get_slack_bot_config_for_channel(
|
||||
channel_name=channel_name, db_session=db_session
|
||||
)
|
||||
|
||||
# Be careful about this default, don't want to accidentally spam every channel
|
||||
# Users should be able to DM slack bot in their private channels though
|
||||
if (
|
||||
slack_bot_config is None
|
||||
and not respond_every_channel
|
||||
# Can't have configs for DMs so don't toss them out
|
||||
and not is_dm
|
||||
# If /DanswerBot (is_bot_msg) or @DanswerBot (bypass_filters)
|
||||
# always respond with the default configs
|
||||
and not (details.is_bot_msg or details.bypass_filters)
|
||||
):
|
||||
return
|
||||
# Be careful about this default, don't want to accidentally spam every channel
|
||||
# Users should be able to DM slack bot in their private channels though
|
||||
if (
|
||||
slack_bot_config is None
|
||||
and not respond_every_channel
|
||||
# Can't have configs for DMs so don't toss them out
|
||||
and not is_dm
|
||||
# If /DanswerBot (is_bot_msg) or @DanswerBot (bypass_filters)
|
||||
# always respond with the default configs
|
||||
and not (details.is_bot_msg or details.bypass_filters)
|
||||
):
|
||||
return
|
||||
|
||||
follow_up = bool(
|
||||
slack_bot_config
|
||||
and slack_bot_config.channel_config
|
||||
and slack_bot_config.channel_config.get("follow_up_tags") is not None
|
||||
)
|
||||
feedback_reminder_id = schedule_feedback_reminder(
|
||||
details=details, client=client.web_client, include_followup=follow_up
|
||||
)
|
||||
follow_up = bool(
|
||||
slack_bot_config
|
||||
and slack_bot_config.channel_config
|
||||
and slack_bot_config.channel_config.get("follow_up_tags") is not None
|
||||
)
|
||||
feedback_reminder_id = schedule_feedback_reminder(
|
||||
details=details, client=client.web_client, include_followup=follow_up
|
||||
)
|
||||
|
||||
failed = handle_message(
|
||||
message_info=details,
|
||||
slack_bot_config=slack_bot_config,
|
||||
client=client.web_client,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
tenant_id=client.tenant_id,
|
||||
)
|
||||
failed = handle_message(
|
||||
message_info=details,
|
||||
slack_bot_config=slack_bot_config,
|
||||
client=client.web_client,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
tenant_id=client.tenant_id,
|
||||
)
|
||||
|
||||
if failed:
|
||||
if feedback_reminder_id:
|
||||
remove_scheduled_feedback_reminder(
|
||||
client=client.web_client,
|
||||
channel=details.sender,
|
||||
msg_id=feedback_reminder_id,
|
||||
)
|
||||
# Skipping answering due to pre-filtering is not considered a failure
|
||||
if notify_no_answer:
|
||||
apologize_for_fail(details, client)
|
||||
finally:
|
||||
if client.tenant_id:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
if failed:
|
||||
if feedback_reminder_id:
|
||||
remove_scheduled_feedback_reminder(
|
||||
client=client.web_client,
|
||||
channel=details.sender,
|
||||
msg_id=feedback_reminder_id,
|
||||
)
|
||||
# Skipping answering due to pre-filtering is not considered a failure
|
||||
if notify_no_answer:
|
||||
apologize_for_fail(details, client)
|
||||
|
||||
|
||||
def acknowledge_message(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
|
||||
@@ -510,9 +499,7 @@ if __name__ == "__main__":
|
||||
for tenant_id in tenant_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id or "public")
|
||||
latest_slack_bot_tokens = fetch_tokens()
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
if (
|
||||
tenant_id not in slack_bot_tokens
|
||||
@@ -546,11 +533,6 @@ if __name__ == "__main__":
|
||||
socket_client = _get_socket_client(
|
||||
latest_slack_bot_tokens, tenant_id
|
||||
)
|
||||
|
||||
# Initialize socket client for this tenant. Each tenant has its own
|
||||
# socket client, allowing for multiple concurrent connections (one
|
||||
# per tenant) with the tenant ID wrapped in the socket model client.
|
||||
# Each `connect` stores websocket connection in a separate thread.
|
||||
_initialize_socket_client(socket_client)
|
||||
|
||||
socket_clients[tenant_id] = socket_client
|
||||
|
||||
@@ -57,10 +57,7 @@ async def get_user_count() -> int:
|
||||
|
||||
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
|
||||
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase):
|
||||
async def create(
|
||||
self,
|
||||
create_dict: Dict[str, Any],
|
||||
) -> UP:
|
||||
async def create(self, create_dict: Dict[str, Any]) -> UP:
|
||||
user_count = await get_user_count()
|
||||
if user_count == 0 or create_dict["email"] in get_default_admin_user_emails():
|
||||
create_dict["role"] = UserRole.ADMIN
|
||||
|
||||
@@ -341,8 +341,6 @@ def add_credential_to_connector(
|
||||
access_type: AccessType,
|
||||
groups: list[int] | None,
|
||||
auto_sync_options: dict | None = None,
|
||||
initial_status: ConnectorCredentialPairStatus = ConnectorCredentialPairStatus.ACTIVE,
|
||||
last_successful_index_time: datetime | None = None,
|
||||
) -> StatusResponse:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
@@ -386,10 +384,9 @@ def add_credential_to_connector(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
name=cc_pair_name,
|
||||
status=initial_status,
|
||||
status=ConnectorCredentialPairStatus.ACTIVE,
|
||||
access_type=access_type,
|
||||
auto_sync_options=auto_sync_options,
|
||||
last_successful_index_time=last_successful_index_time,
|
||||
)
|
||||
db_session.add(association)
|
||||
db_session.flush() # make sure the association has an id
|
||||
|
||||
@@ -40,8 +40,6 @@ CREDENTIAL_PERMISSIONS_TO_IGNORE = {
|
||||
DocumentSource.MEDIAWIKI,
|
||||
}
|
||||
|
||||
PUBLIC_CREDENTIAL_ID = 0
|
||||
|
||||
|
||||
def _add_user_filters(
|
||||
stmt: Select,
|
||||
@@ -243,7 +241,7 @@ def create_credential(
|
||||
curator_public=credential_data.curator_public,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush() # This ensures the credential gets an ID
|
||||
db_session.flush() # This ensures the credential gets an IDcredentials
|
||||
_relate_credential_to_user_groups__no_commit(
|
||||
db_session=db_session,
|
||||
credential_id=credential.id,
|
||||
@@ -386,11 +384,12 @@ def delete_credential(
|
||||
|
||||
|
||||
def create_initial_public_credential(db_session: Session) -> None:
|
||||
public_cred_id = 0
|
||||
error_msg = (
|
||||
"DB is not in a valid initial state."
|
||||
"There must exist an empty public credential for data connectors that do not require additional Auth."
|
||||
)
|
||||
first_credential = fetch_credential_by_id(PUBLIC_CREDENTIAL_ID, None, db_session)
|
||||
first_credential = fetch_credential_by_id(public_cred_id, None, db_session)
|
||||
|
||||
if first_credential is not None:
|
||||
if first_credential.credential_json != {} or first_credential.user is not None:
|
||||
@@ -398,7 +397,7 @@ def create_initial_public_credential(db_session: Session) -> None:
|
||||
return
|
||||
|
||||
credential = Credential(
|
||||
id=PUBLIC_CREDENTIAL_ID,
|
||||
id=public_cred_id,
|
||||
credential_json={},
|
||||
user_id=None,
|
||||
)
|
||||
@@ -406,24 +405,6 @@ def create_initial_public_credential(db_session: Session) -> None:
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def cleanup_gmail_credentials(db_session: Session) -> None:
|
||||
gmail_credentials = fetch_credentials_by_source(
|
||||
db_session=db_session, user=None, document_source=DocumentSource.GMAIL
|
||||
)
|
||||
for credential in gmail_credentials:
|
||||
db_session.delete(credential)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def cleanup_google_drive_credentials(db_session: Session) -> None:
|
||||
google_drive_credentials = fetch_credentials_by_source(
|
||||
db_session=db_session, user=None, document_source=DocumentSource.GOOGLE_DRIVE
|
||||
)
|
||||
for credential in google_drive_credentials:
|
||||
db_session.delete(credential)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_gmail_service_account_credentials(
|
||||
user: User | None, db_session: Session
|
||||
) -> None:
|
||||
|
||||
@@ -375,20 +375,6 @@ def update_docs_last_modified__no_commit(
|
||||
doc.last_modified = now
|
||||
|
||||
|
||||
def mark_document_as_modified(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
stmt = select(DbDocument).where(DbDocument.id == document_id)
|
||||
doc = db_session.scalar(stmt)
|
||||
if doc is None:
|
||||
raise ValueError(f"No document with ID: {document_id}")
|
||||
|
||||
# update last_synced
|
||||
doc.last_modified = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_document_as_synced(document_id: str, db_session: Session) -> None:
|
||||
stmt = select(DbDocument).where(DbDocument.id == document_id)
|
||||
doc = db_session.scalar(stmt)
|
||||
|
||||
@@ -25,6 +25,7 @@ from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from danswer.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
|
||||
from danswer.configs.app_configs import LOG_POSTGRES_LATENCY
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
||||
from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
|
||||
from danswer.configs.app_configs import POSTGRES_DB
|
||||
@@ -34,13 +35,11 @@ from danswer.configs.app_configs import POSTGRES_POOL_PRE_PING
|
||||
from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
|
||||
from danswer.configs.app_configs import POSTGRES_PORT
|
||||
from danswer.configs.app_configs import POSTGRES_USER
|
||||
from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||
from danswer.configs.app_configs import SECRET_JWT_KEY
|
||||
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
|
||||
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
from shared_configs.configs import current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -189,29 +188,6 @@ class SqlEngine:
|
||||
return cls._app_name
|
||||
|
||||
|
||||
def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
if not MULTI_TENANT:
|
||||
return [None]
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as session:
|
||||
result = session.execute(
|
||||
text(
|
||||
f"""
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', '{POSTGRES_DEFAULT_SCHEMA}')"""
|
||||
)
|
||||
)
|
||||
tenant_ids = [row[0] for row in result]
|
||||
|
||||
valid_tenants = [
|
||||
tenant
|
||||
for tenant in tenant_ids
|
||||
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
|
||||
]
|
||||
|
||||
return valid_tenants
|
||||
|
||||
|
||||
def build_connection_string(
|
||||
*,
|
||||
db_api: str = ASYNC_DB_API,
|
||||
@@ -260,30 +236,27 @@ def get_current_tenant_id(request: Request) -> str:
|
||||
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
|
||||
if not MULTI_TENANT:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
current_tenant_id.set(tenant_id)
|
||||
return tenant_id
|
||||
|
||||
token = request.cookies.get("fastapiusersauth")
|
||||
token = request.cookies.get("tenant_details")
|
||||
if not token:
|
||||
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
current_value = current_tenant_id.get()
|
||||
# If no token is present, use the default schema or handle accordingly
|
||||
return current_value
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
USER_AUTH_SECRET,
|
||||
audience=["fastapi-users:auth"],
|
||||
algorithms=["HS256"],
|
||||
)
|
||||
tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"])
|
||||
tenant_id = payload.get("tenant_id")
|
||||
if not tenant_id:
|
||||
return current_tenant_id.get()
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
current_tenant_id.set(tenant_id)
|
||||
|
||||
return tenant_id
|
||||
except jwt.InvalidTokenError:
|
||||
return CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
return current_tenant_id.get()
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in get_current_tenant_id: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
@@ -294,7 +267,7 @@ async def get_async_session_with_tenant(
|
||||
tenant_id: str | None = None,
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
if tenant_id is None:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
tenant_id = current_tenant_id.get()
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
logger.error(f"Invalid tenant ID: {tenant_id}")
|
||||
@@ -326,9 +299,9 @@ def get_session_with_tenant(
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
if tenant_id is None:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
tenant_id = current_tenant_id.get()
|
||||
else:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
current_tenant_id.set(tenant_id)
|
||||
|
||||
event.listen(engine, "checkout", set_search_path_on_checkout)
|
||||
|
||||
@@ -364,22 +337,26 @@ def get_session_with_tenant(
|
||||
def set_search_path_on_checkout(
|
||||
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
|
||||
) -> None:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
tenant_id = current_tenant_id.get()
|
||||
if tenant_id and is_valid_schema_name(tenant_id):
|
||||
with dbapi_conn.cursor() as cursor:
|
||||
cursor.execute(f'SET search_path TO "{tenant_id}"')
|
||||
logger.debug(
|
||||
f"Set search_path to {tenant_id} for connection {connection_record}"
|
||||
)
|
||||
|
||||
|
||||
def get_session_generator_with_tenant() -> Generator[Session, None, None]:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
def get_session_generator_with_tenant(
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[Session, None, None]:
|
||||
with get_session_with_tenant(tenant_id) as session:
|
||||
yield session
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
"""Generate a database session with the appropriate tenant schema set."""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
|
||||
tenant_id = current_tenant_id.get()
|
||||
if tenant_id == "public" and MULTI_TENANT:
|
||||
raise HTTPException(status_code=401, detail="User must authenticate")
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
@@ -395,7 +372,7 @@ def get_session() -> Generator[Session, None, None]:
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Generate an async database session with the appropriate tenant schema set."""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
tenant_id = current_tenant_id.get()
|
||||
engine = get_sqlalchemy_async_engine()
|
||||
async with AsyncSession(engine, expire_on_commit=False) as async_session:
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import and_
|
||||
@@ -67,32 +66,6 @@ def create_index_attempt(
|
||||
return new_attempt.id
|
||||
|
||||
|
||||
def mock_successful_index_attempt(
|
||||
connector_credential_pair_id: int,
|
||||
search_settings_id: int,
|
||||
docs_indexed: int,
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
"""Should not be used in any user triggered flows"""
|
||||
db_time = func.now()
|
||||
new_attempt = IndexAttempt(
|
||||
connector_credential_pair_id=connector_credential_pair_id,
|
||||
search_settings_id=search_settings_id,
|
||||
from_beginning=True,
|
||||
status=IndexingStatus.SUCCESS,
|
||||
total_docs_indexed=docs_indexed,
|
||||
new_docs_indexed=docs_indexed,
|
||||
# Need this to be some convincing random looking value and it can't be 0
|
||||
# or the indexing rate would calculate out to infinity
|
||||
time_started=db_time - timedelta(seconds=1.92),
|
||||
time_updated=db_time,
|
||||
)
|
||||
db_session.add(new_attempt)
|
||||
db_session.commit()
|
||||
|
||||
return new_attempt.id
|
||||
|
||||
|
||||
def get_in_progress_index_attempts(
|
||||
connector_id: int | None,
|
||||
db_session: Session,
|
||||
|
||||
@@ -95,11 +95,10 @@ def upsert_llm_provider(
|
||||
group_ids=llm_provider.groups,
|
||||
db_session=db_session,
|
||||
)
|
||||
full_llm_provider = FullLLMProvider.from_model(existing_llm_provider)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return full_llm_provider
|
||||
return FullLLMProvider.from_model(existing_llm_provider)
|
||||
|
||||
|
||||
def fetch_existing_embedding_providers(
|
||||
|
||||
@@ -56,7 +56,7 @@ def get_notification_by_id(
|
||||
if not notif:
|
||||
raise ValueError(f"No notification found with id {notification_id}")
|
||||
if notif.user_id != user_id and not (
|
||||
notif.user_id is None and user is not None and user.role == UserRole.ADMIN
|
||||
notif.user_id is None and user.role == UserRole.ADMIN
|
||||
):
|
||||
raise PermissionError(
|
||||
f"User {user_id} is not authorized to access notification {notification_id}"
|
||||
|
||||
@@ -328,6 +328,7 @@ def update_all_personas_display_priority(
|
||||
|
||||
for persona in personas:
|
||||
persona.display_priority = display_priority_map[persona.id]
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.configs.constants import KV_REINDEX_KEY
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import resync_cc_pair
|
||||
@@ -14,7 +15,6 @@ from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.search_settings import update_search_settings_status
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.vespa.index import VespaIndex
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
def get_default_document_index(
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
from retry import retry
|
||||
|
||||
from danswer.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
|
||||
@@ -193,21 +194,20 @@ def _get_chunks_via_visit_api(
|
||||
|
||||
document_chunks: list[dict] = []
|
||||
while True:
|
||||
response = requests.get(url, params=params)
|
||||
try:
|
||||
filtered_params = {k: v for k, v in params.items() if v is not None}
|
||||
with get_vespa_http_client() as http_client:
|
||||
response = http_client.get(url, params=filtered_params)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as e:
|
||||
error_base = "Failed to query Vespa"
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
request_info = f"Headers: {response.request.headers}\nPayload: {params}"
|
||||
response_info = f"Status Code: {response.status_code}\nResponse Content: {response.text}"
|
||||
error_base = f"Error occurred getting chunk by Document ID {chunk_request.document_id}"
|
||||
logger.error(
|
||||
f"{error_base}:\n"
|
||||
f"Request URL: {e.request.url}\n"
|
||||
f"Request Headers: {e.request.headers}\n"
|
||||
f"Request Payload: {params}\n"
|
||||
f"Exception: {str(e)}"
|
||||
f"{request_info}\n"
|
||||
f"{response_info}\n"
|
||||
f"Exception: {e}"
|
||||
)
|
||||
raise httpx.HTTPError(error_base) from e
|
||||
raise requests.HTTPError(error_base) from e
|
||||
|
||||
# Check if the response contains any documents
|
||||
response_data = response.json()
|
||||
@@ -301,13 +301,17 @@ def query_vespa(
|
||||
response = http_client.post(SEARCH_ENDPOINT, json=params)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as e:
|
||||
request_info = f"Headers: {response.request.headers}\nPayload: {params}"
|
||||
response_info = (
|
||||
f"Status Code: {response.status_code}\n"
|
||||
f"Response Content: {response.text}"
|
||||
)
|
||||
error_base = "Failed to query Vespa"
|
||||
logger.error(
|
||||
f"{error_base}:\n"
|
||||
f"Request URL: {e.request.url}\n"
|
||||
f"Request Headers: {e.request.headers}\n"
|
||||
f"Request Payload: {params}\n"
|
||||
f"Exception: {str(e)}"
|
||||
f"{request_info}\n"
|
||||
f"{response_info}\n"
|
||||
f"Exception: {e}"
|
||||
)
|
||||
raise httpx.HTTPError(error_base) from e
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.configs.chat_configs import DOC_TIME_DECAY
|
||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
@@ -72,7 +73,6 @@ from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import InferenceChunkUncleaned
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ def _does_document_exist(
|
||||
chunk. This checks for whether the chunk exists already in the index"""
|
||||
doc_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}"
|
||||
doc_fetch_response = http_client.get(doc_url)
|
||||
|
||||
if doc_fetch_response.status_code == 404:
|
||||
return False
|
||||
|
||||
@@ -117,7 +118,7 @@ def get_existing_documents_from_chunks(
|
||||
return document_ids
|
||||
|
||||
|
||||
@retry(tries=5, delay=1, backoff=2)
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _index_vespa_chunk(
|
||||
chunk: DocMetadataAwareIndexChunk,
|
||||
index_name: str,
|
||||
|
||||
@@ -29,7 +29,6 @@ VESPA_APPLICATION_ENDPOINT = f"{VESPA_CONFIG_SERVER_URL}/application/v2"
|
||||
# main search application
|
||||
VESPA_APP_CONTAINER_URL = VESPA_CLOUD_URL or f"http://{VESPA_HOST}:{VESPA_PORT}"
|
||||
|
||||
|
||||
# danswer_chunk below is defined in vespa/app_configs/schemas/danswer_chunk.sd
|
||||
DOCUMENT_ID_ENDPOINT = (
|
||||
f"{VESPA_APP_CONTAINER_URL}/document/v1/default/{{index_name}}/docid"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user