Compare commits

..

64 Commits

Author SHA1 Message Date
pablodanswer
c32d428eb7 k 2024-10-26 16:20:18 -07:00
Yuhong Sun
aa0f307cc7 Backport Test Final (#2942) 2024-10-26 21:52:59 +00:00
Yuhong Sun
e6bef573ba Backport Correct Branch (#2941) 2024-10-26 14:34:24 -07:00
Yuhong Sun
f6f9112b76 Backport Test (#2940) 2024-10-26 14:23:43 -07:00
Yuhong Sun
accdd580d7 Backport Test (#2939) 2024-10-26 13:59:55 -07:00
Yuhong Sun
4bcd65ed92 Harmless Backport Test (#2938) 2024-10-26 13:47:09 -07:00
Yuhong Sun
80f8d7a486 Backport Permissions (#2937) 2024-10-26 13:42:09 -07:00
pablodanswer
e8c28e79c9 ensure proper sentry silencing (#2934)
* ensure proper sentry silencing

* add comments
2024-10-26 20:18:41 +00:00
Yuhong Sun
b4bc6d994d Backport Auth (#2936) 2024-10-26 13:20:02 -07:00
Yuhong Sun
ccc68c5c34 Backport Test (#2935) 2024-10-26 13:09:07 -07:00
pablodanswer
848d86b886 feat: sentry updates (#2929) 2024-10-26 19:06:46 +00:00
Yuhong Sun
c0ab86bac2 Backport Branch Fix (#2931) 2024-10-26 12:04:52 -07:00
Yuhong Sun
8c2138a6ef Backport Test (#2930) 2024-10-26 11:51:03 -07:00
pablodanswer
9def9f0dba add posthog + layout rework (#2926)
* add posthog + layout rework

* remove posthog node

* nit
2024-10-26 18:15:01 +00:00
pablodanswer
5e01d6befb check for index swap (#2922)
* check for index swap

* k

* minor

* k

* nit
2024-10-26 00:26:02 +00:00
hagen-danswer
94edcac36e upgraded claude model strings (#2876)
* upgraded model strings

* trolled

* we do a little trolling

* reeeeeee

* alembic upgrade

* added ignore

* bump litellm

* k

* nit

---------

Co-authored-by: pablodanswer <pablo@danswer.ai>
2024-10-26 00:11:52 +00:00
pablodanswer
9b147ae437 Tenant integration tests (#2913)
* check for index swap

* initial bones

* kk

* k

* k:

* nit

* nit

* rebase + update

* nit

* minior update

* k

* minor integration test fixes

* nit

* ensure we build test docker image

* remove one space

* k

* ensure we wipe volumes

* remove log

* typo

* nit

* k

* k
2024-10-25 18:47:17 +00:00
Chris Weaver
bd63119684 Fix structured outputs (#2923)
* Fix structured outputs

* Add back rest
2024-10-25 18:19:54 +00:00
pablodanswer
76415aff41 Ensure proper modal fallback (#2906)
* modal fallback

* nit

* k

* k
2024-10-25 17:59:43 +00:00
Weves
4ca38201d1 Fix IT fixture ordering 2024-10-24 22:43:38 -07:00
Chris Weaver
4a47e9a841 Add strict json mode (#2917) 2024-10-24 22:38:46 -07:00
Yuhong Sun
d7a30b01d2 Harmless Backport (#2916) 2024-10-24 20:56:59 -07:00
Yuhong Sun
9c0f927e16 Workflow (#2915) 2024-10-24 20:53:48 -07:00
Yuhong Sun
55b9111410 Harmless Backport (#2914) 2024-10-24 20:43:16 -07:00
Yuhong Sun
07a4e112a4 Dev Experience (#2912) 2024-10-24 20:25:36 -07:00
rkuo-danswer
eaa8ae7399 Bugfix/connector deletion lockout (#2901)
* first cut at deletion hardening

* clean up logging

* remove commented code
2024-10-25 02:43:57 +00:00
rkuo-danswer
b9781c43fb Merge pull request #2909 from danswer-ai/bugfix/loopio
loopio connector: entry["id"] can apparently be a number, so convert to str
2024-10-24 19:55:47 -07:00
Yuhong Sun
a931494866 Harmless Backport (#2911) 2024-10-24 19:17:11 -07:00
Yuhong Sun
863f00f015 Auto Backport Partial (#2910) 2024-10-24 19:13:09 -07:00
pablodanswer
eae1dad0fa Silence unnecessary debug log (#2908)
* silence log

* silence
2024-10-25 01:32:53 +00:00
Richard Kuo (Danswer)
10b5b55658 entry["id"] can apparently be a number, so convert to str 2024-10-24 18:31:10 -07:00
Yuhong Sun
b49a9ab171 Seeding (#2902)
* checkpoint

* k

* k

* k

* fixed slack api calls

* missed one

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
2024-10-24 23:45:48 +00:00
rkuo-danswer
9f50417109 try hiding celery task spam (#2905)
* try hiding celery task spam

* mypy fix
2024-10-24 22:44:20 +00:00
rkuo-danswer
94b4dc1656 can't add to primary_worker_locks if it doesn't exist (#2903)
* can't add to primary_worker_locks if it doesn't exist

* move init
2024-10-24 21:49:18 +00:00
rkuo-danswer
4bce143d6e Merge pull request #2904 from danswer-ai/bugfix/fix-typo
fix typo
2024-10-24 15:00:04 -07:00
pablodanswer
33eabf1b25 Add global assistants context (#2900)
* add global assistants context

* nit

* minor cleanup

* minor clarity

* nit
2024-10-24 21:27:55 +00:00
pablodanswer
da979e5745 More intuitive search settings interfaces (#2899)
* clearer search settings interfaces

* nits
2024-10-24 14:27:34 -07:00
Richard Kuo (Danswer)
705b825580 fix typo 2024-10-24 14:21:38 -07:00
Richard Kuo (Danswer)
32b595dfe1 update stale workflow 2024-10-24 13:31:39 -07:00
rkuo-danswer
2b9a751b96 working chat feedback dump script (with api addition) (#2891)
* working chat feedback dump script (with api addition)

* mypy fix

* comment out pydantic models (but leave for reference)

* small code review tweaks

* bump to clear vercel issue?
2024-10-24 19:50:09 +00:00
pablodanswer
1b6b134722 Clearer azure models (#2898)
* clear up llm

* remove logs
2024-10-24 17:29:36 +00:00
pablodanswer
0545fb4443 Multitenant redis update (#2889)
* add multi tenancy to redis

* rename context var

* k

* args -> kwargs

* minor update to kv interface

* robustify
2024-10-24 02:12:25 +00:00
hagen-danswer
b9fb657d81 Temporary fix for empty Google App credentials (#2892)
* Temporary fix for empty Google App credentials

* added it to credential creation
2024-10-24 00:49:04 +00:00
pablodanswer
14e75bbd24 add default schema config (#2888)
* add default schema config

* resolve circular import

* k
2024-10-23 23:12:17 +00:00
rkuo-danswer
3eb67baf5b Bugfix/indexing UI (#2879)
* fresh indexing feature branch

* cherry pick test

* Revert "cherry pick test"

This reverts commit 2a62422068.

* set multitenant so that vespa fields match when indexing

* cleanup pass

* mypy

* pass through env var to control celery indexing concurrency

* comments on task kickoff and some logging improvements

* disentangle configuration for different workers and beats.

* use get_session_with_tenant

* comment out all of update.py

* rename to RedisConnectorIndexingFenceData

* first check num_indexing_workers

* refactor RedisConnectorIndexingFenceData

* comment out on_worker_process_init

* missed a file

* scope db sessions to short lengths

* update launch.json template

* fix types

* keep index button disabled until indexing is truly finished

* change priority order of tooltips

* should be using the logger from app_base

* if we run out of retries, just mark the doc as modified so it gets synced later

* tighten up the logging ... we know these are ID's

* add logging
2024-10-23 20:25:52 +00:00
pablodanswer
8b72264535 Gating Notifications (#2868)
* functional notifications

* typing

* minor

* ports

* nit

* verify functionality

* pretty
2024-10-23 20:20:20 +00:00
pablodanswer
786a46cbd0 sticky credential description (#2886) 2024-10-23 19:59:14 +00:00
hagen-danswer
7abbfa37bb Tiny confluence fix (#2885)
* Tiny confluence fix

* Update utils.py

---------

Co-authored-by: pablodanswer <pablo@danswer.ai>
2024-10-23 19:57:00 +00:00
pablodanswer
143da5bc0d add copying for unrecognized languages (#2883)
* add copying for unrecognized languages

* k
2024-10-23 17:26:54 +00:00
pablodanswer
5703ea47d2 Auth on main (#2878)
* add cloud auth type

* k

* robustified cloud auth type

* k

* minor typing
2024-10-23 16:46:30 +00:00
rkuo-danswer
9105f95d13 Feature/celery refactor (#2813)
* fresh indexing feature branch

* cherry pick test

* Revert "cherry pick test"

This reverts commit 2a62422068.

* set multitenant so that vespa fields match when indexing

* cleanup pass

* mypy

* pass through env var to control celery indexing concurrency

* comments on task kickoff and some logging improvements

* disentangle configuration for different workers and beats.

* use get_session_with_tenant

* comment out all of update.py

* rename to RedisConnectorIndexingFenceData

* first check num_indexing_workers

* refactor RedisConnectorIndexingFenceData

* comment out on_worker_process_init

* missed a file

* scope db sessions to short lengths

* update launch.json template

* fix types

* code review
2024-10-22 22:57:36 +00:00
Yuhong Sun
eccec6ab7c Notion Fix Nested Properties (#2877) 2024-10-22 14:10:31 -07:00
hagen-danswer
914da2e4cb Confluence polish (#2874) 2024-10-22 20:41:47 +00:00
Yuhong Sun
e031576c87 Salesforce Connector Note (#2872) 2024-10-22 10:05:28 -07:00
Richard Kuo (Danswer)
bae794706c add stale issues and pr's cron 2024-10-22 09:46:14 -07:00
Chris Weaver
6e9b6a1075 Handle models like openai/bedrock/claude-3.5-... (#2869)
* Handle models like openai/bedrock/claude-3.5-...

* Fix log statement
2024-10-22 05:27:26 +00:00
rkuo-danswer
e4779c29a7 tighter signaling to prevent indexing cleanup from hitting tasks that are just starting (#2867)
* better indexing synchronization

* add logging for fence wait

* handle the task not creating

* add more logging

* add more logging

* raise retry count
2024-10-21 23:46:23 +00:00
hagen-danswer
802086ee57 Refactored Confluence Connector (#2859)
* Refactored Confluence Connector

* rename metadataconnector to slimconnector

Finish rename

* danswer->onyx

* added rec

* typo

* refactored doc_sync for confluence

* mypy + enable tests

* tested and fixed for confluence cloud

* fixed all server syncing

* fixed connector test

* mypy+connector test fixes

* addressed richards comments

* minor fix
2024-10-21 23:03:40 +00:00
Chris Weaver
c516f3541c Make it so you can update model providers (#2866) 2024-10-21 18:51:53 +00:00
pablodanswer
45d852a9db modal onboarding clarity (#2780) 2024-10-21 03:42:26 +00:00
pablodanswer
cee68106ef Minor vespa standardization (#2861)
* minor additional standardization

* nit: typo

* k

* account for malformed params
2024-10-21 00:41:18 +00:00
pablodanswer
a24b465663 Minor tenant ID improvements (#2850)
* add migration dockerfile

* address edge case

* k

* k

* k

* nit

* k

* k

* k

* k

* remove

* k

* add comment
2024-10-20 23:48:00 +00:00
pablodanswer
7ab0063dc6 (minor) quote overflow (#2862)
* k

* k
2024-10-20 23:31:18 +00:00
Yuhong Sun
dd2551040f Docstring Update for Docs (#2863) 2024-10-20 15:31:08 -07:00
227 changed files with 16697 additions and 3825 deletions

View File

@@ -6,20 +6,24 @@
[Describe the tests you ran to verify your changes]
## Accepted Risk
[Any know risks or failure modes to point out to reviewers]
## Accepted Risk (provide if relevant)
N/A
## Related Issue(s)
[If applicable, link to the issue(s) this PR addresses]
## Related Issue(s) (provide if relevant)
N/A
## Checklist:
- [ ] All of the automated tests pass
- [ ] All PR comments are addressed and marked resolved
- [ ] If there are migrations, they have been rebased to latest main
- [ ] If there are new dependencies, they are added to the requirements
- [ ] If there are new environment variables, they are added to all of the deployment methods
- [ ] If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
- [ ] Docker images build and basic functionalities work
- [ ] Author has done a final read through of the PR right before merge
## Mental Checklist:
- All of the automated tests pass
- All PR comments are addressed and marked resolved
- If there are migrations, they have been rebased to latest main
- If there are new dependencies, they are added to the requirements
- If there are new environment variables, they are added to all of the deployment methods
- If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
- Docker images build and basic functionalities work
- Author has done a final read through of the PR right before merge
## Backporting (check the box to trigger backport action)
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)

View File

@@ -0,0 +1,23 @@
name: 'Nightly - Close stale issues and PRs'
on:
schedule:
- cron: '0 11 * * *' # Runs every day at 3 AM PST / 4 AM PDT / 11 AM UTC
permissions:
# contents: write # only for delete-branch option
issues: write
pull-requests: write
jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v9
with:
stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
close-issue-message: 'This issue was closed because it has been stalled for 90 days with no activity.'
close-pr-message: 'This PR was closed because it has been stalled for 90 days with no activity.'
days-before-stale: 75
# days-before-close: 90 # uncomment after we test stale behavior

View File

@@ -72,7 +72,7 @@ jobs:
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
with:
@@ -85,7 +85,58 @@ jobs:
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Start Docker containers
# Start containers for multi-tenant tests
- name: Start Docker containers for multi-tenant tests
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
MULTI_TENANT=true \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
id: start_docker_multi_tenant
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
- name: Run Multi-Tenant Integration Tests
run: |
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e AUTH_TYPE=cloud \
-e MULTI_TENANT=true \
danswer/danswer-integration:test \
/app/tests/integration/multitenant_tests
continue-on-error: true
id: run_multitenant_tests
- name: Check multi-tenant test results
run: |
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All integration tests passed successfully."
fi
- name: Stop multi-tenant Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
@@ -130,7 +181,7 @@ jobs:
done
echo "Finished waiting for service."
- name: Run integration tests
- name: Run Standard Integration Tests
run: |
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
@@ -145,7 +196,8 @@ jobs:
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
danswer/danswer-integration:test
danswer/danswer-integration:test \
/app/tests/integration/tests
continue-on-error: true
id: run_tests
@@ -158,6 +210,11 @@ jobs:
echo "All integration tests passed successfully."
fi
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Save Docker logs
if: success() || failure()
run: |

View File

@@ -0,0 +1,98 @@
name: Backport on Merge
on:
pull_request:
types: [closed] # Later we check for merge so only PRs that go in can get backported
permissions:
contents: write
actions: write # Allows the workflow to trigger other workflows on push
jobs:
backport:
if: github.event.pull_request.merged == true
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v3
with:
fetch-depth: 0 # Fetch all history for all branches and tags
- name: Set up Git
run: |
git config --global user.name "yuhongsun96"
git config --global user.email "yuhongsun96@gmail.com"
# Configure Git to use the PAT for authentication
git remote set-url origin https://yuhongsun96:${{ secrets.YUHONG_GH_ACTIONS }}@github.com/${{ github.repository }}.git
- name: Check for Backport Checkbox
id: checkbox-check
run: |
PR_BODY="${{ github.event.pull_request.body }}"
if [[ "$PR_BODY" == *"[x] This PR should be backported"* ]]; then
echo "backport=true" >> $GITHUB_OUTPUT
else
echo "backport=false" >> $GITHUB_OUTPUT
fi
- name: List and sort release branches
id: list-branches
run: |
git fetch --all --tags
BRANCHES=$(git for-each-ref --format='%(refname:short)' refs/remotes/origin/release/* | sed 's|origin/release/||' | sort -Vr)
BETA=$(echo "$BRANCHES" | head -n 1)
STABLE=$(echo "$BRANCHES" | head -n 2 | tail -n 1)
echo "beta=release/$BETA" >> $GITHUB_OUTPUT
echo "stable=release/$STABLE" >> $GITHUB_OUTPUT
# Fetch latest tags for beta and stable
LATEST_BETA_TAG=$(git tag -l "v[0-9]*.[0-9]*.[0-9]*-beta.[0-9]*" | grep -E "^v[0-9]+\.[0-9]+\.[0-9]+-beta\.[0-9]+$" | grep -v -- "-cloud" | sort -Vr | head -n 1)
LATEST_STABLE_TAG=$(git tag -l "v[0-9]*.[0-9]*.[0-9]*" | grep -E "^v[0-9]+\.[0-9]+\.[0-9]+$" | sort -Vr | head -n 1)
# Increment latest beta tag
NEW_BETA_TAG=$(echo $LATEST_BETA_TAG | awk -F '[.-]' '{print $1 "." $2 "." $3 "-beta." ($NF+1)}')
# Increment latest stable tag
NEW_STABLE_TAG=$(echo $LATEST_STABLE_TAG | awk -F '.' '{print $1 "." $2 "." ($3+1)}')
echo "latest_beta_tag=$LATEST_BETA_TAG" >> $GITHUB_OUTPUT
echo "latest_stable_tag=$LATEST_STABLE_TAG" >> $GITHUB_OUTPUT
echo "new_beta_tag=$NEW_BETA_TAG" >> $GITHUB_OUTPUT
echo "new_stable_tag=$NEW_STABLE_TAG" >> $GITHUB_OUTPUT
- name: Echo branch and tag information
run: |
echo "Beta branch: ${{ steps.list-branches.outputs.beta }}"
echo "Stable branch: ${{ steps.list-branches.outputs.stable }}"
echo "Latest beta tag: ${{ steps.list-branches.outputs.latest_beta_tag }}"
echo "Latest stable tag: ${{ steps.list-branches.outputs.latest_stable_tag }}"
echo "New beta tag: ${{ steps.list-branches.outputs.new_beta_tag }}"
echo "New stable tag: ${{ steps.list-branches.outputs.new_stable_tag }}"
- name: Trigger Backport
if: steps.checkbox-check.outputs.backport == 'true'
run: |
set -e
echo "Backporting to beta ${{ steps.list-branches.outputs.beta }} and stable ${{ steps.list-branches.outputs.stable }}"
# Fetch all history for all branches and tags
git fetch --prune
# Checkout the beta branch
git checkout ${{ steps.list-branches.outputs.beta }}
# Cherry-pick the merge commit from the merged PR
git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
echo "Cherry-pick to beta failed due to conflicts."
exit 1
}
# Create new beta tag
git tag ${{ steps.list-branches.outputs.new_beta_tag }}
# Push the changes and tag to the beta branch
git push origin ${{ steps.list-branches.outputs.beta }} --force
git push origin ${{ steps.list-branches.outputs.new_beta_tag }}
# Checkout the stable branch
git checkout ${{ steps.list-branches.outputs.stable }}
# Cherry-pick the merge commit from the merged PR
git cherry-pick -m 1 ${{ github.event.pull_request.merge_commit_sha }} || {
echo "Cherry-pick to stable failed due to conflicts."
exit 1
}
# Create new stable tag
git tag ${{ steps.list-branches.outputs.new_stable_tag }}
# Push the changes and tag to the stable branch
git push origin ${{ steps.list-branches.outputs.stable }} --force
git push origin ${{ steps.list-branches.outputs.new_stable_tag }}

View File

@@ -6,19 +6,69 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"compounds": [
{
// Dummy entry used to label the group
"name": "--- Compound ---",
"configurations": [
"--- Individual ---"
],
"presentation": {
"group": "1",
}
},
{
"name": "Run All Danswer Services",
"configurations": [
"Web Server",
"Model Server",
"API Server",
"Indexing",
"Background Jobs",
"Slack Bot"
]
}
"Slack Bot",
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery beat",
],
"presentation": {
"group": "1",
}
},
{
"name": "Web / Model / API",
"configurations": [
"Web Server",
"Model Server",
"API Server",
],
"presentation": {
"group": "1",
}
},
{
"name": "Celery (all)",
"configurations": [
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery beat"
],
"presentation": {
"group": "1",
}
}
],
"configurations": [
{
// Dummy entry used to label the group
"name": "--- Individual ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "2",
"order": 0
}
},
{
"name": "Web Server",
"type": "node",
@@ -29,7 +79,11 @@
"runtimeArgs": [
"run", "dev"
],
"console": "integratedTerminal"
"presentation": {
"group": "2",
},
"console": "integratedTerminal",
"consoleTitle": "Web Server Console"
},
{
"name": "Model Server",
@@ -48,7 +102,11 @@
"--reload",
"--port",
"9000"
]
],
"presentation": {
"group": "2",
},
"consoleTitle": "Model Server Console"
},
{
"name": "API Server",
@@ -68,43 +126,13 @@
"--reload",
"--port",
"8080"
]
],
"presentation": {
"group": "2",
},
"consoleTitle": "API Server Console"
},
{
"name": "Indexing",
"consoleName": "Indexing",
"type": "debugpy",
"request": "launch",
"program": "danswer/background/update.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
}
},
// Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev
{
"name": "Background Jobs",
"consoleName": "Background Jobs",
"type": "debugpy",
"request": "launch",
"program": "scripts/dev_run_background_jobs.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"--no-indexing"
]
},
// For the listner to access the Slack API,
// For the listener to access the Slack API,
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
{
"name": "Slack Bot",
@@ -118,7 +146,151 @@
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
}
},
"presentation": {
"group": "2",
},
"consoleTitle": "Slack Bot Console"
},
{
"name": "Celery primary",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"danswer.background.celery.versioned_apps.primary",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=primary@%n",
"-Q",
"celery",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery primary Console"
},
{
"name": "Celery light",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"danswer.background.celery.versioned_apps.light",
"worker",
"--pool=threads",
"--concurrency=64",
"--prefetch-multiplier=8",
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery light Console"
},
{
"name": "Celery heavy",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"danswer.background.celery.versioned_apps.heavy",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery heavy Console"
},
{
"name": "Celery indexing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"danswer.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=indexing@%n",
"-Q",
"connector_indexing",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery indexing Console"
},
{
"name": "Celery beat",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"danswer.background.celery.versioned_apps.beat",
"beat",
"--loglevel=INFO",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery beat Console"
},
{
"name": "Pytest",
@@ -137,8 +309,22 @@
"-v"
// Specify a sepcific module/test to run or provide nothing to run all tests
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
]
],
"presentation": {
"group": "2",
},
"consoleTitle": "Pytest Console"
},
{
// Dummy entry used to label the group
"name": "--- Tasks ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "3",
"order": 0
}
},
{
"name": "Clear and Restart External Volumes and Containers",
"type": "node",
@@ -147,7 +333,27 @@
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"stopOnEntry": true
}
"stopOnEntry": true,
"presentation": {
"group": "3",
},
},
{
// Celery jobs launched through a single background script (legacy)
// Recommend using the "Celery (all)" compound launch instead.
"name": "Background Jobs",
"consoleName": "Background Jobs",
"type": "debugpy",
"request": "launch",
"program": "scripts/dev_run_background_jobs.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
},
]
}

View File

@@ -68,7 +68,7 @@ We also have built-in support for deployment on Kubernetes. Files for that can b
## 🚧 Roadmap
* Chat/Prompt sharing with specific teammates and user groups.
* Multi-Model model support, chat with images, video etc.
* Multimodal model support, chat with images, video etc.
* Choosing between LLMs and parameters during chat session.
* Tool calling and agent configurations options.
* Organizational understanding and ability to locate and suggest experts from your team.

View File

@@ -13,7 +13,8 @@ from danswer.configs.app_configs import MULTI_TENANT
from danswer.db.engine import build_connection_string
from danswer.db.models import Base
from celery.backends.database.session import ResultModelBase # type: ignore
from danswer.background.celery.celery_app import get_all_tenant_ids
from danswer.db.engine import get_all_tenant_ids
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
# Alembic Config object
config = context.config
@@ -57,11 +58,15 @@ def get_schema_options() -> tuple[str, bool, bool]:
if "=" in pair:
key, value = pair.split("=", 1)
x_args[key.strip()] = value.strip()
schema_name = x_args.get("schema", "public")
schema_name = x_args.get("schema", POSTGRES_DEFAULT_SCHEMA)
create_schema = x_args.get("create_schema", "true").lower() == "true"
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
if MULTI_TENANT and schema_name == "public":
if (
MULTI_TENANT
and schema_name == POSTGRES_DEFAULT_SCHEMA
and not upgrade_all_tenants
):
raise ValueError(
"Cannot run default migrations in public schema when multi-tenancy is enabled. "
"Please specify a tenant-specific schema."

View File

@@ -1,7 +1,9 @@
"""
"""Migrate chat_session and chat_message tables to use UUID primary keys
Revision ID: 6756efa39ada
Revises: 5d12a446f5c0
Create Date: 2024-10-15 17:47:44.108537
"""
from alembic import op
import sqlalchemy as sa
@@ -12,8 +14,6 @@ branch_labels = None
depends_on = None
"""
Migrate chat_session and chat_message tables to use UUID primary keys.
This script:
1. Adds UUID columns to chat_session and chat_message
2. Populates new columns with UUIDs

View File

@@ -31,6 +31,12 @@ def upgrade() -> None:
def downgrade() -> None:
# First, update any null values to a default value
op.execute(
"UPDATE connector_credential_pair SET last_attempt_status = 'NOT_STARTED' WHERE last_attempt_status IS NULL"
)
# Then, make the column non-nullable
op.alter_column(
"connector_credential_pair",
"last_attempt_status",

View File

@@ -1,3 +1,3 @@
import os
__version__ = os.environ.get("DANSWER_VERSION", "") or "0.3-dev"
__version__ = os.environ.get("DANSWER_VERSION", "") or "Development"

View File

@@ -70,3 +70,12 @@ class DocumentAccess(ExternalAccess):
user_groups=set(user_groups),
is_public=is_public,
)
default_public_access = DocumentAccess(
external_user_emails=set(),
external_user_group_ids=set(),
user_emails=set(),
user_groups=set(),
is_public=True,
)

View File

@@ -9,6 +9,7 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
def get_invited_users() -> list[str]:
try:
store = get_kv_store()
return cast(list, store.load(KV_USER_STORE_KEY))
except KvKeyNotFoundError:
return list()

View File

@@ -58,6 +58,7 @@ from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserUpdate
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import DISABLE_VERIFICATION
from danswer.configs.app_configs import EMAIL_FROM
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
@@ -93,7 +94,9 @@ from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import current_tenant_id
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
@@ -132,7 +135,9 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
def user_needs_to_be_verified() -> bool:
# all other auth types besides basic should require users to be
# verified
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
return not DISABLE_VERIFICATION and (
AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
)
def verify_email_is_invited(email: str) -> None:
@@ -187,7 +192,7 @@ def verify_email_domain(email: str) -> None:
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return "public"
return POSTGRES_DEFAULT_SCHEMA
# Implement logic to get tenant_id from the mapping table
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
@@ -235,7 +240,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
) -> User:
try:
tenant_id = (
get_tenant_id_for_email(user_create.email) if MULTI_TENANT else "public"
get_tenant_id_for_email(user_create.email)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
@@ -246,7 +253,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
)
async with get_async_session_with_tenant(tenant_id) as db_session:
token = current_tenant_id.set(tenant_id)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
@@ -285,7 +292,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
else:
raise exceptions.UserAlreadyExists()
current_tenant_id.reset(token)
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return user
async def on_after_login(
@@ -327,7 +334,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Get tenant_id from mapping table
try:
tenant_id = (
get_tenant_id_for_email(account_email) if MULTI_TENANT else "public"
get_tenant_id_for_email(account_email)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
@@ -337,10 +346,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
token = None
async with get_async_session_with_tenant(tenant_id) as db_session:
token = current_tenant_id.set(tenant_id)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
verify_email_in_whitelist(account_email, tenant_id)
verify_email_domain(account_email)
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
self.user_db = tenant_user_db
@@ -426,7 +436,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user.oidc_expiry = None # type: ignore
if token:
current_tenant_id.reset(token)
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return user

View File

@@ -0,0 +1,310 @@
import logging
import multiprocessing
import time
from typing import Any
import sentry_sdk
from celery import Task
from celery.app import trace
from celery.exceptions import WorkerShutdown
from celery.states import READY_STATES
from celery.utils.log import get_task_logger
from celery.worker import strategy # type: ignore
from sentry_sdk.integrations.celery import CeleryIntegration
from danswer.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
from danswer.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.engine import get_all_tenant_ids
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
from danswer.utils.logger import setup_logger
from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
task_logger = get_task_logger(__name__)
if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,
integrations=[CeleryIntegration()],
traces_sample_rate=0.1,
)
logger.info("Sentry initialized")
else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
pass
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict[str, Any] | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
"""We handle this signal in order to remove completed tasks
from their respective tasksets. This allows us to track the progress of document set
and user group syncs.
This function runs after any task completes (both success and failure)
Note that this signal does not fire on a task that failed to complete and is going
to be retried.
This also does not fire if a worker with acks_late=False crashes (which all of our
long running workers are)
"""
if not task:
return
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}")
if state not in READY_STATES:
return
if not task_id:
return
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
if not kwargs:
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
tenant_id = None
else:
tenant_id = kwargs.get("tenant_id")
task_logger.debug(
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
f"{f'for tenant_id={tenant_id}' if tenant_id else ''}"
)
r = get_redis_client(tenant_id=tenant_id)
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
return
if task_id.startswith(RedisDocumentSet.PREFIX):
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
if document_set_id is not None:
rds = RedisDocumentSet(int(document_set_id))
r.srem(rds.taskset_key, task_id)
return
if task_id.startswith(RedisUserGroup.PREFIX):
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
if usergroup_id is not None:
rug = RedisUserGroup(int(usergroup_id))
r.srem(rug.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorDeletion.PREFIX):
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
if cc_pair_id is not None:
rcd = RedisConnectorDeletion(int(cc_pair_id))
r.srem(rcd.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
if cc_pair_id is not None:
rcp = RedisConnectorPruning(int(cc_pair_id))
r.srem(rcp.taskset_key, task_id)
return
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
"""The first signal sent on celery worker startup"""
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
r = get_redis_client(tenant_id=None)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
time_start = time.monotonic()
logger.info("Redis: Readiness check starting.")
while True:
try:
if r.ping():
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
logger.info(
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Redis: Readiness check did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
time.sleep(WAIT_INTERVAL)
logger.info("Redis: Readiness check succeeded. Continuing...")
return
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
logger.info("Running as a secondary celery worker.")
logger.info("Waiting for all tenant primary workers to be ready...")
time_start = time.monotonic()
while True:
tenant_ids = get_all_tenant_ids()
# Check if we have a primary worker lock for each tenant
all_tenants_ready = all(
get_redis_client(tenant_id=tenant_id).exists(
DanswerRedisLocks.PRIMARY_WORKER
)
for tenant_id in tenant_ids
)
if all_tenants_ready:
break
time_elapsed = time.monotonic() - time_start
ready_tenants = sum(
1
for tenant_id in tenant_ids
if get_redis_client(tenant_id=tenant_id).exists(
DanswerRedisLocks.PRIMARY_WORKER
)
)
logger.info(
f"Not all tenant primary workers are ready yet. "
f"Ready tenants: {ready_tenants}/{len(tenant_ids)} "
f"elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Not all tenant primary workers were ready within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
time.sleep(WAIT_INTERVAL)
logger.info("All tenant primary workers are ready. Continuing...")
return
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
task_logger.info("worker_ready signal received.")
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
if not celery_is_worker_primary(sender):
return
if not hasattr(sender, "primary_worker_locks"):
return
for tenant_id, lock in sender.primary_worker_locks.items():
try:
if lock and lock.owned():
logger.debug(f"Attempting to release lock for tenant {tenant_id}")
try:
lock.release()
logger.debug(f"Successfully released lock for tenant {tenant_id}")
except Exception as e:
logger.error(
f"Failed to release lock for tenant {tenant_id}. Error: {str(e)}"
)
finally:
sender.primary_worker_locks[tenant_id] = None
except Exception as e:
logger.error(
f"Error checking lock status for tenant {tenant_id}. Error: {str(e)}"
)
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
# TODO: could unhardcode format and colorize and accept these as options from
# celery's config
# reformats the root logger
root_logger = logging.getLogger()
root_handler = logging.StreamHandler() # Set up a handler for the root logger
root_formatter = ColoredFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
root_handler.setFormatter(root_formatter)
root_logger.addHandler(root_handler) # Apply the handler to the root logger
if logfile:
root_file_handler = logging.FileHandler(logfile)
root_file_formatter = PlainFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
root_file_handler.setFormatter(root_file_formatter)
root_logger.addHandler(root_file_handler)
root_logger.setLevel(loglevel)
# reformats celery's task logger
task_formatter = CeleryTaskColoredFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
task_handler = logging.StreamHandler() # Set up a handler for the task logger
task_handler.setFormatter(task_formatter)
task_logger.addHandler(task_handler) # Apply the handler to the task logger
if logfile:
task_file_handler = logging.FileHandler(logfile)
task_file_formatter = CeleryTaskPlainFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
task_file_handler.setFormatter(task_file_formatter)
task_logger.addHandler(task_file_handler)
task_logger.setLevel(loglevel)
task_logger.propagate = False
# hide celery task received spam
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] received"
strategy.logger.setLevel(logging.WARNING)
# hide celery task succeeded/failed spam
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] succeeded in 0.03137450001668185s: None"
trace.logger.setLevel(logging.WARNING)

View File

@@ -0,0 +1,99 @@
from datetime import timedelta
from typing import Any
from celery import Celery
from celery import signals
from celery.signals import beat_init
import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("danswer.background.celery.configs.beat")
@beat_init.connect
def on_beat_init(sender: Any, **kwargs: Any) -> None:
logger.info("beat_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
SqlEngine.init_engine(pool_size=2, max_overflow=0)
app_base.wait_for_redis(sender, **kwargs)
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
#####
# Celery Beat (Periodic Tasks) Settings
#####
tenant_ids = get_all_tenant_ids()
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": "check_for_vespa_sync_task",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-connector-deletion",
"task": "check_for_connector_deletion_task",
"schedule": timedelta(seconds=60),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-indexing",
"task": "check_for_indexing",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": "check_for_pruning",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "kombu-message-cleanup",
"task": "kombu_message_cleanup_task",
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
{
"name": "monitor-vespa-sync",
"task": "monitor_vespa_sync",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
]
# Build the celery beat schedule dynamically
beat_schedule = {}
for tenant_id in tenant_ids:
for task in tasks_to_schedule:
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
beat_schedule[task_name] = {
"task": task["task"],
"schedule": task["schedule"],
"options": task["options"],
"kwargs": {"tenant_id": tenant_id}, # Must pass tenant_id as an argument
}
# Include any existing beat schedules
existing_beat_schedule = celery_app.conf.beat_schedule or {}
beat_schedule.update(existing_beat_schedule)
# Update the Celery app configuration once
celery_app.conf.beat_schedule = beat_schedule

View File

@@ -0,0 +1,88 @@
import multiprocessing
from typing import Any
from celery import Celery
from celery import signals
from celery import Task
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("danswer.background.celery.configs.heavy")
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
app_base.wait_for_redis(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.pruning",
]
)

View File

@@ -0,0 +1,88 @@
import multiprocessing
from typing import Any
from celery import Celery
from celery import signals
from celery import Task
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("danswer.background.celery.configs.indexing")
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
app_base.wait_for_redis(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.indexing",
]
)

View File

@@ -0,0 +1,89 @@
import multiprocessing
from typing import Any
from celery import Celery
from celery import signals
from celery import Task
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("danswer.background.celery.configs.light")
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.shared",
"danswer.background.celery.tasks.vespa",
]
)

View File

@@ -0,0 +1,300 @@
import multiprocessing
from typing import Any
from celery import bootsteps # type: ignore
from celery import Celery
from celery import signals
from celery import Task
from celery.exceptions import WorkerShutdown
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import danswer.background.celery.apps.app_base as app_base
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import SqlEngine
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("danswer.background.celery.configs.primary")
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
app_base.wait_for_redis(sender, **kwargs)
logger.info("Running as the primary celery worker.")
sender.primary_worker_locks = {}
# This is singleton work that should be done on startup exactly once
# by the primary worker
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
r = get_redis_client(tenant_id=tenant_id)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
# tacking on our own user data to the sender
sender.primary_worker_locks[tenant_id] = lock
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
r.delete(key)
# @worker_process_init.connect
# def on_worker_process_init(sender: Any, **kwargs: Any) -> None:
# """This only runs inside child processes when the worker is in pool=prefork mode.
# This may be technically unnecessary since we're finding prefork pools to be
# unstable and currently aren't planning on using them."""
# logger.info("worker_process_init signal received.")
# SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
# SqlEngine.init_engine(pool_size=5, max_overflow=0)
# # https://stackoverflow.com/questions/43944787/sqlalchemy-celery-with-scoped-session-error
# SqlEngine.get_engine().dispose(close=False)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
class HubPeriodicTask(bootsteps.StartStopStep):
"""Regularly reacquires the primary worker lock outside of the task queue.
Use the task_logger in this class to avoid double logging.
This cannot be done inside a regular beat task because it must run on schedule and
a queue of existing work would starve the task from running.
"""
# it's unclear to me whether using the hub's timer or the bootstep timer is better
requires = {"celery.worker.components:Hub"}
def __init__(self, worker: Any, **kwargs: Any) -> None:
self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds
self.task_tref = None
def start(self, worker: Any) -> None:
if not celery_is_worker_primary(worker):
return
# Access the worker's event loop (hub)
hub = worker.consumer.controller.hub
# Schedule the periodic task
self.task_tref = hub.call_repeatedly(
self.interval, self.run_periodic_task, worker
)
task_logger.info("Scheduled periodic task with hub.")
def run_periodic_task(self, worker: Any) -> None:
try:
if not celery_is_worker_primary(worker):
return
if not hasattr(worker, "primary_worker_locks"):
return
# Retrieve all tenant IDs
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
lock = worker.primary_worker_locks.get(tenant_id)
if not lock:
continue # Skip if no lock for this tenant
r = get_redis_client(tenant_id=tenant_id)
if lock.owned():
task_logger.debug(
f"Reacquiring primary worker lock for tenant {tenant_id}."
)
lock.reacquire()
else:
task_logger.warning(
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
"Reasons could be worker restart or lock expiration."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
)
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
)
worker.primary_worker_locks[tenant_id] = lock
else:
task_logger.error(
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
)
raise TimeoutError(
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
)
except Exception:
task_logger.exception("Periodic task failed.")
def stop(self, worker: Any) -> None:
# Cancel the scheduled task when the worker stops
if self.task_tref:
self.task_tref.cancel()
task_logger.info("Canceled periodic task with hub.")
celery_app.steps["worker"].add(HubPeriodicTask)
celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.connector_deletion",
"danswer.background.celery.tasks.indexing",
"danswer.background.celery.tasks.periodic",
"danswer.background.celery.tasks.pruning",
"danswer.background.celery.tasks.shared",
"danswer.background.celery.tasks.vespa",
]
)

View File

@@ -0,0 +1,26 @@
import logging
from celery import current_task
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
class CeleryTaskPlainFormatter(PlainFormatter):
def format(self, record: logging.LogRecord) -> str:
task = current_task
if task and task.request:
record.__dict__.update(task_id=task.request.id, task_name=task.name)
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
return super().format(record)
class CeleryTaskColoredFormatter(ColoredFormatter):
def format(self, record: logging.LogRecord) -> str:
task = current_task
if task and task.request:
record.__dict__.update(task_id=task.request.id, task_name=task.name)
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
return super().format(record)

View File

@@ -1,619 +0,0 @@
import logging
import multiprocessing
import time
from datetime import timedelta
from typing import Any
import sentry_sdk
from celery import bootsteps # type: ignore
from celery import Celery
from celery import current_task
from celery import signals
from celery import Task
from celery.exceptions import WorkerShutdown
from celery.signals import beat_init
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
from celery.states import READY_STATES
from celery.utils.log import get_task_logger
from sentry_sdk.integrations.celery import CeleryIntegration
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.background.celery.celery_utils import get_all_tenant_ids
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import SqlEngine
from danswer.db.search_settings import get_current_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
from danswer.utils.logger import setup_logger
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
# use this within celery tasks to get celery task specific logging
task_logger = get_task_logger(__name__)
if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,
integrations=[CeleryIntegration()],
traces_sample_rate=0.5,
)
logger.info("Sentry initialized")
else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
celery_app = Celery(__name__)
celery_app.config_from_object(
"danswer.background.celery.celeryconfig"
) # Load configuration from 'celeryconfig.py'
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
tenant_id: str | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
pass
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict[str, Any] | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
"""We handle this signal in order to remove completed tasks
from their respective tasksets. This allows us to track the progress of document set
and user group syncs.
This function runs after any task completes (both success and failure)
Note that this signal does not fire on a task that failed to complete and is going
to be retried.
This also does not fire if a worker with acks_late=False crashes (which all of our
long running workers are)
"""
if not task:
return
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
if not kwargs:
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
tenant_id = None
else:
tenant_id = kwargs.get("tenant_id")
task_logger.debug(
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
f"{f'for tenant_id={tenant_id}' if tenant_id else ''}"
)
if state not in READY_STATES:
return
if not task_id:
return
r = get_redis_client(tenant_id=tenant_id)
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
return
if task_id.startswith(RedisDocumentSet.PREFIX):
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
if document_set_id is not None:
rds = RedisDocumentSet(int(document_set_id))
r.srem(rds.taskset_key, task_id)
return
if task_id.startswith(RedisUserGroup.PREFIX):
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
if usergroup_id is not None:
rug = RedisUserGroup(int(usergroup_id))
r.srem(rug.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorDeletion.PREFIX):
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
if cc_pair_id is not None:
rcd = RedisConnectorDeletion(int(cc_pair_id))
r.srem(rcd.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
if cc_pair_id is not None:
rcp = RedisConnectorPruning(int(cc_pair_id))
r.srem(rcp.taskset_key, task_id)
return
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
"""The first signal sent on celery worker startup"""
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
@beat_init.connect
def on_beat_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
SqlEngine.init_engine(pool_size=2, max_overflow=0)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
# decide some initial startup settings based on the celery worker's hostname
# (set at the command line)'
hostname = sender.hostname
if hostname.startswith("light"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
elif hostname.startswith("heavy"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
elif hostname.startswith("indexing"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
# TODO: why is this necessary for the indexer to do?
with get_session_with_tenant(tenant_id) as db_session:
check_index_swap(db_session=db_session)
search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
if search_settings.provider_type is None:
logger.notice(
"Running a first inference to warm up embedding model"
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(
embedding_model=embedding_model,
)
logger.notice("First inference complete.")
else:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
if not hasattr(sender, "primary_worker_locks"):
sender.primary_worker_locks = {}
tenant_ids = get_all_tenant_ids()
if not celery_is_worker_primary(sender):
logger.info("Running as a secondary celery worker.")
for tenant_id in tenant_ids:
r = get_redis_client(tenant_id=tenant_id)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
time_start = time.monotonic()
logger.notice("Redis: Readiness check starting.")
while True:
# Log all the locks in Redis
all_locks = r.keys("*")
logger.notice(f"Current Redis locks: {all_locks}")
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
break
time_elapsed = time.monotonic() - time_start
logger.info(
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
"Redis: Readiness check did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
time.sleep(WAIT_INTERVAL)
logger.info("Wait for primary worker completed successfully. Continuing...")
return # Exit the function for secondary workers
for tenant_id in tenant_ids:
r = get_redis_client(tenant_id=tenant_id)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
time_start = time.monotonic()
logger.info("Running as the primary celery worker.")
# This is singleton work that should be done on startup exactly once
# by the primary worker
r = get_redis_client(tenant_id=tenant_id)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
sender.primary_worker_locks[tenant_id] = lock
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
r.delete(key)
# @worker_process_init.connect
# def on_worker_process_init(sender: Any, **kwargs: Any) -> None:
# """This only runs inside child processes when the worker is in pool=prefork mode.
# This may be technically unnecessary since we're finding prefork pools to be
# unstable and currently aren't planning on using them."""
# logger.info("worker_process_init signal received.")
# SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
# SqlEngine.init_engine(pool_size=5, max_overflow=0)
# # https://stackoverflow.com/questions/43944787/sqlalchemy-celery-with-scoped-session-error
# SqlEngine.get_engine().dispose(close=False)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
task_logger.info("worker_ready signal received.")
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
if not celery_is_worker_primary(sender):
return
if not hasattr(sender, "primary_worker_locks"):
return
logger.info("Releasing primary worker lock.")
for tenant_id, lock in sender.primary_worker_locks.items():
logger.info(f"Releasing primary worker lock for tenant {tenant_id}.")
if lock.owned():
lock.release()
sender.primary_worker_locks = {}
class CeleryTaskPlainFormatter(PlainFormatter):
def format(self, record: logging.LogRecord) -> str:
task = current_task
if task and task.request:
record.__dict__.update(task_id=task.request.id, task_name=task.name)
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
return super().format(record)
class CeleryTaskColoredFormatter(ColoredFormatter):
def format(self, record: logging.LogRecord) -> str:
task = current_task
if task and task.request:
record.__dict__.update(task_id=task.request.id, task_name=task.name)
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
return super().format(record)
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
# TODO: could unhardcode format and colorize and accept these as options from
# celery's config
# reformats the root logger
root_logger = logging.getLogger()
root_handler = logging.StreamHandler() # Set up a handler for the root logger
root_formatter = ColoredFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
root_handler.setFormatter(root_formatter)
root_logger.addHandler(root_handler) # Apply the handler to the root logger
if logfile:
root_file_handler = logging.FileHandler(logfile)
root_file_formatter = PlainFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
root_file_handler.setFormatter(root_file_formatter)
root_logger.addHandler(root_file_handler)
root_logger.setLevel(loglevel)
# reformats celery's task logger
task_formatter = CeleryTaskColoredFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
task_handler = logging.StreamHandler() # Set up a handler for the task logger
task_handler.setFormatter(task_formatter)
task_logger.addHandler(task_handler) # Apply the handler to the task logger
if logfile:
task_file_handler = logging.FileHandler(logfile)
task_file_formatter = CeleryTaskPlainFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
task_file_handler.setFormatter(task_file_formatter)
task_logger.addHandler(task_file_handler)
task_logger.setLevel(loglevel)
task_logger.propagate = False
class HubPeriodicTask(bootsteps.StartStopStep):
"""Regularly reacquires the primary worker locks for all tenants outside of the task queue.
Use the task_logger in this class to avoid double logging.
This cannot be done inside a regular beat task because it must run on schedule and
a queue of existing work would starve the task from running.
"""
# Requires the Hub component
requires = {"celery.worker.components:Hub"}
def __init__(self, worker: Any, **kwargs: Any) -> None:
super().__init__(worker, **kwargs)
self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds
self.task_tref = None
def start(self, worker: Any) -> None:
if not celery_is_worker_primary(worker):
return
# Access the worker's event loop (hub)
hub = worker.consumer.controller.hub
# Schedule the periodic task
self.task_tref = hub.call_repeatedly(
self.interval, self.run_periodic_task, worker
)
task_logger.info("Scheduled periodic task with hub.")
def run_periodic_task(self, worker: Any) -> None:
try:
if not celery_is_worker_primary(worker):
return
if not hasattr(worker, "primary_worker_locks"):
return
# Retrieve all tenant IDs
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
lock = worker.primary_worker_locks.get(tenant_id)
if not lock:
continue # Skip if no lock for this tenant
r = get_redis_client(tenant_id=tenant_id)
if lock.owned():
task_logger.debug(
f"Reacquiring primary worker lock for tenant {tenant_id}."
)
lock.reacquire()
else:
task_logger.warning(
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
"Reasons could be worker restart or lock expiration."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
)
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
)
worker.primary_worker_locks[tenant_id] = lock
else:
task_logger.error(
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
)
raise TimeoutError(
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
)
except Exception as e:
task_logger.error(f"Error in periodic task: {e}")
def stop(self, worker: Any) -> None:
# Cancel the scheduled task when the worker stops
if self.task_tref:
self.task_tref.cancel()
task_logger.info("Canceled periodic task with hub.")
celery_app.steps["worker"].add(HubPeriodicTask)
celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.connector_deletion",
"danswer.background.celery.tasks.indexing",
"danswer.background.celery.tasks.periodic",
"danswer.background.celery.tasks.pruning",
"danswer.background.celery.tasks.shared",
"danswer.background.celery.tasks.vespa",
]
)
#####
# Celery Beat (Periodic Tasks) Settings
#####
tenant_ids = get_all_tenant_ids()
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": "check_for_vespa_sync_task",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-connector-deletion",
"task": "check_for_connector_deletion_task",
"schedule": timedelta(seconds=60),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-indexing",
"task": "check_for_indexing",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": "check_for_pruning",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "kombu-message-cleanup",
"task": "kombu_message_cleanup_task",
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
{
"name": "monitor-vespa-sync",
"task": "monitor_vespa_sync",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
]
# Build the celery beat schedule dynamically
beat_schedule = {}
for id in tenant_ids:
for task in tasks_to_schedule:
task_name = f"{task['name']}-{id}" # Unique name for each scheduled task
beat_schedule[task_name] = {
"task": task["task"],
"schedule": task["schedule"],
"options": task["options"],
"kwargs": {"tenant_id": id}, # Must pass tenant_id as an argument
}
# Include any existing beat schedules
existing_beat_schedule = celery_app.conf.beat_schedule or {}
beat_schedule.update(existing_beat_schedule)
# Update the Celery app configuration once
celery_app.conf.beat_schedule = beat_schedule

View File

@@ -10,7 +10,7 @@ from celery import Celery
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.celeryconfig import CELERY_SEPARATOR
from danswer.background.celery.configs.base import CELERY_SEPARATOR
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
@@ -313,6 +313,8 @@ class RedisConnectorDeletion(RedisObjectHelper):
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
"""Returns None if the cc_pair doesn't exist.
Otherwise, returns an int with the number of generated tasks."""
last_lock_time = time.monotonic()
async_results = []
@@ -465,14 +467,8 @@ class RedisConnectorPruning(RedisObjectHelper):
return len(async_results)
def is_pruning(self, db_session: Session, redis_client: Redis) -> bool:
def is_pruning(self, redis_client: Redis) -> bool:
"""A single example of a helper method being refactored into the redis helper"""
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=int(self._id), db_session=db_session
)
if not cc_pair:
raise ValueError(f"cc_pair_id {self._id} does not exist.")
if redis_client.exists(self.fence_key):
return True
@@ -538,6 +534,36 @@ class RedisConnectorIndexing(RedisObjectHelper):
) -> int | None:
return None
def is_indexing(self, redis_client: Redis) -> bool:
"""A single example of a helper method being refactored into the redis helper"""
if redis_client.exists(self.fence_key):
return True
return False
class RedisConnectorStop(RedisObjectHelper):
"""Used to signal any running tasks for a connector to stop. We should refactor
connector related redis helpers into a single class.
"""
PREFIX = "connectorstop"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
return None
def celery_get_queue_length(queue: str, r: Redis) -> int:
"""This is a redis specific way to get the length of a celery queue.

View File

@@ -1,25 +1,21 @@
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from typing import Any
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.constants import TENANT_ID_PREFIX
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from danswer.connectors.interfaces import BaseConnector
from danswer.connectors.interfaces import IdConnector
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import Document
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import TaskStatus
from danswer.db.models import TaskQueueState
from danswer.redis.redis_pool import get_redis_client
@@ -75,13 +71,15 @@ def get_deletion_attempt_snapshot(
)
def document_batch_to_ids(doc_batch: list[Document]) -> set[str]:
def document_batch_to_ids(
doc_batch: list[Document],
) -> set[str]:
return {doc.id for doc in doc_batch}
def extract_ids_from_runnable_connector(
runnable_connector: BaseConnector,
progress_callback: Callable[[int], None] | None = None,
callback: RunIndexingCallbackInterface | None = None,
) -> set[str]:
"""
If the PruneConnector hasnt been implemented for the given connector, just pull
@@ -91,10 +89,13 @@ def extract_ids_from_runnable_connector(
"""
all_connector_doc_ids: set[str] = set()
if isinstance(runnable_connector, SlimConnector):
for metadata_batch in runnable_connector.retrieve_all_slim_documents():
all_connector_doc_ids.update({doc.id for doc in metadata_batch})
doc_batch_generator = None
if isinstance(runnable_connector, IdConnector):
all_connector_doc_ids = runnable_connector.retrieve_all_source_ids()
elif isinstance(runnable_connector, LoadConnector):
if isinstance(runnable_connector, LoadConnector):
doc_batch_generator = runnable_connector.load_from_state()
elif isinstance(runnable_connector, PollConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
@@ -103,16 +104,17 @@ def extract_ids_from_runnable_connector(
else:
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
if doc_batch_generator:
doc_batch_processing_func = document_batch_to_ids
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
doc_batch_processing_func = rate_limit_builder(
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(document_batch_to_ids)
for doc_batch in doc_batch_generator:
if progress_callback:
progress_callback(len(doc_batch))
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
doc_batch_processing_func = document_batch_to_ids
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
doc_batch_processing_func = rate_limit_builder(
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(document_batch_to_ids)
for doc_batch in doc_batch_generator:
if callback:
if callback.should_stop():
raise RuntimeError("Stop signal received")
callback.progress(len(doc_batch))
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
return all_connector_doc_ids
@@ -133,33 +135,10 @@ def celery_is_listening_to_queue(worker: Any, name: str) -> bool:
def celery_is_worker_primary(worker: Any) -> bool:
"""There are multiple approaches that could be taken to determine if a celery worker
is 'primary', as defined by us. But the way we do it is to check the hostname set
for the celery worker, which can be done either in celeryconfig.py or on the
for the celery worker, which can be done on the
command line with '--hostname'."""
hostname = worker.hostname
if hostname.startswith("primary"):
return True
return False
def get_all_tenant_ids() -> list[str] | list[None]:
if not MULTI_TENANT:
return [None]
with get_session_with_tenant(tenant_id="public") as session:
result = session.execute(
text(
"""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
)
)
tenant_ids = [row[0] for row in result]
valid_tenants = [
tenant
for tenant in tenant_ids
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
]
return valid_tenants

View File

@@ -31,21 +31,10 @@ if REDIS_SSL:
if REDIS_SSL_CA_CERTS:
SSL_QUERY_PARAMS += f"&ssl_ca_certs={REDIS_SSL_CA_CERTS}"
# region Broker settings
# example celery_broker_url: "redis://:password@localhost:6379/15"
broker_url = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY_RESULT_BACKEND}{SSL_QUERY_PARAMS}"
# NOTE: prefetch 4 is significantly faster than prefetch 1 for small tasks
# however, prefetching is bad when tasks are lengthy as those tasks
# can stall other tasks.
worker_prefetch_multiplier = 4
# Leaving this to the default of True may cause double logging since both our own app
# and celery think they are controlling the logger.
# TODO: Configure celery's logger entirely manually and set this to False
# worker_hijack_root_logger = False
broker_connection_retry_on_startup = True
broker_pool_limit = CELERY_BROKER_POOL_LIMIT
@@ -60,6 +49,7 @@ broker_transport_options = {
"socket_keepalive": True,
"socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS,
}
# endregion
# redis backend settings
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
@@ -73,10 +63,19 @@ redis_backend_health_check_interval = REDIS_HEALTH_CHECK_INTERVAL
task_default_priority = DanswerCeleryPriority.MEDIUM
task_acks_late = True
# region Task result backend settings
# It's possible we don't even need celery's result backend, in which case all of the optimization below
# might be irrelevant
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY_RESULT_BACKEND}{SSL_QUERY_PARAMS}"
result_expires = CELERY_RESULT_EXPIRES # 86400 seconds is the default
# endregion
# Leaving this to the default of True may cause double logging since both our own app
# and celery think they are controlling the logger.
# TODO: Configure celery's logger entirely manually and set this to False
# worker_hijack_root_logger = False
# region Notes on serialization performance
# Option 0: Defaults (json serializer, no compression)
# about 1.5 KB per queued task. 1KB in queue, 400B for result, 100 as a child entry in generator result
@@ -102,3 +101,4 @@ result_expires = CELERY_RESULT_EXPIRES # 86400 seconds is the default
# task_serializer = "pickle-bzip2"
# result_serializer = "pickle-bzip2"
# accept_content=["pickle", "pickle-bzip2"]
# endregion

View File

@@ -0,0 +1,14 @@
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
import danswer.background.celery.configs.base as shared_config
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default

View File

@@ -0,0 +1,20 @@
import danswer.background.celery.configs.base as shared_config
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
worker_concurrency = 4
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -0,0 +1,21 @@
import danswer.background.celery.configs.base as shared_config
from danswer.configs.app_configs import CELERY_WORKER_INDEXING_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -0,0 +1,22 @@
import danswer.background.celery.configs.base as shared_config
from danswer.configs.app_configs import CELERY_WORKER_LIGHT_CONCURRENCY
from danswer.configs.app_configs import CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
worker_concurrency = CELERY_WORKER_LIGHT_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER

View File

@@ -0,0 +1,20 @@
import danswer.background.celery.configs.base as shared_config
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
worker_concurrency = 4
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -1,29 +1,45 @@
from datetime import datetime
from datetime import timezone
import redis
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import ObjectDeletedError
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
RedisConnectorDeletionFenceData,
)
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.db.search_settings import get_all_search_settings
from danswer.redis.redis_pool import get_redis_client
class TaskDependencyError(RuntimeError):
"""Raised to the caller to indicate dependent tasks are running that would interfere
with connector deletion."""
@shared_task(
name="check_for_connector_deletion_task",
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
)
def check_for_connector_deletion_task(*, tenant_id: str | None) -> None:
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
@@ -36,12 +52,30 @@ def check_for_connector_deletion_task(*, tenant_id: str | None) -> None:
if not lock_beat.acquire(blocking=False):
return
# collect cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
try_generate_document_cc_pair_cleanup_tasks(
cc_pair, db_session, r, lock_beat, tenant_id
)
cc_pair_ids.append(cc_pair.id)
# try running cleanup on the cc_pair_ids
for cc_pair_id in cc_pair_ids:
with get_session_with_tenant(tenant_id) as db_session:
rcs = RedisConnectorStop(cc_pair_id)
try:
try_generate_document_cc_pair_cleanup_tasks(
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
)
except TaskDependencyError as e:
# this means we wanted to start deleting but dependent tasks were running
# Leave a stop signal to clear indexing and pruning tasks more quickly
task_logger.info(str(e))
r.set(rcs.fence_key, cc_pair_id)
else:
# clear the stop signal if it exists ... no longer needed
r.delete(rcs.fence_key)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -54,7 +88,8 @@ def check_for_connector_deletion_task(*, tenant_id: str | None) -> None:
def try_generate_document_cc_pair_cleanup_tasks(
cc_pair: ConnectorCredentialPair,
app: Celery,
cc_pair_id: int,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
@@ -63,51 +98,87 @@ def try_generate_document_cc_pair_cleanup_tasks(
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Note that syncing can still be required even if the number of sync tasks generated is zero.
Returns None if no syncing is required.
Will raise TaskDependencyError if dependent tasks such as indexing and pruning are
still running. In our case, the caller reacts by setting a stop signal in Redis to
exit those tasks as quickly as possible.
"""
lock_beat.reacquire()
rcd = RedisConnectorDeletion(cc_pair.id)
rcd = RedisConnectorDeletion(cc_pair_id)
# don't generate sync tasks if tasks are still pending
if r.exists(rcd.fence_key):
return None
# we need to refresh the state of the object inside the fence
# we need to load the state of the object inside the fence
# to avoid a race condition with db.commit/fence deletion
# at the end of this taskset
try:
db_session.refresh(cc_pair)
except ObjectDeletedError:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair:
return None
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rcd.taskset_key)
# set a basic fence to start
fence_value = RedisConnectorDeletionFenceData(
num_tasks=None,
submitted=datetime.now(timezone.utc),
)
r.set(rcd.fence_key, fence_value.model_dump_json())
# Add all documents that need to be updated into the queue
task_logger.info(
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
)
tasks_generated = rcd.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
try:
# do not proceed if connector indexing or connector pruning are running
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
rci = RedisConnectorIndexing(cc_pair_id, search_settings.id)
if r.get(rci.fence_key):
raise TaskDependencyError(
f"Connector deletion - Delayed (indexing in progress): "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings.id}"
)
rcp = RedisConnectorPruning(cc_pair_id)
if r.get(rcp.fence_key):
raise TaskDependencyError(
f"Connector deletion - Delayed (pruning in progress): "
f"cc_pair={cc_pair_id}"
)
# add tasks to celery and build up the task set to monitor in redis
r.delete(rcd.taskset_key)
# Add all documents that need to be updated into the queue
task_logger.info(
f"RedisConnectorDeletion.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id)
if tasks_generated is None:
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
except TaskDependencyError:
r.delete(rcd.fence_key)
raise
except Exception:
task_logger.exception("Unexpected exception")
r.delete(rcd.fence_key)
return None
else:
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
task_logger.info(
f"RedisConnectorDeletion.generate_tasks finished. "
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
)
task_logger.info(
f"RedisConnectorDeletion.generate_tasks finished. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
fence_value.num_tasks = tasks_generated
r.set(rcd.fence_key, fence_value.model_dump_json())
# set this only after all tasks have been added
r.set(rcd.fence_key, tasks_generated)
return tasks_generated

View File

@@ -5,18 +5,26 @@ from time import sleep
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.tasks.shared.tasks import RedisConnectorIndexingFenceData
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
RedisConnectorIndexingFenceData,
)
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
@@ -40,18 +48,48 @@ from danswer.db.models import IndexAttempt
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
logger = setup_logger()
class RunIndexingCallback(RunIndexingCallbackInterface):
def __init__(
self,
stop_key: str,
generator_progress_key: str,
redis_lock: redis.lock.Lock,
redis_client: Redis,
):
super().__init__()
self.redis_lock: redis.lock.Lock = redis_lock
self.stop_key: str = stop_key
self.generator_progress_key: str = generator_progress_key
self.redis_client = redis_client
def should_stop(self) -> bool:
if self.redis_client.exists(self.stop_key):
return True
return False
def progress(self, amount: int) -> None:
self.redis_lock.reacquire()
self.redis_client.incrby(self.generator_progress_key, amount)
@shared_task(
name="check_for_indexing",
soft_time_limit=300,
bind=True,
)
def check_for_indexing(*, tenant_id: str | None) -> int | None:
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
tasks_created = 0
r = get_redis_client(tenant_id=tenant_id)
@@ -64,31 +102,54 @@ def check_for_indexing(*, tenant_id: str | None) -> int | None:
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
task_logger.info(f"Lock acquired for tenant (Y): {tenant_id}")
return None
else:
task_logger.info(f"Lock acquired for tenant (N): {tenant_id}")
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
check_index_swap(db_session=db_session)
current_search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
if current_search_settings.provider_type is None and not MULTI_TENANT:
embedding_model = EmbeddingModel.from_db_model(
search_settings=current_search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
)
warm_up_bi_encoder(
embedding_model=embedding_model,
)
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings = [primary_search_settings]
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings.append(secondary_search_settings)
cc_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
for cc_pair_id in cc_pair_ids:
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings = [primary_search_settings]
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings.append(secondary_search_settings)
for search_settings_instance in search_settings:
rci = RedisConnectorIndexing(
cc_pair.id, search_settings_instance.id
cc_pair_id, search_settings_instance.id
)
if r.exists(rci.fence_key):
continue
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id, db_session
)
if not cc_pair:
continue
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, search_settings_instance.id, db_session
)
@@ -104,6 +165,7 @@ def check_for_indexing(*, tenant_id: str | None) -> int | None:
# using a task queue and only allowing one task per cc_pair/search_setting
# prevents us from starving out certain attempts
attempt_id = try_creating_indexing_task(
self.app,
cc_pair,
search_settings_instance,
False,
@@ -213,6 +275,7 @@ def _should_index(
def try_creating_indexing_task(
celery_app: Celery,
cc_pair: ConnectorCredentialPair,
search_settings: SearchSettings,
reindex: bool,
@@ -248,6 +311,10 @@ def try_creating_indexing_task(
return None
# skip indexing if the cc_pair is deleting
rcd = RedisConnectorDeletion(cc_pair.id)
if r.exists(rcd.fence_key):
return None
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
@@ -258,7 +325,19 @@ def try_creating_indexing_task(
custom_task_id = f"{rci.generator_task_id_prefix}_{uuid4()}"
# create the index attempt ... just for tracking purposes
# set a basic fence to start
fence_value = RedisConnectorIndexingFenceData(
index_attempt_id=None,
started=None,
submitted=datetime.now(timezone.utc),
celery_task_id=None,
)
r.set(rci.fence_key, fence_value.model_dump_json())
# create the index attempt for tracking purposes
# code elsewhere checks for index attempts without an associated redis key
# and cleans them up
# therefore we must create the attempt and the task after the fence goes up
index_attempt_id = create_index_attempt(
cc_pair.id,
search_settings.id,
@@ -279,17 +358,14 @@ def try_creating_indexing_task(
priority=DanswerCeleryPriority.MEDIUM,
)
if not result:
return None
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
# set this only after all tasks have been added
fence_value = RedisConnectorIndexingFenceData(
index_attempt_id=index_attempt_id,
started=None,
submitted=datetime.now(timezone.utc),
celery_task_id=result.id,
)
# now fill out the fence with the rest of the data
fence_value.index_attempt_id = index_attempt_id
fence_value.celery_task_id = result.id
r.set(rci.fence_key, fence_value.model_dump_json())
except Exception:
r.delete(rci.fence_key)
task_logger.exception("Unexpected exception")
return None
finally:
@@ -372,8 +448,56 @@ def connector_indexing_task(
r = get_redis_client(tenant_id=tenant_id)
rcd = RedisConnectorDeletion(cc_pair_id)
if r.exists(rcd.fence_key):
raise RuntimeError(
f"Indexing will not start because connector deletion is in progress: "
f"cc_pair={cc_pair_id} "
f"fence={rcd.fence_key}"
)
rcs = RedisConnectorStop(cc_pair_id)
if r.exists(rcs.fence_key):
raise RuntimeError(
f"Indexing will not start because a connector stop signal was detected: "
f"cc_pair={cc_pair_id} "
f"fence={rcs.fence_key}"
)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
while True:
# read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rci.fence_key))
if fence_value is None:
task_logger.info(
f"connector_indexing_task: fence_value not found: fence={rci.fence_key}"
)
raise RuntimeError(f"Fence not found: fence={rci.fence_key}")
try:
fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
cast(str, fence_json)
)
except ValueError:
task_logger.exception(
f"connector_indexing_task: fence_data not decodeable: fence={rci.fence_key}"
)
raise
if fence_data.index_attempt_id is None or fence_data.celery_task_id is None:
task_logger.info(
f"connector_indexing_task - Waiting for fence: fence={rci.fence_key}"
)
sleep(1)
continue
task_logger.info(
f"connector_indexing_task - Fence found, continuing...: fence={rci.fence_key}"
)
break
lock = r.lock(
rci.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
@@ -383,17 +507,20 @@ def connector_indexing_task(
if not acquired:
task_logger.warning(
f"Indexing task already running, exiting...: "
f"cc_pair_id={cc_pair_id} search_settings_id={search_settings_id}"
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
)
# r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value)
return None
fence_data.started = datetime.now(timezone.utc)
r.set(rci.fence_key, fence_data.model_dump_json())
try:
with get_session_with_tenant(tenant_id) as db_session:
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
raise ValueError(
f"Index attempt not found: index_attempt_id={index_attempt_id}"
f"Index attempt not found: index_attempt={index_attempt_id}"
)
cc_pair = get_connector_credential_pair_from_id(
@@ -402,31 +529,31 @@ def connector_indexing_task(
)
if not cc_pair:
raise ValueError(f"cc_pair not found: cc_pair_id={cc_pair_id}")
raise ValueError(f"cc_pair not found: cc_pair={cc_pair_id}")
if not cc_pair.connector:
raise ValueError(
f"Connector not found: connector_id={cc_pair.connector_id}"
f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}"
)
if not cc_pair.credential:
raise ValueError(
f"Credential not found: credential_id={cc_pair.credential_id}"
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}"
)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
# Define the callback function
def redis_increment_callback(amount: int) -> None:
lock.reacquire()
r.incrby(rci.generator_progress_key, amount)
# define a callback class
callback = RunIndexingCallback(
rcs.fence_key, rci.generator_progress_key, lock, r
)
run_indexing_entrypoint(
index_attempt_id,
tenant_id,
cc_pair_id,
is_ee,
progress_callback=redis_increment_callback,
callback=callback,
)
# get back the total number of indexed docs and return it
@@ -439,9 +566,10 @@ def connector_indexing_task(
r.set(rci.generator_complete_key, HTTPStatus.OK.value)
except Exception as e:
task_logger.exception(f"Failed to run indexing for cc_pair_id={cc_pair_id}.")
task_logger.exception(f"Indexing failed: cc_pair={cc_pair_id}")
if attempt:
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
r.delete(rci.generator_lock_key)
r.delete(rci.generator_progress_key)

View File

@@ -11,7 +11,7 @@ from sqlalchemy import inspect
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import PostgresAdvisoryLocks
from danswer.db.engine import get_session_with_tenant

View File

@@ -3,15 +3,19 @@ from datetime import timedelta
from datetime import timezone
from uuid import uuid4
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
@@ -23,6 +27,7 @@ from danswer.configs.constants import DanswerRedisLocks
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
@@ -37,8 +42,9 @@ logger = setup_logger()
@shared_task(
name="check_for_pruning",
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_pruning(*, tenant_id: str | None) -> None:
def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
@@ -51,15 +57,24 @@ def check_for_pruning(*, tenant_id: str | None) -> None:
if not lock_beat.acquire(blocking=False):
return
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
lock_beat.reacquire()
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair:
continue
if not is_pruning_due(cc_pair, db_session, r):
continue
tasks_created = try_creating_prune_generator_task(
cc_pair, db_session, r, tenant_id
self.app, cc_pair, db_session, r, tenant_id
)
if not tasks_created:
continue
@@ -118,6 +133,7 @@ def is_pruning_due(
def try_creating_prune_generator_task(
celery_app: Celery,
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
@@ -155,6 +171,10 @@ def try_creating_prune_generator_task(
return None
# skip pruning if the cc_pair is deleting
rcd = RedisConnectorDeletion(cc_pair.id)
if r.exists(rcd.fence_key):
return None
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
@@ -196,9 +216,14 @@ def try_creating_prune_generator_task(
soft_time_limit=JOB_TIMEOUT,
track_started=True,
trail=False,
bind=True,
)
def connector_pruning_generator_task(
cc_pair_id: int, connector_id: int, credential_id: int, tenant_id: str | None
self: Task,
cc_pair_id: int,
connector_id: int,
credential_id: int,
tenant_id: str | None,
) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
@@ -216,7 +241,7 @@ def connector_pruning_generator_task(
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"Pruning task already running, exiting...: cc_pair_id={cc_pair_id}"
f"Pruning task already running, exiting...: cc_pair={cc_pair_id}"
)
return None
@@ -234,22 +259,22 @@ def connector_pruning_generator_task(
)
return
# Define the callback function
def redis_increment_callback(amount: int) -> None:
lock.reacquire()
r.incrby(rcp.generator_progress_key, amount)
runnable_connector = instantiate_connector(
db_session,
cc_pair.connector.source,
InputType.PRUNE,
InputType.SLIM_RETRIEVAL,
cc_pair.connector.connector_specific_config,
cc_pair.credential,
)
rcs = RedisConnectorStop(cc_pair_id)
callback = RunIndexingCallback(
rcs.fence_key, rcp.generator_progress_key, lock, r
)
# a list of docs in the source
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
runnable_connector, redis_increment_callback
runnable_connector, callback
)
# a list of docs in our local index
@@ -267,7 +292,7 @@ def connector_pruning_generator_task(
task_logger.info(
f"Pruning set collected: "
f"cc_pair_id={cc_pair.id} "
f"cc_pair={cc_pair_id} "
f"docs_to_remove={len(doc_ids_to_remove)} "
f"doc_source={cc_pair.connector.source}"
)
@@ -275,22 +300,24 @@ def connector_pruning_generator_task(
rcp.documents_to_prune = set(doc_ids_to_remove)
task_logger.info(
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}"
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair_id}"
)
tasks_generated = rcp.generate_tasks(
celery_app, db_session, r, None, tenant_id
self.app, db_session, r, None, tenant_id
)
if tasks_generated is None:
return None
task_logger.info(
f"RedisConnectorPruning.generate_tasks finished. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
)
r.set(rcp.generator_complete_key, tasks_generated)
except Exception as e:
task_logger.exception(f"Failed to run pruning for connector id {connector_id}.")
task_logger.exception(
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
)
r.delete(rcp.generator_progress_key)
r.delete(rcp.taskset_key)

View File

@@ -0,0 +1,8 @@
from datetime import datetime
from pydantic import BaseModel
class RedisConnectorDeletionFenceData(BaseModel):
num_tasks: int | None
submitted: datetime

View File

@@ -0,0 +1,10 @@
from datetime import datetime
from pydantic import BaseModel
class RedisConnectorIndexingFenceData(BaseModel):
index_attempt_id: int | None
started: datetime | None
submitted: datetime
celery_task_id: str | None

View File

@@ -1,16 +1,14 @@
from datetime import datetime
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from pydantic import BaseModel
from danswer.access.access import get_access_for_document
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.apps.app_base import task_logger
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document
from danswer.db.document import get_document_connector_count
from danswer.db.document import mark_document_as_modified
from danswer.db.document import mark_document_as_synced
from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.engine import get_session_with_tenant
@@ -19,12 +17,7 @@ from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
class RedisConnectorIndexingFenceData(BaseModel):
index_attempt_id: int
started: datetime | None
submitted: datetime
celery_task_id: str
DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3
@shared_task(
@@ -32,7 +25,7 @@ class RedisConnectorIndexingFenceData(BaseModel):
bind=True,
soft_time_limit=45,
time_limit=60,
max_retries=3,
max_retries=DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES,
)
def document_by_cc_pair_cleanup_task(
self: Task,
@@ -56,7 +49,7 @@ def document_by_cc_pair_cleanup_task(
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
task_logger.info(f"document_id={document_id}")
task_logger.info(f"tenant_id={tenant_id} document_id={document_id}")
try:
with get_session_with_tenant(tenant_id) as db_session:
@@ -122,6 +115,8 @@ def document_by_cc_pair_cleanup_task(
else:
pass
db_session.commit()
task_logger.info(
f"tenant_id={tenant_id} "
f"document_id={document_id} "
@@ -129,16 +124,27 @@ def document_by_cc_pair_cleanup_task(
f"refcount={count} "
f"chunks={chunks_affected}"
)
db_session.commit()
except SoftTimeLimitExceeded:
task_logger.info(
f"SoftTimeLimitExceeded exception. tenant_id={tenant_id} doc_id={document_id}"
)
return False
except Exception as e:
task_logger.exception("Unexpected exception")
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
if self.request.retries < DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES:
# Still retrying. Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
else:
# This is the last attempt! mark the document as dirty in the db so that it
# eventually gets fixed out of band via stale document reconciliation
task_logger.info(
f"Max retries reached. Marking doc as dirty for reconciliation: "
f"tenant_id={tenant_id} document_id={document_id}"
)
with get_session_with_tenant(tenant_id):
mark_document_as_modified(document_id, db_session)
return False
return True

View File

@@ -5,6 +5,7 @@ from http import HTTPStatus
from typing import cast
import redis
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
@@ -14,8 +15,7 @@ from redis import Redis
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_document
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import celery_get_queue_length
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
@@ -23,7 +23,12 @@ from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.tasks.shared.tasks import RedisConnectorIndexingFenceData
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
RedisConnectorDeletionFenceData,
)
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
RedisConnectorIndexingFenceData,
)
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryQueues
@@ -54,7 +59,6 @@ from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import DocumentSet
from danswer.db.models import IndexAttempt
from danswer.db.models import UserGroup
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
@@ -76,8 +80,9 @@ logger = setup_logger()
name="check_for_vespa_sync_task",
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
)
def check_for_vespa_sync_task(*, tenant_id: str | None) -> None:
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
@@ -94,35 +99,53 @@ def check_for_vespa_sync_task(*, tenant_id: str | None) -> None:
return
with get_session_with_tenant(tenant_id) as db_session:
try_generate_stale_document_sync_tasks(db_session, r, lock_beat, tenant_id)
try_generate_stale_document_sync_tasks(
self.app, db_session, r, lock_beat, tenant_id
)
# region document set scan
document_set_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
# check if any document sets are not synced
document_set_info = fetch_document_sets(
user_id=None, db_session=db_session, include_outdated=True
)
for document_set, _ in document_set_info:
document_set_ids.append(document_set.id)
for document_set_id in document_set_ids:
with get_session_with_tenant(tenant_id) as db_session:
try_generate_document_set_sync_tasks(
document_set, db_session, r, lock_beat, tenant_id
self.app, document_set_id, db_session, r, lock_beat, tenant_id
)
# endregion
# check if any user groups are not synced
if global_version.is_ee_version():
try:
fetch_user_groups = fetch_versioned_implementation(
"danswer.db.user_group", "fetch_user_groups"
)
# check if any user groups are not synced
if global_version.is_ee_version():
try:
fetch_user_groups = fetch_versioned_implementation(
"danswer.db.user_group", "fetch_user_groups"
)
except ModuleNotFoundError:
# Always exceptions on the MIT version, which is expected
# We shouldn't actually get here if the ee version check works
pass
else:
usergroup_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
user_groups = fetch_user_groups(
db_session=db_session, only_up_to_date=False
)
for usergroup in user_groups:
usergroup_ids.append(usergroup.id)
for usergroup_id in usergroup_ids:
with get_session_with_tenant(tenant_id) as db_session:
try_generate_user_group_sync_tasks(
usergroup, db_session, r, lock_beat, tenant_id
self.app, usergroup_id, db_session, r, lock_beat, tenant_id
)
except ModuleNotFoundError:
# Always exceptions on the MIT version, which is expected
# We shouldn't actually get here if the ee version check works
pass
except SoftTimeLimitExceeded:
task_logger.info(
@@ -136,7 +159,11 @@ def check_for_vespa_sync_task(*, tenant_id: str | None) -> None:
def try_generate_stale_document_sync_tasks(
db_session: Session, r: Redis, lock_beat: redis.lock.Lock, tenant_id: str | None
celery_app: Celery,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
# the fence is up, do nothing
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
@@ -187,7 +214,8 @@ def try_generate_stale_document_sync_tasks(
def try_generate_document_set_sync_tasks(
document_set: DocumentSet,
celery_app: Celery,
document_set_id: int,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
@@ -195,7 +223,7 @@ def try_generate_document_set_sync_tasks(
) -> int | None:
lock_beat.reacquire()
rds = RedisDocumentSet(document_set.id)
rds = RedisDocumentSet(document_set_id)
# don't generate document set sync tasks if tasks are still pending
if r.exists(rds.fence_key):
@@ -203,7 +231,10 @@ def try_generate_document_set_sync_tasks(
# don't generate sync tasks if we're up to date
# race condition with the monitor/cleanup function if we use a cached result!
db_session.refresh(document_set)
document_set = get_document_set_by_id(db_session, document_set_id)
if not document_set:
return None
if document_set.is_up_to_date:
return None
@@ -238,7 +269,8 @@ def try_generate_document_set_sync_tasks(
def try_generate_user_group_sync_tasks(
usergroup: UserGroup,
celery_app: Celery,
usergroup_id: int,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
@@ -246,14 +278,21 @@ def try_generate_user_group_sync_tasks(
) -> int | None:
lock_beat.reacquire()
rug = RedisUserGroup(usergroup.id)
rug = RedisUserGroup(usergroup_id)
# don't generate sync tasks if tasks are still pending
if r.exists(rug.fence_key):
return None
# race condition with the monitor/cleanup function if we use a cached result!
db_session.refresh(usergroup)
fetch_user_group = fetch_versioned_implementation(
"danswer.db.user_group", "fetch_user_group"
)
usergroup = fetch_user_group(db_session, usergroup_id)
if not usergroup:
return None
if usergroup.is_up_to_date:
return None
@@ -332,7 +371,7 @@ def monitor_document_set_taskset(
count = cast(int, r.scard(rds.taskset_key))
task_logger.info(
f"Document set sync progress: document_set_id={document_set_id} "
f"Document set sync progress: document_set={document_set_id} "
f"remaining={count} initial={initial_count}"
)
if count > 0:
@@ -347,12 +386,12 @@ def monitor_document_set_taskset(
# if there are no connectors, then delete the document set.
delete_document_set(document_set_row=document_set, db_session=db_session)
task_logger.info(
f"Successfully deleted document set with ID: '{document_set_id}'!"
f"Successfully deleted document set: document_set={document_set_id}"
)
else:
mark_document_set_as_synced(document_set_id, db_session)
task_logger.info(
f"Successfully synced document set with ID: '{document_set_id}'!"
f"Successfully synced document set: document_set={document_set_id}"
)
r.delete(rds.taskset_key)
@@ -372,19 +411,29 @@ def monitor_connector_deletion_taskset(
rcd = RedisConnectorDeletion(cc_pair_id)
fence_value = r.get(rcd.fence_key)
# read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rcd.fence_key))
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorDeletionFenceData.model_validate_json(
cast(str, fence_json)
)
except ValueError:
task_logger.error("The value is not an integer.")
task_logger.exception(
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
)
raise
# the fence is setting up but isn't ready yet
if fence_data.num_tasks is None:
return
count = cast(int, r.scard(rcd.taskset_key))
task_logger.info(
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={initial_count}"
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={fence_data.num_tasks}"
)
if count > 0:
return
@@ -447,7 +496,7 @@ def monitor_connector_deletion_taskset(
)
if not connector or not len(connector.credentials):
task_logger.info(
"Found no credentials left for connector, deleting connector"
"Connector deletion - Found no credentials left for connector, deleting connector"
)
db_session.delete(connector)
db_session.commit()
@@ -457,17 +506,17 @@ def monitor_connector_deletion_taskset(
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
add_deletion_failure_message(db_session, cc_pair_id, error_message)
task_logger.exception(
f"Failed to run connector_deletion. "
f"Connector deletion exceptioned: "
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
)
raise e
task_logger.info(
f"Successfully deleted cc_pair: "
f"Connector deletion succeeded: "
f"cc_pair={cc_pair_id} "
f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} "
f"docs_deleted={initial_count}"
f"docs_deleted={fence_data.num_tasks}"
)
r.delete(rcd.taskset_key)
@@ -577,7 +626,12 @@ def monitor_ccpair_indexing_taskset(
"monitor_ccpair_indexing_taskset: generator_progress_value is not an integer."
)
if fence_data.index_attempt_id is None or fence_data.celery_task_id is None:
# the task is still setting up
return
# Read result state BEFORE generator_complete_key to avoid a race condition
# never use any blocking methods on the result from inside a task!
result: AsyncResult = AsyncResult(fence_data.celery_task_id)
result_state = result.state
@@ -679,36 +733,9 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
f"pruning={n_pruning}"
)
lock_beat.reacquire()
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
# do some cleanup before clearing fences
# check the db for any outstanding index attempts
with get_session_with_tenant(tenant_id) as db_session:
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
monitor_document_set_taskset(key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
monitor_usergroup_taskset = (
fetch_versioned_implementation_with_fallback(
"danswer.background.celery.tasks.vespa.tasks",
"monitor_usergroup_taskset",
noop_fallback,
)
)
monitor_usergroup_taskset(key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
# do some cleanup before clearing fences
# check the db for any outstanding index attempts
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
@@ -726,8 +753,42 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
if not r.exists(rci.fence_key):
mark_attempt_failed(a, db_session, failure_reason=failure_reason)
lock_beat.reacquire()
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
lock_beat.reacquire()
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
"danswer.background.celery.tasks.vespa.tasks",
"monitor_usergroup_taskset",
noop_fallback,
)
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(key_bytes, r, db_session)
# uncomment for debugging if needed

View File

@@ -1,9 +1,8 @@
"""Entry point for running celery worker / celery beat."""
"""Factory stub for running celery worker / celery beat."""
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
celery_app = fetch_versioned_implementation(
"danswer.background.celery.celery_app", "celery_app"
app = fetch_versioned_implementation(
"danswer.background.celery.apps.beat", "celery_app"
)

View File

@@ -0,0 +1,17 @@
"""Factory stub for running celery worker / celery beat.
This code is different from the primary/beat stubs because there is no EE version to
fetch. Port over the code in those files if we add an EE version of this worker."""
from celery import Celery
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from danswer.background.celery.apps.heavy import celery_app
return celery_app
app = get_app()

View File

@@ -0,0 +1,17 @@
"""Factory stub for running celery worker / celery beat.
This code is different from the primary/beat stubs because there is no EE version to
fetch. Port over the code in those files if we add an EE version of this worker."""
from celery import Celery
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from danswer.background.celery.apps.indexing import celery_app
return celery_app
app = get_app()

View File

@@ -0,0 +1,17 @@
"""Factory stub for running celery worker / celery beat.
This code is different from the primary/beat stubs because there is no EE version to
fetch. Port over the code in those files if we add an EE version of this worker."""
from celery import Celery
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from danswer.background.celery.apps.light import celery_app
return celery_app
app = get_app()

View File

@@ -0,0 +1,8 @@
"""Factory stub for running celery worker / celery beat."""
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app = fetch_versioned_implementation(
"danswer.background.celery.apps.primary", "celery_app"
)

View File

@@ -1,6 +1,7 @@
import time
import traceback
from collections.abc import Callable
from abc import ABC
from abc import abstractmethod
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -41,6 +42,19 @@ logger = setup_logger()
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
class RunIndexingCallbackInterface(ABC):
"""Defines a callback interface to be passed to
to run_indexing_entrypoint."""
@abstractmethod
def should_stop(self) -> bool:
"""Signal to stop the looping function in flight."""
@abstractmethod
def progress(self, amount: int) -> None:
"""Send progress updates to the caller."""
def _get_connector_runner(
db_session: Session,
attempt: IndexAttempt,
@@ -92,7 +106,7 @@ def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
tenant_id: str | None,
progress_callback: Callable[[int], None] | None = None,
callback: RunIndexingCallbackInterface | None = None,
) -> None:
"""
1. Get documents which are either new or updated from specified application
@@ -206,6 +220,11 @@ def _run_indexing(
# index being built. We want to populate it even for paused connectors
# Often paused connectors are sources that aren't updated frequently but the
# contents still need to be initially pulled.
if callback:
if callback.should_stop():
raise RuntimeError("Connector stop signal detected")
# TODO: should we move this into the above callback instead?
db_session.refresh(db_cc_pair)
if (
(
@@ -263,8 +282,8 @@ def _run_indexing(
# be inaccurate
db_session.commit()
if progress_callback:
progress_callback(len(doc_batch))
if callback:
callback.progress(len(doc_batch))
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
@@ -394,7 +413,7 @@ def run_indexing_entrypoint(
tenant_id: str | None,
connector_credential_pair_id: int,
is_ee: bool = False,
progress_callback: Callable[[int], None] | None = None,
callback: RunIndexingCallbackInterface | None = None,
) -> None:
try:
if is_ee:
@@ -417,7 +436,7 @@ def run_indexing_entrypoint(
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
_run_indexing(db_session, attempt, tenant_id, progress_callback)
_run_indexing(db_session, attempt, tenant_id, callback)
logger.info(
f"Indexing finished for tenant {tenant_id}: "

View File

@@ -1,494 +0,0 @@
# TODO(rkuo): delete after background indexing via celery is fully vetted
# import logging
# import time
# from datetime import datetime
# import dask
# from dask.distributed import Client
# from dask.distributed import Future
# from distributed import LocalCluster
# from sqlalchemy import text
# from sqlalchemy.exc import ProgrammingError
# from sqlalchemy.orm import Session
# from danswer.background.indexing.dask_utils import ResourceLogger
# from danswer.background.indexing.job_client import SimpleJob
# from danswer.background.indexing.job_client import SimpleJobClient
# from danswer.background.indexing.run_indexing import run_indexing_entrypoint
# from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
# from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
# from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
# from danswer.configs.app_configs import MULTI_TENANT
# from danswer.configs.app_configs import NUM_INDEXING_WORKERS
# from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
# from danswer.configs.constants import DocumentSource
# from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
# from danswer.configs.constants import TENANT_ID_PREFIX
# from danswer.db.connector import fetch_connectors
# from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
# from danswer.db.engine import get_db_current_time
# from danswer.db.engine import get_session_with_tenant
# from danswer.db.engine import get_sqlalchemy_engine
# from danswer.db.engine import SqlEngine
# from danswer.db.index_attempt import create_index_attempt
# from danswer.db.index_attempt import get_index_attempt
# from danswer.db.index_attempt import get_inprogress_index_attempts
# from danswer.db.index_attempt import get_last_attempt_for_cc_pair
# from danswer.db.index_attempt import get_not_started_index_attempts
# from danswer.db.index_attempt import mark_attempt_failed
# from danswer.db.models import ConnectorCredentialPair
# from danswer.db.models import IndexAttempt
# from danswer.db.models import IndexingStatus
# from danswer.db.models import IndexModelStatus
# from danswer.db.models import SearchSettings
# from danswer.db.search_settings import get_current_search_settings
# from danswer.db.search_settings import get_secondary_search_settings
# from danswer.db.swap_index import check_index_swap
# from danswer.document_index.vespa.index import VespaIndex
# from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
# from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
# from danswer.utils.logger import setup_logger
# from danswer.utils.variable_functionality import global_version
# from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
# from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
# from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
# from shared_configs.configs import LOG_LEVEL
# logger = setup_logger()
# # If the indexing dies, it's most likely due to resource constraints,
# # restarting just delays the eventual failure, not useful to the user
# dask.config.set({"distributed.scheduler.allowed-failures": 0})
# _UNEXPECTED_STATE_FAILURE_REASON = (
# "Stopped mid run, likely due to the background process being killed"
# )
# def _should_create_new_indexing(
# cc_pair: ConnectorCredentialPair,
# last_index: IndexAttempt | None,
# search_settings_instance: SearchSettings,
# secondary_index_building: bool,
# db_session: Session,
# ) -> bool:
# connector = cc_pair.connector
# # don't kick off indexing for `NOT_APPLICABLE` sources
# if connector.source == DocumentSource.NOT_APPLICABLE:
# return False
# # User can still manually create single indexing attempts via the UI for the
# # currently in use index
# if DISABLE_INDEX_UPDATE_ON_SWAP:
# if (
# search_settings_instance.status == IndexModelStatus.PRESENT
# and secondary_index_building
# ):
# return False
# # When switching over models, always index at least once
# if search_settings_instance.status == IndexModelStatus.FUTURE:
# if last_index:
# # No new index if the last index attempt succeeded
# # Once is enough. The model will never be able to swap otherwise.
# if last_index.status == IndexingStatus.SUCCESS:
# return False
# # No new index if the last index attempt is waiting to start
# if last_index.status == IndexingStatus.NOT_STARTED:
# return False
# # No new index if the last index attempt is running
# if last_index.status == IndexingStatus.IN_PROGRESS:
# return False
# else:
# if (
# connector.id == 0 or connector.source == DocumentSource.INGESTION_API
# ): # Ingestion API
# return False
# return True
# # If the connector is paused or is the ingestion API, don't index
# # NOTE: during an embedding model switch over, the following logic
# # is bypassed by the above check for a future model
# if (
# not cc_pair.status.is_active()
# or connector.id == 0
# or connector.source == DocumentSource.INGESTION_API
# ):
# return False
# if not last_index:
# return True
# if connector.refresh_freq is None:
# return False
# # Only one scheduled/ongoing job per connector at a time
# # this prevents cases where
# # (1) the "latest" index_attempt is scheduled so we show
# # that in the UI despite another index_attempt being in-progress
# # (2) multiple scheduled index_attempts at a time
# if (
# last_index.status == IndexingStatus.NOT_STARTED
# or last_index.status == IndexingStatus.IN_PROGRESS
# ):
# return False
# current_db_time = get_db_current_time(db_session)
# time_since_index = current_db_time - last_index.time_updated
# return time_since_index.total_seconds() >= connector.refresh_freq
# def _mark_run_failed(
# db_session: Session, index_attempt: IndexAttempt, failure_reason: str
# ) -> None:
# """Marks the `index_attempt` row as failed + updates the `
# connector_credential_pair` to reflect that the run failed"""
# logger.warning(
# f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
# f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
# )
# mark_attempt_failed(
# index_attempt=index_attempt,
# db_session=db_session,
# failure_reason=failure_reason,
# )
# """Main funcs"""
# def create_indexing_jobs(
# existing_jobs: dict[int, Future | SimpleJob], tenant_id: str | None
# ) -> None:
# """Creates new indexing jobs for each connector / credential pair which is:
# 1. Enabled
# 2. `refresh_frequency` time has passed since the last indexing run for this pair
# 3. There is not already an ongoing indexing attempt for this pair
# """
# with get_session_with_tenant(tenant_id) as db_session:
# ongoing: set[tuple[int | None, int]] = set()
# for attempt_id in existing_jobs:
# attempt = get_index_attempt(
# db_session=db_session, index_attempt_id=attempt_id
# )
# if attempt is None:
# logger.error(
# f"Unable to find IndexAttempt for ID '{attempt_id}' when creating "
# "indexing jobs"
# )
# continue
# ongoing.add(
# (
# attempt.connector_credential_pair_id,
# attempt.search_settings_id,
# )
# )
# # Get the primary search settings
# primary_search_settings = get_current_search_settings(db_session)
# search_settings = [primary_search_settings]
# # Check for secondary search settings
# secondary_search_settings = get_secondary_search_settings(db_session)
# if secondary_search_settings is not None:
# # If secondary settings exist, add them to the list
# search_settings.append(secondary_search_settings)
# all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
# for cc_pair in all_connector_credential_pairs:
# for search_settings_instance in search_settings:
# # Check if there is an ongoing indexing attempt for this connector credential pair
# if (cc_pair.id, search_settings_instance.id) in ongoing:
# continue
# last_attempt = get_last_attempt_for_cc_pair(
# cc_pair.id, search_settings_instance.id, db_session
# )
# if not _should_create_new_indexing(
# cc_pair=cc_pair,
# last_index=last_attempt,
# search_settings_instance=search_settings_instance,
# secondary_index_building=len(search_settings) > 1,
# db_session=db_session,
# ):
# continue
# create_index_attempt(
# cc_pair.id, search_settings_instance.id, db_session
# )
# def cleanup_indexing_jobs(
# existing_jobs: dict[int, Future | SimpleJob],
# tenant_id: str | None,
# timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
# ) -> dict[int, Future | SimpleJob]:
# existing_jobs_copy = existing_jobs.copy()
# # clean up completed jobs
# with get_session_with_tenant(tenant_id) as db_session:
# for attempt_id, job in existing_jobs.items():
# index_attempt = get_index_attempt(
# db_session=db_session, index_attempt_id=attempt_id
# )
# # do nothing for ongoing jobs that haven't been stopped
# if not job.done():
# if not index_attempt:
# continue
# if not index_attempt.is_finished():
# continue
# if job.status == "error":
# logger.error(job.exception())
# job.release()
# del existing_jobs_copy[attempt_id]
# if not index_attempt:
# logger.error(
# f"Unable to find IndexAttempt for ID '{attempt_id}' when cleaning "
# "up indexing jobs"
# )
# continue
# if (
# index_attempt.status == IndexingStatus.IN_PROGRESS
# or job.status == "error"
# ):
# _mark_run_failed(
# db_session=db_session,
# index_attempt=index_attempt,
# failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
# )
# # clean up in-progress jobs that were never completed
# try:
# connectors = fetch_connectors(db_session)
# for connector in connectors:
# in_progress_indexing_attempts = get_inprogress_index_attempts(
# connector.id, db_session
# )
# for index_attempt in in_progress_indexing_attempts:
# if index_attempt.id in existing_jobs:
# # If index attempt is canceled, stop the run
# if index_attempt.status == IndexingStatus.FAILED:
# existing_jobs[index_attempt.id].cancel()
# # check to see if the job has been updated in last `timeout_hours` hours, if not
# # assume it to frozen in some bad state and just mark it as failed. Note: this relies
# # on the fact that the `time_updated` field is constantly updated every
# # batch of documents indexed
# current_db_time = get_db_current_time(db_session=db_session)
# time_since_update = current_db_time - index_attempt.time_updated
# if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
# existing_jobs[index_attempt.id].cancel()
# _mark_run_failed(
# db_session=db_session,
# index_attempt=index_attempt,
# failure_reason="Indexing run frozen - no updates in the last three hours. "
# "The run will be re-attempted at next scheduled indexing time.",
# )
# else:
# # If job isn't known, simply mark it as failed
# _mark_run_failed(
# db_session=db_session,
# index_attempt=index_attempt,
# failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
# )
# except ProgrammingError:
# logger.debug(f"No Connector Table exists for: {tenant_id}")
# return existing_jobs_copy
# def kickoff_indexing_jobs(
# existing_jobs: dict[int, Future | SimpleJob],
# client: Client | SimpleJobClient,
# secondary_client: Client | SimpleJobClient,
# tenant_id: str | None,
# ) -> dict[int, Future | SimpleJob]:
# existing_jobs_copy = existing_jobs.copy()
# current_session = get_session_with_tenant(tenant_id)
# # Don't include jobs waiting in the Dask queue that just haven't started running
# # Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
# with current_session as db_session:
# # get_not_started_index_attempts orders its returned results from oldest to newest
# # we must process attempts in a FIFO manner to prevent connector starvation
# new_indexing_attempts = [
# (attempt, attempt.search_settings)
# for attempt in get_not_started_index_attempts(db_session)
# if attempt.id not in existing_jobs
# ]
# logger.debug(f"Found {len(new_indexing_attempts)} new indexing task(s).")
# if not new_indexing_attempts:
# return existing_jobs
# indexing_attempt_count = 0
# primary_client_full = False
# secondary_client_full = False
# for attempt, search_settings in new_indexing_attempts:
# if primary_client_full and secondary_client_full:
# break
# use_secondary_index = (
# search_settings.status == IndexModelStatus.FUTURE
# if search_settings is not None
# else False
# )
# if attempt.connector_credential_pair.connector is None:
# logger.warning(
# f"Skipping index attempt as Connector has been deleted: {attempt}"
# )
# with current_session as db_session:
# mark_attempt_failed(
# attempt, db_session, failure_reason="Connector is null"
# )
# continue
# if attempt.connector_credential_pair.credential is None:
# logger.warning(
# f"Skipping index attempt as Credential has been deleted: {attempt}"
# )
# with current_session as db_session:
# mark_attempt_failed(
# attempt, db_session, failure_reason="Credential is null"
# )
# continue
# if not use_secondary_index:
# if not primary_client_full:
# run = client.submit(
# run_indexing_entrypoint,
# attempt.id,
# tenant_id,
# attempt.connector_credential_pair_id,
# global_version.is_ee_version(),
# pure=False,
# )
# if not run:
# primary_client_full = True
# else:
# if not secondary_client_full:
# run = secondary_client.submit(
# run_indexing_entrypoint,
# attempt.id,
# tenant_id,
# attempt.connector_credential_pair_id,
# global_version.is_ee_version(),
# pure=False,
# )
# if not run:
# secondary_client_full = True
# if run:
# if indexing_attempt_count == 0:
# logger.info(
# f"Indexing dispatch starts: pending={len(new_indexing_attempts)}"
# )
# indexing_attempt_count += 1
# secondary_str = " (secondary index)" if use_secondary_index else ""
# logger.info(
# f"Indexing dispatched{secondary_str}: "
# f"attempt_id={attempt.id} "
# f"connector='{attempt.connector_credential_pair.connector.name}' "
# f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
# f"credentials='{attempt.connector_credential_pair.credential_id}'"
# )
# existing_jobs_copy[attempt.id] = run
# if indexing_attempt_count > 0:
# logger.info(
# f"Indexing dispatch results: "
# f"initial_pending={len(new_indexing_attempts)} "
# f"started={indexing_attempt_count} "
# f"remaining={len(new_indexing_attempts) - indexing_attempt_count}"
# )
# return existing_jobs_copy
# def get_all_tenant_ids() -> list[str] | list[None]:
# if not MULTI_TENANT:
# return [None]
# with get_session_with_tenant(tenant_id="public") as session:
# result = session.execute(
# text(
# """
# SELECT schema_name
# FROM information_schema.schemata
# WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
# )
# )
# tenant_ids = [row[0] for row in result]
# valid_tenants = [
# tenant
# for tenant in tenant_ids
# if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
# ]
# return valid_tenants
# def update_loop(
# delay: int = 10,
# num_workers: int = NUM_INDEXING_WORKERS,
# num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
# ) -> None:
# if not MULTI_TENANT:
# # We can use this function as we are certain only the public schema exists
# # (explicitly for the non-`MULTI_TENANT` case)
# engine = get_sqlalchemy_engine()
# with Session(engine) as db_session:
# check_index_swap(db_session=db_session)
# search_settings = get_current_search_settings(db_session)
# # So that the first time users aren't surprised by really slow speed of first
# # batch of documents indexed
# if search_settings.provider_type is None:
# logger.notice("Running a first inference to warm up embedding model")
# embedding_model = EmbeddingModel.from_db_model(
# search_settings=search_settings,
# server_host=INDEXING_MODEL_SERVER_HOST,
# server_port=INDEXING_MODEL_SERVER_PORT,
# )
# warm_up_bi_encoder(
# embedding_model=embedding_model,
# )
# logger.notice("First inference complete.")
# client_primary: Client | SimpleJobClient
# client_secondary: Client | SimpleJobClient
# if DASK_JOB_CLIENT_ENABLED:
# cluster_primary = LocalCluster(
# n_workers=num_workers,
# threads_per_worker=1,
# silence_logs=logging.ERROR,
# )
# cluster_secondary = LocalCluster(
# n_workers=num_secondary_workers,
# threads_per_worker=1,
# silence_logs=logging.ERROR,
# )
# client_primary = Client(cluster_primary)
# client_secondary = Client(cluster_secondary)
# if LOG_LEVEL.lower() == "debug":
# client_primary.register_worker_plugin(ResourceLogger())
# else:
# client_primary = SimpleJobClient(n_workers=num_workers)
# client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
# existing_jobs: dict[str | None, dict[int, Future | SimpleJob]] = {}
# logger.notice("Startup complete. Waiting for indexing jobs...")
# while True:
# start = time.time()
# start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
# logger.debug(f"Running update, current UTC time: {start_time_utc}")
# if existing_jobs:
# logger.debug(
# "Found existing indexing jobs: "
# f"{[(tenant_id, list(jobs.keys())) for tenant_id, jobs in existing_jobs.items()]}"
# )
# try:
# tenants = get_all_tenant_ids()
# for tenant_id in tenants:
# try:
# logger.debug(
# f"Processing {'index attempts' if tenant_id is None else f'tenant {tenant_id}'}"
# )
# with get_session_with_tenant(tenant_id) as db_session:
# index_to_expire = check_index_swap(db_session=db_session)
# if index_to_expire and tenant_id and MULTI_TENANT:
# VespaIndex.delete_entries_by_tenant_id(
# tenant_id=tenant_id,
# index_name=index_to_expire.index_name,
# )
# if not MULTI_TENANT:
# search_settings = get_current_search_settings(db_session)
# if search_settings.provider_type is None:
# logger.notice(
# "Running a first inference to warm up embedding model"
# )
# embedding_model = EmbeddingModel.from_db_model(
# search_settings=search_settings,
# server_host=INDEXING_MODEL_SERVER_HOST,
# server_port=INDEXING_MODEL_SERVER_PORT,
# )
# warm_up_bi_encoder(embedding_model=embedding_model)
# logger.notice("First inference complete.")
# tenant_jobs = existing_jobs.get(tenant_id, {})
# tenant_jobs = cleanup_indexing_jobs(
# existing_jobs=tenant_jobs, tenant_id=tenant_id
# )
# create_indexing_jobs(existing_jobs=tenant_jobs, tenant_id=tenant_id)
# tenant_jobs = kickoff_indexing_jobs(
# existing_jobs=tenant_jobs,
# client=client_primary,
# secondary_client=client_secondary,
# tenant_id=tenant_id,
# )
# existing_jobs[tenant_id] = tenant_jobs
# except Exception as e:
# logger.exception(
# f"Failed to process tenant {tenant_id or 'default'}: {e}"
# )
# except Exception as e:
# logger.exception(f"Failed to run update due to {e}")
# sleep_time = delay - (time.time() - start)
# if sleep_time > 0:
# time.sleep(sleep_time)
# def update__main() -> None:
# set_is_ee_based_on_env_variable()
# # initialize the Postgres connection pool
# SqlEngine.set_app_name(POSTGRES_INDEXER_APP_NAME)
# logger.notice("Starting indexing service")
# update_loop()
# if __name__ == "__main__":
# update__main()

View File

@@ -672,6 +672,7 @@ def stream_chat_message_objects(
all_docs_useful=selected_db_search_docs is not None
),
document_pruning_config=document_pruning_config,
structured_response_format=new_msg_req.structured_response_format,
),
prompt_config=prompt_config,
llm=(

View File

@@ -43,6 +43,9 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
# Necessary for cloud integration tests
DISABLE_VERIFICATION = os.environ.get("DISABLE_VERIFICATION", "").lower() == "true"
# Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive
# information. This provides an extra layer of security on top of Postgres access controls
# and is available in Danswer EE
@@ -131,7 +134,6 @@ try:
except ValueError:
INDEX_BATCH_SIZE = 16
# Below are intended to match the env variables names used by the official postgres docker image
# https://hub.docker.com/_/postgres
POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
@@ -198,6 +200,41 @@ try:
except ValueError:
CELERY_BROKER_POOL_LIMIT = CELERY_BROKER_POOL_LIMIT_DEFAULT
CELERY_WORKER_LIGHT_CONCURRENCY_DEFAULT = 24
try:
CELERY_WORKER_LIGHT_CONCURRENCY = int(
os.environ.get(
"CELERY_WORKER_LIGHT_CONCURRENCY", CELERY_WORKER_LIGHT_CONCURRENCY_DEFAULT
)
)
except ValueError:
CELERY_WORKER_LIGHT_CONCURRENCY = CELERY_WORKER_LIGHT_CONCURRENCY_DEFAULT
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT = 8
try:
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER = int(
os.environ.get(
"CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER",
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT,
)
)
except ValueError:
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER = (
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT
)
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 1
try:
env_value = os.environ.get("CELERY_WORKER_INDEXING_CONCURRENCY")
if not env_value:
env_value = os.environ.get("NUM_INDEXING_WORKERS")
if not env_value:
env_value = str(CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT)
CELERY_WORKER_INDEXING_CONCURRENCY = int(env_value)
except ValueError:
CELERY_WORKER_INDEXING_CONCURRENCY = CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT
#####
# Connector Configs
#####
@@ -253,12 +290,6 @@ CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES = (
os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES", "").lower() == "true"
)
# Save pages labels as Danswer metadata tags
# The reason to skip this would be to reduce the number of calls to Confluence due to rate limit concerns
CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING = (
os.environ.get("CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING", "").lower() == "true"
)
# Attachments exceeding this size will not be retrieved (in bytes)
CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD = int(
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD", 10 * 1024 * 1024)

View File

@@ -46,7 +46,6 @@ POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
POSTGRES_DEFAULT_SCHEMA = "public"
# API Keys
DANSWER_API_KEY_PREFIX = "API_KEY__"
@@ -71,6 +70,7 @@ KV_CUSTOMER_UUID_KEY = "customer_uuid"
KV_INSTANCE_DOMAIN_KEY = "instance_domain"
KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings"
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
KV_DOCUMENTS_SEEDED_KEY = "documents_seeded"
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
@@ -136,6 +136,7 @@ DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FIL
class NotificationType(str, Enum):
REINDEX = "reindex"
PERSONA_SHARED = "persona_shared"
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending" # 2 days left in trial
class BlobType(str, Enum):

View File

@@ -13,8 +13,8 @@ Connectors come in 3 different flows:
documents via a connector's API or loads the documents from some sort of a dump file.
- Poll connector:
- Incrementally updates documents based on a provided time range. It is used by the background job to pull the latest
changes additions and changes since the last round of polling. This connector helps keep the document index up to date
without needing to fetch/embed/index every document which generally be too slow to do frequently on large sets of
changes and additions since the last round of polling. This connector helps keep the document index up to date
without needing to fetch/embed/index every document which would be too slow to do frequently on large sets of
documents.
- Event Based connectors:
- Connectors that listen to events and update documents accordingly.

View File

@@ -15,7 +15,6 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_st
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
@@ -24,6 +23,7 @@ from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.html_utils import parse_html_page_basic
from danswer.utils.logger import setup_logger
from danswer.utils.retry_wrapper import retry_builder
logger = setup_logger()

View File

@@ -10,7 +10,6 @@ from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
@@ -19,6 +18,7 @@ from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.retry_wrapper import retry_builder
CLICKUP_API_BASE_URL = "https://api.clickup.com/api/v2"

View File

@@ -1,32 +0,0 @@
import bs4
def build_confluence_document_id(base_url: str, content_url: str) -> str:
"""For confluence, the document id is the page url for a page based document
or the attachment download url for an attachment based document
Args:
base_url (str): The base url of the Confluence instance
content_url (str): The url of the page or attachment download url
Returns:
str: The document id
"""
return f"{base_url}{content_url}"
def get_used_attachments(text: str) -> list[str]:
"""Parse a Confluence html page to generate a list of current
attachment in used
Args:
text (str): The page content
Returns:
list[str]: List of filenames currently in use by the page text
"""
files_in_used = []
soup = bs4.BeautifulSoup(text, "html.parser")
for attachment in soup.findAll("ri:attachment"):
files_in_used.append(attachment.attrs["ri:filename"])
return files_in_used

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,226 @@
import math
import time
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
from typing import cast
from typing import TypeVar
from urllib.parse import quote
from atlassian import Confluence # type:ignore
from requests import HTTPError
from danswer.utils.logger import setup_logger
logger = setup_logger()
F = TypeVar("F", bound=Callable[..., Any])
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
class ConfluenceRateLimitError(Exception):
pass
def _handle_http_error(e: HTTPError, attempt: int) -> int:
MIN_DELAY = 2
MAX_DELAY = 60
STARTING_DELAY = 5
BACKOFF = 2
# Check if the response or headers are None to avoid potential AttributeError
if e.response is None or e.response.headers is None:
logger.warning("HTTPError with `None` as response or as headers")
raise e
if (
e.response.status_code != 429
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
):
raise e
retry_after = None
retry_after_header = e.response.headers.get("Retry-After")
if retry_after_header is not None:
try:
retry_after = int(retry_after_header)
if retry_after > MAX_DELAY:
logger.warning(
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
)
retry_after = MAX_DELAY
if retry_after < MIN_DELAY:
retry_after = MIN_DELAY
except ValueError:
pass
if retry_after is not None:
logger.warning(
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
)
delay = retry_after
else:
logger.warning(
"Rate limiting without retry header. Retrying with exponential backoff..."
)
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
delay_until = math.ceil(time.monotonic() + delay)
return delay_until
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
# this uses the native rate limiting option provided by the
# confluence client and otherwise applies a simpler set of error handling
def handle_confluence_rate_limit(confluence_call: F) -> F:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
MAX_RETRIES = 5
TIMEOUT = 3600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
if time.monotonic() > timeout_at:
raise TimeoutError(
f"Confluence call attempts took longer than {TIMEOUT} seconds."
)
try:
# we're relying more on the client to rate limit itself
# and applying our own retries in a more specific set of circumstances
return confluence_call(*args, **kwargs)
except HTTPError as e:
delay_until = _handle_http_error(e, attempt)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
except AttributeError as e:
# Some error within the Confluence library, unclear why it fails.
# Users reported it to be intermittent, so just retry
if attempt == MAX_RETRIES - 1:
raise e
logger.exception(
"Confluence Client raised an AttributeError. Retrying..."
)
time.sleep(5)
return cast(F, wrapped_call)
_DEFAULT_PAGINATION_LIMIT = 100
class OnyxConfluence(Confluence):
"""
This is a custom Confluence class that overrides the default Confluence class to add a custom CQL method.
This is necessary because the default Confluence class does not properly support cql expansions.
All methods are automatically wrapped with handle_confluence_rate_limit.
"""
def __init__(self, url: str, *args: Any, **kwargs: Any) -> None:
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
self._wrap_methods()
def _wrap_methods(self) -> None:
"""
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
wrap it with handle_confluence_rate_limit.
"""
for attr_name in dir(self):
if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
setattr(
self,
attr_name,
handle_confluence_rate_limit(getattr(self, attr_name)),
)
def _paginate_url(
self, url_suffix: str, limit: int | None = None
) -> Iterator[list[dict[str, Any]]]:
"""
This will paginate through the top level query.
"""
if not limit:
limit = _DEFAULT_PAGINATION_LIMIT
connection_char = "&" if "?" in url_suffix else "?"
url_suffix += f"{connection_char}limit={limit}"
while url_suffix:
try:
next_response = self.get(url_suffix)
except Exception as e:
logger.exception("Error in danswer_cql: \n")
raise e
yield next_response.get("results", [])
url_suffix = next_response.get("_links", {}).get("next")
def paginated_groups_retrieval(
self,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
return self._paginate_url("rest/api/group", limit)
def paginated_group_members_retrieval(
self,
group_name: str,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
group_name = quote(group_name)
return self._paginate_url(f"rest/api/group/{group_name}/member", limit)
def paginated_cql_user_retrieval(
self,
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
expand_string = f"&expand={expand}" if expand else ""
return self._paginate_url(
f"rest/api/search/user?cql={cql}{expand_string}", limit
)
def paginated_cql_page_retrieval(
self,
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
expand_string = f"&expand={expand}" if expand else ""
return self._paginate_url(
f"rest/api/content/search?cql={cql}{expand_string}", limit
)
def cql_paginate_all_expansions(
self,
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
"""
This function will paginate through the top level query first, then
paginate through all of the expansions.
The limit only applies to the top level query.
All expansion paginations use default pagination limit (defined by Atlassian).
"""
def _traverse_and_update(data: dict | list) -> None:
if isinstance(data, dict):
next_url = data.get("_links", {}).get("next")
if next_url and "results" in data:
data["results"].extend(self._paginate_url(next_url))
for value in data.values():
_traverse_and_update(value)
elif isinstance(data, list):
for item in data:
_traverse_and_update(item)
for results in self.paginated_cql_page_retrieval(cql, expand, limit):
_traverse_and_update(results)
yield results

View File

@@ -1,219 +0,0 @@
import math
import time
from collections.abc import Callable
from typing import Any
from typing import cast
from typing import TypeVar
from requests import HTTPError
from danswer.utils.logger import setup_logger
logger = setup_logger()
F = TypeVar("F", bound=Callable[..., Any])
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
class ConfluenceRateLimitError(Exception):
pass
# commenting out while we try using confluence's rate limiter instead
# # https://developer.atlassian.com/cloud/confluence/rate-limiting/
# def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
# def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
# max_retries = 5
# starting_delay = 5
# backoff = 2
# # max_delay is used when the server doesn't hand back "Retry-After"
# # and we have to decide the retry delay ourselves
# max_delay = 30 # Atlassian uses max_delay = 30 in their examples
# # max_retry_after is used when we do get a "Retry-After" header
# max_retry_after = 300 # should we really cap the maximum retry delay?
# NEXT_RETRY_KEY = BaseConnector.REDIS_KEY_PREFIX + "confluence_next_retry"
# # for testing purposes, rate limiting is written to fall back to a simpler
# # rate limiting approach when redis is not available
# r = get_redis_client(tenant_id=tenant_id)
# for attempt in range(max_retries):
# try:
# # if multiple connectors are waiting for the next attempt, there could be an issue
# # where many connectors are "released" onto the server at the same time.
# # That's not ideal ... but coming up with a mechanism for queueing
# # all of these connectors is a bigger problem that we want to take on
# # right now
# try:
# next_attempt = r.get(NEXT_RETRY_KEY)
# if next_attempt is None:
# next_attempt = 0
# else:
# next_attempt = int(cast(int, next_attempt))
# # TODO: all connectors need to be interruptible moving forward
# while time.monotonic() < next_attempt:
# time.sleep(1)
# except ConnectionError:
# pass
# return confluence_call(*args, **kwargs)
# except HTTPError as e:
# # Check if the response or headers are None to avoid potential AttributeError
# if e.response is None or e.response.headers is None:
# logger.warning("HTTPError with `None` as response or as headers")
# raise e
# retry_after_header = e.response.headers.get("Retry-After")
# if (
# e.response.status_code == 429
# or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
# ):
# retry_after = None
# if retry_after_header is not None:
# try:
# retry_after = int(retry_after_header)
# except ValueError:
# pass
# if retry_after is not None:
# if retry_after > max_retry_after:
# logger.warning(
# f"Clamping retry_after from {retry_after} to {max_delay} seconds..."
# )
# retry_after = max_delay
# logger.warning(
# f"Rate limit hit. Retrying after {retry_after} seconds..."
# )
# try:
# r.set(
# NEXT_RETRY_KEY,
# math.ceil(time.monotonic() + retry_after),
# )
# except ConnectionError:
# pass
# else:
# logger.warning(
# "Rate limit hit. Retrying with exponential backoff..."
# )
# delay = min(starting_delay * (backoff**attempt), max_delay)
# delay_until = math.ceil(time.monotonic() + delay)
# try:
# r.set(NEXT_RETRY_KEY, delay_until)
# except ConnectionError:
# while time.monotonic() < delay_until:
# time.sleep(1)
# else:
# # re-raise, let caller handle
# raise
# except AttributeError as e:
# # Some error within the Confluence library, unclear why it fails.
# # Users reported it to be intermittent, so just retry
# logger.warning(f"Confluence Internal Error, retrying... {e}")
# delay = min(starting_delay * (backoff**attempt), max_delay)
# delay_until = math.ceil(time.monotonic() + delay)
# try:
# r.set(NEXT_RETRY_KEY, delay_until)
# except ConnectionError:
# while time.monotonic() < delay_until:
# time.sleep(1)
# if attempt == max_retries - 1:
# raise e
# return cast(F, wrapped_call)
def _handle_http_error(e: HTTPError, attempt: int) -> int:
MIN_DELAY = 2
MAX_DELAY = 60
STARTING_DELAY = 5
BACKOFF = 2
# Check if the response or headers are None to avoid potential AttributeError
if e.response is None or e.response.headers is None:
logger.warning("HTTPError with `None` as response or as headers")
raise e
if (
e.response.status_code != 429
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
):
raise e
retry_after = None
retry_after_header = e.response.headers.get("Retry-After")
if retry_after_header is not None:
try:
retry_after = int(retry_after_header)
if retry_after > MAX_DELAY:
logger.warning(
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
)
retry_after = MAX_DELAY
if retry_after < MIN_DELAY:
retry_after = MIN_DELAY
except ValueError:
pass
if retry_after is not None:
logger.warning(
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
)
delay = retry_after
else:
logger.warning(
"Rate limiting without retry header. Retrying with exponential backoff..."
)
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
delay_until = math.ceil(time.monotonic() + delay)
return delay_until
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
# this uses the native rate limiting option provided by the
# confluence client and otherwise applies a simpler set of error handling
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
MAX_RETRIES = 5
TIMEOUT = 3600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
if time.monotonic() > timeout_at:
raise TimeoutError(
f"Confluence call attempts took longer than {TIMEOUT} seconds."
)
try:
# we're relying more on the client to rate limit itself
# and applying our own retries in a more specific set of circumstances
return confluence_call(*args, **kwargs)
except HTTPError as e:
delay_until = _handle_http_error(e, attempt)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
except AttributeError as e:
# Some error within the Confluence library, unclear why it fails.
# Users reported it to be intermittent, so just retry
if attempt == MAX_RETRIES - 1:
raise e
logger.exception(
"Confluence Client raised an AttributeError. Retrying..."
)
time.sleep(5)
return cast(F, wrapped_call)

View File

@@ -0,0 +1,214 @@
import io
from datetime import datetime
from datetime import timezone
from typing import Any
import bs4
from danswer.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from danswer.connectors.confluence.onyx_confluence import (
OnyxConfluence,
)
from danswer.file_processing.extract_file_text import extract_file_text
from danswer.file_processing.html_utils import format_document_soup
from danswer.utils.logger import setup_logger
logger = setup_logger()
_USER_EMAIL_CACHE: dict[str, str | None] = {}
def get_user_email_from_username__server(
confluence_client: OnyxConfluence, user_name: str
) -> str | None:
global _USER_EMAIL_CACHE
if _USER_EMAIL_CACHE.get(user_name) is None:
try:
response = confluence_client.get_mobile_parameters(user_name)
email = response.get("email")
except Exception:
email = None
_USER_EMAIL_CACHE[user_name] = email
return _USER_EMAIL_CACHE[user_name]
_USER_NOT_FOUND = "Unknown Confluence User"
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
"""Get Confluence Display Name based on the account-id or userkey value
Args:
user_id (str): The user id (i.e: the account-id or userkey)
confluence_client (Confluence): The Confluence Client
Returns:
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
"""
global _USER_ID_TO_DISPLAY_NAME_CACHE
if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None:
try:
result = confluence_client.get_user_details_by_userkey(user_id)
found_display_name = result.get("displayName")
except Exception:
found_display_name = None
if not found_display_name:
try:
result = confluence_client.get_user_details_by_accountid(user_id)
found_display_name = result.get("displayName")
except Exception:
found_display_name = None
_USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
def extract_text_from_confluence_html(
confluence_client: OnyxConfluence, confluence_object: dict[str, Any]
) -> str:
"""Parse a Confluence html page and replace the 'user Id' by the real
User Display Name
Args:
confluence_object (dict): The confluence object as a dict
confluence_client (Confluence): Confluence client
Returns:
str: loaded and formated Confluence page
"""
body = confluence_object["body"]
object_html = body.get("storage", body.get("view", {})).get("value")
soup = bs4.BeautifulSoup(object_html, "html.parser")
for user in soup.findAll("ri:user"):
user_id = (
user.attrs["ri:account-id"]
if "ri:account-id" in user.attrs
else user.get("ri:userkey")
)
if not user_id:
logger.warning(
"ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}"
)
continue
# Include @ sign for tagging, more clear for LLM
user.replaceWith("@" + _get_user(confluence_client, user_id))
return format_document_soup(soup)
def attachment_to_content(
confluence_client: OnyxConfluence,
attachment: dict[str, Any],
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
if attachment["metadata"]["mediaType"] in [
"image/jpeg",
"image/png",
"image/gif",
"image/svg+xml",
"video/mp4",
"video/quicktime",
]:
return None
download_link = confluence_client.url + attachment["_links"]["download"]
attachment_size = attachment["extensions"]["fileSize"]
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
)
return None
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
response = confluence_client._session.get(download_link)
if response.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
)
return None
extracted_text = extract_file_text(
io.BytesIO(response.content),
file_name=attachment["title"],
break_on_unprocessable=False,
)
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to char count. "
f"char count={len(extracted_text)} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
)
return None
return extracted_text
def build_confluence_document_id(base_url: str, content_url: str) -> str:
"""For confluence, the document id is the page url for a page based document
or the attachment download url for an attachment based document
Args:
base_url (str): The base url of the Confluence instance
content_url (str): The url of the page or attachment download url
Returns:
str: The document id
"""
return f"{base_url}{content_url}"
def extract_referenced_attachment_names(page_text: str) -> list[str]:
"""Parse a Confluence html page to generate a list of current
attachments in use
Args:
text (str): The page content
Returns:
list[str]: List of filenames currently in use by the page text
"""
referenced_attachment_filenames = []
soup = bs4.BeautifulSoup(page_text, "html.parser")
for attachment in soup.findAll("ri:attachment"):
referenced_attachment_filenames.append(attachment.attrs["ri:filename"])
return referenced_attachment_filenames
def datetime_from_string(datetime_string: str) -> datetime:
datetime_object = datetime.fromisoformat(datetime_string)
if datetime_object.tzinfo is None:
# If no timezone info, assume it is UTC
datetime_object = datetime_object.replace(tzinfo=timezone.utc)
else:
# If not in UTC, translate it
datetime_object = datetime_object.astimezone(timezone.utc)
return datetime_object
def build_confluence_client(
credentials_json: dict[str, Any], is_cloud: bool, wiki_base: str
) -> OnyxConfluence:
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=wiki_base.rstrip("/"),
# passing in username causes issues for Confluence data center
username=credentials_json["confluence_username"] if is_cloud else None,
password=credentials_json["confluence_access_token"] if is_cloud else None,
token=credentials_json["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=60,
max_backoff_seconds=60,
)

View File

@@ -14,7 +14,6 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_st
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
@@ -24,6 +23,7 @@ from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.html_utils import parse_html_page_basic
from danswer.utils.logger import setup_logger
from danswer.utils.retry_wrapper import retry_builder
logger = setup_logger()

View File

@@ -11,7 +11,6 @@ from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.document360.utils import flatten_child_categories
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
@@ -22,6 +21,7 @@ from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.html_utils import parse_html_page_basic
from danswer.utils.retry_wrapper import retry_builder
# Limitations and Potential Improvements
# 1. The "Categories themselves contain potentially relevant information" but they're not pulled in

View File

@@ -64,7 +64,7 @@ def identify_connector_class(
DocumentSource.SLACK: {
InputType.LOAD_STATE: SlackLoadConnector,
InputType.POLL: SlackPollConnector,
InputType.PRUNE: SlackPollConnector,
InputType.SLIM_RETRIEVAL: SlackPollConnector,
},
DocumentSource.GITHUB: GithubConnector,
DocumentSource.GMAIL: GmailConnector,

View File

@@ -10,7 +10,6 @@ from sqlalchemy.orm import Session
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
@@ -28,7 +27,8 @@ from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.extract_file_text import read_text_file
from danswer.file_store.file_store import get_default_file_store
from danswer.utils.logger import setup_logger
from shared_configs.configs import current_tenant_id
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
@@ -175,7 +175,7 @@ class LocalFileConnector(LoadConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = []
token = current_tenant_id.set(self.tenant_id)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self.tenant_id)
with get_session_with_tenant(self.tenant_id) as db_session:
for file_path in self.file_locations:
@@ -199,7 +199,7 @@ class LocalFileConnector(LoadConnector):
if documents:
yield documents
current_tenant_id.reset(token)
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
if __name__ == "__main__":

View File

@@ -19,7 +19,6 @@ from danswer.configs.app_configs import GOOGLE_DRIVE_ONLY_ORG_PUBLIC
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.google_drive.connector_auth import get_google_drive_creds
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
@@ -40,6 +39,7 @@ from danswer.file_processing.unstructured import get_unstructured_api_key
from danswer.file_processing.unstructured import unstructured_to_text
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
from danswer.utils.retry_wrapper import retry_builder
logger = setup_logger()

View File

@@ -3,11 +3,13 @@ from collections.abc import Iterator
from typing import Any
from danswer.connectors.models import Document
from danswer.connectors.models import SlimDocument
SecondsSinceUnixEpoch = float
GenerateDocumentsOutput = Iterator[list[Document]]
GenerateSlimDocumentOutput = Iterator[list[SlimDocument]]
class BaseConnector(abc.ABC):
@@ -52,9 +54,9 @@ class PollConnector(BaseConnector):
raise NotImplementedError
class IdConnector(BaseConnector):
class SlimConnector(BaseConnector):
@abc.abstractmethod
def retrieve_all_source_ids(self) -> set[str]:
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
raise NotImplementedError

View File

@@ -161,7 +161,7 @@ class LoopioConnector(LoadConnector, PollConnector):
]
doc_batch.append(
Document(
id=entry["id"],
id=str(entry["id"]),
sections=[Section(link=link, text=content_text)],
source=DocumentSource.LOOPIO,
semantic_identifier=questions[0],

View File

@@ -14,7 +14,7 @@ class InputType(str, Enum):
LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file
POLL = "poll" # e.g. calling an API to get all documents in the last hour
EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events
PRUNE = "prune"
SLIM_RETRIEVAL = "slim_retrieval"
class ConnectorMissingCredentialError(PermissionError):
@@ -169,6 +169,11 @@ class Document(DocumentBase):
)
class SlimDocument(BaseModel):
id: str
perm_sync_data: Any | None = None
class DocumentErrorSummary(BaseModel):
id: str
semantic_id: str

View File

@@ -241,24 +241,29 @@ class NotionConnector(LoadConnector, PollConnector):
)
# TODO there may be more types to handle here
if "name" in inner_dict:
return inner_dict["name"]
if "content" in inner_dict:
return inner_dict["content"]
start = inner_dict.get("start")
end = inner_dict.get("end")
if start is not None:
if end is not None:
return f"{start} - {end}"
return start
elif end is not None:
return f"Until {end}"
if isinstance(inner_dict, str):
# For some objects the innermost value could just be a string, not sure what causes this
return inner_dict
if "id" in inner_dict:
# This is not useful to index, it's a reference to another Notion object
# and this ID value in plaintext is useless outside of the Notion context
logger.debug("Skipping Notion object id field property")
return None
elif isinstance(inner_dict, dict):
if "name" in inner_dict:
return inner_dict["name"]
if "content" in inner_dict:
return inner_dict["content"]
start = inner_dict.get("start")
end = inner_dict.get("end")
if start is not None:
if end is not None:
return f"{start} - {end}"
return start
elif end is not None:
return f"Until {end}"
if "id" in inner_dict:
# This is not useful to index, it's a reference to another Notion object
# and this ID value in plaintext is useless outside of the Notion context
logger.debug("Skipping Notion object id field property")
return None
logger.debug(f"Unreadable property from innermost prop: {inner_dict}")
return None
@@ -268,7 +273,13 @@ class NotionConnector(LoadConnector, PollConnector):
if not prop:
continue
inner_value = _recurse_properties(prop)
try:
inner_value = _recurse_properties(prop)
except Exception as e:
# This is not a critical failure, these properties are not the actual contents of the page
# more similar to metadata
logger.warning(f"Error recursing properties for {prop_name}: {e}")
continue
# Not a perfect way to format Notion database tables but there's no perfect representation
# since this must be represented as plaintext
if inner_value:

View File

@@ -11,17 +11,25 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import IdConnector
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.models import SlimDocument
from danswer.connectors.salesforce.utils import extract_dict_text
from danswer.utils.logger import setup_logger
# TODO: this connector does not work well at large scales
# the large query against a large Salesforce instance has been reported to take 1.5 hours.
# Additionally it seems to eat up more memory over time if the connection is long running (again a scale issue).
DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
MAX_QUERY_LENGTH = 10000 # max query length is 20,000 characters
ID_PREFIX = "SALESFORCE_"
@@ -29,7 +37,7 @@ ID_PREFIX = "SALESFORCE_"
logger = setup_logger()
class SalesforceConnector(LoadConnector, PollConnector, IdConnector):
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
@@ -243,19 +251,22 @@ class SalesforceConnector(LoadConnector, PollConnector, IdConnector):
end_datetime = datetime.utcfromtimestamp(end)
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
def retrieve_all_source_ids(self) -> set[str]:
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
all_retrieved_ids: set[str] = set()
doc_metadata_list: list[SlimDocument] = []
for parent_object_type in self.parent_object_list:
query = f"SELECT Id FROM {parent_object_type}"
query_result = self.sf_client.query_all(query)
all_retrieved_ids.update(
f"{ID_PREFIX}{instance_dict.get('Id', '')}"
doc_metadata_list.extend(
SlimDocument(
id=f"{ID_PREFIX}{instance_dict.get('Id', '')}",
perm_sync_data={},
)
for instance_dict in query_result["records"]
)
return all_retrieved_ids
yield doc_metadata_list
if __name__ == "__main__":

View File

@@ -13,13 +13,15 @@ from danswer.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import IdConnector
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.models import SlimDocument
from danswer.connectors.slack.utils import expert_info_from_slack_id
from danswer.connectors.slack.utils import get_message_link
from danswer.connectors.slack.utils import make_paginated_slack_api_call_w_retries
@@ -326,7 +328,7 @@ def _get_all_doc_ids(
channels: list[str] | None = None,
channel_name_regex_enabled: bool = False,
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> set[str]:
) -> GenerateSlimDocumentOutput:
"""
Get all document ids in the workspace, channel by channel
This is pretty identical to get_all_docs, but it returns a set of ids instead of documents
@@ -338,13 +340,14 @@ def _get_all_doc_ids(
all_channels, channels, channel_name_regex_enabled
)
all_doc_ids = set()
for channel in filtered_channels:
channel_id = channel["id"]
channel_message_batches = get_channel_messages(
client=client,
channel=channel,
)
message_ts_set: set[str] = set()
for message_batch in channel_message_batches:
for message in message_batch:
if msg_filter_func(message):
@@ -353,12 +356,21 @@ def _get_all_doc_ids(
# The document id is the channel id and the ts of the first message in the thread
# Since we already have the first message of the thread, we dont have to
# fetch the thread for id retrieval, saving time and API calls
all_doc_ids.add(f"{channel['id']}__{message['ts']}")
message_ts_set.add(message["ts"])
return all_doc_ids
channel_metadata_list: list[SlimDocument] = []
for message_ts in message_ts_set:
channel_metadata_list.append(
SlimDocument(
id=f"{channel_id}__{message_ts}",
perm_sync_data={"channel_id": channel_id},
)
)
yield channel_metadata_list
class SlackPollConnector(PollConnector, IdConnector):
class SlackPollConnector(PollConnector, SlimConnector):
def __init__(
self,
workspace: str,
@@ -379,7 +391,7 @@ class SlackPollConnector(PollConnector, IdConnector):
self.client = WebClient(token=bot_token)
return None
def retrieve_all_source_ids(self) -> set[str]:
def retrieve_all_slim_documents(self) -> GenerateSlimDocumentOutput:
if self.client is None:
raise ConnectorMissingCredentialError("Slack")

View File

@@ -10,9 +10,9 @@ from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.web import SlackResponse
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.models import BasicExpertInfo
from danswer.utils.logger import setup_logger
from danswer.utils.retry_wrapper import retry_builder
logger = setup_logger()

View File

@@ -373,7 +373,7 @@ class WebConnector(LoadConnector):
page.close()
except Exception as e:
last_error = f"Failed to fetch '{current_url}': {e}"
logger.error(last_error)
logger.exception(last_error)
playwright.stop()
restart_playwright = True
continue

View File

@@ -211,6 +211,7 @@ def handle_regular_answer(
use_citations=use_citations,
danswerbot_flow=True,
)
if not answer.error_msg:
return answer
else:

View File

@@ -7,7 +7,6 @@ from slack_sdk import WebClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from danswer.background.celery.celery_app import get_all_tenant_ids
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
@@ -47,6 +46,7 @@ from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
from danswer.danswerbot.slack.utils import rephrase_slack_message
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import TenantSocketModeClient
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import get_session_with_tenant
from danswer.db.search_settings import get_current_search_settings
from danswer.key_value_store.interface import KvKeyNotFoundError
@@ -57,6 +57,7 @@ from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import SLACK_CHANNEL_ID
@@ -345,7 +346,9 @@ def process_message(
respond_every_channel: bool = DANSWER_BOT_RESPOND_EVERY_CHANNEL,
notify_no_answer: bool = NOTIFY_SLACKBOT_NO_ANSWER,
) -> None:
logger.debug(f"Received Slack request of type: '{req.type}'")
logger.debug(
f"Received Slack request of type: '{req.type}' for tenant, {client.tenant_id}"
)
# Throw out requests that can't or shouldn't be handled
if not prefilter_requests(req, client):
@@ -357,51 +360,59 @@ def process_message(
client=client.web_client, channel_id=channel
)
with get_session_with_tenant(client.tenant_id) as db_session:
slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session
)
# Set the current tenant ID at the beginning for all DB calls within this thread
if client.tenant_id:
logger.info(f"Setting tenant ID to {client.tenant_id}")
token = CURRENT_TENANT_ID_CONTEXTVAR.set(client.tenant_id)
try:
with get_session_with_tenant(client.tenant_id) as db_session:
slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session
)
# Be careful about this default, don't want to accidentally spam every channel
# Users should be able to DM slack bot in their private channels though
if (
slack_bot_config is None
and not respond_every_channel
# Can't have configs for DMs so don't toss them out
and not is_dm
# If /DanswerBot (is_bot_msg) or @DanswerBot (bypass_filters)
# always respond with the default configs
and not (details.is_bot_msg or details.bypass_filters)
):
return
# Be careful about this default, don't want to accidentally spam every channel
# Users should be able to DM slack bot in their private channels though
if (
slack_bot_config is None
and not respond_every_channel
# Can't have configs for DMs so don't toss them out
and not is_dm
# If /DanswerBot (is_bot_msg) or @DanswerBot (bypass_filters)
# always respond with the default configs
and not (details.is_bot_msg or details.bypass_filters)
):
return
follow_up = bool(
slack_bot_config
and slack_bot_config.channel_config
and slack_bot_config.channel_config.get("follow_up_tags") is not None
)
feedback_reminder_id = schedule_feedback_reminder(
details=details, client=client.web_client, include_followup=follow_up
)
follow_up = bool(
slack_bot_config
and slack_bot_config.channel_config
and slack_bot_config.channel_config.get("follow_up_tags") is not None
)
feedback_reminder_id = schedule_feedback_reminder(
details=details, client=client.web_client, include_followup=follow_up
)
failed = handle_message(
message_info=details,
slack_bot_config=slack_bot_config,
client=client.web_client,
feedback_reminder_id=feedback_reminder_id,
tenant_id=client.tenant_id,
)
failed = handle_message(
message_info=details,
slack_bot_config=slack_bot_config,
client=client.web_client,
feedback_reminder_id=feedback_reminder_id,
tenant_id=client.tenant_id,
)
if failed:
if feedback_reminder_id:
remove_scheduled_feedback_reminder(
client=client.web_client,
channel=details.sender,
msg_id=feedback_reminder_id,
)
# Skipping answering due to pre-filtering is not considered a failure
if notify_no_answer:
apologize_for_fail(details, client)
if failed:
if feedback_reminder_id:
remove_scheduled_feedback_reminder(
client=client.web_client,
channel=details.sender,
msg_id=feedback_reminder_id,
)
# Skipping answering due to pre-filtering is not considered a failure
if notify_no_answer:
apologize_for_fail(details, client)
finally:
if client.tenant_id:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def acknowledge_message(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
@@ -499,7 +510,9 @@ if __name__ == "__main__":
for tenant_id in tenant_ids:
with get_session_with_tenant(tenant_id) as db_session:
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id or "public")
latest_slack_bot_tokens = fetch_tokens()
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
if (
tenant_id not in slack_bot_tokens
@@ -533,6 +546,11 @@ if __name__ == "__main__":
socket_client = _get_socket_client(
latest_slack_bot_tokens, tenant_id
)
# Initialize socket client for this tenant. Each tenant has its own
# socket client, allowing for multiple concurrent connections (one
# per tenant) with the tenant ID wrapped in the socket model client.
# Each `connect` stores websocket connection in a separate thread.
_initialize_socket_client(socket_client)
socket_clients[tenant_id] = socket_client

View File

@@ -341,6 +341,8 @@ def add_credential_to_connector(
access_type: AccessType,
groups: list[int] | None,
auto_sync_options: dict | None = None,
initial_status: ConnectorCredentialPairStatus = ConnectorCredentialPairStatus.ACTIVE,
last_successful_index_time: datetime | None = None,
) -> StatusResponse:
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id(credential_id, user, db_session)
@@ -384,9 +386,10 @@ def add_credential_to_connector(
connector_id=connector_id,
credential_id=credential_id,
name=cc_pair_name,
status=ConnectorCredentialPairStatus.ACTIVE,
status=initial_status,
access_type=access_type,
auto_sync_options=auto_sync_options,
last_successful_index_time=last_successful_index_time,
)
db_session.add(association)
db_session.flush() # make sure the association has an id

View File

@@ -40,6 +40,8 @@ CREDENTIAL_PERMISSIONS_TO_IGNORE = {
DocumentSource.MEDIAWIKI,
}
PUBLIC_CREDENTIAL_ID = 0
def _add_user_filters(
stmt: Select,
@@ -241,7 +243,7 @@ def create_credential(
curator_public=credential_data.curator_public,
)
db_session.add(credential)
db_session.flush() # This ensures the credential gets an IDcredentials
db_session.flush() # This ensures the credential gets an ID
_relate_credential_to_user_groups__no_commit(
db_session=db_session,
credential_id=credential.id,
@@ -384,12 +386,11 @@ def delete_credential(
def create_initial_public_credential(db_session: Session) -> None:
public_cred_id = 0
error_msg = (
"DB is not in a valid initial state."
"There must exist an empty public credential for data connectors that do not require additional Auth."
)
first_credential = fetch_credential_by_id(public_cred_id, None, db_session)
first_credential = fetch_credential_by_id(PUBLIC_CREDENTIAL_ID, None, db_session)
if first_credential is not None:
if first_credential.credential_json != {} or first_credential.user is not None:
@@ -397,7 +398,7 @@ def create_initial_public_credential(db_session: Session) -> None:
return
credential = Credential(
id=public_cred_id,
id=PUBLIC_CREDENTIAL_ID,
credential_json={},
user_id=None,
)
@@ -405,6 +406,24 @@ def create_initial_public_credential(db_session: Session) -> None:
db_session.commit()
def cleanup_gmail_credentials(db_session: Session) -> None:
gmail_credentials = fetch_credentials_by_source(
db_session=db_session, user=None, document_source=DocumentSource.GMAIL
)
for credential in gmail_credentials:
db_session.delete(credential)
db_session.commit()
def cleanup_google_drive_credentials(db_session: Session) -> None:
google_drive_credentials = fetch_credentials_by_source(
db_session=db_session, user=None, document_source=DocumentSource.GOOGLE_DRIVE
)
for credential in google_drive_credentials:
db_session.delete(credential)
db_session.commit()
def delete_gmail_service_account_credentials(
user: User | None, db_session: Session
) -> None:

View File

@@ -375,6 +375,20 @@ def update_docs_last_modified__no_commit(
doc.last_modified = now
def mark_document_as_modified(
document_id: str,
db_session: Session,
) -> None:
stmt = select(DbDocument).where(DbDocument.id == document_id)
doc = db_session.scalar(stmt)
if doc is None:
raise ValueError(f"No document with ID: {document_id}")
# update last_synced
doc.last_modified = datetime.now(timezone.utc)
db_session.commit()
def mark_document_as_synced(document_id: str, db_session: Session) -> None:
stmt = select(DbDocument).where(DbDocument.id == document_id)
doc = db_session.scalar(stmt)

View File

@@ -36,10 +36,11 @@ from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
from danswer.configs.app_configs import POSTGRES_PORT
from danswer.configs.app_configs import POSTGRES_USER
from danswer.configs.app_configs import SECRET_JWT_KEY
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.configs.constants import TENANT_ID_PREFIX
from danswer.utils.logger import setup_logger
from shared_configs.configs import current_tenant_id
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
@@ -188,6 +189,29 @@ class SqlEngine:
return cls._app_name
def get_all_tenant_ids() -> list[str] | list[None]:
if not MULTI_TENANT:
return [None]
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as session:
result = session.execute(
text(
f"""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', '{POSTGRES_DEFAULT_SCHEMA}')"""
)
)
tenant_ids = [row[0] for row in result]
valid_tenants = [
tenant
for tenant in tenant_ids
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
]
return valid_tenants
def build_connection_string(
*,
db_api: str = ASYNC_DB_API,
@@ -236,12 +260,12 @@ def get_current_tenant_id(request: Request) -> str:
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
if not MULTI_TENANT:
tenant_id = POSTGRES_DEFAULT_SCHEMA
current_tenant_id.set(tenant_id)
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
return tenant_id
token = request.cookies.get("tenant_details")
if not token:
current_value = current_tenant_id.get()
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
# If no token is present, use the default schema or handle accordingly
return current_value
@@ -249,14 +273,14 @@ def get_current_tenant_id(request: Request) -> str:
payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"])
tenant_id = payload.get("tenant_id")
if not tenant_id:
return current_tenant_id.get()
return CURRENT_TENANT_ID_CONTEXTVAR.get()
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
current_tenant_id.set(tenant_id)
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
return tenant_id
except jwt.InvalidTokenError:
return current_tenant_id.get()
return CURRENT_TENANT_ID_CONTEXTVAR.get()
except Exception as e:
logger.error(f"Unexpected error in get_current_tenant_id: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@@ -267,7 +291,7 @@ async def get_async_session_with_tenant(
tenant_id: str | None = None,
) -> AsyncGenerator[AsyncSession, None]:
if tenant_id is None:
tenant_id = current_tenant_id.get()
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not is_valid_schema_name(tenant_id):
logger.error(f"Invalid tenant ID: {tenant_id}")
@@ -299,9 +323,9 @@ def get_session_with_tenant(
engine = get_sqlalchemy_engine()
if tenant_id is None:
tenant_id = current_tenant_id.get()
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
else:
current_tenant_id.set(tenant_id)
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
event.listen(engine, "checkout", set_search_path_on_checkout)
@@ -337,26 +361,22 @@ def get_session_with_tenant(
def set_search_path_on_checkout(
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
) -> None:
tenant_id = current_tenant_id.get()
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id and is_valid_schema_name(tenant_id):
with dbapi_conn.cursor() as cursor:
cursor.execute(f'SET search_path TO "{tenant_id}"')
logger.debug(
f"Set search_path to {tenant_id} for connection {connection_record}"
)
def get_session_generator_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
def get_session_generator_with_tenant() -> Generator[Session, None, None]:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
with get_session_with_tenant(tenant_id) as session:
yield session
def get_session() -> Generator[Session, None, None]:
"""Generate a database session with the appropriate tenant schema set."""
tenant_id = current_tenant_id.get()
if tenant_id == "public" and MULTI_TENANT:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
raise HTTPException(status_code=401, detail="User must authenticate")
engine = get_sqlalchemy_engine()
@@ -372,7 +392,7 @@ def get_session() -> Generator[Session, None, None]:
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
"""Generate an async database session with the appropriate tenant schema set."""
tenant_id = current_tenant_id.get()
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
engine = get_sqlalchemy_async_engine()
async with AsyncSession(engine, expire_on_commit=False) as async_session:
if MULTI_TENANT:

View File

@@ -1,5 +1,6 @@
from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from sqlalchemy import and_
@@ -66,6 +67,32 @@ def create_index_attempt(
return new_attempt.id
def mock_successful_index_attempt(
connector_credential_pair_id: int,
search_settings_id: int,
docs_indexed: int,
db_session: Session,
) -> int:
"""Should not be used in any user triggered flows"""
db_time = func.now()
new_attempt = IndexAttempt(
connector_credential_pair_id=connector_credential_pair_id,
search_settings_id=search_settings_id,
from_beginning=True,
status=IndexingStatus.SUCCESS,
total_docs_indexed=docs_indexed,
new_docs_indexed=docs_indexed,
# Need this to be some convincing random looking value and it can't be 0
# or the indexing rate would calculate out to infinity
time_started=db_time - timedelta(seconds=1.92),
time_updated=db_time,
)
db_session.add(new_attempt)
db_session.commit()
return new_attempt.id
def get_in_progress_index_attempts(
connector_id: int | None,
db_session: Session,

View File

@@ -95,10 +95,11 @@ def upsert_llm_provider(
group_ids=llm_provider.groups,
db_session=db_session,
)
full_llm_provider = FullLLMProvider.from_model(existing_llm_provider)
db_session.commit()
return FullLLMProvider.from_model(existing_llm_provider)
return full_llm_provider
def fetch_existing_embedding_providers(

View File

@@ -4,6 +4,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.sql import func
from danswer.auth.schemas import UserRole
from danswer.configs.constants import NotificationType
from danswer.db.models import Notification
from danswer.db.models import User
@@ -54,7 +55,9 @@ def get_notification_by_id(
notif = db_session.get(Notification, notification_id)
if not notif:
raise ValueError(f"No notification found with id {notification_id}")
if notif.user_id != user_id:
if notif.user_id != user_id and not (
notif.user_id is None and user is not None and user.role == UserRole.ADMIN
):
raise PermissionError(
f"User {user_id} is not authorized to access notification {notification_id}"
)

View File

@@ -328,7 +328,6 @@ def update_all_personas_display_priority(
for persona in personas:
persona.display_priority = display_priority_map[persona.id]
db_session.commit()

View File

@@ -8,7 +8,6 @@ from typing import Any
from typing import cast
import httpx
import requests
from retry import retry
from danswer.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
@@ -194,20 +193,21 @@ def _get_chunks_via_visit_api(
document_chunks: list[dict] = []
while True:
response = requests.get(url, params=params)
try:
response.raise_for_status()
except requests.HTTPError as e:
request_info = f"Headers: {response.request.headers}\nPayload: {params}"
response_info = f"Status Code: {response.status_code}\nResponse Content: {response.text}"
error_base = f"Error occurred getting chunk by Document ID {chunk_request.document_id}"
filtered_params = {k: v for k, v in params.items() if v is not None}
with get_vespa_http_client() as http_client:
response = http_client.get(url, params=filtered_params)
response.raise_for_status()
except httpx.HTTPError as e:
error_base = "Failed to query Vespa"
logger.error(
f"{error_base}:\n"
f"{request_info}\n"
f"{response_info}\n"
f"Exception: {e}"
f"Request URL: {e.request.url}\n"
f"Request Headers: {e.request.headers}\n"
f"Request Payload: {params}\n"
f"Exception: {str(e)}"
)
raise requests.HTTPError(error_base) from e
raise httpx.HTTPError(error_base) from e
# Check if the response contains any documents
response_data = response.json()
@@ -301,17 +301,13 @@ def query_vespa(
response = http_client.post(SEARCH_ENDPOINT, json=params)
response.raise_for_status()
except httpx.HTTPError as e:
request_info = f"Headers: {response.request.headers}\nPayload: {params}"
response_info = (
f"Status Code: {response.status_code}\n"
f"Response Content: {response.text}"
)
error_base = "Failed to query Vespa"
logger.error(
f"{error_base}:\n"
f"{request_info}\n"
f"{response_info}\n"
f"Exception: {e}"
f"Request URL: {e.request.url}\n"
f"Request Headers: {e.request.headers}\n"
f"Request Payload: {params}\n"
f"Exception: {str(e)}"
)
raise httpx.HTTPError(error_base) from e

View File

@@ -118,7 +118,7 @@ def get_existing_documents_from_chunks(
return document_ids
@retry(tries=3, delay=1, backoff=2)
@retry(tries=5, delay=1, backoff=2)
def _index_vespa_chunk(
chunk: DocMetadataAwareIndexChunk,
index_name: str,

View File

@@ -103,6 +103,9 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
self,
chunks: list[DocAwareChunk],
) -> list[IndexChunk]:
"""Adds embeddings to the chunks, the title and metadata suffixes are added to the chunk as well
if they exist. If there is no space for it, it would have been thrown out at the chunking step.
"""
# All chunks at this point must have some non-empty content
flat_chunk_texts: list[str] = []
large_chunks_present = False
@@ -121,6 +124,11 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
flat_chunk_texts.append(chunk_text)
if chunk.mini_chunk_texts:
if chunk.large_chunk_reference_ids:
# A large chunk does not contain mini chunks, if it matches the large chunk
# with a high score, then mini chunks would not be used anyway
# otherwise it should match the normal chunk
raise RuntimeError("Large chunk contains mini chunks")
flat_chunk_texts.extend(chunk.mini_chunk_texts)
embeddings = self.embedding_model.encode(

View File

@@ -195,6 +195,8 @@ def index_doc_batch_prepare(
db_session: Session,
ignore_time_skip: bool = False,
) -> DocumentBatchPrepareContext | None:
"""This sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
This preceeds indexing it into the actual document index."""
documents = []
for document in document_batch:
empty_contents = not any(section.text.strip() for section in document.sections)

View File

@@ -3,5 +3,6 @@ from danswer.key_value_store.store import PgRedisKVStore
def get_kv_store() -> KeyValueStore:
# this is the only one supported currently
# In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in
# It's read from the global thread level variable
return PgRedisKVStore()

View File

@@ -14,6 +14,8 @@ class KvKeyNotFoundError(Exception):
class KeyValueStore:
# In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in
# It's read from the global thread level variable
@abc.abstractmethod
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
raise NotImplementedError

View File

@@ -4,6 +4,7 @@ from contextlib import contextmanager
from typing import cast
from fastapi import HTTPException
from redis.client import Redis
from sqlalchemy import text
from sqlalchemy.orm import Session
@@ -16,8 +17,8 @@ from danswer.key_value_store.interface import KeyValueStore
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from shared_configs.configs import current_tenant_id
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
@@ -27,17 +28,23 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
class PgRedisKVStore(KeyValueStore):
def __init__(self) -> None:
tenant_id = current_tenant_id.get()
self.redis_client = get_redis_client(tenant_id=tenant_id)
def __init__(
self, redis_client: Redis | None = None, tenant_id: str | None = None
) -> None:
# If no redis_client is provided, fall back to the context var
if redis_client is not None:
self.redis_client = redis_client
else:
tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
self.redis_client = get_redis_client(tenant_id=tenant_id)
@contextmanager
def get_session(self) -> Iterator[Session]:
engine = get_sqlalchemy_engine()
with Session(engine, expire_on_commit=False) as session:
if MULTI_TENANT:
tenant_id = current_tenant_id.get()
if tenant_id == "public":
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id == POSTGRES_DEFAULT_SCHEMA:
raise HTTPException(
status_code=401, detail="User must authenticate"
)

View File

@@ -237,6 +237,7 @@ class Answer:
prompt=prompt,
tools=final_tool_definitions if final_tool_definitions else None,
tool_choice="required" if self.force_use_tool.force_use else None,
structured_response_format=self.answer_style_config.structured_response_format,
):
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
@@ -331,7 +332,10 @@ class Answer:
tool_choice: ToolChoiceOptions | None = None,
) -> Iterator[str | StreamStopInfo]:
for message in self.llm.stream(
prompt=prompt, tools=tools, tool_choice=tool_choice
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
structured_response_format=self.answer_style_config.structured_response_format,
):
if isinstance(message, AIMessageChunk):
if message.content:

View File

@@ -116,6 +116,10 @@ class AnswerStyleConfig(BaseModel):
document_pruning_config: DocumentPruningConfig = Field(
default_factory=DocumentPruningConfig
)
# forces the LLM to return a structured response, see
# https://platform.openai.com/docs/guides/structured-outputs/introduction
# right now, only used by the simple chat API
structured_response_format: dict | None = None
@model_validator(mode="after")
def check_quotes_and_citation(self) -> "AnswerStyleConfig":

View File

@@ -280,6 +280,7 @@ class DefaultMultiLLM(LLM):
tools: list[dict] | None,
tool_choice: ToolChoiceOptions | None,
stream: bool,
structured_response_format: dict | None = None,
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
if isinstance(prompt, list):
prompt = [
@@ -313,6 +314,11 @@ class DefaultMultiLLM(LLM):
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error
**({"parallel_tool_calls": False} if tools else {}),
**(
{"response_format": structured_response_format}
if structured_response_format
else {}
),
**self._model_kwargs,
)
except Exception as e:
@@ -336,12 +342,16 @@ class DefaultMultiLLM(LLM):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> BaseMessage:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
response = cast(
litellm.ModelResponse, self._completion(prompt, tools, tool_choice, False)
litellm.ModelResponse,
self._completion(
prompt, tools, tool_choice, False, structured_response_format
),
)
choice = response.choices[0]
if hasattr(choice, "message"):
@@ -354,18 +364,21 @@ class DefaultMultiLLM(LLM):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
if DISABLE_LITELLM_STREAMING:
yield self.invoke(prompt)
yield self.invoke(prompt, tools, tool_choice, structured_response_format)
return
output = None
response = cast(
litellm.CustomStreamWrapper,
self._completion(prompt, tools, tool_choice, True),
self._completion(
prompt, tools, tool_choice, True, structured_response_format
),
)
try:
for part in response:

View File

@@ -80,6 +80,7 @@ class CustomModelServer(LLM):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> BaseMessage:
return self._execute(prompt)
@@ -88,5 +89,6 @@ class CustomModelServer(LLM):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
yield self._execute(prompt)

View File

@@ -51,6 +51,7 @@ def get_llms_for_persona(
return get_llm(
provider=llm_provider.provider,
model=model,
deployment_name=llm_provider.deployment_name,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
@@ -104,7 +105,7 @@ def get_default_llms(
def get_llm(
provider: str,
model: str,
deployment_name: str | None = None,
deployment_name: str | None,
api_key: str | None = None,
api_base: str | None = None,
api_version: str | None = None,
@@ -116,6 +117,7 @@ def get_llm(
return DefaultMultiLLM(
model_provider=provider,
model_name=model,
deployment_name=deployment_name,
api_key=api_key,
api_base=api_base,
api_version=api_version,

View File

@@ -88,11 +88,14 @@ class LLM(abc.ABC):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> BaseMessage:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._invoke_implementation(prompt, tools, tool_choice)
return self._invoke_implementation(
prompt, tools, tool_choice, structured_response_format
)
@abc.abstractmethod
def _invoke_implementation(
@@ -100,6 +103,7 @@ class LLM(abc.ABC):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> BaseMessage:
raise NotImplementedError
@@ -108,11 +112,14 @@ class LLM(abc.ABC):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._stream_implementation(prompt, tools, tool_choice)
return self._stream_implementation(
prompt, tools, tool_choice, structured_response_format
)
@abc.abstractmethod
def _stream_implementation(
@@ -120,5 +127,6 @@ class LLM(abc.ABC):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
raise NotImplementedError

View File

@@ -61,6 +61,7 @@ BEDROCK_MODEL_NAMES = [
IGNORABLE_ANTHROPIC_MODELS = [
"claude-2",
"claude-instant-1",
"anthropic/claude-3-5-sonnet-20241022",
]
ANTHROPIC_PROVIDER_NAME = "anthropic"
ANTHROPIC_MODEL_NAMES = [
@@ -100,8 +101,8 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
api_version_required=False,
custom_config_keys=[],
llm_names=fetch_models_for_provider(ANTHROPIC_PROVIDER_NAME),
default_model="claude-3-5-sonnet-20240620",
default_fast_model="claude-3-5-sonnet-20240620",
default_model="claude-3-5-sonnet-20241022",
default_fast_model="claude-3-5-sonnet-20241022",
),
WellKnownLLMProviderDescriptor(
name=AZURE_PROVIDER_NAME,
@@ -135,8 +136,8 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
),
],
llm_names=fetch_models_for_provider(BEDROCK_PROVIDER_NAME),
default_model="anthropic.claude-3-5-sonnet-20240620-v1:0",
default_fast_model="anthropic.claude-3-5-sonnet-20240620-v1:0",
default_model="anthropic.claude-3-5-sonnet-20241022-v2:0",
default_fast_model="anthropic.claude-3-5-sonnet-20241022-v2:0",
),
]

View File

@@ -342,12 +342,26 @@ def get_llm_max_tokens(
try:
model_obj = model_map.get(f"{model_provider}/{model_name}")
if not model_obj:
model_obj = model_map[model_name]
logger.debug(f"Using model object for {model_name}")
else:
if model_obj:
logger.debug(f"Using model object for {model_provider}/{model_name}")
if not model_obj:
model_obj = model_map.get(model_name)
if model_obj:
logger.debug(f"Using model object for {model_name}")
if not model_obj:
model_name_split = model_name.split("/")
if len(model_name_split) > 1:
model_obj = model_map.get(model_name_split[1])
if model_obj:
logger.debug(f"Using model object for {model_name_split[1]}")
if not model_obj:
raise RuntimeError(
f"No litellm entry found for {model_provider}/{model_name}"
)
if "max_input_tokens" in model_obj:
max_tokens = model_obj["max_input_tokens"]
logger.info(

View File

@@ -96,6 +96,7 @@ from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import CORS_ALLOWED_ORIGIN
from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
@@ -213,7 +214,7 @@ def get_application() -> FastAPI:
sentry_sdk.init(
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=0.5,
traces_sample_rate=0.1,
)
logger.info("Sentry initialized")
else:

View File

@@ -61,6 +61,30 @@ class TenantRedis(redis.Redis):
return wrapper
def _prefix_scan_iter(self, method: Callable) -> Callable:
@functools.wraps(method)
def wrapper(*args: Any, **kwargs: Any) -> Any:
# Prefix the match pattern if provided
if "match" in kwargs:
kwargs["match"] = self._prefixed(kwargs["match"])
elif len(args) > 0:
args = (self._prefixed(args[0]),) + args[1:]
# Get the iterator
iterator = method(*args, **kwargs)
# Remove prefix from returned keys
prefix = f"{self.tenant_id}:".encode()
prefix_len = len(prefix)
for key in iterator:
if isinstance(key, bytes) and key.startswith(prefix):
yield key[prefix_len:]
else:
yield key
return wrapper
def __getattribute__(self, item: str) -> Any:
original_attr = super().__getattribute__(item)
methods_to_wrap = [
@@ -74,13 +98,15 @@ class TenantRedis(redis.Redis):
"hset",
"hget",
"getset",
"scan_iter",
"owned",
"reacquire",
"create_lock",
"startswith",
] # Add all methods that need prefixing
if item in methods_to_wrap and callable(original_attr):
] # Regular methods that need simple prefixing
if item == "scan_iter":
return self._prefix_scan_iter(original_attr)
elif item in methods_to_wrap and callable(original_attr):
return self._prefix_method(original_attr)
return original_attr

View File

@@ -10,7 +10,7 @@ from danswer.configs.chat_configs import HYBRID_ALPHA
from danswer.configs.chat_configs import HYBRID_ALPHA_KEYWORD
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.db.engine import current_tenant_id
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
from danswer.llm.interfaces import LLM
@@ -162,7 +162,7 @@ def retrieval_preprocessing(
time_cutoff=time_filter or predicted_time_cutoff,
tags=preset_filters.tags, # Tags are never auto-extracted
access_control_list=user_acl_filters,
tenant_id=current_tenant_id.get() if MULTI_TENANT else None,
tenant_id=CURRENT_TENANT_ID_CONTEXTVAR.get() if MULTI_TENANT else None,
)
llm_evaluation_type = LLMEvaluationType.BASIC

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,212 @@
import datetime
import json
import os
from typing import cast
from sqlalchemy.orm import Session
from danswer.access.models import default_public_access
from danswer.configs.constants import DEFAULT_BOOST
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_DOCUMENTS_SEEDED_KEY
from danswer.configs.model_configs import DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.connectors.models import Document
from danswer.connectors.models import IndexAttemptMetadata
from danswer.connectors.models import InputType
from danswer.connectors.models import Section
from danswer.db.connector import check_connectors_exist
from danswer.db.connector import create_connector
from danswer.db.connector_credential_pair import add_credential_to_connector
from danswer.db.credentials import PUBLIC_CREDENTIAL_ID
from danswer.db.document import check_docs_exist
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import mock_successful_index_attempt
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.indexing_pipeline import index_doc_batch_prepare
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.server.documents.models import ConnectorBase
from danswer.utils.logger import setup_logger
from danswer.utils.retry_wrapper import retry_builder
logger = setup_logger()
def _create_indexable_chunks(
preprocessed_docs: list[dict],
) -> tuple[list[Document], list[DocMetadataAwareIndexChunk]]:
ids_to_documents = {}
chunks = []
for preprocessed_doc in preprocessed_docs:
document = Document(
id=preprocessed_doc["url"], # For Web connector, the URL is the ID
# The section is not really used past this point since we have already done the other processing
# for the chunking and embedding.
sections=[
Section(text=preprocessed_doc["content"], link=preprocessed_doc["url"])
],
source=DocumentSource.WEB,
semantic_identifier=preprocessed_doc["title"],
metadata={},
doc_updated_at=None,
primary_owners=[],
secondary_owners=[],
)
if preprocessed_doc["chunk_ind"] == 0:
ids_to_documents[document.id] = document
chunk = DocMetadataAwareIndexChunk(
chunk_id=preprocessed_doc["chunk_ind"],
blurb=preprocessed_doc["content"]
.split(".", 1)[0]
.split("!", 1)[0]
.split("?", 1)[0],
content=preprocessed_doc["content"],
source_links={0: preprocessed_doc["url"]},
section_continuation=False,
source_document=document,
title_prefix=preprocessed_doc["title"],
metadata_suffix_semantic="",
metadata_suffix_keyword="",
mini_chunk_texts=None,
large_chunk_reference_ids=[],
embeddings=ChunkEmbedding(
full_embedding=preprocessed_doc["content_embedding"],
mini_chunk_embeddings=[],
),
title_embedding=preprocessed_doc["title_embedding"],
tenant_id=None,
access=default_public_access,
document_sets=set(),
boost=DEFAULT_BOOST,
)
chunks.append(chunk)
return list(ids_to_documents.values()), chunks
def seed_initial_documents(db_session: Session) -> None:
"""
Seed initial documents so users don't have an empty index to start
Documents are only loaded if:
- This is the first setup (if the user deletes the docs, we don't load them again)
- The index is empty, there are no docs and no (non-default) connectors
- The user has not updated the embedding models
- If they do, then we have to actually index the website
- If the embedding model is already updated on server startup, they're not a new user
Note that regardless of any search settings, the default documents are always loaded with
the predetermined chunk sizes and single pass embedding.
Steps are as follows:
- Check if this needs to run
- Create the connector representing this
- Create the cc-pair (attaching the public credential) and mocking values like the last success
- Indexing the documents into Postgres
- Indexing the documents into Vespa
- Create a fake index attempt with fake times
"""
logger.info("Seeding initial documents")
kv_store = get_kv_store()
try:
kv_store.load(KV_DOCUMENTS_SEEDED_KEY)
logger.info("Documents already seeded, skipping")
return
except KvKeyNotFoundError:
pass
if check_docs_exist(db_session):
logger.info("Documents already exist, skipping")
return
if check_connectors_exist(db_session):
logger.info("Connectors already exist, skipping")
return
search_settings = get_current_search_settings(db_session)
if search_settings.model_name != DEFAULT_DOCUMENT_ENCODER_MODEL:
logger.info("Embedding model has been updated, skipping")
return
document_index = get_default_document_index(
primary_index_name=search_settings.index_name, secondary_index_name=None
)
# Create a connector so the user can delete it if they want
# or reindex it with a new search model if they want
connector_data = ConnectorBase(
name="Sample Use Cases",
source=DocumentSource.WEB,
input_type=InputType.LOAD_STATE,
connector_specific_config={
"base_url": "https://docs.danswer.dev/more/use_cases",
"web_connector_type": "recursive",
},
refresh_freq=None, # Never refresh by default
prune_freq=None,
indexing_start=None,
)
connector = create_connector(db_session, connector_data)
connector_id = cast(int, connector.id)
last_index_time = datetime.datetime.now(datetime.timezone.utc)
result = add_credential_to_connector(
db_session=db_session,
user=None,
connector_id=connector_id,
credential_id=PUBLIC_CREDENTIAL_ID,
access_type=AccessType.PUBLIC,
cc_pair_name=connector_data.name,
groups=None,
initial_status=ConnectorCredentialPairStatus.PAUSED,
last_successful_index_time=last_index_time,
)
cc_pair_id = cast(int, result.data)
initial_docs_path = os.path.join(
os.getcwd(), "danswer", "seeding", "initial_docs.json"
)
processed_docs = json.load(open(initial_docs_path))
docs, chunks = _create_indexable_chunks(processed_docs)
index_doc_batch_prepare(
document_batch=docs,
index_attempt_metadata=IndexAttemptMetadata(
connector_id=connector_id,
credential_id=PUBLIC_CREDENTIAL_ID,
),
db_session=db_session,
ignore_time_skip=True, # Doesn't actually matter here
)
# In this case since there are no other connectors running in the background
# and this is a fresh deployment, there is no need to grab any locks
logger.info(
"Indexing seeding documents into Vespa "
"(Vespa may take a few seconds to become ready after receiving the schema)"
)
# Retries here because the index may take a few seconds to become ready
# as we just sent over the Vespa schema and there is a slight delay
index_with_retries = retry_builder()(document_index.index)
index_with_retries(chunks=chunks)
# Mock a run for the UI even though it did not actually call out to anything
mock_successful_index_attempt(
connector_credential_pair_id=cc_pair_id,
search_settings_id=search_settings.id,
docs_indexed=len(docs),
db_session=db_session,
)
kv_store.store(KV_DOCUMENTS_SEEDED_KEY, True)

View File

@@ -9,6 +9,7 @@ from danswer.connectors.models import IndexAttemptMetadata
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import get_documents_by_cc_pair
from danswer.db.document import get_ingestion_documents
from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
@@ -67,6 +68,7 @@ def upsert_ingestion_doc(
doc_info: IngestionDocument,
_: User | None = Depends(api_key_dep),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> IngestionResult:
doc_info.document.from_ingestion_api = True
@@ -101,6 +103,7 @@ def upsert_ingestion_doc(
document_index=curr_doc_index,
ignore_time_skip=True,
db_session=db_session,
tenant_id=tenant_id,
)
new_doc, __chunk_count = indexing_pipeline(
@@ -134,6 +137,7 @@ def upsert_ingestion_doc(
document_index=sec_doc_index,
ignore_time_skip=True,
db_session=db_session,
tenant_id=tenant_id,
)
sec_ind_pipeline(

Some files were not shown because too many files have changed in this diff Show More