mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
283 Commits
concurrent
...
final_grap
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0d3eb28e8 | ||
|
|
d789c9ac52 | ||
|
|
d989ce13e7 | ||
|
|
3170430673 | ||
|
|
2beffdaa6e | ||
|
|
77ee061e67 | ||
|
|
532bc53a9a | ||
|
|
7b7b95703d | ||
|
|
fcc5efdaf8 | ||
|
|
1ea4a53af1 | ||
|
|
47479c8799 | ||
|
|
fbc5008259 | ||
|
|
d684fb116d | ||
|
|
2e61b374f4 | ||
|
|
15d324834f | ||
|
|
de9a9b7b6e | ||
|
|
47eb8c521d | ||
|
|
875fb05dca | ||
|
|
1285b2f4d4 | ||
|
|
842628771b | ||
|
|
7a9d5bd92e | ||
|
|
4f3b513ccb | ||
|
|
cd454dd780 | ||
|
|
9140ee99cb | ||
|
|
a64f27c895 | ||
|
|
fdf5611a35 | ||
|
|
c4f483d100 | ||
|
|
fc28c6b9e1 | ||
|
|
33e25dbd8b | ||
|
|
659e8cb69e | ||
|
|
681175e9c3 | ||
|
|
de18ec7ea4 | ||
|
|
9edbb0806d | ||
|
|
63d10e7482 | ||
|
|
ff6a15b5af | ||
|
|
49397e8a86 | ||
|
|
285bdbbaf9 | ||
|
|
e2c37d6847 | ||
|
|
3ff2ba7ee4 | ||
|
|
290f4f0f8c | ||
|
|
3c934a93cd | ||
|
|
a51b0f636e | ||
|
|
a50c2e30ec | ||
|
|
ee278522ef | ||
|
|
430c9a47d7 | ||
|
|
974f85da66 | ||
|
|
a63cb9da43 | ||
|
|
d807ad7699 | ||
|
|
3cb00de6d4 | ||
|
|
da6e46ae75 | ||
|
|
648c2531f9 | ||
|
|
fc98c560a4 | ||
|
|
566f44fcd6 | ||
|
|
2fe49e5efb | ||
|
|
f58acd4e2a | ||
|
|
53008a0271 | ||
|
|
13278663d9 | ||
|
|
31ca6857fb | ||
|
|
6dd91414be | ||
|
|
140c34e59e | ||
|
|
da8e68b320 | ||
|
|
e9a616e579 | ||
|
|
cb2169f2a3 | ||
|
|
79aa5dd6e0 | ||
|
|
604ebafe6c | ||
|
|
a2d775efbd | ||
|
|
641690e3f7 | ||
|
|
eebf98e3a6 | ||
|
|
4bc4da29f5 | ||
|
|
7af572d0e7 | ||
|
|
58bdf9d684 | ||
|
|
f69922fff7 | ||
|
|
d4d37c9cdd | ||
|
|
2654df49fd | ||
|
|
aee5fcd4e0 | ||
|
|
2c77dd241b | ||
|
|
d90c90dd92 | ||
|
|
2c971cf774 | ||
|
|
eab55bdd85 | ||
|
|
f4f2fb5943 | ||
|
|
71f2f1a90a | ||
|
|
74a2271422 | ||
|
|
d42fb6ce34 | ||
|
|
0d749ebd46 | ||
|
|
9f6e8bd124 | ||
|
|
3a2a6abed4 | ||
|
|
07f49a384f | ||
|
|
f1c5e80f17 | ||
|
|
b7ad810d83 | ||
|
|
99b28643f7 | ||
|
|
f52d1142eb | ||
|
|
e563746730 | ||
|
|
aa86830bde | ||
|
|
4558351801 | ||
|
|
a4dcae57cd | ||
|
|
dbd56f946f | ||
|
|
e4e4765c60 | ||
|
|
c967f53c02 | ||
|
|
3a9b964d5c | ||
|
|
f04ecbf87a | ||
|
|
362156f97e | ||
|
|
3fa9676478 | ||
|
|
be4b6189d2 | ||
|
|
ace041415a | ||
|
|
148c2a7375 | ||
|
|
1555ac9dab | ||
|
|
80de408cef | ||
|
|
e20c825e16 | ||
|
|
b0568ac8ae | ||
|
|
0896d3b7da | ||
|
|
87b27046bd | ||
|
|
5e9c6d1499 | ||
|
|
50211ec401 | ||
|
|
6012a7cbd9 | ||
|
|
1e4b27185d | ||
|
|
0c66da17bb | ||
|
|
d985cd4352 | ||
|
|
c8891a5829 | ||
|
|
51a13f5fc7 | ||
|
|
57c1deb8b8 | ||
|
|
e2e04af7e2 | ||
|
|
c1735fcd3a | ||
|
|
b43e5735d7 | ||
|
|
7d4f8ef4e8 | ||
|
|
7c03b6f521 | ||
|
|
ccf986808c | ||
|
|
350482e53e | ||
|
|
fb3d7330fa | ||
|
|
6cec31088d | ||
|
|
491f3254a5 | ||
|
|
5abf67fbf0 | ||
|
|
2933c3598b | ||
|
|
aeb6060854 | ||
|
|
8977b1b5fc | ||
|
|
69c0419146 | ||
|
|
2bd3833c55 | ||
|
|
2d7b312e6c | ||
|
|
ebe3674ca7 | ||
|
|
04f83eb1e1 | ||
|
|
420aabc963 | ||
|
|
61a17319c9 | ||
|
|
e4c85352b4 | ||
|
|
34ba3181ff | ||
|
|
630e2248bd | ||
|
|
c358c91e4c | ||
|
|
2b7915f33b | ||
|
|
0ff1a023cd | ||
|
|
d68d281e1c | ||
|
|
ebce3ff6ba | ||
|
|
f96bd12ab8 | ||
|
|
32359d2dff | ||
|
|
5da6d792de | ||
|
|
fb95398e5b | ||
|
|
af66650ee3 | ||
|
|
5b1f3c8d4e | ||
|
|
a3b1b1db38 | ||
|
|
7520fae068 | ||
|
|
39c946536c | ||
|
|
90528ba195 | ||
|
|
6afcaafe54 | ||
|
|
812ca69949 | ||
|
|
abe01144ca | ||
|
|
d988a3e736 | ||
|
|
2b14afe878 | ||
|
|
033ec0b6b1 | ||
|
|
14a9fecc64 | ||
|
|
0027f161d7 | ||
|
|
32e551b69c | ||
|
|
299cb5035c | ||
|
|
910821c723 | ||
|
|
aa84846298 | ||
|
|
c122be2f6a | ||
|
|
f871b4c6eb | ||
|
|
a96cea2ce0 | ||
|
|
8d443ada5b | ||
|
|
634de83d72 | ||
|
|
580848cf8c | ||
|
|
f01027cfb7 | ||
|
|
76db4b765a | ||
|
|
5800c7158e | ||
|
|
21af852073 | ||
|
|
355326f935 | ||
|
|
762b7b1047 | ||
|
|
df31cac1f1 | ||
|
|
4181124e7a | ||
|
|
44c45cbf2a | ||
|
|
f2e8680955 | ||
|
|
b952dbef42 | ||
|
|
e2f4145cd2 | ||
|
|
183569061b | ||
|
|
8f26728a29 | ||
|
|
1734a4a18c | ||
|
|
766652de14 | ||
|
|
00fa36d591 | ||
|
|
3b596fd6a8 | ||
|
|
5a83b00190 | ||
|
|
57491ceaae | ||
|
|
e4e67c61ef | ||
|
|
8afa53c6bf | ||
|
|
fb6637d5b3 | ||
|
|
1e67332078 | ||
|
|
effce919bd | ||
|
|
e5b3843ef8 | ||
|
|
50c17438d5 | ||
|
|
657d2050a5 | ||
|
|
3640d0c550 | ||
|
|
336ddbd1fe | ||
|
|
8614cd8934 | ||
|
|
525f3e01f5 | ||
|
|
feaa85f764 | ||
|
|
b36cd4937f | ||
|
|
97ba71e1b3 | ||
|
|
5f12b7ad58 | ||
|
|
a873fc6483 | ||
|
|
c0e1a02e8e | ||
|
|
205c3c3fc8 | ||
|
|
e5ceb76de8 | ||
|
|
c21b0ee3f5 | ||
|
|
1e1b2a0901 | ||
|
|
c1c35b00cb | ||
|
|
1bc899cc67 | ||
|
|
6fc6ee5c37 | ||
|
|
7d201f67d4 | ||
|
|
e749fa0f28 | ||
|
|
2e0222d1c1 | ||
|
|
c152123ef4 | ||
|
|
5cb9c17ddf | ||
|
|
b1302303b2 | ||
|
|
e89dc67e5d | ||
|
|
7da6d33451 | ||
|
|
c042a19c00 | ||
|
|
5409777e0b | ||
|
|
5f4b7dd23e | ||
|
|
99db27d989 | ||
|
|
197b62aed1 | ||
|
|
9d5db05e4b | ||
|
|
27e094d2ec | ||
|
|
1a9e5da7c0 | ||
|
|
8afcb03f3c | ||
|
|
9bf42d2303 | ||
|
|
e50b558b5b | ||
|
|
020dff52f7 | ||
|
|
13303edf29 | ||
|
|
584eae17e3 | ||
|
|
b9b633bb74 | ||
|
|
bb1916d5d0 | ||
|
|
048cb8dd55 | ||
|
|
3b035d791e | ||
|
|
53387ab3eb | ||
|
|
ec6e2369a1 | ||
|
|
075eacdd91 | ||
|
|
f77b1ebd87 | ||
|
|
1ddb4b2025 | ||
|
|
42f0fea9f8 | ||
|
|
8de04acb7f | ||
|
|
5053f4e383 | ||
|
|
730a757090 | ||
|
|
006cfa1d3d | ||
|
|
69f6b7d148 | ||
|
|
53a3fb8e52 | ||
|
|
919110a655 | ||
|
|
19cccd267d | ||
|
|
71c2b16a01 | ||
|
|
12f0dbcfc5 | ||
|
|
583bd1d207 | ||
|
|
8a4e47781b | ||
|
|
af647959f6 | ||
|
|
ea53977617 | ||
|
|
c44c22a009 | ||
|
|
5ab4d94d94 | ||
|
|
119aefba88 | ||
|
|
12fccfeffd | ||
|
|
8a7bc4e411 | ||
|
|
492797c9f3 | ||
|
|
739058aacc | ||
|
|
17570038bb | ||
|
|
c0edfb50df | ||
|
|
22573aba2a | ||
|
|
efae24acd0 | ||
|
|
f8e0e6f015 | ||
|
|
3cbc341b60 | ||
|
|
46c7089328 | ||
|
|
3ffbe659e3 |
76
.github/actions/custom-build-and-push/action.yml
vendored
Normal file
76
.github/actions/custom-build-and-push/action.yml
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
name: 'Build and Push Docker Image with Retry'
|
||||
description: 'Attempts to build and push a Docker image, with a retry on failure'
|
||||
inputs:
|
||||
context:
|
||||
description: 'Build context'
|
||||
required: true
|
||||
file:
|
||||
description: 'Dockerfile location'
|
||||
required: true
|
||||
platforms:
|
||||
description: 'Target platforms'
|
||||
required: true
|
||||
pull:
|
||||
description: 'Always attempt to pull a newer version of the image'
|
||||
required: false
|
||||
default: 'true'
|
||||
push:
|
||||
description: 'Push the image to registry'
|
||||
required: false
|
||||
default: 'true'
|
||||
load:
|
||||
description: 'Load the image into Docker daemon'
|
||||
required: false
|
||||
default: 'true'
|
||||
tags:
|
||||
description: 'Image tags'
|
||||
required: true
|
||||
cache-from:
|
||||
description: 'Cache sources'
|
||||
required: false
|
||||
cache-to:
|
||||
description: 'Cache destinations'
|
||||
required: false
|
||||
retry-wait-time:
|
||||
description: 'Time to wait before retry in seconds'
|
||||
required: false
|
||||
default: '5'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Build and push Docker image (First Attempt)
|
||||
id: buildx1
|
||||
uses: docker/build-push-action@v5
|
||||
continue-on-error: true
|
||||
with:
|
||||
context: ${{ inputs.context }}
|
||||
file: ${{ inputs.file }}
|
||||
platforms: ${{ inputs.platforms }}
|
||||
pull: ${{ inputs.pull }}
|
||||
push: ${{ inputs.push }}
|
||||
load: ${{ inputs.load }}
|
||||
tags: ${{ inputs.tags }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
|
||||
- name: Wait to retry
|
||||
if: steps.buildx1.outcome != 'success'
|
||||
run: |
|
||||
echo "First attempt failed. Waiting ${{ inputs.retry-wait-time }} seconds before retry..."
|
||||
sleep ${{ inputs.retry-wait-time }}
|
||||
shell: bash
|
||||
|
||||
- name: Build and push Docker image (Retry Attempt)
|
||||
if: steps.buildx1.outcome != 'success'
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ${{ inputs.context }}
|
||||
file: ${{ inputs.file }}
|
||||
platforms: ${{ inputs.platforms }}
|
||||
pull: ${{ inputs.pull }}
|
||||
push: ${{ inputs.push }}
|
||||
load: ${{ inputs.load }}
|
||||
tags: ${{ inputs.tags }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
@@ -1,33 +0,0 @@
|
||||
name: Build Backend Image on Merge Group
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
types: [checks_requested]
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-backend
|
||||
|
||||
jobs:
|
||||
build:
|
||||
# TODO: make this a matrix build like the web containers
|
||||
runs-on:
|
||||
group: amd64-image-builders
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Backend Image Docker Build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: false
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:latest
|
||||
build-args: |
|
||||
DANSWER_VERSION=v0.0.1
|
||||
@@ -1,53 +0,0 @@
|
||||
name: Build Web Image on Merge Group
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
types: [checks_requested]
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-web-server
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on:
|
||||
group: ${{ matrix.platform == 'linux/amd64' && 'amd64-image-builders' || 'arm64-image-builders' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
platform:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
|
||||
steps:
|
||||
- name: Prepare
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: ${{ matrix.platform }}
|
||||
push: false
|
||||
build-args: |
|
||||
DANSWER_VERSION=v0.0.1
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
67
.github/workflows/pr-helm-chart-testing.yml.disabled.txt
vendored
Normal file
67
.github/workflows/pr-helm-chart-testing.yml.disabled.txt
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
# This workflow is intentionally disabled while we're still working on it
|
||||
# It's close to ready, but a race condition needs to be fixed with
|
||||
# API server and Vespa startup, and it needs to have a way to build/test against
|
||||
# local containers
|
||||
|
||||
name: Helm - Lint and Test Charts
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
lint-test:
|
||||
runs-on: Amd64
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4.2.0
|
||||
with:
|
||||
version: v3.14.4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
pip install -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
run: |
|
||||
changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }})
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Run chart-testing (lint)
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Create kind cluster
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.10.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct install --all --config ct.yaml
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
1
.github/workflows/pr-python-checks.yml
vendored
1
.github/workflows/pr-python-checks.yml
vendored
@@ -1,6 +1,7 @@
|
||||
name: Python Checks
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
|
||||
57
.github/workflows/pr-python-connector-tests.yml
vendored
Normal file
57
.github/workflows/pr-python-connector-tests.yml
vendored
Normal file
@@ -0,0 +1,57 @@
|
||||
name: Connector Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
schedule:
|
||||
# This cron expression runs the job daily at 16:00 UTC (9am PT)
|
||||
- cron: "0 16 * * *"
|
||||
|
||||
env:
|
||||
# Confluence
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
|
||||
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
|
||||
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors
|
||||
|
||||
- name: Alert on Failure
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
env:
|
||||
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
|
||||
run: |
|
||||
curl -X POST \
|
||||
-H 'Content-type: application/json' \
|
||||
--data '{"text":"Scheduled Connector Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
|
||||
$SLACK_WEBHOOK
|
||||
4
.github/workflows/pr-python-tests.yml
vendored
4
.github/workflows/pr-python-tests.yml
vendored
@@ -1,6 +1,7 @@
|
||||
name: Python Unit Tests
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
@@ -10,7 +11,8 @@ jobs:
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
|
||||
REDIS_CLOUD_PYTEST_PASSWORD: ${{ secrets.REDIS_CLOUD_PYTEST_PASSWORD }}
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
19
.github/workflows/pr-quality-checks.yml
vendored
19
.github/workflows/pr-quality-checks.yml
vendored
@@ -4,18 +4,19 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request: null
|
||||
|
||||
jobs:
|
||||
quality-checks:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
with:
|
||||
extra_args: --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }}
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
with:
|
||||
extra_args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || '' }}
|
||||
|
||||
162
.github/workflows/run-it.yml
vendored
Normal file
162
.github/workflows/run-it.yml
vendored
Normal file
@@ -0,0 +1,162 @@
|
||||
name: Run Integration Tests
|
||||
concurrency:
|
||||
group: Run-Integration-Tests-${{ github.head_ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
jobs:
|
||||
integration-tests:
|
||||
runs-on:
|
||||
group: 'arm64-image-builders'
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# NOTE: we don't need to build the Web Docker image since it's not used
|
||||
# during the IT for now. We have a separate action to verify it builds
|
||||
# succesfully
|
||||
- name: Pull Web Docker image
|
||||
run: |
|
||||
docker pull danswer/danswer-web-server:latest
|
||||
docker tag danswer/danswer-web-server:latest danswer/danswer-web-server:it
|
||||
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: danswer/danswer-backend:it
|
||||
cache-from: type=registry,ref=danswer/danswer-backend:it
|
||||
cache-to: |
|
||||
type=registry,ref=danswer/danswer-backend:it,mode=max
|
||||
type=inline
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/arm64
|
||||
tags: danswer/danswer-model-server:it
|
||||
cache-from: type=registry,ref=danswer/danswer-model-server:it
|
||||
cache-to: |
|
||||
type=registry,ref=danswer/danswer-model-server:it,mode=max
|
||||
type=inline
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: danswer/integration-test-runner:it
|
||||
cache-from: type=registry,ref=danswer/integration-test-runner:it
|
||||
cache-to: |
|
||||
type=registry,ref=danswer/integration-test-runner:it,mode=max
|
||||
type=inline
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=it \
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
docker logs -f danswer-stack-api_server-1 &
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Use curl with error handling to ignore specific exit code 56
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run integration tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
danswer/integration-test-runner:it
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
- name: Check test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
|
||||
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: docker-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Stop Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -4,6 +4,6 @@
|
||||
.mypy_cache
|
||||
.idea
|
||||
/deployment/data/nginx/app.conf
|
||||
.vscode/launch.json
|
||||
.vscode/
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
|
||||
15
.vscode/env_template.txt
vendored
15
.vscode/env_template.txt
vendored
@@ -1,5 +1,5 @@
|
||||
# Copy this file to .env at the base of the repo and fill in the <REPLACE THIS> values
|
||||
# This will help with development iteration speed and reduce repeat tasks for dev
|
||||
# Copy this file to .env in the .vscode folder
|
||||
# Fill in the <REPLACE THIS> values as needed, it is recommended to set the GEN_AI_API_KEY value to avoid having to set up an LLM in the UI
|
||||
# Also check out danswer/backend/scripts/restart_containers.sh for a script to restart the containers which Danswer relies on outside of VSCode/Cursor processes
|
||||
|
||||
# For local dev, often user Authentication is not needed
|
||||
@@ -15,7 +15,7 @@ LOG_LEVEL=debug
|
||||
|
||||
# This passes top N results to LLM an additional time for reranking prior to answer generation
|
||||
# This step is quite heavy on token usage so we disable it for dev generally
|
||||
DISABLE_LLM_DOC_RELEVANCE=True
|
||||
DISABLE_LLM_DOC_RELEVANCE=False
|
||||
|
||||
|
||||
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
|
||||
@@ -27,9 +27,9 @@ REQUIRE_EMAIL_VERIFICATION=False
|
||||
|
||||
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
|
||||
GEN_AI_API_KEY=<REPLACE THIS>
|
||||
# If answer quality isn't important for dev, use 3.5 turbo due to it being cheaper
|
||||
GEN_AI_MODEL_VERSION=gpt-3.5-turbo
|
||||
FAST_GEN_AI_MODEL_VERSION=gpt-3.5-turbo
|
||||
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper
|
||||
GEN_AI_MODEL_VERSION=gpt-4o
|
||||
FAST_GEN_AI_MODEL_VERSION=gpt-4o
|
||||
|
||||
# For Danswer Slack Bot, overrides the UI values so no need to set this up via UI every time
|
||||
# Only needed if using DanswerBot
|
||||
@@ -38,7 +38,7 @@ FAST_GEN_AI_MODEL_VERSION=gpt-3.5-turbo
|
||||
|
||||
|
||||
# Python stuff
|
||||
PYTHONPATH=./backend
|
||||
PYTHONPATH=../backend
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
|
||||
@@ -49,4 +49,3 @@ BING_API_KEY=<REPLACE THIS>
|
||||
# Enable the full set of Danswer Enterprise Edition features
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
|
||||
|
||||
84
.vscode/launch.template.jsonc
vendored
84
.vscode/launch.template.jsonc
vendored
@@ -1,15 +1,23 @@
|
||||
/*
|
||||
|
||||
Copy this file into '.vscode/launch.json' or merge its
|
||||
contents into your existing configurations.
|
||||
|
||||
*/
|
||||
/* Copy this file into '.vscode/launch.json' or merge its contents into your existing configurations. */
|
||||
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"compounds": [
|
||||
{
|
||||
"name": "Run All Danswer Services",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
"Indexing",
|
||||
"Background Jobs",
|
||||
"Slack Bot"
|
||||
]
|
||||
}
|
||||
],
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Web Server",
|
||||
@@ -17,7 +25,7 @@
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceRoot}/web",
|
||||
"runtimeExecutable": "npm",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"runtimeArgs": [
|
||||
"run", "dev"
|
||||
],
|
||||
@@ -25,11 +33,12 @@
|
||||
},
|
||||
{
|
||||
"name": "Model Server",
|
||||
"type": "python",
|
||||
"consoleName": "Model Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
@@ -39,16 +48,16 @@
|
||||
"--reload",
|
||||
"--port",
|
||||
"9000"
|
||||
],
|
||||
"consoleTitle": "Model Server"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "API Server",
|
||||
"type": "python",
|
||||
"consoleName": "API Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
@@ -59,32 +68,32 @@
|
||||
"--reload",
|
||||
"--port",
|
||||
"8080"
|
||||
],
|
||||
"consoleTitle": "API Server"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Indexing",
|
||||
"type": "python",
|
||||
"consoleName": "Indexing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "danswer/background/update.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"consoleTitle": "Indexing"
|
||||
}
|
||||
},
|
||||
// Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev
|
||||
{
|
||||
"name": "Background Jobs",
|
||||
"type": "python",
|
||||
"consoleName": "Background Jobs",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
@@ -93,18 +102,18 @@
|
||||
},
|
||||
"args": [
|
||||
"--no-indexing"
|
||||
],
|
||||
"consoleTitle": "Background Jobs"
|
||||
]
|
||||
},
|
||||
// For the listner to access the Slack API,
|
||||
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
"type": "python",
|
||||
"consoleName": "Slack Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "danswer/danswerbot/slack/listener.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
@@ -113,11 +122,12 @@
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"type": "python",
|
||||
"consoleName": "Pytest",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
@@ -128,18 +138,16 @@
|
||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
|
||||
]
|
||||
}
|
||||
],
|
||||
"compounds": [
|
||||
},
|
||||
{
|
||||
"name": "Run Danswer",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
"Indexing",
|
||||
"Background Jobs",
|
||||
]
|
||||
"name": "Clear and Restart External Volumes and Containers",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"stopOnEntry": true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
105
CONTRIBUTING.md
105
CONTRIBUTING.md
@@ -48,23 +48,26 @@ We would love to see you there!
|
||||
|
||||
|
||||
## Get Started 🚀
|
||||
Danswer being a fully functional app, relies on some external pieces of software, specifically:
|
||||
Danswer being a fully functional app, relies on some external software, specifically:
|
||||
- [Postgres](https://www.postgresql.org/) (Relational DB)
|
||||
- [Vespa](https://vespa.ai/) (Vector DB/Search Engine)
|
||||
- [Redis](https://redis.io/) (Cache)
|
||||
- [Nginx](https://nginx.org/) (Not needed for development flows generally)
|
||||
|
||||
This guide provides instructions to set up the Danswer specific services outside of Docker because it's easier for
|
||||
development purposes but also feel free to just use the containers and update with local changes by providing the
|
||||
`--build` flag.
|
||||
|
||||
> **Note:**
|
||||
> This guide provides instructions to build and run Danswer locally from source with Docker containers providing the above external software. We believe this combination is easier for
|
||||
> development purposes. If you prefer to use pre-built container images, we provide instructions on running the full Danswer stack within Docker below.
|
||||
|
||||
|
||||
### Local Set Up
|
||||
It is recommended to use Python version 3.11
|
||||
Be sure to use Python version 3.11. For instructions on installing Python 3.11 on macOS, refer to the [CONTRIBUTING_MACOS.md](./CONTRIBUTING_MACOS.md) readme.
|
||||
|
||||
If using a lower version, modifications will have to be made to the code.
|
||||
If using a higher version, the version of Tensorflow we use may not be available for your platform.
|
||||
If using a higher version, sometimes some libraries will not be available (i.e. we had problems with Tensorflow in the past with higher versions of python).
|
||||
|
||||
|
||||
#### Installing Requirements
|
||||
#### Backend: Python requirements
|
||||
Currently, we use pip and recommend creating a virtual environment.
|
||||
|
||||
For convenience here's a command for it:
|
||||
@@ -73,8 +76,9 @@ python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
--> Note that this virtual environment MUST NOT be set up WITHIN the danswer
|
||||
directory
|
||||
> **Note:**
|
||||
> This virtual environment MUST NOT be set up WITHIN the danswer directory if you plan on using mypy within certain IDEs.
|
||||
> For simplicity, we recommend setting up the virtual environment outside of the danswer directory.
|
||||
|
||||
_For Windows, activate the virtual environment using Command Prompt:_
|
||||
```bash
|
||||
@@ -89,34 +93,38 @@ Install the required python dependencies:
|
||||
```bash
|
||||
pip install -r danswer/backend/requirements/default.txt
|
||||
pip install -r danswer/backend/requirements/dev.txt
|
||||
pip install -r danswer/backend/requirements/ee.txt
|
||||
pip install -r danswer/backend/requirements/model_server.txt
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector)
|
||||
|
||||
In the activated Python virtualenv, install Playwright for Python by running:
|
||||
```bash
|
||||
playwright install
|
||||
```
|
||||
|
||||
You may have to deactivate and reactivate your virtualenv for `playwright` to appear on your path.
|
||||
|
||||
#### Frontend: Node dependencies
|
||||
|
||||
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
|
||||
Once the above is done, navigate to `danswer/web` run:
|
||||
```bash
|
||||
npm i
|
||||
```
|
||||
|
||||
Install Playwright (required by the Web Connector)
|
||||
#### Docker containers for external software
|
||||
You will need Docker installed to run these containers.
|
||||
|
||||
> Note: If you have just done the pip install, open a new terminal and source the python virtual-env again.
|
||||
This will update the path to include playwright
|
||||
|
||||
Then install Playwright by running:
|
||||
First navigate to `danswer/deployment/docker_compose`, then start up Postgres/Vespa/Redis with:
|
||||
```bash
|
||||
playwright install
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db cache
|
||||
```
|
||||
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
|
||||
|
||||
|
||||
#### Dependent Docker Containers
|
||||
First navigate to `danswer/deployment/docker_compose`, then start up Vespa and Postgres with:
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db
|
||||
```
|
||||
(index refers to Vespa and relational_db refers to Postgres)
|
||||
|
||||
#### Running Danswer
|
||||
#### Running Danswer locally
|
||||
To start the frontend, navigate to `danswer/web` and run:
|
||||
```bash
|
||||
npm run dev
|
||||
@@ -127,11 +135,10 @@ Navigate to `danswer/backend` and run:
|
||||
```bash
|
||||
uvicorn model_server.main:app --reload --port 9000
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
```bash
|
||||
powershell -Command "
|
||||
uvicorn model_server.main:app --reload --port 9000
|
||||
"
|
||||
powershell -Command "uvicorn model_server.main:app --reload --port 9000"
|
||||
```
|
||||
|
||||
The first time running Danswer, you will need to run the DB migrations for Postgres.
|
||||
@@ -154,6 +161,7 @@ To run the backend API server, navigate back to `danswer/backend` and run:
|
||||
```bash
|
||||
AUTH_TYPE=disabled uvicorn danswer.main:app --reload --port 8080
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
```bash
|
||||
powershell -Command "
|
||||
@@ -162,20 +170,58 @@ powershell -Command "
|
||||
"
|
||||
```
|
||||
|
||||
Note: if you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
|
||||
> **Note:**
|
||||
> If you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
|
||||
|
||||
#### Wrapping up
|
||||
|
||||
You should now have 4 servers running:
|
||||
|
||||
- Web server
|
||||
- Backend API
|
||||
- Model server
|
||||
- Background jobs
|
||||
|
||||
Now, visit `http://localhost:3000` in your browser. You should see the Danswer onboarding wizard where you can connect your external LLM provider to Danswer.
|
||||
|
||||
You've successfully set up a local Danswer instance! 🏁
|
||||
|
||||
#### Running the Danswer application in a container
|
||||
|
||||
You can run the full Danswer application stack from pre-built images including all external software dependencies.
|
||||
|
||||
Navigate to `danswer/deployment/docker_compose` and run:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
```
|
||||
|
||||
After Docker pulls and starts these containers, navigate to `http://localhost:3000` to use Danswer.
|
||||
|
||||
If you want to make changes to Danswer and run those changes in Docker, you can also build a local version of the Danswer container images that incorporates your changes like so:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d --build
|
||||
```
|
||||
|
||||
### Formatting and Linting
|
||||
#### Backend
|
||||
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
|
||||
First, install pre-commit (if you don't have it already) following the instructions
|
||||
[here](https://pre-commit.com/#installation).
|
||||
|
||||
With the virtual environment active, install the pre-commit library with:
|
||||
```bash
|
||||
pip install pre-commit
|
||||
```
|
||||
|
||||
Then, from the `danswer/backend` directory, run:
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
Additionally, we use `mypy` for static type checking.
|
||||
Danswer is fully type-annotated, and we would like to keep it that way!
|
||||
Danswer is fully type-annotated, and we want to keep it that way!
|
||||
To run the mypy checks manually, run `python -m mypy .` from the `danswer/backend` directory.
|
||||
|
||||
|
||||
@@ -186,6 +232,7 @@ Please double check that prettier passes before creating a pull request.
|
||||
|
||||
|
||||
### Release Process
|
||||
Danswer follows the semver versioning standard.
|
||||
Danswer loosely follows the SemVer versioning standard.
|
||||
Major changes are released with a "minor" version bump. Currently we use patch release versions to indicate small feature changes.
|
||||
A set of Docker containers will be pushed automatically to DockerHub with every tag.
|
||||
You can see the containers [here](https://hub.docker.com/search?q=danswer%2F).
|
||||
|
||||
31
CONTRIBUTING_MACOS.md
Normal file
31
CONTRIBUTING_MACOS.md
Normal file
@@ -0,0 +1,31 @@
|
||||
## Some additional notes for Mac Users
|
||||
The base instructions to set up the development environment are located in [CONTRIBUTING.md](https://github.com/danswer-ai/danswer/blob/main/CONTRIBUTING.md).
|
||||
|
||||
### Setting up Python
|
||||
Ensure [Homebrew](https://brew.sh/) is already set up.
|
||||
|
||||
Then install python 3.11.
|
||||
```bash
|
||||
brew install python@3.11
|
||||
```
|
||||
|
||||
Add python 3.11 to your path: add the following line to ~/.zshrc
|
||||
```
|
||||
export PATH="$(brew --prefix)/opt/python@3.11/libexec/bin:$PATH"
|
||||
```
|
||||
|
||||
> **Note:**
|
||||
> You will need to open a new terminal for the path change above to take effect.
|
||||
|
||||
|
||||
### Setting up Docker
|
||||
On macOS, you will need to install [Docker Desktop](https://www.docker.com/products/docker-desktop/) and
|
||||
ensure it is running before continuing with the docker commands.
|
||||
|
||||
|
||||
### Formatting and Linting
|
||||
MacOS will likely require you to remove some quarantine attributes on some of the hooks for them to execute properly.
|
||||
After installing pre-commit, run the following command:
|
||||
```bash
|
||||
sudo xattr -r -d com.apple.quarantine ~/.cache/pre-commit
|
||||
```
|
||||
@@ -9,7 +9,8 @@ founders@danswer.ai for more information. Please visit https://github.com/danswe
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.3-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION}
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
# Install system dependencies
|
||||
@@ -75,8 +76,8 @@ Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
# Pre-downloading NLTK for setups with limited egress
|
||||
RUN python -c "import nltk; \
|
||||
nltk.download('stopwords', quiet=True); \
|
||||
nltk.download('wordnet', quiet=True); \
|
||||
nltk.download('punkt', quiet=True);"
|
||||
# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed
|
||||
|
||||
# Set up application files
|
||||
WORKDIR /app
|
||||
|
||||
@@ -8,7 +8,10 @@ visit https://github.com/danswer-ai/danswer."
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.3-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION}
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
|
||||
COPY ./requirements/model_server.txt /tmp/requirements.txt
|
||||
@@ -22,14 +25,18 @@ RUN apt-get remove -y --allow-remove-essential perl-base && \
|
||||
# Download model weights
|
||||
# Run Nomic to pull in the custom architecture and have it cached locally
|
||||
RUN python -c "from transformers import AutoTokenizer; \
|
||||
AutoTokenizer.from_pretrained('distilbert-base-uncased', cache_folder='/root/.cache/temp_huggingface/hub/'); \
|
||||
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1', cache_folder='/root/.cache/temp_huggingface/hub/'); \
|
||||
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
|
||||
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
from huggingface_hub import snapshot_download; \
|
||||
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3', cache_dir='/root/.cache/temp_huggingface/hub/'); \
|
||||
snapshot_download('nomic-ai/nomic-embed-text-v1', cache_dir='/root/.cache/temp_huggingface/hub/'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1', cache_dir='/root/.cache/temp_huggingface/hub/'); \
|
||||
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \
|
||||
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
from sentence_transformers import SentenceTransformer; \
|
||||
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True, cache_folder='/root/.cache/temp_huggingface/hub/');"
|
||||
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);"
|
||||
|
||||
# In case the user has volumes mounted to /root/.cache/huggingface that they've downloaded while
|
||||
# running Danswer, don't overwrite it with the built in cache folder
|
||||
RUN mv /root/.cache/huggingface /root/.cache/temp_huggingface
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
BIN
backend/aaa garp.png
Normal file
BIN
backend/aaa garp.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 27 KiB |
@@ -8,6 +8,7 @@ from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
from sqlalchemy.schema import SchemaItem
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
@@ -15,7 +16,9 @@ config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
if config.config_file_name is not None and config.attributes.get(
|
||||
"configure_logger", True
|
||||
):
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
@@ -29,6 +32,20 @@ target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
name: str,
|
||||
type_: str,
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
) -> bool:
|
||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
@@ -55,7 +72,11 @@ def run_migrations_offline() -> None:
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(connection=connection, target_metadata=target_metadata) # type: ignore
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
include_object=include_object,
|
||||
) # type: ignore
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
"""add ccpair deletion failure message
|
||||
|
||||
Revision ID: 0ebb1d516877
|
||||
Revises: 52a219fb5233
|
||||
Create Date: 2024-09-10 15:03:48.233926
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0ebb1d516877"
|
||||
down_revision = "52a219fb5233"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("deletion_failure_message", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "deletion_failure_message")
|
||||
@@ -0,0 +1,135 @@
|
||||
"""embedding model -> search settings
|
||||
|
||||
Revision ID: 1f60f60c3401
|
||||
Revises: f17bf3b0d9f1
|
||||
Create Date: 2024-08-25 12:39:51.731632
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1f60f60c3401"
|
||||
down_revision = "f17bf3b0d9f1"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"index_attempt__embedding_model_fk", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
# Rename the table
|
||||
op.rename_table("embedding_model", "search_settings")
|
||||
|
||||
# Add new columns
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"multipass_indexing", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"multilingual_expansion",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"disable_rerank_for_streaming",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings", sa.Column("rerank_model_name", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings", sa.Column("rerank_provider_type", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings", sa.Column("rerank_api_key", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"num_rerank",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default=str(NUM_POSTPROCESSED_RESULTS),
|
||||
),
|
||||
)
|
||||
|
||||
# Add the new column as nullable initially
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("search_settings_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Populate the new column with data from the existing embedding_model_id
|
||||
op.execute("UPDATE index_attempt SET search_settings_id = embedding_model_id")
|
||||
|
||||
# Create the foreign key constraint
|
||||
op.create_foreign_key(
|
||||
"fk_index_attempt_search_settings",
|
||||
"index_attempt",
|
||||
"search_settings",
|
||||
["search_settings_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Make the new column non-nullable
|
||||
op.alter_column("index_attempt", "search_settings_id", nullable=False)
|
||||
|
||||
# Drop the old embedding_model_id column
|
||||
op.drop_column("index_attempt", "embedding_model_id")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back the embedding_model_id column
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("embedding_model_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Populate the old column with data from search_settings_id
|
||||
op.execute("UPDATE index_attempt SET embedding_model_id = search_settings_id")
|
||||
|
||||
# Make the old column non-nullable
|
||||
op.alter_column("index_attempt", "embedding_model_id", nullable=False)
|
||||
|
||||
# Drop the foreign key constraint
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Drop the new search_settings_id column
|
||||
op.drop_column("index_attempt", "search_settings_id")
|
||||
|
||||
# Rename the table back
|
||||
op.rename_table("search_settings", "embedding_model")
|
||||
|
||||
# Remove added columns
|
||||
op.drop_column("embedding_model", "num_rerank")
|
||||
op.drop_column("embedding_model", "rerank_api_key")
|
||||
op.drop_column("embedding_model", "rerank_provider_type")
|
||||
op.drop_column("embedding_model", "rerank_model_name")
|
||||
op.drop_column("embedding_model", "disable_rerank_for_streaming")
|
||||
op.drop_column("embedding_model", "multilingual_expansion")
|
||||
op.drop_column("embedding_model", "multipass_indexing")
|
||||
|
||||
op.create_foreign_key(
|
||||
"index_attempt__embedding_model_fk",
|
||||
"index_attempt",
|
||||
"embedding_model",
|
||||
["embedding_model_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Add Above Below to Persona
|
||||
|
||||
Revision ID: 2d2304e27d8c
|
||||
Revises: 4b08d97e175a
|
||||
Create Date: 2024-08-21 19:15:15.762948
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2d2304e27d8c"
|
||||
down_revision = "4b08d97e175a"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("persona", sa.Column("chunks_above", sa.Integer(), nullable=True))
|
||||
op.add_column("persona", sa.Column("chunks_below", sa.Integer(), nullable=True))
|
||||
|
||||
op.execute(
|
||||
"UPDATE persona SET chunks_above = 1, chunks_below = 1 WHERE chunks_above IS NULL AND chunks_below IS NULL"
|
||||
)
|
||||
|
||||
op.alter_column("persona", "chunks_above", nullable=False)
|
||||
op.alter_column("persona", "chunks_below", nullable=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "chunks_below")
|
||||
op.drop_column("persona", "chunks_above")
|
||||
90
backend/alembic/versions/351faebd379d_add_curator_fields.py
Normal file
90
backend/alembic/versions/351faebd379d_add_curator_fields.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Add curator fields
|
||||
|
||||
Revision ID: 351faebd379d
|
||||
Revises: ee3f4b47fad5
|
||||
Create Date: 2024-08-15 22:37:08.397052
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "351faebd379d"
|
||||
down_revision = "ee3f4b47fad5"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add is_curator column to User__UserGroup table
|
||||
op.add_column(
|
||||
"user__user_group",
|
||||
sa.Column("is_curator", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
|
||||
# Use batch mode to modify the enum type
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.alter_column( # type: ignore[attr-defined]
|
||||
"role",
|
||||
type_=sa.Enum(
|
||||
"BASIC",
|
||||
"ADMIN",
|
||||
"CURATOR",
|
||||
"GLOBAL_CURATOR",
|
||||
name="userrole",
|
||||
native_enum=False,
|
||||
),
|
||||
existing_type=sa.Enum("BASIC", "ADMIN", name="userrole", native_enum=False),
|
||||
existing_nullable=False,
|
||||
)
|
||||
# Create the association table
|
||||
op.create_table(
|
||||
"credential__user_group",
|
||||
sa.Column("credential_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_group_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["credential_id"],
|
||||
["credential.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_group_id"],
|
||||
["user_group.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("credential_id", "user_group_id"),
|
||||
)
|
||||
op.add_column(
|
||||
"credential",
|
||||
sa.Column(
|
||||
"curator_public", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Update existing records to ensure they fit within the BASIC/ADMIN roles
|
||||
op.execute(
|
||||
"UPDATE \"user\" SET role = 'ADMIN' WHERE role IN ('CURATOR', 'GLOBAL_CURATOR')"
|
||||
)
|
||||
|
||||
# Remove is_curator column from User__UserGroup table
|
||||
op.drop_column("user__user_group", "is_curator")
|
||||
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.alter_column( # type: ignore[attr-defined]
|
||||
"role",
|
||||
type_=sa.Enum(
|
||||
"BASIC", "ADMIN", name="userrole", native_enum=False, length=20
|
||||
),
|
||||
existing_type=sa.Enum(
|
||||
"BASIC",
|
||||
"ADMIN",
|
||||
"CURATOR",
|
||||
"GLOBAL_CURATOR",
|
||||
name="userrole",
|
||||
native_enum=False,
|
||||
),
|
||||
existing_nullable=False,
|
||||
)
|
||||
# Drop the association table
|
||||
op.drop_table("credential__user_group")
|
||||
op.drop_column("credential", "curator_public")
|
||||
@@ -0,0 +1,34 @@
|
||||
"""change default prune_freq
|
||||
|
||||
Revision ID: 4b08d97e175a
|
||||
Revises: d9ec13955951
|
||||
Create Date: 2024-08-20 15:28:52.993827
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4b08d97e175a"
|
||||
down_revision = "d9ec13955951"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET prune_freq = 2592000
|
||||
WHERE prune_freq = 86400
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET prune_freq = 86400
|
||||
WHERE prune_freq = 2592000
|
||||
"""
|
||||
)
|
||||
@@ -1,17 +1,18 @@
|
||||
"""migrate tool calls
|
||||
"""single tool call per message
|
||||
|
||||
Revision ID: eb690a089310
|
||||
Revises: ee3f4b47fad5
|
||||
Create Date: 2024-08-04 17:07:47.533051
|
||||
|
||||
Revision ID: 4e8e7ae58189
|
||||
Revises: 5c7fdadae813
|
||||
Create Date: 2024-09-09 10:07:58.008838
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "eb690a089310"
|
||||
down_revision = "ee3f4b47fad5"
|
||||
revision = "4e8e7ae58189"
|
||||
down_revision = "5c7fdadae813"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
@@ -38,6 +39,11 @@ def upgrade() -> None:
|
||||
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
|
||||
op.drop_column("tool_call", "message_id")
|
||||
|
||||
# Add a unique constraint to ensure one-to-one relationship
|
||||
op.create_unique_constraint(
|
||||
"uq_chat_message_tool_call_id", "chat_message", ["tool_call_id"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back the old column
|
||||
@@ -0,0 +1,66 @@
|
||||
"""Add last synced and last modified to document table
|
||||
|
||||
Revision ID: 52a219fb5233
|
||||
Revises: f17bf3b0d9f1
|
||||
Create Date: 2024-08-28 17:40:46.077470
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "52a219fb5233"
|
||||
down_revision = "f7e58d357687"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# last modified represents the last time anything needing syncing to vespa changed
|
||||
# including row metadata and the document itself. This obviously does not include
|
||||
# the last_synced column.
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column(
|
||||
"last_modified",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
),
|
||||
)
|
||||
|
||||
# last synced represents the last time this document was synced to Vespa
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
# Set last_synced to the same value as last_modified for existing rows
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE document
|
||||
SET last_synced = last_modified
|
||||
"""
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
op.f("ix_document_last_modified"),
|
||||
"document",
|
||||
["last_modified"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
op.f("ix_document_last_synced"),
|
||||
"document",
|
||||
["last_synced"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f("ix_document_last_synced"), table_name="document")
|
||||
op.drop_index(op.f("ix_document_last_modified"), table_name="document")
|
||||
op.drop_column("document", "last_synced")
|
||||
op.drop_column("document", "last_modified")
|
||||
@@ -0,0 +1,35 @@
|
||||
"""match_any_keywords flag for standard answers
|
||||
|
||||
Revision ID: 5c7fdadae813
|
||||
Revises: efb35676026c
|
||||
Create Date: 2024-09-13 18:52:59.256478
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5c7fdadae813"
|
||||
down_revision = "efb35676026c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"standard_answer",
|
||||
sa.Column(
|
||||
"match_any_keywords",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("standard_answer", "match_any_keywords")
|
||||
# ### end Alembic commands ###
|
||||
@@ -10,7 +10,7 @@ import sqlalchemy as sa
|
||||
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.enums import SearchType
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "776b3bbe9092"
|
||||
|
||||
@@ -35,18 +35,22 @@ def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE index_attempt ia
|
||||
SET connector_credential_pair_id =
|
||||
CASE
|
||||
WHEN ia.credential_id IS NULL THEN
|
||||
(SELECT id FROM connector_credential_pair
|
||||
WHERE connector_id = ia.connector_id
|
||||
LIMIT 1)
|
||||
ELSE
|
||||
(SELECT id FROM connector_credential_pair
|
||||
WHERE connector_id = ia.connector_id
|
||||
AND credential_id = ia.credential_id)
|
||||
END
|
||||
WHERE ia.connector_id IS NOT NULL
|
||||
SET connector_credential_pair_id = (
|
||||
SELECT id FROM connector_credential_pair ccp
|
||||
WHERE
|
||||
(ia.connector_id IS NULL OR ccp.connector_id = ia.connector_id)
|
||||
AND (ia.credential_id IS NULL OR ccp.credential_id = ia.credential_id)
|
||||
LIMIT 1
|
||||
)
|
||||
WHERE ia.connector_id IS NOT NULL OR ia.credential_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# For good measure
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM index_attempt
|
||||
WHERE connector_credential_pair_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,158 @@
|
||||
"""migration confluence to be explicit
|
||||
|
||||
Revision ID: a3795dce87be
|
||||
Revises: 1f60f60c3401
|
||||
Create Date: 2024-09-01 13:52:12.006740
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.sql import table, column
|
||||
|
||||
revision = "a3795dce87be"
|
||||
down_revision = "1f60f60c3401"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]:
|
||||
parsed_url = urlparse(wiki_url)
|
||||
wiki_base = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split('/spaces')[0]}"
|
||||
path_parts = parsed_url.path.split("/")
|
||||
space = path_parts[3]
|
||||
page_id = path_parts[5] if len(path_parts) > 5 else ""
|
||||
return wiki_base, space, page_id
|
||||
|
||||
def _extract_confluence_keys_from_datacenter_url(
|
||||
wiki_url: str,
|
||||
) -> tuple[str, str, str]:
|
||||
DISPLAY = "/display/"
|
||||
PAGE = "/pages/"
|
||||
parsed_url = urlparse(wiki_url)
|
||||
wiki_base = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split(DISPLAY)[0]}"
|
||||
space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0]
|
||||
page_id = ""
|
||||
if (content := parsed_url.path.split(PAGE)) and len(content) > 1:
|
||||
page_id = content[1]
|
||||
return wiki_base, space, page_id
|
||||
|
||||
is_confluence_cloud = (
|
||||
".atlassian.net/wiki/spaces/" in wiki_url
|
||||
or ".jira.com/wiki/spaces/" in wiki_url
|
||||
)
|
||||
|
||||
if is_confluence_cloud:
|
||||
wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url(wiki_url)
|
||||
else:
|
||||
wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url(
|
||||
wiki_url
|
||||
)
|
||||
|
||||
return wiki_base, space, page_id, is_confluence_cloud
|
||||
|
||||
|
||||
def reconstruct_confluence_url(
|
||||
wiki_base: str, space: str, page_id: str, is_cloud: bool
|
||||
) -> str:
|
||||
if is_cloud:
|
||||
url = f"{wiki_base}/spaces/{space}"
|
||||
if page_id:
|
||||
url += f"/pages/{page_id}"
|
||||
else:
|
||||
url = f"{wiki_base}/display/{space}"
|
||||
if page_id:
|
||||
url += f"/pages/{page_id}"
|
||||
return url
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
connector = table(
|
||||
"connector",
|
||||
column("id", sa.Integer),
|
||||
column("source", sa.String()),
|
||||
column("input_type", sa.String()),
|
||||
column("connector_specific_config", postgresql.JSONB),
|
||||
)
|
||||
|
||||
# Fetch all Confluence connectors
|
||||
connection = op.get_bind()
|
||||
confluence_connectors = connection.execute(
|
||||
sa.select(connector).where(
|
||||
sa.and_(
|
||||
connector.c.source == "CONFLUENCE", connector.c.input_type == "POLL"
|
||||
)
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for row in confluence_connectors:
|
||||
config = row.connector_specific_config
|
||||
wiki_page_url = config["wiki_page_url"]
|
||||
wiki_base, space, page_id, is_cloud = extract_confluence_keys_from_url(
|
||||
wiki_page_url
|
||||
)
|
||||
|
||||
new_config = {
|
||||
"wiki_base": wiki_base,
|
||||
"space": space,
|
||||
"page_id": page_id,
|
||||
"is_cloud": is_cloud,
|
||||
}
|
||||
|
||||
for key, value in config.items():
|
||||
if key not in ["wiki_page_url"]:
|
||||
new_config[key] = value
|
||||
|
||||
op.execute(
|
||||
connector.update()
|
||||
.where(connector.c.id == row.id)
|
||||
.values(connector_specific_config=new_config)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
connector = table(
|
||||
"connector",
|
||||
column("id", sa.Integer),
|
||||
column("source", sa.String()),
|
||||
column("input_type", sa.String()),
|
||||
column("connector_specific_config", postgresql.JSONB),
|
||||
)
|
||||
|
||||
confluence_connectors = (
|
||||
op.get_bind()
|
||||
.execute(
|
||||
sa.select(connector).where(
|
||||
connector.c.source == "CONFLUENCE", connector.c.input_type == "POLL"
|
||||
)
|
||||
)
|
||||
.fetchall()
|
||||
)
|
||||
|
||||
for row in confluence_connectors:
|
||||
config = row.connector_specific_config
|
||||
if all(key in config for key in ["wiki_base", "space", "is_cloud"]):
|
||||
wiki_page_url = reconstruct_confluence_url(
|
||||
config["wiki_base"],
|
||||
config["space"],
|
||||
config.get("page_id", ""),
|
||||
config["is_cloud"],
|
||||
)
|
||||
|
||||
new_config = {"wiki_page_url": wiki_page_url}
|
||||
new_config.update(
|
||||
{
|
||||
k: v
|
||||
for k, v in config.items()
|
||||
if k not in ["wiki_base", "space", "page_id", "is_cloud"]
|
||||
}
|
||||
)
|
||||
|
||||
op.execute(
|
||||
connector.update()
|
||||
.where(connector.c.id == row.id)
|
||||
.values(connector_specific_config=new_config)
|
||||
)
|
||||
@@ -0,0 +1,26 @@
|
||||
"""add support for litellm proxy in reranking
|
||||
|
||||
Revision ID: ba98eba0f66a
|
||||
Revises: bceb1e139447
|
||||
Create Date: 2024-09-06 10:36:04.507332
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ba98eba0f66a"
|
||||
down_revision = "bceb1e139447"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"search_settings", sa.Column("rerank_api_url", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("search_settings", "rerank_api_url")
|
||||
@@ -0,0 +1,26 @@
|
||||
"""Add base_url to CloudEmbeddingProvider
|
||||
|
||||
Revision ID: bceb1e139447
|
||||
Revises: a3795dce87be
|
||||
Create Date: 2024-08-28 17:00:52.554580
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bceb1e139447"
|
||||
down_revision = "a3795dce87be"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"embedding_provider", sa.Column("api_url", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("embedding_provider", "api_url")
|
||||
@@ -0,0 +1,57 @@
|
||||
"""Add index_attempt_errors table
|
||||
|
||||
Revision ID: c5b692fa265c
|
||||
Revises: 4a951134c801
|
||||
Create Date: 2024-08-08 14:06:39.581972
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c5b692fa265c"
|
||||
down_revision = "4a951134c801"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"index_attempt_errors",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("index_attempt_id", sa.Integer(), nullable=True),
|
||||
sa.Column("batch", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"doc_summaries",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("error_msg", sa.Text(), nullable=True),
|
||||
sa.Column("traceback", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["index_attempt_id"],
|
||||
["index_attempt.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"index_attempt_id",
|
||||
"index_attempt_errors",
|
||||
["time_created"],
|
||||
unique=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("index_attempt_id", table_name="index_attempt_errors")
|
||||
op.drop_table("index_attempt_errors")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,31 @@
|
||||
"""Remove _alt suffix from model_name
|
||||
|
||||
Revision ID: d9ec13955951
|
||||
Revises: da4c21c69164
|
||||
Create Date: 2024-08-20 16:31:32.955686
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d9ec13955951"
|
||||
down_revision = "da4c21c69164"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE embedding_model
|
||||
SET model_name = regexp_replace(model_name, '__danswer_alt_index$', '')
|
||||
WHERE model_name LIKE '%__danswer_alt_index'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We can't reliably add the __danswer_alt_index suffix back, so we'll leave this empty
|
||||
pass
|
||||
@@ -0,0 +1,65 @@
|
||||
"""chosen_assistants changed to jsonb
|
||||
|
||||
Revision ID: da4c21c69164
|
||||
Revises: c5b692fa265c
|
||||
Create Date: 2024-08-18 19:06:47.291491
|
||||
|
||||
"""
|
||||
import json
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "da4c21c69164"
|
||||
down_revision = "c5b692fa265c"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
existing_ids_and_chosen_assistants = conn.execute(
|
||||
sa.text("select id, chosen_assistants from public.user")
|
||||
)
|
||||
op.drop_column(
|
||||
"user",
|
||||
"chosen_assistants",
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"chosen_assistants",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
for id, chosen_assistants in existing_ids_and_chosen_assistants:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
|
||||
),
|
||||
{"chosen_assistants": json.dumps(chosen_assistants), "id": id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
existing_ids_and_chosen_assistants = conn.execute(
|
||||
sa.text("select id, chosen_assistants from public.user")
|
||||
)
|
||||
op.drop_column(
|
||||
"user",
|
||||
"chosen_assistants",
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True),
|
||||
)
|
||||
for id, chosen_assistants in existing_ids_and_chosen_assistants:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
|
||||
),
|
||||
{"chosen_assistants": chosen_assistants, "id": id},
|
||||
)
|
||||
@@ -9,7 +9,7 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import table, column, String, Integer, Boolean
|
||||
|
||||
from danswer.db.embedding_model import (
|
||||
from danswer.db.search_settings import (
|
||||
get_new_default_embedding_model,
|
||||
get_old_default_embedding_model,
|
||||
user_has_overridden_embedding_model,
|
||||
@@ -71,14 +71,14 @@ def upgrade() -> None:
|
||||
"query_prefix": old_embedding_model.query_prefix,
|
||||
"passage_prefix": old_embedding_model.passage_prefix,
|
||||
"index_name": old_embedding_model.index_name,
|
||||
"status": old_embedding_model.status,
|
||||
"status": IndexModelStatus.PRESENT,
|
||||
}
|
||||
],
|
||||
)
|
||||
# if the user has not overridden the default embedding model via env variables,
|
||||
# insert the new default model into the database to auto-upgrade them
|
||||
if not user_has_overridden_embedding_model():
|
||||
new_embedding_model = get_new_default_embedding_model(is_present=False)
|
||||
new_embedding_model = get_new_default_embedding_model()
|
||||
op.bulk_insert(
|
||||
EmbeddingModel,
|
||||
[
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Added alternate model to chat message
|
||||
|
||||
Revision ID: ee3f4b47fad5
|
||||
Revises: 4a951134c801
|
||||
Revises: 2d2304e27d8c
|
||||
Create Date: 2024-08-12 00:11:50.915845
|
||||
|
||||
"""
|
||||
@@ -12,17 +12,17 @@ import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ee3f4b47fad5"
|
||||
down_revision = "4a951134c801"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
down_revision = "2d2304e27d8c"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("alternate_model", sa.String(length=255), nullable=True),
|
||||
sa.Column("overridden_model", sa.String(length=255), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "alternate_model")
|
||||
op.drop_column("chat_message", "overridden_model")
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
"""standard answer match_regex flag
|
||||
|
||||
Revision ID: efb35676026c
|
||||
Revises: 52a219fb5233
|
||||
Create Date: 2024-09-11 13:55:46.101149
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "efb35676026c"
|
||||
down_revision = "0ebb1d516877"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"standard_answer",
|
||||
sa.Column("match_regex", sa.Boolean(), nullable=False, default=False),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("standard_answer", "match_regex")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,172 @@
|
||||
"""embedding provider by provider type
|
||||
|
||||
Revision ID: f17bf3b0d9f1
|
||||
Revises: 351faebd379d
|
||||
Create Date: 2024-08-21 13:13:31.120460
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f17bf3b0d9f1"
|
||||
down_revision = "351faebd379d"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add provider_type column to embedding_provider
|
||||
op.add_column(
|
||||
"embedding_provider",
|
||||
sa.Column("provider_type", sa.String(50), nullable=True),
|
||||
)
|
||||
|
||||
# Update provider_type with existing name values
|
||||
op.execute("UPDATE embedding_provider SET provider_type = UPPER(name)")
|
||||
|
||||
# Make provider_type not nullable
|
||||
op.alter_column("embedding_provider", "provider_type", nullable=False)
|
||||
|
||||
# Drop the foreign key constraint in embedding_model table
|
||||
op.drop_constraint(
|
||||
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Drop the existing primary key constraint
|
||||
op.drop_constraint("embedding_provider_pkey", "embedding_provider", type_="primary")
|
||||
|
||||
# Create a new primary key constraint on provider_type
|
||||
op.create_primary_key(
|
||||
"embedding_provider_pkey", "embedding_provider", ["provider_type"]
|
||||
)
|
||||
|
||||
# Add provider_type column to embedding_model
|
||||
op.add_column(
|
||||
"embedding_model",
|
||||
sa.Column("provider_type", sa.String(50), nullable=True),
|
||||
)
|
||||
|
||||
# Update provider_type for existing embedding models
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE embedding_model
|
||||
SET provider_type = (
|
||||
SELECT provider_type
|
||||
FROM embedding_provider
|
||||
WHERE embedding_provider.id = embedding_model.cloud_provider_id
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop the old id column from embedding_provider
|
||||
op.drop_column("embedding_provider", "id")
|
||||
|
||||
# Drop the name column from embedding_provider
|
||||
op.drop_column("embedding_provider", "name")
|
||||
|
||||
# Drop the default_model_id column from embedding_provider
|
||||
op.drop_column("embedding_provider", "default_model_id")
|
||||
|
||||
# Drop the old cloud_provider_id column from embedding_model
|
||||
op.drop_column("embedding_model", "cloud_provider_id")
|
||||
|
||||
# Create the new foreign key constraint
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_model_cloud_provider",
|
||||
"embedding_model",
|
||||
"embedding_provider",
|
||||
["provider_type"],
|
||||
["provider_type"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the foreign key constraint in embedding_model table
|
||||
op.drop_constraint(
|
||||
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Add back the cloud_provider_id column to embedding_model
|
||||
op.add_column(
|
||||
"embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.add_column("embedding_provider", sa.Column("id", sa.Integer(), nullable=True))
|
||||
|
||||
# Assign incrementing IDs to embedding providers
|
||||
op.execute(
|
||||
"""
|
||||
CREATE SEQUENCE IF NOT EXISTS embedding_provider_id_seq;"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE embedding_provider SET id = nextval('embedding_provider_id_seq');
|
||||
"""
|
||||
)
|
||||
|
||||
# Update cloud_provider_id based on provider_type
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE embedding_model
|
||||
SET cloud_provider_id = CASE
|
||||
WHEN provider_type IS NULL THEN NULL
|
||||
ELSE (
|
||||
SELECT id
|
||||
FROM embedding_provider
|
||||
WHERE embedding_provider.provider_type = embedding_model.provider_type
|
||||
)
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop the provider_type column from embedding_model
|
||||
op.drop_column("embedding_model", "provider_type")
|
||||
|
||||
# Add back the columns to embedding_provider
|
||||
op.add_column("embedding_provider", sa.Column("name", sa.String(50), nullable=True))
|
||||
op.add_column(
|
||||
"embedding_provider", sa.Column("default_model_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Drop the existing primary key constraint on provider_type
|
||||
op.drop_constraint("embedding_provider_pkey", "embedding_provider", type_="primary")
|
||||
|
||||
# Create the original primary key constraint on id
|
||||
op.create_primary_key("embedding_provider_pkey", "embedding_provider", ["id"])
|
||||
|
||||
# Update name with existing provider_type values
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE embedding_provider
|
||||
SET name = CASE
|
||||
WHEN provider_type = 'OPENAI' THEN 'OpenAI'
|
||||
WHEN provider_type = 'COHERE' THEN 'Cohere'
|
||||
WHEN provider_type = 'GOOGLE' THEN 'Google'
|
||||
WHEN provider_type = 'VOYAGE' THEN 'Voyage'
|
||||
ELSE provider_type
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop the provider_type column from embedding_provider
|
||||
op.drop_column("embedding_provider", "provider_type")
|
||||
|
||||
# Recreate the foreign key constraint in embedding_model table
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_model_cloud_provider",
|
||||
"embedding_model",
|
||||
"embedding_provider",
|
||||
["cloud_provider_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Recreate the foreign key constraint in embedding_model table
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_provider_default_model",
|
||||
"embedding_provider",
|
||||
"embedding_model",
|
||||
["default_model_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -0,0 +1,26 @@
|
||||
"""add has_web_login column to user
|
||||
|
||||
Revision ID: f7e58d357687
|
||||
Revises: bceb1e139447
|
||||
Create Date: 2024-09-07 20:20:54.522620
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f7e58d357687"
|
||||
down_revision = "ba98eba0f66a"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "has_web_login")
|
||||
@@ -3,21 +3,49 @@ from sqlalchemy.orm import Session
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.access.utils import prefix_user
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
from danswer.db.document import get_acccess_info_for_documents
|
||||
from danswer.db.document import get_access_info_for_document
|
||||
from danswer.db.document import get_access_info_for_documents
|
||||
from danswer.db.models import User
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
def _get_access_for_document(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> DocumentAccess:
|
||||
info = get_access_info_for_document(
|
||||
db_session=db_session,
|
||||
document_id=document_id,
|
||||
)
|
||||
|
||||
if not info:
|
||||
return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False)
|
||||
|
||||
return DocumentAccess.build(user_ids=info[1], user_groups=[], is_public=info[2])
|
||||
|
||||
|
||||
def get_access_for_document(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> DocumentAccess:
|
||||
versioned_get_access_for_document_fn = fetch_versioned_implementation(
|
||||
"danswer.access.access", "_get_access_for_document"
|
||||
)
|
||||
return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore
|
||||
|
||||
|
||||
def _get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
document_access_info = get_acccess_info_for_documents(
|
||||
document_access_info = get_access_info_for_documents(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
return {
|
||||
document_id: DocumentAccess.build(user_ids, [], is_public)
|
||||
document_id: DocumentAccess.build(
|
||||
user_ids=user_ids, user_groups=[], is_public=is_public
|
||||
)
|
||||
for document_id, user_ids, is_public in document_access_info
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from danswer.server.manage.models import UserPreferences
|
||||
def set_no_auth_user_preferences(
|
||||
store: DynamicConfigStore, preferences: UserPreferences
|
||||
) -> None:
|
||||
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.dict())
|
||||
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.model_dump())
|
||||
|
||||
|
||||
def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences:
|
||||
|
||||
@@ -5,8 +5,20 @@ from fastapi_users import schemas
|
||||
|
||||
|
||||
class UserRole(str, Enum):
|
||||
"""
|
||||
User roles
|
||||
- Basic can't perform any admin actions
|
||||
- Admin can perform all admin actions
|
||||
- Curator can perform admin actions for
|
||||
groups they are curators of
|
||||
- Global Curator can perform admin actions
|
||||
for all groups they are a member of
|
||||
"""
|
||||
|
||||
BASIC = "basic"
|
||||
ADMIN = "admin"
|
||||
CURATOR = "curator"
|
||||
GLOBAL_CURATOR = "global_curator"
|
||||
|
||||
|
||||
class UserStatus(str, Enum):
|
||||
@@ -21,7 +33,9 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
has_web_login: bool | None = True
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
has_web_login: bool | None = True
|
||||
|
||||
@@ -8,13 +8,17 @@ from email.mime.text import MIMEText
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
from email_validator import EmailNotValidError
|
||||
from email_validator import validate_email
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi_users import BaseUserManager
|
||||
from fastapi_users import exceptions
|
||||
from fastapi_users import FastAPIUsers
|
||||
from fastapi_users import models
|
||||
from fastapi_users import schemas
|
||||
@@ -31,6 +35,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.auth.invited_users import get_invited_users
|
||||
from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import EMAIL_FROM
|
||||
@@ -40,6 +45,7 @@ from danswer.configs.app_configs import SMTP_PASS
|
||||
from danswer.configs.app_configs import SMTP_PORT
|
||||
from danswer.configs.app_configs import SMTP_SERVER
|
||||
from danswer.configs.app_configs import SMTP_USER
|
||||
from danswer.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
|
||||
from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
@@ -59,10 +65,7 @@ from danswer.db.users import get_user_by_email
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation,
|
||||
)
|
||||
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -81,7 +84,7 @@ def verify_auth_setting() -> None:
|
||||
"User must choose a valid user authentication method: "
|
||||
"disabled, basic, or google_oauth"
|
||||
)
|
||||
logger.info(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
def get_display_email(email: str | None, space_less: bool = False) -> str:
|
||||
@@ -106,8 +109,28 @@ def user_needs_to_be_verified() -> bool:
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
whitelist = get_invited_users()
|
||||
if (whitelist and email not in whitelist) or not email:
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
if not whitelist:
|
||||
return
|
||||
|
||||
if not email:
|
||||
raise PermissionError("Email must be specified")
|
||||
|
||||
email_info = validate_email(email) # can raise EmailNotValidError
|
||||
|
||||
for email_whitelist in whitelist:
|
||||
try:
|
||||
# normalized emails are now being inserted into the db
|
||||
# we can remove this normalization on read after some time has passed
|
||||
email_info_whitelist = validate_email(email_whitelist)
|
||||
except EmailNotValidError:
|
||||
continue
|
||||
|
||||
# oddly, normalization does not include lowercasing the user part of the
|
||||
# email address ... which we want to allow
|
||||
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
|
||||
return
|
||||
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str) -> None:
|
||||
@@ -164,7 +187,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_create: schemas.UC | UserCreate,
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> models.UP:
|
||||
) -> User:
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if hasattr(user_create, "role"):
|
||||
@@ -173,7 +196,27 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
return await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
user = None
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if (
|
||||
not user.has_web_login
|
||||
and hasattr(user_create, "has_web_login")
|
||||
and user_create.has_web_login
|
||||
):
|
||||
user_update = UserUpdate(
|
||||
password=user_create.password,
|
||||
has_web_login=True,
|
||||
role=user_create.role,
|
||||
is_verified=user_create.is_verified,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
else:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
return user
|
||||
|
||||
async def oauth_callback(
|
||||
self: "BaseUserManager[models.UOAP, models.ID]",
|
||||
@@ -203,18 +246,35 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
is_verified_by_default=is_verified_by_default,
|
||||
)
|
||||
|
||||
# NOTE: google oauth expires after 1hr. We don't want to force the user to
|
||||
# re-authenticate that frequently, so for now we'll just ignore this for
|
||||
# google oauth users
|
||||
if expires_at and AUTH_TYPE != AuthType.GOOGLE_OAUTH:
|
||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||
# re-authenticate that frequently, so by default this is disabled
|
||||
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
|
||||
await self.user_db.update(user, update_dict={"oidc_expiry": oidc_expiry})
|
||||
|
||||
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
|
||||
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
||||
if user.oidc_expiry and not TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
await self.user_db.update(user, update_dict={"oidc_expiry": None})
|
||||
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if not user.has_web_login:
|
||||
await self.user_db.update(
|
||||
user,
|
||||
update_dict={
|
||||
"is_verified": is_verified_by_default,
|
||||
"has_web_login": True,
|
||||
},
|
||||
)
|
||||
user.is_verified = is_verified_by_default
|
||||
user.has_web_login = True
|
||||
|
||||
return user
|
||||
|
||||
async def on_after_register(
|
||||
self, user: User, request: Optional[Request] = None
|
||||
) -> None:
|
||||
logger.info(f"User {user.id} has registered.")
|
||||
logger.notice(f"User {user.id} has registered.")
|
||||
optional_telemetry(
|
||||
record_type=RecordType.SIGN_UP,
|
||||
data={"action": "create"},
|
||||
@@ -224,19 +284,35 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async def on_after_forgot_password(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
) -> None:
|
||||
logger.info(f"User {user.id} has forgot their password. Reset token: {token}")
|
||||
logger.notice(f"User {user.id} has forgot their password. Reset token: {token}")
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
) -> None:
|
||||
verify_email_domain(user.email)
|
||||
|
||||
logger.info(
|
||||
logger.notice(
|
||||
f"Verification requested for user {user.id}. Verification token: {token}"
|
||||
)
|
||||
|
||||
send_user_verification_email(user.email, token)
|
||||
|
||||
async def authenticate(
|
||||
self, credentials: OAuth2PasswordRequestForm
|
||||
) -> Optional[User]:
|
||||
user = await super().authenticate(credentials)
|
||||
if user is None:
|
||||
try:
|
||||
user = await self.get_by_email(credentials.username)
|
||||
if not user.has_web_login:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
pass
|
||||
return user
|
||||
|
||||
|
||||
async def get_user_manager(
|
||||
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
|
||||
@@ -339,6 +415,7 @@ async def optional_user(
|
||||
async def double_check_user(
|
||||
user: User | None,
|
||||
optional: bool = DISABLE_AUTH,
|
||||
include_expired: bool = False,
|
||||
) -> User | None:
|
||||
if optional:
|
||||
return None
|
||||
@@ -355,7 +432,11 @@ async def double_check_user(
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
|
||||
if (
|
||||
user.oidc_expiry
|
||||
and user.oidc_expiry < datetime.now(timezone.utc)
|
||||
and not include_expired
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
@@ -364,12 +445,40 @@ async def double_check_user(
|
||||
return user
|
||||
|
||||
|
||||
async def current_user_with_expired_token(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
return await double_check_user(user, include_expired=True)
|
||||
|
||||
|
||||
async def current_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_curator_or_admin_user(
|
||||
user: User | None = Depends(current_user),
|
||||
) -> User | None:
|
||||
if DISABLE_AUTH:
|
||||
return None
|
||||
|
||||
if not user or not hasattr(user, "role"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not authenticated or lacks role information.",
|
||||
)
|
||||
|
||||
allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN}
|
||||
if user.role not in allowed_roles:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not a curator or admin.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def current_admin_user(user: User | None = Depends(current_user)) -> User | None:
|
||||
if DISABLE_AUTH:
|
||||
return None
|
||||
@@ -377,7 +486,12 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
|
||||
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not an admin.",
|
||||
detail="Access denied. User must be an admin to perform this action.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def get_default_admin_user_emails_() -> list[str]:
|
||||
# No default seeding available for Danswer MIT
|
||||
return []
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
299
backend/danswer/background/celery/celery_redis.py
Normal file
299
backend/danswer/background/celery/celery_redis.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# These are helper objects for tracking the keys we need to write in redis
|
||||
import time
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celeryconfig import CELERY_SEPARATOR
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import (
|
||||
construct_document_select_for_connector_credential_pair_by_needs_sync,
|
||||
)
|
||||
from danswer.db.document_set import construct_document_select_by_docset
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
class RedisObjectHelper(ABC):
|
||||
PREFIX = "base"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int):
|
||||
self._id: int = id
|
||||
|
||||
@property
|
||||
def task_id_prefix(self) -> str:
|
||||
return f"{self.PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def fence_key(self) -> str:
|
||||
# example: documentset_fence_1
|
||||
return f"{self.FENCE_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def taskset_key(self) -> str:
|
||||
# example: documentset_taskset_1
|
||||
return f"{self.TASKSET_PREFIX}_{self._id}"
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_fence_key(key: str) -> int | None:
|
||||
"""
|
||||
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
|
||||
|
||||
Args:
|
||||
key (str): The fence key string.
|
||||
|
||||
Returns:
|
||||
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
|
||||
"""
|
||||
parts = key.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
object_id = int(parts[2])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
return object_id
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_task_id(task_id: str) -> int | None:
|
||||
"""
|
||||
Extracts the object ID from a task ID string.
|
||||
|
||||
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
|
||||
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
|
||||
- `objectid` is the ID you want to extract,
|
||||
- `suffix` is another arbitrary string (e.g., a UUID).
|
||||
|
||||
Example:
|
||||
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
|
||||
this method will return the string `"1"`.
|
||||
|
||||
Args:
|
||||
task_id (str): The task ID string from which to extract the object ID.
|
||||
|
||||
Returns:
|
||||
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
|
||||
"""
|
||||
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
|
||||
parts = task_id.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
object_id = int(parts[1])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
return object_id
|
||||
|
||||
@abstractmethod
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
pass
|
||||
|
||||
|
||||
class RedisDocumentSet(RedisObjectHelper):
|
||||
PREFIX = "documentset"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
stmt = construct_document_select_by_docset(self._id, current_only=False)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisUserGroup(RedisObjectHelper):
|
||||
PREFIX = "usergroup"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
|
||||
try:
|
||||
construct_document_select_by_usergroup = fetch_versioned_implementation(
|
||||
"danswer.db.user_group",
|
||||
"construct_document_select_by_usergroup",
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
return 0
|
||||
|
||||
stmt = construct_document_select_by_usergroup(self._id)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
PREFIX = "connectorsync"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
@classmethod
|
||||
def get_fence_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.FENCE_PREFIX
|
||||
|
||||
@classmethod
|
||||
def get_taskset_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.TASKSET_PREFIX
|
||||
|
||||
@property
|
||||
def taskset_key(self) -> str:
|
||||
"""Notice that this is intentionally reusing the same taskset for all
|
||||
connector syncs"""
|
||||
# example: connector_taskset
|
||||
return f"{self.TASKSET_PREFIX}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(
|
||||
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
|
||||
)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
"""This is a redis specific way to get the length of a celery queue.
|
||||
It is priority aware and knows how to count across the multiple redis lists
|
||||
used to implement task prioritization.
|
||||
This operation is not atomic."""
|
||||
total_length = 0
|
||||
for i in range(len(DanswerCeleryPriority)):
|
||||
queue_name = queue
|
||||
if i > 0:
|
||||
queue_name += CELERY_SEPARATOR
|
||||
queue_name += str(i)
|
||||
|
||||
length = r.llen(queue_name)
|
||||
total_length += cast(int, length)
|
||||
|
||||
return total_length
|
||||
@@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
@@ -22,7 +21,6 @@ from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import TaskQueueState
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
@@ -33,7 +31,7 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_deletion_status(
|
||||
def _get_deletion_status(
|
||||
connector_id: int, credential_id: int, db_session: Session
|
||||
) -> TaskQueueState | None:
|
||||
cleanup_task_name = name_cc_cleanup_task(
|
||||
@@ -45,7 +43,7 @@ def get_deletion_status(
|
||||
def get_deletion_attempt_snapshot(
|
||||
connector_id: int, credential_id: int, db_session: Session
|
||||
) -> DeletionAttemptSnapshot | None:
|
||||
deletion_task = get_deletion_status(connector_id, credential_id, db_session)
|
||||
deletion_task = _get_deletion_status(connector_id, credential_id, db_session)
|
||||
if not deletion_task:
|
||||
return None
|
||||
|
||||
@@ -65,7 +63,7 @@ def should_kick_off_deletion_of_cc_pair(
|
||||
if check_deletion_attempt_is_allowed(cc_pair, db_session):
|
||||
return False
|
||||
|
||||
deletion_task = get_deletion_status(
|
||||
deletion_task = _get_deletion_status(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
db_session=db_session,
|
||||
@@ -81,21 +79,6 @@ def should_kick_off_deletion_of_cc_pair(
|
||||
return True
|
||||
|
||||
|
||||
def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool:
|
||||
if document_set.is_up_to_date:
|
||||
return False
|
||||
|
||||
task_name = name_document_set_sync_task(document_set.id)
|
||||
latest_sync = get_latest_task(task_name, db_session)
|
||||
|
||||
if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session):
|
||||
logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.")
|
||||
return False
|
||||
|
||||
logger.info(f"Document set {document_set.id} syncing now!")
|
||||
return True
|
||||
|
||||
|
||||
def should_prune_cc_pair(
|
||||
connector: Connector, credential: Credential, db_session: Session
|
||||
) -> bool:
|
||||
|
||||
44
backend/danswer/background/celery/celeryconfig.py
Normal file
44
backend/danswer/background/celery/celeryconfig.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
|
||||
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
|
||||
from danswer.configs.app_configs import REDIS_HOST
|
||||
from danswer.configs.app_configs import REDIS_PASSWORD
|
||||
from danswer.configs.app_configs import REDIS_PORT
|
||||
from danswer.configs.app_configs import REDIS_SSL
|
||||
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
|
||||
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
|
||||
CELERY_SEPARATOR = ":"
|
||||
|
||||
CELERY_PASSWORD_PART = ""
|
||||
if REDIS_PASSWORD:
|
||||
CELERY_PASSWORD_PART = f":{REDIS_PASSWORD}@"
|
||||
|
||||
REDIS_SCHEME = "redis"
|
||||
|
||||
# SSL-specific query parameters for Redis URL
|
||||
SSL_QUERY_PARAMS = ""
|
||||
if REDIS_SSL:
|
||||
REDIS_SCHEME = "rediss"
|
||||
SSL_QUERY_PARAMS = f"?ssl_cert_reqs={REDIS_SSL_CERT_REQS}"
|
||||
if REDIS_SSL_CA_CERTS:
|
||||
SSL_QUERY_PARAMS += f"&ssl_ca_certs={REDIS_SSL_CA_CERTS}"
|
||||
|
||||
# example celery_broker_url: "redis://:password@localhost:6379/15"
|
||||
broker_url = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
|
||||
|
||||
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
|
||||
|
||||
# NOTE: prefetch 4 is significantly faster than prefetch 1 for small tasks
|
||||
# however, prefetching is bad when tasks are lengthy as those tasks
|
||||
# can stall other tasks.
|
||||
worker_prefetch_multiplier = 4
|
||||
|
||||
broker_transport_options = {
|
||||
"priority_steps": list(range(len(DanswerCeleryPriority))),
|
||||
"sep": CELERY_SEPARATOR,
|
||||
"queue_order_strategy": "priority",
|
||||
}
|
||||
|
||||
task_default_priority = DanswerCeleryPriority.MEDIUM
|
||||
task_acks_late = True
|
||||
@@ -151,8 +151,7 @@ def delete_connector_credential_pair(
|
||||
# index attempts
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
)
|
||||
|
||||
# document sets
|
||||
@@ -185,11 +184,11 @@ def delete_connector_credential_pair(
|
||||
connector_id=connector_id,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
logger.debug("Found no credentials left for connector, deleting connector")
|
||||
logger.info("Found no credentials left for connector, deleting connector")
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
|
||||
logger.info(
|
||||
logger.notice(
|
||||
"Successfully deleted connector_credential_pair with connector_id:"
|
||||
f" '{connector_id}' and credential_id: '{credential_id}'. Deleted {num_docs_deleted} docs."
|
||||
)
|
||||
|
||||
@@ -11,12 +11,9 @@ from danswer.background.indexing.tracer import DanswerTracer
|
||||
from danswer.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
|
||||
from danswer.configs.app_configs import INDEXING_TRACER_INTERVAL
|
||||
from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from danswer.connectors.connector_runner import ConnectorRunner
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
@@ -24,6 +21,7 @@ from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress
|
||||
from danswer.db.index_attempt import mark_attempt_partially_succeeded
|
||||
from danswer.db.index_attempt import mark_attempt_succeeded
|
||||
from danswer.db.index_attempt import update_docs_indexed
|
||||
from danswer.db.models import IndexAttempt
|
||||
@@ -41,12 +39,12 @@ logger = setup_logger()
|
||||
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
|
||||
|
||||
|
||||
def _get_document_generator(
|
||||
def _get_connector_runner(
|
||||
db_session: Session,
|
||||
attempt: IndexAttempt,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> GenerateDocumentsOutput:
|
||||
) -> ConnectorRunner:
|
||||
"""
|
||||
NOTE: `start_time` and `end_time` are only used for poll connectors
|
||||
|
||||
@@ -76,31 +74,9 @@ def _get_document_generator(
|
||||
)
|
||||
raise e
|
||||
|
||||
if task == InputType.LOAD_STATE:
|
||||
assert isinstance(runnable_connector, LoadConnector)
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
|
||||
elif task == InputType.POLL:
|
||||
assert isinstance(runnable_connector, PollConnector)
|
||||
if (
|
||||
attempt.connector_credential_pair.connector_id is None
|
||||
or attempt.connector_credential_pair.connector_id is None
|
||||
):
|
||||
raise ValueError(
|
||||
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
|
||||
f"can't fetch time range."
|
||||
)
|
||||
|
||||
logger.info(f"Polling for updates between {start_time} and {end_time}")
|
||||
doc_batch_generator = runnable_connector.poll_source(
|
||||
start=start_time.timestamp(), end=end_time.timestamp()
|
||||
)
|
||||
|
||||
else:
|
||||
# Event types cannot be handled by a background type
|
||||
raise RuntimeError(f"Invalid task type: {task}")
|
||||
|
||||
return doc_batch_generator
|
||||
return ConnectorRunner(
|
||||
connector=runnable_connector, time_range=(start_time, end_time)
|
||||
)
|
||||
|
||||
|
||||
def _run_indexing(
|
||||
@@ -114,55 +90,62 @@ def _run_indexing(
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
db_embedding_model = index_attempt.embedding_model
|
||||
index_name = db_embedding_model.index_name
|
||||
search_settings = index_attempt.search_settings
|
||||
index_name = search_settings.index_name
|
||||
|
||||
# Only update cc-pair status for primary index jobs
|
||||
# Secondary index syncs at the end when swapping
|
||||
is_primary = index_attempt.embedding_model.status == IndexModelStatus.PRESENT
|
||||
is_primary = search_settings.status == IndexModelStatus.PRESENT
|
||||
|
||||
# Indexing is only done into one index at a time
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_embedding_model(
|
||||
db_embedding_model
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=search_settings
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
attempt_id=index_attempt.id,
|
||||
embedder=embedding_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=index_attempt.from_beginning
|
||||
or (db_embedding_model.status == IndexModelStatus.FUTURE),
|
||||
or (search_settings.status == IndexModelStatus.FUTURE),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_cc_pair = index_attempt.connector_credential_pair
|
||||
db_connector = index_attempt.connector_credential_pair.connector
|
||||
db_credential = index_attempt.connector_credential_pair.credential
|
||||
earliest_index_time = (
|
||||
db_connector.indexing_start.timestamp() if db_connector.indexing_start else 0
|
||||
)
|
||||
|
||||
last_successful_index_time = (
|
||||
db_connector.indexing_start.timestamp()
|
||||
if index_attempt.from_beginning and db_connector.indexing_start is not None
|
||||
else (
|
||||
0.0
|
||||
if index_attempt.from_beginning
|
||||
else get_last_successful_attempt_time(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
embedding_model=index_attempt.embedding_model,
|
||||
db_session=db_session,
|
||||
)
|
||||
earliest_index_time
|
||||
if index_attempt.from_beginning
|
||||
else get_last_successful_attempt_time(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
earliest_index=earliest_index_time,
|
||||
search_settings=index_attempt.search_settings,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
logger.info(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
|
||||
logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
|
||||
tracer = DanswerTracer()
|
||||
tracer.start()
|
||||
tracer.snap()
|
||||
|
||||
index_attempt_md = IndexAttemptMetadata(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
)
|
||||
|
||||
batch_num = 0
|
||||
net_doc_change = 0
|
||||
document_count = 0
|
||||
chunk_count = 0
|
||||
@@ -181,7 +164,7 @@ def _run_indexing(
|
||||
datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
doc_batch_generator = _get_document_generator(
|
||||
connector_runner = _get_connector_runner(
|
||||
db_session=db_session,
|
||||
attempt=index_attempt,
|
||||
start_time=window_start,
|
||||
@@ -193,15 +176,19 @@ def _run_indexing(
|
||||
tracer_counter = 0
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.snap()
|
||||
for doc_batch in doc_batch_generator:
|
||||
for doc_batch in connector_runner.run():
|
||||
# Check if connector is disabled mid run and stop if so unless it's the secondary
|
||||
# index being built. We want to populate it even for paused connectors
|
||||
# Often paused connectors are sources that aren't updated frequently but the
|
||||
# contents still need to be initially pulled.
|
||||
db_session.refresh(db_connector)
|
||||
if (
|
||||
db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
|
||||
and db_embedding_model.status != IndexModelStatus.FUTURE
|
||||
(
|
||||
db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
|
||||
and search_settings.status != IndexModelStatus.FUTURE
|
||||
)
|
||||
# if it's deleting, we don't care if this is a secondary index
|
||||
or db_cc_pair.status == ConnectorCredentialPairStatus.DELETING
|
||||
):
|
||||
# let the `except` block handle this
|
||||
raise RuntimeError("Connector was disabled mid run")
|
||||
@@ -228,13 +215,13 @@ def _run_indexing(
|
||||
|
||||
logger.debug(f"Indexing batch of documents: {batch_description}")
|
||||
|
||||
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
document_batch=doc_batch,
|
||||
index_attempt_metadata=IndexAttemptMetadata(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
),
|
||||
index_attempt_metadata=index_attempt_md,
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
net_doc_change += new_docs
|
||||
chunk_count += total_batch_chunks
|
||||
document_count += len(doc_batch)
|
||||
@@ -261,7 +248,7 @@ def _run_indexing(
|
||||
INDEXING_TRACER_INTERVAL > 0
|
||||
and tracer_counter % INDEXING_TRACER_INTERVAL == 0
|
||||
):
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Running trace comparison for batch {tracer_counter}. interval={INDEXING_TRACER_INTERVAL}"
|
||||
)
|
||||
tracer.snap()
|
||||
@@ -277,7 +264,7 @@ def _run_indexing(
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
logger.exception(
|
||||
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
|
||||
)
|
||||
# Only mark the attempt as a complete failure if this is the first indexing window.
|
||||
@@ -289,7 +276,7 @@ def _run_indexing(
|
||||
# to give better clarity in the UI, as the next run will never happen.
|
||||
if (
|
||||
ind == 0
|
||||
or db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
|
||||
or not db_cc_pair.status.is_active()
|
||||
or index_attempt.status != IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
mark_attempt_failed(
|
||||
@@ -315,15 +302,52 @@ def _run_indexing(
|
||||
break
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Running trace comparison between start and end of indexing. {tracer_counter} batches processed."
|
||||
)
|
||||
tracer.snap()
|
||||
tracer.log_first_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
|
||||
tracer.stop()
|
||||
logger.info("Memory tracer stopped.")
|
||||
logger.debug("Memory tracer stopped.")
|
||||
|
||||
if (
|
||||
index_attempt_md.num_exceptions > 0
|
||||
and index_attempt_md.num_exceptions >= batch_num
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt,
|
||||
db_session,
|
||||
failure_reason="All batches exceptioned.",
|
||||
)
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=index_attempt.connector_credential_pair.connector.id,
|
||||
credential_id=index_attempt.connector_credential_pair.credential.id,
|
||||
)
|
||||
raise Exception(
|
||||
f"Connector failed - All batches exceptioned: batches={batch_num}"
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
if index_attempt_md.num_exceptions == 0:
|
||||
mark_attempt_succeeded(index_attempt, db_session)
|
||||
logger.info(
|
||||
f"Connector succeeded: "
|
||||
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
else:
|
||||
mark_attempt_partially_succeeded(index_attempt, db_session)
|
||||
logger.info(
|
||||
f"Connector completed with some errors: "
|
||||
f"exceptions={index_attempt_md.num_exceptions} "
|
||||
f"batches={batch_num} "
|
||||
f"docs={document_count} "
|
||||
f"chunks={chunk_count} "
|
||||
f"elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
mark_attempt_succeeded(index_attempt, db_session)
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
@@ -332,11 +356,6 @@ def _run_indexing(
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Connector succeeded: docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
|
||||
def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexAttempt:
|
||||
# make sure that the index attempt can't change in between checking the
|
||||
@@ -365,17 +384,22 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
|
||||
return attempt
|
||||
|
||||
|
||||
def run_indexing_entrypoint(index_attempt_id: int, is_ee: bool = False) -> None:
|
||||
def run_indexing_entrypoint(
|
||||
index_attempt_id: int, connector_credential_pair_id: int, is_ee: bool = False
|
||||
) -> None:
|
||||
"""Entrypoint for indexing run when using dask distributed.
|
||||
Wraps the actual logic in a `try` block so that we can catch any exceptions
|
||||
and mark the attempt as failed."""
|
||||
|
||||
try:
|
||||
if is_ee:
|
||||
global_version.set_ee()
|
||||
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
|
||||
IndexAttemptSingleton.set_cc_and_index_id(
|
||||
index_attempt_id, connector_credential_pair_id
|
||||
)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# make sure that it is valid to run this indexing attempt + mark it
|
||||
|
||||
@@ -48,9 +48,9 @@ class DanswerTracer:
|
||||
|
||||
stats = self.snapshot.statistics("traceback")
|
||||
for s in stats[:numEntries]:
|
||||
logger.info(f"Tracer snap: {s}")
|
||||
logger.debug(f"Tracer snap: {s}")
|
||||
for line in s.traceback:
|
||||
logger.info(f"* {line}")
|
||||
logger.debug(f"* {line}")
|
||||
|
||||
@staticmethod
|
||||
def log_diff(
|
||||
@@ -60,9 +60,9 @@ class DanswerTracer:
|
||||
) -> None:
|
||||
stats = snap_current.compare_to(snap_previous, "traceback")
|
||||
for s in stats[:numEntries]:
|
||||
logger.info(f"Tracer diff: {s}")
|
||||
logger.debug(f"Tracer diff: {s}")
|
||||
for line in s.traceback.format():
|
||||
logger.info(f"* {line}")
|
||||
logger.debug(f"* {line}")
|
||||
|
||||
def log_previous_diff(self, numEntries: int) -> None:
|
||||
if not self.snapshot or not self.snapshot_prev:
|
||||
|
||||
@@ -93,9 +93,16 @@ def build_apply_async_wrapper(build_name_fn: Callable[..., str]) -> Callable[[AA
|
||||
kwargs_for_build_name = kwargs or {}
|
||||
task_name = build_name_fn(*args_for_build_name, **kwargs_for_build_name)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# mark the task as started
|
||||
# register_task must come before fn = apply_async or else the task
|
||||
# might run mark_task_start (and crash) before the task row exists
|
||||
db_task = register_task(task_name, db_session)
|
||||
|
||||
task = fn(args, kwargs, *other_args, **other_kwargs)
|
||||
register_task(task.id, task_name, db_session)
|
||||
|
||||
# we update the celery task id for diagnostic purposes
|
||||
# but it isn't currently used by any code
|
||||
db_task.task_id = task.id
|
||||
db_session.commit()
|
||||
|
||||
return task
|
||||
|
||||
|
||||
@@ -17,15 +17,13 @@ from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||
from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import init_sqlalchemy_engine
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_inprogress_index_attempts
|
||||
@@ -33,11 +31,14 @@ from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import get_not_started_index_attempts
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
@@ -60,20 +61,27 @@ _UNEXPECTED_STATE_FAILURE_REASON = (
|
||||
def _should_create_new_indexing(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
last_index: IndexAttempt | None,
|
||||
model: EmbeddingModel,
|
||||
search_settings_instance: SearchSettings,
|
||||
secondary_index_building: bool,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
connector = cc_pair.connector
|
||||
|
||||
# don't kick off indexing for `NOT_APPLICABLE` sources
|
||||
if connector.source == DocumentSource.NOT_APPLICABLE:
|
||||
return False
|
||||
|
||||
# User can still manually create single indexing attempts via the UI for the
|
||||
# currently in use index
|
||||
if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
if model.status == IndexModelStatus.PRESENT and secondary_index_building:
|
||||
if (
|
||||
search_settings_instance.status == IndexModelStatus.PRESENT
|
||||
and secondary_index_building
|
||||
):
|
||||
return False
|
||||
|
||||
# When switching over models, always index at least once
|
||||
if model.status == IndexModelStatus.FUTURE:
|
||||
if search_settings_instance.status == IndexModelStatus.FUTURE:
|
||||
if last_index:
|
||||
# No new index if the last index attempt succeeded
|
||||
# Once is enough. The model will never be able to swap otherwise.
|
||||
@@ -95,7 +103,7 @@ def _should_create_new_indexing(
|
||||
# If the connector is paused or is the ingestion API, don't index
|
||||
# NOTE: during an embedding model switch over, the following logic
|
||||
# is bypassed by the above check for a future model
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.PAUSED or connector.id == 0:
|
||||
if not cc_pair.status.is_active() or connector.id == 0:
|
||||
return False
|
||||
|
||||
if not last_index:
|
||||
@@ -120,16 +128,6 @@ def _should_create_new_indexing(
|
||||
return time_since_index.total_seconds() >= connector.refresh_freq
|
||||
|
||||
|
||||
def _is_indexing_job_marked_as_finished(index_attempt: IndexAttempt | None) -> bool:
|
||||
if index_attempt is None:
|
||||
return False
|
||||
|
||||
return (
|
||||
index_attempt.status == IndexingStatus.FAILED
|
||||
or index_attempt.status == IndexingStatus.SUCCESS
|
||||
)
|
||||
|
||||
|
||||
def _mark_run_failed(
|
||||
db_session: Session, index_attempt: IndexAttempt, failure_reason: str
|
||||
) -> None:
|
||||
@@ -170,35 +168,42 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
ongoing.add(
|
||||
(
|
||||
attempt.connector_credential_pair_id,
|
||||
attempt.embedding_model_id,
|
||||
attempt.search_settings_id,
|
||||
)
|
||||
)
|
||||
|
||||
embedding_models = [get_current_db_embedding_model(db_session)]
|
||||
secondary_embedding_model = get_secondary_db_embedding_model(db_session)
|
||||
if secondary_embedding_model is not None:
|
||||
embedding_models.append(secondary_embedding_model)
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings = [primary_search_settings]
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings.append(secondary_search_settings)
|
||||
|
||||
all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair in all_connector_credential_pairs:
|
||||
for model in embedding_models:
|
||||
for search_settings_instance in search_settings:
|
||||
# Check if there is an ongoing indexing attempt for this connector credential pair
|
||||
if (cc_pair.id, model.id) in ongoing:
|
||||
if (cc_pair.id, search_settings_instance.id) in ongoing:
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, model.id, db_session
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
if not _should_create_new_indexing(
|
||||
cc_pair=cc_pair,
|
||||
last_index=last_attempt,
|
||||
model=model,
|
||||
secondary_index_building=len(embedding_models) > 1,
|
||||
search_settings_instance=search_settings_instance,
|
||||
secondary_index_building=len(search_settings) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
create_index_attempt(cc_pair.id, model.id, db_session)
|
||||
create_index_attempt(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
|
||||
|
||||
def cleanup_indexing_jobs(
|
||||
@@ -215,10 +220,12 @@ def cleanup_indexing_jobs(
|
||||
)
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done() and not _is_indexing_job_marked_as_finished(
|
||||
index_attempt
|
||||
):
|
||||
continue
|
||||
if not job.done():
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
logger.error(job.exception())
|
||||
@@ -293,7 +300,7 @@ def kickoff_indexing_jobs(
|
||||
# get_not_started_index_attempts orders its returned results from oldest to newest
|
||||
# we must process attempts in a FIFO manner to prevent connector starvation
|
||||
new_indexing_attempts = [
|
||||
(attempt, attempt.embedding_model)
|
||||
(attempt, attempt.search_settings)
|
||||
for attempt in get_not_started_index_attempts(db_session)
|
||||
if attempt.id not in existing_jobs
|
||||
]
|
||||
@@ -305,10 +312,10 @@ def kickoff_indexing_jobs(
|
||||
|
||||
indexing_attempt_count = 0
|
||||
|
||||
for attempt, embedding_model in new_indexing_attempts:
|
||||
for attempt, search_settings in new_indexing_attempts:
|
||||
use_secondary_index = (
|
||||
embedding_model.status == IndexModelStatus.FUTURE
|
||||
if embedding_model is not None
|
||||
search_settings.status == IndexModelStatus.FUTURE
|
||||
if search_settings is not None
|
||||
else False
|
||||
)
|
||||
if attempt.connector_credential_pair.connector is None:
|
||||
@@ -334,6 +341,7 @@ def kickoff_indexing_jobs(
|
||||
run = secondary_client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
attempt.connector_credential_pair_id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
@@ -341,6 +349,7 @@ def kickoff_indexing_jobs(
|
||||
run = client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
attempt.connector_credential_pair_id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
@@ -381,17 +390,21 @@ def update_loop(
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
|
||||
if db_embedding_model.cloud_provider_id is None:
|
||||
logger.debug("Running a first inference to warm up embedding model")
|
||||
if search_settings.provider_type is None:
|
||||
logger.notice("Running a first inference to warm up embedding model")
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=db_embedding_model,
|
||||
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
client_primary: Client | SimpleJobClient
|
||||
@@ -454,7 +467,7 @@ def update__main() -> None:
|
||||
set_is_ee_based_on_env_variable()
|
||||
init_sqlalchemy_engine(POSTGRES_INDEXER_APP_NAME)
|
||||
|
||||
logger.info("Starting indexing service")
|
||||
logger.notice("Starting indexing service")
|
||||
update_loop()
|
||||
|
||||
|
||||
|
||||
@@ -36,7 +36,8 @@ def create_chat_chain(
|
||||
chat_session_id: int,
|
||||
db_session: Session,
|
||||
prefetch_tool_calls: bool = True,
|
||||
parent_id: int | None = None,
|
||||
# Optional id at which we finish processing
|
||||
stop_at_message_id: int | None = None,
|
||||
) -> tuple[ChatMessage, list[ChatMessage]]:
|
||||
"""Build the linear chain of messages without including the root message"""
|
||||
mainline_messages: list[ChatMessage] = []
|
||||
@@ -62,7 +63,12 @@ def create_chat_chain(
|
||||
current_message: ChatMessage | None = root_message
|
||||
while current_message is not None:
|
||||
child_msg = current_message.latest_child_message
|
||||
if not child_msg or (parent_id and current_message.id == parent_id):
|
||||
|
||||
# Break if at the end of the chain
|
||||
# or have reached the `final_id` of the submitted message
|
||||
if not child_msg or (
|
||||
stop_at_message_id and current_message.id == stop_at_message_id
|
||||
):
|
||||
break
|
||||
current_message = id_to_msg.get(child_msg)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -9,6 +10,8 @@ from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SearchResponse
|
||||
from danswer.tools.custom.base_tool_types import ToolResultType
|
||||
from danswer.tools.graphing.models import GraphGenerationDisplay
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
@@ -34,16 +37,37 @@ class QADocsResponse(RetrievalDocs):
|
||||
applied_time_cutoff: datetime | None
|
||||
recency_bias_multiplier: float
|
||||
|
||||
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().dict(*args, **kwargs) # type: ignore
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
initial_dict["applied_time_cutoff"] = (
|
||||
self.applied_time_cutoff.isoformat() if self.applied_time_cutoff else None
|
||||
)
|
||||
|
||||
return initial_dict
|
||||
|
||||
|
||||
class StreamStopReason(Enum):
|
||||
CONTEXT_LENGTH = "context_length"
|
||||
CANCELLED = "cancelled"
|
||||
FINISHED = "finished"
|
||||
NEW_RESPONSE = "new_response"
|
||||
|
||||
|
||||
class StreamStopInfo(BaseModel):
|
||||
stop_reason: StreamStopReason
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
data["stop_reason"] = self.stop_reason.name
|
||||
return data
|
||||
|
||||
|
||||
class LLMRelevanceFilterResponse(BaseModel):
|
||||
relevant_chunk_indices: list[int]
|
||||
llm_selected_doc_indices: list[int]
|
||||
|
||||
|
||||
class FinalUsedContextDocsResponse(BaseModel):
|
||||
final_context_docs: list[LlmDoc]
|
||||
|
||||
|
||||
class RelevanceAnalysis(BaseModel):
|
||||
@@ -64,10 +88,6 @@ class DocumentRelevance(BaseModel):
|
||||
relevance_summaries: dict[str, RelevanceAnalysis]
|
||||
|
||||
|
||||
class Delimiter(BaseModel):
|
||||
delimiter: bool
|
||||
|
||||
|
||||
class DanswerAnswerPiece(BaseModel):
|
||||
# A small piece of a complete answer. Used for streaming back answers.
|
||||
answer_piece: str | None # if None, specifies the end of an Answer
|
||||
@@ -80,6 +100,16 @@ class CitationInfo(BaseModel):
|
||||
document_id: str
|
||||
|
||||
|
||||
class AllCitations(BaseModel):
|
||||
citations: list[CitationInfo]
|
||||
|
||||
|
||||
# This is a mapping of the citation number to the document index within
|
||||
# the result search doc set
|
||||
class MessageSpecificCitations(BaseModel):
|
||||
citation_map: dict[int, int]
|
||||
|
||||
|
||||
class MessageResponseIDInfo(BaseModel):
|
||||
user_message_id: int | None
|
||||
reserved_assistant_message_id: int
|
||||
@@ -125,7 +155,7 @@ class QAResponse(SearchResponse, DanswerAnswer):
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
eval_res_valid: bool | None = None
|
||||
llm_chunks_indices: list[int] | None = None
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
|
||||
|
||||
@@ -134,7 +164,7 @@ class ImageGenerationDisplay(BaseModel):
|
||||
|
||||
|
||||
class CustomToolResponse(BaseModel):
|
||||
response: dict
|
||||
response: ToolResultType
|
||||
tool_name: str
|
||||
|
||||
|
||||
@@ -146,7 +176,8 @@ AnswerQuestionPossibleReturn = (
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
| Delimiter
|
||||
| GraphGenerationDisplay
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from functools import partial
|
||||
@@ -6,15 +7,19 @@ from typing import cast
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.models import AllCitations
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import CustomToolResponse
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import Delimiter
|
||||
from danswer.chat.models import FinalUsedContextDocsResponse
|
||||
from danswer.chat.models import ImageGenerationDisplay
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import MessageResponseIDInfo
|
||||
from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
@@ -32,13 +37,13 @@ from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.chat import reserve_message_id
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import SearchDoc as DbSearchDoc
|
||||
from danswer.db.models import ToolCall
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_persona_by_id
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
@@ -69,11 +74,16 @@ from danswer.search.utils import relevant_sections_to_indices
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.analysis.analysis_tool import CSVAnalysisTool
|
||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
|
||||
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from danswer.tools.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.graphing.graphing_tool import GraphingResponse
|
||||
from danswer.tools.graphing.graphing_tool import GraphingTool
|
||||
from danswer.tools.graphing.models import GraphGenerationDisplay
|
||||
from danswer.tools.graphing.models import GRAPHING_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
@@ -85,13 +95,15 @@ from danswer.tools.internet_search.internet_search_tool import (
|
||||
)
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_runner import ToolCallMetadata
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.utils import compute_all_tool_tokens
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -100,9 +112,9 @@ from danswer.utils.timing import log_generator_function_time
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def translate_citations(
|
||||
def _translate_citations(
|
||||
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
|
||||
) -> dict[int, int]:
|
||||
) -> MessageSpecificCitations:
|
||||
"""Always cites the first instance of the document_id, assumes the db_docs
|
||||
are sorted in the order displayed in the UI"""
|
||||
doc_id_to_saved_doc_id_map: dict[str, int] = {}
|
||||
@@ -117,7 +129,7 @@ def translate_citations(
|
||||
citation.citation_num
|
||||
] = doc_id_to_saved_doc_id_map[citation.document_id]
|
||||
|
||||
return citation_to_saved_doc_id_map
|
||||
return MessageSpecificCitations(citation_map=citation_to_saved_doc_id_map)
|
||||
|
||||
|
||||
def _handle_search_tool_response_summary(
|
||||
@@ -238,14 +250,18 @@ def _get_force_search_settings(
|
||||
ChatPacket = (
|
||||
StreamingError
|
||||
| QADocsResponse
|
||||
| GraphingResponse
|
||||
| LLMRelevanceFilterResponse
|
||||
| FinalUsedContextDocsResponse
|
||||
| ChatMessageDetail
|
||||
| DanswerAnswerPiece
|
||||
| AllCitations
|
||||
| CitationInfo
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| GraphGenerationDisplay
|
||||
| MessageSpecificCitations
|
||||
| MessageResponseIDInfo
|
||||
| Delimiter
|
||||
)
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
@@ -271,6 +287,11 @@ def stream_chat_message_objects(
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
"""
|
||||
# Currently surrounding context is not supported for chat
|
||||
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
|
||||
new_msg_req.chunks_above = 0
|
||||
new_msg_req.chunks_below = 0
|
||||
|
||||
try:
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
@@ -327,9 +348,9 @@ def stream_chat_message_objects(
|
||||
Callable[[str], list[int]], llm_tokenizer.encode
|
||||
)
|
||||
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=embedding_model.index_name, secondary_index_name=None
|
||||
primary_index_name=search_settings.index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
# Every chat Session begins with an empty root message
|
||||
@@ -350,7 +371,7 @@ def stream_chat_message_objects(
|
||||
|
||||
if new_msg_req.regenerate:
|
||||
final_msg, history_msgs = create_chat_chain(
|
||||
parent_id=parent_id,
|
||||
stop_at_message_id=parent_id,
|
||||
chat_session_id=chat_session_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
@@ -461,8 +482,6 @@ def stream_chat_message_objects(
|
||||
else default_num_chunks
|
||||
),
|
||||
max_window_percentage=max_document_percentage,
|
||||
use_sections=new_msg_req.chunks_above > 0
|
||||
or new_msg_req.chunks_below > 0,
|
||||
)
|
||||
reserved_message_id = reserve_message_id(
|
||||
db_session=db_session,
|
||||
@@ -477,16 +496,17 @@ def stream_chat_message_objects(
|
||||
reserved_assistant_message_id=reserved_message_id,
|
||||
)
|
||||
|
||||
alternate_model = (
|
||||
overridden_model = (
|
||||
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None
|
||||
)
|
||||
|
||||
# Cannot determine these without the LLM step or breaking out early
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=final_msg,
|
||||
prompt_id=prompt_id,
|
||||
alternate_model=alternate_model,
|
||||
overridden_model=overridden_model,
|
||||
# message=,
|
||||
# rephrased_query=,
|
||||
# token_count=,
|
||||
@@ -517,8 +537,21 @@ def stream_chat_message_objects(
|
||||
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
|
||||
for db_tool_model in persona.tools:
|
||||
# handle in-code tools specially
|
||||
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
|
||||
if (
|
||||
tool_cls.__name__ == CSVAnalysisTool.__name__
|
||||
and not latest_query_files
|
||||
):
|
||||
tool_dict[db_tool_model.id] = [CSVAnalysisTool()]
|
||||
|
||||
if (
|
||||
tool_cls.__name__ == GraphingTool.__name__
|
||||
and not latest_query_files
|
||||
):
|
||||
tool_dict[db_tool_model.id] = [GraphingTool(output_dir="output")]
|
||||
|
||||
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
@@ -589,19 +622,27 @@ def stream_chat_message_objects(
|
||||
]
|
||||
|
||||
continue
|
||||
|
||||
# handle all custom tools
|
||||
if db_tool_model.openapi_schema:
|
||||
tool_dict[db_tool_model.id] = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema(
|
||||
db_tool_model.openapi_schema
|
||||
db_tool_model.openapi_schema,
|
||||
dynamic_schema_info=DynamicSchemaInfo(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
|
||||
@@ -610,7 +651,6 @@ def stream_chat_message_objects(
|
||||
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
|
||||
llm_provider, llm_model_name
|
||||
)
|
||||
tool_has_been_called = False # TODO remove
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
@@ -648,212 +688,245 @@ def stream_chat_message_objects(
|
||||
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
|
||||
dropped_indices = None
|
||||
tool_result = None
|
||||
yielded_message_id_info = True
|
||||
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, ToolResponse):
|
||||
tool_has_been_called = True
|
||||
if isinstance(packet, StreamStopInfo):
|
||||
if packet.stop_reason is not StreamStopReason.NEW_RESPONSE:
|
||||
break
|
||||
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
dropped_indices,
|
||||
) = _handle_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
dedupe_docs=retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
db_citations = None
|
||||
|
||||
if reference_db_search_docs is not None:
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in reference_db_search_docs
|
||||
],
|
||||
)
|
||||
|
||||
if dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=reference_db_search_docs,
|
||||
dropped_indices=dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
relevant_chunk_indices=llm_indices
|
||||
)
|
||||
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], packet.response
|
||||
if reference_db_search_docs:
|
||||
db_citations = _translate_citations(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
|
||||
file_ids = save_files_from_urls(
|
||||
[img.url for img in img_generation_response]
|
||||
)
|
||||
ai_message_files = [
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
yield ImageGenerationDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
# Saving Gen AI answer and responding with message info
|
||||
if tool_result is None:
|
||||
tool_call = None
|
||||
else:
|
||||
tool_call = ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments={
|
||||
k: v if not isinstance(v, bytes) else v.decode("utf-8")
|
||||
for k, v in tool_result.tool_args.items()
|
||||
},
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=cast(
|
||||
QADocsResponse, qa_docs_response
|
||||
).rephrased_query
|
||||
if qa_docs_response is not None
|
||||
else None,
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=cast(MessageSpecificCitations, db_citations).citation_map
|
||||
if db_citations is not None
|
||||
else None,
|
||||
error=None,
|
||||
tool_call=tool_call,
|
||||
)
|
||||
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
|
||||
yield msg_detail_response
|
||||
reserved_message_id = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=gen_ai_response_message.id
|
||||
if user_message is not None
|
||||
else gen_ai_response_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
yielded_message_id_info = False
|
||||
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=gen_ai_response_message,
|
||||
prompt_id=prompt_id,
|
||||
overridden_model=overridden_model,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
alternate_assistant_id=new_msg_req.alternate_assistant_id,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
reference_db_search_docs = None
|
||||
|
||||
else:
|
||||
if isinstance(packet, Delimiter):
|
||||
db_citations = None
|
||||
if not yielded_message_id_info:
|
||||
yield MessageResponseIDInfo(
|
||||
user_message_id=gen_ai_response_message.id,
|
||||
reserved_assistant_message_id=reserved_message_id,
|
||||
)
|
||||
yielded_message_id_info = True
|
||||
|
||||
if reference_db_search_docs:
|
||||
db_citations = translate_citations(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
if isinstance(packet, ToolResponse):
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
dropped_indices,
|
||||
) = _handle_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
dedupe_docs=retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
|
||||
if reference_db_search_docs is not None:
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in reference_db_search_docs
|
||||
],
|
||||
)
|
||||
|
||||
if dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=reference_db_search_docs,
|
||||
dropped_indices=dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(
|
||||
final_context_docs=packet.response
|
||||
)
|
||||
elif packet.id == GRAPHING_RESPONSE_ID:
|
||||
graph_generation = cast(GraphingResponse, packet.response)
|
||||
yield graph_generation
|
||||
|
||||
# yield GraphGenerationDisplay(
|
||||
# file_id=graph_generation.extra_graph_display.file_id,
|
||||
# line_graph=graph_generation.extra_graph_display.line_graph,
|
||||
# )
|
||||
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], packet.response
|
||||
)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
if tool_result is None:
|
||||
tool_call = None
|
||||
else:
|
||||
tool_call = ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
file_ids = save_files_from_urls(
|
||||
[img.url for img in img_generation_response]
|
||||
)
|
||||
ai_message_files = [
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
yield ImageGenerationDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(
|
||||
CustomToolCallSummary, packet.response
|
||||
)
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
qa_docs_response.rephrased_query
|
||||
if qa_docs_response
|
||||
else None
|
||||
),
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=db_citations,
|
||||
error=None,
|
||||
tool_call=tool_call,
|
||||
)
|
||||
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
|
||||
yield msg_detail_response
|
||||
yield Delimiter(delimiter=True)
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=gen_ai_response_message,
|
||||
prompt_id=prompt_id,
|
||||
# message=,
|
||||
# rephrased_query=,
|
||||
# token_count=,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
alternate_assistant_id=new_msg_req.alternate_assistant_id,
|
||||
# error=,
|
||||
# reference_docs=,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
else:
|
||||
if isinstance(packet, ToolCallMetadata):
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
logger.debug("Reached end of stream")
|
||||
|
||||
logger.debug("Reached end of stream")
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.exception(f"Failed to process chat message: {error_msg}")
|
||||
|
||||
stack_trace = traceback.format_exc()
|
||||
client_error_msg = litellm_exception_to_error_msg(e, llm)
|
||||
if llm.config.api_key and len(llm.config.api_key) > 2:
|
||||
error_msg = error_msg.replace(llm.config.api_key, "[REDACTED_API_KEY]")
|
||||
yield StreamingError(error=client_error_msg, stack_trace=error_msg)
|
||||
stack_trace = stack_trace.replace(llm.config.api_key, "[REDACTED_API_KEY]")
|
||||
|
||||
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
if not tool_has_been_called:
|
||||
try:
|
||||
db_citations = None
|
||||
if reference_db_search_docs:
|
||||
db_citations = translate_citations(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
qa_docs_response.rephrased_query if qa_docs_response else None
|
||||
),
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=db_citations,
|
||||
error=None,
|
||||
tool_call=ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
if tool_result
|
||||
else None,
|
||||
# Post-LLM answer processing
|
||||
try:
|
||||
message_specific_citations: MessageSpecificCitations | None = None
|
||||
if reference_db_search_docs:
|
||||
message_specific_citations = _translate_citations(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
yield AllCitations(citations=answer.citations)
|
||||
|
||||
logger.debug("Committing messages")
|
||||
db_session.commit() # actually save user / assistant message
|
||||
if answer.llm_answer == "":
|
||||
return
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
qa_docs_response.rephrased_query if qa_docs_response else None
|
||||
),
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=message_specific_citations.citation_map
|
||||
if message_specific_citations
|
||||
else None,
|
||||
error=None,
|
||||
tool_call=ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
if tool_result
|
||||
else None,
|
||||
)
|
||||
|
||||
yield msg_detail_response
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.exception(error_msg)
|
||||
logger.debug("Committing messages")
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
# Frontend will erase whatever answer and show this instead
|
||||
yield StreamingError(error="Failed to parse LLM output")
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
|
||||
yield msg_detail_response
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.exception(error_msg)
|
||||
|
||||
# Frontend will erase whatever answer and show this instead
|
||||
yield StreamingError(error="Failed to parse LLM output")
|
||||
|
||||
|
||||
@log_generator_function_time()
|
||||
@@ -874,4 +947,4 @@ def stream_chat_message(
|
||||
is_connected=is_connected,
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.dict())
|
||||
yield get_json_line(obj.model_dump())
|
||||
|
||||
@@ -42,8 +42,7 @@ prompts:
|
||||
task: >
|
||||
Generate an image based on the user's description.
|
||||
|
||||
Provide a detailed description of the generated image, including key elements, colors, and composition.
|
||||
|
||||
Provide a detailed description of the generated image, including key elements, colors, and composition.
|
||||
|
||||
If the request is not possible or appropriate, explain why and suggest alternatives.
|
||||
datetime_aware: true
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TypedDict
|
||||
from typing_extensions import TypedDict # noreorder
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -93,6 +93,14 @@ SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
|
||||
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
|
||||
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
|
||||
|
||||
# If set, Danswer will listen to the `expires_at` returned by the identity
|
||||
# provider (e.g. Okta, Google, etc.) and force the user to re-authenticate
|
||||
# after this time has elapsed. Disabled since by default many auth providers
|
||||
# have very short expiry times (e.g. 1 hour) which provide a poor user experience
|
||||
TRACK_EXTERNAL_IDP_EXPIRY = (
|
||||
os.environ.get("TRACK_EXTERNAL_IDP_EXPIRY", "").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# DB Configs
|
||||
@@ -118,6 +126,7 @@ try:
|
||||
except ValueError:
|
||||
INDEX_BATCH_SIZE = 16
|
||||
|
||||
|
||||
# Below are intended to match the env variables names used by the official postgres docker image
|
||||
# https://hub.docker.com/_/postgres
|
||||
POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
|
||||
@@ -126,7 +135,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
||||
os.environ.get("POSTGRES_PASSWORD") or "password"
|
||||
)
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
|
||||
# defaults to False
|
||||
@@ -141,6 +150,20 @@ try:
|
||||
except ValueError:
|
||||
POSTGRES_POOL_RECYCLE = POSTGRES_POOL_RECYCLE_DEFAULT
|
||||
|
||||
REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
|
||||
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
|
||||
|
||||
# Used for general redis things
|
||||
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
|
||||
|
||||
# Used by celery as broker and backend
|
||||
REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15))
|
||||
|
||||
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "CERT_NONE")
|
||||
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", "")
|
||||
|
||||
#####
|
||||
# Connector Configs
|
||||
#####
|
||||
@@ -192,8 +215,8 @@ CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [
|
||||
]
|
||||
|
||||
# Avoid to get archived pages
|
||||
CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES = (
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES", "").lower() == "true"
|
||||
CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES = (
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Save pages labels as Danswer metadata tags
|
||||
@@ -204,7 +227,12 @@ CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING = (
|
||||
|
||||
# Attachments exceeding this size will not be retrieved (in bytes)
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD = int(
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD", 50 * 1024 * 1024)
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||
)
|
||||
# Attachments with more chars than this will not be indexed. This is to prevent extremely
|
||||
# large files from freezing indexing. 200,000 is ~100 google doc pages.
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
|
||||
)
|
||||
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
@@ -295,6 +323,10 @@ INDEXING_SIZE_WARNING_THRESHOLD = int(
|
||||
# 0 disables this behavior and is the default.
|
||||
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL", 0))
|
||||
|
||||
# During an indexing attempt, specifies the number of batches which are allowed to
|
||||
# exception without aborting the attempt.
|
||||
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT", 0))
|
||||
|
||||
#####
|
||||
# Miscellaneous
|
||||
#####
|
||||
@@ -321,6 +353,10 @@ LOG_VESPA_TIMING_INFORMATION = (
|
||||
os.environ.get("LOG_VESPA_TIMING_INFORMATION", "").lower() == "true"
|
||||
)
|
||||
LOG_ENDPOINT_LATENCY = os.environ.get("LOG_ENDPOINT_LATENCY", "").lower() == "true"
|
||||
LOG_POSTGRES_LATENCY = os.environ.get("LOG_POSTGRES_LATENCY", "").lower() == "true"
|
||||
LOG_POSTGRES_CONN_COUNTS = (
|
||||
os.environ.get("LOG_POSTGRES_CONN_COUNTS", "").lower() == "true"
|
||||
)
|
||||
# Anonymous usage telemetry
|
||||
DISABLE_TELEMETRY = os.environ.get("DISABLE_TELEMETRY", "").lower() == "true"
|
||||
|
||||
|
||||
@@ -31,8 +31,9 @@ FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
|
||||
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
|
||||
# For the highest matching base size chunk, how many chunks above and below do we pull in by default
|
||||
# Note this is not in any of the deployment configs yet
|
||||
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 0)
|
||||
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 0)
|
||||
# Currently only applies to search flow not chat
|
||||
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 1)
|
||||
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 1)
|
||||
# Whether the LLM should be used to decide if a search would help given the chat history
|
||||
DISABLE_LLM_CHOOSE_SEARCH = (
|
||||
os.environ.get("DISABLE_LLM_CHOOSE_SEARCH", "").lower() == "true"
|
||||
@@ -44,7 +45,7 @@ DISABLE_LLM_QUERY_REPHRASE = (
|
||||
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
|
||||
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
|
||||
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.62)))
|
||||
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.5)))
|
||||
HYBRID_ALPHA_KEYWORD = max(
|
||||
0, min(1, float(os.environ.get("HYBRID_ALPHA_KEYWORD") or 0.4))
|
||||
)
|
||||
@@ -53,7 +54,7 @@ HYBRID_ALPHA_KEYWORD = max(
|
||||
# Content. This is to avoid cases where the Content is very relevant but it may not be clear
|
||||
# if the title is separated out. Title is most of a "boost" than a separate field.
|
||||
TITLE_CONTENT_RATIO = max(
|
||||
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.20))
|
||||
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.10))
|
||||
)
|
||||
|
||||
# A list of languages passed to the LLM to rephase the query
|
||||
@@ -82,8 +83,15 @@ DISABLE_LLM_DOC_RELEVANCE = (
|
||||
# Stops streaming answers back to the UI if this pattern is seen:
|
||||
STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
|
||||
|
||||
# The backend logic for this being True isn't fully supported yet
|
||||
HARD_DELETE_CHATS = False
|
||||
# Set this to "true" to hard delete chats
|
||||
# This will make chats unviewable by admins after a user deletes them
|
||||
# As opposed to soft deleting them, which just hides them from non-admin users
|
||||
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true"
|
||||
|
||||
# Internet Search
|
||||
BING_API_KEY = os.environ.get("BING_API_KEY") or None
|
||||
|
||||
# Enable in-house model for detecting connector-based filtering in queries
|
||||
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)
|
||||
|
||||
VESPA_SEARCHER_THREADS = int(os.environ.get("VESPA_SEARCHER_THREADS") or 2)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from enum import auto
|
||||
from enum import Enum
|
||||
|
||||
SOURCE_TYPE = "source_type"
|
||||
@@ -12,10 +13,6 @@ ID_SEPARATOR = ":;:"
|
||||
DEFAULT_BOOST = 0
|
||||
SESSION_KEY = "session"
|
||||
|
||||
|
||||
# For tool calling
|
||||
MAXIMUM_TOOL_CALL_SEQUENCE = 5
|
||||
|
||||
# For chunking/processing chunks
|
||||
RETURN_SEPARATOR = "\n\r\n"
|
||||
SECTION_SEPARATOR = "\n\n"
|
||||
@@ -60,9 +57,12 @@ KV_SLACK_BOT_TOKENS_CONFIG_KEY = "slack_bot_tokens_config_key"
|
||||
KV_GEN_AI_KEY_CHECK_TIME = "genai_api_key_last_check_time"
|
||||
KV_SETTINGS_KEY = "danswer_settings"
|
||||
KV_CUSTOMER_UUID_KEY = "customer_uuid"
|
||||
KV_INSTANCE_DOMAIN_KEY = "instance_domain"
|
||||
KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings"
|
||||
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
|
||||
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
# Special case, document passed in via Danswer APIs without specifying a source type
|
||||
@@ -165,3 +165,28 @@ class FileOrigin(str, Enum):
|
||||
CONNECTOR = "connector"
|
||||
GENERATED_REPORT = "generated_report"
|
||||
OTHER = "other"
|
||||
GRAPH_GEN = "graph_gen"
|
||||
|
||||
|
||||
class PostgresAdvisoryLocks(Enum):
|
||||
KOMBU_MESSAGE_CLEANUP_LOCK_ID = auto()
|
||||
|
||||
|
||||
class DanswerCeleryQueues:
|
||||
VESPA_DOCSET_SYNC_GENERATOR = "vespa_docset_sync_generator"
|
||||
VESPA_USERGROUP_SYNC_GENERATOR = "vespa_usergroup_sync_generator"
|
||||
VESPA_METADATA_SYNC = "vespa_metadata_sync"
|
||||
CONNECTOR_DELETION = "connector_deletion"
|
||||
|
||||
|
||||
class DanswerRedisLocks:
|
||||
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
|
||||
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
|
||||
|
||||
|
||||
class DanswerCeleryPriority(int, Enum):
|
||||
HIGHEST = 0
|
||||
HIGH = auto()
|
||||
MEDIUM = auto()
|
||||
LOW = auto()
|
||||
LOWEST = auto()
|
||||
|
||||
@@ -73,3 +73,15 @@ DANSWER_BOT_FEEDBACK_REMINDER = int(
|
||||
DANSWER_BOT_REPHRASE_MESSAGE = (
|
||||
os.environ.get("DANSWER_BOT_REPHRASE_MESSAGE", "").lower() == "true"
|
||||
)
|
||||
|
||||
# DANSWER_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD is the number of
|
||||
# responses DanswerBot can send in a given time period.
|
||||
# Set to 0 to disable the limit.
|
||||
DANSWER_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD = int(
|
||||
os.environ.get("DANSWER_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD", "5000")
|
||||
)
|
||||
# DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS is the number
|
||||
# of seconds until the response limit is reset.
|
||||
DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS = int(
|
||||
os.environ.get("DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS", "86400")
|
||||
)
|
||||
|
||||
@@ -39,9 +39,13 @@ SIM_SCORE_RANGE_HIGH = float(os.environ.get("SIM_SCORE_RANGE_HIGH") or 1.0)
|
||||
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "search_query: ")
|
||||
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "search_document: ")
|
||||
# Purely an optimization, memory limitation consideration
|
||||
BATCH_SIZE_ENCODE_CHUNKS = 8
|
||||
|
||||
# User's set embedding batch size overrides the default encoding batch sizes
|
||||
EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE") or 0) or None
|
||||
|
||||
BATCH_SIZE_ENCODE_CHUNKS = EMBEDDING_BATCH_SIZE or 8
|
||||
# don't send over too many chunks at once, as sending too many could cause timeouts
|
||||
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES = 512
|
||||
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES = EMBEDDING_BATCH_SIZE or 512
|
||||
# For score display purposes, only way is to know the expected ranges
|
||||
CROSS_ENCODER_RANGE_MAX = 1
|
||||
CROSS_ENCODER_RANGE_MIN = 0
|
||||
@@ -51,37 +55,23 @@ CROSS_ENCODER_RANGE_MIN = 0
|
||||
# Generative AI Model Configs
|
||||
#####
|
||||
|
||||
# If changing GEN_AI_MODEL_PROVIDER or GEN_AI_MODEL_VERSION from the default,
|
||||
# be sure to use one that is LiteLLM compatible:
|
||||
# https://litellm.vercel.app/docs/providers/azure#completion---using-env-variables
|
||||
# The provider is the prefix before / in the model argument
|
||||
|
||||
# Additionally Danswer supports GPT4All and custom request library based models
|
||||
# Set GEN_AI_MODEL_PROVIDER to "custom" to use the custom requests approach
|
||||
# Set GEN_AI_MODEL_PROVIDER to "gpt4all" to use gpt4all models running locally
|
||||
GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai"
|
||||
# If using Azure, it's the engine name, for example: Danswer
|
||||
# NOTE: the 3 below should only be used for dev.
|
||||
GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY")
|
||||
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION")
|
||||
|
||||
# For secondary flows like extracting filters or deciding if a chunk is useful, we don't need
|
||||
# as powerful of a model as say GPT-4 so we can use an alternative that is faster and cheaper
|
||||
FAST_GEN_AI_MODEL_VERSION = os.environ.get("FAST_GEN_AI_MODEL_VERSION")
|
||||
|
||||
# If the Generative AI model requires an API key for access, otherwise can leave blank
|
||||
GEN_AI_API_KEY = (
|
||||
os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY")) or None
|
||||
)
|
||||
|
||||
# API Base, such as (for Azure): https://danswer.openai.azure.com/
|
||||
GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None
|
||||
# API Version, such as (for Azure): 2023-09-15-preview
|
||||
GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None
|
||||
# LiteLLM custom_llm_provider
|
||||
GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None
|
||||
# Override the auto-detection of LLM max context length
|
||||
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None
|
||||
|
||||
# Set this to be enough for an answer + quotes. Also used for Chat
|
||||
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
|
||||
# This is the minimum token context we will leave for the LLM to generate an answer
|
||||
GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
|
||||
os.environ.get("GEN_AI_NUM_RESERVED_OUTPUT_TOKENS") or 1024
|
||||
)
|
||||
|
||||
# Typically, GenAI models nowadays are at least 4K tokens
|
||||
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096
|
||||
|
||||
# Number of tokens from chat history to include at maximum
|
||||
# 3000 should be enough context regardless of use, no need to include as much as possible
|
||||
# as this drives up the cost unnecessarily
|
||||
|
||||
@@ -59,6 +59,8 @@ if __name__ == "__main__":
|
||||
latest_docs = test_connector.poll_source(one_day_ago, current)
|
||||
```
|
||||
|
||||
> Note: Be sure to set PYTHONPATH to danswer/backend before running the above main.
|
||||
|
||||
|
||||
### Additional Required Changes:
|
||||
#### Backend Changes
|
||||
@@ -68,17 +70,16 @@ if __name__ == "__main__":
|
||||
[here](https://github.com/danswer-ai/danswer/blob/main/backend/danswer/connectors/factory.py#L33)
|
||||
|
||||
#### Frontend Changes
|
||||
- Create the new connector directory and admin page under `danswer/web/src/app/admin/connectors/`
|
||||
- Create the new icon, type, source, and filter changes
|
||||
(refer to existing [PR](https://github.com/danswer-ai/danswer/pull/139))
|
||||
- Add the new Connector definition to the `SOURCE_METADATA_MAP` [here](https://github.com/danswer-ai/danswer/blob/main/web/src/lib/sources.ts#L59).
|
||||
- Add the definition for the new Form to the `connectorConfigs` object [here](https://github.com/danswer-ai/danswer/blob/main/web/src/lib/connectors/connectors.ts#L79).
|
||||
|
||||
#### Docs Changes
|
||||
Create the new connector page (with guiding images!) with how to get the connector credentials and how to set up the
|
||||
connector in Danswer. Then create a Pull Request in https://github.com/danswer-ai/danswer-docs
|
||||
|
||||
connector in Danswer. Then create a Pull Request in https://github.com/danswer-ai/danswer-docs.
|
||||
|
||||
### Before opening PR
|
||||
1. Be sure to fully test changes end to end with setting up the connector and updating the index with new docs from the
|
||||
new connector.
|
||||
2. Be sure to run the linting/formatting, refer to the formatting and linting section in
|
||||
new connector. To make it easier to review, please attach a video showing the successful creation of the connector via the UI (starting from the `Add Connector` page).
|
||||
2. Add a folder + tests under `backend/tests/daily/connectors` director. For an example, checkout the [test for Confluence](https://github.com/danswer-ai/danswer/blob/main/backend/tests/daily/connectors/confluence/test_confluence_basic.py). In the PR description, include a guide on how to setup the new source to pass the test. Before merging, we will re-create the environment and make sure the test(s) pass.
|
||||
3. Be sure to run the linting/formatting, refer to the formatting and linting section in
|
||||
[CONTRIBUTING.md](https://github.com/danswer-ai/danswer/blob/main/CONTRIBUTING.md#formatting-and-linting)
|
||||
|
||||
@@ -56,7 +56,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
Raises ValueError for unsupported bucket types.
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Loading credentials for {self.bucket_name} or type {self.bucket_type}"
|
||||
)
|
||||
|
||||
@@ -220,7 +220,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
yield batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
logger.info("Loading blob objects")
|
||||
logger.debug("Loading blob objects")
|
||||
return self._yield_blob_objects(
|
||||
start=datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
end=datetime.now(timezone.utc),
|
||||
|
||||
@@ -7,14 +7,16 @@ from datetime import timezone
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import bs4
|
||||
from atlassian import Confluence # type:ignore
|
||||
from requests import HTTPError
|
||||
|
||||
from danswer.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
@@ -42,77 +44,12 @@ logger = setup_logger()
|
||||
# 2. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
|
||||
|
||||
|
||||
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]:
|
||||
"""Sample
|
||||
URL w/ page: https://danswer.atlassian.net/wiki/spaces/1234abcd/pages/5678efgh/overview
|
||||
URL w/o page: https://danswer.atlassian.net/wiki/spaces/ASAM/overview
|
||||
|
||||
wiki_base is https://danswer.atlassian.net/wiki
|
||||
space is 1234abcd
|
||||
page_id is 5678efgh
|
||||
"""
|
||||
parsed_url = urlparse(wiki_url)
|
||||
wiki_base = (
|
||||
parsed_url.scheme
|
||||
+ "://"
|
||||
+ parsed_url.netloc
|
||||
+ parsed_url.path.split("/spaces")[0]
|
||||
)
|
||||
|
||||
path_parts = parsed_url.path.split("/")
|
||||
space = path_parts[3]
|
||||
|
||||
page_id = path_parts[5] if len(path_parts) > 5 else ""
|
||||
return wiki_base, space, page_id
|
||||
|
||||
|
||||
def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, str, str]:
|
||||
"""Sample
|
||||
URL w/ page https://danswer.ai/confluence/display/1234abcd/pages/5678efgh/overview
|
||||
URL w/o page https://danswer.ai/confluence/display/1234abcd/overview
|
||||
wiki_base is https://danswer.ai/confluence
|
||||
space is 1234abcd
|
||||
page_id is 5678efgh
|
||||
"""
|
||||
# /display/ is always right before the space and at the end of the base print()
|
||||
DISPLAY = "/display/"
|
||||
PAGE = "/pages/"
|
||||
|
||||
parsed_url = urlparse(wiki_url)
|
||||
wiki_base = (
|
||||
parsed_url.scheme
|
||||
+ "://"
|
||||
+ parsed_url.netloc
|
||||
+ parsed_url.path.split(DISPLAY)[0]
|
||||
)
|
||||
space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0]
|
||||
page_id = ""
|
||||
if (content := parsed_url.path.split(PAGE)) and len(content) > 1:
|
||||
page_id = content[1]
|
||||
return wiki_base, space, page_id
|
||||
|
||||
|
||||
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]:
|
||||
is_confluence_cloud = (
|
||||
".atlassian.net/wiki/spaces/" in wiki_url
|
||||
or ".jira.com/wiki/spaces/" in wiki_url
|
||||
)
|
||||
|
||||
try:
|
||||
if is_confluence_cloud:
|
||||
wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url(
|
||||
wiki_url
|
||||
)
|
||||
else:
|
||||
wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url(
|
||||
wiki_url
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Not a valid Confluence Wiki Link, unable to extract wiki base, space, and page id. Exception: {e}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
return wiki_base, space, page_id, is_confluence_cloud
|
||||
NO_PERMISSIONS_TO_VIEW_ATTACHMENTS_ERROR_STR = (
|
||||
"User not permitted to view attachments on content"
|
||||
)
|
||||
NO_PARENT_OR_NO_PERMISSIONS_ERROR_STR = (
|
||||
"No parent or not permitted to view content with id"
|
||||
)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@@ -200,19 +137,38 @@ def _comment_dfs(
|
||||
comments_str += "\nComment:\n" + parse_html_page(
|
||||
comment_html, confluence_client
|
||||
)
|
||||
child_comment_pages = get_page_child_by_type(
|
||||
comment_page["id"],
|
||||
type="comment",
|
||||
start=None,
|
||||
limit=None,
|
||||
expand="body.storage.value",
|
||||
)
|
||||
comments_str = _comment_dfs(
|
||||
comments_str, child_comment_pages, confluence_client
|
||||
)
|
||||
try:
|
||||
child_comment_pages = get_page_child_by_type(
|
||||
comment_page["id"],
|
||||
type="comment",
|
||||
start=None,
|
||||
limit=None,
|
||||
expand="body.storage.value",
|
||||
)
|
||||
comments_str = _comment_dfs(
|
||||
comments_str, child_comment_pages, confluence_client
|
||||
)
|
||||
except HTTPError as e:
|
||||
# not the cleanest, but I'm not aware of a nicer way to check the error
|
||||
if NO_PARENT_OR_NO_PERMISSIONS_ERROR_STR not in str(e):
|
||||
raise
|
||||
|
||||
return comments_str
|
||||
|
||||
|
||||
def _datetime_from_string(datetime_string: str) -> datetime:
|
||||
datetime_object = datetime.fromisoformat(datetime_string)
|
||||
|
||||
if datetime_object.tzinfo is None:
|
||||
# If no timezone info, assume it is UTC
|
||||
datetime_object = datetime_object.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# If not in UTC, translate it
|
||||
datetime_object = datetime_object.astimezone(timezone.utc)
|
||||
|
||||
return datetime_object
|
||||
|
||||
|
||||
class RecursiveIndexer:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -342,7 +298,10 @@ class RecursiveIndexer:
|
||||
class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
wiki_page_url: str,
|
||||
wiki_base: str,
|
||||
space: str,
|
||||
is_cloud: bool,
|
||||
page_id: str = "",
|
||||
index_recursively: bool = True,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
@@ -356,15 +315,15 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
self.recursive_indexer: RecursiveIndexer | None = None
|
||||
self.index_recursively = index_recursively
|
||||
(
|
||||
self.wiki_base,
|
||||
self.space,
|
||||
self.page_id,
|
||||
self.is_cloud,
|
||||
) = extract_confluence_keys_from_url(wiki_page_url)
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
self.space = space
|
||||
self.page_id = page_id
|
||||
|
||||
self.is_cloud = is_cloud
|
||||
|
||||
self.space_level_scan = False
|
||||
|
||||
self.confluence_client: Confluence | None = None
|
||||
|
||||
if self.page_id is None or self.page_id == "":
|
||||
@@ -384,7 +343,6 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
username=username if self.is_cloud else None,
|
||||
password=access_token if self.is_cloud else None,
|
||||
token=access_token if not self.is_cloud else None,
|
||||
cloud=self.is_cloud,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -403,9 +361,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
start=start_ind,
|
||||
limit=batch_size,
|
||||
status=(
|
||||
"current"
|
||||
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
|
||||
else None
|
||||
None if CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES else "current"
|
||||
),
|
||||
expand="body.storage.value,version",
|
||||
)
|
||||
@@ -426,9 +382,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
start=start_ind + i,
|
||||
limit=1,
|
||||
status=(
|
||||
"current"
|
||||
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
|
||||
else None
|
||||
None
|
||||
if CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES
|
||||
else "current"
|
||||
),
|
||||
expand="body.storage.value,version",
|
||||
)
|
||||
@@ -535,145 +491,249 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
logger.exception("Ran into exception when fetching labels from Confluence")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _attachment_to_download_link(
|
||||
cls, confluence_client: Confluence, attachment: dict[str, Any]
|
||||
) -> str:
|
||||
return confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
@classmethod
|
||||
def _attachment_to_content(
|
||||
cls,
|
||||
confluence_client: Confluence,
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if attachment["metadata"]["mediaType"] in [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"video/mp4",
|
||||
"video/quicktime",
|
||||
]:
|
||||
return None
|
||||
|
||||
download_link = cls._attachment_to_download_link(confluence_client, attachment)
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
response = confluence_client._session.get(download_link)
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
|
||||
)
|
||||
return None
|
||||
|
||||
extracted_text = extract_file_text(
|
||||
attachment["title"], io.BytesIO(response.content), False
|
||||
)
|
||||
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to char count. "
|
||||
f"char count={len(extracted_text)} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
return extracted_text
|
||||
|
||||
def _fetch_attachments(
|
||||
self, confluence_client: Confluence, page_id: str, files_in_used: list[str]
|
||||
) -> str:
|
||||
) -> tuple[str, list[dict[str, Any]]]:
|
||||
unused_attachments: list = []
|
||||
|
||||
get_attachments_from_content = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_attachments_from_content
|
||||
)
|
||||
files_attachment_content: list = []
|
||||
|
||||
try:
|
||||
expand = "history.lastUpdated,metadata.labels"
|
||||
attachments_container = get_attachments_from_content(
|
||||
page_id, start=0, limit=500
|
||||
page_id, start=0, limit=500, expand=expand
|
||||
)
|
||||
for attachment in attachments_container["results"]:
|
||||
if attachment["metadata"]["mediaType"] in [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"video/mp4",
|
||||
"video/quicktime",
|
||||
]:
|
||||
continue
|
||||
|
||||
if attachment["title"] not in files_in_used:
|
||||
unused_attachments.append(attachment)
|
||||
continue
|
||||
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
)
|
||||
continue
|
||||
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
response = confluence_client._session.get(download_link)
|
||||
|
||||
if response.status_code == 200:
|
||||
extract = extract_file_text(
|
||||
attachment["title"], io.BytesIO(response.content), False
|
||||
)
|
||||
files_attachment_content.append(extract)
|
||||
attachment_content = self._attachment_to_content(
|
||||
confluence_client, attachment
|
||||
)
|
||||
if attachment_content:
|
||||
files_attachment_content.append(attachment_content)
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(
|
||||
e, HTTPError
|
||||
) and NO_PERMISSIONS_TO_VIEW_ATTACHMENTS_ERROR_STR in str(e):
|
||||
logger.warning(
|
||||
f"User does not have access to attachments on page '{page_id}'"
|
||||
)
|
||||
return "", []
|
||||
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
logger.exception(
|
||||
f"Ran into exception when fetching attachments from Confluence: {e}"
|
||||
)
|
||||
|
||||
return "\n".join(files_attachment_content)
|
||||
return "\n".join(files_attachment_content), unused_attachments
|
||||
|
||||
def _get_doc_batch(
|
||||
self, start_ind: int, time_filter: Callable[[datetime], bool] | None = None
|
||||
) -> tuple[list[Document], int]:
|
||||
) -> tuple[list[Document], list[dict[str, Any]], int]:
|
||||
doc_batch: list[Document] = []
|
||||
unused_attachments: list[dict[str, Any]] = []
|
||||
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
batch = self._fetch_pages(self.confluence_client, start_ind)
|
||||
|
||||
for page in batch:
|
||||
last_modified_str = page["version"]["when"]
|
||||
last_modified = _datetime_from_string(page["version"]["when"])
|
||||
author = cast(str | None, page["version"].get("by", {}).get("email"))
|
||||
last_modified = datetime.fromisoformat(last_modified_str)
|
||||
|
||||
if last_modified.tzinfo is None:
|
||||
# If no timezone info, assume it is UTC
|
||||
last_modified = last_modified.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# If not in UTC, translate it
|
||||
last_modified = last_modified.astimezone(timezone.utc)
|
||||
if time_filter and not time_filter(last_modified):
|
||||
continue
|
||||
|
||||
if time_filter is None or time_filter(last_modified):
|
||||
page_id = page["id"]
|
||||
page_id = page["id"]
|
||||
|
||||
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
|
||||
page_labels = self._fetch_labels(self.confluence_client, page_id)
|
||||
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
|
||||
page_labels = self._fetch_labels(self.confluence_client, page_id)
|
||||
|
||||
# check disallowed labels
|
||||
if self.labels_to_skip:
|
||||
label_intersection = self.labels_to_skip.intersection(page_labels)
|
||||
if label_intersection:
|
||||
logger.info(
|
||||
f"Page with ID '{page_id}' has a label which has been "
|
||||
f"designated as disallowed: {label_intersection}. Skipping."
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
page_html = (
|
||||
page["body"]
|
||||
.get("storage", page["body"].get("view", {}))
|
||||
.get("value")
|
||||
)
|
||||
page_url = self.wiki_base + page["_links"]["webui"]
|
||||
if not page_html:
|
||||
logger.debug("Page is empty, skipping: %s", page_url)
|
||||
continue
|
||||
page_text = parse_html_page(page_html, self.confluence_client)
|
||||
|
||||
files_in_used = get_used_attachments(page_html, self.confluence_client)
|
||||
attachment_text = self._fetch_attachments(
|
||||
self.confluence_client, page_id, files_in_used
|
||||
)
|
||||
page_text += attachment_text
|
||||
comments_text = self._fetch_comments(self.confluence_client, page_id)
|
||||
page_text += comments_text
|
||||
doc_metadata: dict[str, str | list[str]] = {
|
||||
"Wiki Space Name": self.space
|
||||
}
|
||||
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING and page_labels:
|
||||
doc_metadata["labels"] = page_labels
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=page_url,
|
||||
sections=[Section(link=page_url, text=page_text)],
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=page["title"],
|
||||
doc_updated_at=last_modified,
|
||||
primary_owners=(
|
||||
[BasicExpertInfo(email=author)] if author else None
|
||||
),
|
||||
metadata=doc_metadata,
|
||||
# check disallowed labels
|
||||
if self.labels_to_skip:
|
||||
label_intersection = self.labels_to_skip.intersection(page_labels)
|
||||
if label_intersection:
|
||||
logger.info(
|
||||
f"Page with ID '{page_id}' has a label which has been "
|
||||
f"designated as disallowed: {label_intersection}. Skipping."
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
page_html = (
|
||||
page["body"].get("storage", page["body"].get("view", {})).get("value")
|
||||
)
|
||||
page_url = self.wiki_base + page["_links"]["webui"]
|
||||
if not page_html:
|
||||
logger.debug("Page is empty, skipping: %s", page_url)
|
||||
continue
|
||||
page_text = parse_html_page(page_html, self.confluence_client)
|
||||
|
||||
files_in_used = get_used_attachments(page_html, self.confluence_client)
|
||||
attachment_text, unused_page_attachments = self._fetch_attachments(
|
||||
self.confluence_client, page_id, files_in_used
|
||||
)
|
||||
unused_attachments.extend(unused_page_attachments)
|
||||
|
||||
page_text += attachment_text
|
||||
comments_text = self._fetch_comments(self.confluence_client, page_id)
|
||||
page_text += comments_text
|
||||
doc_metadata: dict[str, str | list[str]] = {"Wiki Space Name": self.space}
|
||||
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING and page_labels:
|
||||
doc_metadata["labels"] = page_labels
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=page_url,
|
||||
sections=[Section(link=page_url, text=page_text)],
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=page["title"],
|
||||
doc_updated_at=last_modified,
|
||||
primary_owners=(
|
||||
[BasicExpertInfo(email=author)] if author else None
|
||||
),
|
||||
metadata=doc_metadata,
|
||||
)
|
||||
return doc_batch, len(batch)
|
||||
)
|
||||
return (
|
||||
doc_batch,
|
||||
unused_attachments,
|
||||
len(batch),
|
||||
)
|
||||
|
||||
def _get_attachment_batch(
|
||||
self,
|
||||
start_ind: int,
|
||||
attachments: list[dict[str, Any]],
|
||||
time_filter: Callable[[datetime], bool] | None = None,
|
||||
) -> tuple[list[Document], int]:
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
end_ind = min(start_ind + self.batch_size, len(attachments))
|
||||
|
||||
for attachment in attachments[start_ind:end_ind]:
|
||||
last_updated = _datetime_from_string(
|
||||
attachment["history"]["lastUpdated"]["when"]
|
||||
)
|
||||
|
||||
if time_filter and not time_filter(last_updated):
|
||||
continue
|
||||
|
||||
attachment_url = self._attachment_to_download_link(
|
||||
self.confluence_client, attachment
|
||||
)
|
||||
attachment_content = self._attachment_to_content(
|
||||
self.confluence_client, attachment
|
||||
)
|
||||
if attachment_content is None:
|
||||
continue
|
||||
|
||||
creator_email = attachment["history"]["createdBy"].get("email")
|
||||
|
||||
comment = attachment["metadata"].get("comment", "")
|
||||
doc_metadata: dict[str, str | list[str]] = {"comment": comment}
|
||||
|
||||
attachment_labels: list[str] = []
|
||||
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
|
||||
for label in attachment["metadata"]["labels"]["results"]:
|
||||
attachment_labels.append(label["name"])
|
||||
|
||||
doc_metadata["labels"] = attachment_labels
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=attachment_url,
|
||||
sections=[Section(link=attachment_url, text=attachment_content)],
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=attachment["title"],
|
||||
doc_updated_at=last_updated,
|
||||
primary_owners=(
|
||||
[BasicExpertInfo(email=creator_email)]
|
||||
if creator_email
|
||||
else None
|
||||
),
|
||||
metadata=doc_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return doc_batch, end_ind - start_ind
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
unused_attachments = []
|
||||
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
start_ind = 0
|
||||
while True:
|
||||
doc_batch, num_pages = self._get_doc_batch(start_ind)
|
||||
doc_batch, unused_attachments_batch, num_pages = self._get_doc_batch(
|
||||
start_ind
|
||||
)
|
||||
unused_attachments.extend(unused_attachments_batch)
|
||||
start_ind += num_pages
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
@@ -681,9 +741,23 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
if num_pages < self.batch_size:
|
||||
break
|
||||
|
||||
start_ind = 0
|
||||
while True:
|
||||
attachment_batch, num_attachments = self._get_attachment_batch(
|
||||
start_ind, unused_attachments
|
||||
)
|
||||
start_ind += num_attachments
|
||||
if attachment_batch:
|
||||
yield attachment_batch
|
||||
|
||||
if num_attachments < self.batch_size:
|
||||
break
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
unused_attachments = []
|
||||
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
@@ -692,9 +766,11 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
start_ind = 0
|
||||
while True:
|
||||
doc_batch, num_pages = self._get_doc_batch(
|
||||
doc_batch, unused_attachments_batch, num_pages = self._get_doc_batch(
|
||||
start_ind, time_filter=lambda t: start_time <= t <= end_time
|
||||
)
|
||||
unused_attachments.extend(unused_attachments_batch)
|
||||
|
||||
start_ind += num_pages
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
@@ -702,9 +778,29 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
if num_pages < self.batch_size:
|
||||
break
|
||||
|
||||
start_ind = 0
|
||||
while True:
|
||||
attachment_batch, num_attachments = self._get_attachment_batch(
|
||||
start_ind,
|
||||
unused_attachments,
|
||||
time_filter=lambda t: start_time <= t <= end_time,
|
||||
)
|
||||
start_ind += num_attachments
|
||||
if attachment_batch:
|
||||
yield attachment_batch
|
||||
|
||||
if num_attachments < self.batch_size:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = ConfluenceConnector(os.environ["CONFLUENCE_TEST_SPACE_URL"])
|
||||
connector = ConfluenceConnector(
|
||||
wiki_base=os.environ["CONFLUENCE_TEST_SPACE_URL"],
|
||||
space=os.environ["CONFLUENCE_TEST_SPACE"],
|
||||
is_cloud=os.environ.get("CONFLUENCE_IS_CLOUD", "true").lower() == "true",
|
||||
page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""),
|
||||
index_recursively=True,
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],
|
||||
|
||||
@@ -23,25 +23,33 @@ class ConfluenceRateLimitError(Exception):
|
||||
|
||||
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
max_retries = 5
|
||||
starting_delay = 5
|
||||
backoff = 2
|
||||
max_delay = 600
|
||||
|
||||
for attempt in range(10):
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
# Check if the response or headers are None to avoid potential AttributeError
|
||||
if e.response is None or e.response.headers is None:
|
||||
logger.warning("HTTPError with `None` as response or as headers")
|
||||
raise e
|
||||
|
||||
retry_after_header = e.response.headers.get("Retry-After")
|
||||
if (
|
||||
e.response.status_code == 429
|
||||
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
|
||||
):
|
||||
retry_after = None
|
||||
try:
|
||||
retry_after = int(e.response.headers.get("Retry-After"))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
if retry_after_header is not None:
|
||||
try:
|
||||
retry_after = int(retry_after_header)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if retry_after:
|
||||
if retry_after is not None:
|
||||
logger.warning(
|
||||
f"Rate limit hit. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
@@ -55,5 +63,14 @@ def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
|
||||
else:
|
||||
# re-raise, let caller handle
|
||||
raise
|
||||
except AttributeError as e:
|
||||
# Some error within the Confluence library, unclear why it fails.
|
||||
# Users reported it to be intermittent, so just retry
|
||||
logger.warning(f"Confluence Internal Error, retrying... {e}")
|
||||
delay = min(starting_delay * (backoff**attempt), max_delay)
|
||||
time.sleep(delay)
|
||||
|
||||
if attempt == max_retries - 1:
|
||||
raise e
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
70
backend/danswer/connectors/connector_runner.py
Normal file
70
backend/danswer/connectors/connector_runner.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
from danswer.connectors.interfaces import BaseConnector
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
TimeRange = tuple[datetime, datetime]
|
||||
|
||||
|
||||
class ConnectorRunner:
|
||||
def __init__(
|
||||
self,
|
||||
connector: BaseConnector,
|
||||
time_range: TimeRange | None = None,
|
||||
fail_loudly: bool = False,
|
||||
):
|
||||
self.connector = connector
|
||||
|
||||
if isinstance(self.connector, PollConnector):
|
||||
if time_range is None:
|
||||
raise ValueError("time_range is required for PollConnector")
|
||||
|
||||
self.doc_batch_generator = self.connector.poll_source(
|
||||
time_range[0].timestamp(), time_range[1].timestamp()
|
||||
)
|
||||
|
||||
elif isinstance(self.connector, LoadConnector):
|
||||
if time_range and fail_loudly:
|
||||
raise ValueError(
|
||||
"time_range specified, but passed in connector is not a PollConnector"
|
||||
)
|
||||
|
||||
self.doc_batch_generator = self.connector.load_from_state()
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid connector. type: {type(self.connector)}")
|
||||
|
||||
def run(self) -> GenerateDocumentsOutput:
|
||||
"""Adds additional exception logging to the connector."""
|
||||
try:
|
||||
yield from self.doc_batch_generator
|
||||
except Exception:
|
||||
exc_type, _, exc_traceback = sys.exc_info()
|
||||
|
||||
# Traverse the traceback to find the last frame where the exception was raised
|
||||
tb = exc_traceback
|
||||
if tb is None:
|
||||
logger.error("No traceback found for exception")
|
||||
raise
|
||||
|
||||
while tb.tb_next:
|
||||
tb = tb.tb_next # Move to the next frame in the traceback
|
||||
|
||||
# Get the local variables from the frame where the exception occurred
|
||||
local_vars = tb.tb_frame.f_locals
|
||||
local_vars_str = "\n".join(
|
||||
f"{key}: {value}" for key, value in local_vars.items()
|
||||
)
|
||||
logger.error(
|
||||
f"Error in connector. type: {exc_type};\n"
|
||||
f"local_vars below -> \n{local_vars_str}"
|
||||
)
|
||||
raise
|
||||
@@ -56,7 +56,7 @@ class _RateLimitDecorator:
|
||||
sleep_cnt = 0
|
||||
while len(self.call_history) == self.max_calls:
|
||||
sleep_time = self.sleep_time * (self.sleep_backoff**sleep_cnt)
|
||||
logger.info(
|
||||
logger.notice(
|
||||
f"Rate limit exceeded for function {func.__name__}. "
|
||||
f"Waiting {sleep_time} seconds before retrying."
|
||||
)
|
||||
|
||||
@@ -45,10 +45,15 @@ def extract_jira_project(url: str) -> tuple[str, str]:
|
||||
return jira_base, jira_project
|
||||
|
||||
|
||||
def extract_text_from_content(content: dict) -> str:
|
||||
def extract_text_from_adf(adf: dict | None) -> str:
|
||||
"""Extracts plain text from Atlassian Document Format:
|
||||
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
|
||||
|
||||
WARNING: This function is incomplete and will e.g. skip lists!
|
||||
"""
|
||||
texts = []
|
||||
if "content" in content:
|
||||
for block in content["content"]:
|
||||
if adf is not None and "content" in adf:
|
||||
for block in adf["content"]:
|
||||
if "content" in block:
|
||||
for item in block["content"]:
|
||||
if item["type"] == "text":
|
||||
@@ -72,18 +77,15 @@ def _get_comment_strs(
|
||||
comment_strs = []
|
||||
for comment in jira.fields.comment.comments:
|
||||
try:
|
||||
if hasattr(comment, "body"):
|
||||
body_text = extract_text_from_content(comment.raw["body"])
|
||||
elif hasattr(comment, "raw"):
|
||||
body = comment.raw.get("body", "No body content available")
|
||||
body_text = (
|
||||
extract_text_from_content(body) if isinstance(body, dict) else body
|
||||
)
|
||||
else:
|
||||
body_text = "No body attribute found"
|
||||
body_text = (
|
||||
comment.body
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(comment.raw["body"])
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(comment, "author")
|
||||
and hasattr(comment.author, "emailAddress")
|
||||
and comment.author.emailAddress in comment_email_blacklist
|
||||
):
|
||||
continue # Skip adding comment if author's email is in blacklist
|
||||
@@ -126,11 +128,14 @@ def fetch_jira_issues_batch(
|
||||
)
|
||||
continue
|
||||
|
||||
description = (
|
||||
jira.fields.description
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(jira.raw["fields"]["description"])
|
||||
)
|
||||
comments = _get_comment_strs(jira, comment_email_blacklist)
|
||||
semantic_rep = (
|
||||
f"{jira.fields.description}\n"
|
||||
if jira.fields.description
|
||||
else "" + "\n".join([f"Comment: {comment}" for comment in comments])
|
||||
semantic_rep = f"{description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments if comment]
|
||||
)
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
|
||||
|
||||
@@ -23,7 +23,7 @@ from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.file_processing.extract_file_text import get_file_ext
|
||||
from danswer.file_processing.extract_file_text import is_text_file_extension
|
||||
from danswer.file_processing.extract_file_text import load_files_from_zip
|
||||
from danswer.file_processing.extract_file_text import pdf_to_text
|
||||
from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.file_processing.extract_file_text import read_text_file
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -75,7 +75,7 @@ def _process_file(
|
||||
|
||||
# Using the PDF reader function directly to pass in password cleanly
|
||||
elif extension == ".pdf":
|
||||
file_content_raw = pdf_to_text(file=file, pdf_pass=pdf_pass)
|
||||
file_content_raw, file_metadata = read_pdf_file(file=file, pdf_pass=pdf_pass)
|
||||
|
||||
else:
|
||||
file_content_raw = extract_file_text(
|
||||
|
||||
@@ -38,7 +38,7 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None:
|
||||
tzinfo=timezone.utc
|
||||
) - datetime.now(tz=timezone.utc)
|
||||
sleep_time += timedelta(minutes=1) # add an extra minute just to be safe
|
||||
logger.info(f"Ran into Github rate-limit. Sleeping {sleep_time.seconds} seconds.")
|
||||
logger.notice(f"Ran into Github rate-limit. Sleeping {sleep_time.seconds} seconds.")
|
||||
time.sleep(sleep_time.seconds)
|
||||
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ def get_gmail_creds_for_authorized_user(
|
||||
try:
|
||||
creds.refresh(Request())
|
||||
if creds.valid:
|
||||
logger.info("Refreshed Gmail tokens.")
|
||||
logger.notice("Refreshed Gmail tokens.")
|
||||
return creds
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to refresh gmail access token due to: {e}")
|
||||
@@ -125,7 +125,7 @@ def update_gmail_credential_access_tokens(
|
||||
) -> OAuthCredentials | None:
|
||||
app_credentials = get_google_app_gmail_cred()
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
app_credentials.dict(),
|
||||
app_credentials.model_dump(),
|
||||
scopes=SCOPES,
|
||||
redirect_uri=_build_frontend_gmail_redirect(),
|
||||
)
|
||||
|
||||
@@ -81,10 +81,10 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
|
||||
for workspace in workspace_list:
|
||||
if workspace:
|
||||
logger.info(f"Updating workspace: {workspace}")
|
||||
logger.info(f"Updating Gong workspace: {workspace}")
|
||||
workspace_id = workspace_map.get(workspace)
|
||||
if not workspace_id:
|
||||
logger.error(f"Invalid workspace: {workspace}")
|
||||
logger.error(f"Invalid Gong workspace: {workspace}")
|
||||
if not self.continue_on_fail:
|
||||
raise ValueError(f"Invalid workspace: {workspace}")
|
||||
continue
|
||||
|
||||
@@ -41,8 +41,8 @@ from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import docx_to_text
|
||||
from danswer.file_processing.extract_file_text import pdf_to_text
|
||||
from danswer.file_processing.extract_file_text import pptx_to_text
|
||||
from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -62,6 +62,8 @@ class GDriveMimeType(str, Enum):
|
||||
POWERPOINT = (
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
)
|
||||
PLAIN_TEXT = "text/plain"
|
||||
MARKDOWN = "text/markdown"
|
||||
|
||||
|
||||
GoogleDriveFileType = dict[str, Any]
|
||||
@@ -267,7 +269,7 @@ def get_all_files_batched(
|
||||
yield from batch_generator(
|
||||
items=found_files,
|
||||
batch_size=batch_size,
|
||||
pre_batch_yield=lambda batch_files: logger.info(
|
||||
pre_batch_yield=lambda batch_files: logger.debug(
|
||||
f"Parseable Documents in batch: {[file['name'] for file in batch_files]}"
|
||||
),
|
||||
)
|
||||
@@ -316,25 +318,29 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
|
||||
GDriveMimeType.PPT.value,
|
||||
GDriveMimeType.SPREADSHEET.value,
|
||||
]:
|
||||
export_mime_type = "text/plain"
|
||||
if mime_type == GDriveMimeType.SPREADSHEET.value:
|
||||
export_mime_type = "text/csv"
|
||||
elif mime_type == GDriveMimeType.PPT.value:
|
||||
export_mime_type = "text/plain"
|
||||
|
||||
response = (
|
||||
export_mime_type = (
|
||||
"text/plain"
|
||||
if mime_type != GDriveMimeType.SPREADSHEET.value
|
||||
else "text/csv"
|
||||
)
|
||||
return (
|
||||
service.files()
|
||||
.export(fileId=file["id"], mimeType=export_mime_type)
|
||||
.execute()
|
||||
.decode("utf-8")
|
||||
)
|
||||
return response.decode("utf-8")
|
||||
|
||||
elif mime_type in [
|
||||
GDriveMimeType.PLAIN_TEXT.value,
|
||||
GDriveMimeType.MARKDOWN.value,
|
||||
]:
|
||||
return service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
|
||||
elif mime_type == GDriveMimeType.WORD_DOC.value:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
return docx_to_text(file=io.BytesIO(response))
|
||||
elif mime_type == GDriveMimeType.PDF.value:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
return pdf_to_text(file=io.BytesIO(response))
|
||||
text, _ = read_pdf_file(file=io.BytesIO(response))
|
||||
return text
|
||||
elif mime_type == GDriveMimeType.POWERPOINT.value:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
return pptx_to_text(file=io.BytesIO(response))
|
||||
|
||||
@@ -50,7 +50,7 @@ def get_google_drive_creds_for_authorized_user(
|
||||
try:
|
||||
creds.refresh(Request())
|
||||
if creds.valid:
|
||||
logger.info("Refreshed Google Drive tokens.")
|
||||
logger.notice("Refreshed Google Drive tokens.")
|
||||
return creds
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to refresh google drive access token due to: {e}")
|
||||
@@ -106,7 +106,7 @@ def update_credential_access_tokens(
|
||||
) -> OAuthCredentials | None:
|
||||
app_credentials = get_google_app_cred()
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
app_credentials.dict(),
|
||||
app_credentials.model_dump(),
|
||||
scopes=SCOPES,
|
||||
redirect_uri=_build_frontend_google_drive_redirect(),
|
||||
)
|
||||
|
||||
@@ -103,6 +103,10 @@ class GuruConnector(LoadConnector, PollConnector):
|
||||
# In UI it's called Folders
|
||||
metadata_dict["folders"] = boards
|
||||
|
||||
collection = card.get("collection", {})
|
||||
if collection:
|
||||
metadata_dict["collection_name"] = collection.get("name", "")
|
||||
|
||||
owner = card.get("owner", {})
|
||||
author = None
|
||||
if owner:
|
||||
|
||||
@@ -166,6 +166,36 @@ class Document(DocumentBase):
|
||||
)
|
||||
|
||||
|
||||
class DocumentErrorSummary(BaseModel):
|
||||
id: str
|
||||
semantic_id: str
|
||||
section_link: str | None
|
||||
|
||||
@classmethod
|
||||
def from_document(cls, doc: Document) -> "DocumentErrorSummary":
|
||||
section_link = doc.sections[0].link if len(doc.sections) > 0 else None
|
||||
return cls(
|
||||
id=doc.id, semantic_id=doc.semantic_identifier, section_link=section_link
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "DocumentErrorSummary":
|
||||
return cls(
|
||||
id=str(data.get("id")),
|
||||
semantic_id=str(data.get("semantic_id")),
|
||||
section_link=str(data.get("section_link")),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, str | None]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"semantic_id": self.semantic_id,
|
||||
"section_link": self.section_link,
|
||||
}
|
||||
|
||||
|
||||
class IndexAttemptMetadata(BaseModel):
|
||||
batch_num: int | None = None
|
||||
num_exceptions: int = 0
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
|
||||
@@ -237,6 +237,14 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
continue
|
||||
|
||||
if result_type == "external_object_instance_page":
|
||||
logger.warning(
|
||||
f"Skipping 'external_object_instance_page' ('{result_block_id}') for base block '{base_block_id}': "
|
||||
f"Notion API does not currently support reading external blocks (as of 24/07/03) "
|
||||
f"(discussion: https://github.com/danswer-ai/danswer/issues/1761)"
|
||||
)
|
||||
continue
|
||||
|
||||
cur_result_text_arr = []
|
||||
if "rich_text" in result_obj:
|
||||
for rich_text in result_obj["rich_text"]:
|
||||
|
||||
@@ -98,6 +98,15 @@ class ProductboardConnector(PollConnector):
|
||||
owner = self._get_owner_email(feature)
|
||||
experts = [BasicExpertInfo(email=owner)] if owner else None
|
||||
|
||||
metadata: dict[str, str | list[str]] = {}
|
||||
entity_type = feature.get("type", "feature")
|
||||
if entity_type:
|
||||
metadata["entity_type"] = str(entity_type)
|
||||
|
||||
status = feature.get("status", {}).get("name")
|
||||
if status:
|
||||
metadata["status"] = str(status)
|
||||
|
||||
yield Document(
|
||||
id=feature["id"],
|
||||
sections=[
|
||||
@@ -110,10 +119,7 @@ class ProductboardConnector(PollConnector):
|
||||
source=DocumentSource.PRODUCTBOARD,
|
||||
doc_updated_at=time_str_to_utc(feature["updatedAt"]),
|
||||
primary_owners=experts,
|
||||
metadata={
|
||||
"entity_type": feature["type"],
|
||||
"status": feature["status"]["name"],
|
||||
},
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _get_components(self) -> Generator[Document, None, None]:
|
||||
@@ -174,6 +180,12 @@ class ProductboardConnector(PollConnector):
|
||||
owner = self._get_owner_email(objective)
|
||||
experts = [BasicExpertInfo(email=owner)] if owner else None
|
||||
|
||||
metadata: dict[str, str | list[str]] = {
|
||||
"entity_type": "objective",
|
||||
}
|
||||
if objective.get("state"):
|
||||
metadata["state"] = str(objective["state"])
|
||||
|
||||
yield Document(
|
||||
id=objective["id"],
|
||||
sections=[
|
||||
@@ -186,10 +198,7 @@ class ProductboardConnector(PollConnector):
|
||||
source=DocumentSource.PRODUCTBOARD,
|
||||
doc_updated_at=time_str_to_utc(objective["updatedAt"]),
|
||||
primary_owners=experts,
|
||||
metadata={
|
||||
"entity_type": "release",
|
||||
"state": objective["state"],
|
||||
},
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _is_updated_at_out_of_time_range(
|
||||
|
||||
@@ -25,7 +25,6 @@ from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -137,7 +136,7 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
.execute_query()
|
||||
]
|
||||
else:
|
||||
sites = self.graph_client.sites.get().execute_query()
|
||||
sites = self.graph_client.sites.get_all().execute_query()
|
||||
self.site_data = [
|
||||
SiteData(url=None, folder=None, sites=sites, driveitems=[])
|
||||
]
|
||||
|
||||
@@ -29,6 +29,7 @@ from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.connectors.slack.utils import SlackTextCleaner
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import io
|
||||
import ipaddress
|
||||
import socket
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -27,7 +29,7 @@ from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import pdf_to_text
|
||||
from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.file_processing.html_utils import web_html_cleanup
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.sitemap import list_pages_for_site
|
||||
@@ -84,6 +86,20 @@ def check_internet_connection(url: str) -> None:
|
||||
try:
|
||||
response = requests.get(url, timeout=3)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
# Extract status code from the response, defaulting to -1 if response is None
|
||||
status_code = e.response.status_code if e.response is not None else -1
|
||||
error_msg = {
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "Not Found",
|
||||
500: "Internal Server Error",
|
||||
502: "Bad Gateway",
|
||||
503: "Service Unavailable",
|
||||
504: "Gateway Timeout",
|
||||
}.get(status_code, "HTTP Error")
|
||||
raise Exception(f"{error_msg} ({status_code}) for {url} - {e}")
|
||||
except requests.exceptions.SSLError as e:
|
||||
cause = (
|
||||
e.args[0].reason
|
||||
@@ -91,8 +107,8 @@ def check_internet_connection(url: str) -> None:
|
||||
else e.args
|
||||
)
|
||||
raise Exception(f"SSL error {str(cause)}")
|
||||
except (requests.RequestException, ValueError):
|
||||
raise Exception(f"Unable to reach {url} - check your internet connection")
|
||||
except (requests.RequestException, ValueError) as e:
|
||||
raise Exception(f"Unable to reach {url} - check your internet connection: {e}")
|
||||
|
||||
|
||||
def is_valid_url(url: str) -> bool:
|
||||
@@ -189,6 +205,15 @@ def _read_urls_file(location: str) -> list[str]:
|
||||
return urls
|
||||
|
||||
|
||||
def _get_datetime_from_last_modified_header(last_modified: str) -> datetime | None:
|
||||
try:
|
||||
return datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z").replace(
|
||||
tzinfo=timezone.utc
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
class WebConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -271,7 +296,10 @@ class WebConnector(LoadConnector):
|
||||
if current_url.split(".")[-1] == "pdf":
|
||||
# PDF files are not checked for links
|
||||
response = requests.get(current_url)
|
||||
page_text = pdf_to_text(file=io.BytesIO(response.content))
|
||||
page_text, metadata = read_pdf_file(
|
||||
file=io.BytesIO(response.content)
|
||||
)
|
||||
last_modified = response.headers.get("Last-Modified")
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
@@ -279,13 +307,23 @@ class WebConnector(LoadConnector):
|
||||
sections=[Section(link=current_url, text=page_text)],
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=current_url.split("/")[-1],
|
||||
metadata={},
|
||||
metadata=metadata,
|
||||
doc_updated_at=_get_datetime_from_last_modified_header(
|
||||
last_modified
|
||||
)
|
||||
if last_modified
|
||||
else None,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
page = context.new_page()
|
||||
page_response = page.goto(current_url)
|
||||
last_modified = (
|
||||
page_response.header_value("Last-Modified")
|
||||
if page_response
|
||||
else None
|
||||
)
|
||||
final_page = page.url
|
||||
if final_page != current_url:
|
||||
logger.info(f"Redirected to {final_page}")
|
||||
@@ -321,6 +359,11 @@ class WebConnector(LoadConnector):
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=parsed_html.title or current_url,
|
||||
metadata={},
|
||||
doc_updated_at=_get_datetime_from_last_modified_header(
|
||||
last_modified
|
||||
)
|
||||
if last_modified
|
||||
else None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any
|
||||
import requests
|
||||
from retry import retry
|
||||
from zenpy import Zenpy # type: ignore
|
||||
from zenpy.lib.api_objects import Ticket # type: ignore
|
||||
from zenpy.lib.api_objects.help_centre_objects import Article # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
@@ -59,10 +60,15 @@ class ZendeskClientNotSetUpError(PermissionError):
|
||||
|
||||
|
||||
class ZendeskConnector(LoadConnector, PollConnector):
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
content_type: str = "articles",
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.zendesk_client: Zenpy | None = None
|
||||
self.content_tags: dict[str, str] = {}
|
||||
self.content_type = content_type
|
||||
|
||||
@retry(tries=3, delay=2, backoff=2)
|
||||
def _set_content_tags(
|
||||
@@ -122,16 +128,86 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self.poll_source(None, None)
|
||||
|
||||
def _ticket_to_document(self, ticket: Ticket) -> Document:
|
||||
if self.zendesk_client is None:
|
||||
raise ZendeskClientNotSetUpError()
|
||||
|
||||
owner = None
|
||||
if ticket.requester and ticket.requester.name and ticket.requester.email:
|
||||
owner = [
|
||||
BasicExpertInfo(
|
||||
display_name=ticket.requester.name, email=ticket.requester.email
|
||||
)
|
||||
]
|
||||
update_time = time_str_to_utc(ticket.updated_at) if ticket.updated_at else None
|
||||
|
||||
metadata: dict[str, str | list[str]] = {}
|
||||
if ticket.status is not None:
|
||||
metadata["status"] = ticket.status
|
||||
if ticket.priority is not None:
|
||||
metadata["priority"] = ticket.priority
|
||||
if ticket.tags:
|
||||
metadata["tags"] = ticket.tags
|
||||
if ticket.type is not None:
|
||||
metadata["ticket_type"] = ticket.type
|
||||
|
||||
# Fetch comments for the ticket
|
||||
comments = self.zendesk_client.tickets.comments(ticket=ticket)
|
||||
|
||||
# Combine all comments into a single text
|
||||
comments_text = "\n\n".join(
|
||||
[
|
||||
f"Comment{f' by {comment.author.name}' if comment.author and comment.author.name else ''}"
|
||||
f"{f' at {comment.created_at}' if comment.created_at else ''}:\n{comment.body}"
|
||||
for comment in comments
|
||||
if comment.body
|
||||
]
|
||||
)
|
||||
|
||||
# Combine ticket description and comments
|
||||
description = (
|
||||
ticket.description
|
||||
if hasattr(ticket, "description") and ticket.description
|
||||
else ""
|
||||
)
|
||||
full_text = f"Ticket Description:\n{description}\n\nComments:\n{comments_text}"
|
||||
|
||||
# Extract subdomain from ticket.url
|
||||
subdomain = ticket.url.split("//")[1].split(".zendesk.com")[0]
|
||||
|
||||
# Build the html url for the ticket
|
||||
ticket_url = f"https://{subdomain}.zendesk.com/agent/tickets/{ticket.id}"
|
||||
|
||||
return Document(
|
||||
id=f"zendesk_ticket_{ticket.id}",
|
||||
sections=[Section(link=ticket_url, text=full_text)],
|
||||
source=DocumentSource.ZENDESK,
|
||||
semantic_identifier=f"Ticket #{ticket.id}: {ticket.subject or 'No Subject'}",
|
||||
doc_updated_at=update_time,
|
||||
primary_owners=owner,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.zendesk_client is None:
|
||||
raise ZendeskClientNotSetUpError()
|
||||
|
||||
if self.content_type == "articles":
|
||||
yield from self._poll_articles(start)
|
||||
elif self.content_type == "tickets":
|
||||
yield from self._poll_tickets(start)
|
||||
else:
|
||||
raise ValueError(f"Unsupported content_type: {self.content_type}")
|
||||
|
||||
def _poll_articles(
|
||||
self, start: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
articles = (
|
||||
self.zendesk_client.help_center.articles(cursor_pagination=True)
|
||||
self.zendesk_client.help_center.articles(cursor_pagination=True) # type: ignore
|
||||
if start is None
|
||||
else self.zendesk_client.help_center.articles.incremental(
|
||||
else self.zendesk_client.help_center.articles.incremental( # type: ignore
|
||||
start_time=int(start)
|
||||
)
|
||||
)
|
||||
@@ -155,9 +231,43 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def _poll_tickets(
|
||||
self, start: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.zendesk_client is None:
|
||||
raise ZendeskClientNotSetUpError()
|
||||
|
||||
ticket_generator = self.zendesk_client.tickets.incremental(start_time=start)
|
||||
|
||||
while True:
|
||||
doc_batch = []
|
||||
for _ in range(self.batch_size):
|
||||
try:
|
||||
ticket = next(ticket_generator)
|
||||
|
||||
# Check if the ticket status is deleted and skip it if so
|
||||
if ticket.status == "deleted":
|
||||
continue
|
||||
|
||||
doc_batch.append(self._ticket_to_document(ticket))
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch.clear()
|
||||
|
||||
except StopIteration:
|
||||
# No more tickets to process
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
return
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
import time
|
||||
|
||||
connector = ZendeskConnector()
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
@@ -18,11 +19,11 @@ class Message(BaseModel):
|
||||
sender_realm_str: str
|
||||
subject: str
|
||||
topic_links: Optional[List[Any]] = None
|
||||
last_edit_timestamp: Optional[int] = None
|
||||
edit_history: Any
|
||||
last_edit_timestamp: Optional[int]
|
||||
edit_history: Any = None
|
||||
reactions: List[Any]
|
||||
submessages: List[Any]
|
||||
flags: List[str] = []
|
||||
flags: List[str] = Field(default_factory=list)
|
||||
display_recipient: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
stream_id: int
|
||||
@@ -39,4 +40,4 @@ class GetMessagesResponse(BaseModel):
|
||||
found_newest: Optional[bool] = None
|
||||
history_limited: Optional[bool] = None
|
||||
anchor: Optional[str] = None
|
||||
messages: List[Message] = []
|
||||
messages: List[Message] = Field(default_factory=list)
|
||||
|
||||
@@ -25,7 +25,6 @@ from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.icons import source_to_github_img_link
|
||||
@@ -360,22 +359,6 @@ def build_quotes_block(
|
||||
return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))]
|
||||
|
||||
|
||||
def build_standard_answer_blocks(
|
||||
answer_message: str,
|
||||
) -> list[Block]:
|
||||
generate_button_block = ButtonElement(
|
||||
action_id=GENERATE_ANSWER_BUTTON_ACTION_ID,
|
||||
text="Generate Full Answer",
|
||||
)
|
||||
answer_block = SectionBlock(text=answer_message)
|
||||
return [
|
||||
answer_block,
|
||||
ActionsBlock(
|
||||
elements=[generate_button_block],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def build_qa_response_blocks(
|
||||
message_id: int | None,
|
||||
answer: str | None,
|
||||
|
||||
@@ -6,7 +6,6 @@ FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button"
|
||||
IMMEDIATE_RESOLVED_BUTTON_ACTION_ID = "immediate-resolved-button"
|
||||
FOLLOWUP_BUTTON_ACTION_ID = "followup-button"
|
||||
FOLLOWUP_BUTTON_RESOLVED_ACTION_ID = "followup-resolved-button"
|
||||
SLACK_CHANNEL_ID = "channel_id"
|
||||
VIEW_DOC_FEEDBACK_ID = "view-doc-feedback"
|
||||
GENERATE_ANSWER_BUTTON_ACTION_ID = "generate-answer-button"
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -12,6 +11,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||
from danswer.connectors.slack.utils import expert_info_from_slack_id
|
||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.danswerbot.slack.blocks import build_follow_up_resolved_blocks
|
||||
from danswer.danswerbot.slack.blocks import get_document_feedback_blocks
|
||||
@@ -88,6 +88,8 @@ def handle_generate_answer_button(
|
||||
message_ts = req.payload["message"]["ts"]
|
||||
thread_ts = req.payload["container"]["thread_ts"]
|
||||
user_id = req.payload["user"]["id"]
|
||||
expert_info = expert_info_from_slack_id(user_id, client.web_client, user_cache={})
|
||||
email = expert_info.email if expert_info else None
|
||||
|
||||
if not thread_ts:
|
||||
raise ValueError("Missing thread_ts in the payload")
|
||||
@@ -126,6 +128,7 @@ def handle_generate_answer_button(
|
||||
msg_to_respond=cast(str, message_ts or thread_ts),
|
||||
thread_to_respond=cast(str, thread_ts or message_ts),
|
||||
sender=user_id or None,
|
||||
email=email or None,
|
||||
bypass_filters=True,
|
||||
is_bot_msg=False,
|
||||
is_bot_dm=False,
|
||||
@@ -134,7 +137,7 @@ def handle_generate_answer_button(
|
||||
receiver_ids=None,
|
||||
client=client.web_client,
|
||||
channel=channel_id,
|
||||
logger=cast(logging.Logger, logger),
|
||||
logger=logger,
|
||||
feedback_reminder_id=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import datetime
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
@@ -9,7 +7,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_FEEDBACK_REMINDER
|
||||
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
|
||||
from danswer.danswerbot.slack.blocks import get_feedback_reminder_blocks
|
||||
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
|
||||
from danswer.danswerbot.slack.handlers.handle_regular_answer import (
|
||||
handle_regular_answer,
|
||||
)
|
||||
@@ -17,7 +14,6 @@ from danswer.danswerbot.slack.handlers.handle_standard_answers import (
|
||||
handle_standard_answers,
|
||||
)
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import ChannelIdAdapter
|
||||
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
|
||||
from danswer.danswerbot.slack.utils import fetch_user_ids_from_groups
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
@@ -25,7 +21,9 @@ from danswer.danswerbot.slack.utils import slack_usage_report
|
||||
from danswer.danswerbot.slack.utils import update_emote_react
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.users import add_non_web_user_if_not_exists
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import SLACK_CHANNEL_ID
|
||||
|
||||
logger_base = setup_logger()
|
||||
|
||||
@@ -53,12 +51,8 @@ def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None:
|
||||
def schedule_feedback_reminder(
|
||||
details: SlackMessageInfo, include_followup: bool, client: WebClient
|
||||
) -> str | None:
|
||||
logger = cast(
|
||||
logging.Logger,
|
||||
ChannelIdAdapter(
|
||||
logger_base, extra={SLACK_CHANNEL_ID: details.channel_to_respond}
|
||||
),
|
||||
)
|
||||
logger = setup_logger(extra={SLACK_CHANNEL_ID: details.channel_to_respond})
|
||||
|
||||
if not DANSWER_BOT_FEEDBACK_REMINDER:
|
||||
logger.info("Scheduled feedback reminder disabled...")
|
||||
return None
|
||||
@@ -97,10 +91,7 @@ def schedule_feedback_reminder(
|
||||
def remove_scheduled_feedback_reminder(
|
||||
client: WebClient, channel: str | None, msg_id: str
|
||||
) -> None:
|
||||
logger = cast(
|
||||
logging.Logger,
|
||||
ChannelIdAdapter(logger_base, extra={SLACK_CHANNEL_ID: channel}),
|
||||
)
|
||||
logger = setup_logger(extra={SLACK_CHANNEL_ID: channel})
|
||||
|
||||
try:
|
||||
client.chat_deleteScheduledMessage(
|
||||
@@ -129,10 +120,7 @@ def handle_message(
|
||||
"""
|
||||
channel = message_info.channel_to_respond
|
||||
|
||||
logger = cast(
|
||||
logging.Logger,
|
||||
ChannelIdAdapter(logger_base, extra={SLACK_CHANNEL_ID: channel}),
|
||||
)
|
||||
logger = setup_logger(extra={SLACK_CHANNEL_ID: channel})
|
||||
|
||||
messages = message_info.thread_messages
|
||||
sender_id = message_info.sender
|
||||
@@ -222,6 +210,9 @@ def handle_message(
|
||||
logger.error(f"Was not able to react to user message due to: {e}")
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if message_info.email:
|
||||
add_non_web_user_if_not_exists(message_info.email, db_session)
|
||||
|
||||
# first check if we need to respond with a standard answer
|
||||
used_standard_answer = handle_standard_answers(
|
||||
message_info=message_info,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -38,6 +37,8 @@ from danswer.db.models import Persona
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.persona import fetch_persona_by_id
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
@@ -49,8 +50,9 @@ from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.models import BaseFilters
|
||||
from danswer.search.models import RerankingDetails
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.search.search_settings import get_search_settings
|
||||
from danswer.utils.logger import DanswerLoggingAdapter
|
||||
|
||||
|
||||
srl = SlackRateLimiter()
|
||||
@@ -83,7 +85,7 @@ def handle_regular_answer(
|
||||
receiver_ids: list[str] | None,
|
||||
client: WebClient,
|
||||
channel: str,
|
||||
logger: logging.Logger,
|
||||
logger: DanswerLoggingAdapter,
|
||||
feedback_reminder_id: str | None,
|
||||
num_retries: int = DANSWER_BOT_NUM_RETRIES,
|
||||
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
|
||||
@@ -98,6 +100,12 @@ def handle_regular_answer(
|
||||
messages = message_info.thread_messages
|
||||
message_ts_to_respond_to = message_info.msg_to_respond
|
||||
is_bot_msg = message_info.is_bot_msg
|
||||
user = None
|
||||
if message_info.is_bot_dm:
|
||||
if message_info.email:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
user = get_user_by_email(message_info.email, db_session)
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
persona = slack_bot_config.persona if slack_bot_config else None
|
||||
@@ -127,7 +135,8 @@ def handle_regular_answer(
|
||||
else slack_bot_config.response_type == SlackBotResponseType.CITATIONS
|
||||
)
|
||||
|
||||
if not message_ts_to_respond_to:
|
||||
if not message_ts_to_respond_to and not is_bot_msg:
|
||||
# if the message is not "/danswer" command, then it should have a message ts to respond to
|
||||
raise RuntimeError(
|
||||
"No message timestamp to respond to in `handle_message`. This should never happen."
|
||||
)
|
||||
@@ -136,7 +145,6 @@ def handle_regular_answer(
|
||||
tries=num_retries,
|
||||
delay=0.25,
|
||||
backoff=2,
|
||||
logger=logger,
|
||||
)
|
||||
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
|
||||
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | None:
|
||||
@@ -147,7 +155,12 @@ def handle_regular_answer(
|
||||
if len(new_message_request.messages) > 1:
|
||||
persona = cast(
|
||||
Persona,
|
||||
fetch_persona_by_id(db_session, new_message_request.persona_id),
|
||||
fetch_persona_by_id(
|
||||
db_session,
|
||||
new_message_request.persona_id,
|
||||
user=None,
|
||||
get_editable=False,
|
||||
),
|
||||
)
|
||||
llm, _ = get_llms_for_persona(persona)
|
||||
|
||||
@@ -180,7 +193,7 @@ def handle_regular_answer(
|
||||
# This also handles creating the query event in postgres
|
||||
answer = get_search_answer(
|
||||
query_req=new_message_request,
|
||||
user=None,
|
||||
user=user,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
db_session=db_session,
|
||||
@@ -224,7 +237,8 @@ def handle_regular_answer(
|
||||
)
|
||||
|
||||
# Always apply reranking settings if it exists, this is the non-streaming flow
|
||||
saved_search_settings = get_search_settings()
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
saved_search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# This includes throwing out answer via reflexion
|
||||
answer = _get_answer(
|
||||
@@ -237,7 +251,7 @@ def handle_regular_answer(
|
||||
persona_id=persona.id if persona is not None else 0,
|
||||
retrieval_options=retrieval_details,
|
||||
chain_of_thought=not disable_cot,
|
||||
rerank_settings=saved_search_settings.to_reranking_detail()
|
||||
rerank_settings=RerankingDetails.from_db_model(saved_search_settings)
|
||||
if saved_search_settings
|
||||
else None,
|
||||
)
|
||||
@@ -319,7 +333,7 @@ def handle_regular_answer(
|
||||
)
|
||||
|
||||
if answer.answer_valid is False:
|
||||
logger.info(
|
||||
logger.notice(
|
||||
"Answer was evaluated to be invalid, throwing it away without responding."
|
||||
)
|
||||
update_emote_react(
|
||||
@@ -357,7 +371,7 @@ def handle_regular_answer(
|
||||
return True
|
||||
|
||||
if not answer.answer and disable_docs_only_answer:
|
||||
logger.info(
|
||||
logger.notice(
|
||||
"Unable to find answer - not responding since the "
|
||||
"`DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER` env variable is set"
|
||||
)
|
||||
@@ -406,7 +420,7 @@ def handle_regular_answer(
|
||||
)
|
||||
|
||||
# Get the chunks fed to the LLM only, then fill with other docs
|
||||
llm_doc_inds = answer.llm_chunks_indices or []
|
||||
llm_doc_inds = answer.llm_selected_doc_indices or []
|
||||
llm_docs = [top_docs[i] for i in llm_doc_inds]
|
||||
remaining_docs = [
|
||||
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
|
||||
@@ -457,7 +471,9 @@ def handle_regular_answer(
|
||||
|
||||
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
|
||||
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
|
||||
if receiver_ids:
|
||||
# if there is no message_ts_to_respond_to, and we have made it this far, then this is a /danswer message
|
||||
# so we shouldn't send_team_member_message
|
||||
if receiver_ids and message_ts_to_respond_to is not None:
|
||||
send_team_member_message(
|
||||
client=client,
|
||||
channel=channel,
|
||||
|
||||
@@ -1,216 +1,56 @@
|
||||
import logging
|
||||
|
||||
from slack_sdk import WebClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
|
||||
from danswer.danswerbot.slack.blocks import build_standard_answer_blocks
|
||||
from danswer.danswerbot.slack.blocks import get_restate_blocks
|
||||
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import update_emote_react
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_chat_messages_by_sessions
|
||||
from danswer.db.chat import get_chat_sessions_by_slack_thread_id
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.standard_answer import fetch_standard_answer_categories_by_names
|
||||
from danswer.db.standard_answer import find_matching_standard_answers
|
||||
from danswer.server.manage.models import StandardAnswer
|
||||
from danswer.utils.logger import DanswerLoggingAdapter
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def oneoff_standard_answers(
|
||||
message: str,
|
||||
slack_bot_categories: list[str],
|
||||
db_session: Session,
|
||||
) -> list[StandardAnswer]:
|
||||
"""
|
||||
Respond to the user message if it matches any configured standard answers.
|
||||
|
||||
Returns a list of matching StandardAnswers if found, otherwise None.
|
||||
"""
|
||||
configured_standard_answers = {
|
||||
standard_answer
|
||||
for category in fetch_standard_answer_categories_by_names(
|
||||
slack_bot_categories, db_session=db_session
|
||||
)
|
||||
for standard_answer in category.standard_answers
|
||||
}
|
||||
|
||||
matching_standard_answers = find_matching_standard_answers(
|
||||
query=message,
|
||||
id_in=[answer.id for answer in configured_standard_answers],
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
server_standard_answers = [
|
||||
StandardAnswer.from_model(db_answer) for db_answer in matching_standard_answers
|
||||
]
|
||||
return server_standard_answers
|
||||
|
||||
|
||||
def handle_standard_answers(
|
||||
message_info: SlackMessageInfo,
|
||||
receiver_ids: list[str] | None,
|
||||
slack_bot_config: SlackBotConfig | None,
|
||||
prompt: Prompt | None,
|
||||
logger: logging.Logger,
|
||||
logger: DanswerLoggingAdapter,
|
||||
client: WebClient,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
"""Returns whether one or more Standard Answer message blocks were
|
||||
emitted by the Slack bot"""
|
||||
versioned_handle_standard_answers = fetch_versioned_implementation(
|
||||
"danswer.danswerbot.slack.handlers.handle_standard_answers",
|
||||
"_handle_standard_answers",
|
||||
)
|
||||
return versioned_handle_standard_answers(
|
||||
message_info=message_info,
|
||||
receiver_ids=receiver_ids,
|
||||
slack_bot_config=slack_bot_config,
|
||||
prompt=prompt,
|
||||
logger=logger,
|
||||
client=client,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def _handle_standard_answers(
|
||||
message_info: SlackMessageInfo,
|
||||
receiver_ids: list[str] | None,
|
||||
slack_bot_config: SlackBotConfig | None,
|
||||
prompt: Prompt | None,
|
||||
logger: DanswerLoggingAdapter,
|
||||
client: WebClient,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Potentially respond to the user message depending on whether the user's message matches
|
||||
any of the configured standard answers and also whether those answers have already been
|
||||
provided in the current thread.
|
||||
Standard Answers are a paid Enterprise Edition feature. This is the fallback
|
||||
function handling the case where EE features are not enabled.
|
||||
|
||||
Returns True if standard answers are found to match the user's message and therefore,
|
||||
we still need to respond to the users.
|
||||
Always returns false i.e. since EE features are not enabled, we NEVER create any
|
||||
Slack message blocks.
|
||||
"""
|
||||
# if no channel config, then no standard answers are configured
|
||||
if not slack_bot_config:
|
||||
return False
|
||||
|
||||
slack_thread_id = message_info.thread_to_respond
|
||||
configured_standard_answer_categories = (
|
||||
slack_bot_config.standard_answer_categories if slack_bot_config else []
|
||||
)
|
||||
configured_standard_answers = set(
|
||||
[
|
||||
standard_answer
|
||||
for standard_answer_category in configured_standard_answer_categories
|
||||
for standard_answer in standard_answer_category.standard_answers
|
||||
]
|
||||
)
|
||||
query_msg = message_info.thread_messages[-1]
|
||||
|
||||
if slack_thread_id is None:
|
||||
used_standard_answer_ids = set([])
|
||||
else:
|
||||
chat_sessions = get_chat_sessions_by_slack_thread_id(
|
||||
slack_thread_id=slack_thread_id,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
chat_messages = get_chat_messages_by_sessions(
|
||||
chat_session_ids=[chat_session.id for chat_session in chat_sessions],
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
skip_permission_check=True,
|
||||
)
|
||||
used_standard_answer_ids = set(
|
||||
[
|
||||
standard_answer.id
|
||||
for chat_message in chat_messages
|
||||
for standard_answer in chat_message.standard_answers
|
||||
]
|
||||
)
|
||||
|
||||
usable_standard_answers = configured_standard_answers.difference(
|
||||
used_standard_answer_ids
|
||||
)
|
||||
if usable_standard_answers:
|
||||
matching_standard_answers = find_matching_standard_answers(
|
||||
query=query_msg.message,
|
||||
id_in=[standard_answer.id for standard_answer in usable_standard_answers],
|
||||
db_session=db_session,
|
||||
)
|
||||
else:
|
||||
matching_standard_answers = []
|
||||
if matching_standard_answers:
|
||||
chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="",
|
||||
user_id=None,
|
||||
persona_id=slack_bot_config.persona.id if slack_bot_config.persona else 0,
|
||||
danswerbot_flow=True,
|
||||
slack_thread_id=slack_thread_id,
|
||||
one_shot=True,
|
||||
)
|
||||
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
|
||||
new_user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=root_message,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
message=query_msg.message,
|
||||
token_count=10,
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
formatted_answers = []
|
||||
for standard_answer in matching_standard_answers:
|
||||
block_quotified_answer = ">" + standard_answer.answer.replace("\n", "\n> ")
|
||||
formatted_answer = (
|
||||
f'Since you mentioned _"{standard_answer.keyword}"_, '
|
||||
f"I thought this might be useful: \n\n{block_quotified_answer}"
|
||||
)
|
||||
formatted_answers.append(formatted_answer)
|
||||
answer_message = "\n\n".join(formatted_answers)
|
||||
|
||||
_ = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=new_user_message,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
message=answer_message,
|
||||
token_count=0,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
error=None,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=True,
|
||||
client=client,
|
||||
)
|
||||
|
||||
restate_question_blocks = get_restate_blocks(
|
||||
msg=query_msg.message,
|
||||
is_bot_msg=message_info.is_bot_msg,
|
||||
)
|
||||
|
||||
answer_blocks = build_standard_answer_blocks(
|
||||
answer_message=answer_message,
|
||||
)
|
||||
|
||||
all_blocks = restate_question_blocks + answer_blocks
|
||||
|
||||
try:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=message_info.channel_to_respond,
|
||||
receiver_ids=receiver_ids,
|
||||
text="Hello! Danswer has some results for you!",
|
||||
blocks=all_blocks,
|
||||
thread_ts=message_info.msg_to_respond,
|
||||
unfurl=False,
|
||||
)
|
||||
|
||||
if receiver_ids and slack_thread_id:
|
||||
send_team_member_message(
|
||||
client=client,
|
||||
channel=message_info.channel_to_respond,
|
||||
thread_ts=slack_thread_id,
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to send standard answer message: {e}")
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
@@ -13,6 +13,7 @@ from danswer.configs.constants import MessageType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
|
||||
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
|
||||
from danswer.connectors.slack.utils import expert_info_from_slack_id
|
||||
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
|
||||
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
|
||||
@@ -21,7 +22,6 @@ from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_I
|
||||
from danswer.danswerbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
|
||||
from danswer.danswerbot.slack.constants import VIEW_DOC_FEEDBACK_ID
|
||||
from danswer.danswerbot.slack.handlers.handle_buttons import handle_doc_feedback_button
|
||||
from danswer.danswerbot.slack.handlers.handle_buttons import handle_followup_button
|
||||
@@ -39,7 +39,7 @@ from danswer.danswerbot.slack.handlers.handle_message import (
|
||||
from danswer.danswerbot.slack.handlers.handle_message import schedule_feedback_reminder
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.tokens import fetch_tokens
|
||||
from danswer.danswerbot.slack.utils import ChannelIdAdapter
|
||||
from danswer.danswerbot.slack.utils import check_message_limit
|
||||
from danswer.danswerbot.slack.utils import decompose_action_id
|
||||
from danswer.danswerbot.slack.utils import get_channel_name_from_id
|
||||
from danswer.danswerbot.slack.utils import get_danswer_bot_app_id
|
||||
@@ -47,16 +47,19 @@ from danswer.danswerbot.slack.utils import read_slack_thread
|
||||
from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
|
||||
from danswer.danswerbot.slack.utils import rephrase_slack_message
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import SLACK_CHANNEL_ID
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -84,18 +87,18 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
|
||||
event = cast(dict[str, Any], req.payload.get("event", {}))
|
||||
msg = cast(str | None, event.get("text"))
|
||||
channel = cast(str | None, event.get("channel"))
|
||||
channel_specific_logger = ChannelIdAdapter(
|
||||
logger, extra={SLACK_CHANNEL_ID: channel}
|
||||
)
|
||||
channel_specific_logger = setup_logger(extra={SLACK_CHANNEL_ID: channel})
|
||||
|
||||
# This should never happen, but we can't continue without a channel since
|
||||
# we can't send a response without it
|
||||
if not channel:
|
||||
channel_specific_logger.error("Found message without channel - skipping")
|
||||
channel_specific_logger.warning("Found message without channel - skipping")
|
||||
return False
|
||||
|
||||
if not msg:
|
||||
channel_specific_logger.error("Cannot respond to empty message - skipping")
|
||||
channel_specific_logger.warning(
|
||||
"Cannot respond to empty message - skipping"
|
||||
)
|
||||
return False
|
||||
|
||||
if (
|
||||
@@ -130,9 +133,19 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
|
||||
|
||||
if event_type == "message":
|
||||
bot_tag_id = get_danswer_bot_app_id(client.web_client)
|
||||
|
||||
is_dm = event.get("channel_type") == "im"
|
||||
is_tagged = bot_tag_id and bot_tag_id in msg
|
||||
is_danswer_bot_msg = bot_tag_id and bot_tag_id in event.get("user", "")
|
||||
|
||||
# DanswerBot should never respond to itself
|
||||
if is_danswer_bot_msg:
|
||||
logger.info("Ignoring message from DanswerBot")
|
||||
return False
|
||||
|
||||
# DMs with the bot don't pick up the @DanswerBot so we have to keep the
|
||||
# caught events_api
|
||||
if bot_tag_id and bot_tag_id in msg and event.get("channel_type") != "im":
|
||||
if is_tagged and not is_dm:
|
||||
# Let the tag flow handle this case, don't reply twice
|
||||
return False
|
||||
|
||||
@@ -185,9 +198,8 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
|
||||
if req.type == "slash_commands":
|
||||
# Verify that there's an associated channel
|
||||
channel = req.payload.get("channel_id")
|
||||
channel_specific_logger = ChannelIdAdapter(
|
||||
logger, extra={SLACK_CHANNEL_ID: channel}
|
||||
)
|
||||
channel_specific_logger = setup_logger(extra={SLACK_CHANNEL_ID: channel})
|
||||
|
||||
if not channel:
|
||||
channel_specific_logger.error(
|
||||
"Received DanswerBot command without channel - skipping"
|
||||
@@ -201,6 +213,9 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
|
||||
)
|
||||
return False
|
||||
|
||||
if not check_message_limit():
|
||||
return False
|
||||
|
||||
logger.debug(f"Handling Slack request with Payload: '{req.payload}'")
|
||||
return True
|
||||
|
||||
@@ -230,7 +245,7 @@ def process_feedback(req: SocketModeRequest, client: SocketModeClient) -> None:
|
||||
)
|
||||
|
||||
query_event_id, _, _ = decompose_action_id(feedback_id)
|
||||
logger.info(f"Successfully handled QA feedback for event: {query_event_id}")
|
||||
logger.notice(f"Successfully handled QA feedback for event: {query_event_id}")
|
||||
|
||||
|
||||
def build_request_details(
|
||||
@@ -243,19 +258,26 @@ def build_request_details(
|
||||
tagged = event.get("type") == "app_mention"
|
||||
message_ts = event.get("ts")
|
||||
thread_ts = event.get("thread_ts")
|
||||
sender = event.get("user") or None
|
||||
expert_info = expert_info_from_slack_id(
|
||||
sender, client.web_client, user_cache={}
|
||||
)
|
||||
email = expert_info.email if expert_info else None
|
||||
|
||||
msg = remove_danswer_bot_tag(msg, client=client.web_client)
|
||||
|
||||
if DANSWER_BOT_REPHRASE_MESSAGE:
|
||||
logger.info(f"Rephrasing Slack message. Original message: {msg}")
|
||||
logger.notice(f"Rephrasing Slack message. Original message: {msg}")
|
||||
try:
|
||||
msg = rephrase_slack_message(msg)
|
||||
logger.info(f"Rephrased message: {msg}")
|
||||
logger.notice(f"Rephrased message: {msg}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error while trying to rephrase the Slack message: {e}")
|
||||
else:
|
||||
logger.notice(f"Received Slack message: {msg}")
|
||||
|
||||
if tagged:
|
||||
logger.info("User tagged DanswerBot")
|
||||
logger.debug("User tagged DanswerBot")
|
||||
|
||||
if thread_ts != message_ts and thread_ts is not None:
|
||||
thread_messages = read_slack_thread(
|
||||
@@ -271,7 +293,8 @@ def build_request_details(
|
||||
channel_to_respond=channel,
|
||||
msg_to_respond=cast(str, message_ts or thread_ts),
|
||||
thread_to_respond=cast(str, thread_ts or message_ts),
|
||||
sender=event.get("user") or None,
|
||||
sender=sender,
|
||||
email=email,
|
||||
bypass_filters=tagged,
|
||||
is_bot_msg=False,
|
||||
is_bot_dm=event.get("channel_type") == "im",
|
||||
@@ -281,6 +304,10 @@ def build_request_details(
|
||||
channel = req.payload["channel_id"]
|
||||
msg = req.payload["text"]
|
||||
sender = req.payload["user_id"]
|
||||
expert_info = expert_info_from_slack_id(
|
||||
sender, client.web_client, user_cache={}
|
||||
)
|
||||
email = expert_info.email if expert_info else None
|
||||
|
||||
single_msg = ThreadMessage(message=msg, sender=None, role=MessageType.USER)
|
||||
|
||||
@@ -290,6 +317,7 @@ def build_request_details(
|
||||
msg_to_respond=None,
|
||||
thread_to_respond=None,
|
||||
sender=sender,
|
||||
email=email,
|
||||
bypass_filters=True,
|
||||
is_bot_msg=True,
|
||||
is_bot_dm=False,
|
||||
@@ -437,7 +465,7 @@ def _initialize_socket_client(socket_client: SocketModeClient) -> None:
|
||||
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
|
||||
|
||||
# Establish a WebSocket connection to the Socket Mode servers
|
||||
logger.info("Listening for messages from Slack...")
|
||||
logger.notice("Listening for messages from Slack...")
|
||||
socket_client.connect()
|
||||
|
||||
|
||||
@@ -454,7 +482,9 @@ if __name__ == "__main__":
|
||||
slack_bot_tokens: SlackBotTokens | None = None
|
||||
socket_client: SocketModeClient | None = None
|
||||
|
||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
download_nltk_data()
|
||||
|
||||
while True:
|
||||
@@ -463,18 +493,21 @@ if __name__ == "__main__":
|
||||
|
||||
if latest_slack_bot_tokens != slack_bot_tokens:
|
||||
if slack_bot_tokens is not None:
|
||||
logger.info("Slack Bot tokens have changed - reconnecting")
|
||||
logger.notice("Slack Bot tokens have changed - reconnecting")
|
||||
else:
|
||||
# This happens on the very first time the listener process comes up
|
||||
# or the tokens have updated (set up for the first time)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
if embedding_model.cloud_provider_id is None:
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
slack_bot_tokens = latest_slack_bot_tokens
|
||||
# potentially may cause a message to be dropped, but it is complicated
|
||||
|
||||
@@ -9,6 +9,7 @@ class SlackMessageInfo(BaseModel):
|
||||
msg_to_respond: str | None
|
||||
thread_to_respond: str | None
|
||||
sender: str | None
|
||||
email: str | None
|
||||
bypass_filters: bool # User has tagged @DanswerBot
|
||||
is_bot_msg: bool # User is using /DanswerBot
|
||||
is_bot_dm: bool # User is direct messaging to DanswerBot
|
||||
|
||||
@@ -3,7 +3,6 @@ import random
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
@@ -22,10 +21,15 @@ from danswer.configs.danswerbot_configs import DANSWER_BOT_FEEDBACK_VISIBILITY
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_MAX_QPM
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_MAX_WAIT_TIME
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
|
||||
from danswer.configs.danswerbot_configs import (
|
||||
DANSWER_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD,
|
||||
)
|
||||
from danswer.configs.danswerbot_configs import (
|
||||
DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS,
|
||||
)
|
||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.connectors.slack.utils import SlackTextCleaner
|
||||
from danswer.danswerbot.slack.constants import FeedbackVisibility
|
||||
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
|
||||
from danswer.danswerbot.slack.tokens import fetch_tokens
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.users import get_user_by_email
|
||||
@@ -43,7 +47,41 @@ from danswer.utils.text_processing import replace_whitespaces_w_space
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
DANSWER_BOT_APP_ID: str | None = None
|
||||
_DANSWER_BOT_APP_ID: str | None = None
|
||||
_DANSWER_BOT_MESSAGE_COUNT: int = 0
|
||||
_DANSWER_BOT_COUNT_START_TIME: float = time.time()
|
||||
|
||||
|
||||
def get_danswer_bot_app_id(web_client: WebClient) -> Any:
|
||||
global _DANSWER_BOT_APP_ID
|
||||
if _DANSWER_BOT_APP_ID is None:
|
||||
_DANSWER_BOT_APP_ID = web_client.auth_test().get("user_id")
|
||||
return _DANSWER_BOT_APP_ID
|
||||
|
||||
|
||||
def check_message_limit() -> bool:
|
||||
"""
|
||||
This isnt a perfect solution.
|
||||
High traffic at the end of one period and start of another could cause
|
||||
the limit to be exceeded.
|
||||
"""
|
||||
if DANSWER_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD == 0:
|
||||
return True
|
||||
global _DANSWER_BOT_MESSAGE_COUNT
|
||||
global _DANSWER_BOT_COUNT_START_TIME
|
||||
time_since_start = time.time() - _DANSWER_BOT_COUNT_START_TIME
|
||||
if time_since_start > DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS:
|
||||
_DANSWER_BOT_MESSAGE_COUNT = 0
|
||||
_DANSWER_BOT_COUNT_START_TIME = time.time()
|
||||
if (_DANSWER_BOT_MESSAGE_COUNT + 1) > DANSWER_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD:
|
||||
logger.error(
|
||||
f"DanswerBot has reached the message limit {DANSWER_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD}"
|
||||
f" for the time period {DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS} seconds."
|
||||
" These limits are configurable in backend/danswer/configs/danswerbot_configs.py"
|
||||
)
|
||||
return False
|
||||
_DANSWER_BOT_MESSAGE_COUNT += 1
|
||||
return True
|
||||
|
||||
|
||||
def rephrase_slack_message(msg: str) -> str:
|
||||
@@ -98,32 +136,11 @@ def update_emote_react(
|
||||
logger.error(f"Was not able to react to user message due to: {e}")
|
||||
|
||||
|
||||
def get_danswer_bot_app_id(web_client: WebClient) -> Any:
|
||||
global DANSWER_BOT_APP_ID
|
||||
if DANSWER_BOT_APP_ID is None:
|
||||
DANSWER_BOT_APP_ID = web_client.auth_test().get("user_id")
|
||||
return DANSWER_BOT_APP_ID
|
||||
|
||||
|
||||
def remove_danswer_bot_tag(message_str: str, client: WebClient) -> str:
|
||||
bot_tag_id = get_danswer_bot_app_id(web_client=client)
|
||||
return re.sub(rf"<@{bot_tag_id}>\s", "", message_str)
|
||||
|
||||
|
||||
class ChannelIdAdapter(logging.LoggerAdapter):
|
||||
"""This is used to add the channel ID to all log messages
|
||||
emitted in this file"""
|
||||
|
||||
def process(
|
||||
self, msg: str, kwargs: MutableMapping[str, Any]
|
||||
) -> tuple[str, MutableMapping[str, Any]]:
|
||||
channel_id = self.extra.get(SLACK_CHANNEL_ID) if self.extra else None
|
||||
if channel_id:
|
||||
return f"[Channel ID: {channel_id}] {msg}", kwargs
|
||||
else:
|
||||
return msg, kwargs
|
||||
|
||||
|
||||
def get_web_client() -> WebClient:
|
||||
slack_tokens = fetch_tokens()
|
||||
return WebClient(token=slack_tokens.bot_token)
|
||||
|
||||
@@ -28,7 +28,7 @@ def get_default_admin_user_emails() -> list[str]:
|
||||
get_default_admin_user_emails_fn: Callable[
|
||||
[], list[str]
|
||||
] = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.auth.users", "get_default_admin_user_emails_", lambda: []
|
||||
"danswer.auth.users", "get_default_admin_user_emails_", lambda: list[str]()
|
||||
)
|
||||
return get_default_admin_user_emails_fn()
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import func
|
||||
@@ -36,7 +35,7 @@ from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SavedSearchDoc
|
||||
from danswer.search.models import SearchDoc as ServerSearchDoc
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.tools.tool_runner import ToolCallMetadata
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -87,29 +86,57 @@ def get_chat_sessions_by_slack_thread_id(
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def get_first_messages_for_chat_sessions(
|
||||
chat_session_ids: list[int], db_session: Session
|
||||
def get_valid_messages_from_query_sessions(
|
||||
chat_session_ids: list[int],
|
||||
db_session: Session,
|
||||
) -> dict[int, str]:
|
||||
subquery = (
|
||||
select(ChatMessage.chat_session_id, func.min(ChatMessage.id).label("min_id"))
|
||||
user_message_subquery = (
|
||||
select(
|
||||
ChatMessage.chat_session_id, func.min(ChatMessage.id).label("user_msg_id")
|
||||
)
|
||||
.where(
|
||||
and_(
|
||||
ChatMessage.chat_session_id.in_(chat_session_ids),
|
||||
ChatMessage.message_type == MessageType.USER, # Select USER messages
|
||||
)
|
||||
ChatMessage.chat_session_id.in_(chat_session_ids),
|
||||
ChatMessage.message_type == MessageType.USER,
|
||||
)
|
||||
.group_by(ChatMessage.chat_session_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = select(ChatMessage.chat_session_id, ChatMessage.message).join(
|
||||
subquery,
|
||||
(ChatMessage.chat_session_id == subquery.c.chat_session_id)
|
||||
& (ChatMessage.id == subquery.c.min_id),
|
||||
assistant_message_subquery = (
|
||||
select(
|
||||
ChatMessage.chat_session_id,
|
||||
func.min(ChatMessage.id).label("assistant_msg_id"),
|
||||
)
|
||||
.where(
|
||||
ChatMessage.chat_session_id.in_(chat_session_ids),
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
)
|
||||
.group_by(ChatMessage.chat_session_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = (
|
||||
select(ChatMessage.chat_session_id, ChatMessage.message)
|
||||
.join(
|
||||
user_message_subquery,
|
||||
ChatMessage.chat_session_id == user_message_subquery.c.chat_session_id,
|
||||
)
|
||||
.join(
|
||||
assistant_message_subquery,
|
||||
ChatMessage.chat_session_id == assistant_message_subquery.c.chat_session_id,
|
||||
)
|
||||
.join(
|
||||
ChatMessage__SearchDoc,
|
||||
ChatMessage__SearchDoc.chat_message_id
|
||||
== assistant_message_subquery.c.assistant_msg_id,
|
||||
)
|
||||
.where(ChatMessage.id == user_message_subquery.c.user_msg_id)
|
||||
)
|
||||
|
||||
first_messages = db_session.execute(query).all()
|
||||
return dict([(row.chat_session_id, row.message) for row in first_messages])
|
||||
logger.info(f"Retrieved {len(first_messages)} first messages with documents")
|
||||
|
||||
return {row.chat_session_id: row.message for row in first_messages}
|
||||
|
||||
|
||||
def get_chat_sessions_by_user(
|
||||
@@ -117,6 +144,7 @@ def get_chat_sessions_by_user(
|
||||
deleted: bool | None,
|
||||
db_session: Session,
|
||||
only_one_shot: bool = False,
|
||||
limit: int = 50,
|
||||
) -> list[ChatSession]:
|
||||
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
|
||||
|
||||
@@ -130,6 +158,9 @@ def get_chat_sessions_by_user(
|
||||
if deleted is not None:
|
||||
stmt = stmt.where(ChatSession.deleted == deleted)
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
chat_sessions = result.scalars().all()
|
||||
|
||||
@@ -255,6 +286,13 @@ def delete_chat_session(
|
||||
db_session: Session,
|
||||
hard_delete: bool = HARD_DELETE_CHATS,
|
||||
) -> None:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
|
||||
)
|
||||
|
||||
if chat_session.deleted:
|
||||
raise ValueError("Cannot delete an already deleted chat session")
|
||||
|
||||
if hard_delete:
|
||||
delete_messages_and_files_from_chat_session(chat_session_id, db_session)
|
||||
db_session.execute(delete(ChatSession).where(ChatSession.id == chat_session_id))
|
||||
@@ -445,7 +483,7 @@ def create_new_chat_message(
|
||||
tool_call: ToolCall | None = None,
|
||||
commit: bool = True,
|
||||
reserved_message_id: int | None = None,
|
||||
alternate_model: str | None = None,
|
||||
overridden_model: str | None = None,
|
||||
) -> ChatMessage:
|
||||
if reserved_message_id is not None:
|
||||
# Edit existing message
|
||||
@@ -462,10 +500,10 @@ def create_new_chat_message(
|
||||
existing_message.message_type = message_type
|
||||
existing_message.citations = citations
|
||||
existing_message.files = files
|
||||
existing_message.tool_call = tool_call
|
||||
existing_message.tool_call = tool_call if tool_call else None
|
||||
existing_message.error = error
|
||||
existing_message.alternate_assistant_id = alternate_assistant_id
|
||||
existing_message.alternate_model = alternate_model
|
||||
existing_message.overridden_model = overridden_model
|
||||
|
||||
new_chat_message = existing_message
|
||||
else:
|
||||
@@ -481,10 +519,10 @@ def create_new_chat_message(
|
||||
message_type=message_type,
|
||||
citations=citations,
|
||||
files=files,
|
||||
tool_call=tool_call,
|
||||
tool_call=tool_call if tool_call else None,
|
||||
error=error,
|
||||
alternate_assistant_id=alternate_assistant_id,
|
||||
alternate_model=alternate_model,
|
||||
overridden_model=overridden_model,
|
||||
)
|
||||
db_session.add(new_chat_message)
|
||||
|
||||
@@ -498,6 +536,7 @@ def create_new_chat_message(
|
||||
parent_message.latest_child_message = new_chat_message.id
|
||||
if commit:
|
||||
db_session.commit()
|
||||
|
||||
return new_chat_message
|
||||
|
||||
|
||||
@@ -714,7 +753,7 @@ def translate_db_message_to_chat_message_detail(
|
||||
time_sent=chat_message.time_sent,
|
||||
citations=chat_message.citations,
|
||||
files=chat_message.files or [],
|
||||
tool_call=ToolCallMetadata(
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
@@ -722,7 +761,7 @@ def translate_db_message_to_chat_message_detail(
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
alternate_assistant_id=chat_message.alternate_assistant_id,
|
||||
alternate_model=chat_message.alternate_model,
|
||||
overridden_model=chat_message.overridden_model,
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
||||
@@ -75,8 +75,8 @@ def fetch_ingestion_connector_by_name(
|
||||
|
||||
|
||||
def create_connector(
|
||||
connector_data: ConnectorBase,
|
||||
db_session: Session,
|
||||
connector_data: ConnectorBase,
|
||||
) -> ObjectCreationIdResponse:
|
||||
if connector_by_name_source_exists(
|
||||
connector_data.name, connector_data.source, db_session
|
||||
@@ -132,8 +132,8 @@ def update_connector(
|
||||
|
||||
|
||||
def delete_connector(
|
||||
connector_id: int,
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
) -> StatusResponse[int]:
|
||||
"""Only used in special cases (e.g. a connector is in a bad state and we need to delete it).
|
||||
Be VERY careful using this, as it could lead to a bad state if not used correctly.
|
||||
|
||||
@@ -3,7 +3,10 @@ from datetime import datetime
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
@@ -11,35 +14,127 @@ from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.credentials import fetch_credential_by_id
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.db.models import UserGroup__ConnectorCredentialPair
|
||||
from danswer.db.models import UserRole
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _add_user_filters(
|
||||
stmt: Select, user: User | None, get_editable: bool = True
|
||||
) -> Select:
|
||||
# If user is None, assume the user is an admin or auth is disabled
|
||||
if user is None or user.role == UserRole.ADMIN:
|
||||
return stmt
|
||||
|
||||
UG__CCpair = aliased(UserGroup__ConnectorCredentialPair)
|
||||
User__UG = aliased(User__UserGroup)
|
||||
|
||||
"""
|
||||
Here we select cc_pairs by relation:
|
||||
User -> User__UserGroup -> UserGroup__ConnectorCredentialPair ->
|
||||
ConnectorCredentialPair
|
||||
"""
|
||||
stmt = stmt.outerjoin(UG__CCpair).outerjoin(
|
||||
User__UG,
|
||||
User__UG.user_group_id == UG__CCpair.user_group_id,
|
||||
)
|
||||
|
||||
"""
|
||||
Filter cc_pairs by:
|
||||
- if the user is in the user_group that owns the cc_pair
|
||||
- if the user is not a global_curator, they must also have a curator relationship
|
||||
to the user_group
|
||||
- if editing is being done, we also filter out cc_pairs that are owned by groups
|
||||
that the user isn't a curator for
|
||||
- if we are not editing, we show all cc_pairs in the groups the user is a curator
|
||||
for (as well as public cc_pairs)
|
||||
"""
|
||||
where_clause = User__UG.user_id == user.id
|
||||
if user.role == UserRole.CURATOR and get_editable:
|
||||
where_clause &= User__UG.is_curator == True # noqa: E712
|
||||
if get_editable:
|
||||
user_groups = select(User__UG.user_group_id).where(User__UG.user_id == user.id)
|
||||
if user.role == UserRole.CURATOR:
|
||||
user_groups = user_groups.where(
|
||||
User__UserGroup.is_curator == True # noqa: E712
|
||||
)
|
||||
where_clause &= (
|
||||
~exists()
|
||||
.where(UG__CCpair.cc_pair_id == ConnectorCredentialPair.id)
|
||||
.where(~UG__CCpair.user_group_id.in_(user_groups))
|
||||
.correlate(ConnectorCredentialPair)
|
||||
)
|
||||
else:
|
||||
where_clause |= ConnectorCredentialPair.is_public == True # noqa: E712
|
||||
|
||||
return stmt.where(where_clause)
|
||||
|
||||
|
||||
def get_connector_credential_pairs(
|
||||
db_session: Session, include_disabled: bool = True
|
||||
db_session: Session,
|
||||
include_disabled: bool = True,
|
||||
user: User | None = None,
|
||||
get_editable: bool = True,
|
||||
ids: list[int] | None = None,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
stmt = select(ConnectorCredentialPair)
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
if not include_disabled:
|
||||
stmt = stmt.where(
|
||||
ConnectorCredentialPair.status == ConnectorCredentialPairStatus.ACTIVE
|
||||
) # noqa
|
||||
if ids:
|
||||
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
|
||||
results = db_session.scalars(stmt)
|
||||
return list(results.all())
|
||||
|
||||
|
||||
def add_deletion_failure_message(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
failure_message: str,
|
||||
) -> None:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if not cc_pair:
|
||||
return
|
||||
cc_pair.deletion_failure_message = failure_message
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_cc_pair_groups_for_ids(
|
||||
db_session: Session,
|
||||
cc_pair_ids: list[int],
|
||||
user: User | None = None,
|
||||
get_editable: bool = True,
|
||||
) -> list[UserGroup__ConnectorCredentialPair]:
|
||||
stmt = select(UserGroup__ConnectorCredentialPair).distinct()
|
||||
stmt = stmt.outerjoin(
|
||||
ConnectorCredentialPair,
|
||||
UserGroup__ConnectorCredentialPair.cc_pair_id == ConnectorCredentialPair.id,
|
||||
)
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
stmt = stmt.where(UserGroup__ConnectorCredentialPair.cc_pair_id.in_(cc_pair_ids))
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def get_connector_credential_pair(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
get_editable: bool = True,
|
||||
) -> ConnectorCredentialPair | None:
|
||||
stmt = select(ConnectorCredentialPair)
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
stmt = stmt.where(ConnectorCredentialPair.connector_id == connector_id)
|
||||
stmt = stmt.where(ConnectorCredentialPair.credential_id == credential_id)
|
||||
result = db_session.execute(stmt)
|
||||
@@ -49,8 +144,11 @@ def get_connector_credential_pair(
|
||||
def get_connector_credential_source_from_id(
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
get_editable: bool = True,
|
||||
) -> DocumentSource | None:
|
||||
stmt = select(ConnectorCredentialPair)
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
|
||||
result = db_session.execute(stmt)
|
||||
cc_pair = result.scalar_one_or_none()
|
||||
@@ -60,8 +158,11 @@ def get_connector_credential_source_from_id(
|
||||
def get_connector_credential_pair_from_id(
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
get_editable: bool = True,
|
||||
) -> ConnectorCredentialPair | None:
|
||||
stmt = select(ConnectorCredentialPair)
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
|
||||
result = db_session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
@@ -70,12 +171,13 @@ def get_connector_credential_pair_from_id(
|
||||
def get_last_successful_attempt_time(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
embedding_model: EmbeddingModel,
|
||||
earliest_index: float,
|
||||
search_settings: SearchSettings,
|
||||
db_session: Session,
|
||||
) -> float:
|
||||
"""Gets the timestamp of the last successful index run stored in
|
||||
the CC Pair row in the database"""
|
||||
if embedding_model.status == IndexModelStatus.PRESENT:
|
||||
if search_settings.status == IndexModelStatus.PRESENT:
|
||||
connector_credential_pair = get_connector_credential_pair(
|
||||
connector_id, credential_id, db_session
|
||||
)
|
||||
@@ -83,7 +185,7 @@ def get_last_successful_attempt_time(
|
||||
connector_credential_pair is None
|
||||
or connector_credential_pair.last_successful_index_time is None
|
||||
):
|
||||
return 0.0
|
||||
return earliest_index
|
||||
|
||||
return connector_credential_pair.last_successful_index_time.timestamp()
|
||||
|
||||
@@ -97,17 +199,15 @@ def get_last_successful_attempt_time(
|
||||
.filter(
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
IndexAttempt.embedding_model_id == embedding_model.id,
|
||||
IndexAttempt.search_settings_id == search_settings.id,
|
||||
IndexAttempt.status == IndexingStatus.SUCCESS,
|
||||
)
|
||||
.order_by(IndexAttempt.time_started.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if not attempt or not attempt.time_started:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
if connector and connector.indexing_start:
|
||||
return connector.indexing_start.timestamp()
|
||||
return 0.0
|
||||
return earliest_index
|
||||
|
||||
return attempt.time_started.timestamp()
|
||||
|
||||
@@ -217,14 +317,28 @@ def associate_default_cc_pair(db_session: Session) -> None:
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _relate_groups_to_cc_pair__no_commit(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
user_group_ids: list[int],
|
||||
) -> None:
|
||||
for group_id in user_group_ids:
|
||||
db_session.add(
|
||||
UserGroup__ConnectorCredentialPair(
|
||||
user_group_id=group_id, cc_pair_id=cc_pair_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def add_credential_to_connector(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
cc_pair_name: str | None,
|
||||
is_public: bool,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> StatusResponse[int]:
|
||||
groups: list[int] | None,
|
||||
) -> StatusResponse:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
|
||||
@@ -232,9 +346,13 @@ def add_credential_to_connector(
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
|
||||
if credential is None:
|
||||
error_msg = (
|
||||
f"Credential {credential_id} does not exist or does not belong to user"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Credential does not exist or does not belong to user",
|
||||
detail=error_msg,
|
||||
)
|
||||
|
||||
existing_association = (
|
||||
@@ -248,7 +366,7 @@ def add_credential_to_connector(
|
||||
if existing_association is not None:
|
||||
return StatusResponse(
|
||||
success=False,
|
||||
message=f"Connector already has Credential {credential_id}",
|
||||
message=f"Connector {connector_id} already has Credential {credential_id}",
|
||||
data=connector_id,
|
||||
)
|
||||
|
||||
@@ -260,12 +378,21 @@ def add_credential_to_connector(
|
||||
is_public=is_public,
|
||||
)
|
||||
db_session.add(association)
|
||||
db_session.flush() # make sure the association has an id
|
||||
|
||||
if groups:
|
||||
_relate_groups_to_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=association.id,
|
||||
user_group_ids=groups,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message=f"New Credential {credential_id} added to Connector",
|
||||
data=connector_id,
|
||||
message=f"Creating new association between Connector {connector_id} and Credential {credential_id}",
|
||||
data=association.id,
|
||||
)
|
||||
|
||||
|
||||
@@ -287,13 +414,12 @@ def remove_credential_from_connector(
|
||||
detail="Credential does not exist or does not belong to user",
|
||||
)
|
||||
|
||||
association = (
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.filter(
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
.one_or_none()
|
||||
association = get_connector_credential_pair(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=True,
|
||||
)
|
||||
|
||||
if association is not None:
|
||||
@@ -334,11 +460,11 @@ def resync_cc_pair(
|
||||
ConnectorCredentialPair,
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
)
|
||||
.join(EmbeddingModel, IndexAttempt.embedding_model_id == EmbeddingModel.id)
|
||||
.join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
|
||||
.filter(
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
EmbeddingModel.status == IndexModelStatus.PRESENT,
|
||||
SearchSettings.status == IndexModelStatus.PRESENT,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
@@ -17,8 +18,10 @@ from danswer.connectors.google_drive.constants import (
|
||||
)
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import Credential__UserGroup
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.server.documents.models import CredentialDataUpdateRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -26,42 +29,122 @@ from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# The credentials for these sources are not real so
|
||||
# permissions are not enforced for them
|
||||
CREDENTIAL_PERMISSIONS_TO_IGNORE = {
|
||||
DocumentSource.FILE,
|
||||
DocumentSource.WEB,
|
||||
DocumentSource.NOT_APPLICABLE,
|
||||
DocumentSource.GOOGLE_SITES,
|
||||
DocumentSource.WIKIPEDIA,
|
||||
DocumentSource.MEDIAWIKI,
|
||||
}
|
||||
|
||||
def _attach_user_filters(
|
||||
stmt: Select[tuple[Credential]],
|
||||
|
||||
def _add_user_filters(
|
||||
stmt: Select,
|
||||
user: User | None,
|
||||
assume_admin: bool = False, # Used with API key
|
||||
get_editable: bool = True,
|
||||
) -> Select:
|
||||
"""Attaches filters to the statement to ensure that the user can only
|
||||
access the appropriate credentials"""
|
||||
if user:
|
||||
if user.role == UserRole.ADMIN:
|
||||
if not user:
|
||||
if assume_admin:
|
||||
# apply admin filters minus the user_id check
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
Credential.user_id == user.id,
|
||||
Credential.user_id.is_(None),
|
||||
Credential.admin_public == True, # noqa: E712
|
||||
Credential.source.in_(CREDENTIAL_PERMISSIONS_TO_IGNORE),
|
||||
)
|
||||
)
|
||||
else:
|
||||
stmt = stmt.where(Credential.user_id == user.id)
|
||||
elif assume_admin:
|
||||
stmt = stmt.where(
|
||||
return stmt
|
||||
|
||||
if user.role == UserRole.ADMIN:
|
||||
# Admins can access all credentials that are public or owned by them
|
||||
# or are not associated with any user
|
||||
return stmt.where(
|
||||
or_(
|
||||
Credential.user_id == user.id,
|
||||
Credential.user_id.is_(None),
|
||||
Credential.admin_public == True, # noqa: E712
|
||||
Credential.source.in_(CREDENTIAL_PERMISSIONS_TO_IGNORE),
|
||||
)
|
||||
)
|
||||
if user.role == UserRole.BASIC:
|
||||
# Basic users can only access credentials that are owned by them
|
||||
return stmt.where(Credential.user_id == user.id)
|
||||
|
||||
return stmt
|
||||
"""
|
||||
THIS PART IS FOR CURATORS AND GLOBAL CURATORS
|
||||
Here we select cc_pairs by relation:
|
||||
User -> User__UserGroup -> Credential__UserGroup -> Credential
|
||||
"""
|
||||
stmt = stmt.outerjoin(Credential__UserGroup).outerjoin(
|
||||
User__UserGroup,
|
||||
User__UserGroup.user_group_id == Credential__UserGroup.user_group_id,
|
||||
)
|
||||
"""
|
||||
Filter Credentials by:
|
||||
- if the user is in the user_group that owns the Credential
|
||||
- if the user is not a global_curator, they must also have a curator relationship
|
||||
to the user_group
|
||||
- if editing is being done, we also filter out Credentials that are owned by groups
|
||||
that the user isn't a curator for
|
||||
- if we are not editing, we show all Credentials in the groups the user is a curator
|
||||
for (as well as public Credentials)
|
||||
- if we are not editing, we return all Credentials directly connected to the user
|
||||
"""
|
||||
where_clause = User__UserGroup.user_id == user.id
|
||||
if user.role == UserRole.CURATOR:
|
||||
where_clause &= User__UserGroup.is_curator == True # noqa: E712
|
||||
if get_editable:
|
||||
user_groups = select(User__UserGroup.user_group_id).where(
|
||||
User__UserGroup.user_id == user.id
|
||||
)
|
||||
if user.role == UserRole.CURATOR:
|
||||
user_groups = user_groups.where(
|
||||
User__UserGroup.is_curator == True # noqa: E712
|
||||
)
|
||||
where_clause &= (
|
||||
~exists()
|
||||
.where(Credential__UserGroup.credential_id == Credential.id)
|
||||
.where(~Credential__UserGroup.user_group_id.in_(user_groups))
|
||||
.correlate(Credential)
|
||||
)
|
||||
else:
|
||||
where_clause |= Credential.curator_public == True # noqa: E712
|
||||
where_clause |= Credential.user_id == user.id # noqa: E712
|
||||
|
||||
where_clause |= Credential.source.in_(CREDENTIAL_PERMISSIONS_TO_IGNORE)
|
||||
|
||||
return stmt.where(where_clause)
|
||||
|
||||
|
||||
def _relate_credential_to_user_groups__no_commit(
|
||||
db_session: Session,
|
||||
credential_id: int,
|
||||
user_group_ids: list[int],
|
||||
) -> None:
|
||||
credential_user_groups = []
|
||||
for group_id in user_group_ids:
|
||||
credential_user_groups.append(
|
||||
Credential__UserGroup(
|
||||
credential_id=credential_id,
|
||||
user_group_id=group_id,
|
||||
)
|
||||
)
|
||||
db_session.add_all(credential_user_groups)
|
||||
|
||||
|
||||
def fetch_credentials(
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
get_editable: bool = True,
|
||||
) -> list[Credential]:
|
||||
stmt = select(Credential)
|
||||
stmt = _attach_user_filters(stmt, user)
|
||||
stmt = _add_user_filters(stmt, user, get_editable=get_editable)
|
||||
results = db_session.scalars(stmt)
|
||||
return list(results.all())
|
||||
|
||||
@@ -72,8 +155,9 @@ def fetch_credential_by_id(
|
||||
db_session: Session,
|
||||
assume_admin: bool = False,
|
||||
) -> Credential | None:
|
||||
stmt = select(Credential).where(Credential.id == credential_id)
|
||||
stmt = _attach_user_filters(stmt, user, assume_admin=assume_admin)
|
||||
stmt = select(Credential).distinct()
|
||||
stmt = stmt.where(Credential.id == credential_id)
|
||||
stmt = _add_user_filters(stmt, user, assume_admin=assume_admin)
|
||||
result = db_session.execute(stmt)
|
||||
credential = result.scalar_one_or_none()
|
||||
return credential
|
||||
@@ -83,9 +167,10 @@ def fetch_credentials_by_source(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
document_source: DocumentSource | None = None,
|
||||
get_editable: bool = True,
|
||||
) -> list[Credential]:
|
||||
base_query = select(Credential).where(Credential.source == document_source)
|
||||
base_query = _attach_user_filters(base_query, user)
|
||||
base_query = _add_user_filters(base_query, user, get_editable=get_editable)
|
||||
credentials = db_session.execute(base_query).scalars().all()
|
||||
return list(credentials)
|
||||
|
||||
@@ -153,19 +238,38 @@ def create_credential(
|
||||
admin_public=credential_data.admin_public,
|
||||
source=credential_data.source,
|
||||
name=credential_data.name,
|
||||
curator_public=credential_data.curator_public,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush() # This ensures the credential gets an ID
|
||||
|
||||
_relate_credential_to_user_groups__no_commit(
|
||||
db_session=db_session,
|
||||
credential_id=credential.id,
|
||||
user_group_ids=credential_data.groups,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return credential
|
||||
|
||||
|
||||
def _cleanup_credential__user_group_relationships__no_commit(
|
||||
db_session: Session, credential_id: int
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
db_session.query(Credential__UserGroup).filter(
|
||||
Credential__UserGroup.credential_id == credential_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
|
||||
def alter_credential(
|
||||
credential_id: int,
|
||||
credential_data: CredentialDataUpdateRequest,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> Credential | None:
|
||||
# TODO: add user group relationship update
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
|
||||
if credential is None:
|
||||
@@ -271,10 +375,11 @@ def delete_credential(
|
||||
)
|
||||
|
||||
if force:
|
||||
logger.info(f"Force deleting credential {credential_id}")
|
||||
logger.warning(f"Force deleting credential {credential_id}")
|
||||
else:
|
||||
logger.info(f"Deleting credential {credential_id}")
|
||||
logger.notice(f"Deleting credential {credential_id}")
|
||||
|
||||
_cleanup_credential__user_group_relationships__no_commit(db_session, credential_id)
|
||||
db_session.delete(credential)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.index_attempt import get_last_attempt
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
|
||||
|
||||
def check_deletion_attempt_is_allowed(
|
||||
@@ -24,20 +23,17 @@ def check_deletion_attempt_is_allowed(
|
||||
f"'{connector_credential_pair.credential_id}' is not deletable."
|
||||
)
|
||||
|
||||
if (
|
||||
connector_credential_pair.status != ConnectorCredentialPairStatus.PAUSED
|
||||
and connector_credential_pair.status != ConnectorCredentialPairStatus.DELETING
|
||||
):
|
||||
if connector_credential_pair.status.is_active():
|
||||
return base_error_msg + " Connector must be paused."
|
||||
|
||||
connector_id = connector_credential_pair.connector_id
|
||||
credential_id = connector_credential_pair.credential_id
|
||||
current_embedding_model = get_current_db_embedding_model(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
last_indexing = get_last_attempt(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
embedding_model_id=current_embedding_model.id,
|
||||
search_settings_id=search_settings.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
@@ -10,6 +11,7 @@ from sqlalchemy import delete
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.engine.util import TransactionalContext
|
||||
@@ -38,6 +40,68 @@ def check_docs_exist(db_session: Session) -> bool:
|
||||
return result.scalar() or False
|
||||
|
||||
|
||||
def count_documents_by_needs_sync(session: Session) -> int:
|
||||
"""Get the count of all documents where:
|
||||
1. last_modified is newer than last_synced
|
||||
2. last_synced is null (meaning we've never synced)
|
||||
|
||||
This function executes the query and returns the count of
|
||||
documents matching the criteria."""
|
||||
|
||||
count = (
|
||||
session.query(func.count())
|
||||
.select_from(DbDocument)
|
||||
.filter(
|
||||
or_(
|
||||
DbDocument.last_modified > DbDocument.last_synced,
|
||||
DbDocument.last_synced.is_(None),
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
connector_id: int, credential_id: int
|
||||
) -> Select:
|
||||
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
)
|
||||
|
||||
stmt = (
|
||||
select(DbDocument)
|
||||
.where(
|
||||
DbDocument.id.in_(initial_doc_ids_stmt),
|
||||
or_(
|
||||
DbDocument.last_modified
|
||||
> DbDocument.last_synced, # last_modified is newer than last_synced
|
||||
DbDocument.last_synced.is_(None), # never synced
|
||||
),
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
|
||||
return stmt
|
||||
|
||||
|
||||
def construct_document_select_for_connector_credential_pair(
|
||||
connector_id: int, credential_id: int | None = None
|
||||
) -> Select:
|
||||
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
)
|
||||
stmt = select(DbDocument).where(DbDocument.id.in_(initial_doc_ids_stmt)).distinct()
|
||||
return stmt
|
||||
|
||||
|
||||
def get_documents_for_connector_credential_pair(
|
||||
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
|
||||
) -> Sequence[DbDocument]:
|
||||
@@ -108,7 +172,29 @@ def get_document_cnts_for_cc_pairs(
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
|
||||
|
||||
def get_acccess_info_for_documents(
|
||||
def get_access_info_for_document(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
) -> tuple[str, list[UUID | None], bool] | None:
|
||||
"""Gets access info for a single document by calling the get_access_info_for_documents function
|
||||
and passing a list with a single document ID.
|
||||
|
||||
Args:
|
||||
db_session (Session): The database session to use.
|
||||
document_id (str): The document ID to fetch access info for.
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, List[UUID | None], bool]]: A tuple containing the document ID, a list of user IDs,
|
||||
and a boolean indicating if the document is globally public, or None if no results are found.
|
||||
"""
|
||||
results = get_access_info_for_documents(db_session, [document_id])
|
||||
if not results:
|
||||
return None
|
||||
|
||||
return results[0]
|
||||
|
||||
|
||||
def get_access_info_for_documents(
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
) -> Sequence[tuple[str, list[UUID | None], bool]]:
|
||||
@@ -173,6 +259,7 @@ def upsert_documents(
|
||||
semantic_id=doc.semantic_identifier,
|
||||
link=doc.first_link,
|
||||
doc_updated_at=None, # this is intentional
|
||||
last_modified=datetime.now(timezone.utc),
|
||||
primary_owners=doc.primary_owners,
|
||||
secondary_owners=doc.secondary_owners,
|
||||
)
|
||||
@@ -214,7 +301,7 @@ def upsert_document_by_connector_credential_pair(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_docs_updated_at(
|
||||
def update_docs_updated_at__no_commit(
|
||||
ids_to_new_updated_at: dict[str, datetime],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
@@ -226,6 +313,28 @@ def update_docs_updated_at(
|
||||
for document in documents_to_update:
|
||||
document.doc_updated_at = ids_to_new_updated_at[document.id]
|
||||
|
||||
|
||||
def update_docs_last_modified__no_commit(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
documents_to_update = (
|
||||
db_session.query(DbDocument).filter(DbDocument.id.in_(document_ids)).all()
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
for doc in documents_to_update:
|
||||
doc.last_modified = now
|
||||
|
||||
|
||||
def mark_document_as_synced(document_id: str, db_session: Session) -> None:
|
||||
stmt = select(DbDocument).where(DbDocument.id == document_id)
|
||||
doc = db_session.scalar(stmt)
|
||||
if doc is None:
|
||||
raise ValueError(f"No document with ID: {document_id}")
|
||||
|
||||
# update last_synced
|
||||
doc.last_synced = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -317,7 +426,7 @@ def prepare_to_modify_documents(
|
||||
called ahead of any modification to Vespa. Locks should be released by the
|
||||
caller as soon as updates are complete by finishing the transaction.
|
||||
|
||||
NOTE: only one commit is allowed within the context manager returned by this funtion.
|
||||
NOTE: only one commit is allowed within the context manager returned by this function.
|
||||
Multiple commits will result in a sqlalchemy.exc.InvalidRequestError.
|
||||
NOTE: this function will commit any existing transaction.
|
||||
"""
|
||||
@@ -335,7 +444,9 @@ def prepare_to_modify_documents(
|
||||
yield transaction
|
||||
break
|
||||
except OperationalError as e:
|
||||
logger.info(f"Failed to acquire locks for documents, retrying. Error: {e}")
|
||||
logger.warning(
|
||||
f"Failed to acquire locks for documents, retrying. Error: {e}"
|
||||
)
|
||||
|
||||
time.sleep(retry_delay)
|
||||
|
||||
@@ -377,3 +488,12 @@ def get_documents_by_cc_pair(
|
||||
.filter(ConnectorCredentialPair.id == cc_pair_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_document(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> DbDocument | None:
|
||||
stmt = select(DbDocument).where(DbDocument.id == document_id)
|
||||
doc: DbDocument | None = db_session.execute(stmt).scalar_one_or_none()
|
||||
return doc
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user