Compare commits

..

68 Commits

Author SHA1 Message Date
hagen-danswer
e2aaa60e77 done 2024-11-21 17:26:18 -08:00
hagen-danswer
31f4a68bee less passing around is_cloud 2024-11-21 16:58:31 -08:00
hagen-danswer
4fc196fc39 properly escaped the user query 2024-11-21 16:37:48 -08:00
hagen-danswer
95ab63b6bc reworked it 2024-11-21 16:35:08 -08:00
hagen-danswer
6d26d0b929 replace deprecated confluence group api endpoint 2024-11-21 12:37:16 -08:00
pablodanswer
deee237c7e Sheet update (#3189)
* quick pass

* k

* update sheet

* add multiple sheet stuff

* k

* finalized

* update configuration
2024-11-21 18:07:00 +00:00
hagen-danswer
100b4a0d16 Added Slim connector for Jira (#3181)
* Added Slim connector for Jira

* fixed testing

* more cleanup of Jira connector

* cleanup
2024-11-21 17:00:20 +00:00
rkuo-danswer
70207b4b39 improve web testing (#3162)
* shared admin level test dependency

* change to on - push (recommended by chromatic)

* change playwright reporter to list, name test jobs

* use test tags ... much cleaner

* test vs prod

* try copying templates

* run with localhost?

* revert to dev

* new tests and a bit of refactoring

* add additional checks so that page snapshots reflect loaded state

* more admin tests

* User Management tests

* remaining admin pages

* test search and chat

* await fix and exclude UI that changes with dates.
2024-11-21 04:01:15 +00:00
pablodanswer
50826b6bef Formatting Niceties (#3183)
* search bar formatting

* update styling
2024-11-21 03:11:26 +00:00
pablodanswer
3f648cbc31 Folder clarity (#3180)
* folder clarity

* k
2024-11-21 03:11:17 +00:00
pablodanswer
c875a4774f valid props (#3186) 2024-11-21 01:13:54 +00:00
hagen-danswer
049091eb01 decreased confluence retry times and added more logging (#3184)
* decreased confluence retry times and added more logging

* added check on connector startup

* no retries!

* fr no retries
2024-11-21 00:00:14 +00:00
pablodanswer
3dac24542b silence small error (#3182) 2024-11-20 22:46:38 +00:00
pablodanswer
194dcb593d update slack redirect + token missing check (#3179)
* update slack redirect + token missing check

* reset time
2024-11-20 21:42:54 +00:00
pablodanswer
bf291d0c0a Fix missing json (#3177)
* initial steps

* k

* remove logs

* k

* k
2024-11-20 21:24:43 +00:00
rkuo-danswer
8309f4a802 test overlapping connectors (but using a source that is way too big a… (#3152)
* test overlapping connectors (but using a source that is way too big and slow, fix that next)

* pass thru secrets

* rename

* rename again

* now we are fixing it

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-20 21:12:01 +00:00
pablodanswer
0ff2565125 ensure margin properly applied (#3176)
* ensure margin properly applied

* formatting
2024-11-20 20:04:45 +00:00
hagen-danswer
e89dcd7f84 added logging and bugfixing to conf (#3167)
* standardized escaping of CQL strings

* think i found it

* fix

* should be fixed

* added handling for special linking behavior in confluence

* Update onyx_confluence.py

* Update onyx_confluence.py

---------

Co-authored-by: rkuo-danswer <rkuo@danswer.ai>
2024-11-20 18:40:21 +00:00
pablodanswer
645e7e828e Add Google Tag Manager for Web Cloud Build (#3173)
* add gtm for cloud build

* update github workflow
2024-11-20 17:38:33 +00:00
pablodanswer
2a54f14195 ensure everythigng has a default max height in selectorformfield (#3174) 2024-11-20 17:26:22 +00:00
hagen-danswer
9209fc804b multiple slackbot support (#3077)
* multiple slackbot support

* app_id + tenant_id key

* removed kv store stuff

* fixed up mypy and migration

* got frontend working for multiple slack bots

* some frontend stuff

* alembic fix

* might be valid

* refactor dun

* alembic stuff

* temp frontend stuff

* alembic stuff

* maybe fixed alembic

* maybe dis fix

* im getting mad

* api names changed

* tested

* almost done

* done

* routing nonsense

* done!

* done!!

* fr done

* doneski

* fix alembic migration

* getting mad again

* PLEASE IM BEGGING YOU
2024-11-20 01:49:43 +00:00
rkuo-danswer
b712877701 Merge pull request #3165 from danswer-ai/bugfix/pruning_logs
improve logging around pruning
2024-11-19 13:19:31 -08:00
Richard Kuo (Danswer)
e6df32dcc3 improve logging around pruning 2024-11-19 12:41:21 -08:00
Chris Weaver
eb81258a23 Update README.md
Fix slack link
2024-11-19 08:02:35 -08:00
hagen-danswer
487ef4acc0 Merge pull request #3160 from danswer-ai/add-to-admin-chat-sessions-api
Extend query history API
2024-11-19 07:28:12 -08:00
pablodanswer
9b7cc83eae add new date search filter (#3065)
* add new complicated filters

* clarity updates

* update date range filter
2024-11-19 03:42:42 +00:00
Weves
ce3124f9e4 Extend query history API 2024-11-18 17:50:21 -08:00
rkuo-danswer
e69303e309 add helpful hint on 507 (#3157)
* add helpful hint on 507

* add helpful hint to the direct exception in _index_vespa_chunk
2024-11-19 01:08:32 +00:00
rkuo-danswer
6e698ac84a Hardening deletion when cc pair relationships are left over (#3154)
* more logs

* this fence should be set to None

* type hinting

* reset deletion attempt if conditions are inconsistent

* always clean up in db if we reach reconciliation

* add reset method

* more logging

* harden up error checking
2024-11-19 01:07:59 +00:00
pablodanswer
d69180aeb8 add additional theming options (#3155)
* add additional theming options

* nit

* Update Filters.tsx
2024-11-18 22:56:48 +00:00
rkuo-danswer
aa37051be9 Bugfix/indexing redux (#3151)
* raise indexing lock timeout

* refactor unknown index attempts and redis lock
2024-11-18 22:47:31 +00:00
pablodanswer
a7d95661b3 Add assistant categories (#3064)
* add assistant categories v1

* functionality finalized

* finalize

* update assistant category display

* nit

* add tests

* post rebase update

* minor update to tests

* update typing

* finalize

* typing

* nit

* alembic

* alembic (once again)
2024-11-18 20:33:48 +00:00
Chris Weaver
33ee899408 Long term logs (#3150) 2024-11-18 10:48:03 -08:00
hagen-danswer
954b5b2a56 Made external permissioned users and slack users show diff (#3147)
* Made external permissioned users and slack users show diff

* finished

* Fix typing

* k

* Fix

* k

---------

Co-authored-by: Weves <chrisweaver101@gmail.com>
2024-11-17 01:13:47 +00:00
pablodanswer
521425a4f2 nits + pricing 2024-11-16 16:28:37 -08:00
hagen-danswer
618bc02d54 Fixed int test (#3148) 2024-11-16 18:13:06 +00:00
rkuo-danswer
b7de74fdf8 Feature/playwright tests (#3129)
* initial PoC

* preliminary working config

* first cut at chromatic tests

* first cut at chromatic tests

* fix yaml

* fix yaml again

* use workingDir

* adapt playwright example

* remove env

* fix working directory

* fix more paths

* fix dir

* add playwright setup

* accidentally deleted a step

* update test

* think we don't need home.png right now

* remove unused home.png

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-16 04:26:17 +00:00
hagen-danswer
6e83fe3a39 reworked drive+confluence frontend and implied backend changes (#3143)
* reworked drive+confluence frontend and implied backend changes

* fixed oauth admin tests

* fixed service account tests

* frontend cleanup

* copy change

* details!

* added key

* so good

* whoops!

* fixed mnore treljsertjoslijt

* has issue with boolean form

* should be done
2024-11-16 03:38:30 +00:00
Weves
259fc049b7 Add error message on JSON decode error in CustomTool 2024-11-15 20:00:12 -08:00
rkuo-danswer
7015e6f2ab Bugfix/overlapping connectors (#3138)
* fix tenant logging

* upsert only new/updated docs, but always upsert document to cc pair relationship

* better logging and rough cut at testing
2024-11-16 00:47:52 +00:00
pablodanswer
24be13c015 Improved tokenizer fallback (#3132)
* silence warning

* improved fallback logic

* k

* minor cosmetic update

* minor logic update

* nit
2024-11-14 20:13:29 -08:00
pablodanswer
ddff7ecc3f minor configuration updates (#3134) 2024-11-14 18:09:30 -08:00
Yuhong Sun
97932dc44b Fix Quotes Prompting (#3137) 2024-11-14 17:28:03 -08:00
rkuo-danswer
637b6d9e75 Merge pull request #3135 from danswer-ai/bugfix/helm_ct_python_setup
unnecessary python setup
2024-11-14 14:57:12 -08:00
Richard Kuo (Danswer)
54dc1ac917 unnecessary python setup 2024-11-14 11:14:12 -08:00
rkuo-danswer
21d5cc43f8 Merge pull request #3131 from danswer-ai/bugfix/session_text
use text()
2024-11-13 20:24:14 -08:00
pablodanswer
7c841051ed Cohere (#3111)
* add cohere default

* finalize

* minor improvement

* update

* update

* update configs

* ensure we properly expose name(space) for slackbot

* update config

* config
2024-11-14 01:58:54 +00:00
pablodanswer
6e91964924 minor clarity (#3116) 2024-11-14 01:42:21 +00:00
pablodanswer
facf1d55a0 Cloud improvements (#3099)
* add improved cloud configuration

* fix typing

* finalize slackbot improvements

* minor update

* finalized keda

* moderate slackbot switch

* update some configs

* revert

* include reset engine!
2024-11-13 23:52:52 +00:00
rkuo-danswer
d68f8d6fbc scale indexing sql pool based on concurrency (#3130) 2024-11-13 23:26:13 +00:00
Richard Kuo (Danswer)
65a205d488 use text() 2024-11-13 15:03:21 -08:00
hagen-danswer
485f3f72fa Updated google copy and added non admin oauth support (#3120)
* Updated google copy and added non admin oauth support

* backend update

* accounted for oauth

* further removed class variables

* updated sets
2024-11-13 20:07:10 +00:00
rkuo-danswer
dcbea883ae add creator id to cc pair (#3121)
* add creator id to cc pair

* fix alembic head

* show email instead of UUID

* safer check on email

* make foreign key relationships optional

* always allow creator to edit (per hagen)

* use primary join

* no index_doc_batch spam

* try this again

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-13 19:35:08 +00:00
hagen-danswer
a50a3944b3 Make curators able to create permission synced connectors (#3126)
* Make curators able to create permission synced connectors

* removed editing permission synced connectors for curators

* updated tests to use access type instead of is_public

* update copy
2024-11-13 18:58:23 +00:00
hagen-danswer
60471b6a73 Added support for page within a page in Confluence (#3125) 2024-11-13 16:39:00 +00:00
rkuo-danswer
d703e694ce limited role api keys (#3115)
* in progress PoC

* working limited user, needs routes to be marked next

* make selected endpoint available to limited user role

* xfail on test_slack_prune

* add comment to sync function

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-13 16:15:43 +00:00
hagen-danswer
6066042fef Merge pull request #3124 from danswer-ai/fix-doc-sync
quick fix for google doc sync
2024-11-13 07:30:52 -08:00
hagen-danswer
eb0e20b9e4 quick fix for google doc sync 2024-11-13 07:24:29 -08:00
pablodanswer
490a68773b update organization (#3118)
* update organization

* minor clean up

* add minor clarity

* k

* slight rejigger

* alembic fix

* update paradigm

* delete code!

* delete code

* minor update
2024-11-13 06:45:32 +00:00
rkuo-danswer
227aff1e47 clean up logging in light worker (#3072) 2024-11-13 03:42:02 +00:00
Weves
6e29d1944c Fix widget example 2024-11-12 18:48:44 -08:00
pablodanswer
22189f02c6 Add referral source to cloud on data plane (#3096)
* cloud auth referral source

* minor clarity

* k

* minor modification to be best practice

* typing

* Update ReferralSourceSelector.tsx

* Update ReferralSourceSelector.tsx

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
2024-11-13 00:42:25 +00:00
hagen-danswer
fdc4811fce doc sync celery refactor (#3084)
* doc_sync is refactored

* maybe this works

* tested to work!

* mypy fixes

* enabled integration tests

* fixed the test

* added external group sync

* testing should work now

* mypy

* confluence doc id fix

* got group sync working

* addressed feedback

* renamed some vars and fixed mypy

* conf fix?

* added wiki handling to confluence connector

* test fixes

* revert google drive connector

* fixed groups

* hotfix
2024-11-12 23:57:14 +00:00
Chris Weaver
021d0cf314 Support LITELLM_EXTRA_BODY env variable (#3119)
* Support LITELLM_EXTRA_BODY env variable

* Remove unused param

* Add comment
2024-11-12 23:17:44 +00:00
pablodanswer
942e47db29 improved mobile scroll (#3110) 2024-11-12 01:57:49 +00:00
pablodanswer
f4a020b599 moderate component fixes (#3095)
* moderate component fixes

* nit

* nit

* update colors

* k
2024-11-12 00:47:35 +00:00
pablodanswer
5166649eae Cleaner EE fallback for no op (#3106)
* treat async values differently

* cleaner approach

* spacing

* typing
2024-11-11 17:42:14 +00:00
Chris Weaver
ba805f766f New assistants api (#3097) 2024-11-11 07:55:23 -08:00
345 changed files with 18709 additions and 14390 deletions

View File

@@ -65,6 +65,7 @@ jobs:
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_GTM_ENABLED=true
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

225
.github/workflows/pr-chromatic-tests.yml vendored Normal file
View File

@@ -0,0 +1,225 @@
name: Run Chromatic Tests
concurrency:
group: Run-Chromatic-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on: push
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
jobs:
playwright-tests:
name: Playwright Tests
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
- name: Setup node
uses: actions/setup-node@v4
with:
node-version: 22
- name: Install node dependencies
working-directory: ./web
run: npm ci
- name: Install playwright browsers
working-directory: ./web
run: npx playwright install --with-deps
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# we use the runs-on cache for docker builds
# in conjunction with runs-on runners, it has better speed and unlimited caching
# https://runs-on.com/caching/s3-cache-for-github-actions/
# https://runs-on.com/caching/docker/
# https://github.com/moby/buildkit#s3-cache-experimental
# images are built and run locally for testing purposes. Not pushed.
- name: Build Web Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/amd64
tags: danswer/danswer-web-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Backend Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64
tags: danswer/danswer-backend:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Model Server Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64
tags: danswer/danswer-model-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
docker logs -f danswer-stack-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Run pytest playwright test init
working-directory: ./backend
env:
PYTEST_IGNORE_SKIP: true
run: pytest -s tests/integration/tests/playwright/test_playwright.py
- name: Run Playwright tests
working-directory: ./web
run: npx playwright test
- uses: actions/upload-artifact@v4
if: always()
with:
# Chromatic automatically defaults to the test-results directory.
# Replace with the path to your custom directory and adjust the CHROMATIC_ARCHIVE_LOCATION environment variable accordingly.
name: test-results
path: ./web/test-results
retention-days: 30
# save before stopping the containers so the logs can be captured
- name: Save Docker logs
if: success() || failure()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@v4
with:
name: docker-logs
path: ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
chromatic-tests:
name: Chromatic Tests
needs: playwright-tests
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup node
uses: actions/setup-node@v4
with:
node-version: 22
- name: Install node dependencies
working-directory: ./web
run: npm ci
- name: Download Playwright test results
uses: actions/download-artifact@v4
with:
name: test-results
path: ./web/test-results
- name: Run Chromatic
uses: chromaui/action@latest
with:
playwright: true
projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
workingDir: ./web
env:
CHROMATIC_ARCHIVE_LOCATION: ./test-results

View File

@@ -23,21 +23,6 @@ jobs:
with:
version: v3.14.4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.6.1
@@ -52,6 +37,22 @@ jobs:
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
# rkuo: I don't think we need python?
# - name: Set up Python
# uses: actions/setup-python@v5
# with:
# python-version: '3.11'
# cache: 'pip'
# cache-dependency-path: |
# backend/requirements/default.txt
# backend/requirements/dev.txt
# backend/requirements/model_server.txt
# - run: |
# python -m pip install --upgrade pip
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
# lint all charts if any changes were detected
- name: Run chart-testing (lint)
if: steps.list-changed.outputs.changed == 'true'

View File

@@ -13,7 +13,10 @@ on:
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
integration-tests:
# See https://runs-on.com/runners/linux/
@@ -195,9 +198,13 @@ jobs:
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
danswer/danswer-integration:test \
/app/tests/integration/tests
/app/tests/integration/tests \
/app/tests/integration/connector_job_tests
continue-on-error: true
id: run_tests

View File

@@ -20,6 +20,7 @@ env:
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
# Google
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1 }}
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}

1
.gitignore vendored
View File

@@ -7,3 +7,4 @@
.vscode/
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml
/web/test-results/

View File

@@ -203,7 +203,7 @@
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
],
"presentation": {
"group": "2",
@@ -232,7 +232,7 @@
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
],
"presentation": {
"group": "2",

View File

@@ -12,7 +12,7 @@
<a href="https://docs.danswer.dev/" target="_blank">
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
</a>
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ" target="_blank">
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA" target="_blank">
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
</a>
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
@@ -135,7 +135,7 @@ Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md
## ✨Contributors
<a href="https://github.com/aryn-ai/sycamore/graphs/contributors">
<a href="https://github.com/danswer-ai/danswer/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=danswer-ai/danswer"/>
</a>

View File

@@ -0,0 +1,68 @@
"""default chosen assistants to none
Revision ID: 26b931506ecb
Revises: 2daa494a0851
Create Date: 2024-11-12 13:23:29.858995
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "26b931506ecb"
down_revision = "2daa494a0851"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user", sa.Column("chosen_assistants_new", postgresql.JSONB(), nullable=True)
)
op.execute(
"""
UPDATE "user"
SET chosen_assistants_new =
CASE
WHEN chosen_assistants = '[-2, -1, 0]' THEN NULL
ELSE chosen_assistants
END
"""
)
op.drop_column("user", "chosen_assistants")
op.alter_column(
"user", "chosen_assistants_new", new_column_name="chosen_assistants"
)
def downgrade() -> None:
op.add_column(
"user",
sa.Column(
"chosen_assistants_old",
postgresql.JSONB(),
nullable=False,
server_default="[-2, -1, 0]",
),
)
op.execute(
"""
UPDATE "user"
SET chosen_assistants_old =
CASE
WHEN chosen_assistants IS NULL THEN '[-2, -1, 0]'::jsonb
ELSE chosen_assistants
END
"""
)
op.drop_column("user", "chosen_assistants")
op.alter_column(
"user", "chosen_assistants_old", new_column_name="chosen_assistants"
)

View File

@@ -0,0 +1,30 @@
"""add-group-sync-time
Revision ID: 2daa494a0851
Revises: c0fd6e4da83a
Create Date: 2024-11-11 10:57:22.991157
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2daa494a0851"
down_revision = "c0fd6e4da83a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"last_time_external_group_sync",
sa.DateTime(timezone=True),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "last_time_external_group_sync")

View File

@@ -0,0 +1,45 @@
"""add persona categories
Revision ID: 47e5bef3a1d7
Revises: dfbe9e93d3c7
Create Date: 2024-11-05 18:55:02.221064
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "47e5bef3a1d7"
down_revision = "dfbe9e93d3c7"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create the persona_category table
op.create_table(
"persona_category",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("description", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
# Add category_id to persona table
op.add_column("persona", sa.Column("category_id", sa.Integer(), nullable=True))
op.create_foreign_key(
"fk_persona_category",
"persona",
"persona_category",
["category_id"],
["id"],
ondelete="SET NULL",
)
def downgrade() -> None:
op.drop_constraint("fk_persona_category", "persona", type_="foreignkey")
op.drop_column("persona", "category_id")
op.drop_table("persona_category")

View File

@@ -0,0 +1,280 @@
"""add_multiple_slack_bot_support
Revision ID: 4ee1287bd26a
Revises: 47e5bef3a1d7
Create Date: 2024-11-06 13:15:53.302644
"""
import logging
from typing import cast
from alembic import op
import sqlalchemy as sa
from sqlalchemy.orm import Session
from danswer.key_value_store.factory import get_kv_store
from danswer.db.models import SlackBot
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "4ee1287bd26a"
down_revision = "47e5bef3a1d7"
branch_labels: None = None
depends_on: None = None
# Configure logging
logger = logging.getLogger("alembic.runtime.migration")
logger.setLevel(logging.INFO)
def upgrade() -> None:
logger.info(f"{revision}: create_table: slack_bot")
# Create new slack_bot table
op.create_table(
"slack_bot",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("enabled", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("bot_token", sa.LargeBinary(), nullable=False),
sa.Column("app_token", sa.LargeBinary(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("bot_token"),
sa.UniqueConstraint("app_token"),
)
# # Create new slack_channel_config table
op.create_table(
"slack_channel_config",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("slack_bot_id", sa.Integer(), nullable=True),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.Column("channel_config", postgresql.JSONB(), nullable=False),
sa.Column("response_type", sa.String(), nullable=False),
sa.Column(
"enable_auto_filters", sa.Boolean(), nullable=False, server_default="false"
),
sa.ForeignKeyConstraint(
["slack_bot_id"],
["slack_bot.id"],
),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# Handle existing Slack bot tokens first
logger.info(f"{revision}: Checking for existing Slack bot.")
bot_token = None
app_token = None
first_row_id = None
try:
tokens = cast(dict, get_kv_store().load("slack_bot_tokens_config_key"))
except Exception:
logger.warning("No existing Slack bot tokens found.")
tokens = {}
bot_token = tokens.get("bot_token")
app_token = tokens.get("app_token")
if bot_token and app_token:
logger.info(f"{revision}: Found bot and app tokens.")
session = Session(bind=op.get_bind())
new_slack_bot = SlackBot(
name="Slack Bot (Migrated)",
enabled=True,
bot_token=bot_token,
app_token=app_token,
)
session.add(new_slack_bot)
session.commit()
first_row_id = new_slack_bot.id
# Create a default bot if none exists
# This is in case there are no slack tokens but there are channels configured
op.execute(
sa.text(
"""
INSERT INTO slack_bot (name, enabled, bot_token, app_token)
SELECT 'Default Bot', true, '', ''
WHERE NOT EXISTS (SELECT 1 FROM slack_bot)
RETURNING id;
"""
)
)
# Get the bot ID to use (either from existing migration or newly created)
bot_id_query = sa.text(
"""
SELECT COALESCE(
:first_row_id,
(SELECT id FROM slack_bot ORDER BY id ASC LIMIT 1)
) as bot_id;
"""
)
result = op.get_bind().execute(bot_id_query, {"first_row_id": first_row_id})
bot_id = result.scalar()
# CTE (Common Table Expression) that transforms the old slack_bot_config table data
# This splits up the channel_names into their own rows
channel_names_cte = """
WITH channel_names AS (
SELECT
sbc.id as config_id,
sbc.persona_id,
sbc.response_type,
sbc.enable_auto_filters,
jsonb_array_elements_text(sbc.channel_config->'channel_names') as channel_name,
sbc.channel_config->>'respond_tag_only' as respond_tag_only,
sbc.channel_config->>'respond_to_bots' as respond_to_bots,
sbc.channel_config->'respond_member_group_list' as respond_member_group_list,
sbc.channel_config->'answer_filters' as answer_filters,
sbc.channel_config->'follow_up_tags' as follow_up_tags
FROM slack_bot_config sbc
)
"""
# Insert the channel names into the new slack_channel_config table
insert_statement = """
INSERT INTO slack_channel_config (
slack_bot_id,
persona_id,
channel_config,
response_type,
enable_auto_filters
)
SELECT
:bot_id,
channel_name.persona_id,
jsonb_build_object(
'channel_name', channel_name.channel_name,
'respond_tag_only',
COALESCE((channel_name.respond_tag_only)::boolean, false),
'respond_to_bots',
COALESCE((channel_name.respond_to_bots)::boolean, false),
'respond_member_group_list',
COALESCE(channel_name.respond_member_group_list, '[]'::jsonb),
'answer_filters',
COALESCE(channel_name.answer_filters, '[]'::jsonb),
'follow_up_tags',
COALESCE(channel_name.follow_up_tags, '[]'::jsonb)
),
channel_name.response_type,
channel_name.enable_auto_filters
FROM channel_names channel_name;
"""
op.execute(sa.text(channel_names_cte + insert_statement).bindparams(bot_id=bot_id))
# Clean up old tokens if they existed
try:
if bot_token and app_token:
logger.info(f"{revision}: Removing old bot and app tokens.")
get_kv_store().delete("slack_bot_tokens_config_key")
except Exception:
logger.warning("tried to delete tokens in dynamic config but failed")
# Rename the table
op.rename_table(
"slack_bot_config__standard_answer_category",
"slack_channel_config__standard_answer_category",
)
# Rename the column
op.alter_column(
"slack_channel_config__standard_answer_category",
"slack_bot_config_id",
new_column_name="slack_channel_config_id",
)
# Drop the table with CASCADE to handle dependent objects
op.execute("DROP TABLE slack_bot_config CASCADE")
logger.info(f"{revision}: Migration complete.")
def downgrade() -> None:
# Recreate the old slack_bot_config table
op.create_table(
"slack_bot_config",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.Column("channel_config", postgresql.JSONB(), nullable=False),
sa.Column("response_type", sa.String(), nullable=False),
sa.Column("enable_auto_filters", sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# Migrate data back to the old format
# Group by persona_id to combine channel names back into arrays
op.execute(
sa.text(
"""
INSERT INTO slack_bot_config (
persona_id,
channel_config,
response_type,
enable_auto_filters
)
SELECT DISTINCT ON (persona_id)
persona_id,
jsonb_build_object(
'channel_names', (
SELECT jsonb_agg(c.channel_config->>'channel_name')
FROM slack_channel_config c
WHERE c.persona_id = scc.persona_id
),
'respond_tag_only', (channel_config->>'respond_tag_only')::boolean,
'respond_to_bots', (channel_config->>'respond_to_bots')::boolean,
'respond_member_group_list', channel_config->'respond_member_group_list',
'answer_filters', channel_config->'answer_filters',
'follow_up_tags', channel_config->'follow_up_tags'
),
response_type,
enable_auto_filters
FROM slack_channel_config scc
WHERE persona_id IS NOT NULL;
"""
)
)
# Rename the table back
op.rename_table(
"slack_channel_config__standard_answer_category",
"slack_bot_config__standard_answer_category",
)
# Rename the column back
op.alter_column(
"slack_bot_config__standard_answer_category",
"slack_channel_config_id",
new_column_name="slack_bot_config_id",
)
# Try to save the first bot's tokens back to KV store
try:
first_bot = (
op.get_bind()
.execute(
sa.text(
"SELECT bot_token, app_token FROM slack_bot ORDER BY id LIMIT 1"
)
)
.first()
)
if first_bot and first_bot.bot_token and first_bot.app_token:
tokens = {
"bot_token": first_bot.bot_token,
"app_token": first_bot.app_token,
}
get_kv_store().store("slack_bot_tokens_config_key", tokens)
except Exception:
logger.warning("Failed to save tokens back to KV store")
# Drop the new tables in reverse order
op.drop_table("slack_channel_config")
op.drop_table("slack_bot")

View File

@@ -7,6 +7,7 @@ Create Date: 2024-10-26 13:06:06.937969
"""
from alembic import op
from sqlalchemy.orm import Session
from sqlalchemy import text
# Import your models and constants
from danswer.db.models import (
@@ -15,7 +16,6 @@ from danswer.db.models import (
Credential,
IndexAttempt,
)
from danswer.configs.constants import DocumentSource
# revision identifiers, used by Alembic.
@@ -30,13 +30,11 @@ def upgrade() -> None:
bind = op.get_bind()
session = Session(bind=bind)
connectors_to_delete = (
session.query(Connector)
.filter(Connector.source == DocumentSource.REQUESTTRACKER)
.all()
# Get connectors using raw SQL
result = bind.execute(
text("SELECT id FROM connector WHERE source = 'requesttracker'")
)
connector_ids = [connector.id for connector in connectors_to_delete]
connector_ids = [row[0] for row in result]
if connector_ids:
cc_pairs_to_delete = (

View File

@@ -0,0 +1,30 @@
"""add creator to cc pair
Revision ID: 9cf5c00f72fe
Revises: 26b931506ecb
Create Date: 2024-11-12 15:16:42.682902
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9cf5c00f72fe"
down_revision = "26b931506ecb"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"creator_id",
sa.UUID(as_uuid=True),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "creator_id")

View File

@@ -288,6 +288,15 @@ def upgrade() -> None:
def downgrade() -> None:
# NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints
# below
op.execute("DELETE FROM chat_feedback")
op.execute("DELETE FROM chat_message__search_doc")
op.execute("DELETE FROM document_retrieval_feedback")
op.execute("DELETE FROM document_retrieval_feedback")
op.execute("DELETE FROM chat_message")
op.execute("DELETE FROM chat_session")
op.drop_constraint(
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
)

View File

@@ -23,6 +23,56 @@ def upgrade() -> None:
def downgrade() -> None:
# Delete chat messages and feedback first since they reference chat sessions
# Get chat messages from sessions with null persona_id
chat_messages_query = """
SELECT id
FROM chat_message
WHERE chat_session_id IN (
SELECT id
FROM chat_session
WHERE persona_id IS NULL
)
"""
# Delete dependent records first
op.execute(
f"""
DELETE FROM document_retrieval_feedback
WHERE chat_message_id IN (
{chat_messages_query}
)
"""
)
op.execute(
f"""
DELETE FROM chat_message__search_doc
WHERE chat_message_id IN (
{chat_messages_query}
)
"""
)
# Delete chat messages
op.execute(
"""
DELETE FROM chat_message
WHERE chat_session_id IN (
SELECT id
FROM chat_session
WHERE persona_id IS NULL
)
"""
)
# Now we can safely delete the chat sessions
op.execute(
"""
DELETE FROM chat_session
WHERE persona_id IS NULL
"""
)
op.alter_column(
"chat_session",
"persona_id",

View File

@@ -0,0 +1,42 @@
"""extended_role_for_non_web
Revision ID: dfbe9e93d3c7
Revises: 9cf5c00f72fe
Create Date: 2024-11-16 07:54:18.727906
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "dfbe9e93d3c7"
down_revision = "9cf5c00f72fe"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
"""
UPDATE "user"
SET role = 'EXT_PERM_USER'
WHERE has_web_login = false
"""
)
op.drop_column("user", "has_web_login")
def downgrade() -> None:
op.add_column(
"user",
sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"),
)
op.execute(
"""
UPDATE "user"
SET has_web_login = false,
role = 'BASIC'
WHERE role IN ('SLACK_USER', 'EXT_PERM_USER')
"""
)

View File

@@ -16,6 +16,41 @@ class ExternalAccess:
is_public: bool
@dataclass(frozen=True)
class DocExternalAccess:
external_access: ExternalAccess
# The document ID
doc_id: str
def to_dict(self) -> dict:
return {
"external_access": {
"external_user_emails": list(self.external_access.external_user_emails),
"external_user_group_ids": list(
self.external_access.external_user_group_ids
),
"is_public": self.external_access.is_public,
},
"doc_id": self.doc_id,
}
@classmethod
def from_dict(cls, data: dict) -> "DocExternalAccess":
external_access = ExternalAccess(
external_user_emails=set(
data["external_access"].get("external_user_emails", [])
),
external_user_group_ids=set(
data["external_access"].get("external_user_group_ids", [])
),
is_public=data["external_access"]["is_public"],
)
return cls(
external_access=external_access,
doc_id=data["doc_id"],
)
@dataclass(frozen=True)
class DocumentAccess(ExternalAccess):
# User emails for Danswer users, None indicates admin

View File

@@ -2,8 +2,8 @@ from typing import cast
from danswer.configs.constants import KV_USER_STORE_KEY
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import JSON_ro
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.utils.special_types import JSON_ro
def get_invited_users() -> list[str]:

View File

@@ -13,12 +13,24 @@ class UserRole(str, Enum):
groups they are curators of
- Global Curator can perform admin actions
for all groups they are a member of
- Limited can access a limited set of basic api endpoints
- Slack are users that have used danswer via slack but dont have a web login
- External permissioned users that have been picked up during the external permissions sync process but don't have a web login
"""
LIMITED = "limited"
BASIC = "basic"
ADMIN = "admin"
CURATOR = "curator"
GLOBAL_CURATOR = "global_curator"
SLACK_USER = "slack_user"
EXT_PERM_USER = "ext_perm_user"
def is_web_login(self) -> bool:
return self not in [
UserRole.SLACK_USER,
UserRole.EXT_PERM_USER,
]
class UserStatus(str, Enum):
@@ -33,10 +45,8 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
has_web_login: bool | None = True
tenant_id: str | None = None
class UserUpdate(schemas.BaseUserUpdate):
role: UserRole
has_web_login: bool | None = True

View File

@@ -49,7 +49,6 @@ from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import text
from sqlalchemy.orm import attributes
from sqlalchemy.orm import Session
from danswer.auth.api_key import get_hashed_api_key_from_request
@@ -222,18 +221,25 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
user_db: SQLAlchemyUserDatabase[User, uuid.UUID]
async def create(
self,
user_create: schemas.UC | UserCreate,
safe: bool = False,
request: Optional[Request] = None,
) -> User:
referral_source = None
if request is not None:
referral_source = request.cookies.get("referral_source", None)
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user_create.email,
referral_source=referral_source,
)
async with get_async_session_with_tenant(tenant_id) as db_session:
@@ -242,7 +248,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
db_session, User, OAuthAccount
)
self.user_db = tenant_user_db
self.database = tenant_user_db
@@ -261,14 +269,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if (
not user.has_web_login
and hasattr(user_create, "has_web_login")
and user_create.has_web_login
):
if not user.role.is_web_login() and user_create.role.is_web_login():
user_update = UserUpdate(
password=user_create.password,
has_web_login=True,
role=user_create.role,
is_verified=user_create.is_verified,
)
@@ -282,7 +285,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
return user
async def oauth_callback(
self: "BaseUserManager[models.UOAP, models.ID]",
self,
oauth_name: str,
access_token: str,
account_id: str,
@@ -293,13 +296,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
*,
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> models.UOAP:
) -> User:
referral_source = None
if request:
referral_source = getattr(request.state, "referral_source", None)
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=account_email,
referral_source=referral_source,
)
if not tenant_id:
@@ -314,9 +322,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
verify_email_domain(account_email)
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
db_session, User, OAuthAccount
)
self.user_db = tenant_user_db
self.database = tenant_user_db # type: ignore
self.database = tenant_user_db
oauth_account_dict = {
"oauth_name": oauth_name,
@@ -368,7 +378,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
and existing_oauth_account.oauth_name == oauth_name
):
user = await self.user_db.update_oauth_account(
user, existing_oauth_account, oauth_account_dict
user,
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
# but the type checker doesn't know that :(
existing_oauth_account, # type: ignore
oauth_account_dict,
)
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
@@ -381,16 +395,15 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
)
# Handle case where user has used product outside of web and is now creating an account through web
if not user.has_web_login: # type: ignore
if not user.role.is_web_login():
await self.user_db.update(
user,
{
"is_verified": is_verified_by_default,
"has_web_login": True,
"role": UserRole.BASIC,
},
)
user.is_verified = is_verified_by_default
user.has_web_login = True # type: ignore
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
# otherwise, the oidc expiry will always be old, and the user will never be able to login
@@ -465,9 +478,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
self.password_helper.hash(credentials.password)
return None
has_web_login = attributes.get_attribute(user, "has_web_login")
if not has_web_login:
if not user.role.is_web_login():
raise BasicAuthenticationError(
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
)
@@ -652,12 +663,26 @@ async def current_user_with_expired_token(
return await double_check_user(user, include_expired=True)
async def current_user(
async def current_limited_user(
user: User | None = Depends(optional_user),
) -> User | None:
return await double_check_user(user)
async def current_user(
user: User | None = Depends(optional_user),
) -> User | None:
user = await double_check_user(user)
if not user:
return None
if user.role == UserRole.LIMITED:
raise BasicAuthenticationError(
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
)
return user
async def current_curator_or_admin_user(
user: User | None = Depends(current_user),
) -> User | None:
@@ -711,8 +736,6 @@ def generate_state_token(
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
def create_danswer_oauth_router(
oauth_client: BaseOAuth2,
backend: AuthenticationBackend,
@@ -762,15 +785,22 @@ def get_oauth_router(
response_model=OAuth2AuthorizeResponse,
)
async def authorize(
request: Request, scopes: List[str] = Query(None)
request: Request,
scopes: List[str] = Query(None),
) -> OAuth2AuthorizeResponse:
referral_source = request.cookies.get("referral_source", None)
if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
authorize_redirect_url = str(request.url_for(callback_route_name))
next_url = request.query_params.get("next", "/")
state_data: Dict[str, str] = {"next_url": next_url}
state_data: Dict[str, str] = {
"next_url": next_url,
"referral_source": referral_source or "default_referral",
}
state = generate_state_token(state_data, state_secret)
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
@@ -829,8 +859,11 @@ def get_oauth_router(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)
# Authenticate user
request.state.referral_source = referral_source
# Proceed to authenticate or create the user
try:
user = await user_manager.oauth_callback(
oauth_client.name,
@@ -872,7 +905,6 @@ def get_oauth_router(
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type
return redirect_response
return router

View File

@@ -24,6 +24,8 @@ from danswer.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_document_set import RedisDocumentSet
from danswer.redis.redis_pool import get_redis_client
@@ -136,6 +138,22 @@ def on_task_postrun(
RedisConnectorPrune.remove_from_taskset(int(cc_pair_id), task_id, r)
return
if task_id.startswith(RedisConnectorPermissionSync.SUBTASK_PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorPermissionSync.remove_from_taskset(
int(cc_pair_id), task_id, r
)
return
if task_id.startswith(RedisConnectorExternalGroupSync.SUBTASK_PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorExternalGroupSync.remove_from_taskset(
int(cc_pair_id), task_id, r
)
return
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
"""The first signal sent on celery worker startup"""

View File

@@ -12,6 +12,7 @@ from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
@@ -72,6 +73,15 @@ class DynamicTenantScheduler(PersistentScheduler):
logger.info(f"Found {len(existing_tenants)} existing tenants in schedule")
for tenant_id in tenant_ids:
if (
IGNORED_SYNCING_TENANT_LIST
and tenant_id in IGNORED_SYNCING_TENANT_LIST
):
logger.info(
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
)
continue
if tenant_id not in existing_tenants:
logger.info(f"Processing new tenant: {tenant_id}")

View File

@@ -91,5 +91,7 @@ def on_setup_logging(
celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.pruning",
"danswer.background.celery.tasks.doc_permission_syncing",
"danswer.background.celery.tasks.external_group_syncing",
]
)

View File

@@ -6,6 +6,7 @@ from celery import signals
from celery import Task
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
@@ -59,7 +60,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
@@ -81,6 +82,11 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None:
SqlEngine.reset_engine()
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any

View File

@@ -92,5 +92,6 @@ celery_app.autodiscover_tasks(
"danswer.background.celery.tasks.shared",
"danswer.background.celery.tasks.vespa",
"danswer.background.celery.tasks.connector_deletion",
"danswer.background.celery.tasks.doc_permission_syncing",
]
)

View File

@@ -14,12 +14,18 @@ from celery.signals import worker_shutdown
import danswer.background.celery.apps.app_base as app_base
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.background.celery.tasks.vespa.tasks import get_unfenced_index_attempt_ids
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_session_with_default_tenant
from danswer.db.engine import SqlEngine
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_connector_stop import RedisConnectorStop
@@ -134,6 +140,27 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
RedisConnectorStop.reset_all(r)
RedisConnectorPermissionSync.reset_all(r)
RedisConnectorExternalGroupSync.reset_all(r)
# mark orphaned index attempts as failed
with get_session_with_default_tenant() as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
continue
failure_reason = (
f"Orphaned index attempt found on startup: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
logger.warning(failure_reason)
mark_attempt_failed(attempt.id, db_session, failure_reason)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
@@ -233,6 +260,8 @@ celery_app.autodiscover_tasks(
"danswer.background.celery.tasks.connector_deletion",
"danswer.background.celery.tasks.indexing",
"danswer.background.celery.tasks.periodic",
"danswer.background.celery.tasks.doc_permission_syncing",
"danswer.background.celery.tasks.external_group_syncing",
"danswer.background.celery.tasks.pruning",
"danswer.background.celery.tasks.shared",
"danswer.background.celery.tasks.vespa",

View File

@@ -1,96 +0,0 @@
from datetime import timedelta
from typing import Any
from celery.beat import PersistentScheduler # type: ignore
from celery.utils.log import get_task_logger
from danswer.db.engine import get_all_tenant_ids
from danswer.utils.variable_functionality import fetch_versioned_implementation
logger = get_task_logger(__name__)
class DynamicTenantScheduler(PersistentScheduler):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._reload_interval = timedelta(minutes=1)
self._last_reload = self.app.now() - self._reload_interval
def setup_schedule(self) -> None:
super().setup_schedule()
def tick(self) -> float:
retval = super().tick()
now = self.app.now()
if (
self._last_reload is None
or (now - self._last_reload) > self._reload_interval
):
logger.info("Reloading schedule to check for new tenants...")
self._update_tenant_tasks()
self._last_reload = now
return retval
def _update_tenant_tasks(self) -> None:
logger.info("Checking for tenant task updates...")
try:
tenant_ids = get_all_tenant_ids()
tasks_to_schedule = fetch_versioned_implementation(
"danswer.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
)
new_beat_schedule: dict[str, dict[str, Any]] = {}
current_schedule = getattr(self, "_store", {"entries": {}}).get(
"entries", {}
)
existing_tenants = set()
for task_name in current_schedule.keys():
if "-" in task_name:
existing_tenants.add(task_name.split("-")[-1])
for tenant_id in tenant_ids:
if tenant_id not in existing_tenants:
logger.info(f"Found new tenant: {tenant_id}")
for task in tasks_to_schedule():
task_name = f"{task['name']}-{tenant_id}"
new_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {"tenant_id": tenant_id},
}
if options := task.get("options"):
new_task["options"] = options
new_beat_schedule[task_name] = new_task
if self._should_update_schedule(current_schedule, new_beat_schedule):
logger.info(
"Updating schedule",
extra={
"new_tasks": len(new_beat_schedule),
"current_tasks": len(current_schedule),
},
)
if not hasattr(self, "_store"):
self._store: dict[str, dict] = {"entries": {}}
self.update_from_dict(new_beat_schedule)
logger.info(f"New schedule: {new_beat_schedule}")
logger.info("Tenant tasks updated successfully")
else:
logger.debug("No schedule updates needed")
except (AttributeError, KeyError):
logger.exception("Failed to process task configuration")
except Exception:
logger.exception("Unexpected error updating tenant tasks")
def _should_update_schedule(
self, current_schedule: dict, new_schedule: dict
) -> bool:
"""Compare schedules to determine if an update is needed."""
current_tasks = set(current_schedule.keys())
new_tasks = set(new_schedule.keys())
return current_tasks != new_tasks

View File

@@ -81,7 +81,7 @@ def extract_ids_from_runnable_connector(
callback: RunIndexingCallbackInterface | None = None,
) -> set[str]:
"""
If the PruneConnector hasnt been implemented for the given connector, just pull
If the SlimConnector hasnt been implemented for the given connector, just pull
all docs using the load_from_state and grab out the IDs.
Optionally, a callback can be passed to handle the length of each document batch.

View File

@@ -8,7 +8,7 @@ tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": "check_for_vespa_sync_task",
"schedule": timedelta(seconds=5),
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
@@ -20,13 +20,13 @@ tasks_to_schedule = [
{
"name": "check-for-indexing",
"task": "check_for_indexing",
"schedule": timedelta(seconds=10),
"schedule": timedelta(seconds=15),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": "check_for_pruning",
"schedule": timedelta(seconds=10),
"schedule": timedelta(seconds=15),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
@@ -41,6 +41,18 @@ tasks_to_schedule = [
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-doc-permissions-sync",
"task": "check_for_doc_permissions_sync",
"schedule": timedelta(seconds=30),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-external-group-sync",
"task": "check_for_external_group_sync",
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
]

View File

@@ -1,12 +1,12 @@
from datetime import datetime
from datetime import timezone
import redis
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
@@ -87,7 +87,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
cc_pair_id: int,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
@@ -143,6 +143,12 @@ def try_generate_document_cc_pair_cleanup_tasks(
f"cc_pair={cc_pair_id}"
)
if redis_connector.permissions.fenced:
raise TaskDependencyError(
f"Connector deletion - Delayed (permissions in progress): "
f"cc_pair={cc_pair_id}"
)
# add tasks to celery and build up the task set to monitor in redis
redis_connector.delete.taskset_clear()

View File

@@ -0,0 +1,321 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from uuid import uuid4
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from danswer.access.models import DocExternalAccess
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import DocumentSource
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_ext_perm_user_if_not_exists
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncData,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import doc_permission_sync_ctx
from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.danswer.db.document import upsert_document_external_perms
from ee.danswer.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
logger = setup_logger()
DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES = 3
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if external doc permissions sync is due."""
if cc_pair.access_type != AccessType.SYNC:
return False
# skip doc permissions sync if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return False
# If the last sync is None, it has never been run so we run the sync
last_perm_sync = cc_pair.last_time_perm_sync
if last_perm_sync is None:
return True
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
if not source_sync_period:
return True
# If the last sync is greater than the full fetch period, we run the sync
next_sync = last_perm_sync + timedelta(seconds=source_sync_period)
if datetime.now(timezone.utc) >= next_sync:
return True
return False
@shared_task(
name="check_for_doc_permissions_sync",
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
# get all cc pairs that need to be synced
cc_pair_ids_to_sync: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if _is_external_doc_permissions_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_permissions_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
continue
task_logger.info(f"Doc permissions sync queued: cc_pair={cc_pair_id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
def try_creating_permissions_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Returns None if no syncing is required."""
redis_connector = RedisConnector(tenant_id, cc_pair_id)
LOCK_TIMEOUT = 30
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
if redis_connector.permissions.fenced:
return None
if redis_connector.delete.fenced:
return None
if redis_connector.prune.fenced:
return None
redis_connector.permissions.generator_clear()
redis_connector.permissions.taskset_clear()
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
app.send_task(
"connector_permission_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.HIGH,
)
# set a basic fence to start
payload = RedisConnectorPermissionSyncData(
started=None,
)
redis_connector.permissions.set_fence(payload)
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
return None
finally:
if lock.owned():
lock.release()
return 1
@shared_task(
name="connector_permission_sync_generator_task",
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
trail=False,
bind=True,
)
def connector_permission_sync_generator_task(
self: Task,
cc_pair_id: int,
tenant_id: str | None,
) -> None:
"""
Permission sync task that handles document permission syncing for a given connector credential pair
This task assumes that the task has already been properly fenced
"""
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id
doc_permission_sync_ctx_dict["request_id"] = self.request.id
doc_permission_sync_ctx.set(doc_permission_sync_ctx_dict)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
lock = r.lock(
DanswerRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"Permission sync task already running, exiting...: cc_pair={cc_pair_id}"
)
return None
try:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(
f"No connector credential pair found for id: {cc_pair_id}"
)
source_type = cc_pair.connector.source
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
if doc_sync_func is None:
raise ValueError(f"No doc sync func found for {source_type}")
logger.info(f"Syncing docs for {source_type}")
payload = RedisConnectorPermissionSyncData(
started=datetime.now(timezone.utc),
)
redis_connector.permissions.set_fence(payload)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
task_logger.info(
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = redis_connector.permissions.generate_tasks(
self.app, lock, document_external_accesses, source_type
)
if tasks_generated is None:
return None
task_logger.info(
f"RedisConnector.permissions.generate_tasks finished. "
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
)
redis_connector.permissions.generator_complete = tasks_generated
except Exception as e:
task_logger.exception(f"Failed to run permission sync: cc_pair={cc_pair_id}")
redis_connector.permissions.generator_clear()
redis_connector.permissions.taskset_clear()
redis_connector.permissions.set_fence(None)
raise e
finally:
if lock.owned():
lock.release()
@shared_task(
name="update_external_document_permissions_task",
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
bind=True,
)
def update_external_document_permissions_task(
self: Task,
tenant_id: str | None,
serialized_doc_external_access: dict,
source_string: str,
) -> bool:
document_external_access = DocExternalAccess.from_dict(
serialized_doc_external_access
)
doc_id = document_external_access.doc_id
external_access = document_external_access.external_access
try:
with get_session_with_tenant(tenant_id) as db_session:
# Then we build the update requests to update vespa
batch_add_ext_perm_user_if_not_exists(
db_session=db_session,
emails=list(external_access.external_user_emails),
)
upsert_document_external_perms(
db_session=db_session,
doc_id=doc_id,
external_access=external_access,
source_type=DocumentSource(source_string),
)
logger.debug(
f"Successfully synced postgres document permissions for {doc_id}"
)
return True
except Exception:
logger.exception("Error Syncing Document Permissions")
return False

View File

@@ -0,0 +1,265 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from uuid import uuid4
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import mark_cc_pair_as_external_group_synced
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
logger = setup_logger()
EXTERNAL_GROUPS_UPDATE_MAX_RETRIES = 3
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if external group sync is due."""
if cc_pair.access_type != AccessType.SYNC:
return False
# skip pruning if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return False
# If there is not group sync function for the connector, we don't run the sync
# This is fine because all sources dont necessarily have a concept of groups
if not GROUP_PERMISSIONS_FUNC_MAP.get(cc_pair.connector.source):
return False
# If the last sync is None, it has never been run so we run the sync
last_ext_group_sync = cc_pair.last_time_external_group_sync
if last_ext_group_sync is None:
return True
source_sync_period = EXTERNAL_GROUP_SYNC_PERIODS.get(cc_pair.connector.source)
# If EXTERNAL_GROUP_SYNC_PERIODS is None, we always run the sync.
if not source_sync_period:
return True
# If the last sync is greater than the full fetch period, we run the sync
next_sync = last_ext_group_sync + timedelta(seconds=source_sync_period)
if datetime.now(timezone.utc) >= next_sync:
return True
return False
@shared_task(
name="check_for_external_group_sync",
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
cc_pair_ids_to_sync: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if _is_external_group_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_permissions_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
continue
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
def try_creating_permissions_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Returns None if no syncing is required."""
redis_connector = RedisConnector(tenant_id, cc_pair_id)
LOCK_TIMEOUT = 30
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
# Dont kick off a new sync if the previous one is still running
if redis_connector.external_group_sync.fenced:
return None
redis_connector.external_group_sync.generator_clear()
redis_connector.external_group_sync.taskset_clear()
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
_ = app.send_task(
"connector_external_group_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.HIGH,
)
# set a basic fence to start
redis_connector.external_group_sync.set_fence(True)
except Exception:
task_logger.exception(
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
)
return None
finally:
if lock.owned():
lock.release()
return 1
@shared_task(
name="connector_external_group_sync_generator_task",
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
trail=False,
bind=True,
)
def connector_external_group_sync_generator_task(
self: Task,
cc_pair_id: int,
tenant_id: str | None,
) -> None:
"""
Permission sync task that handles document permission syncing for a given connector credential pair
This task assumes that the task has already been properly fenced
"""
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
lock = r.lock(
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
)
try:
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
)
return None
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(
f"No connector credential pair found for id: {cc_pair_id}"
)
source_type = cc_pair.connector.source
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
if ext_group_sync_func is None:
raise ValueError(f"No external group sync func found for {source_type}")
logger.info(f"Syncing docs for {source_type}")
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
logger.info(
f"Syncing {len(external_user_groups)} external user groups for {source_type}"
)
replace_user__ext_group_for_cc_pair(
db_session=db_session,
cc_pair_id=cc_pair.id,
group_defs=external_user_groups,
source=cc_pair.connector.source,
)
logger.info(
f"Synced {len(external_user_groups)} external user groups for {source_type}"
)
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
except Exception as e:
task_logger.exception(
f"Failed to run external group sync: cc_pair={cc_pair_id}"
)
redis_connector.external_group_sync.generator_clear()
redis_connector.external_group_sync.taskset_clear()
raise e
finally:
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
redis_connector.external_group_sync.set_fence(False)
if lock.owned():
lock.release()

View File

@@ -3,13 +3,14 @@ from datetime import timezone
from http import HTTPStatus
from time import sleep
import redis
import sentry_sdk
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
@@ -44,7 +45,7 @@ from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_index import RedisConnectorIndexingFenceData
from danswer.redis.redis_connector_index import RedisConnectorIndexPayload
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
@@ -61,14 +62,18 @@ class RunIndexingCallback(RunIndexingCallbackInterface):
self,
stop_key: str,
generator_progress_key: str,
redis_lock: redis.lock.Lock,
redis_lock: RedisLock,
redis_client: Redis,
):
super().__init__()
self.redis_lock: redis.lock.Lock = redis_lock
self.redis_lock: RedisLock = redis_lock
self.stop_key: str = stop_key
self.generator_progress_key: str = generator_progress_key
self.redis_client = redis_client
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
def should_stop(self) -> bool:
if self.redis_client.exists(self.stop_key):
@@ -76,7 +81,19 @@ class RunIndexingCallback(RunIndexingCallbackInterface):
return False
def progress(self, amount: int) -> None:
self.redis_lock.reacquire()
try:
self.redis_lock.reacquire()
self.last_lock_reacquire = datetime.now(timezone.utc)
except LockError:
logger.exception(
f"RunIndexingCallback - lock.reacquire exceptioned. "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_reacquired={self.last_lock_reacquire} "
f"now={datetime.now(timezone.utc)}"
)
raise
self.redis_client.incrby(self.generator_progress_key, amount)
@@ -175,7 +192,8 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
)
if attempt_id:
task_logger.info(
f"Indexing queued: index_attempt={attempt_id} "
f"Connector indexing queued: "
f"index_attempt={attempt_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id} "
)
@@ -325,7 +343,7 @@ def try_creating_indexing_task(
redis_connector_index.generator_clear()
# set a basic fence to start
payload = RedisConnectorIndexingFenceData(
payload = RedisConnectorIndexPayload(
index_attempt_id=None,
started=None,
submitted=datetime.now(timezone.utc),
@@ -366,9 +384,8 @@ def try_creating_indexing_task(
payload.index_attempt_id = index_attempt_id
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
except Exception:
redis_connector_index.set_fence(payload)
redis_connector_index.set_fence(None)
task_logger.exception(
f"Unexpected exception: "
f"tenant={tenant_id} "
@@ -499,7 +516,8 @@ def connector_indexing_task(
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
logger.info(
f"Indexing spawned task starting: attempt={index_attempt_id} "
f"Indexing spawned task starting: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"

View File

@@ -38,6 +38,42 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if pruning is due.
Next pruning time is calculated as a delta from the last successful prune, or the
last successful indexing if pruning has never succeeded.
TODO(rkuo): consider whether we should allow pruning to be immediately rescheduled
if pruning fails (which is what it does now). A backoff could be reasonable.
"""
# skip pruning if no prune frequency is set
# pruning can still be forced via the API which will run a pruning task directly
if not cc_pair.connector.prune_freq:
return False
# skip pruning if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
# skip pruning if the next scheduled prune time hasn't been reached yet
last_pruned = cc_pair.last_pruned
if not last_pruned:
if not cc_pair.last_successful_index_time:
# if we've never indexed, we can't prune
return False
# if never pruned, use the last time the connector indexed successfully
last_pruned = cc_pair.last_successful_index_time
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
return False
return True
@shared_task(
name="check_for_pruning",
soft_time_limit=JOB_TIMEOUT,
@@ -69,7 +105,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
if not cc_pair:
continue
if not is_pruning_due(cc_pair, db_session, r):
if not _is_pruning_due(cc_pair):
continue
tasks_created = try_creating_prune_generator_task(
@@ -90,47 +126,6 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
lock_beat.release()
def is_pruning_due(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
) -> bool:
"""Returns an int if pruning is triggered.
The int represents the number of prune tasks generated (in this case, only one
because the task is a long running generator task.)
Returns None if no pruning is triggered (due to not being needed or
other reasons such as simultaneous pruning restrictions.
Checks for scheduling related conditions, then delegates the rest of the checks to
try_creating_prune_generator_task.
"""
# skip pruning if no prune frequency is set
# pruning can still be forced via the API which will run a pruning task directly
if not cc_pair.connector.prune_freq:
return False
# skip pruning if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
# skip pruning if the next scheduled prune time hasn't been reached yet
last_pruned = cc_pair.last_pruned
if not last_pruned:
if not cc_pair.last_successful_index_time:
# if we've never indexed, we can't prune
return False
# if never pruned, use the last time the connector indexed successfully
last_pruned = cc_pair.last_successful_index_time
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
return False
return True
def try_creating_prune_generator_task(
celery_app: Celery,
cc_pair: ConnectorCredentialPair,
@@ -166,10 +161,16 @@ def try_creating_prune_generator_task(
return None
try:
if redis_connector.prune.fenced: # skip pruning if already pruning
# skip pruning if already pruning
if redis_connector.prune.fenced:
return None
if redis_connector.delete.fenced: # skip pruning if the cc_pair is deleting
# skip pruning if the cc_pair is deleting
if redis_connector.delete.fenced:
return None
# skip pruning if doc permissions sync is running
if redis_connector.permissions.fenced:
return None
db_session.refresh(cc_pair)
@@ -231,6 +232,8 @@ def connector_pruning_generator_task(
pruning_ctx_dict["request_id"] = self.request.id
pruning_ctx.set(pruning_ctx_dict)
task_logger.info(f"Pruning generator starting: cc_pair={cc_pair_id}")
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
@@ -261,6 +264,11 @@ def connector_pruning_generator_task(
)
return
task_logger.info(
f"Pruning generator running connector: "
f"cc_pair={cc_pair_id} "
f"connector_source={cc_pair.connector.source}"
)
runnable_connector = instantiate_connector(
db_session,
cc_pair.connector.source,
@@ -275,6 +283,7 @@ def connector_pruning_generator_task(
lock,
r,
)
# a list of docs in the source
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
runnable_connector, callback
@@ -296,8 +305,8 @@ def connector_pruning_generator_task(
task_logger.info(
f"Pruning set collected: "
f"cc_pair={cc_pair_id} "
f"docs_to_remove={len(doc_ids_to_remove)} "
f"doc_source={cc_pair.connector.source}"
f"connector_source={cc_pair.connector.source} "
f"docs_to_remove={len(doc_ids_to_remove)}"
)
task_logger.info(
@@ -320,10 +329,10 @@ def connector_pruning_generator_task(
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
)
redis_connector.prune.generator_clear()
redis_connector.prune.taskset_clear()
redis_connector.prune.set_fence(False)
redis_connector.prune.reset()
raise e
finally:
if lock.owned():
lock.release()
task_logger.info(f"Pruning generator finished: cc_pair={cc_pair_id}")

View File

@@ -59,7 +59,7 @@ def document_by_cc_pair_cleanup_task(
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
task_logger.info(f"tenant={tenant_id} doc={document_id}")
task_logger.debug(f"Task start: tenant={tenant_id} doc={document_id}")
try:
with get_session_with_tenant(tenant_id) as db_session:
@@ -141,7 +141,9 @@ def document_by_cc_pair_cleanup_task(
return False
except Exception as ex:
if isinstance(ex, RetryError):
task_logger.info(f"Retry failed: {ex.last_attempt.attempt_number}")
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
@@ -171,11 +173,21 @@ def document_by_cc_pair_cleanup_task(
else:
# This is the last attempt! mark the document as dirty in the db so that it
# eventually gets fixed out of band via stale document reconciliation
task_logger.info(
f"Max retries reached. Marking doc as dirty for reconciliation: "
task_logger.warning(
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
f"tenant={tenant_id} doc={document_id}"
)
with get_session_with_tenant(tenant_id):
with get_session_with_tenant(tenant_id) as db_session:
# delete the cc pair relationship now and let reconciliation clean it up
# in vespa
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_id=document_id,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
mark_document_as_modified(document_id, db_session)
return False

View File

@@ -13,6 +13,7 @@ from celery.exceptions import SoftTimeLimitExceeded
from celery.result import AsyncResult
from celery.states import READY_STATES
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from tenacity import RetryError
@@ -27,6 +28,7 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector import mark_cc_pair_as_permissions_synced
from danswer.db.connector import mark_ccpair_as_pruned
from danswer.db.connector_credential_pair import add_deletion_failure_message
from danswer.db.connector_credential_pair import (
@@ -58,6 +60,10 @@ from danswer.document_index.interfaces import VespaDocumentFields
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncData,
)
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_document_set import RedisDocumentSet
@@ -162,7 +168,7 @@ def try_generate_stale_document_sync_tasks(
celery_app: Celery,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:
# the fence is up, do nothing
@@ -180,7 +186,12 @@ def try_generate_stale_document_sync_tasks(
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
)
task_logger.info("RedisConnector.generate_tasks starting by cc_pair.")
task_logger.info(
"RedisConnector.generate_tasks starting by cc_pair. "
"Documents spanning multiple cc_pairs will only be synced once."
)
docs_to_skip: set[str] = set()
# rkuo: we could technically sync all stale docs in one big pass.
# but I feel it's more understandable to group the docs by cc_pair
@@ -188,22 +199,21 @@ def try_generate_stale_document_sync_tasks(
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
tasks_generated = rc.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
rc.set_skip_docs(docs_to_skip)
result = rc.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
if tasks_generated is None:
if result is None:
continue
if tasks_generated == 0:
if result[1] == 0:
continue
task_logger.info(
f"RedisConnector.generate_tasks finished for single cc_pair. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
f"cc_pair={cc_pair.id} tasks_generated={result[0]} tasks_possible={result[1]}"
)
total_tasks_generated += tasks_generated
total_tasks_generated += result[0]
task_logger.info(
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
@@ -218,7 +228,7 @@ def try_generate_document_set_sync_tasks(
document_set_id: int,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
@@ -246,12 +256,11 @@ def try_generate_document_set_sync_tasks(
)
# Add all documents that need to be updated into the queue
tasks_generated = rds.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
result = rds.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
if result is None:
return None
tasks_generated = result[0]
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
@@ -260,7 +269,7 @@ def try_generate_document_set_sync_tasks(
task_logger.info(
f"RedisDocumentSet.generate_tasks finished. "
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
f"document_set={document_set.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
@@ -273,7 +282,7 @@ def try_generate_user_group_sync_tasks(
usergroup_id: int,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
@@ -302,12 +311,11 @@ def try_generate_user_group_sync_tasks(
task_logger.info(
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
)
tasks_generated = rug.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
result = rug.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
if result is None:
return None
tasks_generated = result[0]
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
@@ -316,7 +324,7 @@ def try_generate_user_group_sync_tasks(
task_logger.info(
f"RedisUserGroup.generate_tasks finished. "
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
f"usergroup={usergroup.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
@@ -436,11 +444,22 @@ def monitor_connector_deletion_taskset(
db_session, cc_pair.connector_id, cc_pair.credential_id
)
if len(doc_ids) > 0:
# if this happens, documents somehow got added while deletion was in progress. Likely a bug
# gating off pruning and indexing work before deletion starts
# NOTE(rkuo): if this happens, documents somehow got added while
# deletion was in progress. Likely a bug gating off pruning and indexing
# work before deletion starts.
task_logger.warning(
f"Connector deletion - documents still found after taskset completion: "
f"cc_pair={cc_pair_id} num={len(doc_ids)}"
"Connector deletion - documents still found after taskset completion. "
"Clearing the current deletion attempt and allowing deletion to restart: "
f"cc_pair={cc_pair_id} "
f"docs_deleted={fence_data.num_tasks} "
f"docs_remaining={len(doc_ids)}"
)
# We don't want to waive off why we get into this state, but resetting
# our attempt and letting the deletion restart is a good way to recover
redis_connector.delete.reset()
raise RuntimeError(
"Connector deletion - documents still found after taskset completion"
)
# clean up the rest of the related Postgres entities
@@ -504,8 +523,7 @@ def monitor_connector_deletion_taskset(
f"docs_deleted={fence_data.num_tasks}"
)
redis_connector.delete.taskset_clear()
redis_connector.delete.set_fence(None)
redis_connector.delete.reset()
def monitor_ccpair_pruning_taskset(
@@ -546,6 +564,47 @@ def monitor_ccpair_pruning_taskset(
redis_connector.prune.set_fence(False)
def monitor_ccpair_permissions_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.permissions.fenced:
return
initial = redis_connector.permissions.generator_complete
if initial is None:
return
remaining = redis_connector.permissions.get_remaining()
task_logger.info(
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
)
if remaining > 0:
return
payload: RedisConnectorPermissionSyncData | None = (
redis_connector.permissions.payload
)
start_time: datetime | None = payload.started if payload else None
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
redis_connector.permissions.taskset_clear()
redis_connector.permissions.generator_clear()
redis_connector.permissions.set_fence(None)
def monitor_ccpair_indexing_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
@@ -580,8 +639,8 @@ def monitor_ccpair_indexing_taskset(
progress = redis_connector_index.get_progress()
if progress is not None:
task_logger.info(
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"Connector indexing progress: cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
@@ -596,33 +655,41 @@ def monitor_ccpair_indexing_taskset(
result_state = result.state
status_int = redis_connector_index.get_completion()
if status_int is None:
if status_int is None: # completion signal not set ... check for errors
# If we get here, and then the task both sets the completion signal and finishes,
# we will incorrectly abort the task. We must check result state, then check
# get_completion again to avoid the race condition.
if result_state in READY_STATES:
# IF the task state is READY, THEN generator_complete should be set
# if it isn't, then the worker crashed
task_logger.info(
f"Connector indexing aborted: "
f"cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
if index_attempt:
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason="Connector indexing aborted or exceptioned.",
if redis_connector_index.get_completion() is None:
# IF the task state is READY, THEN generator_complete should be set
# if it isn't, then the worker crashed
msg = (
f"Connector indexing aborted or exceptioned: "
f"attempt={payload.index_attempt_id} "
f"celery_task={payload.celery_task_id} "
f"result_state={result_state} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
task_logger.warning(msg)
redis_connector_index.reset()
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
if index_attempt:
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
redis_connector_index.reset()
return
status_enum = HTTPStatus(status_int)
task_logger.info(
f"Connector indexing finished: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"Connector indexing finished: cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
@@ -630,6 +697,37 @@ def monitor_ccpair_indexing_taskset(
redis_connector_index.reset()
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
want to clean them up.
Unfenced = attempt not in terminal state and fence does not exist.
"""
unfenced_attempts: list[int] = []
# do some cleanup before clearing fences
# check the db for any outstanding index attempts
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
# if attempts exist in the db but we don't detect them in redis, mark them as failed
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)
if r.exists(fence_key):
continue
unfenced_attempts.append(attempt.id)
return unfenced_attempts
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
@@ -643,7 +741,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
"""
r = get_redis_client(tenant_id=tenant_id)
lock_beat: redis.lock.Lock = r.lock(
lock_beat: RedisLock = r.lock(
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -668,40 +766,37 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
n_pruning = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_PRUNING, r_celery
)
n_permissions_sync = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
)
task_logger.info(
f"Queue lengths: celery={n_celery} "
f"indexing={n_indexing} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning}"
f"pruning={n_pruning} "
f"permissions_sync={n_permissions_sync} "
)
# do some cleanup before clearing fences
# check the db for any outstanding index attempts
# Fail any index attempts in the DB that don't have fences
with get_session_with_tenant(tenant_id) as db_session:
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
continue
for a in attempts:
# if attempts exist in the db but we don't detect them in redis, mark them as failed
fence_key = RedisConnectorIndex.fence_key_with_ids(
a.connector_credential_pair_id, a.search_settings_id
failure_reason = (
f"Unfenced index attempt found in DB: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
task_logger.warning(failure_reason)
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
if not r.exists(fence_key):
failure_reason = (
f"Unknown index attempt. Might be left over from a process restart: "
f"index_attempt={a.id} "
f"cc_pair={a.connector_credential_pair_id} "
f"search_settings={a.search_settings_id}"
)
task_logger.warning(failure_reason)
mark_attempt_failed(a.id, db_session, failure_reason=failure_reason)
lock_beat.reacquire()
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
@@ -741,6 +836,12 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
# uncomment for debugging if needed
# r_celery = celery_app.broker_connection().channel().client
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
@@ -811,7 +912,9 @@ def vespa_metadata_sync_task(
)
except Exception as ex:
if isinstance(ex, RetryError):
task_logger.warning(f"Retry failed: {ex.last_attempt.attempt_number}")
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()

View File

@@ -29,18 +29,26 @@ JobStatusType = (
def _initializer(
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
) -> Any:
"""Ensure the parent proc's database connections are not touched
in the new connection pool
"""Initialize the child process with a fresh SQLAlchemy Engine.
Based on the recommended approach in the SQLAlchemy docs found:
Based on SQLAlchemy's recommendations to handle multiprocessing:
https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
"""
if kwargs is None:
kwargs = {}
logger.info("Initializing spawned worker child process.")
# Reset the engine in the child process
SqlEngine.reset_engine()
# Optionally set a custom app name for database logging purposes
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
# Initialize a new engine with desired parameters
SqlEngine.init_engine(pool_size=4, max_overflow=12, pool_recycle=60)
# Proceed with executing the target function
return func(*args, **kwargs)

View File

@@ -33,8 +33,8 @@ from danswer.document_index.factory import get_default_document_index
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import IndexAttemptSingleton
from danswer.utils.logger import setup_logger
from danswer.utils.logger import TaskAttemptSingleton
from danswer.utils.variable_functionality import global_version
logger = setup_logger()
@@ -427,17 +427,19 @@ def run_indexing_entrypoint(
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
IndexAttemptSingleton.set_cc_and_index_id(
TaskAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
)
with get_session_with_tenant(tenant_id) as db_session:
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
tenant_str = ""
if tenant_id is not None:
tenant_str = f" for tenant {tenant_id}"
logger.info(
f"Indexing starting for tenant {tenant_id}: "
if tenant_id is not None
else ""
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
f"Indexing starting{tenant_str}: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
@@ -445,10 +447,8 @@ def run_indexing_entrypoint(
_run_indexing(db_session, attempt, tenant_id, callback)
logger.info(
f"Indexing finished for tenant {tenant_id}: "
if tenant_id is not None
else ""
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
f"Indexing finished{tenant_str}: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)

View File

@@ -1,4 +0,0 @@
def name_sync_external_doc_permissions_task(
cc_pair_id: int, tenant_id: str | None = None
) -> str:
return f"sync_external_doc_permissions_task__{cc_pair_id}"

View File

@@ -14,15 +14,6 @@ from danswer.db.tasks import mark_task_start
from danswer.db.tasks import register_task
def name_cc_prune_task(
connector_id: int | None = None, credential_id: int | None = None
) -> str:
task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}"
if not connector_id or not credential_id:
task_name = "prune_connector_credential_pair"
return task_name
T = TypeVar("T", bound=Callable)

View File

@@ -19,16 +19,10 @@ from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.chat.models import StreamStopInfo
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.chat import attach_files_to_chat_message
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
@@ -41,7 +35,6 @@ from danswer.db.chat import reserve_message_id
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
@@ -61,14 +54,13 @@ from danswer.llm.answering.models import PromptConfig
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import litellm_exception_to_error_msg
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import InferenceSection
from danswer.search.models import RetrievalDetails
from danswer.search.retrieval.search_runner import inference_sections_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
@@ -77,14 +69,14 @@ from danswer.search.utils import relevant_sections_to_indices
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.force import ForceUseTool
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
from danswer.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.tool_constructor import construct_tools
from danswer.tools.tool_constructor import CustomToolConfig
from danswer.tools.tool_constructor import ImageGenerationToolConfig
from danswer.tools.tool_constructor import InternetSearchToolConfig
from danswer.tools.tool_constructor import SearchToolConfig
from danswer.tools.tool_implementations.custom.custom_tool import (
CUSTOM_TOOL_RESPONSE_ID,
)
@@ -95,9 +87,6 @@ from danswer.tools.tool_implementations.images.image_generation_tool import (
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
INTERNET_SEARCH_RESPONSE_ID,
)
@@ -122,10 +111,8 @@ from danswer.tools.tool_implementations.search.search_tool import (
SECTION_RELEVANCE_LIST_ID,
)
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.headers import header_dict_to_header_list
from danswer.utils.logger import setup_logger
from danswer.utils.long_term_log import LongTermLogger
from danswer.utils.timing import log_generator_function_time
logger = setup_logger()
@@ -295,7 +282,6 @@ def stream_chat_message_objects(
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
# if specified, uses the last user message and does not create a new user message based
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
@@ -307,6 +293,9 @@ def stream_chat_message_objects(
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
use_existing_user_message = new_msg_req.use_existing_user_message
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
# Currently surrounding context is not supported for chat
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
new_msg_req.chunks_above = 0
@@ -328,6 +317,11 @@ def stream_chat_message_objects(
retrieval_options = new_msg_req.retrieval_options
alternate_assistant_id = new_msg_req.alternate_assistant_id
# permanent "log" store, used primarily for debugging
long_term_logger = LongTermLogger(
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
)
# use alternate persona if alternative assistant id is passed in
if alternate_assistant_id is not None:
persona = get_persona_by_id(
@@ -353,6 +347,7 @@ def stream_chat_message_objects(
persona=persona,
llm_override=new_msg_req.llm_override or chat_session.llm_override,
additional_headers=litellm_additional_headers,
long_term_logger=long_term_logger,
)
except GenAIDisabledException:
raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.")
@@ -428,12 +423,20 @@ def stream_chat_message_objects(
final_msg, history_msgs = create_chat_chain(
chat_session_id=chat_session_id, db_session=db_session
)
if final_msg.message_type != MessageType.USER:
raise RuntimeError(
"The last message was not a user message. Cannot call "
"`stream_chat_message_objects` with `is_regenerate=True` "
"when the last message is not a user message."
)
if existing_assistant_message_id is None:
if final_msg.message_type != MessageType.USER:
raise RuntimeError(
"The last message was not a user message. Cannot call "
"`stream_chat_message_objects` with `is_regenerate=True` "
"when the last message is not a user message."
)
else:
if final_msg.id != existing_assistant_message_id:
raise RuntimeError(
"The last message was not the existing assistant message. "
f"Final message id: {final_msg.id}, "
f"existing assistant message id: {existing_assistant_message_id}"
)
# Disable Query Rephrasing for the first message
# This leads to a better first response since the LLM rephrasing the question
@@ -504,13 +507,19 @@ def stream_chat_message_objects(
),
max_window_percentage=max_document_percentage,
)
reserved_message_id = reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id
if user_message is not None
else parent_message.id,
message_type=MessageType.ASSISTANT,
# we don't need to reserve a message id if we're using an existing assistant message
reserved_message_id = (
final_msg.id
if existing_assistant_message_id is not None
else reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id
if user_message is not None
else parent_message.id,
message_type=MessageType.ASSISTANT,
)
)
yield MessageResponseIDInfo(
user_message_id=user_message.id if user_message else None,
@@ -525,7 +534,13 @@ def stream_chat_message_objects(
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
parent_message=final_msg,
# if we're using an existing assistant message, then this will just be an
# update operation, in which case the parent should be the parent of
# the latest. If we're creating a new assistant message, then the parent
# should be the latest message (latest user message)
parent_message=(
final_msg if existing_assistant_message_id is None else parent_message
),
prompt_id=prompt_id,
overridden_model=overridden_model,
# message=,
@@ -537,6 +552,7 @@ def stream_chat_message_objects(
# reference_docs=,
db_session=db_session,
commit=False,
reserved_message_id=reserved_message_id,
)
if not final_msg.prompt:
@@ -560,142 +576,39 @@ def stream_chat_message_objects(
structured_response_format=new_msg_req.structured_response_format,
)
# find out what tools to use
search_tool: SearchTool | None = None
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
for db_tool_model in persona.tools:
# handle in-code tools specially
if db_tool_model.in_code_tool_id:
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
answer_style_config=answer_style_config,
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
)
tool_dict[db_tool_model.id] = [search_tool]
elif tool_cls.__name__ == ImageGenerationTool.__name__:
img_generation_llm_config: LLMConfig | None = None
if (
llm
and llm.config.api_key
and llm.config.model_provider == "openai"
):
img_generation_llm_config = LLMConfig(
model_provider=llm.config.model_provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=llm.config.api_key,
api_base=llm.config.api_base,
api_version=llm.config.api_version,
)
elif (
llm.config.model_provider == "azure"
and AZURE_DALLE_API_KEY is not None
):
img_generation_llm_config = LLMConfig(
model_provider="azure",
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
temperature=GEN_AI_TEMPERATURE,
api_key=AZURE_DALLE_API_KEY,
api_base=AZURE_DALLE_API_BASE,
api_version=AZURE_DALLE_API_VERSION,
)
else:
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
iter(
[
llm_provider
for llm_provider in llm_providers
if llm_provider.provider == "openai"
]
),
None,
)
if not openai_provider or not openai_provider.api_key:
raise ValueError(
"Image generation tool requires an OpenAI API key"
)
img_generation_llm_config = LLMConfig(
model_provider=openai_provider.provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=openai_provider.api_key,
api_base=openai_provider.api_base,
api_version=openai_provider.api_version,
)
tool_dict[db_tool_model.id] = [
ImageGenerationTool(
api_key=cast(str, img_generation_llm_config.api_key),
api_base=img_generation_llm_config.api_base,
api_version=img_generation_llm_config.api_version,
additional_headers=litellm_additional_headers,
model=img_generation_llm_config.model_name,
)
]
elif tool_cls.__name__ == InternetSearchTool.__name__:
bing_api_key = BING_API_KEY
if not bing_api_key:
raise ValueError(
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
)
tool_dict[db_tool_model.id] = [
InternetSearchTool(
api_key=bing_api_key,
answer_style_config=answer_style_config,
prompt_config=prompt_config,
)
]
continue
# handle all custom tools
if db_tool_model.openapi_schema:
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(
db_tool_model.openapi_schema,
dynamic_schema_info=DynamicSchemaInfo(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
),
custom_headers=(db_tool_model.custom_headers or [])
+ (
header_dict_to_header_list(
custom_tool_additional_headers or {}
)
),
),
)
tool_dict = construct_tools(
persona=persona,
prompt_config=prompt_config,
db_session=db_session,
user=user,
llm=llm,
fast_llm=fast_llm,
search_tool_config=SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
retrieval_options=retrieval_options or RetrievalDetails(),
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
latest_query_files=latest_query_files,
),
internet_search_tool_config=InternetSearchToolConfig(
answer_style_config=answer_style_config,
),
image_generation_tool_config=ImageGenerationToolConfig(
additional_headers=litellm_additional_headers,
),
custom_tool_config=CustomToolConfig(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
additional_headers=custom_tool_additional_headers,
),
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
# factor in tool definition size when pruning
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
tools, llm_tokenizer
)
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
llm_provider, llm_model_name
)
# LLM prompt building, response capturing, etc.
answer = Answer(
is_connected=is_connected,
@@ -871,7 +784,6 @@ def stream_chat_message_objects(
tool_name_to_tool_id[tool.name] = tool_id
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query if qa_docs_response else None
@@ -879,9 +791,11 @@ def stream_chat_message_objects(
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=message_specific_citations.citation_map
if message_specific_citations
else None,
citations=(
message_specific_citations.citation_map
if message_specific_citations
else None
),
error=None,
tool_call=(
ToolCall(
@@ -915,7 +829,6 @@ def stream_chat_message_objects(
def stream_chat_message(
new_msg_req: CreateChatMessageRequest,
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
@@ -925,7 +838,6 @@ def stream_chat_message(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
custom_tool_additional_headers=custom_tool_additional_headers,
is_connected=is_connected,

View File

@@ -503,3 +503,7 @@ _API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS")
API_KEY_HASH_ROUNDS = (
int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None
)
POD_NAME = os.environ.get("POD_NAME")
POD_NAMESPACE = os.environ.get("POD_NAMESPACE")

View File

@@ -60,7 +60,6 @@ KV_GMAIL_CRED_KEY = "gmail_app_credential"
KV_GMAIL_SERVICE_ACCOUNT_KEY = "gmail_service_account_key"
KV_GOOGLE_DRIVE_CRED_KEY = "google_drive_app_credential"
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
KV_SLACK_BOT_TOKENS_CONFIG_KEY = "slack_bot_tokens_config_key"
KV_GEN_AI_KEY_CHECK_TIME = "genai_api_key_last_check_time"
KV_SETTINGS_KEY = "danswer_settings"
KV_CUSTOMER_UUID_KEY = "customer_uuid"
@@ -74,12 +73,16 @@ CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_INDEXING_LOCK_TIMEOUT = 60 * 60 # 60 min
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 300 # 5 min
CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
@@ -209,9 +212,17 @@ class PostgresAdvisoryLocks(Enum):
class DanswerCeleryQueues:
# Light queue
VESPA_METADATA_SYNC = "vespa_metadata_sync"
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
CONNECTOR_DELETION = "connector_deletion"
# Heavy queue
CONNECTOR_PRUNING = "connector_pruning"
CONNECTOR_DOC_PERMISSIONS_SYNC = "connector_doc_permissions_sync"
CONNECTOR_EXTERNAL_GROUP_SYNC = "connector_external_group_sync"
# Indexing queue
CONNECTOR_INDEXING = "connector_indexing"
@@ -221,8 +232,18 @@ class DanswerRedisLocks:
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK = (
"da_lock:check_connector_doc_permissions_sync_beat"
)
CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK = (
"da_lock:check_connector_external_group_sync_beat"
)
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
"da_lock:connector_doc_permissions_sync"
)
CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX = "da_lock:connector_external_group_sync"
PRUNING_LOCK_PREFIX = "da_lock:pruning"
INDEXING_METADATA_PREFIX = "da_metadata:indexing"

View File

@@ -119,3 +119,14 @@ if _LITELLM_PASS_THROUGH_HEADERS_RAW:
logger.error(
"Failed to parse LITELLM_PASS_THROUGH_HEADERS, must be a valid JSON object"
)
# if specified, will merge the specified JSON with the existing body of the
# request before sending it to the LLM
LITELLM_EXTRA_BODY: dict | None = None
_LITELLM_EXTRA_BODY_RAW = os.environ.get("LITELLM_EXTRA_BODY")
if _LITELLM_EXTRA_BODY_RAW:
try:
LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW)
except Exception:
pass

View File

@@ -5,9 +5,9 @@ from io import BytesIO
from typing import Any
from typing import Optional
import boto3
from botocore.client import Config
from mypy_boto3_s3 import S3Client
import boto3 # type: ignore
from botocore.client import Config # type: ignore
from mypy_boto3_s3 import S3Client # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import BlobType

View File

@@ -7,9 +7,9 @@ from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.confluence.onyx_confluence import build_confluence_client
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
from danswer.connectors.confluence.utils import attachment_to_content
from danswer.connectors.confluence.utils import build_confluence_client
from danswer.connectors.confluence.utils import build_confluence_document_id
from danswer.connectors.confluence.utils import datetime_from_string
from danswer.connectors.confluence.utils import extract_text_from_confluence_html
@@ -70,7 +70,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
) -> None:
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
self.confluence_client: OnyxConfluence | None = None
self._confluence_client: OnyxConfluence | None = None
self.is_cloud = is_cloud
# Remove trailing slash from wiki_base if present
@@ -81,15 +81,15 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
if cql_query:
# if a cql_query is provided, we will use it to fetch the pages
cql_page_query = cql_query
elif space:
# if no cql_query is provided, we will use the space to fetch the pages
cql_page_query += f" and space='{quote(space)}'"
elif page_id:
# if a cql_query is not provided, we will use the page_id to fetch the page
if index_recursively:
cql_page_query += f" and ancestor='{page_id}'"
else:
# if neither a space nor a cql_query is provided, we will use the page_id to fetch the page
cql_page_query += f" and id='{page_id}'"
elif space:
# if no cql_query or page_id is provided, we will use the space to fetch the pages
cql_page_query += f" and space='{quote(space)}'"
self.cql_page_query = cql_page_query
self.cql_time_filter = ""
@@ -97,39 +97,44 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self.cql_label_filter = ""
if labels_to_skip:
labels_to_skip = list(set(labels_to_skip))
comma_separated_labels = ",".join(f"'{label}'" for label in labels_to_skip)
comma_separated_labels = ",".join(
f"'{quote(label)}'" for label in labels_to_skip
)
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
@property
def confluence_client(self) -> OnyxConfluence:
if self._confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
return self._confluence_client
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
# for a list of other hidden constructor args
self.confluence_client = build_confluence_client(
credentials_json=credentials,
self._confluence_client = build_confluence_client(
credentials=credentials,
is_cloud=self.is_cloud,
wiki_base=self.wiki_base,
)
return None
def _get_comment_string_for_page_id(self, page_id: str) -> str:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
comment_string = ""
comment_cql = f"type=comment and container='{page_id}'"
comment_cql += self.cql_label_filter
expand = ",".join(_COMMENT_EXPANSION_FIELDS)
for comments in self.confluence_client.paginated_cql_page_retrieval(
for comment in self.confluence_client.paginated_cql_retrieval(
cql=comment_cql,
expand=expand,
):
for comment in comments:
comment_string += "\nComment:\n"
comment_string += extract_text_from_confluence_html(
confluence_client=self.confluence_client,
confluence_object=comment,
)
comment_string += "\nComment:\n"
comment_string += extract_text_from_confluence_html(
confluence_client=self.confluence_client,
confluence_object=comment,
fetched_titles=set(),
)
return comment_string
@@ -141,28 +146,28 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
If its a page, it extracts the text, adds the comments for the document text.
If its an attachment, it just downloads the attachment and converts that into a document.
"""
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
# The url and the id are the same
object_url = build_confluence_document_id(
self.wiki_base, confluence_object["_links"]["webui"]
self.wiki_base, confluence_object["_links"]["webui"], self.is_cloud
)
object_text = None
# Extract text from page
if confluence_object["type"] == "page":
object_text = extract_text_from_confluence_html(
self.confluence_client, confluence_object
confluence_client=self.confluence_client,
confluence_object=confluence_object,
fetched_titles={confluence_object.get("title", "")},
)
# Add comments to text
object_text += self._get_comment_string_for_page_id(confluence_object["id"])
elif confluence_object["type"] == "attachment":
object_text = attachment_to_content(
self.confluence_client, confluence_object
confluence_client=self.confluence_client, attachment=confluence_object
)
if object_text is None:
# This only happens for attachments that are not parseable
return None
# Get space name
@@ -193,44 +198,39 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
)
def _fetch_document_batches(self) -> GenerateDocumentsOutput:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
doc_batch: list[Document] = []
confluence_page_ids: list[str] = []
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
# Fetch pages as Documents
for page_batch in self.confluence_client.paginated_cql_page_retrieval(
for page in self.confluence_client.paginated_cql_retrieval(
cql=page_query,
expand=",".join(_PAGE_EXPANSION_FIELDS),
limit=self.batch_size,
):
for page in page_batch:
confluence_page_ids.append(page["id"])
doc = self._convert_object_to_document(page)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
confluence_page_ids.append(page["id"])
doc = self._convert_object_to_document(page)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
# Fetch attachments as Documents
for confluence_page_id in confluence_page_ids:
attachment_cql = f"type=attachment and container='{confluence_page_id}'"
attachment_cql += self.cql_label_filter
# TODO: maybe should add time filter as well?
for attachments in self.confluence_client.paginated_cql_page_retrieval(
for attachment in self.confluence_client.paginated_cql_retrieval(
cql=attachment_cql,
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
):
for attachment in attachments:
doc = self._convert_object_to_document(attachment)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
doc = self._convert_object_to_document(attachment)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
@@ -255,48 +255,47 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
doc_metadata_list: list[SlimDocument] = []
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
page_query = self.cql_page_query + self.cql_label_filter
for pages in self.confluence_client.cql_paginate_all_expansions(
for page in self.confluence_client.cql_paginate_all_expansions(
cql=page_query,
expand=restrictions_expand,
):
for page in pages:
# If the page has restrictions, add them to the perm_sync_data
# These will be used by doc_sync.py to sync permissions
perm_sync_data = {
"restrictions": page.get("restrictions", {}),
"space_key": page.get("space", {}).get("key"),
}
# If the page has restrictions, add them to the perm_sync_data
# These will be used by doc_sync.py to sync permissions
perm_sync_data = {
"restrictions": page.get("restrictions", {}),
"space_key": page.get("space", {}).get("key"),
}
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base,
page["_links"]["webui"],
self.is_cloud,
),
perm_sync_data=perm_sync_data,
)
)
attachment_cql = f"type=attachment and container='{page['id']}'"
attachment_cql += self.cql_label_filter
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_cql,
expand=restrictions_expand,
):
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base, page["_links"]["webui"]
self.wiki_base,
attachment["_links"]["webui"],
self.is_cloud,
),
perm_sync_data=perm_sync_data,
)
)
attachment_cql = f"type=attachment and container='{page['id']}'"
attachment_cql += self.cql_label_filter
for attachments in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_cql,
expand=restrictions_expand,
):
for attachment in attachments:
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base, attachment["_links"]["webui"]
),
perm_sync_data=perm_sync_data,
)
)
yield doc_metadata_list
doc_metadata_list = []
yield doc_metadata_list
doc_metadata_list = []

View File

@@ -20,6 +20,10 @@ F = TypeVar("F", bound=Callable[..., Any])
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
# https://jira.atlassian.com/browse/CONFCLOUD-76433
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
_REPLACEMENT_EXPANSIONS = "body.view.value"
class ConfluenceRateLimitError(Exception):
pass
@@ -80,7 +84,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
MAX_RETRIES = 5
TIMEOUT = 3600
TIMEOUT = 600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
@@ -95,6 +99,10 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
return confluence_call(*args, **kwargs)
except HTTPError as e:
delay_until = _handle_http_error(e, attempt)
logger.warning(
f"HTTPError in confluence call. "
f"Retrying in {delay_until} seconds..."
)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
@@ -141,7 +149,7 @@ class OnyxConfluence(Confluence):
def _paginate_url(
self, url_suffix: str, limit: int | None = None
) -> Iterator[list[dict[str, Any]]]:
) -> Iterator[dict[str, Any]]:
"""
This will paginate through the top level query.
"""
@@ -153,46 +161,43 @@ class OnyxConfluence(Confluence):
while url_suffix:
try:
logger.debug(f"Making confluence call to {url_suffix}")
next_response = self.get(url_suffix)
except Exception as e:
logger.exception("Error in danswer_cql: \n")
raise e
yield next_response.get("results", [])
logger.warning(f"Error in confluence call to {url_suffix}")
# If the problematic expansion is in the url, replace it
# with the replacement expansion and try again
# If that fails, raise the error
if _PROBLEMATIC_EXPANSIONS not in url_suffix:
logger.exception(f"Error in confluence call to {url_suffix}")
raise e
logger.warning(
f"Replacing {_PROBLEMATIC_EXPANSIONS} with {_REPLACEMENT_EXPANSIONS}"
" and trying again."
)
url_suffix = url_suffix.replace(
_PROBLEMATIC_EXPANSIONS,
_REPLACEMENT_EXPANSIONS,
)
continue
# yield the results individually
yield from next_response.get("results", [])
url_suffix = next_response.get("_links", {}).get("next")
def paginated_groups_retrieval(
self,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
return self._paginate_url("rest/api/group", limit)
def paginated_group_members_retrieval(
self,
group_name: str,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
group_name = quote(group_name)
return self._paginate_url(f"rest/api/group/{group_name}/member", limit)
def paginated_cql_user_retrieval(
def paginated_cql_retrieval(
self,
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
) -> Iterator[dict[str, Any]]:
"""
The content/search endpoint can be used to fetch pages, attachments, and comments.
"""
expand_string = f"&expand={expand}" if expand else ""
return self._paginate_url(
f"rest/api/search/user?cql={cql}{expand_string}", limit
)
def paginated_cql_page_retrieval(
self,
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
expand_string = f"&expand={expand}" if expand else ""
return self._paginate_url(
yield from self._paginate_url(
f"rest/api/content/search?cql={cql}{expand_string}", limit
)
@@ -201,7 +206,7 @@ class OnyxConfluence(Confluence):
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
) -> Iterator[dict[str, Any]]:
"""
This function will paginate through the top level query first, then
paginate through all of the expansions.
@@ -221,6 +226,110 @@ class OnyxConfluence(Confluence):
for item in data:
_traverse_and_update(item)
for results in self.paginated_cql_page_retrieval(cql, expand, limit):
_traverse_and_update(results)
yield results
for confluence_object in self.paginated_cql_retrieval(cql, expand, limit):
_traverse_and_update(confluence_object)
yield confluence_object
def paginated_cql_user_retrieval(
self,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
The search/user endpoint can be used to fetch users.
It's a seperate endpoint from the content/search endpoint used only for users.
Otherwise it's very similar to the content/search endpoint.
"""
cql = "type=user"
url = "rest/api/search/user" if self.cloud else "rest/api/search"
expand_string = f"&expand={expand}" if expand else ""
url += f"?cql={cql}{expand_string}"
yield from self._paginate_url(url, limit)
def paginated_groups_by_user_retrieval(
self,
user: dict[str, Any],
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch groups.
"""
user_field = "accountId" if self.cloud else "key"
user_value = user["accountId"] if self.cloud else user["userKey"]
# Server uses userKey (but calls it key during the API call), Cloud uses accountId
user_query = f"{user_field}={quote(user_value)}"
url = f"rest/api/user/memberof?{user_query}"
yield from self._paginate_url(url, limit)
def paginated_groups_retrieval(
self,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch groups.
"""
yield from self._paginate_url("rest/api/group", limit)
def paginated_group_members_retrieval(
self,
group_name: str,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch the members of a group.
THIS DOESN'T WORK FOR SERVER because it breaks when there is a slash in the group name.
E.g. neither "test/group" nor "test%2Fgroup" works for confluence.
"""
group_name = quote(group_name)
yield from self._paginate_url(f"rest/api/group/{group_name}/member", limit)
def _validate_connector_configuration(
credentials: dict[str, Any],
is_cloud: bool,
wiki_base: str,
) -> None:
# test connection with direct client, no retries
confluence_client_without_retries = Confluence(
api_version="cloud" if is_cloud else "latest",
url=wiki_base.rstrip("/"),
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
)
spaces = confluence_client_without_retries.get_all_spaces(limit=1)
if not spaces:
raise RuntimeError(
f"No spaces found at {wiki_base}! "
"Check your credentials and wiki_base and make sure "
"is_cloud is set correctly."
)
def build_confluence_client(
credentials: dict[str, Any],
is_cloud: bool,
wiki_base: str,
) -> OnyxConfluence:
_validate_connector_configuration(
credentials=credentials,
is_cloud=is_cloud,
wiki_base=wiki_base,
)
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=wiki_base.rstrip("/"),
# passing in username causes issues for Confluence data center
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=10,
max_backoff_seconds=60,
)

View File

@@ -2,6 +2,7 @@ import io
from datetime import datetime
from datetime import timezone
from typing import Any
from urllib.parse import quote
import bs4
@@ -71,7 +72,9 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
def extract_text_from_confluence_html(
confluence_client: OnyxConfluence, confluence_object: dict[str, Any]
confluence_client: OnyxConfluence,
confluence_object: dict[str, Any],
fetched_titles: set[str],
) -> str:
"""Parse a Confluence html page and replace the 'user Id' by the real
User Display Name
@@ -79,7 +82,7 @@ def extract_text_from_confluence_html(
Args:
confluence_object (dict): The confluence object as a dict
confluence_client (Confluence): Confluence client
fetched_titles (set[str]): The titles of the pages that have already been fetched
Returns:
str: loaded and formated Confluence page
"""
@@ -100,6 +103,73 @@ def extract_text_from_confluence_html(
continue
# Include @ sign for tagging, more clear for LLM
user.replaceWith("@" + _get_user(confluence_client, user_id))
for html_page_reference in soup.findAll("ac:structured-macro"):
# Here, we only want to process page within page macros
if html_page_reference.attrs.get("ac:name") != "include":
continue
page_data = html_page_reference.find("ri:page")
if not page_data:
logger.warning(
f"Skipping retrieval of {html_page_reference} because because page data is missing"
)
continue
page_title = page_data.attrs.get("ri:content-title")
if not page_title:
# only fetch pages that have a title
logger.warning(
f"Skipping retrieval of {html_page_reference} because it has no title"
)
continue
if page_title in fetched_titles:
# prevent recursive fetching of pages
logger.debug(f"Skipping {page_title} because it has already been fetched")
continue
fetched_titles.add(page_title)
# Wrap this in a try-except because there are some pages that might not exist
try:
page_query = f"type=page and title='{quote(page_title)}'"
page_contents: dict[str, Any] | None = None
# Confluence enforces title uniqueness, so we should only get one result here
for page in confluence_client.paginated_cql_retrieval(
cql=page_query,
expand="body.storage.value",
limit=1,
):
page_contents = page
break
except Exception as e:
logger.warning(
f"Error getting page contents for object {confluence_object}: {e}"
)
continue
if not page_contents:
continue
text_from_page = extract_text_from_confluence_html(
confluence_client=confluence_client,
confluence_object=page_contents,
fetched_titles=fetched_titles,
)
html_page_reference.replaceWith(text_from_page)
for html_link_body in soup.findAll("ac:link-body"):
# This extracts the text from inline links in the page so they can be
# represented in the document text as plain text
try:
text_from_link = html_link_body.text
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
except Exception as e:
logger.warning(f"Error processing ac:link-body: {e}")
return format_document_soup(soup)
@@ -153,7 +223,9 @@ def attachment_to_content(
return extracted_text
def build_confluence_document_id(base_url: str, content_url: str) -> str:
def build_confluence_document_id(
base_url: str, content_url: str, is_cloud: bool
) -> str:
"""For confluence, the document id is the page url for a page based document
or the attachment download url for an attachment based document
@@ -164,6 +236,8 @@ def build_confluence_document_id(base_url: str, content_url: str) -> str:
Returns:
str: The document id
"""
if is_cloud and not base_url.endswith("/wiki"):
base_url += "/wiki"
return f"{base_url}{content_url}"
@@ -195,20 +269,3 @@ def datetime_from_string(datetime_string: str) -> datetime:
datetime_object = datetime_object.astimezone(timezone.utc)
return datetime_object
def build_confluence_client(
credentials_json: dict[str, Any], is_cloud: bool, wiki_base: str
) -> OnyxConfluence:
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=wiki_base.rstrip("/"),
# passing in username causes issues for Confluence data center
username=credentials_json["confluence_username"] if is_cloud else None,
password=credentials_json["confluence_access_token"] if is_cloud else None,
token=credentials_json["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=60,
max_backoff_seconds=60,
)

View File

@@ -1,8 +1,8 @@
import os
from collections.abc import Iterable
from datetime import datetime
from datetime import timezone
from typing import Any
from urllib.parse import urlparse
from jira import JIRA
from jira.resources import Issue
@@ -12,129 +12,93 @@ from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.danswer_jira.utils import best_effort_basic_expert_info
from danswer.connectors.danswer_jira.utils import best_effort_get_field_from_issue
from danswer.connectors.danswer_jira.utils import build_jira_client
from danswer.connectors.danswer_jira.utils import build_jira_url
from danswer.connectors.danswer_jira.utils import extract_jira_project
from danswer.connectors.danswer_jira.utils import extract_text_from_adf
from danswer.connectors.danswer_jira.utils import get_comment_strs
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.models import SlimDocument
from danswer.utils.logger import setup_logger
logger = setup_logger()
PROJECT_URL_PAT = "projects"
JIRA_API_VERSION = os.environ.get("JIRA_API_VERSION") or "2"
_JIRA_SLIM_PAGE_SIZE = 500
_JIRA_FULL_PAGE_SIZE = 50
def extract_jira_project(url: str) -> tuple[str, str]:
parsed_url = urlparse(url)
jira_base = parsed_url.scheme + "://" + parsed_url.netloc
def _paginate_jql_search(
jira_client: JIRA,
jql: str,
max_results: int,
fields: str | None = None,
) -> Iterable[Issue]:
start = 0
while True:
logger.debug(
f"Fetching Jira issues with JQL: {jql}, "
f"starting at {start}, max results: {max_results}"
)
issues = jira_client.search_issues(
jql_str=jql,
startAt=start,
maxResults=max_results,
fields=fields,
)
# Split the path by '/' and find the position of 'projects' to get the project name
split_path = parsed_url.path.split("/")
if PROJECT_URL_PAT in split_path:
project_pos = split_path.index(PROJECT_URL_PAT)
if len(split_path) > project_pos + 1:
jira_project = split_path[project_pos + 1]
else:
raise ValueError("No project name found in the URL")
else:
raise ValueError("'projects' not found in the URL")
for issue in issues:
if isinstance(issue, Issue):
yield issue
else:
raise Exception(f"Found Jira object not of type Issue: {issue}")
return jira_base, jira_project
if len(issues) < max_results:
break
def extract_text_from_adf(adf: dict | None) -> str:
"""Extracts plain text from Atlassian Document Format:
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
WARNING: This function is incomplete and will e.g. skip lists!
"""
texts = []
if adf is not None and "content" in adf:
for block in adf["content"]:
if "content" in block:
for item in block["content"]:
if item["type"] == "text":
texts.append(item["text"])
return " ".join(texts)
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
if hasattr(jira_issue.fields, field):
return getattr(jira_issue.fields, field)
try:
return jira_issue.raw["fields"][field]
except Exception:
return None
def _get_comment_strs(
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
) -> list[str]:
comment_strs = []
for comment in jira.fields.comment.comments:
try:
body_text = (
comment.body
if JIRA_API_VERSION == "2"
else extract_text_from_adf(comment.raw["body"])
)
if (
hasattr(comment, "author")
and hasattr(comment.author, "emailAddress")
and comment.author.emailAddress in comment_email_blacklist
):
continue # Skip adding comment if author's email is in blacklist
comment_strs.append(body_text)
except Exception as e:
logger.error(f"Failed to process comment due to an error: {e}")
continue
return comment_strs
start += max_results
def fetch_jira_issues_batch(
jql: str,
start_index: int,
jira_client: JIRA,
batch_size: int = INDEX_BATCH_SIZE,
jql: str,
batch_size: int,
comment_email_blacklist: tuple[str, ...] = (),
labels_to_skip: set[str] | None = None,
) -> tuple[list[Document], int]:
doc_batch = []
batch = jira_client.search_issues(
jql,
startAt=start_index,
maxResults=batch_size,
)
for jira in batch:
if type(jira) != Issue:
logger.warning(f"Found Jira object not of type Issue {jira}")
continue
if labels_to_skip and any(
label in jira.fields.labels for label in labels_to_skip
):
logger.info(
f"Skipping {jira.key} because it has a label to skip. Found "
f"labels: {jira.fields.labels}. Labels to skip: {labels_to_skip}."
)
continue
) -> Iterable[Document]:
for issue in _paginate_jql_search(
jira_client=jira_client,
jql=jql,
max_results=batch_size,
):
if labels_to_skip:
if any(label in issue.fields.labels for label in labels_to_skip):
logger.info(
f"Skipping {issue.key} because it has a label to skip. Found "
f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}."
)
continue
description = (
jira.fields.description
issue.fields.description
if JIRA_API_VERSION == "2"
else extract_text_from_adf(jira.raw["fields"]["description"])
else extract_text_from_adf(issue.raw["fields"]["description"])
)
comments = get_comment_strs(
issue=issue,
comment_email_blacklist=comment_email_blacklist,
)
comments = _get_comment_strs(jira, comment_email_blacklist)
ticket_content = f"{description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments if comment]
)
@@ -142,66 +106,53 @@ def fetch_jira_issues_batch(
# Check ticket size
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
logger.info(
f"Skipping {jira.key} because it exceeds the maximum size of "
f"Skipping {issue.key} because it exceeds the maximum size of "
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
)
continue
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
page_url = f"{jira_client.client_info()}/browse/{issue.key}"
people = set()
try:
people.add(
BasicExpertInfo(
display_name=jira.fields.creator.displayName,
email=jira.fields.creator.emailAddress,
)
)
creator = best_effort_get_field_from_issue(issue, "creator")
if basic_expert_info := best_effort_basic_expert_info(creator):
people.add(basic_expert_info)
except Exception:
# Author should exist but if not, doesn't matter
pass
try:
people.add(
BasicExpertInfo(
display_name=jira.fields.assignee.displayName, # type: ignore
email=jira.fields.assignee.emailAddress, # type: ignore
)
)
assignee = best_effort_get_field_from_issue(issue, "assignee")
if basic_expert_info := best_effort_basic_expert_info(assignee):
people.add(basic_expert_info)
except Exception:
# Author should exist but if not, doesn't matter
pass
metadata_dict = {}
priority = best_effort_get_field_from_issue(jira, "priority")
if priority:
if priority := best_effort_get_field_from_issue(issue, "priority"):
metadata_dict["priority"] = priority.name
status = best_effort_get_field_from_issue(jira, "status")
if status:
if status := best_effort_get_field_from_issue(issue, "status"):
metadata_dict["status"] = status.name
resolution = best_effort_get_field_from_issue(jira, "resolution")
if resolution:
if resolution := best_effort_get_field_from_issue(issue, "resolution"):
metadata_dict["resolution"] = resolution.name
labels = best_effort_get_field_from_issue(jira, "labels")
if labels:
if labels := best_effort_get_field_from_issue(issue, "labels"):
metadata_dict["label"] = labels
doc_batch.append(
Document(
id=page_url,
sections=[Section(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=jira.fields.summary,
doc_updated_at=time_str_to_utc(jira.fields.updated),
primary_owners=list(people) or None,
# TODO add secondary_owners (commenters) if needed
metadata=metadata_dict,
)
yield Document(
id=page_url,
sections=[Section(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=issue.fields.summary,
doc_updated_at=time_str_to_utc(issue.fields.updated),
primary_owners=list(people) or None,
# TODO add secondary_owners (commenters) if needed
metadata=metadata_dict,
)
return doc_batch, len(batch)
class JiraConnector(LoadConnector, PollConnector):
class JiraConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
jira_project_url: str,
@@ -213,8 +164,8 @@ class JiraConnector(LoadConnector, PollConnector):
labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP,
) -> None:
self.batch_size = batch_size
self.jira_base, self.jira_project = extract_jira_project(jira_project_url)
self.jira_client: JIRA | None = None
self.jira_base, self._jira_project = extract_jira_project(jira_project_url)
self._jira_client: JIRA | None = None
self._comment_email_blacklist = comment_email_blacklist or []
self.labels_to_skip = set(labels_to_skip)
@@ -223,54 +174,45 @@ class JiraConnector(LoadConnector, PollConnector):
def comment_email_blacklist(self) -> tuple:
return tuple(email.strip() for email in self._comment_email_blacklist)
@property
def jira_client(self) -> JIRA:
if self._jira_client is None:
raise ConnectorMissingCredentialError("Jira")
return self._jira_client
@property
def quoted_jira_project(self) -> str:
# Quote the project name to handle reserved words
return f'"{self._jira_project}"'
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
api_token = credentials["jira_api_token"]
# if user provide an email we assume it's cloud
if "jira_user_email" in credentials:
email = credentials["jira_user_email"]
self.jira_client = JIRA(
basic_auth=(email, api_token),
server=self.jira_base,
options={"rest_api_version": JIRA_API_VERSION},
)
else:
self.jira_client = JIRA(
token_auth=api_token,
server=self.jira_base,
options={"rest_api_version": JIRA_API_VERSION},
)
self._jira_client = build_jira_client(
credentials=credentials,
jira_base=self.jira_base,
)
return None
def load_from_state(self) -> GenerateDocumentsOutput:
if self.jira_client is None:
raise ConnectorMissingCredentialError("Jira")
jql = f"project = {self.quoted_jira_project}"
# Quote the project name to handle reserved words
quoted_project = f'"{self.jira_project}"'
start_ind = 0
while True:
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
jql=f"project = {quoted_project}",
start_index=start_ind,
jira_client=self.jira_client,
batch_size=self.batch_size,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
)
document_batch = []
for doc in fetch_jira_issues_batch(
jira_client=self.jira_client,
jql=jql,
batch_size=_JIRA_FULL_PAGE_SIZE,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
):
document_batch.append(doc)
if len(document_batch) >= self.batch_size:
yield document_batch
document_batch = []
if doc_batch:
yield doc_batch
start_ind += fetched_batch_size
if fetched_batch_size < self.batch_size:
break
yield document_batch
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if self.jira_client is None:
raise ConnectorMissingCredentialError("Jira")
start_date_str = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
"%Y-%m-%d %H:%M"
)
@@ -278,31 +220,54 @@ class JiraConnector(LoadConnector, PollConnector):
"%Y-%m-%d %H:%M"
)
# Quote the project name to handle reserved words
quoted_project = f'"{self.jira_project}"'
jql = (
f"project = {quoted_project} AND "
f"project = {self.quoted_jira_project} AND "
f"updated >= '{start_date_str}' AND "
f"updated <= '{end_date_str}'"
)
start_ind = 0
while True:
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
jql=jql,
start_index=start_ind,
jira_client=self.jira_client,
batch_size=self.batch_size,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
document_batch = []
for doc in fetch_jira_issues_batch(
jira_client=self.jira_client,
jql=jql,
batch_size=_JIRA_FULL_PAGE_SIZE,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
):
document_batch.append(doc)
if len(document_batch) >= self.batch_size:
yield document_batch
document_batch = []
yield document_batch
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
jql = f"project = {self.quoted_jira_project}"
slim_doc_batch = []
for issue in _paginate_jql_search(
jira_client=self.jira_client,
jql=jql,
max_results=_JIRA_SLIM_PAGE_SIZE,
fields="key",
):
issue_key = best_effort_get_field_from_issue(issue, "key")
id = build_jira_url(self.jira_client, issue_key)
slim_doc_batch.append(
SlimDocument(
id=id,
perm_sync_data=None,
)
)
if len(slim_doc_batch) >= _JIRA_SLIM_PAGE_SIZE:
yield slim_doc_batch
slim_doc_batch = []
if doc_batch:
yield doc_batch
start_ind += fetched_batch_size
if fetched_batch_size < self.batch_size:
break
yield slim_doc_batch
if __name__ == "__main__":

View File

@@ -1,17 +1,136 @@
"""Module with custom fields processing functions"""
import os
from typing import Any
from typing import List
from urllib.parse import urlparse
from jira import JIRA
from jira.resources import CustomFieldOption
from jira.resources import Issue
from jira.resources import User
from danswer.connectors.models import BasicExpertInfo
from danswer.utils.logger import setup_logger
logger = setup_logger()
PROJECT_URL_PAT = "projects"
JIRA_API_VERSION = os.environ.get("JIRA_API_VERSION") or "2"
def best_effort_basic_expert_info(obj: Any) -> BasicExpertInfo | None:
display_name = None
email = None
if hasattr(obj, "display_name"):
display_name = obj.display_name
else:
display_name = obj.get("displayName")
if hasattr(obj, "emailAddress"):
email = obj.emailAddress
else:
email = obj.get("emailAddress")
if not email and not display_name:
return None
return BasicExpertInfo(display_name=display_name, email=email)
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
if hasattr(jira_issue.fields, field):
return getattr(jira_issue.fields, field)
try:
return jira_issue.raw["fields"][field]
except Exception:
return None
def extract_text_from_adf(adf: dict | None) -> str:
"""Extracts plain text from Atlassian Document Format:
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
WARNING: This function is incomplete and will e.g. skip lists!
"""
texts = []
if adf is not None and "content" in adf:
for block in adf["content"]:
if "content" in block:
for item in block["content"]:
if item["type"] == "text":
texts.append(item["text"])
return " ".join(texts)
def build_jira_url(jira_client: JIRA, issue_key: str) -> str:
return f"{jira_client.client_info()}/browse/{issue_key}"
def build_jira_client(credentials: dict[str, Any], jira_base: str) -> JIRA:
api_token = credentials["jira_api_token"]
# if user provide an email we assume it's cloud
if "jira_user_email" in credentials:
email = credentials["jira_user_email"]
return JIRA(
basic_auth=(email, api_token),
server=jira_base,
options={"rest_api_version": JIRA_API_VERSION},
)
else:
return JIRA(
token_auth=api_token,
server=jira_base,
options={"rest_api_version": JIRA_API_VERSION},
)
def extract_jira_project(url: str) -> tuple[str, str]:
parsed_url = urlparse(url)
jira_base = parsed_url.scheme + "://" + parsed_url.netloc
# Split the path by '/' and find the position of 'projects' to get the project name
split_path = parsed_url.path.split("/")
if PROJECT_URL_PAT in split_path:
project_pos = split_path.index(PROJECT_URL_PAT)
if len(split_path) > project_pos + 1:
jira_project = split_path[project_pos + 1]
else:
raise ValueError("No project name found in the URL")
else:
raise ValueError("'projects' not found in the URL")
return jira_base, jira_project
def get_comment_strs(
issue: Issue, comment_email_blacklist: tuple[str, ...] = ()
) -> list[str]:
comment_strs = []
for comment in issue.fields.comment.comments:
try:
body_text = (
comment.body
if JIRA_API_VERSION == "2"
else extract_text_from_adf(comment.raw["body"])
)
if (
hasattr(comment, "author")
and hasattr(comment.author, "emailAddress")
and comment.author.emailAddress in comment_email_blacklist
):
continue # Skip adding comment if author's email is in blacklist
comment_strs.append(body_text)
except Exception as e:
logger.error(f"Failed to process comment due to an error: {e}")
continue
return comment_strs
class CustomFieldExtractor:
@staticmethod
def _process_custom_field_value(value: Any) -> str:

View File

@@ -305,6 +305,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
query = _build_time_range_query(time_range_start, time_range_end)
doc_batch = []
for user_email in self._get_all_user_emails():
logger.info(f"Fetching slim threads for user: {user_email}")
gmail_service = get_gmail_service(self.creds, user_email)
for thread in execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().list,

View File

@@ -15,6 +15,7 @@ from danswer.connectors.google_drive.doc_conversion import (
convert_drive_item_to_document,
)
from danswer.connectors.google_drive.file_retrieval import crawl_folders_for_files
from danswer.connectors.google_drive.file_retrieval import get_all_files_for_oauth
from danswer.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
from danswer.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from danswer.connectors.google_drive.models import GoogleDriveFileType
@@ -82,12 +83,31 @@ def _process_files_batch(
yield doc_batch
def _clean_requested_drive_ids(
requested_drive_ids: set[str],
requested_folder_ids: set[str],
all_drive_ids_available: set[str],
) -> tuple[set[str], set[str]]:
invalid_requested_drive_ids = requested_drive_ids - all_drive_ids_available
filtered_folder_ids = requested_folder_ids - all_drive_ids_available
if invalid_requested_drive_ids:
logger.warning(
f"Some shared drive IDs were not found. IDs: {invalid_requested_drive_ids}"
)
logger.warning("Checking for folder access instead...")
filtered_folder_ids.update(invalid_requested_drive_ids)
valid_requested_drive_ids = requested_drive_ids - invalid_requested_drive_ids
return valid_requested_drive_ids, filtered_folder_ids
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
include_shared_drives: bool = True,
include_shared_drives: bool = False,
include_my_drives: bool = False,
include_files_shared_with_me: bool = False,
shared_drive_urls: str | None = None,
include_my_drives: bool = True,
my_drive_emails: str | None = None,
shared_folder_urls: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
@@ -120,22 +140,36 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
if (
not include_shared_drives
and not include_my_drives
and not include_files_shared_with_me
and not shared_folder_urls
and not my_drive_emails
and not shared_drive_urls
):
raise ValueError(
"At least one of include_shared_drives, include_my_drives,"
" or shared_folder_urls must be true"
"Nothing to index. Please specify at least one of the following: "
"include_shared_drives, include_my_drives, include_files_shared_with_me, "
"shared_folder_urls, or my_drive_emails"
)
self.batch_size = batch_size
self.include_shared_drives = include_shared_drives
specific_requests_made = False
if bool(shared_drive_urls) or bool(my_drive_emails) or bool(shared_folder_urls):
specific_requests_made = True
self.include_files_shared_with_me = (
False if specific_requests_made else include_files_shared_with_me
)
self.include_my_drives = False if specific_requests_made else include_my_drives
self.include_shared_drives = (
False if specific_requests_made else include_shared_drives
)
shared_drive_url_list = _extract_str_list_from_comma_str(shared_drive_urls)
self._requested_shared_drive_ids = set(
_extract_ids_from_urls(shared_drive_url_list)
)
self.include_my_drives = include_my_drives
self._requested_my_drive_emails = set(
_extract_str_list_from_comma_str(my_drive_emails)
)
@@ -192,80 +226,72 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
def _update_traversed_parent_ids(self, folder_id: str) -> None:
self._retrieved_ids.add(folder_id)
def _get_all_user_emails(self, admins_only: bool) -> list[str]:
def _get_all_user_emails(self) -> list[str]:
# Start with primary admin email
user_emails = [self.primary_admin_email]
# Only fetch additional users if using service account
if isinstance(self.creds, OAuthCredentials):
return user_emails
admin_service = get_admin_service(
creds=self.creds,
user_email=self.primary_admin_email,
)
query = "isAdmin=true" if admins_only else "isAdmin=false"
emails = []
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
query=query,
):
if email := user.get("primaryEmail"):
emails.append(email)
return emails
# Get admins first since they're more likely to have access to most files
for is_admin in [True, False]:
query = "isAdmin=true" if is_admin else "isAdmin=false"
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
query=query,
):
if email := user.get("primaryEmail"):
if email not in user_emails:
user_emails.append(email)
return user_emails
def _get_all_drive_ids(self) -> set[str]:
primary_drive_service = get_drive_service(
creds=self.creds,
user_email=self.primary_admin_email,
)
is_service_account = isinstance(self.creds, ServiceAccountCredentials)
all_drive_ids = set()
for drive in execute_paginated_retrieval(
retrieval_function=primary_drive_service.drives().list,
list_key="drives",
useDomainAdminAccess=True,
useDomainAdminAccess=is_service_account,
fields="drives(id)",
):
all_drive_ids.add(drive["id"])
return all_drive_ids
def _initialize_all_class_variables(self) -> None:
# Get all user emails
# Get admins first becuase they are more likely to have access to the most files
user_emails = [self.primary_admin_email]
for admins_only in [True, False]:
for email in self._get_all_user_emails(admins_only=admins_only):
if email not in user_emails:
user_emails.append(email)
self._all_org_emails = user_emails
self._all_drive_ids: set[str] = self._get_all_drive_ids()
# remove drive ids from the folder ids because they are queried differently
self._requested_folder_ids -= self._all_drive_ids
# Remove drive_ids that are not in the all_drive_ids and check them as folders instead
invalid_drive_ids = self._requested_shared_drive_ids - self._all_drive_ids
if invalid_drive_ids:
if not all_drive_ids:
logger.warning(
f"Some shared drive IDs were not found. IDs: {invalid_drive_ids}"
"No drives found even though we are indexing shared drives was requested."
)
logger.warning("Checking for folder access instead...")
self._requested_folder_ids.update(invalid_drive_ids)
if not self.include_shared_drives:
self._requested_shared_drive_ids = set()
elif not self._requested_shared_drive_ids:
self._requested_shared_drive_ids = self._all_drive_ids
return all_drive_ids
def _impersonate_user_for_retrieval(
self,
user_email: str,
is_slim: bool,
filtered_drive_ids: set[str],
filtered_folder_ids: set[str],
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
drive_service = get_drive_service(self.creds, user_email)
if self.include_my_drives and (
not self._requested_my_drive_emails
or user_email in self._requested_my_drive_emails
):
# if we are including my drives, try to get the current user's my
# drive if any of the following are true:
# - include_my_drives is true
# - the current user's email is in the requested emails
if self.include_my_drives or user_email in self._requested_my_drive_emails:
yield from get_all_files_in_my_drive(
service=drive_service,
update_traversed_ids_func=self._update_traversed_parent_ids,
@@ -274,7 +300,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
end=end,
)
remaining_drive_ids = self._requested_shared_drive_ids - self._retrieved_ids
remaining_drive_ids = filtered_drive_ids - self._retrieved_ids
for drive_id in remaining_drive_ids:
yield from get_files_in_shared_drive(
service=drive_service,
@@ -285,7 +311,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
end=end,
)
remaining_folders = self._requested_folder_ids - self._retrieved_ids
remaining_folders = filtered_folder_ids - self._retrieved_ids
for folder_id in remaining_folders:
yield from crawl_folders_for_files(
service=drive_service,
@@ -296,33 +322,142 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
end=end,
)
def _fetch_drive_items(
def _manage_service_account_retrieval(
self,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
self._initialize_all_class_variables()
all_org_emails: list[str] = self._get_all_user_emails()
all_drive_ids: set[str] = self._get_all_drive_ids()
drive_ids_to_retrieve: set[str] = set()
folder_ids_to_retrieve: set[str] = set()
if self._requested_shared_drive_ids or self._requested_folder_ids:
drive_ids_to_retrieve, folder_ids_to_retrieve = _clean_requested_drive_ids(
requested_drive_ids=self._requested_shared_drive_ids,
requested_folder_ids=self._requested_folder_ids,
all_drive_ids_available=all_drive_ids,
)
elif self.include_shared_drives:
drive_ids_to_retrieve = all_drive_ids
# Process users in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=10) as executor:
future_to_email = {
executor.submit(
self._impersonate_user_for_retrieval, email, is_slim, start, end
self._impersonate_user_for_retrieval,
email,
is_slim,
drive_ids_to_retrieve,
folder_ids_to_retrieve,
start,
end,
): email
for email in self._all_org_emails
for email in all_org_emails
}
# Yield results as they complete
for future in as_completed(future_to_email):
yield from future.result()
remaining_folders = self._requested_folder_ids - self._retrieved_ids
remaining_folders = (
drive_ids_to_retrieve | folder_ids_to_retrieve
) - self._retrieved_ids
if remaining_folders:
logger.warning(
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
)
def _manage_oauth_retrieval(
self,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
drive_service = get_drive_service(self.creds, self.primary_admin_email)
if self.include_files_shared_with_me or self.include_my_drives:
yield from get_all_files_for_oauth(
service=drive_service,
include_files_shared_with_me=self.include_files_shared_with_me,
include_my_drives=self.include_my_drives,
include_shared_drives=self.include_shared_drives,
is_slim=is_slim,
start=start,
end=end,
)
all_requested = (
self.include_files_shared_with_me
and self.include_my_drives
and self.include_shared_drives
)
if all_requested:
# If all 3 are true, we already yielded from get_all_files_for_oauth
return
all_drive_ids = self._get_all_drive_ids()
drive_ids_to_retrieve: set[str] = set()
folder_ids_to_retrieve: set[str] = set()
if self._requested_shared_drive_ids or self._requested_folder_ids:
drive_ids_to_retrieve, folder_ids_to_retrieve = _clean_requested_drive_ids(
requested_drive_ids=self._requested_shared_drive_ids,
requested_folder_ids=self._requested_folder_ids,
all_drive_ids_available=all_drive_ids,
)
elif self.include_shared_drives:
drive_ids_to_retrieve = all_drive_ids
for drive_id in drive_ids_to_retrieve:
yield from get_files_in_shared_drive(
service=drive_service,
drive_id=drive_id,
is_slim=is_slim,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
)
# Even if no folders were requested, we still check if any drives were requested
# that could be folders.
remaining_folders = folder_ids_to_retrieve - self._retrieved_ids
for folder_id in remaining_folders:
yield from crawl_folders_for_files(
service=drive_service,
parent_id=folder_id,
traversed_parent_ids=self._retrieved_ids,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
)
remaining_folders = (
drive_ids_to_retrieve | folder_ids_to_retrieve
) - self._retrieved_ids
if remaining_folders:
logger.warning(
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
)
def _fetch_drive_items(
self,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
retrieval_method = (
self._manage_service_account_retrieval
if isinstance(self.creds, ServiceAccountCredentials)
else self._manage_oauth_retrieval
)
return retrieval_method(
is_slim=is_slim,
start=start,
end=end,
)
def _extract_docs_from_google_drive(
self,
start: SecondsSinceUnixEpoch | None = None,

View File

@@ -2,6 +2,7 @@ import io
from datetime import datetime
from datetime import timezone
from googleapiclient.discovery import build # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
@@ -48,6 +49,67 @@ def _extract_sections_basic(
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
try:
if mime_type == GDriveMimeType.SPREADSHEET.value:
try:
sheets_service = build(
"sheets", "v4", credentials=service._http.credentials
)
spreadsheet = (
sheets_service.spreadsheets()
.get(spreadsheetId=file["id"])
.execute()
)
sections = []
for sheet in spreadsheet["sheets"]:
sheet_name = sheet["properties"]["title"]
sheet_id = sheet["properties"]["sheetId"]
# Get sheet dimensions
grid_properties = sheet["properties"].get("gridProperties", {})
row_count = grid_properties.get("rowCount", 1000)
column_count = grid_properties.get("columnCount", 26)
# Convert column count to letter (e.g., 26 -> Z, 27 -> AA)
end_column = ""
while column_count:
column_count, remainder = divmod(column_count - 1, 26)
end_column = chr(65 + remainder) + end_column
range_name = f"'{sheet_name}'!A1:{end_column}{row_count}"
try:
result = (
sheets_service.spreadsheets()
.values()
.get(spreadsheetId=file["id"], range=range_name)
.execute()
)
values = result.get("values", [])
if values:
text = f"Sheet: {sheet_name}\n"
for row in values:
text += "\t".join(str(cell) for cell in row) + "\n"
sections.append(
Section(
link=f"{link}#gid={sheet_id}",
text=text,
)
)
except HttpError as e:
logger.warning(
f"Error fetching data for sheet '{sheet_name}': {e}"
)
continue
return sections
except Exception as e:
logger.warning(
f"Ran into exception '{e}' when pulling data from Google Sheet '{file['name']}'."
" Falling back to basic extraction."
)
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
@@ -65,6 +127,7 @@ def _extract_sections_basic(
.decode("utf-8")
)
return [Section(link=link, text=text)]
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,

View File

@@ -140,8 +140,8 @@ def get_files_in_shared_drive(
) -> Iterator[GoogleDriveFileType]:
# If we know we are going to folder crawl later, we can cache the folders here
# Get all folders being queried and add them to the traversed set
query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
query += " and trashed = false"
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
folder_query += " and trashed = false"
found_folders = False
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
@@ -152,7 +152,7 @@ def get_files_in_shared_drive(
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields="nextPageToken, files(id)",
q=query,
q=folder_query,
):
update_traversed_ids_func(file["id"])
found_folders = True
@@ -160,9 +160,9 @@ def get_files_in_shared_drive(
update_traversed_ids_func(drive_id)
# Get all files in the shared drive
query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
query += " and trashed = false"
query += _generate_time_range_filter(start, end)
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
file_query += _generate_time_range_filter(start, end)
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
@@ -172,7 +172,7 @@ def get_files_in_shared_drive(
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
q=file_query,
)
@@ -185,14 +185,16 @@ def get_all_files_in_my_drive(
) -> Iterator[GoogleDriveFileType]:
# If we know we are going to folder crawl later, we can cache the folders here
# Get all folders being queried and add them to the traversed set
query = "trashed = false and 'me' in owners"
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
folder_query += " and trashed = false"
folder_query += " and 'me' in owners"
found_folders = False
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="user",
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
q=folder_query,
):
update_traversed_ids_func(file["id"])
found_folders = True
@@ -200,18 +202,52 @@ def get_all_files_in_my_drive(
update_traversed_ids_func(get_root_folder_id(service))
# Then get the files
query = "trashed = false and 'me' in owners"
query += _generate_time_range_filter(start, end)
fields = "files(id, name, mimeType, webViewLink, modifiedTime, createdTime)"
if not is_slim:
fields += ", files(permissions, permissionIds, owners)"
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
file_query += " and 'me' in owners"
file_query += _generate_time_range_filter(start, end)
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="user",
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
q=file_query,
)
def get_all_files_for_oauth(
service: Any,
include_files_shared_with_me: bool,
include_my_drives: bool,
# One of the above 2 should be true
include_shared_drives: bool,
is_slim: bool = False,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
should_get_all = (
include_shared_drives and include_my_drives and include_files_shared_with_me
)
corpora = "allDrives" if should_get_all else "user"
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
file_query += _generate_time_range_filter(start, end)
if not should_get_all:
if include_files_shared_with_me and not include_my_drives:
file_query += " and not 'me' in owners"
if not include_files_shared_with_me and include_my_drives:
file_query += " and 'me' in owners"
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora=corpora,
includeItemsFromAllDrives=should_get_all,
supportsAllDrives=should_get_all,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=file_query,
)

View File

@@ -105,7 +105,7 @@ def execute_paginated_retrieval(
)()
elif e.resp.status == 404 or e.resp.status == 403:
if continue_on_404_or_403:
logger.warning(f"Error executing request: {e}")
logger.debug(f"Error executing request: {e}")
results = {}
else:
raise e

View File

@@ -2,8 +2,8 @@ import os
from sqlalchemy.orm import Session
from danswer.db.models import SlackBotConfig
from danswer.db.slack_bot_config import fetch_slack_bot_configs
from danswer.db.models import SlackChannelConfig
from danswer.db.slack_channel_config import fetch_slack_channel_configs
VALID_SLACK_FILTERS = [
@@ -13,53 +13,59 @@ VALID_SLACK_FILTERS = [
]
def get_slack_bot_config_for_channel(
channel_name: str | None, db_session: Session
) -> SlackBotConfig | None:
def get_slack_channel_config_for_bot_and_channel(
db_session: Session,
slack_bot_id: int,
channel_name: str | None,
) -> SlackChannelConfig | None:
if not channel_name:
return None
slack_bot_configs = fetch_slack_bot_configs(db_session=db_session)
slack_bot_configs = fetch_slack_channel_configs(
db_session=db_session, slack_bot_id=slack_bot_id
)
for config in slack_bot_configs:
if channel_name in config.channel_config["channel_names"]:
if channel_name in config.channel_config["channel_name"]:
return config
return None
def validate_channel_names(
channel_names: list[str],
current_slack_bot_config_id: int | None,
def validate_channel_name(
db_session: Session,
) -> list[str]:
"""Make sure that these channel_names don't exist in other slack bot configs.
Returns a list of cleaned up channel names (e.g. '#' removed if present)"""
slack_bot_configs = fetch_slack_bot_configs(db_session=db_session)
cleaned_channel_names = [
channel_name.lstrip("#").lower() for channel_name in channel_names
]
for slack_bot_config in slack_bot_configs:
if slack_bot_config.id == current_slack_bot_config_id:
current_slack_bot_id: int,
channel_name: str,
current_slack_channel_config_id: int | None,
) -> str:
"""Make sure that this channel_name does not exist in other Slack channel configs.
Returns a cleaned up channel name (e.g. '#' removed if present)"""
slack_bot_configs = fetch_slack_channel_configs(
db_session=db_session,
slack_bot_id=current_slack_bot_id,
)
cleaned_channel_name = channel_name.lstrip("#").lower()
for slack_channel_config in slack_bot_configs:
if slack_channel_config.id == current_slack_channel_config_id:
continue
for channel_name in cleaned_channel_names:
if channel_name in slack_bot_config.channel_config["channel_names"]:
raise ValueError(
f"Channel name '{channel_name}' already exists in "
"another slack bot config"
)
if cleaned_channel_name == slack_channel_config.channel_config["channel_name"]:
raise ValueError(
f"Channel name '{channel_name}' already exists in "
"another Slack channel config with in Slack Bot with name: "
f"{slack_channel_config.slack_bot.name}"
)
return cleaned_channel_names
return cleaned_channel_name
# Scaling configurations for multi-tenant Slack bot handling
# Scaling configurations for multi-tenant Slack channel handling
TENANT_LOCK_EXPIRATION = 1800 # How long a pod can hold exclusive access to a tenant before other pods can acquire it
TENANT_HEARTBEAT_INTERVAL = (
60 # How often pods send heartbeats to indicate they are still processing a tenant
15 # How often pods send heartbeats to indicate they are still processing a tenant
)
TENANT_HEARTBEAT_EXPIRATION = 180 # How long before a tenant's heartbeat expires, allowing other pods to take over
TENANT_ACQUISITION_INTERVAL = (
60 # How often pods attempt to acquire unprocessed tenants
TENANT_HEARTBEAT_EXPIRATION = (
30 # How long before a tenant's heartbeat expires, allowing other pods to take over
)
TENANT_ACQUISITION_INTERVAL = 60 # How often pods attempt to acquire unprocessed tenants and checks for new tokens
MAX_TENANTS_PER_POD = int(os.getenv("MAX_TENANTS_PER_POD", 50))

View File

@@ -13,7 +13,7 @@ from danswer.connectors.slack.utils import expert_info_from_slack_id
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.danswerbot.slack.blocks import build_follow_up_resolved_blocks
from danswer.danswerbot.slack.blocks import get_document_feedback_blocks
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
from danswer.danswerbot.slack.config import get_slack_channel_config_for_bot_and_channel
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FeedbackVisibility
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
@@ -117,8 +117,10 @@ def handle_generate_answer_button(
)
with get_session_with_tenant(client.tenant_id) as db_session:
slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session,
slack_bot_id=client.slack_bot_id,
channel_name=channel_name,
)
handle_regular_answer(
@@ -133,7 +135,7 @@ def handle_generate_answer_button(
is_bot_msg=False,
is_bot_dm=False,
),
slack_bot_config=slack_bot_config,
slack_channel_config=slack_channel_config,
receiver_ids=None,
client=client.web_client,
tenant_id=client.tenant_id,
@@ -256,11 +258,13 @@ def handle_followup_button(
channel_name, is_dm = get_channel_name_from_id(
client=client.web_client, channel_id=channel_id
)
slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session,
slack_bot_id=client.slack_bot_id,
channel_name=channel_name,
)
if slack_bot_config:
tag_names = slack_bot_config.channel_config.get("follow_up_tags")
if slack_channel_config:
tag_names = slack_channel_config.channel_config.get("follow_up_tags")
remaining = None
if tag_names:
tag_ids, remaining = fetch_user_ids_from_emails(

View File

@@ -19,8 +19,8 @@ from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import slack_usage_report
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import SlackBotConfig
from danswer.db.users import add_non_web_user_if_not_exists
from danswer.db.models import SlackChannelConfig
from danswer.db.users import add_slack_user_if_not_exists
from danswer.utils.logger import setup_logger
from shared_configs.configs import SLACK_CHANNEL_ID
@@ -106,7 +106,7 @@ def remove_scheduled_feedback_reminder(
def handle_message(
message_info: SlackMessageInfo,
slack_bot_config: SlackBotConfig | None,
slack_channel_config: SlackChannelConfig | None,
client: WebClient,
feedback_reminder_id: str | None,
tenant_id: str | None,
@@ -140,7 +140,7 @@ def handle_message(
)
document_set_names: list[str] | None = None
persona = slack_bot_config.persona if slack_bot_config else None
persona = slack_channel_config.persona if slack_channel_config else None
prompt = None
if persona:
document_set_names = [
@@ -152,8 +152,8 @@ def handle_message(
respond_member_group_list = None
channel_conf = None
if slack_bot_config and slack_bot_config.channel_config:
channel_conf = slack_bot_config.channel_config
if slack_channel_config and slack_channel_config.channel_config:
channel_conf = slack_channel_config.channel_config
if not bypass_filters and "answer_filters" in channel_conf:
if (
"questionmark_prefilter" in channel_conf["answer_filters"]
@@ -213,13 +213,13 @@ def handle_message(
with get_session_with_tenant(tenant_id) as db_session:
if message_info.email:
add_non_web_user_if_not_exists(db_session, message_info.email)
add_slack_user_if_not_exists(db_session, message_info.email)
# first check if we need to respond with a standard answer
used_standard_answer = handle_standard_answers(
message_info=message_info,
receiver_ids=send_to,
slack_bot_config=slack_bot_config,
slack_channel_config=slack_channel_config,
prompt=prompt,
logger=logger,
client=client,
@@ -231,7 +231,7 @@ def handle_message(
# if no standard answer applies, try a regular answer
issue_with_regular_answer = handle_regular_answer(
message_info=message_info,
slack_bot_config=slack_bot_config,
slack_channel_config=slack_channel_config,
receiver_ids=send_to,
client=client,
channel=channel,

View File

@@ -34,8 +34,8 @@ from danswer.danswerbot.slack.utils import SlackRateLimiter
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import Persona
from danswer.db.models import SlackBotConfig
from danswer.db.models import SlackBotResponseType
from danswer.db.models import SlackChannelConfig
from danswer.db.persona import fetch_persona_by_id
from danswer.db.search_settings import get_current_search_settings
from danswer.db.users import get_user_by_email
@@ -81,7 +81,7 @@ def rate_limits(
def handle_regular_answer(
message_info: SlackMessageInfo,
slack_bot_config: SlackBotConfig | None,
slack_channel_config: SlackChannelConfig | None,
receiver_ids: list[str] | None,
client: WebClient,
channel: str,
@@ -96,7 +96,7 @@ def handle_regular_answer(
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
) -> bool:
channel_conf = slack_bot_config.channel_config if slack_bot_config else None
channel_conf = slack_channel_config.channel_config if slack_channel_config else None
messages = message_info.thread_messages
message_ts_to_respond_to = message_info.msg_to_respond
@@ -108,7 +108,7 @@ def handle_regular_answer(
user = get_user_by_email(message_info.email, db_session)
document_set_names: list[str] | None = None
persona = slack_bot_config.persona if slack_bot_config else None
persona = slack_channel_config.persona if slack_channel_config else None
prompt = None
if persona:
document_set_names = [
@@ -120,9 +120,9 @@ def handle_regular_answer(
bypass_acl = False
if (
slack_bot_config
and slack_bot_config.persona
and slack_bot_config.persona.document_sets
slack_channel_config
and slack_channel_config.persona
and slack_channel_config.persona.document_sets
):
# For Slack channels, use the full document set, admin will be warned when configuring it
# with non-public document sets
@@ -131,8 +131,8 @@ def handle_regular_answer(
# figure out if we want to use citations or quotes
use_citations = (
not DANSWER_BOT_USE_QUOTES
if slack_bot_config is None
else slack_bot_config.response_type == SlackBotResponseType.CITATIONS
if slack_channel_config is None
else slack_channel_config.response_type == SlackBotResponseType.CITATIONS
)
if not message_ts_to_respond_to and not is_bot_msg:
@@ -234,8 +234,8 @@ def handle_regular_answer(
# persona.llm_filter_extraction if persona is not None else True
# )
auto_detect_filters = (
slack_bot_config.enable_auto_filters
if slack_bot_config is not None
slack_channel_config.enable_auto_filters
if slack_channel_config is not None
else False
)
retrieval_details = RetrievalDetails(

View File

@@ -3,7 +3,7 @@ from sqlalchemy.orm import Session
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.db.models import Prompt
from danswer.db.models import SlackBotConfig
from danswer.db.models import SlackChannelConfig
from danswer.utils.logger import DanswerLoggingAdapter
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
@@ -14,7 +14,7 @@ logger = setup_logger()
def handle_standard_answers(
message_info: SlackMessageInfo,
receiver_ids: list[str] | None,
slack_bot_config: SlackBotConfig | None,
slack_channel_config: SlackChannelConfig | None,
prompt: Prompt | None,
logger: DanswerLoggingAdapter,
client: WebClient,
@@ -29,7 +29,7 @@ def handle_standard_answers(
return versioned_handle_standard_answers(
message_info=message_info,
receiver_ids=receiver_ids,
slack_bot_config=slack_bot_config,
slack_channel_config=slack_channel_config,
prompt=prompt,
logger=logger,
client=client,
@@ -40,7 +40,7 @@ def handle_standard_answers(
def _handle_standard_answers(
message_info: SlackMessageInfo,
receiver_ids: list[str] | None,
slack_bot_config: SlackBotConfig | None,
slack_channel_config: SlackChannelConfig | None,
prompt: Prompt | None,
logger: DanswerLoggingAdapter,
client: WebClient,

View File

@@ -4,6 +4,7 @@ import signal
import sys
import threading
import time
from collections.abc import Callable
from threading import Event
from types import FrameType
from typing import Any
@@ -16,14 +17,17 @@ from prometheus_client import start_http_server
from slack_sdk import WebClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from sqlalchemy.orm import Session
from danswer.configs.app_configs import POD_NAME
from danswer.configs.app_configs import POD_NAMESPACE
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
from danswer.connectors.slack.utils import expert_info_from_slack_id
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
from danswer.danswerbot.slack.config import get_slack_channel_config_for_bot_and_channel
from danswer.danswerbot.slack.config import MAX_TENANTS_PER_POD
from danswer.danswerbot.slack.config import TENANT_ACQUISITION_INTERVAL
from danswer.danswerbot.slack.config import TENANT_HEARTBEAT_EXPIRATION
@@ -52,20 +56,20 @@ from danswer.danswerbot.slack.handlers.handle_message import (
)
from danswer.danswerbot.slack.handlers.handle_message import schedule_feedback_reminder
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.tokens import fetch_tokens
from danswer.danswerbot.slack.utils import check_message_limit
from danswer.danswerbot.slack.utils import decompose_action_id
from danswer.danswerbot.slack.utils import get_channel_name_from_id
from danswer.danswerbot.slack.utils import get_danswer_bot_app_id
from danswer.danswerbot.slack.utils import get_danswer_bot_slack_bot_id
from danswer.danswerbot.slack.utils import read_slack_thread
from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
from danswer.danswerbot.slack.utils import rephrase_slack_message
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import TenantSocketModeClient
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import SlackBot
from danswer.db.search_settings import get_current_search_settings
from danswer.db.slack_bot import fetch_slack_bots
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
@@ -75,16 +79,21 @@ from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import DISALLOWED_SLACK_BOT_TENANT_LIST
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SLACK_CHANNEL_ID
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
# Prometheus metric for HPA
active_tenants_gauge = Gauge(
"active_tenants", "Number of active tenants handled by this pod"
"active_tenants",
"Number of active tenants handled by this pod",
["namespace", "pod"],
)
# In rare cases, some users have been experiencing a massive amount of trivial messages coming through
@@ -108,8 +117,10 @@ class SlackbotHandler:
def __init__(self) -> None:
logger.info("Initializing SlackbotHandler")
self.tenant_ids: Set[str | None] = set()
self.socket_clients: Dict[str | None, TenantSocketModeClient] = {}
self.slack_bot_tokens: Dict[str | None, SlackBotTokens] = {}
# The keys for these dictionaries are tuples of (tenant_id, slack_bot_id)
self.socket_clients: Dict[tuple[str | None, int], TenantSocketModeClient] = {}
self.slack_bot_tokens: Dict[tuple[str | None, int], SlackBotTokens] = {}
self.running = True
self.pod_id = self.get_pod_id()
self._shutdown_event = Event()
@@ -147,7 +158,9 @@ class SlackbotHandler:
while not self._shutdown_event.is_set():
try:
self.acquire_tenants()
active_tenants_gauge.set(len(self.tenant_ids))
active_tenants_gauge.labels(namespace=POD_NAMESPACE, pod=POD_NAME).set(
len(self.tenant_ids)
)
logger.debug(f"Current active tenants: {len(self.tenant_ids)}")
except Exception as e:
logger.exception(f"Error in Slack acquisition: {e}")
@@ -162,11 +175,63 @@ class SlackbotHandler:
logger.exception(f"Error in heartbeat loop: {e}")
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)
def _manage_clients_per_tenant(
self, db_session: Session, tenant_id: str | None, bot: SlackBot
) -> None:
slack_bot_tokens = SlackBotTokens(
bot_token=bot.bot_token,
app_token=bot.app_token,
)
tenant_bot_pair = (tenant_id, bot.id)
# If the tokens are not set, we need to close the socket client and delete the tokens
# for the tenant and app
if not slack_bot_tokens:
logger.debug(
f"No Slack bot token found for tenant {tenant_id}, bot {bot.id}"
)
if tenant_bot_pair in self.socket_clients:
asyncio.run(self.socket_clients[tenant_bot_pair].close())
del self.socket_clients[tenant_bot_pair]
del self.slack_bot_tokens[tenant_bot_pair]
return
tokens_exist = tenant_bot_pair in self.slack_bot_tokens
tokens_changed = (
tokens_exist and slack_bot_tokens != self.slack_bot_tokens[tenant_bot_pair]
)
if not tokens_exist or tokens_changed:
if tokens_exist:
logger.info(
f"Slack Bot tokens have changed for tenant {tenant_id}, bot {bot.id} - reconnecting"
)
else:
search_settings = get_current_search_settings(db_session)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
self.slack_bot_tokens[tenant_bot_pair] = slack_bot_tokens
if tenant_bot_pair in self.socket_clients:
asyncio.run(self.socket_clients[tenant_bot_pair].close())
self.start_socket_client(bot.id, tenant_id, slack_bot_tokens)
def acquire_tenants(self) -> None:
tenant_ids = get_all_tenant_ids()
logger.debug(f"Found {len(tenant_ids)} total tenants in Postgres")
for tenant_id in tenant_ids:
if (
DISALLOWED_SLACK_BOT_TENANT_LIST is not None
and tenant_id in DISALLOWED_SLACK_BOT_TENANT_LIST
):
logger.debug(f"Tenant {tenant_id} is in the disallowed list, skipping")
continue
if tenant_id in self.tenant_ids:
logger.debug(f"Tenant {tenant_id} already in self.tenant_ids")
continue
@@ -190,63 +255,30 @@ class SlackbotHandler:
continue
logger.debug(f"Acquired lock for tenant {tenant_id}")
self.tenant_ids.add(tenant_id)
for tenant_id in self.tenant_ids:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
tenant_id or POSTGRES_DEFAULT_SCHEMA
)
try:
with get_session_with_tenant(tenant_id) as db_session:
try:
logger.debug(
f"Setting tenant ID context variable for tenant {tenant_id}"
)
slack_bot_tokens = fetch_tokens()
logger.debug(f"Fetched Slack bot tokens for tenant {tenant_id}")
logger.debug(
f"Reset tenant ID context variable for tenant {tenant_id}"
)
if not slack_bot_tokens:
logger.debug(
f"No Slack bot token found for tenant {tenant_id}"
bots = fetch_slack_bots(db_session=db_session)
for bot in bots:
self._manage_clients_per_tenant(
db_session=db_session,
tenant_id=tenant_id,
bot=bot,
)
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
continue
if (
tenant_id not in self.slack_bot_tokens
or slack_bot_tokens != self.slack_bot_tokens[tenant_id]
):
if tenant_id in self.slack_bot_tokens:
logger.info(
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
)
else:
search_settings = get_current_search_settings(
db_session
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
self.slack_bot_tokens[tenant_id] = slack_bot_tokens
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
self.start_socket_client(tenant_id, slack_bot_tokens)
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
if (tenant_id, bot.id) in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id, bot.id].close())
del self.socket_clients[tenant_id, bot.id]
del self.slack_bot_tokens[tenant_id, bot.id]
except Exception as e:
logger.exception(f"Error handling tenant {tenant_id}: {e}")
finally:
@@ -265,26 +297,37 @@ class SlackbotHandler:
)
def start_socket_client(
self, tenant_id: str | None, slack_bot_tokens: SlackBotTokens
self, slack_bot_id: int, tenant_id: str | None, slack_bot_tokens: SlackBotTokens
) -> None:
logger.info(f"Starting socket client for tenant {tenant_id}")
socket_client = _get_socket_client(slack_bot_tokens, tenant_id)
logger.info(
f"Starting socket client for tenant: {tenant_id}, app: {slack_bot_id}"
)
socket_client: TenantSocketModeClient = _get_socket_client(
slack_bot_tokens, tenant_id, slack_bot_id
)
# Append the event handler
process_slack_event = create_process_slack_event()
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
# Establish a WebSocket connection to the Socket Mode servers
logger.info(f"Connecting socket client for tenant {tenant_id}")
logger.info(
f"Connecting socket client for tenant: {tenant_id}, app: {slack_bot_id}"
)
socket_client.connect()
self.socket_clients[tenant_id] = socket_client
self.socket_clients[tenant_id, slack_bot_id] = socket_client
self.tenant_ids.add(tenant_id)
logger.info(f"Started SocketModeClient for tenant {tenant_id}")
logger.info(
f"Started SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
)
def stop_socket_clients(self) -> None:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
for tenant_id, client in self.socket_clients.items():
for (tenant_id, slack_bot_id), client in self.socket_clients.items():
asyncio.run(client.close())
logger.info(f"Stopped SocketModeClient for tenant {tenant_id}")
logger.info(
f"Stopped SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
)
def shutdown(self, signum: int | None, frame: FrameType | None) -> None:
if not self.running:
@@ -298,6 +341,16 @@ class SlackbotHandler:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
self.stop_socket_clients()
# Release locks for all tenants
logger.info(f"Releasing locks for {len(self.tenant_ids)} tenants")
for tenant_id in self.tenant_ids:
try:
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client.delete(DanswerRedisLocks.SLACK_BOT_LOCK)
logger.info(f"Released lock for tenant {tenant_id}")
except Exception as e:
logger.error(f"Error releasing lock for tenant {tenant_id}: {e}")
# Wait for background threads to finish (with timeout)
logger.info("Waiting for background threads to finish...")
self.acquire_thread.join(timeout=5)
@@ -358,7 +411,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
)
return False
bot_tag_id = get_danswer_bot_app_id(client.web_client)
bot_tag_id = get_danswer_bot_slack_bot_id(client.web_client)
if event_type == "message":
is_dm = event.get("channel_type") == "im"
is_tagged = bot_tag_id and bot_tag_id in msg
@@ -381,13 +434,15 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
)
with get_session_with_tenant(client.tenant_id) as db_session:
slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session,
slack_bot_id=client.slack_bot_id,
channel_name=channel_name,
)
# If DanswerBot is not specifically tagged and the channel is not set to respond to bots, ignore the message
if (not bot_tag_id or bot_tag_id not in msg) and (
not slack_bot_config
or not slack_bot_config.channel_config.get("respond_to_bots")
not slack_channel_config
or not slack_channel_config.channel_config.get("respond_to_bots")
):
channel_specific_logger.info("Ignoring message from bot")
return False
@@ -592,14 +647,16 @@ def process_message(
token = CURRENT_TENANT_ID_CONTEXTVAR.set(client.tenant_id)
try:
with get_session_with_tenant(client.tenant_id) as db_session:
slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session,
slack_bot_id=client.slack_bot_id,
channel_name=channel_name,
)
# Be careful about this default, don't want to accidentally spam every channel
# Users should be able to DM slack bot in their private channels though
if (
slack_bot_config is None
slack_channel_config is None
and not respond_every_channel
# Can't have configs for DMs so don't toss them out
and not is_dm
@@ -610,9 +667,10 @@ def process_message(
return
follow_up = bool(
slack_bot_config
and slack_bot_config.channel_config
and slack_bot_config.channel_config.get("follow_up_tags") is not None
slack_channel_config
and slack_channel_config.channel_config
and slack_channel_config.channel_config.get("follow_up_tags")
is not None
)
feedback_reminder_id = schedule_feedback_reminder(
details=details, client=client.web_client, include_followup=follow_up
@@ -620,7 +678,7 @@ def process_message(
failed = handle_message(
message_info=details,
slack_bot_config=slack_bot_config,
slack_channel_config=slack_channel_config,
client=client.web_client,
feedback_reminder_id=feedback_reminder_id,
tenant_id=client.tenant_id,
@@ -672,26 +730,32 @@ def view_routing(req: SocketModeRequest, client: TenantSocketModeClient) -> None
return process_feedback(req, client)
def process_slack_event(client: TenantSocketModeClient, req: SocketModeRequest) -> None:
# Always respond right away, if Slack doesn't receive these frequently enough
# it will assume the Bot is DEAD!!! :(
acknowledge_message(req, client)
def create_process_slack_event() -> (
Callable[[TenantSocketModeClient, SocketModeRequest], None]
):
def process_slack_event(
client: TenantSocketModeClient, req: SocketModeRequest
) -> None:
# Always respond right away, if Slack doesn't receive these frequently enough
# it will assume the Bot is DEAD!!! :(
acknowledge_message(req, client)
try:
if req.type == "interactive":
if req.payload.get("type") == "block_actions":
return action_routing(req, client)
elif req.payload.get("type") == "view_submission":
return view_routing(req, client)
elif req.type == "events_api" or req.type == "slash_commands":
return process_message(req, client)
except Exception as e:
logger.exception(f"Failed to process slack event. Error: {e}")
logger.error(f"Slack request payload: {req.payload}")
try:
if req.type == "interactive":
if req.payload.get("type") == "block_actions":
return action_routing(req, client)
elif req.payload.get("type") == "view_submission":
return view_routing(req, client)
elif req.type == "events_api" or req.type == "slash_commands":
return process_message(req, client)
except Exception:
logger.exception("Failed to process slack event")
return process_slack_event
def _get_socket_client(
slack_bot_tokens: SlackBotTokens, tenant_id: str | None
slack_bot_tokens: SlackBotTokens, tenant_id: str | None, slack_bot_id: int
) -> TenantSocketModeClient:
# For more info on how to set this up, checkout the docs:
# https://docs.danswer.dev/slack_bot_setup
@@ -700,6 +764,7 @@ def _get_socket_client(
app_token=slack_bot_tokens.app_token,
web_client=WebClient(token=slack_bot_tokens.bot_token),
tenant_id=tenant_id,
slack_bot_id=slack_bot_id,
)

View File

@@ -1,28 +0,0 @@
import os
from typing import cast
from danswer.configs.constants import KV_SLACK_BOT_TOKENS_CONFIG_KEY
from danswer.key_value_store.factory import get_kv_store
from danswer.server.manage.models import SlackBotTokens
def fetch_tokens() -> SlackBotTokens:
# first check env variables
app_token = os.environ.get("DANSWER_BOT_SLACK_APP_TOKEN")
bot_token = os.environ.get("DANSWER_BOT_SLACK_BOT_TOKEN")
if app_token and bot_token:
return SlackBotTokens(app_token=app_token, bot_token=bot_token)
dynamic_config_store = get_kv_store()
return SlackBotTokens(
**cast(dict, dynamic_config_store.load(key=KV_SLACK_BOT_TOKENS_CONFIG_KEY))
)
def save_tokens(
tokens: SlackBotTokens,
) -> None:
dynamic_config_store = get_kv_store()
dynamic_config_store.store(
key=KV_SLACK_BOT_TOKENS_CONFIG_KEY, val=dict(tokens), encrypt=True
)

View File

@@ -30,7 +30,6 @@ from danswer.configs.danswerbot_configs import (
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import SlackTextCleaner
from danswer.danswerbot.slack.constants import FeedbackVisibility
from danswer.danswerbot.slack.tokens import fetch_tokens
from danswer.db.engine import get_session_with_tenant
from danswer.db.users import get_user_by_email
from danswer.llm.exceptions import GenAIDisabledException
@@ -47,16 +46,16 @@ from danswer.utils.text_processing import replace_whitespaces_w_space
logger = setup_logger()
_DANSWER_BOT_APP_ID: str | None = None
_DANSWER_BOT_SLACK_BOT_ID: str | None = None
_DANSWER_BOT_MESSAGE_COUNT: int = 0
_DANSWER_BOT_COUNT_START_TIME: float = time.time()
def get_danswer_bot_app_id(web_client: WebClient) -> Any:
global _DANSWER_BOT_APP_ID
if _DANSWER_BOT_APP_ID is None:
_DANSWER_BOT_APP_ID = web_client.auth_test().get("user_id")
return _DANSWER_BOT_APP_ID
def get_danswer_bot_slack_bot_id(web_client: WebClient) -> Any:
global _DANSWER_BOT_SLACK_BOT_ID
if _DANSWER_BOT_SLACK_BOT_ID is None:
_DANSWER_BOT_SLACK_BOT_ID = web_client.auth_test().get("user_id")
return _DANSWER_BOT_SLACK_BOT_ID
def check_message_limit() -> bool:
@@ -137,15 +136,10 @@ def update_emote_react(
def remove_danswer_bot_tag(message_str: str, client: WebClient) -> str:
bot_tag_id = get_danswer_bot_app_id(web_client=client)
bot_tag_id = get_danswer_bot_slack_bot_id(web_client=client)
return re.sub(rf"<@{bot_tag_id}>\s", "", message_str)
def get_web_client() -> WebClient:
slack_tokens = fetch_tokens()
return WebClient(token=slack_tokens.bot_token)
@retry(
tries=DANSWER_BOT_NUM_RETRIES,
delay=0.25,
@@ -437,9 +431,9 @@ def read_slack_thread(
)
message_type = MessageType.USER
else:
self_app_id = get_danswer_bot_app_id(client)
self_slack_bot_id = get_danswer_bot_slack_bot_id(client)
if reply.get("user") == self_app_id:
if reply.get("user") == self_slack_bot_id:
# DanswerBot response
message_type = MessageType.ASSISTANT
user_sem_id = "Assistant"
@@ -582,6 +576,9 @@ def get_feedback_visibility() -> FeedbackVisibility:
class TenantSocketModeClient(SocketModeClient):
def __init__(self, tenant_id: str | None, *args: Any, **kwargs: Any):
def __init__(
self, tenant_id: str | None, slack_bot_id: int, *args: Any, **kwargs: Any
):
super().__init__(*args, **kwargs)
self.tenant_id = tenant_id
self.slack_bot_id = slack_bot_id

View File

@@ -4,6 +4,7 @@ from typing import Any
from typing import Dict
from fastapi import Depends
from fastapi_users.models import ID
from fastapi_users.models import UP
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyAccessTokenDatabase
@@ -43,7 +44,10 @@ def get_total_users_count(db_session: Session) -> int:
"""
user_count = (
db_session.query(User)
.filter(~User.email.endswith(get_api_key_email_pattern())) # type: ignore
.filter(
~User.email.endswith(get_api_key_email_pattern()), # type: ignore
User.role != UserRole.EXT_PERM_USER,
)
.count()
)
invited_users = len(get_invited_users())
@@ -61,7 +65,7 @@ async def get_user_count() -> int:
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase):
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
async def create(
self,
create_dict: Dict[str, Any],

View File

@@ -282,3 +282,32 @@ def mark_ccpair_as_pruned(cc_pair_id: int, db_session: Session) -> None:
cc_pair.last_pruned = datetime.now(timezone.utc)
db_session.commit()
def mark_cc_pair_as_permissions_synced(
db_session: Session, cc_pair_id: int, start_time: datetime | None
) -> None:
stmt = select(ConnectorCredentialPair).where(
ConnectorCredentialPair.id == cc_pair_id
)
cc_pair = db_session.scalar(stmt)
if cc_pair is None:
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
cc_pair.last_time_perm_sync = start_time
db_session.commit()
def mark_cc_pair_as_external_group_synced(db_session: Session, cc_pair_id: int) -> None:
stmt = select(ConnectorCredentialPair).where(
ConnectorCredentialPair.id == cc_pair_id
)
cc_pair = db_session.scalar(stmt)
if cc_pair is None:
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
# The sync time can be marked after it ran because all group syncs
# are run in full, not polling for changes.
# If this changes, we need to update this function.
cc_pair.last_time_external_group_sync = datetime.now(timezone.utc)
db_session.commit()

View File

@@ -76,8 +76,10 @@ def _add_user_filters(
.where(~UG__CCpair.user_group_id.in_(user_groups))
.correlate(ConnectorCredentialPair)
)
where_clause |= ConnectorCredentialPair.creator_id == user.id
else:
where_clause |= ConnectorCredentialPair.access_type == AccessType.PUBLIC
where_clause |= ConnectorCredentialPair.access_type == AccessType.SYNC
return stmt.where(where_clause)
@@ -387,6 +389,7 @@ def add_credential_to_connector(
)
association = ConnectorCredentialPair(
creator_id=user.id if user else None,
connector_id=connector_id,
credential_id=credential_id,
name=cc_pair_name,

View File

@@ -19,6 +19,7 @@ from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import null
from danswer.configs.constants import DEFAULT_BOOST
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.feedback import delete_document_feedback_for_documents__no_commit
@@ -46,13 +47,21 @@ def count_documents_by_needs_sync(session: Session) -> int:
"""Get the count of all documents where:
1. last_modified is newer than last_synced
2. last_synced is null (meaning we've never synced)
AND the document has a relationship with a connector/credential pair
TODO: The documents without a relationship with a connector/credential pair
should be cleaned up somehow eventually.
This function executes the query and returns the count of
documents matching the criteria."""
count = (
session.query(func.count())
session.query(func.count(DbDocument.id.distinct()))
.select_from(DbDocument)
.join(
DocumentByConnectorCredentialPair,
DbDocument.id == DocumentByConnectorCredentialPair.id,
)
.filter(
or_(
DbDocument.last_modified > DbDocument.last_synced,
@@ -91,6 +100,22 @@ def construct_document_select_for_connector_credential_pair_by_needs_sync(
return stmt
def get_all_documents_needing_vespa_sync_for_cc_pair(
db_session: Session, cc_pair_id: int
) -> list[DbDocument]:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id, db_session=db_session
)
if not cc_pair:
raise ValueError(f"No CC pair found with ID: {cc_pair_id}")
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)
return list(db_session.scalars(stmt).all())
def construct_document_select_for_connector_credential_pair(
connector_id: int, credential_id: int | None = None
) -> Select:
@@ -104,6 +129,21 @@ def construct_document_select_for_connector_credential_pair(
return stmt
def get_documents_for_cc_pair(
db_session: Session,
cc_pair_id: int,
) -> list[DbDocument]:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id, db_session=db_session
)
if not cc_pair:
raise ValueError(f"No CC pair found with ID: {cc_pair_id}")
stmt = construct_document_select_for_connector_credential_pair(
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
)
return list(db_session.scalars(stmt).all())
def get_document_ids_for_connector_credential_pair(
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
) -> list[str]:
@@ -169,6 +209,7 @@ def get_document_connector_counts(
def get_document_counts_for_cc_pairs(
db_session: Session, cc_pair_identifiers: list[ConnectorCredentialPairIdentifier]
) -> Sequence[tuple[int, int, int]]:
"""Returns a sequence of tuples of (connector_id, credential_id, document count)"""
stmt = (
select(
DocumentByConnectorCredentialPair.connector_id,
@@ -306,6 +347,8 @@ def upsert_documents(
]
)
# This does not update the permissions of the document if
# the document already exists.
on_conflict_stmt = insert_stmt.on_conflict_do_update(
index_elements=["id"], # Conflict target
set_={
@@ -323,23 +366,23 @@ def upsert_documents(
def upsert_document_by_connector_credential_pair(
db_session: Session, document_metadata_batch: list[DocumentMetadata]
db_session: Session, connector_id: int, credential_id: int, document_ids: list[str]
) -> None:
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause."""
if not document_metadata_batch:
logger.info("`document_metadata_batch` is empty. Skipping.")
if not document_ids:
logger.info("`document_ids` is empty. Skipping.")
return
insert_stmt = insert(DocumentByConnectorCredentialPair).values(
[
model_to_dict(
DocumentByConnectorCredentialPair(
id=document_metadata.document_id,
connector_id=document_metadata.connector_id,
credential_id=document_metadata.credential_id,
id=doc_id,
connector_id=connector_id,
credential_id=credential_id,
)
)
for document_metadata in document_metadata_batch
for doc_id in document_ids
]
)
# for now, there are no columns to update. If more metadata is added, then this
@@ -400,17 +443,6 @@ def mark_document_as_synced(document_id: str, db_session: Session) -> None:
db_session.commit()
def upsert_documents_complete(
db_session: Session,
document_metadata_batch: list[DocumentMetadata],
) -> None:
upsert_documents(db_session, document_metadata_batch)
upsert_document_by_connector_credential_pair(db_session, document_metadata_batch)
logger.info(
f"Upserted {len(document_metadata_batch)} document store entries into DB"
)
def delete_document_by_connector_credential_pair__no_commit(
db_session: Session,
document_id: str,
@@ -463,7 +495,6 @@ def delete_documents_complete__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
"""This completely deletes the documents from the db, including all foreign key relationships"""
logger.info(f"Deleting {len(document_ids)} documents from the DB")
delete_documents_by_connector_credential_pair__no_commit(db_session, document_ids)
delete_document_feedback_for_documents__no_commit(
document_ids=document_ids, db_session=db_session
@@ -520,7 +551,7 @@ def prepare_to_modify_documents(
db_session.commit() # ensure that we're not in a transaction
lock_acquired = False
for _ in range(_NUM_LOCK_ATTEMPTS):
for i in range(_NUM_LOCK_ATTEMPTS):
try:
with db_session.begin() as transaction:
lock_acquired = acquire_document_locks(
@@ -531,7 +562,7 @@ def prepare_to_modify_documents(
break
except OperationalError as e:
logger.warning(
f"Failed to acquire locks for documents, retrying. Error: {e}"
f"Failed to acquire locks for documents on attempt {i}, retrying. Error: {e}"
)
time.sleep(retry_delay)

View File

@@ -189,6 +189,13 @@ class SqlEngine:
return ""
return cls._app_name
@classmethod
def reset_engine(cls) -> None:
with cls._lock:
if cls._engine:
cls._engine.dispose()
cls._engine = None
def get_all_tenant_ids() -> list[str] | list[None]:
if not MULTI_TENANT:
@@ -312,7 +319,9 @@ async def get_async_session_with_tenant(
await session.execute(text(f'SET search_path = "{tenant_id}"'))
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
await session.execute(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
)
except Exception:
logger.exception("Error setting search_path.")
@@ -373,7 +382,9 @@ def get_session_with_tenant(
cursor.execute(f'SET search_path = "{tenant_id}"')
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
cursor.execute(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
)
finally:
cursor.close()

View File

@@ -53,7 +53,7 @@ from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.enums import TaskStatus
from danswer.db.pydantic_type import PydanticType
from danswer.key_value_store.interface import JSON_ro
from danswer.utils.special_types import JSON_ro
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
@@ -126,8 +126,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# if specified, controls the assistants that are shown to the user + their order
# if not specified, all assistants are shown
chosen_assistants: Mapped[list[int]] = mapped_column(
postgresql.JSONB(), nullable=False, default=[-2, -1, 0]
chosen_assistants: Mapped[list[int] | None] = mapped_column(
postgresql.JSONB(), nullable=True, default=None
)
visible_assistants: Mapped[list[int]] = mapped_column(
postgresql.JSONB(), nullable=False, default=[]
@@ -171,8 +171,11 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
notifications: Mapped[list["Notification"]] = relationship(
"Notification", back_populates="user"
)
# Whether the user has logged in via web. False if user has only used Danswer through Slack bot
has_web_login: Mapped[bool] = mapped_column(Boolean, default=True)
cc_pairs: Mapped[list["ConnectorCredentialPair"]] = relationship(
"ConnectorCredentialPair",
back_populates="creator",
primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)",
)
class InputPrompt(Base):
@@ -347,11 +350,11 @@ class StandardAnswer__StandardAnswerCategory(Base):
)
class SlackBotConfig__StandardAnswerCategory(Base):
__tablename__ = "slack_bot_config__standard_answer_category"
class SlackChannelConfig__StandardAnswerCategory(Base):
__tablename__ = "slack_channel_config__standard_answer_category"
slack_bot_config_id: Mapped[int] = mapped_column(
ForeignKey("slack_bot_config.id"), primary_key=True
slack_channel_config_id: Mapped[int] = mapped_column(
ForeignKey("slack_channel_config.id"), primary_key=True
)
standard_answer_category_id: Mapped[int] = mapped_column(
ForeignKey("standard_answer_category.id"), primary_key=True
@@ -420,6 +423,9 @@ class ConnectorCredentialPair(Base):
last_time_perm_sync: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
last_time_external_group_sync: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
# Time finished, not used for calculating backend jobs which uses time started (created)
last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), default=None
@@ -452,6 +458,14 @@ class ConnectorCredentialPair(Base):
"IndexAttempt", back_populates="connector_credential_pair"
)
# the user id of the user that created this cc pair
creator_id: Mapped[UUID | None] = mapped_column(nullable=True)
creator: Mapped["User"] = relationship(
"User",
back_populates="cc_pairs",
primaryjoin="foreign(ConnectorCredentialPair.creator_id) == remote(User.id)",
)
class Document(Base):
__tablename__ = "document"
@@ -1349,6 +1363,9 @@ class Persona(Base):
recency_bias: Mapped[RecencyBiasSetting] = mapped_column(
Enum(RecencyBiasSetting, native_enum=False)
)
category_id: Mapped[int | None] = mapped_column(
ForeignKey("persona_category.id"), nullable=True
)
# Allows the Persona to specify a different LLM version than is controlled
# globablly via env variables. For flexibility, validity is not currently enforced
# NOTE: only is applied on the actual response generation - is not used for things like
@@ -1420,6 +1437,9 @@ class Persona(Base):
secondary="persona__user_group",
viewonly=True,
)
category: Mapped["PersonaCategory"] = relationship(
"PersonaCategory", back_populates="personas"
)
# Default personas loaded via yaml cannot have the same name
__table_args__ = (
@@ -1432,6 +1452,17 @@ class Persona(Base):
)
class PersonaCategory(Base):
__tablename__ = "persona_category"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
description: Mapped[str | None] = mapped_column(String, nullable=True)
personas: Mapped[list["Persona"]] = relationship(
"Persona", back_populates="category"
)
AllowedAnswerFilters = (
Literal["well_answered_postfilter"] | Literal["questionmark_prefilter"]
)
@@ -1441,7 +1472,7 @@ class ChannelConfig(TypedDict):
"""NOTE: is a `TypedDict` so it can be used as a type hint for a JSONB column
in Postgres"""
channel_names: list[str]
channel_name: str
respond_tag_only: NotRequired[bool] # defaults to False
respond_to_bots: NotRequired[bool] # defaults to False
respond_member_group_list: NotRequired[list[str]]
@@ -1456,10 +1487,11 @@ class SlackBotResponseType(str, PyEnum):
CITATIONS = "citations"
class SlackBotConfig(Base):
__tablename__ = "slack_bot_config"
class SlackChannelConfig(Base):
__tablename__ = "slack_channel_config"
id: Mapped[int] = mapped_column(primary_key=True)
slack_bot_id: Mapped[int] = mapped_column(ForeignKey("slack_bot.id"), nullable=True)
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True
)
@@ -1476,10 +1508,30 @@ class SlackBotConfig(Base):
)
persona: Mapped[Persona | None] = relationship("Persona")
slack_bot: Mapped["SlackBot"] = relationship(
"SlackBot",
back_populates="slack_channel_configs",
)
standard_answer_categories: Mapped[list["StandardAnswerCategory"]] = relationship(
"StandardAnswerCategory",
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
back_populates="slack_bot_configs",
secondary=SlackChannelConfig__StandardAnswerCategory.__table__,
back_populates="slack_channel_configs",
)
class SlackBot(Base):
__tablename__ = "slack_bot"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String)
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
bot_token: Mapped[str] = mapped_column(EncryptedString(), unique=True)
app_token: Mapped[str] = mapped_column(EncryptedString(), unique=True)
slack_channel_configs: Mapped[list[SlackChannelConfig]] = relationship(
"SlackChannelConfig",
back_populates="slack_bot",
)
@@ -1718,9 +1770,9 @@ class StandardAnswerCategory(Base):
secondary=StandardAnswer__StandardAnswerCategory.__table__,
back_populates="categories",
)
slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship(
"SlackBotConfig",
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
slack_channel_configs: Mapped[list["SlackChannelConfig"]] = relationship(
"SlackChannelConfig",
secondary=SlackChannelConfig__StandardAnswerCategory.__table__,
back_populates="standard_answer_categories",
)

View File

@@ -26,6 +26,7 @@ from danswer.db.models import DocumentSet
from danswer.db.models import Persona
from danswer.db.models import Persona__User
from danswer.db.models import Persona__UserGroup
from danswer.db.models import PersonaCategory
from danswer.db.models import Prompt
from danswer.db.models import StarterMessage
from danswer.db.models import Tool
@@ -417,6 +418,7 @@ def upsert_persona(
search_start_date: datetime | None = None,
builtin_persona: bool = False,
is_default_persona: bool = False,
category_id: int | None = None,
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
chunks_below: int = CONTEXT_CHUNKS_BELOW,
) -> Persona:
@@ -487,7 +489,7 @@ def upsert_persona(
persona.is_visible = is_visible
persona.search_start_date = search_start_date
persona.is_default_persona = is_default_persona
persona.category_id = category_id
# Do not delete any associations manually added unless
# a new updated list is provided
if document_sets is not None:
@@ -528,6 +530,7 @@ def upsert_persona(
is_visible=is_visible,
search_start_date=search_start_date,
is_default_persona=is_default_persona,
category_id=category_id,
)
db_session.add(persona)
@@ -743,5 +746,40 @@ def delete_persona_by_name(
)
db_session.execute(stmt)
db_session.commit()
def get_assistant_categories(db_session: Session) -> list[PersonaCategory]:
return db_session.query(PersonaCategory).all()
def create_assistant_category(
db_session: Session, name: str, description: str
) -> PersonaCategory:
category = PersonaCategory(name=name, description=description)
db_session.add(category)
db_session.commit()
return category
def update_persona_category(
category_id: int,
category_description: str,
category_name: str,
db_session: Session,
) -> None:
persona_category = (
db_session.query(PersonaCategory)
.filter(PersonaCategory.id == category_id)
.one_or_none()
)
if persona_category is None:
raise ValueError(f"Persona category with ID {category_id} does not exist")
persona_category.description = category_description
persona_category.name = category_name
db_session.commit()
def delete_persona_category(category_id: int, db_session: Session) -> None:
db_session.query(PersonaCategory).filter(PersonaCategory.id == category_id).delete()
db_session.commit()

View File

@@ -0,0 +1,76 @@
from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import SlackBot
def insert_slack_bot(
db_session: Session,
name: str,
enabled: bool,
bot_token: str,
app_token: str,
) -> SlackBot:
slack_bot = SlackBot(
name=name,
enabled=enabled,
bot_token=bot_token,
app_token=app_token,
)
db_session.add(slack_bot)
db_session.commit()
return slack_bot
def update_slack_bot(
db_session: Session,
slack_bot_id: int,
name: str,
enabled: bool,
bot_token: str,
app_token: str,
) -> SlackBot:
slack_bot = db_session.scalar(select(SlackBot).where(SlackBot.id == slack_bot_id))
if slack_bot is None:
raise ValueError(f"Unable to find Slack Bot with ID {slack_bot_id}")
# update the app
slack_bot.name = name
slack_bot.enabled = enabled
slack_bot.bot_token = bot_token
slack_bot.app_token = app_token
db_session.commit()
return slack_bot
def fetch_slack_bot(
db_session: Session,
slack_bot_id: int,
) -> SlackBot:
slack_bot = db_session.scalar(select(SlackBot).where(SlackBot.id == slack_bot_id))
if slack_bot is None:
raise ValueError(f"Unable to find Slack Bot with ID {slack_bot_id}")
return slack_bot
def remove_slack_bot(
db_session: Session,
slack_bot_id: int,
) -> None:
slack_bot = fetch_slack_bot(
db_session=db_session,
slack_bot_id=slack_bot_id,
)
db_session.delete(slack_bot)
db_session.commit()
def fetch_slack_bots(db_session: Session) -> Sequence[SlackBot]:
return db_session.scalars(select(SlackBot)).all()

View File

@@ -9,8 +9,8 @@ from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
from danswer.db.models import ChannelConfig
from danswer.db.models import Persona
from danswer.db.models import Persona__DocumentSet
from danswer.db.models import SlackBotConfig
from danswer.db.models import SlackBotResponseType
from danswer.db.models import SlackChannelConfig
from danswer.db.models import User
from danswer.db.persona import get_default_prompt
from danswer.db.persona import mark_persona_as_deleted
@@ -22,8 +22,8 @@ from danswer.utils.variable_functionality import (
)
def _build_persona_name(channel_names: list[str]) -> str:
return f"{SLACK_BOT_PERSONA_PREFIX}{'-'.join(channel_names)}"
def _build_persona_name(channel_name: str) -> str:
return f"{SLACK_BOT_PERSONA_PREFIX}{channel_name}"
def _cleanup_relationships(db_session: Session, persona_id: int) -> None:
@@ -38,9 +38,9 @@ def _cleanup_relationships(db_session: Session, persona_id: int) -> None:
db_session.delete(rel)
def create_slack_bot_persona(
def create_slack_channel_persona(
db_session: Session,
channel_names: list[str],
channel_name: str,
document_set_ids: list[int],
existing_persona_id: int | None = None,
num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
@@ -48,11 +48,11 @@ def create_slack_bot_persona(
) -> Persona:
"""NOTE: does not commit changes"""
# create/update persona associated with the slack bot
persona_name = _build_persona_name(channel_names)
# create/update persona associated with the Slack channel
persona_name = _build_persona_name(channel_name)
default_prompt = get_default_prompt(db_session)
persona = upsert_persona(
user=None, # Slack Bot Personas are not attached to users
user=None, # Slack channel Personas are not attached to users
persona_id=existing_persona_id,
name=persona_name,
description="",
@@ -78,14 +78,15 @@ def _no_ee_standard_answer_categories(*args: Any, **kwargs: Any) -> list:
return []
def insert_slack_bot_config(
def insert_slack_channel_config(
db_session: Session,
slack_bot_id: int,
persona_id: int | None,
channel_config: ChannelConfig,
response_type: SlackBotResponseType,
standard_answer_category_ids: list[int],
enable_auto_filters: bool,
db_session: Session,
) -> SlackBotConfig:
) -> SlackChannelConfig:
versioned_fetch_standard_answer_categories_by_ids = (
fetch_versioned_implementation_with_fallback(
"danswer.db.standard_answer",
@@ -110,34 +111,37 @@ def insert_slack_bot_config(
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
)
slack_bot_config = SlackBotConfig(
slack_channel_config = SlackChannelConfig(
slack_bot_id=slack_bot_id,
persona_id=persona_id,
channel_config=channel_config,
response_type=response_type,
standard_answer_categories=existing_standard_answer_categories,
enable_auto_filters=enable_auto_filters,
)
db_session.add(slack_bot_config)
db_session.add(slack_channel_config)
db_session.commit()
return slack_bot_config
return slack_channel_config
def update_slack_bot_config(
slack_bot_config_id: int,
def update_slack_channel_config(
db_session: Session,
slack_channel_config_id: int,
persona_id: int | None,
channel_config: ChannelConfig,
response_type: SlackBotResponseType,
standard_answer_category_ids: list[int],
enable_auto_filters: bool,
db_session: Session,
) -> SlackBotConfig:
slack_bot_config = db_session.scalar(
select(SlackBotConfig).where(SlackBotConfig.id == slack_bot_config_id)
) -> SlackChannelConfig:
slack_channel_config = db_session.scalar(
select(SlackChannelConfig).where(
SlackChannelConfig.id == slack_channel_config_id
)
)
if slack_bot_config is None:
if slack_channel_config is None:
raise ValueError(
f"Unable to find slack bot config with ID {slack_bot_config_id}"
f"Unable to find Slack channel config with ID {slack_channel_config_id}"
)
versioned_fetch_standard_answer_categories_by_ids = (
@@ -159,25 +163,25 @@ def update_slack_bot_config(
)
# get the existing persona id before updating the object
existing_persona_id = slack_bot_config.persona_id
existing_persona_id = slack_channel_config.persona_id
# update the config
# NOTE: need to do this before cleaning up the old persona or else we
# will encounter `violates foreign key constraint` errors
slack_bot_config.persona_id = persona_id
slack_bot_config.channel_config = channel_config
slack_bot_config.response_type = response_type
slack_bot_config.standard_answer_categories = list(
slack_channel_config.persona_id = persona_id
slack_channel_config.channel_config = channel_config
slack_channel_config.response_type = response_type
slack_channel_config.standard_answer_categories = list(
existing_standard_answer_categories
)
slack_bot_config.enable_auto_filters = enable_auto_filters
slack_channel_config.enable_auto_filters = enable_auto_filters
# if the persona has changed, then clean up the old persona
if persona_id != existing_persona_id and existing_persona_id:
existing_persona = db_session.scalar(
select(Persona).where(Persona.id == existing_persona_id)
)
# if the existing persona was one created just for use with this Slack Bot,
# if the existing persona was one created just for use with this Slack channel,
# then clean it up
if existing_persona and existing_persona.name.startswith(
SLACK_BOT_PERSONA_PREFIX
@@ -188,28 +192,30 @@ def update_slack_bot_config(
db_session.commit()
return slack_bot_config
return slack_channel_config
def remove_slack_bot_config(
slack_bot_config_id: int,
user: User | None,
def remove_slack_channel_config(
db_session: Session,
slack_channel_config_id: int,
user: User | None,
) -> None:
slack_bot_config = db_session.scalar(
select(SlackBotConfig).where(SlackBotConfig.id == slack_bot_config_id)
slack_channel_config = db_session.scalar(
select(SlackChannelConfig).where(
SlackChannelConfig.id == slack_channel_config_id
)
)
if slack_bot_config is None:
if slack_channel_config is None:
raise ValueError(
f"Unable to find slack bot config with ID {slack_bot_config_id}"
f"Unable to find Slack channel config with ID {slack_channel_config_id}"
)
existing_persona_id = slack_bot_config.persona_id
existing_persona_id = slack_channel_config.persona_id
if existing_persona_id:
existing_persona = db_session.scalar(
select(Persona).where(Persona.id == existing_persona_id)
)
# if the existing persona was one created just for use with this Slack Bot,
# if the existing persona was one created just for use with this Slack channel,
# then clean it up
if existing_persona and existing_persona.name.startswith(
SLACK_BOT_PERSONA_PREFIX
@@ -221,17 +227,28 @@ def remove_slack_bot_config(
persona_id=existing_persona_id, user=user, db_session=db_session
)
db_session.delete(slack_bot_config)
db_session.delete(slack_channel_config)
db_session.commit()
def fetch_slack_bot_config(
db_session: Session, slack_bot_config_id: int
) -> SlackBotConfig | None:
def fetch_slack_channel_configs(
db_session: Session, slack_bot_id: int | None = None
) -> Sequence[SlackChannelConfig]:
if not slack_bot_id:
return db_session.scalars(select(SlackChannelConfig)).all()
return db_session.scalars(
select(SlackChannelConfig).where(
SlackChannelConfig.slack_bot_id == slack_bot_id
)
).all()
def fetch_slack_channel_config(
db_session: Session, slack_channel_config_id: int
) -> SlackChannelConfig | None:
return db_session.scalar(
select(SlackBotConfig).where(SlackBotConfig.id == slack_bot_config_id)
select(SlackChannelConfig).where(
SlackChannelConfig.id == slack_channel_config_id
)
)
def fetch_slack_bot_configs(db_session: Session) -> Sequence[SlackBotConfig]:
return db_session.scalars(select(SlackBotConfig)).all()

View File

@@ -24,6 +24,13 @@ def get_tool_by_id(tool_id: int, db_session: Session) -> Tool:
return tool
def get_tool_by_name(tool_name: str, db_session: Session) -> Tool:
tool = db_session.scalar(select(Tool).where(Tool.name == tool_name))
if not tool:
raise ValueError("Tool by specified name does not exist")
return tool
def create_tool(
name: str,
description: str | None,
@@ -37,7 +44,7 @@ def create_tool(
description=description,
in_code_tool_id=None,
openapi_schema=openapi_schema,
custom_headers=[header.dict() for header in custom_headers]
custom_headers=[header.model_dump() for header in custom_headers]
if custom_headers
else [],
user_id=user_id,

View File

@@ -1,6 +1,7 @@
from collections.abc import Sequence
from uuid import UUID
from fastapi import HTTPException
from fastapi_users.password import PasswordHelper
from sqlalchemy import func
from sqlalchemy import select
@@ -10,15 +11,94 @@ from danswer.auth.schemas import UserRole
from danswer.db.models import User
def validate_user_role_update(requested_role: UserRole, current_role: UserRole) -> None:
"""
Validate that a user role update is valid.
Assumed only admins can hit this endpoint.
raise if:
- requested role is a curator
- requested role is a slack user
- requested role is an external permissioned user
- requested role is a limited user
- current role is a slack user
- current role is an external permissioned user
- current role is a limited user
"""
if current_role == UserRole.SLACK_USER:
raise HTTPException(
status_code=400,
detail="To change a Slack User's role, they must first login to Danswer via the web app.",
)
if current_role == UserRole.EXT_PERM_USER:
# This shouldn't happen, but just in case
raise HTTPException(
status_code=400,
detail="To change an External Permissioned User's role, they must first login to Danswer via the web app.",
)
if current_role == UserRole.LIMITED:
raise HTTPException(
status_code=400,
detail="To change a Limited User's role, they must first login to Danswer via the web app.",
)
if requested_role == UserRole.CURATOR:
# This shouldn't happen, but just in case
raise HTTPException(
status_code=400,
detail="Curator role must be set via the User Group Menu",
)
if requested_role == UserRole.LIMITED:
# This shouldn't happen, but just in case
raise HTTPException(
status_code=400,
detail=(
"A user cannot be set to a Limited User role. "
"This role is automatically assigned to users through certain endpoints in the API."
),
)
if requested_role == UserRole.SLACK_USER:
# This shouldn't happen, but just in case
raise HTTPException(
status_code=400,
detail=(
"A user cannot be set to a Slack User role. "
"This role is automatically assigned to users who only use Danswer via Slack."
),
)
if requested_role == UserRole.EXT_PERM_USER:
# This shouldn't happen, but just in case
raise HTTPException(
status_code=400,
detail=(
"A user cannot be set to an External Permissioned User role. "
"This role is automatically assigned to users who have been "
"pulled in to the system via an external permissions system."
),
)
def list_users(
db_session: Session, email_filter_string: str = "", user: User | None = None
db_session: Session, email_filter_string: str = "", include_external: bool = False
) -> Sequence[User]:
"""List all users. No pagination as of now, as the # of users
is assumed to be relatively small (<< 1 million)"""
stmt = select(User)
where_clause = []
if not include_external:
where_clause.append(User.role != UserRole.EXT_PERM_USER)
if email_filter_string:
stmt = stmt.where(User.email.ilike(f"%{email_filter_string}%")) # type: ignore
where_clause.append(User.email.ilike(f"%{email_filter_string}%")) # type: ignore
stmt = stmt.where(*where_clause)
return db_session.scalars(stmt).unique().all()
@@ -45,55 +125,58 @@ def get_user_by_email(email: str, db_session: Session) -> User | None:
def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None:
user = db_session.query(User).filter(User.id == user_id).first() # type: ignore
return user
return db_session.query(User).filter(User.id == user_id).first() # type: ignore
def _generate_non_web_user(email: str) -> User:
def _generate_non_web_slack_user(email: str) -> User:
fastapi_users_pw_helper = PasswordHelper()
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
return User(
email=email,
hashed_password=hashed_pass,
has_web_login=False,
role=UserRole.BASIC,
role=UserRole.SLACK_USER,
)
def add_non_web_user_if_not_exists(db_session: Session, email: str) -> User:
def add_slack_user_if_not_exists(db_session: Session, email: str) -> User:
email = email.lower()
user = get_user_by_email(email, db_session)
if user is not None:
# If the user is an external permissioned user, we update it to a slack user
if user.role == UserRole.EXT_PERM_USER:
user.role = UserRole.SLACK_USER
db_session.commit()
return user
user = _generate_non_web_user(email=email)
user = _generate_non_web_slack_user(email=email)
db_session.add(user)
db_session.commit()
return user
def add_non_web_user_if_not_exists__no_commit(db_session: Session, email: str) -> User:
user = get_user_by_email(email, db_session)
if user is not None:
return user
user = _generate_non_web_user(email=email)
db_session.add(user)
db_session.flush() # generate id
return user
def _generate_non_web_permissioned_user(email: str) -> User:
fastapi_users_pw_helper = PasswordHelper()
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
return User(
email=email,
hashed_password=hashed_pass,
role=UserRole.EXT_PERM_USER,
)
def batch_add_non_web_user_if_not_exists__no_commit(
def batch_add_ext_perm_user_if_not_exists(
db_session: Session, emails: list[str]
) -> list[User]:
emails = [email.lower() for email in emails]
found_users, missing_user_emails = get_users_by_emails(db_session, emails)
new_users: list[User] = []
for email in missing_user_emails:
new_users.append(_generate_non_web_user(email=email))
new_users.append(_generate_non_web_permissioned_user(email=email))
db_session.add_all(new_users)
db_session.flush() # generate ids
db_session.commit()
return found_users + new_users

View File

@@ -15,7 +15,7 @@ schema DANSWER_CHUNK_NAME {
# Must have an additional field for whether to skip title embeddings
# This information cannot be extracted from either the title field nor title embedding
field skip_title type bool {
indexing: attribute
indexing: attribute
}
# May not always match the `semantic_identifier` e.g. for Slack docs the
# `semantic_identifier` will be the channel name, but the `title` will be empty
@@ -36,7 +36,7 @@ schema DANSWER_CHUNK_NAME {
}
# Title embedding (x1)
field title_embedding type tensor<float>(x[VARIABLE_DIM]) {
indexing: attribute
indexing: attribute | index
attribute {
distance-metric: angular
}
@@ -44,7 +44,7 @@ schema DANSWER_CHUNK_NAME {
# Content embeddings (chunk + optional mini chunks embeddings)
# "t" and "x" are arbitrary names, not special keywords
field embeddings type tensor<float>(t{},x[VARIABLE_DIM]) {
indexing: attribute
indexing: attribute | index
attribute {
distance-metric: angular
}

View File

@@ -2,6 +2,7 @@ import concurrent.futures
import json
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
import httpx
from retry import retry
@@ -194,6 +195,14 @@ def _index_vespa_chunk(
logger.exception(
f"Failed to index document: '{document.id}'. Got response: '{res.text}'"
)
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.INSUFFICIENT_STORAGE:
logger.error(
"NOTE: HTTP Status 507 Insufficient Storage usually means "
"you need to allocate more memory or disk space to the "
"Vespa/index container."
)
raise e

View File

@@ -1,7 +1,9 @@
import traceback
from functools import partial
from http import HTTPStatus
from typing import Protocol
import httpx
from pydantic import BaseModel
from pydantic import ConfigDict
from sqlalchemy.orm import Session
@@ -20,7 +22,8 @@ from danswer.db.document import get_documents_by_ids
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document import update_docs_last_modified__no_commit
from danswer.db.document import update_docs_updated_at__no_commit
from danswer.db.document import upsert_documents_complete
from danswer.db.document import upsert_document_by_connector_credential_pair
from danswer.db.document import upsert_documents
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.index_attempt import create_index_attempt_error
from danswer.db.models import Document as DBDocument
@@ -56,13 +59,13 @@ class IndexingPipelineProtocol(Protocol):
...
def upsert_documents_in_db(
def _upsert_documents_in_db(
documents: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
) -> None:
# Metadata here refers to basic document info, not metadata about the actual content
doc_m_batch: list[DocumentMetadata] = []
document_metadata_list: list[DocumentMetadata] = []
for doc in documents:
first_link = next(
(section.link for section in doc.sections if section.link), ""
@@ -77,12 +80,9 @@ def upsert_documents_in_db(
secondary_owners=get_experts_stores_representations(doc.secondary_owners),
from_ingestion_api=doc.from_ingestion_api,
)
doc_m_batch.append(db_doc_metadata)
document_metadata_list.append(db_doc_metadata)
upsert_documents_complete(
db_session=db_session,
document_metadata_batch=doc_m_batch,
)
upsert_documents(db_session, document_metadata_list)
# Insert document content metadata
for doc in documents:
@@ -95,21 +95,25 @@ def upsert_documents_in_db(
document_id=doc.id,
db_session=db_session,
)
else:
create_or_add_document_tag(
tag_key=k,
tag_value=v,
source=doc.source,
document_id=doc.id,
db_session=db_session,
)
continue
create_or_add_document_tag(
tag_key=k,
tag_value=v,
source=doc.source,
document_id=doc.id,
db_session=db_session,
)
def get_doc_ids_to_update(
documents: list[Document], db_docs: list[DBDocument]
) -> list[Document]:
"""Figures out which documents actually need to be updated. If a document is already present
and the `updated_at` hasn't changed, we shouldn't need to do anything with it."""
and the `updated_at` hasn't changed, we shouldn't need to do anything with it.
NB: Still need to associate the document in the DB if multiple connectors are
indexing the same doc."""
id_update_time_map = {
doc.id: doc.doc_updated_at for doc in db_docs if doc.doc_updated_at
}
@@ -152,6 +156,14 @@ def index_doc_batch_with_handler(
tenant_id=tenant_id,
)
except Exception as e:
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.INSUFFICIENT_STORAGE:
logger.error(
"NOTE: HTTP Status 507 Insufficient Storage indicates "
"you need to allocate more memory or disk space to the "
"Vespa/index container."
)
if INDEXING_EXCEPTION_LIMIT == 0:
raise
@@ -195,9 +207,9 @@ def index_doc_batch_prepare(
db_session: Session,
ignore_time_skip: bool = False,
) -> DocumentBatchPrepareContext | None:
"""This sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
"""Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
This preceeds indexing it into the actual document index."""
documents = []
documents: list[Document] = []
for document in document_batch:
empty_contents = not any(section.text.strip() for section in document.sections)
if (
@@ -212,50 +224,65 @@ def index_doc_batch_prepare(
logger.warning(
f"Skipping document with ID {document.id} as it has neither title nor content."
)
elif (
document.title is not None and not document.title.strip() and empty_contents
):
continue
if document.title is not None and not document.title.strip() and empty_contents:
# The title is explicitly empty ("" and not None) and the document is empty
# so when building the chunk text representation, it will be empty and unuseable
logger.warning(
f"Skipping document with ID {document.id} as the chunks will be empty."
)
else:
documents.append(document)
continue
document_ids = [document.id for document in documents]
documents.append(document)
# Create a trimmed list of docs that don't have a newer updated at
# Shortcuts the time-consuming flow on connector index retries
document_ids: list[str] = [document.id for document in documents]
db_docs: list[DBDocument] = get_documents_by_ids(
db_session=db_session,
document_ids=document_ids,
)
# Skip indexing docs that don't have a newer updated at
# Shortcuts the time-consuming flow on connector index retries
updatable_docs = (
get_doc_ids_to_update(documents=documents, db_docs=db_docs)
if not ignore_time_skip
else documents
)
# No docs to update either because the batch is empty or every doc was already indexed
# for all updatable docs, upsert into the DB
# Does not include doc_updated_at which is also used to indicate a successful update
if updatable_docs:
_upsert_documents_in_db(
documents=updatable_docs,
index_attempt_metadata=index_attempt_metadata,
db_session=db_session,
)
logger.info(
f"Upserted {len(updatable_docs)} changed docs out of "
f"{len(documents)} total docs into the DB"
)
# for all docs, upsert the document to cc pair relationship
upsert_document_by_connector_credential_pair(
db_session,
index_attempt_metadata.connector_id,
index_attempt_metadata.credential_id,
document_ids,
)
# No docs to process because the batch is empty or every doc was already indexed
if not updatable_docs:
return None
# Create records in the source of truth about these documents,
# does not include doc_updated_at which is also used to indicate a successful update
upsert_documents_in_db(
documents=documents,
index_attempt_metadata=index_attempt_metadata,
db_session=db_session,
)
id_to_db_doc_map = {doc.id: doc for doc in db_docs}
return DocumentBatchPrepareContext(
updatable_docs=updatable_docs, id_to_db_doc_map=id_to_db_doc_map
)
@log_function_time()
@log_function_time(debug_only=True)
def index_doc_batch(
*,
chunker: Chunker,
@@ -269,7 +296,10 @@ def index_doc_batch(
) -> tuple[int, int]:
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
Note that the documents should already be batched at this point so that it does not inflate the
memory requirements"""
memory requirements
Returns a tuple where the first element is the number of new docs and the
second element is the number of chunks."""
no_access = DocumentAccess.build(
user_emails=[],
@@ -312,9 +342,9 @@ def index_doc_batch(
# we're concerned about race conditions where multiple simultaneous indexings might result
# in one set of metadata overwriting another one in vespa.
# we still write data here for immediate and most likely correct sync, but
# we still write data here for the immediate and most likely correct sync, but
# to resolve this, an update of the last modified field at the end of this loop
# always triggers a final metadata sync
# always triggers a final metadata sync via the celery queue
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
@@ -351,7 +381,8 @@ def index_doc_batch(
ids_to_new_updated_at = {}
for doc in successful_docs:
last_modified_ids.append(doc.id)
# doc_updated_at is the connector source's idea of when the doc was last modified
# doc_updated_at is the source's idea (on the other end of the connector)
# of when the doc was last modified
if doc.doc_updated_at is None:
continue
ids_to_new_updated_at[doc.id] = doc.doc_updated_at
@@ -366,10 +397,13 @@ def index_doc_batch(
db_session.commit()
return len([r for r in insertion_records if r.already_existed is False]), len(
access_aware_chunks
result = (
len([r for r in insertion_records if r.already_existed is False]),
len(access_aware_chunks),
)
return result
def build_indexing_pipeline(
*,

View File

@@ -1,12 +1,6 @@
import abc
from collections.abc import Mapping
from collections.abc import Sequence
from typing import TypeAlias
JSON_ro: TypeAlias = (
Mapping[str, "JSON_ro"] | Sequence["JSON_ro"] | str | int | float | bool | None
)
from danswer.utils.special_types import JSON_ro
class KvKeyNotFoundError(Exception):

View File

@@ -11,11 +11,11 @@ from sqlalchemy.orm import Session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import is_valid_schema_name
from danswer.db.models import KVStore
from danswer.key_value_store.interface import JSON_ro
from danswer.key_value_store.interface import KeyValueStore
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.special_types import JSON_ro
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR

View File

@@ -263,6 +263,7 @@ class Answer:
message_history=self.message_history,
llm_config=self.llm.config,
single_message_history=self.single_message_history,
raw_user_text=self.question,
)
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)

View File

@@ -59,6 +59,7 @@ class AnswerPromptBuilder:
message_history: list[PreviousMessage],
llm_config: LLMConfig,
single_message_history: str | None = None,
raw_user_text: str | None = None,
) -> None:
self.max_tokens = compute_max_llm_input_tokens(llm_config)
@@ -88,6 +89,12 @@ class AnswerPromptBuilder:
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
self.raw_user_message = (
HumanMessage(content=raw_user_text)
if raw_user_text is not None
else user_message
)
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
if not system_message:
self.system_message_and_token_cnt = None

View File

@@ -231,16 +231,16 @@ class QuotesProcessor:
model_previous = self.model_output
self.model_output += token
if not self.found_answer_start:
m = answer_pattern.search(self.model_output)
if m:
self.found_answer_start = True
# Prevent heavy cases of hallucinations
if self.is_json_prompt and len(self.model_output) > 70:
logger.warning("LLM did not produce json as prompted")
if self.is_json_prompt and len(self.model_output) > 400:
self.found_answer_end = True
logger.warning("LLM did not produce json as prompted")
logger.debug("Model output thus far:", self.model_output)
return
remaining = self.model_output[m.end() :]

View File

@@ -1,6 +1,8 @@
import json
import os
import traceback
from collections.abc import Iterator
from collections.abc import Sequence
from typing import Any
from typing import cast
@@ -21,15 +23,18 @@ from langchain_core.messages import SystemMessage
from langchain_core.messages import SystemMessageChunk
from langchain_core.messages.tool import ToolCallChunk
from langchain_core.messages.tool import ToolMessage
from langchain_core.prompt_values import PromptValue
from danswer.configs.app_configs import LOG_ALL_MODEL_INTERACTIONS
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.configs.model_configs import LITELLM_EXTRA_BODY
from danswer.llm.interfaces import LLM
from danswer.llm.interfaces import LLMConfig
from danswer.llm.interfaces import ToolChoiceOptions
from danswer.server.utils import mask_string
from danswer.utils.logger import setup_logger
from danswer.utils.long_term_log import LongTermLogger
logger = setup_logger()
@@ -39,7 +44,7 @@ logger = setup_logger()
litellm.drop_params = True
litellm.telemetry = False
litellm.set_verbose = LOG_ALL_MODEL_INTERACTIONS
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
def _base_msg_to_role(msg: BaseMessage) -> str:
@@ -195,6 +200,23 @@ def _convert_delta_to_message_chunk(
raise ValueError(f"Unknown role: {role}")
def _prompt_to_dict(
prompt: LanguageModelInput,
) -> Sequence[str | list[str] | dict[str, Any] | tuple[str, str]]:
# NOTE: this must go first, since it is also a Sequence
if isinstance(prompt, str):
return [_convert_message_to_dict(HumanMessage(content=prompt))]
if isinstance(prompt, (list, Sequence)):
return [
_convert_message_to_dict(msg) if isinstance(msg, BaseMessage) else msg
for msg in prompt
]
if isinstance(prompt, PromptValue):
return [_convert_message_to_dict(message) for message in prompt.to_messages()]
class DefaultMultiLLM(LLM):
"""Uses Litellm library to allow easy configuration to use a multitude of LLMs
See https://python.langchain.com/docs/integrations/chat/litellm"""
@@ -213,6 +235,8 @@ class DefaultMultiLLM(LLM):
temperature: float = GEN_AI_TEMPERATURE,
custom_config: dict[str, str] | None = None,
extra_headers: dict[str, str] | None = None,
extra_body: dict | None = LITELLM_EXTRA_BODY,
long_term_logger: LongTermLogger | None = None,
):
self._timeout = timeout
self._model_provider = model_provider
@@ -223,6 +247,7 @@ class DefaultMultiLLM(LLM):
self._api_base = api_base
self._api_version = api_version
self._custom_llm_provider = custom_llm_provider
self._long_term_logger = long_term_logger
# This can be used to store the maximum output tokens for this model.
# self._max_output_tokens = (
@@ -246,12 +271,60 @@ class DefaultMultiLLM(LLM):
model_kwargs: dict[str, Any] = {}
if extra_headers:
model_kwargs.update({"extra_headers": extra_headers})
if extra_body:
model_kwargs.update({"extra_body": extra_body})
self._model_kwargs = model_kwargs
def log_model_configs(self) -> None:
logger.debug(f"Config: {self.config}")
def _safe_model_config(self) -> dict:
dump = self.config.model_dump()
dump["api_key"] = mask_string(dump.get("api_key", ""))
return dump
def _record_call(self, prompt: LanguageModelInput) -> None:
if self._long_term_logger:
self._long_term_logger.record(
{"prompt": _prompt_to_dict(prompt), "model": self._safe_model_config()},
category=_LLM_PROMPT_LONG_TERM_LOG_CATEGORY,
)
def _record_result(
self, prompt: LanguageModelInput, model_output: BaseMessage
) -> None:
if self._long_term_logger:
self._long_term_logger.record(
{
"prompt": _prompt_to_dict(prompt),
"content": model_output.content,
"tool_calls": (
model_output.tool_calls
if hasattr(model_output, "tool_calls")
else []
),
"model": self._safe_model_config(),
},
category=_LLM_PROMPT_LONG_TERM_LOG_CATEGORY,
)
def _record_error(self, prompt: LanguageModelInput, error: Exception) -> None:
if self._long_term_logger:
self._long_term_logger.record(
{
"prompt": _prompt_to_dict(prompt),
"error": str(error),
"traceback": "".join(
traceback.format_exception(
type(error), error, error.__traceback__
)
),
"model": self._safe_model_config(),
},
category=_LLM_PROMPT_LONG_TERM_LOG_CATEGORY,
)
# def _calculate_max_output_tokens(self, prompt: LanguageModelInput) -> int:
# # NOTE: This method can be used for calculating the maximum tokens for the stream,
# # but it isn't used in practice due to the computational cost of counting tokens
@@ -284,14 +357,10 @@ class DefaultMultiLLM(LLM):
stream: bool,
structured_response_format: dict | None = None,
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
if isinstance(prompt, list):
prompt = [
_convert_message_to_dict(msg) if isinstance(msg, BaseMessage) else msg
for msg in prompt
]
elif isinstance(prompt, str):
prompt = [_convert_message_to_dict(HumanMessage(content=prompt))]
# litellm doesn't accept LangChain BaseMessage objects, so we need to convert them
# to a dict representation
processed_prompt = _prompt_to_dict(prompt)
self._record_call(processed_prompt)
try:
return litellm.completion(
@@ -304,7 +373,7 @@ class DefaultMultiLLM(LLM):
api_version=self._api_version or None,
custom_llm_provider=self._custom_llm_provider or None,
# actual input
messages=prompt,
messages=processed_prompt,
tools=tools,
tool_choice=tool_choice if tools else None,
# streaming choice
@@ -324,6 +393,7 @@ class DefaultMultiLLM(LLM):
**self._model_kwargs,
)
except Exception as e:
self._record_error(processed_prompt, e)
# for break pointing
raise e
@@ -357,7 +427,10 @@ class DefaultMultiLLM(LLM):
)
choice = response.choices[0]
if hasattr(choice, "message"):
return _convert_litellm_message_to_langchain_message(choice.message)
output = _convert_litellm_message_to_langchain_message(choice.message)
if output:
self._record_result(prompt, output)
return output
else:
raise ValueError("Unexpected response choice type")
@@ -406,6 +479,9 @@ class DefaultMultiLLM(LLM):
"The AI model failed partway through generation, please try again."
)
if output:
self._record_result(prompt, output)
if LOG_DANSWER_MODEL_INTERACTIONS and output:
content = output.content or ""
if isinstance(output, AIMessage):

View File

@@ -10,6 +10,7 @@ from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.interfaces import LLM
from danswer.llm.override_models import LLMOverride
from danswer.utils.headers import build_llm_extra_headers
from danswer.utils.long_term_log import LongTermLogger
def get_main_llm_from_tuple(
@@ -22,6 +23,7 @@ def get_llms_for_persona(
persona: Persona,
llm_override: LLMOverride | None = None,
additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None,
) -> tuple[LLM, LLM]:
model_provider_override = llm_override.model_provider if llm_override else None
model_version_override = llm_override.model_version if llm_override else None
@@ -32,6 +34,7 @@ def get_llms_for_persona(
return get_default_llms(
temperature=temperature_override or GEN_AI_TEMPERATURE,
additional_headers=additional_headers,
long_term_logger=long_term_logger,
)
with get_session_context_manager() as db_session:
@@ -57,6 +60,7 @@ def get_llms_for_persona(
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
additional_headers=additional_headers,
long_term_logger=long_term_logger,
)
return _create_llm(model), _create_llm(fast_model)
@@ -66,6 +70,7 @@ def get_default_llms(
timeout: int = QA_TIMEOUT,
temperature: float = GEN_AI_TEMPERATURE,
additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None,
) -> tuple[LLM, LLM]:
if DISABLE_GENERATIVE_AI:
raise GenAIDisabledException()
@@ -97,6 +102,7 @@ def get_default_llms(
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,
long_term_logger=long_term_logger,
)
return _create_llm(model_name), _create_llm(fast_model_name)
@@ -113,6 +119,7 @@ def get_llm(
temperature: float = GEN_AI_TEMPERATURE,
timeout: int = QA_TIMEOUT,
additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None,
) -> LLM:
return DefaultMultiLLM(
model_provider=provider,
@@ -125,4 +132,5 @@ def get_llm(
temperature=temperature,
custom_config=custom_config,
extra_headers=build_llm_extra_headers(additional_headers),
long_term_logger=long_term_logger,
)

View File

@@ -64,6 +64,9 @@ from danswer.server.features.prompt.api import basic_router as prompt_router
from danswer.server.features.tool.api import admin_router as admin_tool_router
from danswer.server.features.tool.api import router as tool_router
from danswer.server.gpts.api import router as gpts_router
from danswer.server.long_term_logs.long_term_logs_api import (
router as long_term_logs_router,
)
from danswer.server.manage.administrative import router as admin_router
from danswer.server.manage.embedding.api import admin_router as embedding_admin_router
from danswer.server.manage.embedding.api import basic_router as embedding_router
@@ -74,6 +77,9 @@ from danswer.server.manage.search_settings import router as search_settings_rout
from danswer.server.manage.slack_bot import router as slack_bot_management_router
from danswer.server.manage.users import router as user_router
from danswer.server.middleware.latency_logging import add_latency_logging_middleware
from danswer.server.openai_assistants_api.full_openai_assistants_api import (
get_full_openai_assistants_api_router,
)
from danswer.server.query_and_chat.chat_backend import router as chat_router
from danswer.server.query_and_chat.query_backend import (
admin_router as admin_query_router,
@@ -270,6 +276,10 @@ def get_application() -> FastAPI:
application, token_rate_limit_settings_router
)
include_router_with_global_prefix_prepended(application, indexing_router)
include_router_with_global_prefix_prepended(
application, get_full_openai_assistants_api_router()
)
include_router_with_global_prefix_prepended(application, long_term_logs_router)
if AUTH_TYPE == AuthType.DISABLED:
# Server logs this during auth setup verification step
@@ -309,7 +319,7 @@ def get_application() -> FastAPI:
tags=["users"],
)
if AUTH_TYPE == AuthType.GOOGLE_OAUTH or AUTH_TYPE == AuthType.CLOUD:
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
application,

View File

@@ -89,67 +89,70 @@ def _check_tokenizer_cache(
model_provider: EmbeddingProvider | None, model_name: str | None
) -> BaseTokenizer:
global _TOKENIZER_CACHE
id_tuple = (model_provider, model_name)
if id_tuple not in _TOKENIZER_CACHE:
if model_provider in [EmbeddingProvider.OPENAI, EmbeddingProvider.AZURE]:
if model_name is None:
raise ValueError(
"model_name is required for OPENAI and AZURE embeddings"
)
tokenizer = None
_TOKENIZER_CACHE[id_tuple] = TiktokenTokenizer(model_name)
return _TOKENIZER_CACHE[id_tuple]
if model_name:
tokenizer = _try_initialize_tokenizer(model_name, model_provider)
try:
if model_name is None:
model_name = DOCUMENT_ENCODER_MODEL
logger.debug(f"Initializing HuggingFaceTokenizer for: {model_name}")
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(model_name)
except Exception as primary_error:
logger.error(
f"Error initializing HuggingFaceTokenizer for {model_name}: {primary_error}"
)
logger.warning(
if not tokenizer:
logger.info(
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
)
tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
try:
# Cache this tokenizer name to the default so we don't have to try to load it again
# and fail again
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(
DOCUMENT_ENCODER_MODEL
)
except Exception as fallback_error:
logger.error(
f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}"
)
raise ValueError(
f"Failed to initialize tokenizer for {model_name} and fallback model"
) from fallback_error
_TOKENIZER_CACHE[id_tuple] = tokenizer
return _TOKENIZER_CACHE[id_tuple]
def _try_initialize_tokenizer(
model_name: str, model_provider: EmbeddingProvider | None
) -> BaseTokenizer | None:
tokenizer: BaseTokenizer | None = None
if model_provider is not None:
# Try using TiktokenTokenizer first if model_provider exists
try:
tokenizer = TiktokenTokenizer(model_name)
logger.info(f"Initialized TiktokenTokenizer for: {model_name}")
return tokenizer
except Exception as tiktoken_error:
logger.debug(
f"TiktokenTokenizer not available for model {model_name}: {tiktoken_error}"
)
else:
# If no provider specified, try HuggingFaceTokenizer
try:
tokenizer = HuggingFaceTokenizer(model_name)
logger.info(f"Initialized HuggingFaceTokenizer for: {model_name}")
return tokenizer
except Exception as hf_error:
logger.warning(
f"Error initializing HuggingFaceTokenizer for {model_name}: {hf_error}"
)
# If both initializations fail, return None
return None
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
def get_tokenizer(
model_name: str | None, provider_type: EmbeddingProvider | str | None
) -> BaseTokenizer:
if provider_type is not None:
if isinstance(provider_type, str):
try:
provider_type = EmbeddingProvider(provider_type)
except ValueError:
logger.debug(
f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer."
)
return _DEFAULT_TOKENIZER
return _check_tokenizer_cache(provider_type, model_name)
return _DEFAULT_TOKENIZER
if isinstance(provider_type, str):
try:
provider_type = EmbeddingProvider(provider_type)
except ValueError:
logger.debug(
f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer."
)
return _DEFAULT_TOKENIZER
return _check_tokenizer_cache(provider_type, model_name)
def tokenizer_trim_content(

View File

@@ -64,6 +64,7 @@ from danswer.tools.tool_implementations.search.search_tool import (
)
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.utils.logger import setup_logger
from danswer.utils.long_term_log import LongTermLogger
from danswer.utils.timing import log_generator_function_time
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
@@ -124,6 +125,11 @@ def stream_answer_objects(
danswerbot_flow=danswerbot_flow,
)
# permanent "log" store, used primarily for debugging
long_term_logger = LongTermLogger(
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session.id)}
)
temporary_persona: Persona | None = None
if query_req.persona_config is not None:
@@ -134,7 +140,9 @@ def stream_answer_objects(
persona = temporary_persona if temporary_persona else chat_session.persona
try:
llm, fast_llm = get_llms_for_persona(persona=persona)
llm, fast_llm = get_llms_for_persona(
persona=persona, long_term_logger=long_term_logger
)
except ValueError as e:
logger.error(
f"Failed to initialize LLMs for persona '{persona.name}': {str(e)}"
@@ -237,7 +245,9 @@ def stream_answer_objects(
question=query_msg.message,
answer_style_config=answer_config,
prompt_config=PromptConfig.from_model(prompt),
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=persona)),
llm=get_main_llm_from_tuple(
get_llms_for_persona(persona=persona, long_term_logger=long_term_logger)
),
single_message_history=history_str,
tools=[search_tool] if search_tool else [],
force_use_tool=(

View File

@@ -1,6 +1,8 @@
import redis
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_connector_stop import RedisConnectorStop
@@ -19,6 +21,10 @@ class RedisConnector:
self.stop = RedisConnectorStop(tenant_id, id, self.redis)
self.prune = RedisConnectorPrune(tenant_id, id, self.redis)
self.delete = RedisConnectorDelete(tenant_id, id, self.redis)
self.permissions = RedisConnectorPermissionSync(tenant_id, id, self.redis)
self.external_group_sync = RedisConnectorExternalGroupSync(
tenant_id, id, self.redis
)
def new_index(self, search_settings_id: int) -> RedisConnectorIndex:
return RedisConnectorIndex(

View File

@@ -1,9 +1,10 @@
import time
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -13,6 +14,7 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair_f
from danswer.db.document import (
construct_document_select_for_connector_credential_pair_by_needs_sync,
)
from danswer.db.models import Document
from danswer.redis.redis_object_helper import RedisObjectHelper
@@ -30,6 +32,9 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
def __init__(self, tenant_id: str | None, id: int) -> None:
super().__init__(tenant_id, str(id))
# documents that should be skipped
self.skip_docs: set[str] = set()
@classmethod
def get_fence_key(cls) -> str:
return RedisConnectorCredentialPair.FENCE_PREFIX
@@ -45,14 +50,19 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
# example: connector_taskset
return f"{self.TASKSET_PREFIX}"
def set_skip_docs(self, skip_docs: set[str]) -> None:
# documents that should be skipped. Note that this classes updates
# the list on the fly
self.skip_docs = skip_docs
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
lock: RedisLock,
tenant_id: str | None,
) -> int | None:
) -> tuple[int, int] | None:
last_lock_time = time.monotonic()
async_results = []
@@ -63,7 +73,11 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)
num_docs = 0
for doc in db_session.scalars(stmt).yield_per(1):
doc = cast(Document, doc)
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
@@ -71,6 +85,12 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
lock.reacquire()
last_lock_time = current_time
num_docs += 1
# check if we should skip the document (typically because it's already syncing)
if doc.id in self.skip_docs:
continue
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
@@ -93,5 +113,6 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
)
async_results.append(result)
self.skip_docs.add(doc.id)
return len(async_results)
return len(async_results), num_docs

View File

@@ -6,6 +6,7 @@ from uuid import uuid4
import redis
from celery import Celery
from pydantic import BaseModel
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -13,6 +14,7 @@ from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import construct_document_select_for_connector_credential_pair
from danswer.db.models import Document as DbDocument
class RedisConnectorDeletionFenceData(BaseModel):
@@ -83,7 +85,7 @@ class RedisConnectorDelete:
self,
celery_app: Celery,
db_session: Session,
lock: redis.lock.Lock,
lock: RedisLock,
) -> int | None:
"""Returns None if the cc_pair doesn't exist.
Otherwise, returns an int with the number of generated tasks."""
@@ -97,7 +99,8 @@ class RedisConnectorDelete:
stmt = construct_document_select_for_connector_credential_pair(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
for doc_temp in db_session.scalars(stmt).yield_per(1):
doc: DbDocument = doc_temp
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
@@ -129,6 +132,10 @@ class RedisConnectorDelete:
return len(async_results)
def reset(self) -> None:
self.redis.delete(self.taskset_key)
self.redis.delete(self.fence_key)
@staticmethod
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
taskset_key = f"{RedisConnectorDelete.TASKSET_PREFIX}_{id}"

View File

@@ -0,0 +1,188 @@
import time
from datetime import datetime
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from pydantic import BaseModel
from redis.lock import Lock as RedisLock
from danswer.access.models import DocExternalAccess
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
class RedisConnectorPermissionSyncData(BaseModel):
started: datetime | None
class RedisConnectorPermissionSync:
"""Manages interactions with redis for doc permission sync tasks. Should only be accessed
through RedisConnector."""
PREFIX = "connectordocpermissionsync"
FENCE_PREFIX = f"{PREFIX}_fence"
# phase 1 - geneartor task and progress signals
GENERATORTASK_PREFIX = f"{PREFIX}+generator" # connectorpermissions+generator
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # connectorpermissions_generator_progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # connectorpermissions_generator_complete
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpermissions_taskset
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpermissions+sub
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
self.id = id
self.redis = redis
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
self.generator_task_key = f"{self.GENERATORTASK_PREFIX}_{id}"
self.generator_progress_key = f"{self.GENERATOR_PROGRESS_PREFIX}_{id}"
self.generator_complete_key = f"{self.GENERATOR_COMPLETE_PREFIX}_{id}"
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
def taskset_clear(self) -> None:
self.redis.delete(self.taskset_key)
def generator_clear(self) -> None:
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
def get_remaining(self) -> int:
remaining = cast(int, self.redis.scard(self.taskset_key))
return remaining
def get_active_task_count(self) -> int:
"""Count of active permission sync tasks"""
count = 0
for _ in self.redis.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
count += 1
return count
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
@property
def payload(self) -> RedisConnectorPermissionSyncData | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorPermissionSyncData.model_validate_json(
cast(str, fence_str)
)
return payload
def set_fence(
self,
payload: RedisConnectorPermissionSyncData | None,
) -> None:
if not payload:
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, payload.model_dump_json())
@property
def generator_complete(self) -> int | None:
"""the fence payload is an int representing the starting number of
permission sync tasks to be processed ... just after the generator completes."""
fence_bytes = self.redis.get(self.generator_complete_key)
if fence_bytes is None:
return None
if fence_bytes == b"None":
return None
fence_int = int(cast(bytes, fence_bytes).decode())
return fence_int
@generator_complete.setter
def generator_complete(self, payload: int | None) -> None:
"""Set the payload to an int to set the fence, otherwise if None it will
be deleted"""
if payload is None:
self.redis.delete(self.generator_complete_key)
return
self.redis.set(self.generator_complete_key, payload)
def generate_tasks(
self,
celery_app: Celery,
lock: RedisLock | None,
new_permissions: list[DocExternalAccess],
source_string: str,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
# Create a task for each document permission sync
for doc_perm in new_permissions:
current_time = time.monotonic()
if lock and current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# Add task for document permissions sync
custom_task_id = f"{self.subtask_prefix}_{uuid4()}"
self.redis.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"update_external_document_permissions_task",
kwargs=dict(
tenant_id=self.tenant_id,
serialized_doc_external_access=doc_perm.to_dict(),
source_string=source_string,
),
queue=DanswerCeleryQueues.DOC_PERMISSIONS_UPSERT,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
@staticmethod
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
taskset_key = f"{RedisConnectorPermissionSync.TASKSET_PREFIX}_{id}"
r.srem(taskset_key, task_id)
return
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
for key in r.scan_iter(RedisConnectorPermissionSync.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(
RedisConnectorPermissionSync.GENERATOR_COMPLETE_PREFIX + "*"
):
r.delete(key)
for key in r.scan_iter(
RedisConnectorPermissionSync.GENERATOR_PROGRESS_PREFIX + "*"
):
r.delete(key)
for key in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
r.delete(key)

View File

@@ -0,0 +1,134 @@
from typing import cast
import redis
from celery import Celery
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
class RedisConnectorExternalGroupSync:
"""Manages interactions with redis for external group syncing tasks. Should only be accessed
through RedisConnector."""
PREFIX = "connectorexternalgroupsync"
FENCE_PREFIX = f"{PREFIX}_fence"
# phase 1 - geneartor task and progress signals
GENERATORTASK_PREFIX = f"{PREFIX}+generator" # connectorexternalgroupsync+generator
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # connectorexternalgroupsync_generator_progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # connectorexternalgroupsync_generator_complete
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorexternalgroupsync_taskset
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorexternalgroupsync+sub
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
self.id = id
self.redis = redis
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
self.generator_task_key = f"{self.GENERATORTASK_PREFIX}_{id}"
self.generator_progress_key = f"{self.GENERATOR_PROGRESS_PREFIX}_{id}"
self.generator_complete_key = f"{self.GENERATOR_COMPLETE_PREFIX}_{id}"
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
def taskset_clear(self) -> None:
self.redis.delete(self.taskset_key)
def generator_clear(self) -> None:
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
def get_remaining(self) -> int:
# todo: move into fence
remaining = cast(int, self.redis.scard(self.taskset_key))
return remaining
def get_active_task_count(self) -> int:
"""Count of active external group syncing tasks"""
count = 0
for _ in self.redis.scan_iter(
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*"
):
count += 1
return count
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
def set_fence(self, value: bool) -> None:
if not value:
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, 0)
@property
def generator_complete(self) -> int | None:
"""the fence payload is an int representing the starting number of
external group syncing tasks to be processed ... just after the generator completes.
"""
fence_bytes = self.redis.get(self.generator_complete_key)
if fence_bytes is None:
return None
if fence_bytes == b"None":
return None
fence_int = int(cast(bytes, fence_bytes).decode())
return fence_int
@generator_complete.setter
def generator_complete(self, payload: int | None) -> None:
"""Set the payload to an int to set the fence, otherwise if None it will
be deleted"""
if payload is None:
self.redis.delete(self.generator_complete_key)
return
self.redis.set(self.generator_complete_key, payload)
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
lock: RedisLock | None,
) -> int | None:
pass
@staticmethod
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
taskset_key = f"{RedisConnectorExternalGroupSync.TASKSET_PREFIX}_{id}"
r.srem(taskset_key, task_id)
return
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
for key in r.scan_iter(RedisConnectorExternalGroupSync.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(
RedisConnectorExternalGroupSync.GENERATOR_COMPLETE_PREFIX + "*"
):
r.delete(key)
for key in r.scan_iter(
RedisConnectorExternalGroupSync.GENERATOR_PROGRESS_PREFIX + "*"
):
r.delete(key)
for key in r.scan_iter(RedisConnectorExternalGroupSync.FENCE_PREFIX + "*"):
r.delete(key)

View File

@@ -6,7 +6,7 @@ import redis
from pydantic import BaseModel
class RedisConnectorIndexingFenceData(BaseModel):
class RedisConnectorIndexPayload(BaseModel):
index_attempt_id: int | None
started: datetime | None
submitted: datetime
@@ -71,22 +71,20 @@ class RedisConnectorIndex:
return False
@property
def payload(self) -> RedisConnectorIndexingFenceData | None:
def payload(self) -> RedisConnectorIndexPayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorIndexingFenceData.model_validate_json(
cast(str, fence_str)
)
payload = RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
return payload
def set_fence(
self,
payload: RedisConnectorIndexingFenceData | None,
payload: RedisConnectorIndexPayload | None,
) -> None:
if not payload:
self.redis.delete(self.fence_key)

View File

@@ -4,6 +4,7 @@ from uuid import uuid4
import redis
from celery import Celery
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -105,7 +106,7 @@ class RedisConnectorPrune:
documents_to_prune: set[str],
celery_app: Celery,
db_session: Session,
lock: redis.lock.Lock | None,
lock: RedisLock | None,
) -> int | None:
last_lock_time = time.monotonic()
@@ -149,6 +150,12 @@ class RedisConnectorPrune:
return len(async_results)
def reset(self) -> None:
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
self.redis.delete(self.taskset_key)
self.redis.delete(self.fence_key)
@staticmethod
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
taskset_key = f"{RedisConnectorPrune.TASKSET_PREFIX}_{id}"

View File

@@ -5,6 +5,7 @@ from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -50,9 +51,9 @@ class RedisDocumentSet(RedisObjectHelper):
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
lock: RedisLock,
tenant_id: str | None,
) -> int | None:
) -> tuple[int, int] | None:
last_lock_time = time.monotonic()
async_results = []
@@ -84,7 +85,7 @@ class RedisDocumentSet(RedisObjectHelper):
async_results.append(result)
return len(async_results)
return len(async_results), len(async_results)
def reset(self) -> None:
self.redis.delete(self.taskset_key)

View File

@@ -1,9 +1,9 @@
from abc import ABC
from abc import abstractmethod
import redis
from celery import Celery
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.redis.redis_pool import get_redis_client
@@ -85,7 +85,13 @@ class RedisObjectHelper(ABC):
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
lock: RedisLock,
tenant_id: str | None,
) -> int | None:
pass
) -> tuple[int, int] | None:
"""First element should be the number of actual tasks generated, second should
be the number of docs that were candidates to be synced for the cc pair.
The need for this is when we are syncing stale docs referenced by multiple
connectors. In a single pass across multiple cc pairs, we only want a task
for be created for a particular document id the first time we see it.
The rest can be skipped."""

View File

@@ -5,6 +5,7 @@ from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -51,15 +52,15 @@ class RedisUserGroup(RedisObjectHelper):
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
lock: RedisLock,
tenant_id: str | None,
) -> int | None:
) -> tuple[int, int] | None:
last_lock_time = time.monotonic()
async_results = []
if not global_version.is_ee_version():
return 0
return 0, 0
try:
construct_document_select_by_usergroup = fetch_versioned_implementation(
@@ -67,7 +68,7 @@ class RedisUserGroup(RedisObjectHelper):
"construct_document_select_by_usergroup",
)
except ModuleNotFoundError:
return 0
return 0, 0
stmt = construct_document_select_by_usergroup(int(self._id))
for doc in db_session.scalars(stmt).yield_per(1):
@@ -97,7 +98,7 @@ class RedisUserGroup(RedisObjectHelper):
async_results.append(result)
return len(async_results)
return len(async_results), len(async_results)
def reset(self) -> None:
self.redis.delete(self.taskset_key)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,44 @@
[
{
"url": "https://docs.danswer.dev/more/use_cases/overview",
"title": "Use Cases Overview",
"content": "How to leverage Danswer in your organization\n\nDanswer Overview\nDanswer is the AI Assistant connected to your organization's docs, apps, and people. Danswer makes Generative AI more versatile for work by enabling new types of questions like \"What is the most common feature request we've heard from customers this month\". Whereas other AI systems have no context of your team and are generally unhelpful with work related questions, Danswer makes it possible to ask these questions in natural language and get back answers in seconds.\n\nDanswer can connect to +30 different tools and the use cases are not limited to the ones in the following pages. The highlighted use cases are for inspiration and come from feedback gathered from our users and customers.\n\n\nCommon Getting Started Questions:\n\nWhy are these docs connected in my Danswer deployment?\nAnswer: This is just an example of how connectors work in Danswer. You can connect up your own team's knowledge and you will be able to ask questions unique to your organization. Danswer will keep all of the knowledge up to date and in sync with your connected applications.\n\nIs my data being sent anywhere when I connect it up to Danswer?\nAnswer: No! Danswer is built with data security as our highest priority. We open sourced it so our users can know exactly what is going on with their data. By default all of the document processing happens within Danswer. The only time it is sent outward is for the GenAI call to generate answers.\n\nWhere is the feature for auto sync-ing document level access permissions from all connected sources?\nAnswer: This falls under the Enterprise Edition set of Danswer features built on top of the MIT/community edition. If you are on Danswer Cloud, you have access to them by default. If you're running it yourself, reach out to the Danswer team to receive access.",
"chunk_ind": 0
},
{
"url": "https://docs.danswer.dev/more/use_cases/enterprise_search",
"title": "Enterprise Search",
"content": "Value of Enterprise Search with Danswer\n\nWhat is Enterprise Search and why is it Important?\nAn Enterprise Search system gives team members a single place to access all of the disparate knowledge of an organization. Critical information is saved across a host of channels like call transcripts with prospects, engineering design docs, IT runbooks, customer support email exchanges, project management tickets, and more. As fast moving teams scale up, information gets spread out and more disorganized.\n\nSince it quickly becomes infeasible to check across every source, decisions get made on incomplete information, employee satisfaction decreases, and the most valuable members of your team are tied up with constant distractions as junior teammates are unable to unblock themselves. Danswer solves this problem by letting anyone on the team access all of the knowledge across your organization in a permissioned and secure way. Users can ask questions in natural language and get back answers and documents across all of the connected sources instantly.\n\nWhat's the real cost?\nA typical knowledge worker spends over 2 hours a week on search, but more than that, the cost of incomplete or incorrect information can be extremely high. Customer support/success that isn't able to find the reference to similar cases could cause hours or even days of delay leading to lower customer satisfaction or in the worst case - churn. An account exec not realizing that a prospect had previously mentioned a specific need could lead to lost deals. An engineer not realizing a similar feature had previously been built could result in weeks of wasted development time and tech debt with duplicate implementation. With a lack of knowledge, your whole organization is navigating in the dark - inefficient and mistake prone.",
"chunk_ind": 0
},
{
"url": "https://docs.danswer.dev/more/use_cases/enterprise_search",
"title": "Enterprise Search",
"content": "More than Search\nWhen analyzing the entire corpus of knowledge within your company is as easy as asking a question in a search bar, your entire team can stay informed and up to date. Danswer also makes it trivial to identify where knowledge is well documented and where it is lacking. Team members who are centers of knowledge can begin to effectively document their expertise since it is no longer being thrown into a black hole. All of this allows the organization to achieve higher efficiency and drive business outcomes.\n\nWith Generative AI, the entire user experience has evolved as well. For example, instead of just finding similar cases for your customer support team to reference, Danswer breaks down the issue and explains it so that even the most junior members can understand it. This in turn lets them give the most holistic and technically accurate response possible to your customers. On the other end, even the super stars of your sales team will not be able to review 10 hours of transcripts before hopping on that critical call, but Danswer can easily parse through it in mere seconds and give crucial context to help your team close.",
"chunk_ind": 0
},
{
"url": "https://docs.danswer.dev/more/use_cases/ai_platform",
"title": "AI Platform",
"content": "Build AI Agents powered by the knowledge and workflows specific to your organization.\n\nBeyond Answers\nAgents enabled by generative AI and reasoning capable models are helping teams to automate their work. Danswer is helping teams make it happen. Danswer provides out of the box user chat sessions, attaching custom tools, handling LLM reasoning, code execution, data analysis, referencing internal knowledge, and much more.\n\nDanswer as a platform is not a no-code agent builder. We are made by developers for developers and this gives your team the full flexibility and power to create agents not constrained by blocks and simple logic paths.\n\nFlexibility and Extensibility\nDanswer is open source and completely whitebox. This not only gives transparency to what happens within the system but also means that your team can directly modify the source code to suit your unique needs.",
"chunk_ind": 0
},
{
"url": "https://docs.danswer.dev/more/use_cases/customer_support",
"title": "Customer Support",
"content": "Help your customer support team instantly answer any question across your entire product.\n\nAI Enabled Support\nCustomer support agents have one of the highest breadth jobs. They field requests that cover the entire surface area of the product and need to help your users find success on extremely short timelines. Because they're not the same people who designed or built the system, they often lack the depth of understanding needed - resulting in delays and escalations to other teams. Modern teams are leveraging AI to help their CS team optimize the speed and quality of these critical customer-facing interactions.\n\nThe Importance of Context\nThere are two critical components of AI copilots for customer support. The first is that the AI system needs to be connected with as much information as possible (not just support tools like Zendesk or Intercom) and that the knowledge needs to be as fresh as possible. Sometimes a fix might even be in places rarely checked by CS such as pull requests in a code repository. The second critical component is the ability of the AI system to break down difficult concepts and convoluted processes into more digestible descriptions and for your team members to be able to chat back and forth with the system to build a better understanding.\n\nDanswer takes care of both of these. The system connects up to over 30+ different applications and the knowledge is pulled in constantly so that the information access is always up to date.",
"chunk_ind": 0
},
{
"url": "https://docs.danswer.dev/more/use_cases/sales",
"title": "Sales",
"content": "Keep your team up to date on every conversation and update so they can close.\n\nRecall Every Detail\nBeing able to instantly revisit every detail of any call without reading transcripts is helping Sales teams provide more tailored pitches, build stronger relationships, and close more deals. Instead of searching and reading through hours of transcripts in preparation for a call, your team can now ask Danswer \"What specific features was ACME interested in seeing for the demo\". Since your team doesn't have time to read every transcript prior to a call, Danswer provides a more thorough summary because it can instantly parse hundreds of pages and distill out the relevant information. Even for fast lookups it becomes much more convenient - for example to brush up on connection building topics by asking \"What rapport building topic did we chat about in the last call with ACME\".\n\nKnow Every Product Update\nIt is impossible for Sales teams to keep up with every product update. Because of this, when a prospect has a question that the Sales team does not know, they have no choice but to rely on the Product and Engineering orgs to get an authoritative answer. Not only is this distracting to the other teams, it also slows down the time to respond to the prospect (and as we know, time is the biggest killer of deals). With Danswer, it is even possible to get answers live on call because of how fast accessing information becomes. A question like \"Have we shipped the Microsoft AD integration yet?\" can now be answered in seconds meaning that prospects can get answers while on the call instead of asynchronously and sales cycles are reduced as a result.",
"chunk_ind": 0
},
{
"url": "https://docs.danswer.dev/more/use_cases/operations",
"title": "Operations",
"content": "Double the productivity of your Ops teams like IT, HR, etc.\n\nAutomatically Resolve Tickets\nModern teams are leveraging AI to auto-resolve up to 50% of tickets. Whether it is an employee asking about benefits details or how to set up the VPN for remote work, Danswer can help your team help themselves. This frees up your team to do the real impactful work of landing star candidates or improving your internal processes.\n\nAI Aided Onboarding\nOne of the periods where your team needs the most help is when they're just ramping up. Instead of feeling lost in dozens of new tools, Danswer gives them a single place where they can ask about anything in natural language. Whether it's how to set up their work environment or what their onboarding goals are, Danswer can walk them through every step with the help of Generative AI. This lets your team feel more empowered and gives time back to the more seasoned members of your team to focus on moving the needle.",
"chunk_ind": 0
}
]

Some files were not shown because too many files have changed in this diff Show More