mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
195 Commits
fix-deleti
...
debug-test
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f82de7c45 | ||
|
|
45f67368a2 | ||
|
|
014ba9e220 | ||
|
|
ba64543dd7 | ||
|
|
18c62a0c24 | ||
|
|
33f555922c | ||
|
|
05f6f6d5b5 | ||
|
|
19dae1d870 | ||
|
|
6d859bd37c | ||
|
|
122e3fa3fa | ||
|
|
87b542b335 | ||
|
|
00229d2abe | ||
|
|
5f2644985c | ||
|
|
c82a36ad68 | ||
|
|
16d1c19d9f | ||
|
|
9f179940f8 | ||
|
|
8a8e2b310e | ||
|
|
2274cab554 | ||
|
|
ef104e9a82 | ||
|
|
a575d7f1eb | ||
|
|
f404c4b448 | ||
|
|
3884f1d70a | ||
|
|
bc9d5fece7 | ||
|
|
bb279a8580 | ||
|
|
a9403016c9 | ||
|
|
f3cea79c1c | ||
|
|
54bb79303c | ||
|
|
d3dfabb20e | ||
|
|
7d1ec1095c | ||
|
|
f531d071af | ||
|
|
4218814385 | ||
|
|
e662e3b57d | ||
|
|
2073820e33 | ||
|
|
5f25b243c5 | ||
|
|
a9427f190a | ||
|
|
18fbe9d7e8 | ||
|
|
75c9b1cafe | ||
|
|
632a8f700b | ||
|
|
cd58c96014 | ||
|
|
c5032d25c9 | ||
|
|
72acde6fd4 | ||
|
|
5596a68d08 | ||
|
|
5b18409c89 | ||
|
|
84272af5ac | ||
|
|
6bef70c8b7 | ||
|
|
7f7559e3d2 | ||
|
|
7ba829a585 | ||
|
|
8b2ecb4eab | ||
|
|
2dd3870504 | ||
|
|
df464fc54b | ||
|
|
96b98fbc4a | ||
|
|
66cf67d04d | ||
|
|
285bdbbaf9 | ||
|
|
e2c37d6847 | ||
|
|
3ff2ba7ee4 | ||
|
|
290f4f0f8c | ||
|
|
3c934a93cd | ||
|
|
a51b0f636e | ||
|
|
a50c2e30ec | ||
|
|
ee278522ef | ||
|
|
430c9a47d7 | ||
|
|
974f85da66 | ||
|
|
a63cb9da43 | ||
|
|
d807ad7699 | ||
|
|
3cb00de6d4 | ||
|
|
da6e46ae75 | ||
|
|
648c2531f9 | ||
|
|
fc98c560a4 | ||
|
|
566f44fcd6 | ||
|
|
2fe49e5efb | ||
|
|
f58acd4e2a | ||
|
|
53008a0271 | ||
|
|
13278663d9 | ||
|
|
31ca6857fb | ||
|
|
6dd91414be | ||
|
|
140c34e59e | ||
|
|
da8e68b320 | ||
|
|
e9a616e579 | ||
|
|
cb2169f2a3 | ||
|
|
79aa5dd6e0 | ||
|
|
604ebafe6c | ||
|
|
a2d775efbd | ||
|
|
641690e3f7 | ||
|
|
eebf98e3a6 | ||
|
|
4bc4da29f5 | ||
|
|
7af572d0e7 | ||
|
|
58bdf9d684 | ||
|
|
f69922fff7 | ||
|
|
d4d37c9cdd | ||
|
|
2654df49fd | ||
|
|
aee5fcd4e0 | ||
|
|
2c77dd241b | ||
|
|
d90c90dd92 | ||
|
|
2c971cf774 | ||
|
|
eab55bdd85 | ||
|
|
f4f2fb5943 | ||
|
|
71f2f1a90a | ||
|
|
74a2271422 | ||
|
|
d42fb6ce34 | ||
|
|
0d749ebd46 | ||
|
|
9f6e8bd124 | ||
|
|
3a2a6abed4 | ||
|
|
07f49a384f | ||
|
|
f1c5e80f17 | ||
|
|
b7ad810d83 | ||
|
|
99b28643f7 | ||
|
|
f52d1142eb | ||
|
|
e563746730 | ||
|
|
aa86830bde | ||
|
|
4558351801 | ||
|
|
a4dcae57cd | ||
|
|
dbd56f946f | ||
|
|
e4e4765c60 | ||
|
|
c967f53c02 | ||
|
|
3a9b964d5c | ||
|
|
f04ecbf87a | ||
|
|
362156f97e | ||
|
|
3fa9676478 | ||
|
|
be4b6189d2 | ||
|
|
ace041415a | ||
|
|
148c2a7375 | ||
|
|
1555ac9dab | ||
|
|
80de408cef | ||
|
|
e20c825e16 | ||
|
|
b0568ac8ae | ||
|
|
0896d3b7da | ||
|
|
87b27046bd | ||
|
|
5e9c6d1499 | ||
|
|
50211ec401 | ||
|
|
6012a7cbd9 | ||
|
|
1e4b27185d | ||
|
|
0c66da17bb | ||
|
|
d985cd4352 | ||
|
|
c8891a5829 | ||
|
|
51a13f5fc7 | ||
|
|
57c1deb8b8 | ||
|
|
e2e04af7e2 | ||
|
|
c1735fcd3a | ||
|
|
b43e5735d7 | ||
|
|
7d4f8ef4e8 | ||
|
|
7c03b6f521 | ||
|
|
ccf986808c | ||
|
|
350482e53e | ||
|
|
fb3d7330fa | ||
|
|
6cec31088d | ||
|
|
491f3254a5 | ||
|
|
5abf67fbf0 | ||
|
|
2933c3598b | ||
|
|
aeb6060854 | ||
|
|
8977b1b5fc | ||
|
|
69c0419146 | ||
|
|
2bd3833c55 | ||
|
|
2d7b312e6c | ||
|
|
ebe3674ca7 | ||
|
|
04f83eb1e1 | ||
|
|
420aabc963 | ||
|
|
61a17319c9 | ||
|
|
e4c85352b4 | ||
|
|
34ba3181ff | ||
|
|
630e2248bd | ||
|
|
c358c91e4c | ||
|
|
2b7915f33b | ||
|
|
0ff1a023cd | ||
|
|
d68d281e1c | ||
|
|
ebce3ff6ba | ||
|
|
f96bd12ab8 | ||
|
|
32359d2dff | ||
|
|
5da6d792de | ||
|
|
fb95398e5b | ||
|
|
af66650ee3 | ||
|
|
5b1f3c8d4e | ||
|
|
a3b1b1db38 | ||
|
|
7520fae068 | ||
|
|
39c946536c | ||
|
|
90528ba195 | ||
|
|
6afcaafe54 | ||
|
|
812ca69949 | ||
|
|
abe01144ca | ||
|
|
d988a3e736 | ||
|
|
2b14afe878 | ||
|
|
033ec0b6b1 | ||
|
|
14a9fecc64 | ||
|
|
0027f161d7 | ||
|
|
32e551b69c | ||
|
|
299cb5035c | ||
|
|
910821c723 | ||
|
|
aa84846298 | ||
|
|
c122be2f6a | ||
|
|
f871b4c6eb | ||
|
|
a96cea2ce0 | ||
|
|
8d443ada5b | ||
|
|
634de83d72 | ||
|
|
580848cf8c | ||
|
|
f01027cfb7 | ||
|
|
76db4b765a |
76
.github/actions/custom-build-and-push/action.yml
vendored
Normal file
76
.github/actions/custom-build-and-push/action.yml
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
name: 'Build and Push Docker Image with Retry'
|
||||
description: 'Attempts to build and push a Docker image, with a retry on failure'
|
||||
inputs:
|
||||
context:
|
||||
description: 'Build context'
|
||||
required: true
|
||||
file:
|
||||
description: 'Dockerfile location'
|
||||
required: true
|
||||
platforms:
|
||||
description: 'Target platforms'
|
||||
required: true
|
||||
pull:
|
||||
description: 'Always attempt to pull a newer version of the image'
|
||||
required: false
|
||||
default: 'true'
|
||||
push:
|
||||
description: 'Push the image to registry'
|
||||
required: false
|
||||
default: 'true'
|
||||
load:
|
||||
description: 'Load the image into Docker daemon'
|
||||
required: false
|
||||
default: 'true'
|
||||
tags:
|
||||
description: 'Image tags'
|
||||
required: true
|
||||
cache-from:
|
||||
description: 'Cache sources'
|
||||
required: false
|
||||
cache-to:
|
||||
description: 'Cache destinations'
|
||||
required: false
|
||||
retry-wait-time:
|
||||
description: 'Time to wait before retry in seconds'
|
||||
required: false
|
||||
default: '5'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Build and push Docker image (First Attempt)
|
||||
id: buildx1
|
||||
uses: docker/build-push-action@v5
|
||||
continue-on-error: true
|
||||
with:
|
||||
context: ${{ inputs.context }}
|
||||
file: ${{ inputs.file }}
|
||||
platforms: ${{ inputs.platforms }}
|
||||
pull: ${{ inputs.pull }}
|
||||
push: ${{ inputs.push }}
|
||||
load: ${{ inputs.load }}
|
||||
tags: ${{ inputs.tags }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
|
||||
- name: Wait to retry
|
||||
if: steps.buildx1.outcome != 'success'
|
||||
run: |
|
||||
echo "First attempt failed. Waiting ${{ inputs.retry-wait-time }} seconds before retry..."
|
||||
sleep ${{ inputs.retry-wait-time }}
|
||||
shell: bash
|
||||
|
||||
- name: Build and push Docker image (Retry Attempt)
|
||||
if: steps.buildx1.outcome != 'success'
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ${{ inputs.context }}
|
||||
file: ${{ inputs.file }}
|
||||
platforms: ${{ inputs.platforms }}
|
||||
pull: ${{ inputs.pull }}
|
||||
push: ${{ inputs.push }}
|
||||
load: ${{ inputs.load }}
|
||||
tags: ${{ inputs.tags }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
@@ -27,6 +27,11 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Install build-essential
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential
|
||||
|
||||
- name: Backend Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
|
||||
67
.github/workflows/pr-helm-chart-testing.yml.disabled.txt
vendored
Normal file
67
.github/workflows/pr-helm-chart-testing.yml.disabled.txt
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
# This workflow is intentionally disabled while we're still working on it
|
||||
# It's close to ready, but a race condition needs to be fixed with
|
||||
# API server and Vespa startup, and it needs to have a way to build/test against
|
||||
# local containers
|
||||
|
||||
name: Helm - Lint and Test Charts
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
lint-test:
|
||||
runs-on: Amd64
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4.2.0
|
||||
with:
|
||||
version: v3.14.4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
run: |
|
||||
changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }})
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Run chart-testing (lint)
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Create kind cluster
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.10.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct install --all --config ct.yaml
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
6
.github/workflows/pr-python-checks.yml
vendored
6
.github/workflows/pr-python-checks.yml
vendored
@@ -24,9 +24,9 @@ jobs:
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
pip install -r backend/requirements/model_server.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Run MyPy
|
||||
run: |
|
||||
|
||||
@@ -10,6 +10,9 @@ on:
|
||||
env:
|
||||
# Confluence
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
|
||||
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
|
||||
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
@@ -36,8 +39,8 @@ jobs:
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
|
||||
7
.github/workflows/pr-python-tests.yml
vendored
7
.github/workflows/pr-python-tests.yml
vendored
@@ -11,7 +11,8 @@ jobs:
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
|
||||
REDIS_CLOUD_PYTEST_PASSWORD: ${{ secrets.REDIS_CLOUD_PYTEST_PASSWORD }}
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
@@ -28,8 +29,8 @@ jobs:
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
|
||||
53
.github/workflows/run-it.yml
vendored
53
.github/workflows/run-it.yml
vendored
@@ -13,8 +13,7 @@ env:
|
||||
|
||||
jobs:
|
||||
integration-tests:
|
||||
runs-on:
|
||||
group: 'arm64-image-builders'
|
||||
runs-on: Amd64
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
@@ -28,30 +27,20 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build Web Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/arm64
|
||||
pull: true
|
||||
push: true
|
||||
load: true
|
||||
tags: danswer/danswer-web-server:it
|
||||
cache-from: type=registry,ref=danswer/danswer-web-server:it
|
||||
cache-to: |
|
||||
type=registry,ref=danswer/danswer-web-server:it,mode=max
|
||||
type=inline
|
||||
# NOTE: we don't need to build the Web Docker image since it's not used
|
||||
# during the IT for now. We have a separate action to verify it builds
|
||||
# succesfully
|
||||
- name: Pull Web Docker image
|
||||
run: |
|
||||
docker pull danswer/danswer-web-server:latest
|
||||
docker tag danswer/danswer-web-server:latest danswer/danswer-web-server:it
|
||||
|
||||
- name: Build Backend Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/arm64
|
||||
pull: true
|
||||
push: true
|
||||
load: true
|
||||
platforms: linux/amd64
|
||||
tags: danswer/danswer-backend:it
|
||||
cache-from: type=registry,ref=danswer/danswer-backend:it
|
||||
cache-to: |
|
||||
@@ -59,14 +48,11 @@ jobs:
|
||||
type=inline
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/arm64
|
||||
pull: true
|
||||
push: true
|
||||
load: true
|
||||
platforms: linux/amd64
|
||||
tags: danswer/danswer-model-server:it
|
||||
cache-from: type=registry,ref=danswer/danswer-model-server:it
|
||||
cache-to: |
|
||||
@@ -74,14 +60,11 @@ jobs:
|
||||
type=inline
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/arm64
|
||||
pull: true
|
||||
push: true
|
||||
load: true
|
||||
platforms: linux/amd64
|
||||
tags: danswer/integration-test-runner:it
|
||||
cache-from: type=registry,ref=danswer/integration-test-runner:it
|
||||
cache-to: |
|
||||
@@ -92,14 +75,19 @@ jobs:
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=it \
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d --build
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
docker logs -f danswer-stack-api_server-1 &
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
|
||||
@@ -137,6 +125,7 @@ jobs:
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
danswer/integration-test-runner:it
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -4,6 +4,6 @@
|
||||
.mypy_cache
|
||||
.idea
|
||||
/deployment/data/nginx/app.conf
|
||||
.vscode/launch.json
|
||||
.vscode/
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
|
||||
15
.vscode/env_template.txt
vendored
15
.vscode/env_template.txt
vendored
@@ -1,5 +1,5 @@
|
||||
# Copy this file to .env at the base of the repo and fill in the <REPLACE THIS> values
|
||||
# This will help with development iteration speed and reduce repeat tasks for dev
|
||||
# Copy this file to .env in the .vscode folder
|
||||
# Fill in the <REPLACE THIS> values as needed, it is recommended to set the GEN_AI_API_KEY value to avoid having to set up an LLM in the UI
|
||||
# Also check out danswer/backend/scripts/restart_containers.sh for a script to restart the containers which Danswer relies on outside of VSCode/Cursor processes
|
||||
|
||||
# For local dev, often user Authentication is not needed
|
||||
@@ -15,7 +15,7 @@ LOG_LEVEL=debug
|
||||
|
||||
# This passes top N results to LLM an additional time for reranking prior to answer generation
|
||||
# This step is quite heavy on token usage so we disable it for dev generally
|
||||
DISABLE_LLM_DOC_RELEVANCE=True
|
||||
DISABLE_LLM_DOC_RELEVANCE=False
|
||||
|
||||
|
||||
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
|
||||
@@ -27,9 +27,9 @@ REQUIRE_EMAIL_VERIFICATION=False
|
||||
|
||||
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
|
||||
GEN_AI_API_KEY=<REPLACE THIS>
|
||||
# If answer quality isn't important for dev, use 3.5 turbo due to it being cheaper
|
||||
GEN_AI_MODEL_VERSION=gpt-3.5-turbo
|
||||
FAST_GEN_AI_MODEL_VERSION=gpt-3.5-turbo
|
||||
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper
|
||||
GEN_AI_MODEL_VERSION=gpt-4o
|
||||
FAST_GEN_AI_MODEL_VERSION=gpt-4o
|
||||
|
||||
# For Danswer Slack Bot, overrides the UI values so no need to set this up via UI every time
|
||||
# Only needed if using DanswerBot
|
||||
@@ -38,7 +38,7 @@ FAST_GEN_AI_MODEL_VERSION=gpt-3.5-turbo
|
||||
|
||||
|
||||
# Python stuff
|
||||
PYTHONPATH=./backend
|
||||
PYTHONPATH=../backend
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
|
||||
@@ -49,4 +49,3 @@ BING_API_KEY=<REPLACE THIS>
|
||||
# Enable the full set of Danswer Enterprise Edition features
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
|
||||
|
||||
84
.vscode/launch.template.jsonc
vendored
84
.vscode/launch.template.jsonc
vendored
@@ -1,15 +1,23 @@
|
||||
/*
|
||||
|
||||
Copy this file into '.vscode/launch.json' or merge its
|
||||
contents into your existing configurations.
|
||||
|
||||
*/
|
||||
/* Copy this file into '.vscode/launch.json' or merge its contents into your existing configurations. */
|
||||
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"compounds": [
|
||||
{
|
||||
"name": "Run All Danswer Services",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
"Indexing",
|
||||
"Background Jobs",
|
||||
"Slack Bot"
|
||||
]
|
||||
}
|
||||
],
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Web Server",
|
||||
@@ -17,7 +25,7 @@
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceRoot}/web",
|
||||
"runtimeExecutable": "npm",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"runtimeArgs": [
|
||||
"run", "dev"
|
||||
],
|
||||
@@ -25,11 +33,12 @@
|
||||
},
|
||||
{
|
||||
"name": "Model Server",
|
||||
"type": "python",
|
||||
"consoleName": "Model Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
@@ -39,16 +48,16 @@
|
||||
"--reload",
|
||||
"--port",
|
||||
"9000"
|
||||
],
|
||||
"consoleTitle": "Model Server"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "API Server",
|
||||
"type": "python",
|
||||
"consoleName": "API Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
@@ -59,32 +68,32 @@
|
||||
"--reload",
|
||||
"--port",
|
||||
"8080"
|
||||
],
|
||||
"consoleTitle": "API Server"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Indexing",
|
||||
"type": "python",
|
||||
"consoleName": "Indexing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "danswer/background/update.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"consoleTitle": "Indexing"
|
||||
}
|
||||
},
|
||||
// Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev
|
||||
{
|
||||
"name": "Background Jobs",
|
||||
"type": "python",
|
||||
"consoleName": "Background Jobs",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
@@ -93,18 +102,18 @@
|
||||
},
|
||||
"args": [
|
||||
"--no-indexing"
|
||||
],
|
||||
"consoleTitle": "Background Jobs"
|
||||
]
|
||||
},
|
||||
// For the listner to access the Slack API,
|
||||
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
"type": "python",
|
||||
"consoleName": "Slack Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "danswer/danswerbot/slack/listener.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
@@ -113,11 +122,12 @@
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"type": "python",
|
||||
"consoleName": "Pytest",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
@@ -128,18 +138,16 @@
|
||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
|
||||
]
|
||||
}
|
||||
],
|
||||
"compounds": [
|
||||
},
|
||||
{
|
||||
"name": "Run Danswer",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
"Indexing",
|
||||
"Background Jobs",
|
||||
]
|
||||
"name": "Clear and Restart External Volumes and Containers",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"stopOnEntry": true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
105
CONTRIBUTING.md
105
CONTRIBUTING.md
@@ -48,23 +48,26 @@ We would love to see you there!
|
||||
|
||||
|
||||
## Get Started 🚀
|
||||
Danswer being a fully functional app, relies on some external pieces of software, specifically:
|
||||
Danswer being a fully functional app, relies on some external software, specifically:
|
||||
- [Postgres](https://www.postgresql.org/) (Relational DB)
|
||||
- [Vespa](https://vespa.ai/) (Vector DB/Search Engine)
|
||||
- [Redis](https://redis.io/) (Cache)
|
||||
- [Nginx](https://nginx.org/) (Not needed for development flows generally)
|
||||
|
||||
This guide provides instructions to set up the Danswer specific services outside of Docker because it's easier for
|
||||
development purposes but also feel free to just use the containers and update with local changes by providing the
|
||||
`--build` flag.
|
||||
|
||||
> **Note:**
|
||||
> This guide provides instructions to build and run Danswer locally from source with Docker containers providing the above external software. We believe this combination is easier for
|
||||
> development purposes. If you prefer to use pre-built container images, we provide instructions on running the full Danswer stack within Docker below.
|
||||
|
||||
|
||||
### Local Set Up
|
||||
It is recommended to use Python version 3.11
|
||||
Be sure to use Python version 3.11. For instructions on installing Python 3.11 on macOS, refer to the [CONTRIBUTING_MACOS.md](./CONTRIBUTING_MACOS.md) readme.
|
||||
|
||||
If using a lower version, modifications will have to be made to the code.
|
||||
If using a higher version, the version of Tensorflow we use may not be available for your platform.
|
||||
If using a higher version, sometimes some libraries will not be available (i.e. we had problems with Tensorflow in the past with higher versions of python).
|
||||
|
||||
|
||||
#### Installing Requirements
|
||||
#### Backend: Python requirements
|
||||
Currently, we use pip and recommend creating a virtual environment.
|
||||
|
||||
For convenience here's a command for it:
|
||||
@@ -73,8 +76,9 @@ python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
--> Note that this virtual environment MUST NOT be set up WITHIN the danswer
|
||||
directory
|
||||
> **Note:**
|
||||
> This virtual environment MUST NOT be set up WITHIN the danswer directory if you plan on using mypy within certain IDEs.
|
||||
> For simplicity, we recommend setting up the virtual environment outside of the danswer directory.
|
||||
|
||||
_For Windows, activate the virtual environment using Command Prompt:_
|
||||
```bash
|
||||
@@ -89,34 +93,38 @@ Install the required python dependencies:
|
||||
```bash
|
||||
pip install -r danswer/backend/requirements/default.txt
|
||||
pip install -r danswer/backend/requirements/dev.txt
|
||||
pip install -r danswer/backend/requirements/ee.txt
|
||||
pip install -r danswer/backend/requirements/model_server.txt
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector)
|
||||
|
||||
In the activated Python virtualenv, install Playwright for Python by running:
|
||||
```bash
|
||||
playwright install
|
||||
```
|
||||
|
||||
You may have to deactivate and reactivate your virtualenv for `playwright` to appear on your path.
|
||||
|
||||
#### Frontend: Node dependencies
|
||||
|
||||
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
|
||||
Once the above is done, navigate to `danswer/web` run:
|
||||
```bash
|
||||
npm i
|
||||
```
|
||||
|
||||
Install Playwright (required by the Web Connector)
|
||||
#### Docker containers for external software
|
||||
You will need Docker installed to run these containers.
|
||||
|
||||
> Note: If you have just done the pip install, open a new terminal and source the python virtual-env again.
|
||||
This will update the path to include playwright
|
||||
|
||||
Then install Playwright by running:
|
||||
First navigate to `danswer/deployment/docker_compose`, then start up Postgres/Vespa/Redis with:
|
||||
```bash
|
||||
playwright install
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db cache
|
||||
```
|
||||
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
|
||||
|
||||
|
||||
#### Dependent Docker Containers
|
||||
First navigate to `danswer/deployment/docker_compose`, then start up Vespa and Postgres with:
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db
|
||||
```
|
||||
(index refers to Vespa and relational_db refers to Postgres)
|
||||
|
||||
#### Running Danswer
|
||||
#### Running Danswer locally
|
||||
To start the frontend, navigate to `danswer/web` and run:
|
||||
```bash
|
||||
npm run dev
|
||||
@@ -127,11 +135,10 @@ Navigate to `danswer/backend` and run:
|
||||
```bash
|
||||
uvicorn model_server.main:app --reload --port 9000
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
```bash
|
||||
powershell -Command "
|
||||
uvicorn model_server.main:app --reload --port 9000
|
||||
"
|
||||
powershell -Command "uvicorn model_server.main:app --reload --port 9000"
|
||||
```
|
||||
|
||||
The first time running Danswer, you will need to run the DB migrations for Postgres.
|
||||
@@ -154,6 +161,7 @@ To run the backend API server, navigate back to `danswer/backend` and run:
|
||||
```bash
|
||||
AUTH_TYPE=disabled uvicorn danswer.main:app --reload --port 8080
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
```bash
|
||||
powershell -Command "
|
||||
@@ -162,20 +170,58 @@ powershell -Command "
|
||||
"
|
||||
```
|
||||
|
||||
Note: if you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
|
||||
> **Note:**
|
||||
> If you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
|
||||
|
||||
#### Wrapping up
|
||||
|
||||
You should now have 4 servers running:
|
||||
|
||||
- Web server
|
||||
- Backend API
|
||||
- Model server
|
||||
- Background jobs
|
||||
|
||||
Now, visit `http://localhost:3000` in your browser. You should see the Danswer onboarding wizard where you can connect your external LLM provider to Danswer.
|
||||
|
||||
You've successfully set up a local Danswer instance! 🏁
|
||||
|
||||
#### Running the Danswer application in a container
|
||||
|
||||
You can run the full Danswer application stack from pre-built images including all external software dependencies.
|
||||
|
||||
Navigate to `danswer/deployment/docker_compose` and run:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
```
|
||||
|
||||
After Docker pulls and starts these containers, navigate to `http://localhost:3000` to use Danswer.
|
||||
|
||||
If you want to make changes to Danswer and run those changes in Docker, you can also build a local version of the Danswer container images that incorporates your changes like so:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d --build
|
||||
```
|
||||
|
||||
### Formatting and Linting
|
||||
#### Backend
|
||||
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
|
||||
First, install pre-commit (if you don't have it already) following the instructions
|
||||
[here](https://pre-commit.com/#installation).
|
||||
|
||||
With the virtual environment active, install the pre-commit library with:
|
||||
```bash
|
||||
pip install pre-commit
|
||||
```
|
||||
|
||||
Then, from the `danswer/backend` directory, run:
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
Additionally, we use `mypy` for static type checking.
|
||||
Danswer is fully type-annotated, and we would like to keep it that way!
|
||||
Danswer is fully type-annotated, and we want to keep it that way!
|
||||
To run the mypy checks manually, run `python -m mypy .` from the `danswer/backend` directory.
|
||||
|
||||
|
||||
@@ -186,6 +232,7 @@ Please double check that prettier passes before creating a pull request.
|
||||
|
||||
|
||||
### Release Process
|
||||
Danswer follows the semver versioning standard.
|
||||
Danswer loosely follows the SemVer versioning standard.
|
||||
Major changes are released with a "minor" version bump. Currently we use patch release versions to indicate small feature changes.
|
||||
A set of Docker containers will be pushed automatically to DockerHub with every tag.
|
||||
You can see the containers [here](https://hub.docker.com/search?q=danswer%2F).
|
||||
|
||||
31
CONTRIBUTING_MACOS.md
Normal file
31
CONTRIBUTING_MACOS.md
Normal file
@@ -0,0 +1,31 @@
|
||||
## Some additional notes for Mac Users
|
||||
The base instructions to set up the development environment are located in [CONTRIBUTING.md](https://github.com/danswer-ai/danswer/blob/main/CONTRIBUTING.md).
|
||||
|
||||
### Setting up Python
|
||||
Ensure [Homebrew](https://brew.sh/) is already set up.
|
||||
|
||||
Then install python 3.11.
|
||||
```bash
|
||||
brew install python@3.11
|
||||
```
|
||||
|
||||
Add python 3.11 to your path: add the following line to ~/.zshrc
|
||||
```
|
||||
export PATH="$(brew --prefix)/opt/python@3.11/libexec/bin:$PATH"
|
||||
```
|
||||
|
||||
> **Note:**
|
||||
> You will need to open a new terminal for the path change above to take effect.
|
||||
|
||||
|
||||
### Setting up Docker
|
||||
On macOS, you will need to install [Docker Desktop](https://www.docker.com/products/docker-desktop/) and
|
||||
ensure it is running before continuing with the docker commands.
|
||||
|
||||
|
||||
### Formatting and Linting
|
||||
MacOS will likely require you to remove some quarantine attributes on some of the hooks for them to execute properly.
|
||||
After installing pre-commit, run the following command:
|
||||
```bash
|
||||
sudo xattr -r -d com.apple.quarantine ~/.cache/pre-commit
|
||||
```
|
||||
@@ -9,7 +9,8 @@ founders@danswer.ai for more information. Please visit https://github.com/danswe
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.3-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION}
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
# Install system dependencies
|
||||
@@ -40,6 +41,8 @@ RUN apt-get update && \
|
||||
COPY ./requirements/default.txt /tmp/requirements.txt
|
||||
COPY ./requirements/ee.txt /tmp/ee-requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
--retries 5 \
|
||||
--timeout 30 \
|
||||
-r /tmp/requirements.txt \
|
||||
-r /tmp/ee-requirements.txt && \
|
||||
pip uninstall -y py && \
|
||||
@@ -75,8 +78,8 @@ Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
# Pre-downloading NLTK for setups with limited egress
|
||||
RUN python -c "import nltk; \
|
||||
nltk.download('stopwords', quiet=True); \
|
||||
nltk.download('wordnet', quiet=True); \
|
||||
nltk.download('punkt', quiet=True);"
|
||||
# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed
|
||||
|
||||
# Set up application files
|
||||
WORKDIR /app
|
||||
|
||||
@@ -8,11 +8,17 @@ visit https://github.com/danswer-ai/danswer."
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.3-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION}
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
|
||||
COPY ./requirements/model_server.txt /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
--retries 5 \
|
||||
--timeout 30 \
|
||||
-r /tmp/requirements.txt
|
||||
|
||||
RUN apt-get remove -y --allow-remove-essential perl-base && \
|
||||
apt-get autoremove -y
|
||||
|
||||
@@ -16,7 +16,9 @@ config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
if config.config_file_name is not None and config.attributes.get(
|
||||
"configure_logger", True
|
||||
):
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
"""add ccpair deletion failure message
|
||||
|
||||
Revision ID: 0ebb1d516877
|
||||
Revises: 52a219fb5233
|
||||
Create Date: 2024-09-10 15:03:48.233926
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0ebb1d516877"
|
||||
down_revision = "52a219fb5233"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("deletion_failure_message", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "deletion_failure_message")
|
||||
@@ -0,0 +1,102 @@
|
||||
"""add_user_delete_cascades
|
||||
|
||||
Revision ID: 1b8206b29c5d
|
||||
Revises: 35e6853a51d5
|
||||
Create Date: 2024-09-18 11:48:59.418726
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1b8206b29c5d"
|
||||
down_revision = "35e6853a51d5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_constraint("credential_user_id_fkey", "credential", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"credential_user_id_fkey",
|
||||
"credential",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.drop_constraint("chat_session_user_id_fkey", "chat_session", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"chat_session_user_id_fkey",
|
||||
"chat_session",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.drop_constraint("chat_folder_user_id_fkey", "chat_folder", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"chat_folder_user_id_fkey",
|
||||
"chat_folder",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.drop_constraint("prompt_user_id_fkey", "prompt", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"prompt_user_id_fkey", "prompt", "user", ["user_id"], ["id"], ondelete="CASCADE"
|
||||
)
|
||||
|
||||
op.drop_constraint("notification_user_id_fkey", "notification", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"notification_user_id_fkey",
|
||||
"notification",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.drop_constraint("inputprompt_user_id_fkey", "inputprompt", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"inputprompt_user_id_fkey",
|
||||
"inputprompt",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint("credential_user_id_fkey", "credential", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"credential_user_id_fkey", "credential", "user", ["user_id"], ["id"]
|
||||
)
|
||||
|
||||
op.drop_constraint("chat_session_user_id_fkey", "chat_session", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"chat_session_user_id_fkey", "chat_session", "user", ["user_id"], ["id"]
|
||||
)
|
||||
|
||||
op.drop_constraint("chat_folder_user_id_fkey", "chat_folder", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"chat_folder_user_id_fkey", "chat_folder", "user", ["user_id"], ["id"]
|
||||
)
|
||||
|
||||
op.drop_constraint("prompt_user_id_fkey", "prompt", type_="foreignkey")
|
||||
op.create_foreign_key("prompt_user_id_fkey", "prompt", "user", ["user_id"], ["id"])
|
||||
|
||||
op.drop_constraint("notification_user_id_fkey", "notification", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"notification_user_id_fkey", "notification", "user", ["user_id"], ["id"]
|
||||
)
|
||||
|
||||
op.drop_constraint("inputprompt_user_id_fkey", "inputprompt", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"inputprompt_user_id_fkey", "inputprompt", "user", ["user_id"], ["id"]
|
||||
)
|
||||
@@ -30,7 +30,7 @@ def upgrade() -> None:
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"multipass_indexing", sa.Boolean(), nullable=False, server_default="true"
|
||||
"multipass_indexing", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
"""server default chosen assistants
|
||||
|
||||
Revision ID: 35e6853a51d5
|
||||
Revises: c99d76fcd298
|
||||
Create Date: 2024-09-13 13:20:32.885317
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "35e6853a51d5"
|
||||
down_revision = "c99d76fcd298"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
DEFAULT_ASSISTANTS = [-2, -1, 0]
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Step 1: Update any NULL values to the default value
|
||||
# This upgrades existing users without ordered assistant
|
||||
# to have default assistants set to visible assistants which are
|
||||
# accessible by them.
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user" u
|
||||
SET chosen_assistants = (
|
||||
SELECT jsonb_agg(
|
||||
p.id ORDER BY
|
||||
COALESCE(p.display_priority, 2147483647) ASC,
|
||||
p.id ASC
|
||||
)
|
||||
FROM persona p
|
||||
LEFT JOIN persona__user pu ON p.id = pu.persona_id AND pu.user_id = u.id
|
||||
WHERE p.is_visible = true
|
||||
AND (p.is_public = true OR pu.user_id IS NOT NULL)
|
||||
)
|
||||
WHERE chosen_assistants IS NULL
|
||||
OR chosen_assistants = 'null'
|
||||
OR jsonb_typeof(chosen_assistants) = 'null'
|
||||
OR (jsonb_typeof(chosen_assistants) = 'string' AND chosen_assistants = '"null"')
|
||||
"""
|
||||
)
|
||||
|
||||
# Step 2: Alter the column to make it non-nullable
|
||||
op.alter_column(
|
||||
"user",
|
||||
"chosen_assistants",
|
||||
type_=postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default=sa.text(f"'{DEFAULT_ASSISTANTS}'::jsonb"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"user",
|
||||
"chosen_assistants",
|
||||
type_=postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
server_default=None,
|
||||
)
|
||||
@@ -0,0 +1,66 @@
|
||||
"""Add last synced and last modified to document table
|
||||
|
||||
Revision ID: 52a219fb5233
|
||||
Revises: f7e58d357687
|
||||
Create Date: 2024-08-28 17:40:46.077470
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "52a219fb5233"
|
||||
down_revision = "f7e58d357687"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# last modified represents the last time anything needing syncing to vespa changed
|
||||
# including row metadata and the document itself. This obviously does not include
|
||||
# the last_synced column.
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column(
|
||||
"last_modified",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
),
|
||||
)
|
||||
|
||||
# last synced represents the last time this document was synced to Vespa
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
# Set last_synced to the same value as last_modified for existing rows
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE document
|
||||
SET last_synced = last_modified
|
||||
"""
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
op.f("ix_document_last_modified"),
|
||||
"document",
|
||||
["last_modified"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
op.f("ix_document_last_synced"),
|
||||
"document",
|
||||
["last_synced"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f("ix_document_last_synced"), table_name="document")
|
||||
op.drop_index(op.f("ix_document_last_modified"), table_name="document")
|
||||
op.drop_column("document", "last_synced")
|
||||
op.drop_column("document", "last_modified")
|
||||
79
backend/alembic/versions/55546a7967ee_assistant_rework.py
Normal file
79
backend/alembic/versions/55546a7967ee_assistant_rework.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""assistant_rework
|
||||
|
||||
Revision ID: 55546a7967ee
|
||||
Revises: 61ff3651add4
|
||||
Create Date: 2024-09-18 17:00:23.755399
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "55546a7967ee"
|
||||
down_revision = "61ff3651add4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Reworking persona and user tables for new assistant features
|
||||
# keep track of user's chosen assistants separate from their `ordering`
|
||||
op.add_column("persona", sa.Column("builtin_persona", sa.Boolean(), nullable=True))
|
||||
op.execute("UPDATE persona SET builtin_persona = default_persona")
|
||||
op.alter_column("persona", "builtin_persona", nullable=False)
|
||||
op.drop_index("_default_persona_name_idx", table_name="persona")
|
||||
op.create_index(
|
||||
"_builtin_persona_name_idx",
|
||||
"persona",
|
||||
["name"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("builtin_persona = true"),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"user", sa.Column("visible_assistants", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"user", sa.Column("hidden_assistants", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE \"user\" SET visible_assistants = '[]'::jsonb, hidden_assistants = '[]'::jsonb"
|
||||
)
|
||||
op.alter_column(
|
||||
"user",
|
||||
"visible_assistants",
|
||||
nullable=False,
|
||||
server_default=sa.text("'[]'::jsonb"),
|
||||
)
|
||||
op.alter_column(
|
||||
"user",
|
||||
"hidden_assistants",
|
||||
nullable=False,
|
||||
server_default=sa.text("'[]'::jsonb"),
|
||||
)
|
||||
op.drop_column("persona", "default_persona")
|
||||
op.add_column(
|
||||
"persona", sa.Column("is_default_persona", sa.Boolean(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Reverting changes made in upgrade
|
||||
op.drop_column("user", "hidden_assistants")
|
||||
op.drop_column("user", "visible_assistants")
|
||||
op.drop_index("_builtin_persona_name_idx", table_name="persona")
|
||||
|
||||
op.drop_column("persona", "is_default_persona")
|
||||
op.add_column("persona", sa.Column("default_persona", sa.Boolean(), nullable=True))
|
||||
op.execute("UPDATE persona SET default_persona = builtin_persona")
|
||||
op.alter_column("persona", "default_persona", nullable=False)
|
||||
op.drop_column("persona", "builtin_persona")
|
||||
op.create_index(
|
||||
"_default_persona_name_idx",
|
||||
"persona",
|
||||
["name"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("default_persona = true"),
|
||||
)
|
||||
@@ -0,0 +1,35 @@
|
||||
"""match_any_keywords flag for standard answers
|
||||
|
||||
Revision ID: 5c7fdadae813
|
||||
Revises: efb35676026c
|
||||
Create Date: 2024-09-13 18:52:59.256478
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5c7fdadae813"
|
||||
down_revision = "efb35676026c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"standard_answer",
|
||||
sa.Column(
|
||||
"match_any_keywords",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("standard_answer", "match_any_keywords")
|
||||
# ### end Alembic commands ###
|
||||
162
backend/alembic/versions/61ff3651add4_add_permission_syncing.py
Normal file
162
backend/alembic/versions/61ff3651add4_add_permission_syncing.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Add Permission Syncing
|
||||
|
||||
Revision ID: 61ff3651add4
|
||||
Revises: 1b8206b29c5d
|
||||
Create Date: 2024-09-05 13:57:11.770413
|
||||
|
||||
"""
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "61ff3651add4"
|
||||
down_revision = "1b8206b29c5d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Admin user who set up connectors will lose access to the docs temporarily
|
||||
# only way currently to give back access is to rerun from beginning
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"access_type",
|
||||
sa.String(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET access_type = 'PUBLIC' WHERE is_public = true"
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET access_type = 'PRIVATE' WHERE is_public = false"
|
||||
)
|
||||
op.alter_column("connector_credential_pair", "access_type", nullable=False)
|
||||
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"auto_sync_options",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("last_time_perm_sync", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.drop_column("connector_credential_pair", "is_public")
|
||||
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("external_user_emails", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column(
|
||||
"external_user_group_ids", postgresql.ARRAY(sa.String()), nullable=True
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("is_public", sa.Boolean(), nullable=True),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"user__external_user_group_id",
|
||||
sa.Column(
|
||||
"user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False
|
||||
),
|
||||
sa.Column("external_user_group_id", sa.String(), nullable=False),
|
||||
sa.Column("cc_pair_id", sa.Integer(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("user_id"),
|
||||
)
|
||||
|
||||
op.drop_column("external_permission", "user_id")
|
||||
op.drop_column("email_to_external_user_cache", "user_id")
|
||||
op.drop_table("permission_sync_run")
|
||||
op.drop_table("external_permission")
|
||||
op.drop_table("email_to_external_user_cache")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("is_public", sa.BOOLEAN(), nullable=True),
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET is_public = (access_type = 'PUBLIC')"
|
||||
)
|
||||
op.alter_column("connector_credential_pair", "is_public", nullable=False)
|
||||
|
||||
op.drop_column("connector_credential_pair", "auto_sync_options")
|
||||
op.drop_column("connector_credential_pair", "access_type")
|
||||
op.drop_column("connector_credential_pair", "last_time_perm_sync")
|
||||
op.drop_column("document", "external_user_emails")
|
||||
op.drop_column("document", "external_user_group_ids")
|
||||
op.drop_column("document", "is_public")
|
||||
|
||||
op.drop_table("user__external_user_group_id")
|
||||
|
||||
# Drop the enum type at the end of the downgrade
|
||||
op.create_table(
|
||||
"permission_sync_run",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"source_type",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("update_type", sa.String(), nullable=False),
|
||||
sa.Column("cc_pair_id", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("error_msg", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["cc_pair_id"],
|
||||
["connector_credential_pair.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"external_permission",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("user_email", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"source_type",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("external_permission_group", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"email_to_external_user_cache",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("external_user_id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("user_email", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
27
backend/alembic/versions/797089dfb4d2_persona_start_date.py
Normal file
27
backend/alembic/versions/797089dfb4d2_persona_start_date.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""persona_start_date
|
||||
|
||||
Revision ID: 797089dfb4d2
|
||||
Revises: 55546a7967ee
|
||||
Create Date: 2024-09-11 14:51:49.785835
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "797089dfb4d2"
|
||||
down_revision = "55546a7967ee"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("search_start_date", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "search_start_date")
|
||||
@@ -0,0 +1,158 @@
|
||||
"""migration confluence to be explicit
|
||||
|
||||
Revision ID: a3795dce87be
|
||||
Revises: 1f60f60c3401
|
||||
Create Date: 2024-09-01 13:52:12.006740
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.sql import table, column
|
||||
|
||||
revision = "a3795dce87be"
|
||||
down_revision = "1f60f60c3401"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]:
|
||||
parsed_url = urlparse(wiki_url)
|
||||
wiki_base = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split('/spaces')[0]}"
|
||||
path_parts = parsed_url.path.split("/")
|
||||
space = path_parts[3]
|
||||
page_id = path_parts[5] if len(path_parts) > 5 else ""
|
||||
return wiki_base, space, page_id
|
||||
|
||||
def _extract_confluence_keys_from_datacenter_url(
|
||||
wiki_url: str,
|
||||
) -> tuple[str, str, str]:
|
||||
DISPLAY = "/display/"
|
||||
PAGE = "/pages/"
|
||||
parsed_url = urlparse(wiki_url)
|
||||
wiki_base = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split(DISPLAY)[0]}"
|
||||
space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0]
|
||||
page_id = ""
|
||||
if (content := parsed_url.path.split(PAGE)) and len(content) > 1:
|
||||
page_id = content[1]
|
||||
return wiki_base, space, page_id
|
||||
|
||||
is_confluence_cloud = (
|
||||
".atlassian.net/wiki/spaces/" in wiki_url
|
||||
or ".jira.com/wiki/spaces/" in wiki_url
|
||||
)
|
||||
|
||||
if is_confluence_cloud:
|
||||
wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url(wiki_url)
|
||||
else:
|
||||
wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url(
|
||||
wiki_url
|
||||
)
|
||||
|
||||
return wiki_base, space, page_id, is_confluence_cloud
|
||||
|
||||
|
||||
def reconstruct_confluence_url(
|
||||
wiki_base: str, space: str, page_id: str, is_cloud: bool
|
||||
) -> str:
|
||||
if is_cloud:
|
||||
url = f"{wiki_base}/spaces/{space}"
|
||||
if page_id:
|
||||
url += f"/pages/{page_id}"
|
||||
else:
|
||||
url = f"{wiki_base}/display/{space}"
|
||||
if page_id:
|
||||
url += f"/pages/{page_id}"
|
||||
return url
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
connector = table(
|
||||
"connector",
|
||||
column("id", sa.Integer),
|
||||
column("source", sa.String()),
|
||||
column("input_type", sa.String()),
|
||||
column("connector_specific_config", postgresql.JSONB),
|
||||
)
|
||||
|
||||
# Fetch all Confluence connectors
|
||||
connection = op.get_bind()
|
||||
confluence_connectors = connection.execute(
|
||||
sa.select(connector).where(
|
||||
sa.and_(
|
||||
connector.c.source == "CONFLUENCE", connector.c.input_type == "POLL"
|
||||
)
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for row in confluence_connectors:
|
||||
config = row.connector_specific_config
|
||||
wiki_page_url = config["wiki_page_url"]
|
||||
wiki_base, space, page_id, is_cloud = extract_confluence_keys_from_url(
|
||||
wiki_page_url
|
||||
)
|
||||
|
||||
new_config = {
|
||||
"wiki_base": wiki_base,
|
||||
"space": space,
|
||||
"page_id": page_id,
|
||||
"is_cloud": is_cloud,
|
||||
}
|
||||
|
||||
for key, value in config.items():
|
||||
if key not in ["wiki_page_url"]:
|
||||
new_config[key] = value
|
||||
|
||||
op.execute(
|
||||
connector.update()
|
||||
.where(connector.c.id == row.id)
|
||||
.values(connector_specific_config=new_config)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
connector = table(
|
||||
"connector",
|
||||
column("id", sa.Integer),
|
||||
column("source", sa.String()),
|
||||
column("input_type", sa.String()),
|
||||
column("connector_specific_config", postgresql.JSONB),
|
||||
)
|
||||
|
||||
confluence_connectors = (
|
||||
op.get_bind()
|
||||
.execute(
|
||||
sa.select(connector).where(
|
||||
connector.c.source == "CONFLUENCE", connector.c.input_type == "POLL"
|
||||
)
|
||||
)
|
||||
.fetchall()
|
||||
)
|
||||
|
||||
for row in confluence_connectors:
|
||||
config = row.connector_specific_config
|
||||
if all(key in config for key in ["wiki_base", "space", "is_cloud"]):
|
||||
wiki_page_url = reconstruct_confluence_url(
|
||||
config["wiki_base"],
|
||||
config["space"],
|
||||
config.get("page_id", ""),
|
||||
config["is_cloud"],
|
||||
)
|
||||
|
||||
new_config = {"wiki_page_url": wiki_page_url}
|
||||
new_config.update(
|
||||
{
|
||||
k: v
|
||||
for k, v in config.items()
|
||||
if k not in ["wiki_base", "space", "page_id", "is_cloud"]
|
||||
}
|
||||
)
|
||||
|
||||
op.execute(
|
||||
connector.update()
|
||||
.where(connector.c.id == row.id)
|
||||
.values(connector_specific_config=new_config)
|
||||
)
|
||||
@@ -0,0 +1,26 @@
|
||||
"""add support for litellm proxy in reranking
|
||||
|
||||
Revision ID: ba98eba0f66a
|
||||
Revises: bceb1e139447
|
||||
Create Date: 2024-09-06 10:36:04.507332
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ba98eba0f66a"
|
||||
down_revision = "bceb1e139447"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"search_settings", sa.Column("rerank_api_url", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("search_settings", "rerank_api_url")
|
||||
@@ -0,0 +1,26 @@
|
||||
"""Add base_url to CloudEmbeddingProvider
|
||||
|
||||
Revision ID: bceb1e139447
|
||||
Revises: a3795dce87be
|
||||
Create Date: 2024-08-28 17:00:52.554580
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bceb1e139447"
|
||||
down_revision = "a3795dce87be"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"embedding_provider", sa.Column("api_url", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("embedding_provider", "api_url")
|
||||
@@ -0,0 +1,43 @@
|
||||
"""non nullable default persona
|
||||
|
||||
Revision ID: bd2921608c3a
|
||||
Revises: 797089dfb4d2
|
||||
Create Date: 2024-09-20 10:28:37.992042
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bd2921608c3a"
|
||||
down_revision = "797089dfb4d2"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Set existing NULL values to False
|
||||
op.execute(
|
||||
"UPDATE persona SET is_default_persona = FALSE WHERE is_default_persona IS NULL"
|
||||
)
|
||||
|
||||
# Alter the column to be not nullable with a default value of False
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"is_default_persona",
|
||||
existing_type=sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert the changes
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"is_default_persona",
|
||||
existing_type=sa.Boolean(),
|
||||
nullable=True,
|
||||
server_default=None,
|
||||
)
|
||||
@@ -0,0 +1,31 @@
|
||||
"""add nullable to persona id in Chat Session
|
||||
|
||||
Revision ID: c99d76fcd298
|
||||
Revises: 5c7fdadae813
|
||||
Create Date: 2024-07-09 19:27:01.579697
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c99d76fcd298"
|
||||
down_revision = "5c7fdadae813"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"persona_id",
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=False,
|
||||
)
|
||||
@@ -0,0 +1,32 @@
|
||||
"""standard answer match_regex flag
|
||||
|
||||
Revision ID: efb35676026c
|
||||
Revises: 0ebb1d516877
|
||||
Create Date: 2024-09-11 13:55:46.101149
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "efb35676026c"
|
||||
down_revision = "0ebb1d516877"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"standard_answer",
|
||||
sa.Column(
|
||||
"match_regex", sa.Boolean(), nullable=False, server_default=sa.false()
|
||||
),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("standard_answer", "match_regex")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,26 @@
|
||||
"""add custom headers to tools
|
||||
|
||||
Revision ID: f32615f71aeb
|
||||
Revises: bd2921608c3a
|
||||
Create Date: 2024-09-12 20:26:38.932377
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f32615f71aeb"
|
||||
down_revision = "bd2921608c3a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"tool", sa.Column("custom_headers", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("tool", "custom_headers")
|
||||
@@ -0,0 +1,26 @@
|
||||
"""add has_web_login column to user
|
||||
|
||||
Revision ID: f7e58d357687
|
||||
Revises: ba98eba0f66a
|
||||
Create Date: 2024-09-07 20:20:54.522620
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f7e58d357687"
|
||||
down_revision = "ba98eba0f66a"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "has_web_login")
|
||||
@@ -1,26 +1,81 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.access.utils import prefix_user
|
||||
from danswer.access.utils import prefix_user_email
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
from danswer.db.document import get_acccess_info_for_documents
|
||||
from danswer.db.document import get_access_info_for_document
|
||||
from danswer.db.document import get_access_info_for_documents
|
||||
from danswer.db.models import User
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
def _get_access_for_document(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> DocumentAccess:
|
||||
info = get_access_info_for_document(
|
||||
db_session=db_session,
|
||||
document_id=document_id,
|
||||
)
|
||||
|
||||
return DocumentAccess.build(
|
||||
user_emails=info[1] if info and info[1] else [],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=info[2] if info else False,
|
||||
)
|
||||
|
||||
|
||||
def get_access_for_document(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> DocumentAccess:
|
||||
versioned_get_access_for_document_fn = fetch_versioned_implementation(
|
||||
"danswer.access.access", "_get_access_for_document"
|
||||
)
|
||||
return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore
|
||||
|
||||
|
||||
def get_null_document_access() -> DocumentAccess:
|
||||
return DocumentAccess(
|
||||
user_emails=set(),
|
||||
user_groups=set(),
|
||||
is_public=False,
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
)
|
||||
|
||||
|
||||
def _get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
document_access_info = get_acccess_info_for_documents(
|
||||
document_access_info = get_access_info_for_documents(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
return {
|
||||
document_id: DocumentAccess.build(user_ids, [], is_public)
|
||||
for document_id, user_ids, is_public in document_access_info
|
||||
doc_access = {
|
||||
document_id: DocumentAccess(
|
||||
user_emails=set([email for email in user_emails if email]),
|
||||
# MIT version will wipe all groups and external groups on update
|
||||
user_groups=set(),
|
||||
is_public=is_public,
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
)
|
||||
for document_id, user_emails, is_public in document_access_info
|
||||
}
|
||||
|
||||
# Sometimes the document has not be indexed by the indexing job yet, in those cases
|
||||
# the document does not exist and so we use least permissive. Specifically the EE version
|
||||
# checks the MIT version permissions and creates a superset. This ensures that this flow
|
||||
# does not fail even if the Document has not yet been indexed.
|
||||
for doc_id in document_ids:
|
||||
if doc_id not in doc_access:
|
||||
doc_access[doc_id] = get_null_document_access()
|
||||
return doc_access
|
||||
|
||||
|
||||
def get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
@@ -42,7 +97,7 @@ def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
matches one entry in the returned set.
|
||||
"""
|
||||
if user:
|
||||
return {prefix_user(str(user.id)), PUBLIC_DOC_PAT}
|
||||
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}
|
||||
return {PUBLIC_DOC_PAT}
|
||||
|
||||
|
||||
|
||||
@@ -1,30 +1,72 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from danswer.access.utils import prefix_user
|
||||
from danswer.access.utils import prefix_external_group
|
||||
from danswer.access.utils import prefix_user_email
|
||||
from danswer.access.utils import prefix_user_group
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocumentAccess:
|
||||
user_ids: set[str] # stringified UUIDs
|
||||
user_groups: set[str] # names of user groups associated with this document
|
||||
class ExternalAccess:
|
||||
# Emails of external users with access to the doc externally
|
||||
external_user_emails: set[str]
|
||||
# Names or external IDs of groups with access to the doc
|
||||
external_user_group_ids: set[str]
|
||||
# Whether the document is public in the external system or Danswer
|
||||
is_public: bool
|
||||
|
||||
def to_acl(self) -> list[str]:
|
||||
return (
|
||||
[prefix_user(user_id) for user_id in self.user_ids]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocumentAccess(ExternalAccess):
|
||||
# User emails for Danswer users, None indicates admin
|
||||
user_emails: set[str | None]
|
||||
# Names of user groups associated with this document
|
||||
user_groups: set[str]
|
||||
|
||||
def to_acl(self) -> set[str]:
|
||||
return set(
|
||||
[
|
||||
prefix_user_email(user_email)
|
||||
for user_email in self.user_emails
|
||||
if user_email
|
||||
]
|
||||
+ [prefix_user_group(group_name) for group_name in self.user_groups]
|
||||
+ [
|
||||
prefix_user_email(user_email)
|
||||
for user_email in self.external_user_emails
|
||||
]
|
||||
+ [
|
||||
# The group names are already prefixed by the source type
|
||||
# This adds an additional prefix of "external_group:"
|
||||
prefix_external_group(group_name)
|
||||
for group_name in self.external_user_group_ids
|
||||
]
|
||||
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, user_ids: list[UUID | None], user_groups: list[str], is_public: bool
|
||||
cls,
|
||||
user_emails: list[str | None],
|
||||
user_groups: list[str],
|
||||
external_user_emails: list[str],
|
||||
external_user_group_ids: list[str],
|
||||
is_public: bool,
|
||||
) -> "DocumentAccess":
|
||||
return cls(
|
||||
user_ids={str(user_id) for user_id in user_ids if user_id},
|
||||
external_user_emails={
|
||||
prefix_user_email(external_email)
|
||||
for external_email in external_user_emails
|
||||
},
|
||||
external_user_group_ids={
|
||||
prefix_external_group(external_group_id)
|
||||
for external_group_id in external_user_group_ids
|
||||
},
|
||||
user_emails={
|
||||
prefix_user_email(user_email)
|
||||
for user_email in user_emails
|
||||
if user_email
|
||||
},
|
||||
user_groups=set(user_groups),
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
@@ -1,10 +1,24 @@
|
||||
def prefix_user(user_id: str) -> str:
|
||||
"""Prefixes a user ID to eliminate collision with group names.
|
||||
This assumes that groups are prefixed with a different prefix."""
|
||||
return f"user_id:{user_id}"
|
||||
from danswer.configs.constants import DocumentSource
|
||||
|
||||
|
||||
def prefix_user_email(user_email: str) -> str:
|
||||
"""Prefixes a user email to eliminate collision with group names.
|
||||
This applies to both a Danswer user and an External user, this is to make the query time
|
||||
more efficient"""
|
||||
return f"user_email:{user_email}"
|
||||
|
||||
|
||||
def prefix_user_group(user_group_name: str) -> str:
|
||||
"""Prefixes a user group name to eliminate collision with user IDs.
|
||||
"""Prefixes a user group name to eliminate collision with user emails.
|
||||
This assumes that user ids are prefixed with a different prefix."""
|
||||
return f"group:{user_group_name}"
|
||||
|
||||
|
||||
def prefix_external_group(ext_group_name: str) -> str:
|
||||
"""Prefixes an external group name to eliminate collision with user emails / Danswer groups."""
|
||||
return f"external_group:{ext_group_name}"
|
||||
|
||||
|
||||
def prefix_group_w_source(ext_group_name: str, source: DocumentSource) -> str:
|
||||
"""External groups may collide across sources, every source needs its own prefix."""
|
||||
return f"{source.value.upper()}_{ext_group_name}"
|
||||
|
||||
@@ -33,7 +33,9 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
has_web_login: bool | None = True
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
has_web_login: bool | None = True
|
||||
|
||||
@@ -16,7 +16,9 @@ from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi_users import BaseUserManager
|
||||
from fastapi_users import exceptions
|
||||
from fastapi_users import FastAPIUsers
|
||||
from fastapi_users import models
|
||||
from fastapi_users import schemas
|
||||
@@ -33,6 +35,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.auth.invited_users import get_invited_users
|
||||
from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import EMAIL_FROM
|
||||
@@ -67,23 +70,6 @@ from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def validate_curator_request(groups: list | None, is_public: bool) -> None:
|
||||
if is_public:
|
||||
detail = "Curators cannot create public objects"
|
||||
logger.error(detail)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=detail,
|
||||
)
|
||||
if not groups:
|
||||
detail = "Curators must specify 1+ groups"
|
||||
logger.error(detail)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
@@ -201,7 +187,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_create: schemas.UC | UserCreate,
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> models.UP:
|
||||
) -> User:
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if hasattr(user_create, "role"):
|
||||
@@ -210,7 +196,27 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
return await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
user = None
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if (
|
||||
not user.has_web_login
|
||||
and hasattr(user_create, "has_web_login")
|
||||
and user_create.has_web_login
|
||||
):
|
||||
user_update = UserUpdate(
|
||||
password=user_create.password,
|
||||
has_web_login=True,
|
||||
role=user_create.role,
|
||||
is_verified=user_create.is_verified,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
else:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
return user
|
||||
|
||||
async def oauth_callback(
|
||||
self: "BaseUserManager[models.UOAP, models.ID]",
|
||||
@@ -251,6 +257,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
if user.oidc_expiry and not TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
await self.user_db.update(user, update_dict={"oidc_expiry": None})
|
||||
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if not user.has_web_login:
|
||||
await self.user_db.update(
|
||||
user,
|
||||
update_dict={
|
||||
"is_verified": is_verified_by_default,
|
||||
"has_web_login": True,
|
||||
},
|
||||
)
|
||||
user.is_verified = is_verified_by_default
|
||||
user.has_web_login = True
|
||||
|
||||
return user
|
||||
|
||||
async def on_after_register(
|
||||
@@ -279,6 +297,32 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
send_user_verification_email(user.email, token)
|
||||
|
||||
async def authenticate(
|
||||
self, credentials: OAuth2PasswordRequestForm
|
||||
) -> Optional[User]:
|
||||
try:
|
||||
user = await self.get_by_email(credentials.username)
|
||||
except exceptions.UserNotExists:
|
||||
self.password_helper.hash(credentials.password)
|
||||
return None
|
||||
|
||||
if not user.has_web_login:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||
)
|
||||
|
||||
verified, updated_password_hash = self.password_helper.verify_and_update(
|
||||
credentials.password, user.hashed_password
|
||||
)
|
||||
if not verified:
|
||||
return None
|
||||
|
||||
if updated_password_hash is not None:
|
||||
await self.user_db.update(user, {"hashed_password": updated_password_hash})
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_user_manager(
|
||||
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
|
||||
@@ -381,6 +425,7 @@ async def optional_user(
|
||||
async def double_check_user(
|
||||
user: User | None,
|
||||
optional: bool = DISABLE_AUTH,
|
||||
include_expired: bool = False,
|
||||
) -> User | None:
|
||||
if optional:
|
||||
return None
|
||||
@@ -397,7 +442,11 @@ async def double_check_user(
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
|
||||
if (
|
||||
user.oidc_expiry
|
||||
and user.oidc_expiry < datetime.now(timezone.utc)
|
||||
and not include_expired
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
@@ -406,6 +455,12 @@ async def double_check_user(
|
||||
return user
|
||||
|
||||
|
||||
async def current_user_with_expired_token(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
return await double_check_user(user, include_expired=True)
|
||||
|
||||
|
||||
async def current_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
361
backend/danswer/background/celery/celery_redis.py
Normal file
361
backend/danswer/background/celery/celery_redis.py
Normal file
@@ -0,0 +1,361 @@
|
||||
# These are helper objects for tracking the keys we need to write in redis
|
||||
import time
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celeryconfig import CELERY_SEPARATOR
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import construct_document_select_for_connector_credential_pair
|
||||
from danswer.db.document import (
|
||||
construct_document_select_for_connector_credential_pair_by_needs_sync,
|
||||
)
|
||||
from danswer.db.document_set import construct_document_select_by_docset
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
class RedisObjectHelper(ABC):
|
||||
PREFIX = "base"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int):
|
||||
self._id: int = id
|
||||
|
||||
@property
|
||||
def task_id_prefix(self) -> str:
|
||||
return f"{self.PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def fence_key(self) -> str:
|
||||
# example: documentset_fence_1
|
||||
return f"{self.FENCE_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def taskset_key(self) -> str:
|
||||
# example: documentset_taskset_1
|
||||
return f"{self.TASKSET_PREFIX}_{self._id}"
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_fence_key(key: str) -> int | None:
|
||||
"""
|
||||
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
|
||||
|
||||
Args:
|
||||
key (str): The fence key string.
|
||||
|
||||
Returns:
|
||||
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
|
||||
"""
|
||||
parts = key.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
object_id = int(parts[2])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
return object_id
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_task_id(task_id: str) -> int | None:
|
||||
"""
|
||||
Extracts the object ID from a task ID string.
|
||||
|
||||
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
|
||||
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
|
||||
- `objectid` is the ID you want to extract,
|
||||
- `suffix` is another arbitrary string (e.g., a UUID).
|
||||
|
||||
Example:
|
||||
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
|
||||
this method will return the string `"1"`.
|
||||
|
||||
Args:
|
||||
task_id (str): The task ID string from which to extract the object ID.
|
||||
|
||||
Returns:
|
||||
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
|
||||
"""
|
||||
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
|
||||
parts = task_id.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
object_id = int(parts[1])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
return object_id
|
||||
|
||||
@abstractmethod
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
pass
|
||||
|
||||
|
||||
class RedisDocumentSet(RedisObjectHelper):
|
||||
PREFIX = "documentset"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
stmt = construct_document_select_by_docset(self._id, current_only=False)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisUserGroup(RedisObjectHelper):
|
||||
PREFIX = "usergroup"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
|
||||
try:
|
||||
construct_document_select_by_usergroup = fetch_versioned_implementation(
|
||||
"danswer.db.user_group",
|
||||
"construct_document_select_by_usergroup",
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
return 0
|
||||
|
||||
stmt = construct_document_select_by_usergroup(self._id)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
"""This class differs from the default in that the taskset used spans
|
||||
all connectors and is not per connector."""
|
||||
|
||||
PREFIX = "connectorsync"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
@classmethod
|
||||
def get_fence_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.FENCE_PREFIX
|
||||
|
||||
@classmethod
|
||||
def get_taskset_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.TASKSET_PREFIX
|
||||
|
||||
@property
|
||||
def taskset_key(self) -> str:
|
||||
"""Notice that this is intentionally reusing the same taskset for all
|
||||
connector syncs"""
|
||||
# example: connector_taskset
|
||||
return f"{self.TASKSET_PREFIX}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(
|
||||
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
|
||||
)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorDeletion(RedisObjectHelper):
|
||||
PREFIX = "connectordeletion"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc.id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
"""This is a redis specific way to get the length of a celery queue.
|
||||
It is priority aware and knows how to count across the multiple redis lists
|
||||
used to implement task prioritization.
|
||||
This operation is not atomic."""
|
||||
total_length = 0
|
||||
for i in range(len(DanswerCeleryPriority)):
|
||||
queue_name = queue
|
||||
if i > 0:
|
||||
queue_name += CELERY_SEPARATOR
|
||||
queue_name += str(i)
|
||||
|
||||
length = r.llen(queue_name)
|
||||
total_length += cast(int, length)
|
||||
|
||||
return total_length
|
||||
@@ -3,9 +3,8 @@ from datetime import timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
@@ -16,30 +15,44 @@ from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.enums import TaskStatus
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import TaskQueueState
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.db.tasks import get_latest_task_by_type
|
||||
from danswer.redis.redis_pool import RedisPool
|
||||
from danswer.server.documents.models import DeletionAttemptSnapshot
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
redis_pool = RedisPool()
|
||||
|
||||
|
||||
def _get_deletion_status(
|
||||
connector_id: int, credential_id: int, db_session: Session
|
||||
) -> TaskQueueState | None:
|
||||
cleanup_task_name = name_cc_cleanup_task(
|
||||
connector_id=connector_id, credential_id=credential_id
|
||||
"""We no longer store TaskQueueState in the DB for a deletion attempt.
|
||||
This function populates TaskQueueState by just checking redis.
|
||||
"""
|
||||
cc_pair = get_connector_credential_pair(
|
||||
connector_id=connector_id, credential_id=credential_id, db_session=db_session
|
||||
)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
|
||||
r = redis_pool.get_client()
|
||||
if not r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
return TaskQueueState(
|
||||
task_id="", task_name=rcd.fence_key, status=TaskStatus.STARTED
|
||||
)
|
||||
return get_latest_task(task_name=cleanup_task_name, db_session=db_session)
|
||||
|
||||
|
||||
def get_deletion_attempt_snapshot(
|
||||
@@ -56,46 +69,6 @@ def get_deletion_attempt_snapshot(
|
||||
)
|
||||
|
||||
|
||||
def should_kick_off_deletion_of_cc_pair(
|
||||
cc_pair: ConnectorCredentialPair, db_session: Session
|
||||
) -> bool:
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
||||
return False
|
||||
|
||||
if check_deletion_attempt_is_allowed(cc_pair, db_session):
|
||||
return False
|
||||
|
||||
deletion_task = _get_deletion_status(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if deletion_task and check_task_is_live_and_not_timed_out(
|
||||
deletion_task,
|
||||
db_session,
|
||||
# 1 hour timeout
|
||||
timeout=60 * 60,
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool:
|
||||
if document_set.is_up_to_date:
|
||||
return False
|
||||
|
||||
task_name = name_document_set_sync_task(document_set.id)
|
||||
latest_sync = get_latest_task(task_name, db_session)
|
||||
|
||||
if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session):
|
||||
logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.")
|
||||
return False
|
||||
|
||||
logger.info(f"Document set {document_set.id} syncing now.")
|
||||
return True
|
||||
|
||||
|
||||
def should_prune_cc_pair(
|
||||
connector: Connector, credential: Credential, db_session: Session
|
||||
) -> bool:
|
||||
|
||||
76
backend/danswer/background/celery/celeryconfig.py
Normal file
76
backend/danswer/background/celery/celeryconfig.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
|
||||
from danswer.configs.app_configs import CELERY_RESULT_EXPIRES
|
||||
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
|
||||
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY_RESULT_BACKEND
|
||||
from danswer.configs.app_configs import REDIS_HOST
|
||||
from danswer.configs.app_configs import REDIS_PASSWORD
|
||||
from danswer.configs.app_configs import REDIS_PORT
|
||||
from danswer.configs.app_configs import REDIS_SSL
|
||||
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
|
||||
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
|
||||
CELERY_SEPARATOR = ":"
|
||||
|
||||
CELERY_PASSWORD_PART = ""
|
||||
if REDIS_PASSWORD:
|
||||
CELERY_PASSWORD_PART = f":{REDIS_PASSWORD}@"
|
||||
|
||||
REDIS_SCHEME = "redis"
|
||||
|
||||
# SSL-specific query parameters for Redis URL
|
||||
SSL_QUERY_PARAMS = ""
|
||||
if REDIS_SSL:
|
||||
REDIS_SCHEME = "rediss"
|
||||
SSL_QUERY_PARAMS = f"?ssl_cert_reqs={REDIS_SSL_CERT_REQS}"
|
||||
if REDIS_SSL_CA_CERTS:
|
||||
SSL_QUERY_PARAMS += f"&ssl_ca_certs={REDIS_SSL_CA_CERTS}"
|
||||
|
||||
# example celery_broker_url: "redis://:password@localhost:6379/15"
|
||||
broker_url = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
|
||||
|
||||
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY_RESULT_BACKEND}{SSL_QUERY_PARAMS}"
|
||||
|
||||
# NOTE: prefetch 4 is significantly faster than prefetch 1 for small tasks
|
||||
# however, prefetching is bad when tasks are lengthy as those tasks
|
||||
# can stall other tasks.
|
||||
worker_prefetch_multiplier = 4
|
||||
|
||||
broker_transport_options = {
|
||||
"priority_steps": list(range(len(DanswerCeleryPriority))),
|
||||
"sep": CELERY_SEPARATOR,
|
||||
"queue_order_strategy": "priority",
|
||||
}
|
||||
|
||||
task_default_priority = DanswerCeleryPriority.MEDIUM
|
||||
task_acks_late = True
|
||||
|
||||
# It's possible we don't even need celery's result backend, in which case all of the optimization below
|
||||
# might be irrelevant
|
||||
result_expires = CELERY_RESULT_EXPIRES # 86400 seconds is the default
|
||||
|
||||
# Option 0: Defaults (json serializer, no compression)
|
||||
# about 1.5 KB per queued task. 1KB in queue, 400B for result, 100 as a child entry in generator result
|
||||
|
||||
# Option 1: Reduces generator task result sizes by roughly 20%
|
||||
# task_compression = "bzip2"
|
||||
# task_serializer = "pickle"
|
||||
# result_compression = "bzip2"
|
||||
# result_serializer = "pickle"
|
||||
# accept_content=["pickle"]
|
||||
|
||||
# Option 2: this significantly reduces the size of the result for generator tasks since the list of children
|
||||
# can be large. small tasks change very little
|
||||
# def pickle_bz2_encoder(data):
|
||||
# return bz2.compress(pickle.dumps(data))
|
||||
|
||||
# def pickle_bz2_decoder(data):
|
||||
# return pickle.loads(bz2.decompress(data))
|
||||
|
||||
# from kombu import serialization # To register custom serialization with Celery/Kombu
|
||||
|
||||
# serialization.register('pickle-bzip2', pickle_bz2_encoder, pickle_bz2_decoder, 'application/x-pickle-bz2', 'binary')
|
||||
|
||||
# task_serializer = "pickle-bzip2"
|
||||
# result_serializer = "pickle-bzip2"
|
||||
# accept_content=["pickle", "pickle-bzip2"]
|
||||
@@ -13,28 +13,16 @@ connector / credential pair from the access list
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
)
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document_connector_cnts
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.document import get_document_connector_counts
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from danswer.db.document_set import fetch_document_sets_for_documents
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from danswer.utils.variable_functionality import noop_fallback
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -57,13 +45,15 @@ def delete_connector_credential_pair_batch(
|
||||
with prepare_to_modify_documents(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
):
|
||||
document_connector_cnts = get_document_connector_cnts(
|
||||
document_connector_counts = get_document_connector_counts(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
)
|
||||
|
||||
# figure out which docs need to be completely deleted
|
||||
document_ids_to_delete = [
|
||||
document_id for document_id, cnt in document_connector_cnts if cnt == 1
|
||||
document_id
|
||||
for document_id, cnt in document_connector_counts
|
||||
if cnt == 1
|
||||
]
|
||||
logger.debug(f"Deleting documents: {document_ids_to_delete}")
|
||||
|
||||
@@ -76,7 +66,7 @@ def delete_connector_credential_pair_batch(
|
||||
|
||||
# figure out which docs need to be updated
|
||||
document_ids_to_update = [
|
||||
document_id for document_id, cnt in document_connector_cnts if cnt > 1
|
||||
document_id for document_id, cnt in document_connector_counts if cnt > 1
|
||||
]
|
||||
|
||||
# maps document id to list of document set names
|
||||
@@ -109,7 +99,7 @@ def delete_connector_credential_pair_batch(
|
||||
document_index.update(update_requests=update_requests)
|
||||
|
||||
# clean up Postgres
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
delete_documents_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_update,
|
||||
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
@@ -118,79 +108,3 @@ def delete_connector_credential_pair_batch(
|
||||
),
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_connector_credential_pair(
|
||||
db_session: Session,
|
||||
document_index: DocumentIndex,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> int:
|
||||
connector_id = cc_pair.connector_id
|
||||
credential_id = cc_pair.credential_id
|
||||
|
||||
num_docs_deleted = 0
|
||||
while True:
|
||||
documents = get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
limit=_DELETION_BATCH_SIZE,
|
||||
)
|
||||
if not documents:
|
||||
break
|
||||
|
||||
delete_connector_credential_pair_batch(
|
||||
document_ids=[document.id for document in documents],
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
document_index=document_index,
|
||||
)
|
||||
num_docs_deleted += len(documents)
|
||||
|
||||
# clean up the rest of the related Postgres entities
|
||||
# index attempts
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
# document sets
|
||||
delete_document_set_cc_pair_relationship__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
# user groups
|
||||
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.user_group",
|
||||
"delete_user_group_cc_pair_relationship__no_commit",
|
||||
noop_fallback,
|
||||
)
|
||||
cleanup_user_groups(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
# if there are no credentials left, delete the connector
|
||||
connector = fetch_connector_by_id(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
logger.info("Found no credentials left for connector, deleting connector")
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
|
||||
logger.notice(
|
||||
"Successfully deleted connector_credential_pair with connector_id:"
|
||||
f" '{connector_id}' and credential_id: '{credential_id}'. Deleted {num_docs_deleted} docs."
|
||||
)
|
||||
return num_docs_deleted
|
||||
|
||||
@@ -56,11 +56,11 @@ def _get_connector_runner(
|
||||
|
||||
try:
|
||||
runnable_connector = instantiate_connector(
|
||||
attempt.connector_credential_pair.connector.source,
|
||||
task,
|
||||
attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
attempt.connector_credential_pair.credential,
|
||||
db_session,
|
||||
db_session=db_session,
|
||||
source=attempt.connector_credential_pair.connector.source,
|
||||
input_type=task,
|
||||
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
credential=attempt.connector_credential_pair.credential,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
@@ -384,17 +384,22 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
|
||||
return attempt
|
||||
|
||||
|
||||
def run_indexing_entrypoint(index_attempt_id: int, is_ee: bool = False) -> None:
|
||||
def run_indexing_entrypoint(
|
||||
index_attempt_id: int, connector_credential_pair_id: int, is_ee: bool = False
|
||||
) -> None:
|
||||
"""Entrypoint for indexing run when using dask distributed.
|
||||
Wraps the actual logic in a `try` block so that we can catch any exceptions
|
||||
and mark the attempt as failed."""
|
||||
|
||||
try:
|
||||
if is_ee:
|
||||
global_version.set_ee()
|
||||
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
|
||||
IndexAttemptSingleton.set_cc_and_index_id(
|
||||
index_attempt_id, connector_credential_pair_id
|
||||
)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# make sure that it is valid to run this indexing attempt + mark it
|
||||
|
||||
@@ -14,14 +14,6 @@ from danswer.db.tasks import mark_task_start
|
||||
from danswer.db.tasks import register_task
|
||||
|
||||
|
||||
def name_cc_cleanup_task(connector_id: int, credential_id: int) -> str:
|
||||
return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"
|
||||
|
||||
|
||||
def name_document_set_sync_task(document_set_id: int) -> str:
|
||||
return f"sync_doc_set_{document_set_id}"
|
||||
|
||||
|
||||
def name_cc_prune_task(
|
||||
connector_id: int | None = None, credential_id: int | None = None
|
||||
) -> str:
|
||||
@@ -93,9 +85,16 @@ def build_apply_async_wrapper(build_name_fn: Callable[..., str]) -> Callable[[AA
|
||||
kwargs_for_build_name = kwargs or {}
|
||||
task_name = build_name_fn(*args_for_build_name, **kwargs_for_build_name)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# mark the task as started
|
||||
# register_task must come before fn = apply_async or else the task
|
||||
# might run mark_task_start (and crash) before the task row exists
|
||||
db_task = register_task(task_name, db_session)
|
||||
|
||||
task = fn(args, kwargs, *other_args, **other_kwargs)
|
||||
register_task(task.id, task_name, db_session)
|
||||
|
||||
# we update the celery task id for diagnostic purposes
|
||||
# but it isn't currently used by any code
|
||||
db_task.task_id = task.id
|
||||
db_session.commit()
|
||||
|
||||
return task
|
||||
|
||||
|
||||
@@ -47,7 +47,6 @@ from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import LOG_LEVEL
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# If the indexing dies, it's most likely due to resource constraints,
|
||||
@@ -68,7 +67,7 @@ def _should_create_new_indexing(
|
||||
) -> bool:
|
||||
connector = cc_pair.connector
|
||||
|
||||
# don't kick off indexing for `NOT_APPLICABLE`
|
||||
# don't kick off indexing for `NOT_APPLICABLE` sources
|
||||
if connector.source == DocumentSource.NOT_APPLICABLE:
|
||||
return False
|
||||
|
||||
@@ -212,7 +211,6 @@ def cleanup_indexing_jobs(
|
||||
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
|
||||
) -> dict[int, Future | SimpleJob]:
|
||||
existing_jobs_copy = existing_jobs.copy()
|
||||
|
||||
# clean up completed jobs
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for attempt_id, job in existing_jobs.items():
|
||||
@@ -313,7 +311,12 @@ def kickoff_indexing_jobs(
|
||||
|
||||
indexing_attempt_count = 0
|
||||
|
||||
primary_client_full = False
|
||||
secondary_client_full = False
|
||||
for attempt, search_settings in new_indexing_attempts:
|
||||
if primary_client_full and secondary_client_full:
|
||||
break
|
||||
|
||||
use_secondary_index = (
|
||||
search_settings.status == IndexModelStatus.FUTURE
|
||||
if search_settings is not None
|
||||
@@ -338,20 +341,28 @@ def kickoff_indexing_jobs(
|
||||
)
|
||||
continue
|
||||
|
||||
if use_secondary_index:
|
||||
run = secondary_client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
if not use_secondary_index:
|
||||
if not primary_client_full:
|
||||
run = client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
attempt.connector_credential_pair_id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
if not run:
|
||||
primary_client_full = True
|
||||
else:
|
||||
run = client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
if not secondary_client_full:
|
||||
run = secondary_client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
attempt.connector_credential_pair_id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
if not run:
|
||||
secondary_client_full = True
|
||||
|
||||
if run:
|
||||
if indexing_attempt_count == 0:
|
||||
|
||||
@@ -122,7 +122,7 @@ def load_personas_from_yaml(
|
||||
prompt_ids=prompt_ids,
|
||||
document_set_ids=doc_set_ids,
|
||||
tool_ids=tool_ids,
|
||||
default_persona=True,
|
||||
builtin_persona=True,
|
||||
is_public=True,
|
||||
display_priority=existing_persona.display_priority
|
||||
if existing_persona is not None
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -44,8 +45,26 @@ class QADocsResponse(RetrievalDocs):
|
||||
return initial_dict
|
||||
|
||||
|
||||
class StreamStopReason(Enum):
|
||||
CONTEXT_LENGTH = "context_length"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class StreamStopInfo(BaseModel):
|
||||
stop_reason: StreamStopReason
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
data["stop_reason"] = self.stop_reason.name
|
||||
return data
|
||||
|
||||
|
||||
class LLMRelevanceFilterResponse(BaseModel):
|
||||
relevant_chunk_indices: list[int]
|
||||
llm_selected_doc_indices: list[int]
|
||||
|
||||
|
||||
class FinalUsedContextDocsResponse(BaseModel):
|
||||
final_context_docs: list[LlmDoc]
|
||||
|
||||
|
||||
class RelevanceAnalysis(BaseModel):
|
||||
@@ -78,6 +97,16 @@ class CitationInfo(BaseModel):
|
||||
document_id: str
|
||||
|
||||
|
||||
class AllCitations(BaseModel):
|
||||
citations: list[CitationInfo]
|
||||
|
||||
|
||||
# This is a mapping of the citation number to the document index within
|
||||
# the result search doc set
|
||||
class MessageSpecificCitations(BaseModel):
|
||||
citation_map: dict[int, int]
|
||||
|
||||
|
||||
class MessageResponseIDInfo(BaseModel):
|
||||
user_message_id: int | None
|
||||
reserved_assistant_message_id: int
|
||||
@@ -123,7 +152,7 @@ class QAResponse(SearchResponse, DanswerAnswer):
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
eval_res_valid: bool | None = None
|
||||
llm_chunks_indices: list[int] | None = None
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
|
||||
|
||||
@@ -144,6 +173,7 @@ AnswerQuestionPossibleReturn = (
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,12 +7,15 @@ from typing import cast
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.models import AllCitations
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import CustomToolResponse
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import FinalUsedContextDocsResponse
|
||||
from danswer.chat.models import ImageGenerationDisplay
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import MessageResponseIDInfo
|
||||
from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
@@ -70,7 +73,9 @@ from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
|
||||
from danswer.tools.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from danswer.tools.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.force import ForceUseTool
|
||||
@@ -85,6 +90,8 @@ from danswer.tools.internet_search.internet_search_tool import (
|
||||
)
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
@@ -100,9 +107,9 @@ from danswer.utils.timing import log_generator_function_time
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def translate_citations(
|
||||
def _translate_citations(
|
||||
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
|
||||
) -> dict[int, int]:
|
||||
) -> MessageSpecificCitations:
|
||||
"""Always cites the first instance of the document_id, assumes the db_docs
|
||||
are sorted in the order displayed in the UI"""
|
||||
doc_id_to_saved_doc_id_map: dict[str, int] = {}
|
||||
@@ -117,7 +124,7 @@ def translate_citations(
|
||||
citation.citation_num
|
||||
] = doc_id_to_saved_doc_id_map[citation.document_id]
|
||||
|
||||
return citation_to_saved_doc_id_map
|
||||
return MessageSpecificCitations(citation_map=citation_to_saved_doc_id_map)
|
||||
|
||||
|
||||
def _handle_search_tool_response_summary(
|
||||
@@ -239,11 +246,14 @@ ChatPacket = (
|
||||
StreamingError
|
||||
| QADocsResponse
|
||||
| LLMRelevanceFilterResponse
|
||||
| FinalUsedContextDocsResponse
|
||||
| ChatMessageDetail
|
||||
| DanswerAnswerPiece
|
||||
| AllCitations
|
||||
| CitationInfo
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| MessageSpecificCitations
|
||||
| MessageResponseIDInfo
|
||||
)
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
@@ -263,6 +273,7 @@ def stream_chat_message_objects(
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
enforce_chat_session_id_for_search_docs: bool = True,
|
||||
) -> ChatPacketStream:
|
||||
"""Streams in order:
|
||||
1. [conditional] Retrieved documents if a search needs to be run
|
||||
@@ -434,6 +445,7 @@ def stream_chat_message_objects(
|
||||
chat_session=chat_session,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
enforce_chat_session_id_for_search_docs=enforce_chat_session_id_for_search_docs,
|
||||
)
|
||||
|
||||
# Generates full documents currently
|
||||
@@ -597,8 +609,13 @@ def stream_chat_message_objects(
|
||||
if db_tool_model.openapi_schema:
|
||||
tool_dict[db_tool_model.id] = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema(
|
||||
db_tool_model.openapi_schema
|
||||
build_custom_tools_from_openapi_schema_and_headers(
|
||||
db_tool_model.openapi_schema,
|
||||
dynamic_schema_info=DynamicSchemaInfo(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
),
|
||||
custom_headers=db_tool_model.custom_headers,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -663,9 +680,11 @@ def stream_chat_message_objects(
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
dedupe_docs=retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False,
|
||||
dedupe_docs=(
|
||||
retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False
|
||||
),
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
@@ -688,9 +707,14 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
relevant_chunk_indices=llm_indices
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(
|
||||
final_context_docs=packet.response
|
||||
)
|
||||
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], packet.response
|
||||
@@ -727,10 +751,18 @@ def stream_chat_message_objects(
|
||||
tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
logger.debug("Reached end of stream")
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.exception(f"Failed to process chat message: {error_msg}")
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
|
||||
error_msg = str(e)
|
||||
yield StreamingError(error=error_msg)
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
|
||||
error_msg = str(e)
|
||||
stack_trace = traceback.format_exc()
|
||||
client_error_msg = litellm_exception_to_error_msg(e, llm)
|
||||
if llm.config.api_key and len(llm.config.api_key) > 2:
|
||||
@@ -743,12 +775,13 @@ def stream_chat_message_objects(
|
||||
|
||||
# Post-LLM answer processing
|
||||
try:
|
||||
db_citations = None
|
||||
message_specific_citations: MessageSpecificCitations | None = None
|
||||
if reference_db_search_docs:
|
||||
db_citations = translate_citations(
|
||||
message_specific_citations = _translate_citations(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
yield AllCitations(citations=answer.citations)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
@@ -765,18 +798,22 @@ def stream_chat_message_objects(
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=db_citations,
|
||||
citations=message_specific_citations.citation_map
|
||||
if message_specific_citations
|
||||
else None,
|
||||
error=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
]
|
||||
if tool_result
|
||||
else [],
|
||||
tool_calls=(
|
||||
[
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
]
|
||||
if tool_result
|
||||
else []
|
||||
),
|
||||
)
|
||||
|
||||
logger.debug("Committing messages")
|
||||
|
||||
@@ -126,6 +126,7 @@ try:
|
||||
except ValueError:
|
||||
INDEX_BATCH_SIZE = 16
|
||||
|
||||
|
||||
# Below are intended to match the env variables names used by the official postgres docker image
|
||||
# https://hub.docker.com/_/postgres
|
||||
POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
|
||||
@@ -149,6 +150,27 @@ try:
|
||||
except ValueError:
|
||||
POSTGRES_POOL_RECYCLE = POSTGRES_POOL_RECYCLE_DEFAULT
|
||||
|
||||
REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
|
||||
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
|
||||
|
||||
# Used for general redis things
|
||||
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
|
||||
|
||||
# Used by celery as broker and backend
|
||||
REDIS_DB_NUMBER_CELERY_RESULT_BACKEND = int(
|
||||
os.environ.get("REDIS_DB_NUMBER_CELERY_RESULT_BACKEND", 14)
|
||||
)
|
||||
REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15)) # broker
|
||||
|
||||
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
|
||||
# should be one of "required", "optional", or "none"
|
||||
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "none")
|
||||
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", "")
|
||||
|
||||
CELERY_RESULT_EXPIRES = int(os.environ.get("CELERY_RESULT_EXPIRES", 86400)) # seconds
|
||||
|
||||
#####
|
||||
# Connector Configs
|
||||
#####
|
||||
|
||||
@@ -83,8 +83,15 @@ DISABLE_LLM_DOC_RELEVANCE = (
|
||||
# Stops streaming answers back to the UI if this pattern is seen:
|
||||
STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
|
||||
|
||||
# The backend logic for this being True isn't fully supported yet
|
||||
HARD_DELETE_CHATS = False
|
||||
# Set this to "true" to hard delete chats
|
||||
# This will make chats unviewable by admins after a user deletes them
|
||||
# As opposed to soft deleting them, which just hides them from non-admin users
|
||||
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true"
|
||||
|
||||
# Internet Search
|
||||
BING_API_KEY = os.environ.get("BING_API_KEY") or None
|
||||
|
||||
# Enable in-house model for detecting connector-based filtering in queries
|
||||
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)
|
||||
|
||||
VESPA_SEARCHER_THREADS = int(os.environ.get("VESPA_SEARCHER_THREADS") or 2)
|
||||
|
||||
@@ -57,9 +57,12 @@ KV_SLACK_BOT_TOKENS_CONFIG_KEY = "slack_bot_tokens_config_key"
|
||||
KV_GEN_AI_KEY_CHECK_TIME = "genai_api_key_last_check_time"
|
||||
KV_SETTINGS_KEY = "danswer_settings"
|
||||
KV_CUSTOMER_UUID_KEY = "customer_uuid"
|
||||
KV_INSTANCE_DOMAIN_KEY = "instance_domain"
|
||||
KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings"
|
||||
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
|
||||
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
# Special case, document passed in via Danswer APIs without specifying a source type
|
||||
@@ -96,6 +99,7 @@ class DocumentSource(str, Enum):
|
||||
CLICKUP = "clickup"
|
||||
MEDIAWIKI = "mediawiki"
|
||||
WIKIPEDIA = "wikipedia"
|
||||
ASANA = "asana"
|
||||
S3 = "s3"
|
||||
R2 = "r2"
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
@@ -130,6 +134,12 @@ class AuthType(str, Enum):
|
||||
SAML = "saml"
|
||||
|
||||
|
||||
class SessionType(str, Enum):
|
||||
CHAT = "Chat"
|
||||
SEARCH = "Search"
|
||||
SLACK = "Slack"
|
||||
|
||||
|
||||
class QAFeedbackType(str, Enum):
|
||||
LIKE = "like" # User likes the answer, used for metrics
|
||||
DISLIKE = "dislike" # User dislikes the answer, used for metrics
|
||||
@@ -166,3 +176,25 @@ class FileOrigin(str, Enum):
|
||||
|
||||
class PostgresAdvisoryLocks(Enum):
|
||||
KOMBU_MESSAGE_CLEANUP_LOCK_ID = auto()
|
||||
|
||||
|
||||
class DanswerCeleryQueues:
|
||||
VESPA_DOCSET_SYNC_GENERATOR = "vespa_docset_sync_generator"
|
||||
VESPA_USERGROUP_SYNC_GENERATOR = "vespa_usergroup_sync_generator"
|
||||
VESPA_METADATA_SYNC = "vespa_metadata_sync"
|
||||
CONNECTOR_DELETION = "connector_deletion"
|
||||
|
||||
|
||||
class DanswerRedisLocks:
|
||||
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
|
||||
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
|
||||
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
|
||||
MONITOR_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:monitor_connector_deletion_beat"
|
||||
|
||||
|
||||
class DanswerCeleryPriority(int, Enum):
|
||||
HIGHEST = 0
|
||||
HIGH = auto()
|
||||
MEDIUM = auto()
|
||||
LOW = auto()
|
||||
LOWEST = auto()
|
||||
|
||||
@@ -39,9 +39,13 @@ SIM_SCORE_RANGE_HIGH = float(os.environ.get("SIM_SCORE_RANGE_HIGH") or 1.0)
|
||||
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "search_query: ")
|
||||
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "search_document: ")
|
||||
# Purely an optimization, memory limitation consideration
|
||||
BATCH_SIZE_ENCODE_CHUNKS = 8
|
||||
|
||||
# User's set embedding batch size overrides the default encoding batch sizes
|
||||
EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE") or 0) or None
|
||||
|
||||
BATCH_SIZE_ENCODE_CHUNKS = EMBEDDING_BATCH_SIZE or 8
|
||||
# don't send over too many chunks at once, as sending too many could cause timeouts
|
||||
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES = 512
|
||||
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES = EMBEDDING_BATCH_SIZE or 512
|
||||
# For score display purposes, only way is to know the expected ranges
|
||||
CROSS_ENCODER_RANGE_MAX = 1
|
||||
CROSS_ENCODER_RANGE_MIN = 0
|
||||
@@ -51,33 +55,11 @@ CROSS_ENCODER_RANGE_MIN = 0
|
||||
# Generative AI Model Configs
|
||||
#####
|
||||
|
||||
# If changing GEN_AI_MODEL_PROVIDER or GEN_AI_MODEL_VERSION from the default,
|
||||
# be sure to use one that is LiteLLM compatible:
|
||||
# https://litellm.vercel.app/docs/providers/azure#completion---using-env-variables
|
||||
# The provider is the prefix before / in the model argument
|
||||
|
||||
# Additionally Danswer supports GPT4All and custom request library based models
|
||||
# Set GEN_AI_MODEL_PROVIDER to "custom" to use the custom requests approach
|
||||
# Set GEN_AI_MODEL_PROVIDER to "gpt4all" to use gpt4all models running locally
|
||||
GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai"
|
||||
# If using Azure, it's the engine name, for example: Danswer
|
||||
# NOTE: the 3 below should only be used for dev.
|
||||
GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY")
|
||||
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION")
|
||||
|
||||
# For secondary flows like extracting filters or deciding if a chunk is useful, we don't need
|
||||
# as powerful of a model as say GPT-4 so we can use an alternative that is faster and cheaper
|
||||
FAST_GEN_AI_MODEL_VERSION = os.environ.get("FAST_GEN_AI_MODEL_VERSION")
|
||||
|
||||
# If the Generative AI model requires an API key for access, otherwise can leave blank
|
||||
GEN_AI_API_KEY = (
|
||||
os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY")) or None
|
||||
)
|
||||
|
||||
# API Base, such as (for Azure): https://danswer.openai.azure.com/
|
||||
GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None
|
||||
# API Version, such as (for Azure): 2023-09-15-preview
|
||||
GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None
|
||||
# LiteLLM custom_llm_provider
|
||||
GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None
|
||||
# Override the auto-detection of LLM max context length
|
||||
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None
|
||||
|
||||
|
||||
@@ -59,6 +59,8 @@ if __name__ == "__main__":
|
||||
latest_docs = test_connector.poll_source(one_day_ago, current)
|
||||
```
|
||||
|
||||
> Note: Be sure to set PYTHONPATH to danswer/backend before running the above main.
|
||||
|
||||
|
||||
### Additional Required Changes:
|
||||
#### Backend Changes
|
||||
|
||||
233
backend/danswer/connectors/asana/asana_api.py
Executable file
233
backend/danswer/connectors/asana/asana_api.py
Executable file
@@ -0,0 +1,233 @@
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from typing import Dict
|
||||
|
||||
import asana # type: ignore
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# https://github.com/Asana/python-asana/tree/master?tab=readme-ov-file#documentation-for-api-endpoints
|
||||
class AsanaTask:
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
title: str,
|
||||
text: str,
|
||||
link: str,
|
||||
last_modified: datetime,
|
||||
project_gid: str,
|
||||
project_name: str,
|
||||
) -> None:
|
||||
self.id = id
|
||||
self.title = title
|
||||
self.text = text
|
||||
self.link = link
|
||||
self.last_modified = last_modified
|
||||
self.project_gid = project_gid
|
||||
self.project_name = project_name
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ID: {self.id}\nTitle: {self.title}\nLast modified: {self.last_modified}\nText: {self.text}"
|
||||
|
||||
|
||||
class AsanaAPI:
|
||||
def __init__(
|
||||
self, api_token: str, workspace_gid: str, team_gid: str | None
|
||||
) -> None:
|
||||
self._user = None # type: ignore
|
||||
self.workspace_gid = workspace_gid
|
||||
self.team_gid = team_gid
|
||||
|
||||
self.configuration = asana.Configuration()
|
||||
self.api_client = asana.ApiClient(self.configuration)
|
||||
self.tasks_api = asana.TasksApi(self.api_client)
|
||||
self.stories_api = asana.StoriesApi(self.api_client)
|
||||
self.users_api = asana.UsersApi(self.api_client)
|
||||
self.project_api = asana.ProjectsApi(self.api_client)
|
||||
self.workspaces_api = asana.WorkspacesApi(self.api_client)
|
||||
|
||||
self.api_error_count = 0
|
||||
self.configuration.access_token = api_token
|
||||
self.task_count = 0
|
||||
|
||||
def get_tasks(
|
||||
self, project_gids: list[str] | None, start_date: str
|
||||
) -> Iterator[AsanaTask]:
|
||||
"""Get all tasks from the projects with the given gids that were modified since the given date.
|
||||
If project_gids is None, get all tasks from all projects in the workspace."""
|
||||
logger.info("Starting to fetch Asana projects")
|
||||
projects = self.project_api.get_projects(
|
||||
opts={
|
||||
"workspace": self.workspace_gid,
|
||||
"opt_fields": "gid,name,archived,modified_at",
|
||||
}
|
||||
)
|
||||
start_seconds = int(time.mktime(datetime.now().timetuple()))
|
||||
projects_list = []
|
||||
project_count = 0
|
||||
for project_info in projects:
|
||||
project_gid = project_info["gid"]
|
||||
if project_gids is None or project_gid in project_gids:
|
||||
projects_list.append(project_gid)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Skipping project: {project_gid} - not in accepted project_gids"
|
||||
)
|
||||
project_count += 1
|
||||
if project_count % 100 == 0:
|
||||
logger.info(f"Processed {project_count} projects")
|
||||
|
||||
logger.info(f"Found {len(projects_list)} projects to process")
|
||||
for project_gid in projects_list:
|
||||
for task in self._get_tasks_for_project(
|
||||
project_gid, start_date, start_seconds
|
||||
):
|
||||
yield task
|
||||
logger.info(f"Completed fetching {self.task_count} tasks from Asana")
|
||||
if self.api_error_count > 0:
|
||||
logger.warning(
|
||||
f"Encountered {self.api_error_count} API errors during task fetching"
|
||||
)
|
||||
|
||||
def _get_tasks_for_project(
|
||||
self, project_gid: str, start_date: str, start_seconds: int
|
||||
) -> Iterator[AsanaTask]:
|
||||
project = self.project_api.get_project(project_gid, opts={})
|
||||
if project["archived"]:
|
||||
logger.info(f"Skipping archived project: {project['name']} ({project_gid})")
|
||||
return []
|
||||
if not project["team"] or not project["team"]["gid"]:
|
||||
logger.info(
|
||||
f"Skipping project without a team: {project['name']} ({project_gid})"
|
||||
)
|
||||
return []
|
||||
if project["privacy_setting"] == "private":
|
||||
if self.team_gid and project["team"]["gid"] != self.team_gid:
|
||||
logger.info(
|
||||
f"Skipping private project not in configured team: {project['name']} ({project_gid})"
|
||||
)
|
||||
return []
|
||||
else:
|
||||
logger.info(
|
||||
f"Processing private project in configured team: {project['name']} ({project_gid})"
|
||||
)
|
||||
|
||||
simple_start_date = start_date.split(".")[0].split("+")[0]
|
||||
logger.info(
|
||||
f"Fetching tasks modified since {simple_start_date} for project: {project['name']} ({project_gid})"
|
||||
)
|
||||
|
||||
opts = {
|
||||
"opt_fields": "name,memberships,memberships.project,completed_at,completed_by,created_at,"
|
||||
"created_by,custom_fields,dependencies,due_at,due_on,external,html_notes,liked,likes,"
|
||||
"modified_at,notes,num_hearts,parent,projects,resource_subtype,resource_type,start_on,"
|
||||
"workspace,permalink_url",
|
||||
"modified_since": start_date,
|
||||
}
|
||||
tasks_from_api = self.tasks_api.get_tasks_for_project(project_gid, opts)
|
||||
for data in tasks_from_api:
|
||||
self.task_count += 1
|
||||
if self.task_count % 10 == 0:
|
||||
end_seconds = time.mktime(datetime.now().timetuple())
|
||||
runtime_seconds = end_seconds - start_seconds
|
||||
if runtime_seconds > 0:
|
||||
logger.info(
|
||||
f"Processed {self.task_count} tasks in {runtime_seconds:.0f} seconds "
|
||||
f"({self.task_count / runtime_seconds:.2f} tasks/second)"
|
||||
)
|
||||
|
||||
logger.debug(f"Processing Asana task: {data['name']}")
|
||||
|
||||
text = self._construct_task_text(data)
|
||||
|
||||
try:
|
||||
text += self._fetch_and_add_comments(data["gid"])
|
||||
|
||||
last_modified_date = self.format_date(data["modified_at"])
|
||||
text += f"Last modified: {last_modified_date}\n"
|
||||
|
||||
task = AsanaTask(
|
||||
id=data["gid"],
|
||||
title=data["name"],
|
||||
text=text,
|
||||
link=data["permalink_url"],
|
||||
last_modified=datetime.fromisoformat(data["modified_at"]),
|
||||
project_gid=project_gid,
|
||||
project_name=project["name"],
|
||||
)
|
||||
yield task
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Error processing task {data['gid']} in project {project_gid}",
|
||||
exc_info=True,
|
||||
)
|
||||
self.api_error_count += 1
|
||||
|
||||
def _construct_task_text(self, data: Dict) -> str:
|
||||
text = f"{data['name']}\n\n"
|
||||
|
||||
if data["notes"]:
|
||||
text += f"{data['notes']}\n\n"
|
||||
|
||||
if data["created_by"] and data["created_by"]["gid"]:
|
||||
creator = self.get_user(data["created_by"]["gid"])["name"]
|
||||
created_date = self.format_date(data["created_at"])
|
||||
text += f"Created by: {creator} on {created_date}\n"
|
||||
|
||||
if data["due_on"]:
|
||||
due_date = self.format_date(data["due_on"])
|
||||
text += f"Due date: {due_date}\n"
|
||||
|
||||
if data["completed_at"]:
|
||||
completed_date = self.format_date(data["completed_at"])
|
||||
text += f"Completed on: {completed_date}\n"
|
||||
|
||||
text += "\n"
|
||||
return text
|
||||
|
||||
def _fetch_and_add_comments(self, task_gid: str) -> str:
|
||||
text = ""
|
||||
stories_opts: Dict[str, str] = {}
|
||||
story_start = time.time()
|
||||
stories = self.stories_api.get_stories_for_task(task_gid, stories_opts)
|
||||
|
||||
story_count = 0
|
||||
comment_count = 0
|
||||
|
||||
for story in stories:
|
||||
story_count += 1
|
||||
if story["resource_subtype"] == "comment_added":
|
||||
comment = self.stories_api.get_story(
|
||||
story["gid"], opts={"opt_fields": "text,created_by,created_at"}
|
||||
)
|
||||
commenter = self.get_user(comment["created_by"]["gid"])["name"]
|
||||
text += f"Comment by {commenter}: {comment['text']}\n\n"
|
||||
comment_count += 1
|
||||
|
||||
story_duration = time.time() - story_start
|
||||
logger.debug(
|
||||
f"Processed {story_count} stories (including {comment_count} comments) in {story_duration:.2f} seconds"
|
||||
)
|
||||
|
||||
return text
|
||||
|
||||
def get_user(self, user_gid: str) -> Dict:
|
||||
if self._user is not None:
|
||||
return self._user
|
||||
self._user = self.users_api.get_user(user_gid, {"opt_fields": "name,email"})
|
||||
|
||||
if not self._user:
|
||||
logger.warning(f"Unable to fetch user information for user_gid: {user_gid}")
|
||||
return {"name": "Unknown"}
|
||||
return self._user
|
||||
|
||||
def format_date(self, date_str: str) -> str:
|
||||
date = datetime.fromisoformat(date_str)
|
||||
return time.strftime("%Y-%m-%d", date.timetuple())
|
||||
|
||||
def get_time(self) -> str:
|
||||
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
120
backend/danswer/connectors/asana/connector.py
Executable file
120
backend/danswer/connectors/asana/connector.py
Executable file
@@ -0,0 +1,120 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.asana import asana_api
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class AsanaConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
asana_workspace_id: str,
|
||||
asana_project_ids: str | None = None,
|
||||
asana_team_id: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
) -> None:
|
||||
self.workspace_id = asana_workspace_id
|
||||
self.project_ids_to_index: list[str] | None = (
|
||||
asana_project_ids.split(",") if asana_project_ids is not None else None
|
||||
)
|
||||
self.asana_team_id = asana_team_id
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
logger.info(
|
||||
f"AsanaConnector initialized with workspace_id: {asana_workspace_id}"
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.api_token = credentials["asana_api_token_secret"]
|
||||
self.asana_client = asana_api.AsanaAPI(
|
||||
api_token=self.api_token,
|
||||
workspace_gid=self.workspace_id,
|
||||
team_gid=self.asana_team_id,
|
||||
)
|
||||
logger.info("Asana credentials loaded and API client initialized")
|
||||
return None
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_time = datetime.datetime.fromtimestamp(start).isoformat()
|
||||
logger.info(f"Starting Asana poll from {start_time}")
|
||||
asana = asana_api.AsanaAPI(
|
||||
api_token=self.api_token,
|
||||
workspace_gid=self.workspace_id,
|
||||
team_gid=self.asana_team_id,
|
||||
)
|
||||
docs_batch: list[Document] = []
|
||||
tasks = asana.get_tasks(self.project_ids_to_index, start_time)
|
||||
|
||||
for task in tasks:
|
||||
doc = self._message_to_doc(task)
|
||||
docs_batch.append(doc)
|
||||
|
||||
if len(docs_batch) >= self.batch_size:
|
||||
logger.info(f"Yielding batch of {len(docs_batch)} documents")
|
||||
yield docs_batch
|
||||
docs_batch = []
|
||||
|
||||
if docs_batch:
|
||||
logger.info(f"Yielding final batch of {len(docs_batch)} documents")
|
||||
yield docs_batch
|
||||
|
||||
logger.info("Asana poll completed")
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
logger.notice("Starting full index of all Asana tasks")
|
||||
return self.poll_source(start=0, end=None)
|
||||
|
||||
def _message_to_doc(self, task: asana_api.AsanaTask) -> Document:
|
||||
logger.debug(f"Converting Asana task {task.id} to Document")
|
||||
return Document(
|
||||
id=task.id,
|
||||
sections=[Section(link=task.link, text=task.text)],
|
||||
doc_updated_at=task.last_modified,
|
||||
source=DocumentSource.ASANA,
|
||||
semantic_identifier=task.title,
|
||||
metadata={
|
||||
"group": task.project_gid,
|
||||
"project": task.project_name,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
import os
|
||||
|
||||
logger.notice("Starting Asana connector test")
|
||||
connector = AsanaConnector(
|
||||
os.environ["WORKSPACE_ID"],
|
||||
os.environ["PROJECT_IDS"],
|
||||
os.environ["TEAM_ID"],
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"asana_api_token_secret": os.environ["API_TOKEN"],
|
||||
}
|
||||
)
|
||||
logger.info("Loading all documents from Asana")
|
||||
all_docs = connector.load_from_state()
|
||||
current = time.time()
|
||||
one_day_ago = current - 24 * 60 * 60 # 1 day
|
||||
logger.info("Polling for documents updated in the last 24 hours")
|
||||
latest_docs = connector.poll_source(one_day_ago, current)
|
||||
for docs in latest_docs:
|
||||
for doc in docs:
|
||||
print(doc.id)
|
||||
logger.notice("Asana connector test completed")
|
||||
@@ -7,7 +7,6 @@ from datetime import timezone
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import bs4
|
||||
from atlassian import Confluence # type:ignore
|
||||
@@ -53,79 +52,6 @@ NO_PARENT_OR_NO_PERMISSIONS_ERROR_STR = (
|
||||
)
|
||||
|
||||
|
||||
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]:
|
||||
"""Sample
|
||||
URL w/ page: https://danswer.atlassian.net/wiki/spaces/1234abcd/pages/5678efgh/overview
|
||||
URL w/o page: https://danswer.atlassian.net/wiki/spaces/ASAM/overview
|
||||
|
||||
wiki_base is https://danswer.atlassian.net/wiki
|
||||
space is 1234abcd
|
||||
page_id is 5678efgh
|
||||
"""
|
||||
parsed_url = urlparse(wiki_url)
|
||||
wiki_base = (
|
||||
parsed_url.scheme
|
||||
+ "://"
|
||||
+ parsed_url.netloc
|
||||
+ parsed_url.path.split("/spaces")[0]
|
||||
)
|
||||
|
||||
path_parts = parsed_url.path.split("/")
|
||||
space = path_parts[3]
|
||||
|
||||
page_id = path_parts[5] if len(path_parts) > 5 else ""
|
||||
return wiki_base, space, page_id
|
||||
|
||||
|
||||
def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, str, str]:
|
||||
"""Sample
|
||||
URL w/ page https://danswer.ai/confluence/display/1234abcd/pages/5678efgh/overview
|
||||
URL w/o page https://danswer.ai/confluence/display/1234abcd/overview
|
||||
wiki_base is https://danswer.ai/confluence
|
||||
space is 1234abcd
|
||||
page_id is 5678efgh
|
||||
"""
|
||||
# /display/ is always right before the space and at the end of the base print()
|
||||
DISPLAY = "/display/"
|
||||
PAGE = "/pages/"
|
||||
|
||||
parsed_url = urlparse(wiki_url)
|
||||
wiki_base = (
|
||||
parsed_url.scheme
|
||||
+ "://"
|
||||
+ parsed_url.netloc
|
||||
+ parsed_url.path.split(DISPLAY)[0]
|
||||
)
|
||||
space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0]
|
||||
page_id = ""
|
||||
if (content := parsed_url.path.split(PAGE)) and len(content) > 1:
|
||||
page_id = content[1]
|
||||
return wiki_base, space, page_id
|
||||
|
||||
|
||||
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]:
|
||||
is_confluence_cloud = (
|
||||
".atlassian.net/wiki/spaces/" in wiki_url
|
||||
or ".jira.com/wiki/spaces/" in wiki_url
|
||||
)
|
||||
|
||||
try:
|
||||
if is_confluence_cloud:
|
||||
wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url(
|
||||
wiki_url
|
||||
)
|
||||
else:
|
||||
wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url(
|
||||
wiki_url
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Not a valid Confluence Wiki Link, unable to extract wiki base, space, and page id. Exception: {e}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
return wiki_base, space, page_id, is_confluence_cloud
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def _get_user(user_id: str, confluence_client: Confluence) -> str:
|
||||
"""Get Confluence Display Name based on the account-id or userkey value
|
||||
@@ -372,7 +298,10 @@ class RecursiveIndexer:
|
||||
class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
wiki_page_url: str,
|
||||
wiki_base: str,
|
||||
space: str,
|
||||
is_cloud: bool,
|
||||
page_id: str = "",
|
||||
index_recursively: bool = True,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
@@ -386,15 +315,15 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
self.recursive_indexer: RecursiveIndexer | None = None
|
||||
self.index_recursively = index_recursively
|
||||
(
|
||||
self.wiki_base,
|
||||
self.space,
|
||||
self.page_id,
|
||||
self.is_cloud,
|
||||
) = extract_confluence_keys_from_url(wiki_page_url)
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
self.space = space
|
||||
self.page_id = page_id
|
||||
|
||||
self.is_cloud = is_cloud
|
||||
|
||||
self.space_level_scan = False
|
||||
|
||||
self.confluence_client: Confluence | None = None
|
||||
|
||||
if self.page_id is None or self.page_id == "":
|
||||
@@ -414,7 +343,6 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
username=username if self.is_cloud else None,
|
||||
password=access_token if self.is_cloud else None,
|
||||
token=access_token if not self.is_cloud else None,
|
||||
cloud=self.is_cloud,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -866,7 +794,13 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = ConfluenceConnector(os.environ["CONFLUENCE_TEST_SPACE_URL"])
|
||||
connector = ConfluenceConnector(
|
||||
wiki_base=os.environ["CONFLUENCE_TEST_SPACE_URL"],
|
||||
space=os.environ["CONFLUENCE_TEST_SPACE"],
|
||||
is_cloud=os.environ.get("CONFLUENCE_IS_CLOUD", "true").lower() == "true",
|
||||
page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""),
|
||||
index_recursively=True,
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],
|
||||
|
||||
@@ -23,7 +23,7 @@ class ConfluenceRateLimitError(Exception):
|
||||
|
||||
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
max_retries = 10
|
||||
max_retries = 5
|
||||
starting_delay = 5
|
||||
backoff = 2
|
||||
max_delay = 600
|
||||
@@ -32,17 +32,24 @@ def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
|
||||
try:
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
# Check if the response or headers are None to avoid potential AttributeError
|
||||
if e.response is None or e.response.headers is None:
|
||||
logger.warning("HTTPError with `None` as response or as headers")
|
||||
raise e
|
||||
|
||||
retry_after_header = e.response.headers.get("Retry-After")
|
||||
if (
|
||||
e.response.status_code == 429
|
||||
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
|
||||
):
|
||||
retry_after = None
|
||||
try:
|
||||
retry_after = int(e.response.headers.get("Retry-After"))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
if retry_after_header is not None:
|
||||
try:
|
||||
retry_after = int(retry_after_header)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if retry_after:
|
||||
if retry_after is not None:
|
||||
logger.warning(
|
||||
f"Rate limit hit. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
|
||||
@@ -45,10 +45,15 @@ def extract_jira_project(url: str) -> tuple[str, str]:
|
||||
return jira_base, jira_project
|
||||
|
||||
|
||||
def extract_text_from_content(content: dict) -> str:
|
||||
def extract_text_from_adf(adf: dict | None) -> str:
|
||||
"""Extracts plain text from Atlassian Document Format:
|
||||
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
|
||||
|
||||
WARNING: This function is incomplete and will e.g. skip lists!
|
||||
"""
|
||||
texts = []
|
||||
if "content" in content:
|
||||
for block in content["content"]:
|
||||
if adf is not None and "content" in adf:
|
||||
for block in adf["content"]:
|
||||
if "content" in block:
|
||||
for item in block["content"]:
|
||||
if item["type"] == "text":
|
||||
@@ -72,18 +77,15 @@ def _get_comment_strs(
|
||||
comment_strs = []
|
||||
for comment in jira.fields.comment.comments:
|
||||
try:
|
||||
if hasattr(comment, "body"):
|
||||
body_text = extract_text_from_content(comment.raw["body"])
|
||||
elif hasattr(comment, "raw"):
|
||||
body = comment.raw.get("body", "No body content available")
|
||||
body_text = (
|
||||
extract_text_from_content(body) if isinstance(body, dict) else body
|
||||
)
|
||||
else:
|
||||
body_text = "No body attribute found"
|
||||
body_text = (
|
||||
comment.body
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(comment.raw["body"])
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(comment, "author")
|
||||
and hasattr(comment.author, "emailAddress")
|
||||
and comment.author.emailAddress in comment_email_blacklist
|
||||
):
|
||||
continue # Skip adding comment if author's email is in blacklist
|
||||
@@ -126,11 +128,14 @@ def fetch_jira_issues_batch(
|
||||
)
|
||||
continue
|
||||
|
||||
description = (
|
||||
jira.fields.description
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(jira.raw["fields"]["description"])
|
||||
)
|
||||
comments = _get_comment_strs(jira, comment_email_blacklist)
|
||||
semantic_rep = (
|
||||
f"{jira.fields.description}\n"
|
||||
if jira.fields.description
|
||||
else "" + "\n".join([f"Comment: {comment}" for comment in comments])
|
||||
semantic_rep = f"{description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments if comment]
|
||||
)
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Type
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.asana.connector import AsanaConnector
|
||||
from danswer.connectors.axero.connector import AxeroConnector
|
||||
from danswer.connectors.blob.connector import BlobStorageConnector
|
||||
from danswer.connectors.bookstack.connector import BookstackConnector
|
||||
@@ -91,6 +92,7 @@ def identify_connector_class(
|
||||
DocumentSource.CLICKUP: ClickupConnector,
|
||||
DocumentSource.MEDIAWIKI: MediaWikiConnector,
|
||||
DocumentSource.WIKIPEDIA: WikipediaConnector,
|
||||
DocumentSource.ASANA: AsanaConnector,
|
||||
DocumentSource.S3: BlobStorageConnector,
|
||||
DocumentSource.R2: BlobStorageConnector,
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
|
||||
@@ -124,11 +126,11 @@ def identify_connector_class(
|
||||
|
||||
|
||||
def instantiate_connector(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
input_type: InputType,
|
||||
connector_specific_config: dict[str, Any],
|
||||
credential: Credential,
|
||||
db_session: Session,
|
||||
) -> BaseConnector:
|
||||
connector_class = identify_connector_class(source, input_type)
|
||||
connector = connector_class(**connector_specific_config)
|
||||
|
||||
@@ -6,7 +6,6 @@ from datetime import timezone
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
@@ -21,19 +20,13 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.google_drive.connector_auth import (
|
||||
get_google_drive_creds_for_authorized_user,
|
||||
)
|
||||
from danswer.connectors.google_drive.connector_auth import (
|
||||
get_google_drive_creds_for_service_account,
|
||||
)
|
||||
from danswer.connectors.google_drive.connector_auth import get_google_drive_creds
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
@@ -62,6 +55,8 @@ class GDriveMimeType(str, Enum):
|
||||
POWERPOINT = (
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
)
|
||||
PLAIN_TEXT = "text/plain"
|
||||
MARKDOWN = "text/markdown"
|
||||
|
||||
|
||||
GoogleDriveFileType = dict[str, Any]
|
||||
@@ -316,19 +311,22 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
|
||||
GDriveMimeType.PPT.value,
|
||||
GDriveMimeType.SPREADSHEET.value,
|
||||
]:
|
||||
export_mime_type = "text/plain"
|
||||
if mime_type == GDriveMimeType.SPREADSHEET.value:
|
||||
export_mime_type = "text/csv"
|
||||
elif mime_type == GDriveMimeType.PPT.value:
|
||||
export_mime_type = "text/plain"
|
||||
|
||||
response = (
|
||||
export_mime_type = (
|
||||
"text/plain"
|
||||
if mime_type != GDriveMimeType.SPREADSHEET.value
|
||||
else "text/csv"
|
||||
)
|
||||
return (
|
||||
service.files()
|
||||
.export(fileId=file["id"], mimeType=export_mime_type)
|
||||
.execute()
|
||||
.decode("utf-8")
|
||||
)
|
||||
return response.decode("utf-8")
|
||||
|
||||
elif mime_type in [
|
||||
GDriveMimeType.PLAIN_TEXT.value,
|
||||
GDriveMimeType.MARKDOWN.value,
|
||||
]:
|
||||
return service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
|
||||
elif mime_type == GDriveMimeType.WORD_DOC.value:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
return docx_to_text(file=io.BytesIO(response))
|
||||
@@ -402,42 +400,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
|
||||
(2) A credential which holds a service account key JSON file, which
|
||||
can then be used to impersonate any user in the workspace.
|
||||
"""
|
||||
creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
new_creds_dict = None
|
||||
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
||||
access_token_json_str = cast(
|
||||
str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
|
||||
)
|
||||
creds = get_google_drive_creds_for_authorized_user(
|
||||
token_json_str=access_token_json_str
|
||||
)
|
||||
|
||||
# tell caller to update token stored in DB if it has changed
|
||||
# (e.g. the token has been refreshed)
|
||||
new_creds_json_str = creds.to_json() if creds else ""
|
||||
if new_creds_json_str != access_token_json_str:
|
||||
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
|
||||
|
||||
if DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
|
||||
service_account_key_json_str = credentials[
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
]
|
||||
creds = get_google_drive_creds_for_service_account(
|
||||
service_account_key_json_str=service_account_key_json_str
|
||||
)
|
||||
|
||||
# "Impersonate" a user if one is specified
|
||||
delegated_user_email = cast(
|
||||
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
|
||||
)
|
||||
if delegated_user_email:
|
||||
creds = creds.with_subject(delegated_user_email) if creds else None # type: ignore
|
||||
|
||||
if creds is None:
|
||||
raise PermissionError(
|
||||
"Unable to access Google Drive - unknown credential structure."
|
||||
)
|
||||
|
||||
creds, new_creds_dict = get_google_drive_creds(credentials)
|
||||
self.creds = creds
|
||||
return new_creds_dict
|
||||
|
||||
@@ -504,6 +467,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
|
||||
file["modifiedTime"]
|
||||
).astimezone(timezone.utc),
|
||||
metadata={} if text_contents else {IGNORE_FOR_QA: "True"},
|
||||
additional_info=file.get("id"),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -10,11 +10,13 @@ from google.oauth2.service_account import Credentials as ServiceAccountCredentia
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import KV_CRED_KEY
|
||||
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
|
||||
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
from danswer.connectors.google_drive.constants import BASE_SCOPES
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
)
|
||||
@@ -22,7 +24,8 @@ from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
from danswer.connectors.google_drive.constants import SCOPES
|
||||
from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES
|
||||
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
|
||||
from danswer.db.credentials import update_credential_json
|
||||
from danswer.db.models import User
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
@@ -34,15 +37,25 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_gdrive_scopes() -> list[str]:
|
||||
base_scopes: list[str] = BASE_SCOPES
|
||||
permissions_scopes: list[str] = FETCH_PERMISSIONS_SCOPES
|
||||
groups_scopes: list[str] = FETCH_GROUPS_SCOPES
|
||||
|
||||
if ENTERPRISE_EDITION_ENABLED:
|
||||
return base_scopes + permissions_scopes + groups_scopes
|
||||
return base_scopes + permissions_scopes
|
||||
|
||||
|
||||
def _build_frontend_google_drive_redirect() -> str:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
|
||||
|
||||
|
||||
def get_google_drive_creds_for_authorized_user(
|
||||
token_json_str: str,
|
||||
token_json_str: str, scopes: list[str] = build_gdrive_scopes()
|
||||
) -> OAuthCredentials | None:
|
||||
creds_json = json.loads(token_json_str)
|
||||
creds = OAuthCredentials.from_authorized_user_info(creds_json, SCOPES)
|
||||
creds = OAuthCredentials.from_authorized_user_info(creds_json, scopes)
|
||||
if creds.valid:
|
||||
return creds
|
||||
|
||||
@@ -59,18 +72,67 @@ def get_google_drive_creds_for_authorized_user(
|
||||
return None
|
||||
|
||||
|
||||
def get_google_drive_creds_for_service_account(
|
||||
service_account_key_json_str: str,
|
||||
def _get_google_drive_creds_for_service_account(
|
||||
service_account_key_json_str: str, scopes: list[str] = build_gdrive_scopes()
|
||||
) -> ServiceAccountCredentials | None:
|
||||
service_account_key = json.loads(service_account_key_json_str)
|
||||
creds = ServiceAccountCredentials.from_service_account_info(
|
||||
service_account_key, scopes=SCOPES
|
||||
service_account_key, scopes=scopes
|
||||
)
|
||||
if not creds.valid or not creds.expired:
|
||||
creds.refresh(Request())
|
||||
return creds if creds.valid else None
|
||||
|
||||
|
||||
def get_google_drive_creds(
|
||||
credentials: dict[str, str], scopes: list[str] = build_gdrive_scopes()
|
||||
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
|
||||
oauth_creds = None
|
||||
service_creds = None
|
||||
new_creds_dict = None
|
||||
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
||||
access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY])
|
||||
oauth_creds = get_google_drive_creds_for_authorized_user(
|
||||
token_json_str=access_token_json_str, scopes=scopes
|
||||
)
|
||||
|
||||
# tell caller to update token stored in DB if it has changed
|
||||
# (e.g. the token has been refreshed)
|
||||
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
|
||||
if new_creds_json_str != access_token_json_str:
|
||||
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
|
||||
|
||||
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
|
||||
service_account_key_json_str = credentials[
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
]
|
||||
service_creds = _get_google_drive_creds_for_service_account(
|
||||
service_account_key_json_str=service_account_key_json_str,
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
# "Impersonate" a user if one is specified
|
||||
delegated_user_email = cast(
|
||||
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
|
||||
)
|
||||
if delegated_user_email:
|
||||
service_creds = (
|
||||
service_creds.with_subject(delegated_user_email)
|
||||
if service_creds
|
||||
else None
|
||||
)
|
||||
|
||||
creds: ServiceAccountCredentials | OAuthCredentials | None = (
|
||||
oauth_creds or service_creds
|
||||
)
|
||||
if creds is None:
|
||||
raise PermissionError(
|
||||
"Unable to access Google Drive - unknown credential structure."
|
||||
)
|
||||
|
||||
return creds, new_creds_dict
|
||||
|
||||
|
||||
def verify_csrf(credential_id: int, state: str) -> None:
|
||||
csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
if csrf != state:
|
||||
@@ -84,7 +146,7 @@ def get_auth_url(credential_id: int) -> str:
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
scopes=SCOPES,
|
||||
scopes=build_gdrive_scopes(),
|
||||
redirect_uri=_build_frontend_google_drive_redirect(),
|
||||
)
|
||||
auth_url, _ = flow.authorization_url(prompt="consent")
|
||||
@@ -107,7 +169,7 @@ def update_credential_access_tokens(
|
||||
app_credentials = get_google_app_cred()
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
app_credentials.model_dump(),
|
||||
scopes=SCOPES,
|
||||
scopes=build_gdrive_scopes(),
|
||||
redirect_uri=_build_frontend_google_drive_redirect(),
|
||||
)
|
||||
flow.fetch_token(code=auth_code)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens"
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user"
|
||||
SCOPES = [
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly",
|
||||
]
|
||||
|
||||
BASE_SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
|
||||
FETCH_PERMISSIONS_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"]
|
||||
FETCH_GROUPS_SCOPES = ["https://www.googleapis.com/auth/cloud-identity.groups.readonly"]
|
||||
|
||||
@@ -113,6 +113,9 @@ class DocumentBase(BaseModel):
|
||||
# The default title is semantic_identifier though unless otherwise specified
|
||||
title: str | None = None
|
||||
from_ingestion_api: bool = False
|
||||
# Anything else that may be useful that is specific to this particular connector type that other
|
||||
# parts of the code may need. If you're unsure, this can be left as None
|
||||
additional_info: Any = None
|
||||
|
||||
def get_title_for_document_index(
|
||||
self,
|
||||
|
||||
@@ -237,6 +237,14 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
continue
|
||||
|
||||
if result_type == "external_object_instance_page":
|
||||
logger.warning(
|
||||
f"Skipping 'external_object_instance_page' ('{result_block_id}') for base block '{base_block_id}': "
|
||||
f"Notion API does not currently support reading external blocks (as of 24/07/03) "
|
||||
f"(discussion: https://github.com/danswer-ai/danswer/issues/1761)"
|
||||
)
|
||||
continue
|
||||
|
||||
cur_result_text_arr = []
|
||||
if "rich_text" in result_obj:
|
||||
for rich_text in result_obj["rich_text"]:
|
||||
|
||||
@@ -98,6 +98,15 @@ class ProductboardConnector(PollConnector):
|
||||
owner = self._get_owner_email(feature)
|
||||
experts = [BasicExpertInfo(email=owner)] if owner else None
|
||||
|
||||
metadata: dict[str, str | list[str]] = {}
|
||||
entity_type = feature.get("type", "feature")
|
||||
if entity_type:
|
||||
metadata["entity_type"] = str(entity_type)
|
||||
|
||||
status = feature.get("status", {}).get("name")
|
||||
if status:
|
||||
metadata["status"] = str(status)
|
||||
|
||||
yield Document(
|
||||
id=feature["id"],
|
||||
sections=[
|
||||
@@ -110,10 +119,7 @@ class ProductboardConnector(PollConnector):
|
||||
source=DocumentSource.PRODUCTBOARD,
|
||||
doc_updated_at=time_str_to_utc(feature["updatedAt"]),
|
||||
primary_owners=experts,
|
||||
metadata={
|
||||
"entity_type": feature["type"],
|
||||
"status": feature["status"]["name"],
|
||||
},
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _get_components(self) -> Generator[Document, None, None]:
|
||||
@@ -174,6 +180,12 @@ class ProductboardConnector(PollConnector):
|
||||
owner = self._get_owner_email(objective)
|
||||
experts = [BasicExpertInfo(email=owner)] if owner else None
|
||||
|
||||
metadata: dict[str, str | list[str]] = {
|
||||
"entity_type": "objective",
|
||||
}
|
||||
if objective.get("state"):
|
||||
metadata["state"] = str(objective["state"])
|
||||
|
||||
yield Document(
|
||||
id=objective["id"],
|
||||
sections=[
|
||||
@@ -186,10 +198,7 @@ class ProductboardConnector(PollConnector):
|
||||
source=DocumentSource.PRODUCTBOARD,
|
||||
doc_updated_at=time_str_to_utc(objective["updatedAt"]),
|
||||
primary_owners=experts,
|
||||
metadata={
|
||||
"entity_type": "release",
|
||||
"state": objective["state"],
|
||||
},
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _is_updated_at_out_of_time_range(
|
||||
|
||||
@@ -25,7 +25,6 @@ from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -137,7 +136,7 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
.execute_query()
|
||||
]
|
||||
else:
|
||||
sites = self.graph_client.sites.get().execute_query()
|
||||
sites = self.graph_client.sites.get_all().execute_query()
|
||||
self.site_data = [
|
||||
SiteData(url=None, folder=None, sites=sites, driveitems=[])
|
||||
]
|
||||
|
||||
@@ -29,6 +29,7 @@ from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.connectors.slack.utils import SlackTextCleaner
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import io
|
||||
import ipaddress
|
||||
import socket
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -85,7 +87,8 @@ def check_internet_connection(url: str) -> None:
|
||||
response = requests.get(url, timeout=3)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
status_code = e.response.status_code
|
||||
# Extract status code from the response, defaulting to -1 if response is None
|
||||
status_code = e.response.status_code if e.response is not None else -1
|
||||
error_msg = {
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
@@ -202,6 +205,15 @@ def _read_urls_file(location: str) -> list[str]:
|
||||
return urls
|
||||
|
||||
|
||||
def _get_datetime_from_last_modified_header(last_modified: str) -> datetime | None:
|
||||
try:
|
||||
return datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z").replace(
|
||||
tzinfo=timezone.utc
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
class WebConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -287,6 +299,7 @@ class WebConnector(LoadConnector):
|
||||
page_text, metadata = read_pdf_file(
|
||||
file=io.BytesIO(response.content)
|
||||
)
|
||||
last_modified = response.headers.get("Last-Modified")
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
@@ -295,12 +308,22 @@ class WebConnector(LoadConnector):
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=current_url.split("/")[-1],
|
||||
metadata=metadata,
|
||||
doc_updated_at=_get_datetime_from_last_modified_header(
|
||||
last_modified
|
||||
)
|
||||
if last_modified
|
||||
else None,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
page = context.new_page()
|
||||
page_response = page.goto(current_url)
|
||||
last_modified = (
|
||||
page_response.header_value("Last-Modified")
|
||||
if page_response
|
||||
else None
|
||||
)
|
||||
final_page = page.url
|
||||
if final_page != current_url:
|
||||
logger.info(f"Redirected to {final_page}")
|
||||
@@ -336,6 +359,11 @@ class WebConnector(LoadConnector):
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=parsed_html.title or current_url,
|
||||
metadata={},
|
||||
doc_updated_at=_get_datetime_from_last_modified_header(
|
||||
last_modified
|
||||
)
|
||||
if last_modified
|
||||
else None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any
|
||||
import requests
|
||||
from retry import retry
|
||||
from zenpy import Zenpy # type: ignore
|
||||
from zenpy.lib.api_objects import Ticket # type: ignore
|
||||
from zenpy.lib.api_objects.help_centre_objects import Article # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
@@ -59,10 +60,15 @@ class ZendeskClientNotSetUpError(PermissionError):
|
||||
|
||||
|
||||
class ZendeskConnector(LoadConnector, PollConnector):
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
content_type: str = "articles",
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.zendesk_client: Zenpy | None = None
|
||||
self.content_tags: dict[str, str] = {}
|
||||
self.content_type = content_type
|
||||
|
||||
@retry(tries=3, delay=2, backoff=2)
|
||||
def _set_content_tags(
|
||||
@@ -122,16 +128,86 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self.poll_source(None, None)
|
||||
|
||||
def _ticket_to_document(self, ticket: Ticket) -> Document:
|
||||
if self.zendesk_client is None:
|
||||
raise ZendeskClientNotSetUpError()
|
||||
|
||||
owner = None
|
||||
if ticket.requester and ticket.requester.name and ticket.requester.email:
|
||||
owner = [
|
||||
BasicExpertInfo(
|
||||
display_name=ticket.requester.name, email=ticket.requester.email
|
||||
)
|
||||
]
|
||||
update_time = time_str_to_utc(ticket.updated_at) if ticket.updated_at else None
|
||||
|
||||
metadata: dict[str, str | list[str]] = {}
|
||||
if ticket.status is not None:
|
||||
metadata["status"] = ticket.status
|
||||
if ticket.priority is not None:
|
||||
metadata["priority"] = ticket.priority
|
||||
if ticket.tags:
|
||||
metadata["tags"] = ticket.tags
|
||||
if ticket.type is not None:
|
||||
metadata["ticket_type"] = ticket.type
|
||||
|
||||
# Fetch comments for the ticket
|
||||
comments = self.zendesk_client.tickets.comments(ticket=ticket)
|
||||
|
||||
# Combine all comments into a single text
|
||||
comments_text = "\n\n".join(
|
||||
[
|
||||
f"Comment{f' by {comment.author.name}' if comment.author and comment.author.name else ''}"
|
||||
f"{f' at {comment.created_at}' if comment.created_at else ''}:\n{comment.body}"
|
||||
for comment in comments
|
||||
if comment.body
|
||||
]
|
||||
)
|
||||
|
||||
# Combine ticket description and comments
|
||||
description = (
|
||||
ticket.description
|
||||
if hasattr(ticket, "description") and ticket.description
|
||||
else ""
|
||||
)
|
||||
full_text = f"Ticket Description:\n{description}\n\nComments:\n{comments_text}"
|
||||
|
||||
# Extract subdomain from ticket.url
|
||||
subdomain = ticket.url.split("//")[1].split(".zendesk.com")[0]
|
||||
|
||||
# Build the html url for the ticket
|
||||
ticket_url = f"https://{subdomain}.zendesk.com/agent/tickets/{ticket.id}"
|
||||
|
||||
return Document(
|
||||
id=f"zendesk_ticket_{ticket.id}",
|
||||
sections=[Section(link=ticket_url, text=full_text)],
|
||||
source=DocumentSource.ZENDESK,
|
||||
semantic_identifier=f"Ticket #{ticket.id}: {ticket.subject or 'No Subject'}",
|
||||
doc_updated_at=update_time,
|
||||
primary_owners=owner,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.zendesk_client is None:
|
||||
raise ZendeskClientNotSetUpError()
|
||||
|
||||
if self.content_type == "articles":
|
||||
yield from self._poll_articles(start)
|
||||
elif self.content_type == "tickets":
|
||||
yield from self._poll_tickets(start)
|
||||
else:
|
||||
raise ValueError(f"Unsupported content_type: {self.content_type}")
|
||||
|
||||
def _poll_articles(
|
||||
self, start: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
articles = (
|
||||
self.zendesk_client.help_center.articles(cursor_pagination=True)
|
||||
self.zendesk_client.help_center.articles(cursor_pagination=True) # type: ignore
|
||||
if start is None
|
||||
else self.zendesk_client.help_center.articles.incremental(
|
||||
else self.zendesk_client.help_center.articles.incremental( # type: ignore
|
||||
start_time=int(start)
|
||||
)
|
||||
)
|
||||
@@ -155,9 +231,43 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def _poll_tickets(
|
||||
self, start: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.zendesk_client is None:
|
||||
raise ZendeskClientNotSetUpError()
|
||||
|
||||
ticket_generator = self.zendesk_client.tickets.incremental(start_time=start)
|
||||
|
||||
while True:
|
||||
doc_batch = []
|
||||
for _ in range(self.batch_size):
|
||||
try:
|
||||
ticket = next(ticket_generator)
|
||||
|
||||
# Check if the ticket status is deleted and skip it if so
|
||||
if ticket.status == "deleted":
|
||||
continue
|
||||
|
||||
doc_batch.append(self._ticket_to_document(ticket))
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch.clear()
|
||||
|
||||
except StopIteration:
|
||||
# No more tickets to process
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
return
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
import time
|
||||
|
||||
connector = ZendeskConnector()
|
||||
|
||||
@@ -25,7 +25,6 @@ from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.icons import source_to_github_img_link
|
||||
@@ -360,22 +359,6 @@ def build_quotes_block(
|
||||
return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))]
|
||||
|
||||
|
||||
def build_standard_answer_blocks(
|
||||
answer_message: str,
|
||||
) -> list[Block]:
|
||||
generate_button_block = ButtonElement(
|
||||
action_id=GENERATE_ANSWER_BUTTON_ACTION_ID,
|
||||
text="Generate Full Answer",
|
||||
)
|
||||
answer_block = SectionBlock(text=answer_message)
|
||||
return [
|
||||
answer_block,
|
||||
ActionsBlock(
|
||||
elements=[generate_button_block],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def build_qa_response_blocks(
|
||||
message_id: int | None,
|
||||
answer: str | None,
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||
from danswer.connectors.slack.utils import expert_info_from_slack_id
|
||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.danswerbot.slack.blocks import build_follow_up_resolved_blocks
|
||||
from danswer.danswerbot.slack.blocks import get_document_feedback_blocks
|
||||
@@ -87,6 +88,8 @@ def handle_generate_answer_button(
|
||||
message_ts = req.payload["message"]["ts"]
|
||||
thread_ts = req.payload["container"]["thread_ts"]
|
||||
user_id = req.payload["user"]["id"]
|
||||
expert_info = expert_info_from_slack_id(user_id, client.web_client, user_cache={})
|
||||
email = expert_info.email if expert_info else None
|
||||
|
||||
if not thread_ts:
|
||||
raise ValueError("Missing thread_ts in the payload")
|
||||
@@ -125,6 +128,7 @@ def handle_generate_answer_button(
|
||||
msg_to_respond=cast(str, message_ts or thread_ts),
|
||||
thread_to_respond=cast(str, thread_ts or message_ts),
|
||||
sender=user_id or None,
|
||||
email=email or None,
|
||||
bypass_filters=True,
|
||||
is_bot_msg=False,
|
||||
is_bot_dm=False,
|
||||
|
||||
@@ -21,6 +21,7 @@ from danswer.danswerbot.slack.utils import slack_usage_report
|
||||
from danswer.danswerbot.slack.utils import update_emote_react
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.users import add_non_web_user_if_not_exists
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import SLACK_CHANNEL_ID
|
||||
|
||||
@@ -209,6 +210,9 @@ def handle_message(
|
||||
logger.error(f"Was not able to react to user message due to: {e}")
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if message_info.email:
|
||||
add_non_web_user_if_not_exists(db_session, message_info.email)
|
||||
|
||||
# first check if we need to respond with a standard answer
|
||||
used_standard_answer = handle_standard_answers(
|
||||
message_info=message_info,
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
from fastapi import HTTPException
|
||||
from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.models.blocks import DividerBlock
|
||||
@@ -38,6 +39,7 @@ from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.persona import fetch_persona_by_id
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
@@ -99,6 +101,12 @@ def handle_regular_answer(
|
||||
messages = message_info.thread_messages
|
||||
message_ts_to_respond_to = message_info.msg_to_respond
|
||||
is_bot_msg = message_info.is_bot_msg
|
||||
user = None
|
||||
if message_info.is_bot_dm:
|
||||
if message_info.email:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
user = get_user_by_email(message_info.email, db_session)
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
persona = slack_bot_config.persona if slack_bot_config else None
|
||||
@@ -128,7 +136,8 @@ def handle_regular_answer(
|
||||
else slack_bot_config.response_type == SlackBotResponseType.CITATIONS
|
||||
)
|
||||
|
||||
if not message_ts_to_respond_to:
|
||||
if not message_ts_to_respond_to and not is_bot_msg:
|
||||
# if the message is not "/danswer" command, then it should have a message ts to respond to
|
||||
raise RuntimeError(
|
||||
"No message timestamp to respond to in `handle_message`. This should never happen."
|
||||
)
|
||||
@@ -145,15 +154,23 @@ def handle_regular_answer(
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if len(new_message_request.messages) > 1:
|
||||
persona = cast(
|
||||
Persona,
|
||||
fetch_persona_by_id(
|
||||
db_session,
|
||||
new_message_request.persona_id,
|
||||
user=None,
|
||||
get_editable=False,
|
||||
),
|
||||
)
|
||||
if new_message_request.persona_config:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Slack bot does not support persona config",
|
||||
)
|
||||
|
||||
elif new_message_request.persona_id:
|
||||
persona = cast(
|
||||
Persona,
|
||||
fetch_persona_by_id(
|
||||
db_session,
|
||||
new_message_request.persona_id,
|
||||
user=None,
|
||||
get_editable=False,
|
||||
),
|
||||
)
|
||||
|
||||
llm, _ = get_llms_for_persona(persona)
|
||||
|
||||
# In cases of threads, split the available tokens between docs and thread context
|
||||
@@ -185,7 +202,7 @@ def handle_regular_answer(
|
||||
# This also handles creating the query event in postgres
|
||||
answer = get_search_answer(
|
||||
query_req=new_message_request,
|
||||
user=None,
|
||||
user=user,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
db_session=db_session,
|
||||
@@ -412,7 +429,7 @@ def handle_regular_answer(
|
||||
)
|
||||
|
||||
# Get the chunks fed to the LLM only, then fill with other docs
|
||||
llm_doc_inds = answer.llm_chunks_indices or []
|
||||
llm_doc_inds = answer.llm_selected_doc_indices or []
|
||||
llm_docs = [top_docs[i] for i in llm_doc_inds]
|
||||
remaining_docs = [
|
||||
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
|
||||
@@ -463,7 +480,9 @@ def handle_regular_answer(
|
||||
|
||||
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
|
||||
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
|
||||
if receiver_ids:
|
||||
# if there is no message_ts_to_respond_to, and we have made it this far, then this is a /danswer message
|
||||
# so we shouldn't send_team_member_message
|
||||
if receiver_ids and message_ts_to_respond_to is not None:
|
||||
send_team_member_message(
|
||||
client=client,
|
||||
channel=channel,
|
||||
|
||||
@@ -1,60 +1,16 @@
|
||||
from slack_sdk import WebClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
|
||||
from danswer.danswerbot.slack.blocks import build_standard_answer_blocks
|
||||
from danswer.danswerbot.slack.blocks import get_restate_blocks
|
||||
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import update_emote_react
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_chat_messages_by_sessions
|
||||
from danswer.db.chat import get_chat_sessions_by_slack_thread_id
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.standard_answer import fetch_standard_answer_categories_by_names
|
||||
from danswer.db.standard_answer import find_matching_standard_answers
|
||||
from danswer.server.manage.models import StandardAnswer
|
||||
from danswer.utils.logger import DanswerLoggingAdapter
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def oneoff_standard_answers(
|
||||
message: str,
|
||||
slack_bot_categories: list[str],
|
||||
db_session: Session,
|
||||
) -> list[StandardAnswer]:
|
||||
"""
|
||||
Respond to the user message if it matches any configured standard answers.
|
||||
|
||||
Returns a list of matching StandardAnswers if found, otherwise None.
|
||||
"""
|
||||
configured_standard_answers = {
|
||||
standard_answer
|
||||
for category in fetch_standard_answer_categories_by_names(
|
||||
slack_bot_categories, db_session=db_session
|
||||
)
|
||||
for standard_answer in category.standard_answers
|
||||
}
|
||||
|
||||
matching_standard_answers = find_matching_standard_answers(
|
||||
query=message,
|
||||
id_in=[answer.id for answer in configured_standard_answers],
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
server_standard_answers = [
|
||||
StandardAnswer.from_model(db_answer) for db_answer in matching_standard_answers
|
||||
]
|
||||
return server_standard_answers
|
||||
|
||||
|
||||
def handle_standard_answers(
|
||||
message_info: SlackMessageInfo,
|
||||
receiver_ids: list[str] | None,
|
||||
@@ -63,153 +19,38 @@ def handle_standard_answers(
|
||||
logger: DanswerLoggingAdapter,
|
||||
client: WebClient,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
"""Returns whether one or more Standard Answer message blocks were
|
||||
emitted by the Slack bot"""
|
||||
versioned_handle_standard_answers = fetch_versioned_implementation(
|
||||
"danswer.danswerbot.slack.handlers.handle_standard_answers",
|
||||
"_handle_standard_answers",
|
||||
)
|
||||
return versioned_handle_standard_answers(
|
||||
message_info=message_info,
|
||||
receiver_ids=receiver_ids,
|
||||
slack_bot_config=slack_bot_config,
|
||||
prompt=prompt,
|
||||
logger=logger,
|
||||
client=client,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def _handle_standard_answers(
|
||||
message_info: SlackMessageInfo,
|
||||
receiver_ids: list[str] | None,
|
||||
slack_bot_config: SlackBotConfig | None,
|
||||
prompt: Prompt | None,
|
||||
logger: DanswerLoggingAdapter,
|
||||
client: WebClient,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Potentially respond to the user message depending on whether the user's message matches
|
||||
any of the configured standard answers and also whether those answers have already been
|
||||
provided in the current thread.
|
||||
Standard Answers are a paid Enterprise Edition feature. This is the fallback
|
||||
function handling the case where EE features are not enabled.
|
||||
|
||||
Returns True if standard answers are found to match the user's message and therefore,
|
||||
we still need to respond to the users.
|
||||
Always returns false i.e. since EE features are not enabled, we NEVER create any
|
||||
Slack message blocks.
|
||||
"""
|
||||
# if no channel config, then no standard answers are configured
|
||||
if not slack_bot_config:
|
||||
return False
|
||||
|
||||
slack_thread_id = message_info.thread_to_respond
|
||||
configured_standard_answer_categories = (
|
||||
slack_bot_config.standard_answer_categories if slack_bot_config else []
|
||||
)
|
||||
configured_standard_answers = set(
|
||||
[
|
||||
standard_answer
|
||||
for standard_answer_category in configured_standard_answer_categories
|
||||
for standard_answer in standard_answer_category.standard_answers
|
||||
]
|
||||
)
|
||||
query_msg = message_info.thread_messages[-1]
|
||||
|
||||
if slack_thread_id is None:
|
||||
used_standard_answer_ids = set([])
|
||||
else:
|
||||
chat_sessions = get_chat_sessions_by_slack_thread_id(
|
||||
slack_thread_id=slack_thread_id,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
chat_messages = get_chat_messages_by_sessions(
|
||||
chat_session_ids=[chat_session.id for chat_session in chat_sessions],
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
skip_permission_check=True,
|
||||
)
|
||||
used_standard_answer_ids = set(
|
||||
[
|
||||
standard_answer.id
|
||||
for chat_message in chat_messages
|
||||
for standard_answer in chat_message.standard_answers
|
||||
]
|
||||
)
|
||||
|
||||
usable_standard_answers = configured_standard_answers.difference(
|
||||
used_standard_answer_ids
|
||||
)
|
||||
if usable_standard_answers:
|
||||
matching_standard_answers = find_matching_standard_answers(
|
||||
query=query_msg.message,
|
||||
id_in=[standard_answer.id for standard_answer in usable_standard_answers],
|
||||
db_session=db_session,
|
||||
)
|
||||
else:
|
||||
matching_standard_answers = []
|
||||
if matching_standard_answers:
|
||||
chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="",
|
||||
user_id=None,
|
||||
persona_id=slack_bot_config.persona.id if slack_bot_config.persona else 0,
|
||||
danswerbot_flow=True,
|
||||
slack_thread_id=slack_thread_id,
|
||||
one_shot=True,
|
||||
)
|
||||
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
|
||||
new_user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=root_message,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
message=query_msg.message,
|
||||
token_count=0,
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
formatted_answers = []
|
||||
for standard_answer in matching_standard_answers:
|
||||
block_quotified_answer = ">" + standard_answer.answer.replace("\n", "\n> ")
|
||||
formatted_answer = (
|
||||
f'Since you mentioned _"{standard_answer.keyword}"_, '
|
||||
f"I thought this might be useful: \n\n{block_quotified_answer}"
|
||||
)
|
||||
formatted_answers.append(formatted_answer)
|
||||
answer_message = "\n\n".join(formatted_answers)
|
||||
|
||||
_ = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=new_user_message,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
message=answer_message,
|
||||
token_count=0,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
error=None,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=True,
|
||||
client=client,
|
||||
)
|
||||
|
||||
restate_question_blocks = get_restate_blocks(
|
||||
msg=query_msg.message,
|
||||
is_bot_msg=message_info.is_bot_msg,
|
||||
)
|
||||
|
||||
answer_blocks = build_standard_answer_blocks(
|
||||
answer_message=answer_message,
|
||||
)
|
||||
|
||||
all_blocks = restate_question_blocks + answer_blocks
|
||||
|
||||
try:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=message_info.channel_to_respond,
|
||||
receiver_ids=receiver_ids,
|
||||
text="Hello! Danswer has some results for you!",
|
||||
blocks=all_blocks,
|
||||
thread_ts=message_info.msg_to_respond,
|
||||
unfurl=False,
|
||||
)
|
||||
|
||||
if receiver_ids and slack_thread_id:
|
||||
send_team_member_message(
|
||||
client=client,
|
||||
channel=message_info.channel_to_respond,
|
||||
thread_ts=slack_thread_id,
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to send standard answer message: {e}")
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
@@ -13,6 +13,7 @@ from danswer.configs.constants import MessageType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
|
||||
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
|
||||
from danswer.connectors.slack.utils import expert_info_from_slack_id
|
||||
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
|
||||
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
|
||||
@@ -55,6 +56,7 @@ from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import SLACK_CHANNEL_ID
|
||||
@@ -256,6 +258,11 @@ def build_request_details(
|
||||
tagged = event.get("type") == "app_mention"
|
||||
message_ts = event.get("ts")
|
||||
thread_ts = event.get("thread_ts")
|
||||
sender = event.get("user") or None
|
||||
expert_info = expert_info_from_slack_id(
|
||||
sender, client.web_client, user_cache={}
|
||||
)
|
||||
email = expert_info.email if expert_info else None
|
||||
|
||||
msg = remove_danswer_bot_tag(msg, client=client.web_client)
|
||||
|
||||
@@ -286,7 +293,8 @@ def build_request_details(
|
||||
channel_to_respond=channel,
|
||||
msg_to_respond=cast(str, message_ts or thread_ts),
|
||||
thread_to_respond=cast(str, thread_ts or message_ts),
|
||||
sender=event.get("user") or None,
|
||||
sender=sender,
|
||||
email=email,
|
||||
bypass_filters=tagged,
|
||||
is_bot_msg=False,
|
||||
is_bot_dm=event.get("channel_type") == "im",
|
||||
@@ -296,6 +304,10 @@ def build_request_details(
|
||||
channel = req.payload["channel_id"]
|
||||
msg = req.payload["text"]
|
||||
sender = req.payload["user_id"]
|
||||
expert_info = expert_info_from_slack_id(
|
||||
sender, client.web_client, user_cache={}
|
||||
)
|
||||
email = expert_info.email if expert_info else None
|
||||
|
||||
single_msg = ThreadMessage(message=msg, sender=None, role=MessageType.USER)
|
||||
|
||||
@@ -305,6 +317,7 @@ def build_request_details(
|
||||
msg_to_respond=None,
|
||||
thread_to_respond=None,
|
||||
sender=sender,
|
||||
email=email,
|
||||
bypass_filters=True,
|
||||
is_bot_msg=True,
|
||||
is_bot_dm=False,
|
||||
@@ -469,6 +482,8 @@ if __name__ == "__main__":
|
||||
slack_bot_tokens: SlackBotTokens | None = None
|
||||
socket_client: SocketModeClient | None = None
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
download_nltk_data()
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ class SlackMessageInfo(BaseModel):
|
||||
msg_to_respond: str | None
|
||||
thread_to_respond: str | None
|
||||
sender: str | None
|
||||
email: str | None
|
||||
bypass_filters: bool # User has tagged @DanswerBot
|
||||
is_bot_msg: bool # User is using /DanswerBot
|
||||
is_bot_dm: bool # User is direct messaging to DanswerBot
|
||||
|
||||
@@ -28,7 +28,7 @@ def get_default_admin_user_emails() -> list[str]:
|
||||
get_default_admin_user_emails_fn: Callable[
|
||||
[], list[str]
|
||||
] = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.auth.users", "get_default_admin_user_emails_", lambda: []
|
||||
"danswer.auth.users", "get_default_admin_user_emails_", lambda: list[str]()
|
||||
)
|
||||
return get_default_admin_user_emails_fn()
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import func
|
||||
@@ -87,29 +86,57 @@ def get_chat_sessions_by_slack_thread_id(
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def get_first_messages_for_chat_sessions(
|
||||
chat_session_ids: list[int], db_session: Session
|
||||
def get_valid_messages_from_query_sessions(
|
||||
chat_session_ids: list[int],
|
||||
db_session: Session,
|
||||
) -> dict[int, str]:
|
||||
subquery = (
|
||||
select(ChatMessage.chat_session_id, func.min(ChatMessage.id).label("min_id"))
|
||||
user_message_subquery = (
|
||||
select(
|
||||
ChatMessage.chat_session_id, func.min(ChatMessage.id).label("user_msg_id")
|
||||
)
|
||||
.where(
|
||||
and_(
|
||||
ChatMessage.chat_session_id.in_(chat_session_ids),
|
||||
ChatMessage.message_type == MessageType.USER, # Select USER messages
|
||||
)
|
||||
ChatMessage.chat_session_id.in_(chat_session_ids),
|
||||
ChatMessage.message_type == MessageType.USER,
|
||||
)
|
||||
.group_by(ChatMessage.chat_session_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = select(ChatMessage.chat_session_id, ChatMessage.message).join(
|
||||
subquery,
|
||||
(ChatMessage.chat_session_id == subquery.c.chat_session_id)
|
||||
& (ChatMessage.id == subquery.c.min_id),
|
||||
assistant_message_subquery = (
|
||||
select(
|
||||
ChatMessage.chat_session_id,
|
||||
func.min(ChatMessage.id).label("assistant_msg_id"),
|
||||
)
|
||||
.where(
|
||||
ChatMessage.chat_session_id.in_(chat_session_ids),
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
)
|
||||
.group_by(ChatMessage.chat_session_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = (
|
||||
select(ChatMessage.chat_session_id, ChatMessage.message)
|
||||
.join(
|
||||
user_message_subquery,
|
||||
ChatMessage.chat_session_id == user_message_subquery.c.chat_session_id,
|
||||
)
|
||||
.join(
|
||||
assistant_message_subquery,
|
||||
ChatMessage.chat_session_id == assistant_message_subquery.c.chat_session_id,
|
||||
)
|
||||
.join(
|
||||
ChatMessage__SearchDoc,
|
||||
ChatMessage__SearchDoc.chat_message_id
|
||||
== assistant_message_subquery.c.assistant_msg_id,
|
||||
)
|
||||
.where(ChatMessage.id == user_message_subquery.c.user_msg_id)
|
||||
)
|
||||
|
||||
first_messages = db_session.execute(query).all()
|
||||
return dict([(row.chat_session_id, row.message) for row in first_messages])
|
||||
logger.info(f"Retrieved {len(first_messages)} first messages with documents")
|
||||
|
||||
return {row.chat_session_id: row.message for row in first_messages}
|
||||
|
||||
|
||||
def get_chat_sessions_by_user(
|
||||
@@ -199,7 +226,7 @@ def create_chat_session(
|
||||
db_session: Session,
|
||||
description: str,
|
||||
user_id: UUID | None,
|
||||
persona_id: int,
|
||||
persona_id: int | None, # Can be none if temporary persona is used
|
||||
llm_override: LLMOverride | None = None,
|
||||
prompt_override: PromptOverride | None = None,
|
||||
one_shot: bool = False,
|
||||
@@ -253,6 +280,13 @@ def delete_chat_session(
|
||||
db_session: Session,
|
||||
hard_delete: bool = HARD_DELETE_CHATS,
|
||||
) -> None:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
|
||||
)
|
||||
|
||||
if chat_session.deleted:
|
||||
raise ValueError("Cannot delete an already deleted chat session")
|
||||
|
||||
if hard_delete:
|
||||
delete_messages_and_files_from_chat_session(chat_session_id, db_session)
|
||||
db_session.execute(delete(ChatSession).where(ChatSession.id == chat_session_id))
|
||||
@@ -564,6 +598,7 @@ def get_doc_query_identifiers_from_model(
|
||||
chat_session: ChatSession,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
enforce_chat_session_id_for_search_docs: bool,
|
||||
) -> list[tuple[str, int]]:
|
||||
"""Given a list of search_doc_ids"""
|
||||
search_docs = (
|
||||
@@ -583,7 +618,8 @@ def get_doc_query_identifiers_from_model(
|
||||
for doc in search_docs
|
||||
]
|
||||
):
|
||||
raise ValueError("Invalid reference doc, not from this chat session.")
|
||||
if enforce_chat_session_id_for_search_docs:
|
||||
raise ValueError("Invalid reference doc, not from this chat session.")
|
||||
except IndexError:
|
||||
# This happens when the doc has no chat_messages associated with it.
|
||||
# which happens as an edge case where the chat message failed to save
|
||||
|
||||
@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.credentials import fetch_credential_by_id
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexAttempt
|
||||
@@ -24,6 +25,10 @@ from danswer.db.models import UserGroup__ConnectorCredentialPair
|
||||
from danswer.db.models import UserRole
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
|
||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
||||
check_if_valid_sync_source,
|
||||
)
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -74,7 +79,7 @@ def _add_user_filters(
|
||||
.correlate(ConnectorCredentialPair)
|
||||
)
|
||||
else:
|
||||
where_clause |= ConnectorCredentialPair.is_public == True # noqa: E712
|
||||
where_clause |= ConnectorCredentialPair.access_type == AccessType.PUBLIC
|
||||
|
||||
return stmt.where(where_clause)
|
||||
|
||||
@@ -94,8 +99,19 @@ def get_connector_credential_pairs(
|
||||
) # noqa
|
||||
if ids:
|
||||
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
|
||||
results = db_session.scalars(stmt)
|
||||
return list(results.all())
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def add_deletion_failure_message(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
failure_message: str,
|
||||
) -> None:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if not cc_pair:
|
||||
return
|
||||
cc_pair.deletion_failure_message = failure_message
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_cc_pair_groups_for_ids(
|
||||
@@ -297,9 +313,9 @@ def associate_default_cc_pair(db_session: Session) -> None:
|
||||
association = ConnectorCredentialPair(
|
||||
connector_id=0,
|
||||
credential_id=0,
|
||||
access_type=AccessType.PUBLIC,
|
||||
name="DefaultCCPair",
|
||||
status=ConnectorCredentialPairStatus.ACTIVE,
|
||||
is_public=True,
|
||||
)
|
||||
db_session.add(association)
|
||||
db_session.commit()
|
||||
@@ -324,8 +340,9 @@ def add_credential_to_connector(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
cc_pair_name: str | None,
|
||||
is_public: bool,
|
||||
access_type: AccessType,
|
||||
groups: list[int] | None,
|
||||
auto_sync_options: dict | None = None,
|
||||
) -> StatusResponse:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
@@ -333,10 +350,21 @@ def add_credential_to_connector(
|
||||
if connector is None:
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
|
||||
if access_type == AccessType.SYNC:
|
||||
if not check_if_valid_sync_source(connector.source):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Connector of type {connector.source} does not support SYNC access type",
|
||||
)
|
||||
|
||||
if credential is None:
|
||||
error_msg = (
|
||||
f"Credential {credential_id} does not exist or does not belong to user"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Credential does not exist or does not belong to user",
|
||||
detail=error_msg,
|
||||
)
|
||||
|
||||
existing_association = (
|
||||
@@ -350,7 +378,7 @@ def add_credential_to_connector(
|
||||
if existing_association is not None:
|
||||
return StatusResponse(
|
||||
success=False,
|
||||
message=f"Connector already has Credential {credential_id}",
|
||||
message=f"Connector {connector_id} already has Credential {credential_id}",
|
||||
data=connector_id,
|
||||
)
|
||||
|
||||
@@ -359,12 +387,13 @@ def add_credential_to_connector(
|
||||
credential_id=credential_id,
|
||||
name=cc_pair_name,
|
||||
status=ConnectorCredentialPairStatus.ACTIVE,
|
||||
is_public=is_public,
|
||||
access_type=access_type,
|
||||
auto_sync_options=auto_sync_options,
|
||||
)
|
||||
db_session.add(association)
|
||||
db_session.flush() # make sure the association has an id
|
||||
|
||||
if groups:
|
||||
if groups and access_type != AccessType.SYNC:
|
||||
_relate_groups_to_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=association.id,
|
||||
@@ -374,8 +403,8 @@ def add_credential_to_connector(
|
||||
db_session.commit()
|
||||
|
||||
return StatusResponse(
|
||||
success=False,
|
||||
message=f"Connector already has Credential {credential_id}",
|
||||
success=True,
|
||||
message=f"Creating new association between Connector {connector_id} and Credential {credential_id}",
|
||||
data=association.id,
|
||||
)
|
||||
|
||||
@@ -407,6 +436,10 @@ def remove_credential_from_connector(
|
||||
)
|
||||
|
||||
if association is not None:
|
||||
delete_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=association.id,
|
||||
)
|
||||
db_session.delete(association)
|
||||
db_session.commit()
|
||||
return StatusResponse(
|
||||
|
||||
@@ -3,26 +3,30 @@ import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.engine.util import TransactionalContext
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import null
|
||||
|
||||
from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.feedback import delete_document_feedback_for_documents__no_commit
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import Document as DbDocument
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.models import User
|
||||
from danswer.db.tag import delete_document_tags_for_documents__no_commit
|
||||
from danswer.db.utils import model_to_dict
|
||||
from danswer.document_index.interfaces import DocumentMetadata
|
||||
@@ -38,6 +42,68 @@ def check_docs_exist(db_session: Session) -> bool:
|
||||
return result.scalar() or False
|
||||
|
||||
|
||||
def count_documents_by_needs_sync(session: Session) -> int:
|
||||
"""Get the count of all documents where:
|
||||
1. last_modified is newer than last_synced
|
||||
2. last_synced is null (meaning we've never synced)
|
||||
|
||||
This function executes the query and returns the count of
|
||||
documents matching the criteria."""
|
||||
|
||||
count = (
|
||||
session.query(func.count())
|
||||
.select_from(DbDocument)
|
||||
.filter(
|
||||
or_(
|
||||
DbDocument.last_modified > DbDocument.last_synced,
|
||||
DbDocument.last_synced.is_(None),
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
connector_id: int, credential_id: int
|
||||
) -> Select:
|
||||
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
)
|
||||
|
||||
stmt = (
|
||||
select(DbDocument)
|
||||
.where(
|
||||
DbDocument.id.in_(initial_doc_ids_stmt),
|
||||
or_(
|
||||
DbDocument.last_modified
|
||||
> DbDocument.last_synced, # last_modified is newer than last_synced
|
||||
DbDocument.last_synced.is_(None), # never synced
|
||||
),
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
|
||||
return stmt
|
||||
|
||||
|
||||
def construct_document_select_for_connector_credential_pair(
|
||||
connector_id: int, credential_id: int | None = None
|
||||
) -> Select:
|
||||
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
)
|
||||
stmt = select(DbDocument).where(DbDocument.id.in_(initial_doc_ids_stmt)).distinct()
|
||||
return stmt
|
||||
|
||||
|
||||
def get_documents_for_connector_credential_pair(
|
||||
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
|
||||
) -> Sequence[DbDocument]:
|
||||
@@ -62,7 +128,18 @@ def get_documents_by_ids(
|
||||
return list(documents)
|
||||
|
||||
|
||||
def get_document_connector_cnts(
|
||||
def get_document_connector_count(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
) -> int:
|
||||
results = get_document_connector_counts(db_session, [document_id])
|
||||
if not results or len(results) == 0:
|
||||
return 0
|
||||
|
||||
return results[0][1]
|
||||
|
||||
|
||||
def get_document_connector_counts(
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
) -> Sequence[tuple[str, int]]:
|
||||
@@ -77,7 +154,7 @@ def get_document_connector_cnts(
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
|
||||
|
||||
def get_document_cnts_for_cc_pairs(
|
||||
def get_document_counts_for_cc_pairs(
|
||||
db_session: Session, cc_pair_identifiers: list[ConnectorCredentialPairIdentifier]
|
||||
) -> Sequence[tuple[int, int, int]]:
|
||||
stmt = (
|
||||
@@ -108,22 +185,50 @@ def get_document_cnts_for_cc_pairs(
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
|
||||
|
||||
def get_acccess_info_for_documents(
|
||||
def get_access_info_for_document(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
) -> tuple[str, list[str | None], bool] | None:
|
||||
"""Gets access info for a single document by calling the get_access_info_for_documents function
|
||||
and passing a list with a single document ID.
|
||||
Args:
|
||||
db_session (Session): The database session to use.
|
||||
document_id (str): The document ID to fetch access info for.
|
||||
Returns:
|
||||
Optional[Tuple[str, List[str | None], bool]]: A tuple containing the document ID, a list of user emails,
|
||||
and a boolean indicating if the document is globally public, or None if no results are found.
|
||||
"""
|
||||
results = get_access_info_for_documents(db_session, [document_id])
|
||||
if not results:
|
||||
return None
|
||||
|
||||
return results[0]
|
||||
|
||||
|
||||
def get_access_info_for_documents(
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
) -> Sequence[tuple[str, list[UUID | None], bool]]:
|
||||
) -> Sequence[tuple[str, list[str | None], bool]]:
|
||||
"""Gets back all relevant access info for the given documents. This includes
|
||||
the user_ids for cc pairs that the document is associated with + whether any
|
||||
of the associated cc pairs are intending to make the document globally public.
|
||||
Returns the list where each element contains:
|
||||
- Document ID (which is also the ID of the DocumentByConnectorCredentialPair)
|
||||
- List of emails of Danswer users with direct access to the doc (includes a "None" element if
|
||||
the connector was set up by an admin when auth was off
|
||||
- bool for whether the document is public (the document later can also be marked public by
|
||||
automatic permission sync step)
|
||||
"""
|
||||
stmt = select(
|
||||
DocumentByConnectorCredentialPair.id,
|
||||
func.array_agg(func.coalesce(User.email, null())).label("user_emails"),
|
||||
func.bool_or(ConnectorCredentialPair.access_type == AccessType.PUBLIC).label(
|
||||
"public_doc"
|
||||
),
|
||||
).where(DocumentByConnectorCredentialPair.id.in_(document_ids))
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
DocumentByConnectorCredentialPair.id,
|
||||
func.array_agg(Credential.user_id).label("user_ids"),
|
||||
func.bool_or(ConnectorCredentialPair.is_public).label("public_doc"),
|
||||
)
|
||||
.where(DocumentByConnectorCredentialPair.id.in_(document_ids))
|
||||
.join(
|
||||
stmt.join(
|
||||
Credential,
|
||||
DocumentByConnectorCredentialPair.credential_id == Credential.id,
|
||||
)
|
||||
@@ -136,6 +241,13 @@ def get_acccess_info_for_documents(
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.outerjoin(
|
||||
User,
|
||||
and_(
|
||||
Credential.user_id == User.id,
|
||||
ConnectorCredentialPair.access_type != AccessType.SYNC,
|
||||
),
|
||||
)
|
||||
# don't include CC pairs that are being deleted
|
||||
# NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them
|
||||
.where(ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING)
|
||||
@@ -173,6 +285,7 @@ def upsert_documents(
|
||||
semantic_id=doc.semantic_identifier,
|
||||
link=doc.first_link,
|
||||
doc_updated_at=None, # this is intentional
|
||||
last_modified=datetime.now(timezone.utc),
|
||||
primary_owners=doc.primary_owners,
|
||||
secondary_owners=doc.secondary_owners,
|
||||
)
|
||||
@@ -180,9 +293,19 @@ def upsert_documents(
|
||||
for doc in seen_documents.values()
|
||||
]
|
||||
)
|
||||
# for now, there are no columns to update. If more metadata is added, then this
|
||||
# needs to change to an `on_conflict_do_update`
|
||||
on_conflict_stmt = insert_stmt.on_conflict_do_nothing()
|
||||
|
||||
on_conflict_stmt = insert_stmt.on_conflict_do_update(
|
||||
index_elements=["id"], # Conflict target
|
||||
set_={
|
||||
"from_ingestion_api": insert_stmt.excluded.from_ingestion_api,
|
||||
"boost": insert_stmt.excluded.boost,
|
||||
"hidden": insert_stmt.excluded.hidden,
|
||||
"semantic_id": insert_stmt.excluded.semantic_id,
|
||||
"link": insert_stmt.excluded.link,
|
||||
"primary_owners": insert_stmt.excluded.primary_owners,
|
||||
"secondary_owners": insert_stmt.excluded.secondary_owners,
|
||||
},
|
||||
)
|
||||
db_session.execute(on_conflict_stmt)
|
||||
db_session.commit()
|
||||
|
||||
@@ -214,7 +337,7 @@ def upsert_document_by_connector_credential_pair(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_docs_updated_at(
|
||||
def update_docs_updated_at__no_commit(
|
||||
ids_to_new_updated_at: dict[str, datetime],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
@@ -226,6 +349,28 @@ def update_docs_updated_at(
|
||||
for document in documents_to_update:
|
||||
document.doc_updated_at = ids_to_new_updated_at[document.id]
|
||||
|
||||
|
||||
def update_docs_last_modified__no_commit(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
documents_to_update = (
|
||||
db_session.query(DbDocument).filter(DbDocument.id.in_(document_ids)).all()
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
for doc in documents_to_update:
|
||||
doc.last_modified = now
|
||||
|
||||
|
||||
def mark_document_as_synced(document_id: str, db_session: Session) -> None:
|
||||
stmt = select(DbDocument).where(DbDocument.id == document_id)
|
||||
doc = db_session.scalar(stmt)
|
||||
if doc is None:
|
||||
raise ValueError(f"No document with ID: {document_id}")
|
||||
|
||||
# update last_synced
|
||||
doc.last_synced = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -241,11 +386,34 @@ def upsert_documents_complete(
|
||||
|
||||
|
||||
def delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
connector_credential_pair_identifier: ConnectorCredentialPairIdentifier
|
||||
| None = None,
|
||||
) -> None:
|
||||
"""Deletes a single document by cc pair relationship entry.
|
||||
Foreign key rows are left in place.
|
||||
The implicit assumption is that the document itself still has other cc_pair
|
||||
references and needs to continue existing.
|
||||
"""
|
||||
delete_documents_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=[document_id],
|
||||
connector_credential_pair_identifier=connector_credential_pair_identifier,
|
||||
)
|
||||
|
||||
|
||||
def delete_documents_by_connector_credential_pair__no_commit(
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
connector_credential_pair_identifier: ConnectorCredentialPairIdentifier
|
||||
| None = None,
|
||||
) -> None:
|
||||
"""This deletes just the document by cc pair entries for a particular cc pair.
|
||||
Foreign key rows are left in place.
|
||||
The implicit assumption is that the document itself still has other cc_pair
|
||||
references and needs to continue existing.
|
||||
"""
|
||||
stmt = delete(DocumentByConnectorCredentialPair).where(
|
||||
DocumentByConnectorCredentialPair.id.in_(document_ids)
|
||||
)
|
||||
@@ -268,8 +436,9 @@ def delete_documents__no_commit(db_session: Session, document_ids: list[str]) ->
|
||||
def delete_documents_complete__no_commit(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> None:
|
||||
"""This completely deletes the documents from the db, including all foreign key relationships"""
|
||||
logger.info(f"Deleting {len(document_ids)} documents from the DB")
|
||||
delete_document_by_connector_credential_pair__no_commit(db_session, document_ids)
|
||||
delete_documents_by_connector_credential_pair__no_commit(db_session, document_ids)
|
||||
delete_document_feedback_for_documents__no_commit(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
)
|
||||
@@ -379,3 +548,12 @@ def get_documents_by_cc_pair(
|
||||
.filter(ConnectorCredentialPair.id == cc_pair_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_document(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> DbDocument | None:
|
||||
stmt = select(DbDocument).where(DbDocument.id == document_id)
|
||||
doc: DbDocument | None = db_session.execute(stmt).scalar_one_or_none()
|
||||
return doc
|
||||
|
||||
@@ -14,6 +14,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.connector_credential_pair import get_cc_pair_groups_for_ids
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Document
|
||||
@@ -180,7 +181,7 @@ def _check_if_cc_pairs_are_owned_by_groups(
|
||||
ids=missing_cc_pair_ids,
|
||||
)
|
||||
for cc_pair in cc_pairs:
|
||||
if not cc_pair.is_public:
|
||||
if cc_pair.access_type != AccessType.PUBLIC:
|
||||
raise ValueError(
|
||||
f"Connector Credential Pair with ID: '{cc_pair.id}'"
|
||||
" is not owned by the specified groups"
|
||||
@@ -248,6 +249,10 @@ def update_document_set(
|
||||
document_set_update_request: DocumentSetUpdateRequest,
|
||||
user: User | None = None,
|
||||
) -> tuple[DocumentSetDBModel, list[DocumentSet__ConnectorCredentialPair]]:
|
||||
"""If successful, this sets document_set_row.is_up_to_date = False.
|
||||
That will be processed via Celery in check_for_vespa_sync_task
|
||||
and trigger a long running background sync to Vespa.
|
||||
"""
|
||||
if not document_set_update_request.cc_pair_ids:
|
||||
# It's cc-pairs in actuality but the UI displays this error
|
||||
raise ValueError("Cannot create a document set with no Connectors")
|
||||
@@ -519,11 +524,101 @@ def fetch_documents_for_document_set_paginated(
|
||||
return documents, documents[-1].id if documents else None
|
||||
|
||||
|
||||
def construct_document_select_by_docset(
|
||||
document_set_id: int,
|
||||
current_only: bool = True,
|
||||
) -> Select:
|
||||
"""This returns a statement that should be executed using
|
||||
.yield_per() to minimize overhead. The primary consumers of this function
|
||||
are background processing task generators."""
|
||||
|
||||
stmt = (
|
||||
select(Document)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
DocumentByConnectorCredentialPair.id == Document.id,
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
and_(
|
||||
ConnectorCredentialPair.connector_id
|
||||
== DocumentByConnectorCredentialPair.connector_id,
|
||||
ConnectorCredentialPair.credential_id
|
||||
== DocumentByConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.join(
|
||||
DocumentSet__ConnectorCredentialPair,
|
||||
DocumentSet__ConnectorCredentialPair.connector_credential_pair_id
|
||||
== ConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
DocumentSetDBModel,
|
||||
DocumentSetDBModel.id
|
||||
== DocumentSet__ConnectorCredentialPair.document_set_id,
|
||||
)
|
||||
.where(DocumentSetDBModel.id == document_set_id)
|
||||
.order_by(Document.id)
|
||||
)
|
||||
|
||||
if current_only:
|
||||
stmt = stmt.where(
|
||||
DocumentSet__ConnectorCredentialPair.is_current == True # noqa: E712
|
||||
)
|
||||
|
||||
stmt = stmt.distinct()
|
||||
return stmt
|
||||
|
||||
|
||||
def fetch_document_sets_for_document(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Fetches the document set names for a single document ID.
|
||||
|
||||
:param document_id: The ID of the document to fetch sets for.
|
||||
:param db_session: The SQLAlchemy session to use for the query.
|
||||
:return: A list of document set names, or None if no result is found.
|
||||
"""
|
||||
result = fetch_document_sets_for_documents([document_id], db_session)
|
||||
if not result:
|
||||
return []
|
||||
|
||||
return result[0][1]
|
||||
|
||||
|
||||
def fetch_document_sets_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> Sequence[tuple[str, list[str]]]:
|
||||
"""Gives back a list of (document_id, list[document_set_names]) tuples"""
|
||||
|
||||
"""Building subqueries"""
|
||||
# NOTE: have to build these subqueries first in order to guarantee that we get one
|
||||
# returned row for each specified document_id. Basically, we want to do the filters first,
|
||||
# then the outer joins.
|
||||
|
||||
# don't include CC pairs that are being deleted
|
||||
# NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them
|
||||
# as we can assume their document sets are no longer relevant
|
||||
valid_cc_pairs_subquery = aliased(
|
||||
ConnectorCredentialPair,
|
||||
select(ConnectorCredentialPair)
|
||||
.where(
|
||||
ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING
|
||||
) # noqa: E712
|
||||
.subquery(),
|
||||
)
|
||||
|
||||
valid_document_set__cc_pairs_subquery = aliased(
|
||||
DocumentSet__ConnectorCredentialPair,
|
||||
select(DocumentSet__ConnectorCredentialPair)
|
||||
.where(DocumentSet__ConnectorCredentialPair.is_current == True) # noqa: E712
|
||||
.subquery(),
|
||||
)
|
||||
"""End building subqueries"""
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
Document.id,
|
||||
@@ -531,39 +626,33 @@ def fetch_document_sets_for_documents(
|
||||
func.array_remove(func.array_agg(DocumentSetDBModel.name), None), []
|
||||
).label("document_set_names"),
|
||||
)
|
||||
# Here we select document sets by relation:
|
||||
# Document -> DocumentByConnectorCredentialPair -> ConnectorCredentialPair ->
|
||||
# DocumentSet__ConnectorCredentialPair -> DocumentSet
|
||||
.outerjoin(
|
||||
DocumentByConnectorCredentialPair,
|
||||
Document.id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.outerjoin(
|
||||
ConnectorCredentialPair,
|
||||
valid_cc_pairs_subquery,
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id
|
||||
== ConnectorCredentialPair.connector_id,
|
||||
== valid_cc_pairs_subquery.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
== valid_cc_pairs_subquery.credential_id,
|
||||
),
|
||||
)
|
||||
.outerjoin(
|
||||
DocumentSet__ConnectorCredentialPair,
|
||||
ConnectorCredentialPair.id
|
||||
== DocumentSet__ConnectorCredentialPair.connector_credential_pair_id,
|
||||
valid_document_set__cc_pairs_subquery,
|
||||
valid_cc_pairs_subquery.id
|
||||
== valid_document_set__cc_pairs_subquery.connector_credential_pair_id,
|
||||
)
|
||||
.outerjoin(
|
||||
DocumentSetDBModel,
|
||||
DocumentSetDBModel.id
|
||||
== DocumentSet__ConnectorCredentialPair.document_set_id,
|
||||
== valid_document_set__cc_pairs_subquery.document_set_id,
|
||||
)
|
||||
.where(Document.id.in_(document_ids))
|
||||
# don't include CC pairs that are being deleted
|
||||
# NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them
|
||||
# as we can assume their document sets are no longer relevant
|
||||
.where(
|
||||
ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING,
|
||||
)
|
||||
.where(
|
||||
DocumentSet__ConnectorCredentialPair.is_current == True, # noqa: E712
|
||||
)
|
||||
.group_by(Document.id)
|
||||
)
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
@@ -616,7 +705,7 @@ def check_document_sets_are_public(
|
||||
ConnectorCredentialPair.id.in_(
|
||||
connector_credential_pair_ids # type:ignore
|
||||
),
|
||||
ConnectorCredentialPair.is_public.is_(False),
|
||||
ConnectorCredentialPair.access_type != AccessType.PUBLIC,
|
||||
)
|
||||
.limit(1)
|
||||
.first()
|
||||
|
||||
@@ -137,8 +137,8 @@ def get_sqlalchemy_engine() -> Engine:
|
||||
)
|
||||
_SYNC_ENGINE = create_engine(
|
||||
connection_string,
|
||||
pool_size=40,
|
||||
max_overflow=10,
|
||||
pool_size=5,
|
||||
max_overflow=0,
|
||||
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
||||
pool_recycle=POSTGRES_POOL_RECYCLE,
|
||||
)
|
||||
@@ -156,8 +156,8 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
connect_args={
|
||||
"server_settings": {"application_name": POSTGRES_APP_NAME + "_async"}
|
||||
},
|
||||
pool_size=40,
|
||||
max_overflow=10,
|
||||
pool_size=5,
|
||||
max_overflow=0,
|
||||
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
||||
pool_recycle=POSTGRES_POOL_RECYCLE,
|
||||
)
|
||||
|
||||
@@ -51,3 +51,9 @@ class ConnectorCredentialPairStatus(str, PyEnum):
|
||||
|
||||
def is_active(self) -> bool:
|
||||
return self == ConnectorCredentialPairStatus.ACTIVE
|
||||
|
||||
|
||||
class AccessType(str, PyEnum):
|
||||
PUBLIC = "public"
|
||||
PRIVATE = "private"
|
||||
SYNC = "sync"
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -14,6 +16,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.db.chat import get_chat_message
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.models import ChatMessageFeedback
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Document as DbDocument
|
||||
@@ -24,7 +27,6 @@ from danswer.db.models import User__UserGroup
|
||||
from danswer.db.models import UserGroup__ConnectorCredentialPair
|
||||
from danswer.db.models import UserRole
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -93,7 +95,7 @@ def _add_user_filters(
|
||||
.correlate(CCPair)
|
||||
)
|
||||
else:
|
||||
where_clause |= CCPair.is_public == True # noqa: E712
|
||||
where_clause |= CCPair.access_type == AccessType.PUBLIC
|
||||
|
||||
return stmt.where(where_clause)
|
||||
|
||||
@@ -123,12 +125,11 @@ def update_document_boost(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
boost: int,
|
||||
document_index: DocumentIndex,
|
||||
user: User | None = None,
|
||||
) -> None:
|
||||
stmt = select(DbDocument).where(DbDocument.id == document_id)
|
||||
stmt = _add_user_filters(stmt, user, get_editable=True)
|
||||
result = db_session.execute(stmt).scalar_one_or_none()
|
||||
result: DbDocument | None = db_session.execute(stmt).scalar_one_or_none()
|
||||
if result is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Document is not editable by this user"
|
||||
@@ -136,13 +137,9 @@ def update_document_boost(
|
||||
|
||||
result.boost = boost
|
||||
|
||||
update = UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
boost=boost,
|
||||
)
|
||||
|
||||
document_index.update(update_requests=[update])
|
||||
|
||||
# updating last_modified triggers sync
|
||||
# TODO: Should this submit to the queue directly so that the UI can update?
|
||||
result.last_modified = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -163,13 +160,9 @@ def update_document_hidden(
|
||||
|
||||
result.hidden = hidden
|
||||
|
||||
update = UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
hidden=hidden,
|
||||
)
|
||||
|
||||
document_index.update(update_requests=[update])
|
||||
|
||||
# updating last_modified triggers sync
|
||||
# TODO: Should this submit to the queue directly so that the UI can update?
|
||||
result.last_modified = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -210,11 +203,9 @@ def create_doc_retrieval_feedback(
|
||||
SearchFeedbackType.REJECT,
|
||||
SearchFeedbackType.HIDE,
|
||||
]:
|
||||
update = UpdateRequest(
|
||||
document_ids=[document_id], boost=db_doc.boost, hidden=db_doc.hidden
|
||||
)
|
||||
# Updates are generally batched for efficiency, this case only 1 doc/value is updated
|
||||
document_index.update(update_requests=[update])
|
||||
# updating last_modified triggers sync
|
||||
# TODO: Should this submit to the queue directly so that the UI can update?
|
||||
db_doc.last_modified = datetime.now(timezone.utc)
|
||||
|
||||
db_session.add(retrieval_feedback)
|
||||
db_session.commit()
|
||||
|
||||
@@ -181,6 +181,45 @@ def get_last_attempt(
|
||||
return db_session.execute(stmt).scalars().first()
|
||||
|
||||
|
||||
def get_latest_index_attempts_by_status(
|
||||
secondary_index: bool,
|
||||
db_session: Session,
|
||||
status: IndexingStatus,
|
||||
) -> Sequence[IndexAttempt]:
|
||||
"""
|
||||
Retrieves the most recent index attempt with the specified status for each connector_credential_pair.
|
||||
Filters attempts based on the secondary_index flag to get either future or present index attempts.
|
||||
Returns a sequence of IndexAttempt objects, one for each unique connector_credential_pair.
|
||||
"""
|
||||
latest_failed_attempts = (
|
||||
select(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.max(IndexAttempt.id).label("max_failed_id"),
|
||||
)
|
||||
.join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
|
||||
.where(
|
||||
SearchSettings.status
|
||||
== (
|
||||
IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
|
||||
),
|
||||
IndexAttempt.status == status,
|
||||
)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
stmt = select(IndexAttempt).join(
|
||||
latest_failed_attempts,
|
||||
(
|
||||
IndexAttempt.connector_credential_pair_id
|
||||
== latest_failed_attempts.c.connector_credential_pair_id
|
||||
)
|
||||
& (IndexAttempt.id == latest_failed_attempts.c.max_failed_id),
|
||||
)
|
||||
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def get_latest_index_attempts(
|
||||
secondary_index: bool,
|
||||
db_session: Session,
|
||||
@@ -211,12 +250,41 @@ def get_latest_index_attempts(
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def get_index_attempts_for_connector(
|
||||
def count_index_attempts_for_connector(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
only_current: bool = True,
|
||||
disinclude_finished: bool = False,
|
||||
) -> Sequence[IndexAttempt]:
|
||||
) -> int:
|
||||
stmt = (
|
||||
select(IndexAttempt)
|
||||
.join(ConnectorCredentialPair)
|
||||
.where(ConnectorCredentialPair.connector_id == connector_id)
|
||||
)
|
||||
if disinclude_finished:
|
||||
stmt = stmt.where(
|
||||
IndexAttempt.status.in_(
|
||||
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
|
||||
)
|
||||
)
|
||||
if only_current:
|
||||
stmt = stmt.join(SearchSettings).where(
|
||||
SearchSettings.status == IndexModelStatus.PRESENT
|
||||
)
|
||||
# Count total items for pagination
|
||||
count_stmt = stmt.with_only_columns(func.count()).order_by(None)
|
||||
total_count = db_session.execute(count_stmt).scalar_one()
|
||||
return total_count
|
||||
|
||||
|
||||
def get_paginated_index_attempts_for_cc_pair_id(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
page: int,
|
||||
page_size: int,
|
||||
only_current: bool = True,
|
||||
disinclude_finished: bool = False,
|
||||
) -> list[IndexAttempt]:
|
||||
stmt = (
|
||||
select(IndexAttempt)
|
||||
.join(ConnectorCredentialPair)
|
||||
@@ -233,22 +301,30 @@ def get_index_attempts_for_connector(
|
||||
SearchSettings.status == IndexModelStatus.PRESENT
|
||||
)
|
||||
|
||||
stmt = stmt.order_by(IndexAttempt.time_created.desc())
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
stmt = stmt.order_by(IndexAttempt.time_started.desc())
|
||||
|
||||
# Apply pagination
|
||||
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
return list(db_session.execute(stmt).scalars().all())
|
||||
|
||||
|
||||
def get_latest_finished_index_attempt_for_cc_pair(
|
||||
def get_latest_index_attempt_for_cc_pair_id(
|
||||
db_session: Session,
|
||||
connector_credential_pair_id: int,
|
||||
secondary_index: bool,
|
||||
db_session: Session,
|
||||
only_finished: bool = True,
|
||||
) -> IndexAttempt | None:
|
||||
stmt = select(IndexAttempt).distinct()
|
||||
stmt = select(IndexAttempt)
|
||||
stmt = stmt.where(
|
||||
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
|
||||
IndexAttempt.status.not_in(
|
||||
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
|
||||
),
|
||||
)
|
||||
if only_finished:
|
||||
stmt = stmt.where(
|
||||
IndexAttempt.status.not_in(
|
||||
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
|
||||
),
|
||||
)
|
||||
if secondary_index:
|
||||
stmt = stmt.join(SearchSettings).where(
|
||||
SearchSettings.status == IndexModelStatus.FUTURE
|
||||
@@ -295,14 +371,21 @@ def get_index_attempts_for_cc_pair(
|
||||
|
||||
|
||||
def delete_index_attempts(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
# First, delete related entries in IndexAttemptErrors
|
||||
stmt_errors = delete(IndexAttemptError).where(
|
||||
IndexAttemptError.index_attempt_id.in_(
|
||||
select(IndexAttempt.id).where(
|
||||
IndexAttempt.connector_credential_pair_id == cc_pair_id
|
||||
)
|
||||
)
|
||||
)
|
||||
db_session.execute(stmt_errors)
|
||||
|
||||
stmt = delete(IndexAttempt).where(
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
IndexAttempt.connector_credential_pair_id == cc_pair_id,
|
||||
)
|
||||
|
||||
db_session.execute(stmt)
|
||||
|
||||
@@ -4,8 +4,11 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||
from danswer.db.models import LLMProvider__UserGroup
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.models import Tool as ToolModel
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
@@ -50,6 +53,7 @@ def upsert_cloud_embedding_provider(
|
||||
setattr(existing_provider, key, value)
|
||||
else:
|
||||
new_provider = CloudEmbeddingProviderModel(**provider.model_dump())
|
||||
|
||||
db_session.add(new_provider)
|
||||
existing_provider = new_provider
|
||||
db_session.commit()
|
||||
@@ -58,13 +62,21 @@ def upsert_cloud_embedding_provider(
|
||||
|
||||
|
||||
def upsert_llm_provider(
|
||||
db_session: Session, llm_provider: LLMProviderUpsertRequest
|
||||
llm_provider: LLMProviderUpsertRequest,
|
||||
db_session: Session,
|
||||
is_creation: bool = True,
|
||||
) -> FullLLMProvider:
|
||||
existing_llm_provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
|
||||
)
|
||||
if existing_llm_provider and is_creation:
|
||||
raise ValueError(f"LLM Provider with name {llm_provider.name} already exists")
|
||||
|
||||
if not existing_llm_provider:
|
||||
if not is_creation:
|
||||
raise ValueError(
|
||||
f"LLM Provider with name {llm_provider.name} does not exist"
|
||||
)
|
||||
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
|
||||
db_session.add(existing_llm_provider)
|
||||
|
||||
@@ -101,6 +113,20 @@ def fetch_existing_embedding_providers(
|
||||
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
|
||||
|
||||
|
||||
def fetch_existing_doc_sets(
|
||||
db_session: Session, doc_ids: list[int]
|
||||
) -> list[DocumentSet]:
|
||||
return list(
|
||||
db_session.scalars(select(DocumentSet).where(DocumentSet.id.in_(doc_ids))).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolModel]:
|
||||
return list(
|
||||
db_session.scalars(select(ToolModel).where(ToolModel.id.in_(tool_ids))).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_existing_llm_providers(
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
@@ -157,12 +183,19 @@ def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider |
|
||||
def remove_embedding_provider(
|
||||
db_session: Session, provider_type: EmbeddingProvider
|
||||
) -> None:
|
||||
db_session.execute(
|
||||
delete(SearchSettings).where(SearchSettings.provider_type == provider_type)
|
||||
)
|
||||
|
||||
# Delete the embedding provider
|
||||
db_session.execute(
|
||||
delete(CloudEmbeddingProviderModel).where(
|
||||
CloudEmbeddingProviderModel.provider_type == provider_type
|
||||
)
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
||||
# Remove LLMProvider's dependent relationships
|
||||
@@ -178,7 +211,7 @@ def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_default_provider(db_session: Session, provider_id: int) -> None:
|
||||
def update_default_provider(provider_id: int, db_session: Session) -> None:
|
||||
new_default = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
)
|
||||
|
||||
@@ -39,6 +39,7 @@ from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.configs.constants import NotificationType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.configs.constants import TokenRateLimitScope
|
||||
@@ -61,7 +62,7 @@ from shared_configs.enums import RerankerProvider
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
__abstract__ = True
|
||||
|
||||
|
||||
class EncryptedString(TypeDecorator):
|
||||
@@ -108,7 +109,7 @@ class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
|
||||
|
||||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
|
||||
"OAuthAccount", lazy="joined"
|
||||
"OAuthAccount", lazy="joined", cascade="all, delete-orphan"
|
||||
)
|
||||
role: Mapped[UserRole] = mapped_column(
|
||||
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
|
||||
@@ -122,7 +123,13 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
# if specified, controls the assistants that are shown to the user + their order
|
||||
# if not specified, all assistants are shown
|
||||
chosen_assistants: Mapped[list[int]] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
postgresql.JSONB(), nullable=False, default=[-2, -1, 0]
|
||||
)
|
||||
visible_assistants: Mapped[list[int]] = mapped_column(
|
||||
postgresql.JSONB(), nullable=False, default=[]
|
||||
)
|
||||
hidden_assistants: Mapped[list[int]] = mapped_column(
|
||||
postgresql.JSONB(), nullable=False, default=[]
|
||||
)
|
||||
|
||||
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
|
||||
@@ -157,6 +164,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
notifications: Mapped[list["Notification"]] = relationship(
|
||||
"Notification", back_populates="user"
|
||||
)
|
||||
# Whether the user has logged in via web. False if user has only used Danswer through Slack bot
|
||||
has_web_login: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
|
||||
class InputPrompt(Base):
|
||||
@@ -168,7 +177,9 @@ class InputPrompt(Base):
|
||||
active: Mapped[bool] = mapped_column(Boolean)
|
||||
user: Mapped[User | None] = relationship("User", back_populates="input_prompts")
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class InputPrompt__User(Base):
|
||||
@@ -212,7 +223,9 @@ class Notification(Base):
|
||||
notif_type: Mapped[NotificationType] = mapped_column(
|
||||
Enum(NotificationType, native_enum=False)
|
||||
)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
dismissed: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
last_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
|
||||
first_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
|
||||
@@ -247,7 +260,7 @@ class Persona__User(Base):
|
||||
|
||||
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id"), primary_key=True, nullable=True
|
||||
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True, nullable=True
|
||||
)
|
||||
|
||||
|
||||
@@ -258,7 +271,7 @@ class DocumentSet__User(Base):
|
||||
ForeignKey("document_set.id"), primary_key=True
|
||||
)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id"), primary_key=True, nullable=True
|
||||
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True, nullable=True
|
||||
)
|
||||
|
||||
|
||||
@@ -373,16 +386,29 @@ class ConnectorCredentialPair(Base):
|
||||
connector_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("connector.id"), primary_key=True
|
||||
)
|
||||
|
||||
deletion_failure_message: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
credential_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("credential.id"), primary_key=True
|
||||
)
|
||||
# controls whether the documents indexed by this CC pair are visible to all
|
||||
# or if they are only visible to those with that are given explicit access
|
||||
# (e.g. via owning the credential or being a part of a group that is given access)
|
||||
is_public: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
default=True,
|
||||
nullable=False,
|
||||
access_type: Mapped[AccessType] = mapped_column(
|
||||
Enum(AccessType, native_enum=False), nullable=False
|
||||
)
|
||||
|
||||
# special info needed for the auto-sync feature. The exact structure depends on the
|
||||
|
||||
# source type (defined in the connector's `source` field)
|
||||
# E.g. for google_drive perm sync:
|
||||
# {"customer_id": "123567", "company_domain": "@danswer.ai"}
|
||||
auto_sync_options: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
last_time_perm_sync: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
# Time finished, not used for calculating backend jobs which uses time started (created)
|
||||
last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column(
|
||||
@@ -413,6 +439,7 @@ class ConnectorCredentialPair(Base):
|
||||
|
||||
class Document(Base):
|
||||
__tablename__ = "document"
|
||||
# NOTE: if more sensitive data is added here for display, make sure to add user/group permission
|
||||
|
||||
# this should correspond to the ID of the document
|
||||
# (as is passed around in Danswer)
|
||||
@@ -426,12 +453,27 @@ class Document(Base):
|
||||
semantic_id: Mapped[str] = mapped_column(String)
|
||||
# First Section's link
|
||||
link: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# The updated time is also used as a measure of the last successful state of the doc
|
||||
# pulled from the source (to help skip reindexing already updated docs in case of
|
||||
# connector retries)
|
||||
# TODO: rename this column because it conflates the time of the source doc
|
||||
# with the local last modified time of the doc and any associated metadata
|
||||
# it should just be the server timestamp of the source doc
|
||||
doc_updated_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
# last time any vespa relevant row metadata or the doc changed.
|
||||
# does not include last_synced
|
||||
last_modified: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, index=True, default=func.now()
|
||||
)
|
||||
|
||||
# last successful sync to vespa
|
||||
last_synced: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True, index=True
|
||||
)
|
||||
# The following are not attached to User because the account/email may not be known
|
||||
# within Danswer
|
||||
# Something like the document creator
|
||||
@@ -441,14 +483,25 @@ class Document(Base):
|
||||
secondary_owners: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
# TODO if more sensitive data is added here for display, make sure to add user/group permission
|
||||
# Permission sync columns
|
||||
# Email addresses are saved at the document level for externally synced permissions
|
||||
# This is becuase the normal flow of assigning permissions is through the cc_pair
|
||||
# doesn't apply here
|
||||
external_user_emails: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
# These group ids have been prefixed by the source type
|
||||
external_user_group_ids: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
retrieval_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship(
|
||||
"DocumentRetrievalFeedback", back_populates="document"
|
||||
)
|
||||
tags = relationship(
|
||||
"Tag",
|
||||
secondary="document__tag",
|
||||
secondary=Document__Tag.__table__,
|
||||
back_populates="documents",
|
||||
)
|
||||
|
||||
@@ -465,7 +518,7 @@ class Tag(Base):
|
||||
|
||||
documents = relationship(
|
||||
"Document",
|
||||
secondary="document__tag",
|
||||
secondary=Document__Tag.__table__,
|
||||
back_populates="tags",
|
||||
)
|
||||
|
||||
@@ -521,7 +574,9 @@ class Credential(Base):
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
credential_json: Mapped[dict[str, Any]] = mapped_column(EncryptedJson())
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
# if `true`, then all Admins will have access to the credential
|
||||
admin_public: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
@@ -576,6 +631,8 @@ class SearchSettings(Base):
|
||||
Enum(RerankerProvider, native_enum=False), nullable=True
|
||||
)
|
||||
rerank_api_key: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
rerank_api_url: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
num_rerank: Mapped[int] = mapped_column(Integer, default=NUM_POSTPROCESSED_RESULTS)
|
||||
|
||||
cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship(
|
||||
@@ -607,6 +664,10 @@ class SearchSettings(Base):
|
||||
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
|
||||
cloud_provider='{self.cloud_provider.provider_type if self.cloud_provider else 'None'}')>"
|
||||
|
||||
@property
|
||||
def api_url(self) -> str | None:
|
||||
return self.cloud_provider.api_url if self.cloud_provider is not None else None
|
||||
|
||||
@property
|
||||
def api_key(self) -> str | None:
|
||||
return self.cloud_provider.api_key if self.cloud_provider is not None else None
|
||||
@@ -671,7 +732,11 @@ class IndexAttempt(Base):
|
||||
"SearchSettings", back_populates="index_attempts"
|
||||
)
|
||||
|
||||
error_rows = relationship("IndexAttemptError", back_populates="index_attempt")
|
||||
error_rows = relationship(
|
||||
"IndexAttemptError",
|
||||
back_populates="index_attempt",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
@@ -806,7 +871,7 @@ class SearchDoc(Base):
|
||||
|
||||
chat_messages = relationship(
|
||||
"ChatMessage",
|
||||
secondary="chat_message__search_doc",
|
||||
secondary=ChatMessage__SearchDoc.__table__,
|
||||
back_populates="search_docs",
|
||||
)
|
||||
|
||||
@@ -835,8 +900,12 @@ class ChatSession(Base):
|
||||
__tablename__ = "chat_session"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"))
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id"), nullable=True
|
||||
)
|
||||
description: Mapped[str] = mapped_column(Text)
|
||||
# One-shot direct answering, currently the two types of chats are not mixed
|
||||
one_shot: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
@@ -870,7 +939,6 @@ class ChatSession(Base):
|
||||
prompt_override: Mapped[PromptOverride | None] = mapped_column(
|
||||
PydanticType(PromptOverride), nullable=True
|
||||
)
|
||||
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
@@ -879,7 +947,6 @@ class ChatSession(Base):
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
user: Mapped[User] = relationship("User", back_populates="chat_sessions")
|
||||
folder: Mapped["ChatFolder"] = relationship(
|
||||
"ChatFolder", back_populates="chat_sessions"
|
||||
@@ -949,7 +1016,7 @@ class ChatMessage(Base):
|
||||
)
|
||||
search_docs: Mapped[list["SearchDoc"]] = relationship(
|
||||
"SearchDoc",
|
||||
secondary="chat_message__search_doc",
|
||||
secondary=ChatMessage__SearchDoc.__table__,
|
||||
back_populates="chat_messages",
|
||||
)
|
||||
# NOTE: Should always be attached to the `assistant` message.
|
||||
@@ -972,7 +1039,9 @@ class ChatFolder(Base):
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
# Only null if auth is off
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=0)
|
||||
|
||||
@@ -1085,6 +1154,7 @@ class CloudEmbeddingProvider(Base):
|
||||
provider_type: Mapped[EmbeddingProvider] = mapped_column(
|
||||
Enum(EmbeddingProvider), primary_key=True
|
||||
)
|
||||
api_url: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString())
|
||||
search_settings: Mapped[list["SearchSettings"]] = relationship(
|
||||
"SearchSettings",
|
||||
@@ -1102,7 +1172,9 @@ class DocumentSet(Base):
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
description: Mapped[str] = mapped_column(String)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
# Whether changes to the document set have been propagated
|
||||
is_up_to_date: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
# If `False`, then the document set is not visible to users who are not explicitly
|
||||
@@ -1146,7 +1218,9 @@ class Prompt(Base):
|
||||
__tablename__ = "prompt"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
description: Mapped[str] = mapped_column(String)
|
||||
system_prompt: Mapped[str] = mapped_column(Text)
|
||||
@@ -1181,9 +1255,13 @@ class Tool(Base):
|
||||
openapi_schema: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
|
||||
custom_headers: Mapped[list[dict[str, str]] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
# user who created / owns the tool. Will be None for built-in tools.
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
|
||||
user: Mapped[User | None] = relationship("User", back_populates="custom_tools")
|
||||
# Relationship to Persona through the association table
|
||||
@@ -1207,7 +1285,9 @@ class Persona(Base):
|
||||
__tablename__ = "persona"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
description: Mapped[str] = mapped_column(String)
|
||||
# Number of chunks to pass to the LLM for generation.
|
||||
@@ -1236,9 +1316,18 @@ class Persona(Base):
|
||||
starter_messages: Mapped[list[StarterMessage] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
# Default personas are configured via backend during deployment
|
||||
search_start_date: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), default=None
|
||||
)
|
||||
# Built-in personas are configured via backend during deployment
|
||||
# Treated specially (cannot be user edited etc.)
|
||||
default_persona: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
builtin_persona: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# Default personas are personas created by admins and are automatically added
|
||||
# to all users' assistants list.
|
||||
is_default_persona: Mapped[bool] = mapped_column(
|
||||
Boolean, default=False, nullable=False
|
||||
)
|
||||
# controls whether the persona is available to be selected by users
|
||||
is_visible: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
# controls the ordering of personas in the UI
|
||||
@@ -1289,10 +1378,10 @@ class Persona(Base):
|
||||
# Default personas loaded via yaml cannot have the same name
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"_default_persona_name_idx",
|
||||
"_builtin_persona_name_idx",
|
||||
"name",
|
||||
unique=True,
|
||||
postgresql_where=(default_persona == True), # noqa: E712
|
||||
postgresql_where=(builtin_persona == True), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1316,53 +1405,6 @@ class ChannelConfig(TypedDict):
|
||||
follow_up_tags: NotRequired[list[str]]
|
||||
|
||||
|
||||
class StandardAnswerCategory(Base):
|
||||
__tablename__ = "standard_answer_category"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
|
||||
"StandardAnswer",
|
||||
secondary=StandardAnswer__StandardAnswerCategory.__table__,
|
||||
back_populates="categories",
|
||||
)
|
||||
slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship(
|
||||
"SlackBotConfig",
|
||||
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
|
||||
back_populates="standard_answer_categories",
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswer(Base):
|
||||
__tablename__ = "standard_answer"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
keyword: Mapped[str] = mapped_column(String)
|
||||
answer: Mapped[str] = mapped_column(String)
|
||||
active: Mapped[bool] = mapped_column(Boolean)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"unique_keyword_active",
|
||||
keyword,
|
||||
active,
|
||||
unique=True,
|
||||
postgresql_where=(active == True), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
categories: Mapped[list[StandardAnswerCategory]] = relationship(
|
||||
"StandardAnswerCategory",
|
||||
secondary=StandardAnswer__StandardAnswerCategory.__table__,
|
||||
back_populates="standard_answers",
|
||||
)
|
||||
chat_messages: Mapped[list[ChatMessage]] = relationship(
|
||||
"ChatMessage",
|
||||
secondary=ChatMessage__StandardAnswer.__table__,
|
||||
back_populates="standard_answers",
|
||||
)
|
||||
|
||||
|
||||
class SlackBotResponseType(str, PyEnum):
|
||||
QUOTES = "quotes"
|
||||
CITATIONS = "citations"
|
||||
@@ -1388,7 +1430,7 @@ class SlackBotConfig(Base):
|
||||
)
|
||||
|
||||
persona: Mapped[Persona | None] = relationship("Persona")
|
||||
standard_answer_categories: Mapped[list[StandardAnswerCategory]] = relationship(
|
||||
standard_answer_categories: Mapped[list["StandardAnswerCategory"]] = relationship(
|
||||
"StandardAnswerCategory",
|
||||
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
|
||||
back_populates="slack_bot_configs",
|
||||
@@ -1400,7 +1442,7 @@ class TaskQueueState(Base):
|
||||
__tablename__ = "task_queue_jobs"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
# Celery task id
|
||||
# Celery task id. currently only for readability/diagnostics
|
||||
task_id: Mapped[str] = mapped_column(String)
|
||||
# For any job type, this would be the same
|
||||
task_name: Mapped[str] = mapped_column(String)
|
||||
@@ -1450,7 +1492,9 @@ class SamlAccount(Base):
|
||||
__tablename__ = "saml"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"), unique=True)
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), unique=True
|
||||
)
|
||||
encrypted_cookie: Mapped[str] = mapped_column(Text, unique=True)
|
||||
expires_at: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
@@ -1469,7 +1513,7 @@ class User__UserGroup(Base):
|
||||
ForeignKey("user_group.id"), primary_key=True
|
||||
)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id"), primary_key=True, nullable=True
|
||||
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True, nullable=True
|
||||
)
|
||||
|
||||
|
||||
@@ -1618,94 +1662,70 @@ class TokenRateLimit__UserGroup(Base):
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswerCategory(Base):
|
||||
__tablename__ = "standard_answer_category"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
|
||||
"StandardAnswer",
|
||||
secondary=StandardAnswer__StandardAnswerCategory.__table__,
|
||||
back_populates="categories",
|
||||
)
|
||||
slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship(
|
||||
"SlackBotConfig",
|
||||
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
|
||||
back_populates="standard_answer_categories",
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswer(Base):
|
||||
__tablename__ = "standard_answer"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
keyword: Mapped[str] = mapped_column(String)
|
||||
answer: Mapped[str] = mapped_column(String)
|
||||
active: Mapped[bool] = mapped_column(Boolean)
|
||||
match_regex: Mapped[bool] = mapped_column(Boolean)
|
||||
match_any_keywords: Mapped[bool] = mapped_column(Boolean)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"unique_keyword_active",
|
||||
keyword,
|
||||
active,
|
||||
unique=True,
|
||||
postgresql_where=(active == True), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
categories: Mapped[list[StandardAnswerCategory]] = relationship(
|
||||
"StandardAnswerCategory",
|
||||
secondary=StandardAnswer__StandardAnswerCategory.__table__,
|
||||
back_populates="standard_answers",
|
||||
)
|
||||
chat_messages: Mapped[list[ChatMessage]] = relationship(
|
||||
"ChatMessage",
|
||||
secondary=ChatMessage__StandardAnswer.__table__,
|
||||
back_populates="standard_answers",
|
||||
)
|
||||
|
||||
|
||||
"""Tables related to Permission Sync"""
|
||||
|
||||
|
||||
class PermissionSyncStatus(str, PyEnum):
|
||||
IN_PROGRESS = "in_progress"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class PermissionSyncJobType(str, PyEnum):
|
||||
USER_LEVEL = "user_level"
|
||||
GROUP_LEVEL = "group_level"
|
||||
|
||||
|
||||
class PermissionSyncRun(Base):
|
||||
"""Represents one run of a permission sync job. For some given cc_pair, it is either sync-ing
|
||||
the users or it is sync-ing the groups"""
|
||||
|
||||
__tablename__ = "permission_sync_run"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
# Not strictly needed but makes it easy to use without fetching from cc_pair
|
||||
source_type: Mapped[DocumentSource] = mapped_column(
|
||||
Enum(DocumentSource, native_enum=False)
|
||||
)
|
||||
# Currently all sync jobs are handled as a group permission sync or a user permission sync
|
||||
update_type: Mapped[PermissionSyncJobType] = mapped_column(
|
||||
Enum(PermissionSyncJobType)
|
||||
)
|
||||
cc_pair_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("connector_credential_pair.id"), nullable=True
|
||||
)
|
||||
status: Mapped[PermissionSyncStatus] = mapped_column(Enum(PermissionSyncStatus))
|
||||
error_msg: Mapped[str | None] = mapped_column(Text, default=None)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
cc_pair: Mapped[ConnectorCredentialPair] = relationship("ConnectorCredentialPair")
|
||||
|
||||
|
||||
class ExternalPermission(Base):
|
||||
class User__ExternalUserGroupId(Base):
|
||||
"""Maps user info both internal and external to the name of the external group
|
||||
This maps the user to all of their external groups so that the external group name can be
|
||||
attached to the ACL list matching during query time. User level permissions can be handled by
|
||||
directly adding the Danswer user to the doc ACL list"""
|
||||
|
||||
__tablename__ = "external_permission"
|
||||
__tablename__ = "user__external_user_group_id"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
# Email is needed because we want to keep track of users not in Danswer to simplify process
|
||||
# when the user joins
|
||||
user_email: Mapped[str] = mapped_column(String)
|
||||
source_type: Mapped[DocumentSource] = mapped_column(
|
||||
Enum(DocumentSource, native_enum=False)
|
||||
)
|
||||
external_permission_group: Mapped[str] = mapped_column(String)
|
||||
user = relationship("User")
|
||||
|
||||
|
||||
class EmailToExternalUserCache(Base):
|
||||
"""A way to map users IDs in the external tool to a user in Danswer or at least an email for
|
||||
when the user joins. Used as a cache for when fetching external groups which have their own
|
||||
user ids, this can easily be mapped back to users already known in Danswer without needing
|
||||
to call external APIs to get the user emails.
|
||||
|
||||
This way when groups are updated in the external tool and we need to update the mapping of
|
||||
internal users to the groups, we can sync the internal users to the external groups they are
|
||||
part of using this.
|
||||
|
||||
Ie. User Chris is part of groups alpha, beta, and we can update this if Chris is no longer
|
||||
part of alpha in some external tool.
|
||||
"""
|
||||
|
||||
__tablename__ = "email_to_external_user_cache"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
external_user_id: Mapped[str] = mapped_column(String)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
# Email is needed because we want to keep track of users not in Danswer to simplify process
|
||||
# when the user joins
|
||||
user_email: Mapped[str] = mapped_column(String)
|
||||
source_type: Mapped[DocumentSource] = mapped_column(
|
||||
Enum(DocumentSource, native_enum=False)
|
||||
)
|
||||
|
||||
user = relationship("User")
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
|
||||
# These group ids have been prefixed by the source type
|
||||
external_user_group_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
cc_pair_id: Mapped[int] = mapped_column(ForeignKey("connector_credential_pair.id"))
|
||||
|
||||
|
||||
class UsageReport(Base):
|
||||
@@ -1721,7 +1741,7 @@ class UsageReport(Base):
|
||||
|
||||
# if None, report was auto-generated
|
||||
requestor_user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id"), nullable=True
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from uuid import UUID
|
||||
|
||||
@@ -178,6 +179,7 @@ def create_update_persona(
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to create persona")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return PersonaSnapshot.from_model(persona)
|
||||
|
||||
|
||||
@@ -210,6 +212,22 @@ def update_persona_shared_users(
|
||||
)
|
||||
|
||||
|
||||
def update_persona_public_status(
|
||||
persona_id: int,
|
||||
is_public: bool,
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
) -> None:
|
||||
persona = fetch_persona_by_id(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
|
||||
raise ValueError("You don't have permission to modify this persona")
|
||||
|
||||
persona.is_public = is_public
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_prompts(
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
@@ -242,7 +260,7 @@ def get_personas(
|
||||
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable)
|
||||
|
||||
if not include_default:
|
||||
stmt = stmt.where(Persona.default_persona.is_(False))
|
||||
stmt = stmt.where(Persona.builtin_persona.is_(False))
|
||||
if not include_slack_bot_personas:
|
||||
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
|
||||
if not include_deleted:
|
||||
@@ -290,7 +308,7 @@ def mark_delete_persona_by_name(
|
||||
) -> None:
|
||||
stmt = (
|
||||
update(Persona)
|
||||
.where(Persona.name == persona_name, Persona.default_persona == is_default)
|
||||
.where(Persona.name == persona_name, Persona.builtin_persona == is_default)
|
||||
.values(deleted=True)
|
||||
)
|
||||
|
||||
@@ -390,7 +408,6 @@ def upsert_persona(
|
||||
document_set_ids: list[int] | None = None,
|
||||
tool_ids: list[int] | None = None,
|
||||
persona_id: int | None = None,
|
||||
default_persona: bool = False,
|
||||
commit: bool = True,
|
||||
icon_color: str | None = None,
|
||||
icon_shape: int | None = None,
|
||||
@@ -398,6 +415,9 @@ def upsert_persona(
|
||||
display_priority: int | None = None,
|
||||
is_visible: bool = True,
|
||||
remove_image: bool | None = None,
|
||||
search_start_date: datetime | None = None,
|
||||
builtin_persona: bool = False,
|
||||
is_default_persona: bool = False,
|
||||
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
|
||||
chunks_below: int = CONTEXT_CHUNKS_BELOW,
|
||||
) -> Persona:
|
||||
@@ -438,8 +458,8 @@ def upsert_persona(
|
||||
validate_persona_tools(tools)
|
||||
|
||||
if persona:
|
||||
if not default_persona and persona.default_persona:
|
||||
raise ValueError("Cannot update default persona with non-default.")
|
||||
if not builtin_persona and persona.builtin_persona:
|
||||
raise ValueError("Cannot update builtin persona with non-builtin.")
|
||||
|
||||
# this checks if the user has permission to edit the persona
|
||||
persona = fetch_persona_by_id(
|
||||
@@ -454,7 +474,7 @@ def upsert_persona(
|
||||
persona.llm_relevance_filter = llm_relevance_filter
|
||||
persona.llm_filter_extraction = llm_filter_extraction
|
||||
persona.recency_bias = recency_bias
|
||||
persona.default_persona = default_persona
|
||||
persona.builtin_persona = builtin_persona
|
||||
persona.llm_model_provider_override = llm_model_provider_override
|
||||
persona.llm_model_version_override = llm_model_version_override
|
||||
persona.starter_messages = starter_messages
|
||||
@@ -466,6 +486,8 @@ def upsert_persona(
|
||||
persona.uploaded_image_id = uploaded_image_id
|
||||
persona.display_priority = display_priority
|
||||
persona.is_visible = is_visible
|
||||
persona.search_start_date = search_start_date
|
||||
persona.is_default_persona = is_default_persona
|
||||
|
||||
# Do not delete any associations manually added unless
|
||||
# a new updated list is provided
|
||||
@@ -493,7 +515,7 @@ def upsert_persona(
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
default_persona=default_persona,
|
||||
builtin_persona=builtin_persona,
|
||||
prompts=prompts or [],
|
||||
document_sets=document_sets or [],
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
@@ -505,6 +527,8 @@ def upsert_persona(
|
||||
uploaded_image_id=uploaded_image_id,
|
||||
display_priority=display_priority,
|
||||
is_visible=is_visible,
|
||||
search_start_date=search_start_date,
|
||||
is_default_persona=is_default_persona,
|
||||
)
|
||||
db_session.add(persona)
|
||||
|
||||
@@ -534,7 +558,7 @@ def delete_old_default_personas(
|
||||
Need a more graceful fix later or those need to never have IDs"""
|
||||
stmt = (
|
||||
update(Persona)
|
||||
.where(Persona.default_persona, Persona.id > 0)
|
||||
.where(Persona.builtin_persona, Persona.id > 0)
|
||||
.values(deleted=True, name=func.concat(Persona.name, "_old"))
|
||||
)
|
||||
|
||||
@@ -551,6 +575,7 @@ def update_persona_visibility(
|
||||
persona = fetch_persona_by_id(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
persona.is_visible = is_visible
|
||||
db_session.commit()
|
||||
|
||||
@@ -563,13 +588,15 @@ def validate_persona_tools(tools: list[Tool]) -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> Sequence[Prompt]:
|
||||
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]:
|
||||
"""Unsafe, can fetch prompts from all users"""
|
||||
if not prompt_ids:
|
||||
return []
|
||||
prompts = db_session.scalars(select(Prompt).where(Prompt.id.in_(prompt_ids))).all()
|
||||
prompts = db_session.scalars(
|
||||
select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False))
|
||||
).all()
|
||||
|
||||
return prompts
|
||||
return list(prompts)
|
||||
|
||||
|
||||
def get_prompt_by_id(
|
||||
@@ -650,9 +677,7 @@ def get_persona_by_id(
|
||||
result = db_session.execute(persona_stmt)
|
||||
persona = result.scalar_one_or_none()
|
||||
if persona is None:
|
||||
raise ValueError(
|
||||
f"Persona with ID {persona_id} does not exist or does not belong to user"
|
||||
)
|
||||
raise ValueError(f"Persona with ID {persona_id} does not exist")
|
||||
return persona
|
||||
|
||||
# or check if user owns persona
|
||||
@@ -715,7 +740,7 @@ def delete_persona_by_name(
|
||||
persona_name: str, db_session: Session, is_default: bool = True
|
||||
) -> None:
|
||||
stmt = delete(Persona).where(
|
||||
Persona.name == persona_name, Persona.default_persona == is_default
|
||||
Persona.name == persona_name, Persona.builtin_persona == is_default
|
||||
)
|
||||
|
||||
db_session.execute(stmt)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -13,10 +15,12 @@ from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.llm import fetch_embedding_provider
|
||||
from danswer.db.models import CloudEmbeddingProvider
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.indexing.models import IndexingSetting
|
||||
from danswer.natural_language_processing.search_nlp_models import clean_model_name
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
|
||||
from danswer.search.models import SavedSearchSettings
|
||||
from danswer.server.manage.embedding.models import (
|
||||
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
|
||||
@@ -89,6 +93,30 @@ def get_current_db_embedding_provider(
|
||||
return current_embedding_provider
|
||||
|
||||
|
||||
def delete_search_settings(db_session: Session, search_settings_id: int) -> None:
|
||||
current_settings = get_current_search_settings(db_session)
|
||||
|
||||
if current_settings.id == search_settings_id:
|
||||
raise ValueError("Cannot delete currently active search settings")
|
||||
|
||||
# First, delete associated index attempts
|
||||
index_attempts_query = delete(IndexAttempt).where(
|
||||
IndexAttempt.search_settings_id == search_settings_id
|
||||
)
|
||||
db_session.execute(index_attempts_query)
|
||||
|
||||
# Then, delete the search settings
|
||||
search_settings_query = delete(SearchSettings).where(
|
||||
and_(
|
||||
SearchSettings.id == search_settings_id,
|
||||
SearchSettings.status != IndexModelStatus.PRESENT,
|
||||
)
|
||||
)
|
||||
|
||||
db_session.execute(search_settings_query)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_current_search_settings(db_session: Session) -> SearchSettings:
|
||||
query = (
|
||||
select(SearchSettings)
|
||||
@@ -115,6 +143,13 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
|
||||
return latest_settings
|
||||
|
||||
|
||||
def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
|
||||
query = select(SearchSettings).order_by(SearchSettings.id.desc())
|
||||
result = db_session.execute(query)
|
||||
all_settings = result.scalars().all()
|
||||
return list(all_settings)
|
||||
|
||||
|
||||
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
|
||||
if db_session is None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
@@ -146,6 +181,14 @@ def update_current_search_settings(
|
||||
logger.warning("No current search settings found to update")
|
||||
return
|
||||
|
||||
# Whenever we update the current search settings, we should ensure that the local reranking model is warmed up.
|
||||
if (
|
||||
search_settings.rerank_provider_type is None
|
||||
and search_settings.rerank_model_name is not None
|
||||
and current_settings.rerank_model_name != search_settings.rerank_model_name
|
||||
):
|
||||
warm_up_cross_encoder(search_settings.rerank_model_name)
|
||||
|
||||
update_search_settings(current_settings, search_settings, preserved_fields)
|
||||
db_session.commit()
|
||||
logger.info("Current search settings updated successfully")
|
||||
@@ -234,6 +277,7 @@ def get_old_default_embedding_model() -> IndexingSetting:
|
||||
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
|
||||
index_name="danswer_chunk",
|
||||
multipass_indexing=False,
|
||||
api_url=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -246,4 +290,5 @@ def get_new_default_embedding_model() -> IndexingSetting:
|
||||
passage_prefix=ASYM_PASSAGE_PREFIX,
|
||||
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
|
||||
multipass_indexing=False,
|
||||
api_url=None,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -14,8 +15,11 @@ from danswer.db.models import User
|
||||
from danswer.db.persona import get_default_prompt
|
||||
from danswer.db.persona import mark_persona_as_deleted
|
||||
from danswer.db.persona import upsert_persona
|
||||
from danswer.db.standard_answer import fetch_standard_answer_categories_by_ids
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.utils.errors import EERequiredError
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
|
||||
|
||||
def _build_persona_name(channel_names: list[str]) -> str:
|
||||
@@ -62,7 +66,7 @@ def create_slack_bot_persona(
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
is_public=True,
|
||||
default_persona=False,
|
||||
is_default_persona=False,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
@@ -70,6 +74,10 @@ def create_slack_bot_persona(
|
||||
return persona
|
||||
|
||||
|
||||
def _no_ee_standard_answer_categories(*args: Any, **kwargs: Any) -> list:
|
||||
return []
|
||||
|
||||
|
||||
def insert_slack_bot_config(
|
||||
persona_id: int | None,
|
||||
channel_config: ChannelConfig,
|
||||
@@ -78,14 +86,29 @@ def insert_slack_bot_config(
|
||||
enable_auto_filters: bool,
|
||||
db_session: Session,
|
||||
) -> SlackBotConfig:
|
||||
existing_standard_answer_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=standard_answer_category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
|
||||
raise ValueError(
|
||||
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
|
||||
versioned_fetch_standard_answer_categories_by_ids = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.standard_answer",
|
||||
"fetch_standard_answer_categories_by_ids",
|
||||
_no_ee_standard_answer_categories,
|
||||
)
|
||||
)
|
||||
existing_standard_answer_categories = (
|
||||
versioned_fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=standard_answer_category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
|
||||
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
|
||||
if len(existing_standard_answer_categories) == 0:
|
||||
raise EERequiredError(
|
||||
"Standard answers are a paid Enterprise Edition feature - enable EE or remove standard answer categories"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
|
||||
)
|
||||
|
||||
slack_bot_config = SlackBotConfig(
|
||||
persona_id=persona_id,
|
||||
@@ -117,9 +140,18 @@ def update_slack_bot_config(
|
||||
f"Unable to find slack bot config with ID {slack_bot_config_id}"
|
||||
)
|
||||
|
||||
existing_standard_answer_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=standard_answer_category_ids,
|
||||
db_session=db_session,
|
||||
versioned_fetch_standard_answer_categories_by_ids = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.standard_answer",
|
||||
"fetch_standard_answer_categories_by_ids",
|
||||
_no_ee_standard_answer_categories,
|
||||
)
|
||||
)
|
||||
existing_standard_answer_categories = (
|
||||
versioned_fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=standard_answer_category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
|
||||
raise ValueError(
|
||||
|
||||
@@ -44,12 +44,11 @@ def get_latest_task_by_type(
|
||||
|
||||
|
||||
def register_task(
|
||||
task_id: str,
|
||||
task_name: str,
|
||||
db_session: Session,
|
||||
) -> TaskQueueState:
|
||||
new_task = TaskQueueState(
|
||||
task_id=task_id, task_name=task_name, status=TaskStatus.PENDING
|
||||
task_id="", task_name=task_name, status=TaskStatus.PENDING
|
||||
)
|
||||
|
||||
db_session.add(new_task)
|
||||
|
||||
@@ -5,6 +5,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import Tool
|
||||
from danswer.server.features.tool.models import Header
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -25,6 +26,7 @@ def create_tool(
|
||||
name: str,
|
||||
description: str | None,
|
||||
openapi_schema: dict[str, Any] | None,
|
||||
custom_headers: list[Header] | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> Tool:
|
||||
@@ -33,6 +35,9 @@ def create_tool(
|
||||
description=description,
|
||||
in_code_tool_id=None,
|
||||
openapi_schema=openapi_schema,
|
||||
custom_headers=[header.dict() for header in custom_headers]
|
||||
if custom_headers
|
||||
else [],
|
||||
user_id=user_id,
|
||||
)
|
||||
db_session.add(new_tool)
|
||||
@@ -45,6 +50,7 @@ def update_tool(
|
||||
name: str | None,
|
||||
description: str | None,
|
||||
openapi_schema: dict[str, Any] | None,
|
||||
custom_headers: list[Header] | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> Tool:
|
||||
@@ -60,6 +66,8 @@ def update_tool(
|
||||
tool.openapi_schema = openapi_schema
|
||||
if user_id is not None:
|
||||
tool.user_id = user_id
|
||||
if custom_headers is not None:
|
||||
tool.custom_headers = [header.dict() for header in custom_headers]
|
||||
db_session.commit()
|
||||
|
||||
return tool
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.db.models import User
|
||||
|
||||
|
||||
@@ -20,8 +23,23 @@ def list_users(
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
|
||||
def get_users_by_emails(
|
||||
db_session: Session, emails: list[str]
|
||||
) -> tuple[list[User], list[str]]:
|
||||
# Use distinct to avoid duplicates
|
||||
stmt = select(User).filter(User.email.in_(emails)) # type: ignore
|
||||
found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list
|
||||
found_users_emails = [user.email for user in found_users]
|
||||
missing_user_emails = [email for email in emails if email not in found_users_emails]
|
||||
return found_users, missing_user_emails
|
||||
|
||||
|
||||
def get_user_by_email(email: str, db_session: Session) -> User | None:
|
||||
user = db_session.query(User).filter(User.email == email).first() # type: ignore
|
||||
user = (
|
||||
db_session.query(User)
|
||||
.filter(func.lower(User.email) == func.lower(email))
|
||||
.first()
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
@@ -30,3 +48,52 @@ def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None:
|
||||
user = db_session.query(User).filter(User.id == user_id).first() # type: ignore
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def _generate_non_web_user(email: str) -> User:
|
||||
fastapi_users_pw_helper = PasswordHelper()
|
||||
password = fastapi_users_pw_helper.generate()
|
||||
hashed_pass = fastapi_users_pw_helper.hash(password)
|
||||
return User(
|
||||
email=email,
|
||||
hashed_password=hashed_pass,
|
||||
has_web_login=False,
|
||||
role=UserRole.BASIC,
|
||||
)
|
||||
|
||||
|
||||
def add_non_web_user_if_not_exists(db_session: Session, email: str) -> User:
|
||||
user = get_user_by_email(email, db_session)
|
||||
if user is not None:
|
||||
return user
|
||||
|
||||
user = _generate_non_web_user(email=email)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def add_non_web_user_if_not_exists__no_commit(db_session: Session, email: str) -> User:
|
||||
user = get_user_by_email(email, db_session)
|
||||
if user is not None:
|
||||
return user
|
||||
|
||||
user = _generate_non_web_user(email=email)
|
||||
db_session.add(user)
|
||||
db_session.flush() # generate id
|
||||
return user
|
||||
|
||||
|
||||
def batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session: Session, emails: list[str]
|
||||
) -> list[User]:
|
||||
found_users, missing_user_emails = get_users_by_emails(db_session, emails)
|
||||
|
||||
new_users: list[User] = []
|
||||
for email in missing_user_emails:
|
||||
new_users.append(_generate_non_web_user(email=email))
|
||||
|
||||
db_session.add_all(new_users)
|
||||
db_session.flush() # generate ids
|
||||
|
||||
return found_users + new_users
|
||||
|
||||
@@ -177,6 +177,30 @@ class Updatable(abc.ABC):
|
||||
- Whether the document is hidden or not, hidden documents are not returned from search
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_single(self, update_request: UpdateRequest) -> None:
|
||||
"""
|
||||
Updates some set of chunks for a document. The document and fields to update
|
||||
are specified in the update request. Each update request in the list applies
|
||||
its changes to a list of document ids.
|
||||
None values mean that the field does not need an update.
|
||||
|
||||
The rationale for a single update function is that it allows retries and parallelism
|
||||
to happen at a higher / more strategic level, is simpler to read, and allows
|
||||
us to individually handle error conditions per document.
|
||||
|
||||
Parameters:
|
||||
- update_request: for a list of document ids in the update request, apply the same updates
|
||||
to all of the documents with those ids.
|
||||
|
||||
Return:
|
||||
- an HTTPStatus code. The code can used to decide whether to fail immediately,
|
||||
retry, etc. Although this method likely hits an HTTP API behind the
|
||||
scenes, the usage of HTTPStatus is a convenience and the interface is not
|
||||
actually HTTP specific.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, update_requests: list[UpdateRequest]) -> None:
|
||||
"""
|
||||
|
||||
@@ -26,6 +26,17 @@
|
||||
<disk>0.75</disk>
|
||||
</resource-limits>
|
||||
</tuning>
|
||||
<engine>
|
||||
<proton>
|
||||
<tuning>
|
||||
<searchnode>
|
||||
<requestthreads>
|
||||
<persearch>SEARCH_THREAD_NUMBER</persearch>
|
||||
</requestthreads>
|
||||
</searchnode>
|
||||
</tuning>
|
||||
</proton>
|
||||
</engine>
|
||||
<config name="vespa.config.search.summary.juniperrc">
|
||||
<max_matches>3</max_matches>
|
||||
<length>750</length>
|
||||
@@ -33,4 +44,4 @@
|
||||
<min_length>300</min_length>
|
||||
</config>
|
||||
</content>
|
||||
</services>
|
||||
</services>
|
||||
@@ -30,6 +30,7 @@ from danswer.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from danswer.document_index.vespa_constants import HIDDEN
|
||||
from danswer.document_index.vespa_constants import LARGE_CHUNK_REFERENCE_IDS
|
||||
from danswer.document_index.vespa_constants import MAX_ID_SEARCH_QUERY_SIZE
|
||||
from danswer.document_index.vespa_constants import MAX_OR_CONDITIONS
|
||||
from danswer.document_index.vespa_constants import METADATA
|
||||
from danswer.document_index.vespa_constants import METADATA_SUFFIX
|
||||
from danswer.document_index.vespa_constants import PRIMARY_OWNERS
|
||||
@@ -292,12 +293,11 @@ def query_vespa(
|
||||
if LOG_VESPA_TIMING_INFORMATION
|
||||
else {},
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
SEARCH_ENDPOINT,
|
||||
json=params,
|
||||
)
|
||||
try:
|
||||
response = requests.post(
|
||||
SEARCH_ENDPOINT,
|
||||
json=params,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
request_info = f"Headers: {response.request.headers}\nPayload: {params}"
|
||||
@@ -319,6 +319,12 @@ def query_vespa(
|
||||
logger.debug("Vespa timing info: %s", response_json.get("timing"))
|
||||
hits = response_json["root"].get("children", [])
|
||||
|
||||
if not hits:
|
||||
logger.warning(
|
||||
f"No hits found for YQL Query: {query_params.get('yql', 'No YQL Query')}"
|
||||
)
|
||||
logger.debug(f"Vespa Response: {response.text}")
|
||||
|
||||
for hit in hits:
|
||||
if hit["fields"].get(CONTENT) is None:
|
||||
identifier = hit["fields"].get("documentid") or hit["id"]
|
||||
@@ -379,7 +385,7 @@ def batch_search_api_retrieval(
|
||||
capped_requests: list[VespaChunkRequest] = []
|
||||
uncapped_requests: list[VespaChunkRequest] = []
|
||||
chunk_count = 0
|
||||
for request in chunk_requests:
|
||||
for req_ind, request in enumerate(chunk_requests, start=1):
|
||||
# All requests without a chunk range are uncapped
|
||||
# Uncapped requests are retrieved using the Visit API
|
||||
range = request.range
|
||||
@@ -387,9 +393,10 @@ def batch_search_api_retrieval(
|
||||
uncapped_requests.append(request)
|
||||
continue
|
||||
|
||||
# If adding the range to the chunk count is greater than the
|
||||
# max query size, we need to perform a retrieval to avoid hitting the limit
|
||||
if chunk_count + range > MAX_ID_SEARCH_QUERY_SIZE:
|
||||
if (
|
||||
chunk_count + range > MAX_ID_SEARCH_QUERY_SIZE
|
||||
or req_ind % MAX_OR_CONDITIONS == 0
|
||||
):
|
||||
retrieved_chunks.extend(
|
||||
_get_chunks_via_batch_search(
|
||||
index_name=index_name,
|
||||
|
||||
@@ -16,6 +16,7 @@ import requests
|
||||
from danswer.configs.chat_configs import DOC_TIME_DECAY
|
||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from danswer.configs.chat_configs import VESPA_SEARCHER_THREADS
|
||||
from danswer.configs.constants import KV_REINDEX_KEY
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import DocumentInsertionRecord
|
||||
@@ -52,6 +53,7 @@ from danswer.document_index.vespa_constants import DOCUMENT_REPLACEMENT_PAT
|
||||
from danswer.document_index.vespa_constants import DOCUMENT_SETS
|
||||
from danswer.document_index.vespa_constants import HIDDEN
|
||||
from danswer.document_index.vespa_constants import NUM_THREADS
|
||||
from danswer.document_index.vespa_constants import SEARCH_THREAD_NUMBER_PAT
|
||||
from danswer.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT
|
||||
from danswer.document_index.vespa_constants import VESPA_DIM_REPLACEMENT_PAT
|
||||
from danswer.document_index.vespa_constants import VESPA_TIMEOUT
|
||||
@@ -118,7 +120,7 @@ class VespaIndex(DocumentIndex):
|
||||
secondary_index_embedding_dim: int | None,
|
||||
) -> None:
|
||||
deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate"
|
||||
logger.debug(f"Sending Vespa zip to {deploy_url}")
|
||||
logger.info(f"Deploying Vespa application package to {deploy_url}")
|
||||
|
||||
vespa_schema_path = os.path.join(
|
||||
os.getcwd(), "danswer", "document_index", "vespa", "app_config"
|
||||
@@ -134,6 +136,10 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
doc_lines = _create_document_xml_lines(schema_names)
|
||||
services = services_template.replace(DOCUMENT_REPLACEMENT_PAT, doc_lines)
|
||||
services = services.replace(
|
||||
SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS)
|
||||
)
|
||||
|
||||
kv_store = get_dynamic_config_store()
|
||||
|
||||
needs_reindexing = False
|
||||
@@ -282,7 +288,7 @@ class VespaIndex(DocumentIndex):
|
||||
raise requests.HTTPError(failure_msg) from e
|
||||
|
||||
def update(self, update_requests: list[UpdateRequest]) -> None:
|
||||
logger.info(f"Updating {len(update_requests)} documents in Vespa")
|
||||
logger.debug(f"Updating {len(update_requests)} documents in Vespa")
|
||||
|
||||
# Handle Vespa character limitations
|
||||
# Mutating update_requests but it's not used later anyway
|
||||
@@ -371,6 +377,91 @@ class VespaIndex(DocumentIndex):
|
||||
time.monotonic() - update_start,
|
||||
)
|
||||
|
||||
def update_single(self, update_request: UpdateRequest) -> None:
|
||||
"""Note: if the document id does not exist, the update will be a no-op and the
|
||||
function will complete with no errors or exceptions.
|
||||
Handle other exceptions if you wish to implement retry behavior
|
||||
"""
|
||||
if len(update_request.document_ids) != 1:
|
||||
raise ValueError("update_request must contain a single document id")
|
||||
|
||||
# Handle Vespa character limitations
|
||||
# Mutating update_request but it's not used later anyway
|
||||
update_request.document_ids = [
|
||||
replace_invalid_doc_id_characters(doc_id)
|
||||
for doc_id in update_request.document_ids
|
||||
]
|
||||
|
||||
# update_start = time.monotonic()
|
||||
|
||||
# Fetch all chunks for each document ahead of time
|
||||
index_names = [self.index_name]
|
||||
if self.secondary_index_name:
|
||||
index_names.append(self.secondary_index_name)
|
||||
|
||||
chunk_id_start_time = time.monotonic()
|
||||
all_doc_chunk_ids: list[str] = []
|
||||
for index_name in index_names:
|
||||
for document_id in update_request.document_ids:
|
||||
# this calls vespa and can raise http exceptions
|
||||
doc_chunk_ids = get_all_vespa_ids_for_document_id(
|
||||
document_id=document_id,
|
||||
index_name=index_name,
|
||||
filters=None,
|
||||
get_large_chunks=True,
|
||||
)
|
||||
all_doc_chunk_ids.extend(doc_chunk_ids)
|
||||
logger.debug(
|
||||
f"Took {time.monotonic() - chunk_id_start_time:.2f} seconds to fetch all Vespa chunk IDs"
|
||||
)
|
||||
|
||||
# Build the _VespaUpdateRequest objects
|
||||
update_dict: dict[str, dict] = {"fields": {}}
|
||||
if update_request.boost is not None:
|
||||
update_dict["fields"][BOOST] = {"assign": update_request.boost}
|
||||
if update_request.document_sets is not None:
|
||||
update_dict["fields"][DOCUMENT_SETS] = {
|
||||
"assign": {
|
||||
document_set: 1 for document_set in update_request.document_sets
|
||||
}
|
||||
}
|
||||
if update_request.access is not None:
|
||||
update_dict["fields"][ACCESS_CONTROL_LIST] = {
|
||||
"assign": {acl_entry: 1 for acl_entry in update_request.access.to_acl()}
|
||||
}
|
||||
if update_request.hidden is not None:
|
||||
update_dict["fields"][HIDDEN] = {"assign": update_request.hidden}
|
||||
|
||||
if not update_dict["fields"]:
|
||||
logger.error("Update request received but nothing to update")
|
||||
return
|
||||
|
||||
processed_update_requests: list[_VespaUpdateRequest] = []
|
||||
for document_id in update_request.document_ids:
|
||||
for doc_chunk_id in all_doc_chunk_ids:
|
||||
processed_update_requests.append(
|
||||
_VespaUpdateRequest(
|
||||
document_id=document_id,
|
||||
url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}/{doc_chunk_id}",
|
||||
update_request=update_dict,
|
||||
)
|
||||
)
|
||||
|
||||
with httpx.Client(http2=True) as http_client:
|
||||
for update in processed_update_requests:
|
||||
http_client.put(
|
||||
update.url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=update.update_request,
|
||||
)
|
||||
|
||||
# logger.debug(
|
||||
# "Finished updating Vespa documents in %.2f seconds",
|
||||
# time.monotonic() - update_start,
|
||||
# )
|
||||
|
||||
return
|
||||
|
||||
def delete(self, doc_ids: list[str]) -> None:
|
||||
logger.info(f"Deleting {len(doc_ids)} documents from Vespa")
|
||||
|
||||
|
||||
@@ -162,14 +162,16 @@ def _index_vespa_chunk(
|
||||
METADATA_SUFFIX: chunk.metadata_suffix_keyword,
|
||||
EMBEDDINGS: embeddings_name_vector_map,
|
||||
TITLE_EMBEDDING: chunk.title_embedding,
|
||||
BOOST: chunk.boost,
|
||||
DOC_UPDATED_AT: _vespa_get_updated_at_attribute(document.doc_updated_at),
|
||||
PRIMARY_OWNERS: get_experts_stores_representations(document.primary_owners),
|
||||
SECONDARY_OWNERS: get_experts_stores_representations(document.secondary_owners),
|
||||
# the only `set` vespa has is `weightedset`, so we have to give each
|
||||
# element an arbitrary weight
|
||||
# rkuo: acl, docset and boost metadata are also updated through the metadata sync queue
|
||||
# which only calls VespaIndex.update
|
||||
ACCESS_CONTROL_LIST: {acl_entry: 1 for acl_entry in chunk.access.to_acl()},
|
||||
DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets},
|
||||
BOOST: chunk.boost,
|
||||
}
|
||||
|
||||
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_chunk_id}"
|
||||
|
||||
@@ -7,6 +7,7 @@ from danswer.configs.constants import SOURCE_TYPE
|
||||
VESPA_DIM_REPLACEMENT_PAT = "VARIABLE_DIM"
|
||||
DANSWER_CHUNK_REPLACEMENT_PAT = "DANSWER_CHUNK_NAME"
|
||||
DOCUMENT_REPLACEMENT_PAT = "DOCUMENT_REPLACEMENT"
|
||||
SEARCH_THREAD_NUMBER_PAT = "SEARCH_THREAD_NUMBER"
|
||||
DATE_REPLACEMENT = "DATE_REPLACEMENT"
|
||||
|
||||
# config server
|
||||
@@ -25,6 +26,9 @@ NUM_THREADS = (
|
||||
32 # since Vespa doesn't allow batching of inserts / updates, we use threads
|
||||
)
|
||||
MAX_ID_SEARCH_QUERY_SIZE = 400
|
||||
# Suspect that adding too many "or" conditions will cause Vespa to timeout and return
|
||||
# an empty list of hits (with no error status and coverage: 0 and degraded)
|
||||
MAX_OR_CONDITIONS = 10
|
||||
# up from 500ms for now, since we've seen quite a few timeouts
|
||||
# in the long term, we are looking to improve the performance of Vespa
|
||||
# so that we can bring this back to default
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user