mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-21 09:45:46 +00:00
Compare commits
252 Commits
v0.3.87
...
eval/split
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4293543a6a | ||
|
|
e95bfa0e0b | ||
|
|
4848b5f1de | ||
|
|
7ba5c434fa | ||
|
|
59bf5ba848 | ||
|
|
f66c33380c | ||
|
|
115650ce9f | ||
|
|
7aa3602fca | ||
|
|
864c552a17 | ||
|
|
07b2ed3d8f | ||
|
|
38290057f2 | ||
|
|
2344edf158 | ||
|
|
86d1804eb0 | ||
|
|
1ebae50d0c | ||
|
|
a9fbaa396c | ||
|
|
27d5f69427 | ||
|
|
5d98421ae8 | ||
|
|
6b561b8ca9 | ||
|
|
2dc7e64dd7 | ||
|
|
5230f7e22f | ||
|
|
a595d43ae3 | ||
|
|
ee561f42ff | ||
|
|
f00b3d76b3 | ||
|
|
e4984153c0 | ||
|
|
87fadb07ea | ||
|
|
2b07c102f9 | ||
|
|
e93de602c3 | ||
|
|
1c77395503 | ||
|
|
cdf6089b3e | ||
|
|
d01f46af2b | ||
|
|
b83f435bb0 | ||
|
|
25b3dacaba | ||
|
|
a1e638a73d | ||
|
|
bd1e0c5969 | ||
|
|
4d295ab97d | ||
|
|
6fe3eeaa48 | ||
|
|
078d5defbb | ||
|
|
0d52e99bd4 | ||
|
|
1b864a00e4 | ||
|
|
dae4f6a0bd | ||
|
|
f63d0ca3ad | ||
|
|
da31da33e7 | ||
|
|
56b175f597 | ||
|
|
1b311d092e | ||
|
|
6ee1292757 | ||
|
|
017af052be | ||
|
|
e7f81d1688 | ||
|
|
b6bd818e60 | ||
|
|
36da2e4b27 | ||
|
|
c7af6a4601 | ||
|
|
e90c66c1b6 | ||
|
|
8c312482c1 | ||
|
|
e50820e65e | ||
|
|
991ee79e47 | ||
|
|
3e645a510e | ||
|
|
08c6e821e7 | ||
|
|
47a550221f | ||
|
|
511f619212 | ||
|
|
6c51f001dc | ||
|
|
09a11b5e1a | ||
|
|
aa0f7abdac | ||
|
|
7c8f8dba17 | ||
|
|
39982e5fdc | ||
|
|
5e0de111f9 | ||
|
|
727d80f168 | ||
|
|
146f85936b | ||
|
|
e06f8a0a4b | ||
|
|
f0888f2f61 | ||
|
|
d35d7ee833 | ||
|
|
c5bb3fde94 | ||
|
|
79190030a5 | ||
|
|
8e8f262ed3 | ||
|
|
ac14369716 | ||
|
|
de4d8e9a65 | ||
|
|
0b384c5b34 | ||
|
|
fa049f4f98 | ||
|
|
72d6a0ef71 | ||
|
|
ae4e643266 | ||
|
|
a7da07afc0 | ||
|
|
7f1bb67e52 | ||
|
|
982b1b0c49 | ||
|
|
2db128fb36 | ||
|
|
3ebac6256f | ||
|
|
1a3ec59610 | ||
|
|
581cb827bb | ||
|
|
393b3c9343 | ||
|
|
2035e9f39c | ||
|
|
52c3a5e9d2 | ||
|
|
3e45a41617 | ||
|
|
415960564d | ||
|
|
ed550986a6 | ||
|
|
60dd77393d | ||
|
|
3fe5313b02 | ||
|
|
bd0925611a | ||
|
|
de6d040349 | ||
|
|
38da3128d8 | ||
|
|
e47da0d688 | ||
|
|
2c0e0c5f11 | ||
|
|
29d57f6354 | ||
|
|
369e607631 | ||
|
|
f03f97307f | ||
|
|
145cdb69b7 | ||
|
|
9310a8edc2 | ||
|
|
2140f80891 | ||
|
|
52dab23295 | ||
|
|
91c9b2eb42 | ||
|
|
5764cdd469 | ||
|
|
8fea6d7f64 | ||
|
|
5324b15397 | ||
|
|
8be42a5f98 | ||
|
|
062dc98719 | ||
|
|
43557f738b | ||
|
|
b5aa7370a2 | ||
|
|
4ba6e45128 | ||
|
|
d6e5a98a22 | ||
|
|
20c4cdbdda | ||
|
|
0d814939ee | ||
|
|
7d2b0ffcc5 | ||
|
|
8c6cd661f5 | ||
|
|
5d552705aa | ||
|
|
1ee8ee9e8b | ||
|
|
f0b2b57d81 | ||
|
|
5c12a3e872 | ||
|
|
3af81ca96b | ||
|
|
f55e5415bb | ||
|
|
3d434c2c9e | ||
|
|
90ec156791 | ||
|
|
8ba48e24a6 | ||
|
|
e34bcbbd06 | ||
|
|
db319168f8 | ||
|
|
010ce5395f | ||
|
|
98a58337a7 | ||
|
|
733d4e666b | ||
|
|
2937fe9e7d | ||
|
|
457527ac86 | ||
|
|
7cc51376f2 | ||
|
|
7278d45552 | ||
|
|
1c343bbee7 | ||
|
|
bdcfb39724 | ||
|
|
694d20ea8f | ||
|
|
45402d0755 | ||
|
|
69740ba3d5 | ||
|
|
6162283beb | ||
|
|
44284f7912 | ||
|
|
775ca5787b | ||
|
|
c6e49a3034 | ||
|
|
9c8cfd9175 | ||
|
|
fc3ed76d12 | ||
|
|
a2597d5f21 | ||
|
|
af588461d2 | ||
|
|
460e61b3a7 | ||
|
|
c631ac0c3a | ||
|
|
10be91a8cc | ||
|
|
eadad34a77 | ||
|
|
b19d88a151 | ||
|
|
e33b469915 | ||
|
|
719fc06604 | ||
|
|
d7a704c0d9 | ||
|
|
7a408749cf | ||
|
|
d9acd03a85 | ||
|
|
af94c092e7 | ||
|
|
f55a4ef9bd | ||
|
|
6c6e33e001 | ||
|
|
336c046e5d | ||
|
|
9a9b89f073 | ||
|
|
89fac98534 | ||
|
|
65b65518de | ||
|
|
0c827d1e6c | ||
|
|
1984f2c1ca | ||
|
|
50f006557f | ||
|
|
c00bd44bcc | ||
|
|
680aca68e5 | ||
|
|
22a2f86fb9 | ||
|
|
c055dc1535 | ||
|
|
81e9880d9d | ||
|
|
3466f6d3a4 | ||
|
|
91cf45165f | ||
|
|
ee2a5bbf49 | ||
|
|
153007c57c | ||
|
|
fa8cc10063 | ||
|
|
2c3ba5f021 | ||
|
|
e3ef620094 | ||
|
|
40369e0538 | ||
|
|
d6c5c65b51 | ||
|
|
7b16cb9562 | ||
|
|
ef4f06a375 | ||
|
|
17cc262f5d | ||
|
|
680482bd06 | ||
|
|
64874d2737 | ||
|
|
00ade322f1 | ||
|
|
eab5d054d5 | ||
|
|
a09d60d7d0 | ||
|
|
f17dc52b37 | ||
|
|
c1862e961b | ||
|
|
6b46a71cb5 | ||
|
|
9ae3a4af7f | ||
|
|
328b96c9ff | ||
|
|
bac34a47b2 | ||
|
|
15934ee268 | ||
|
|
fe975c3357 | ||
|
|
8bf483904d | ||
|
|
db338bfddf | ||
|
|
ae02a5199a | ||
|
|
4b44073d9a | ||
|
|
ce36530c79 | ||
|
|
39d69838c5 | ||
|
|
e11f0f6202 | ||
|
|
ce870ff577 | ||
|
|
a52711967f | ||
|
|
67a4eb6f6f | ||
|
|
9599388db8 | ||
|
|
f82ae158ea | ||
|
|
670de6c00d | ||
|
|
56c52bddff | ||
|
|
3984350ff9 | ||
|
|
f799d9aa11 | ||
|
|
529f2c8c2d | ||
|
|
b4683dc841 | ||
|
|
db8ce61ff4 | ||
|
|
d016e8335e | ||
|
|
0c295d1de5 | ||
|
|
e9f273d99a | ||
|
|
428f5edd21 | ||
|
|
50170cc97e | ||
|
|
7503f8f37b | ||
|
|
92de6acc6f | ||
|
|
65d5808ea7 | ||
|
|
061dab7f37 | ||
|
|
e65d9e155d | ||
|
|
50f799edf4 | ||
|
|
c1d8f6cb66 | ||
|
|
6c71bc05ea | ||
|
|
123ec4342a | ||
|
|
7253316b9e | ||
|
|
4ae924662c | ||
|
|
094eea2742 | ||
|
|
8178d536b4 | ||
|
|
5cafc96cae | ||
|
|
3e39a921b0 | ||
|
|
98b2507045 | ||
|
|
3dfe17a54d | ||
|
|
b4675082b1 | ||
|
|
287a706e89 | ||
|
|
ba58208a85 | ||
|
|
694e9e8679 | ||
|
|
9e30ec1f1f | ||
|
|
1b56c75527 | ||
|
|
b07fdbf1d1 | ||
|
|
54c2547d89 | ||
|
|
58b5e25c97 | ||
|
|
4e15ba78d5 | ||
|
|
c798ade127 |
33
.github/workflows/docker-build-backend-container-on-merge-group.yml
vendored
Normal file
33
.github/workflows/docker-build-backend-container-on-merge-group.yml
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
name: Build Backend Image on Merge Group
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
types: [checks_requested]
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-backend
|
||||
|
||||
jobs:
|
||||
build:
|
||||
# TODO: make this a matrix build like the web containers
|
||||
runs-on:
|
||||
group: amd64-image-builders
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Backend Image Docker Build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: false
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:latest
|
||||
build-args: |
|
||||
DANSWER_VERSION=v0.0.1
|
||||
@@ -5,9 +5,14 @@ on:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-backend
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
# TODO: make this a matrix build like the web containers
|
||||
runs-on:
|
||||
group: amd64-image-builders
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@@ -30,8 +35,8 @@ jobs:
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
danswer/danswer-backend:${{ github.ref_name }}
|
||||
danswer/danswer-backend:latest
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.REGISTRY_IMAGE }}:latest
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
@@ -39,6 +44,6 @@ jobs:
|
||||
uses: aquasecurity/trivy-action@master
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
|
||||
image-ref: docker.io/danswer/danswer-backend:${{ github.ref_name }}
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
trivyignores: ./backend/.trivyignore
|
||||
|
||||
@@ -10,7 +10,7 @@ env:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on:
|
||||
runs-on:
|
||||
group: ${{ matrix.platform == 'linux/amd64' && 'amd64-image-builders' || 'arm64-image-builders' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -34,8 +34,8 @@ jobs:
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
tags: |
|
||||
type=raw,value=danswer/danswer-web-server:${{ github.ref_name }}
|
||||
type=raw,value=danswer/danswer-web-server:latest
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
53
.github/workflows/docker-build-web-container-on-merge-group.yml
vendored
Normal file
53
.github/workflows/docker-build-web-container-on-merge-group.yml
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
name: Build Web Image on Merge Group
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
types: [checks_requested]
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-web-server
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on:
|
||||
group: ${{ matrix.platform == 'linux/amd64' && 'amd64-image-builders' || 'arm64-image-builders' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
platform:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
|
||||
steps:
|
||||
- name: Prepare
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: ${{ matrix.platform }}
|
||||
push: false
|
||||
build-args: |
|
||||
DANSWER_VERSION=v0.0.1
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -5,3 +5,5 @@
|
||||
.idea
|
||||
/deployment/data/nginx/app.conf
|
||||
.vscode/launch.json
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
|
||||
23
.vscode/env_template.txt
vendored
23
.vscode/env_template.txt
vendored
@@ -4,17 +4,20 @@
|
||||
|
||||
# For local dev, often user Authentication is not needed
|
||||
AUTH_TYPE=disabled
|
||||
# This passes top N results to LLM an additional time for reranking prior to answer generation, quite token heavy so we disable it for dev generally
|
||||
DISABLE_LLM_CHUNK_FILTER=True
|
||||
|
||||
|
||||
# Always keep these on for Dev
|
||||
# Logs all model prompts to stdout
|
||||
LOG_ALL_MODEL_INTERACTIONS=True
|
||||
LOG_DANSWER_MODEL_INTERACTIONS=True
|
||||
# More verbose logging
|
||||
LOG_LEVEL=debug
|
||||
|
||||
|
||||
# This passes top N results to LLM an additional time for reranking prior to answer generation
|
||||
# This step is quite heavy on token usage so we disable it for dev generally
|
||||
DISABLE_LLM_CHUNK_FILTER=True
|
||||
|
||||
|
||||
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
|
||||
OAUTH_CLIENT_ID=<REPLACE THIS>
|
||||
OAUTH_CLIENT_SECRET=<REPLACE THIS>
|
||||
@@ -22,10 +25,6 @@ OAUTH_CLIENT_SECRET=<REPLACE THIS>
|
||||
REQUIRE_EMAIL_VERIFICATION=False
|
||||
|
||||
|
||||
# Toggles on/off the EE Features
|
||||
NEXT_PUBLIC_ENABLE_PAID_EE_FEATURES=False
|
||||
|
||||
|
||||
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
|
||||
GEN_AI_API_KEY=<REPLACE THIS>
|
||||
# If answer quality isn't important for dev, use 3.5 turbo due to it being cheaper
|
||||
@@ -41,3 +40,13 @@ FAST_GEN_AI_MODEL_VERSION=gpt-3.5-turbo
|
||||
# Python stuff
|
||||
PYTHONPATH=./backend
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
|
||||
# Internet Search
|
||||
BING_API_KEY=<REPLACE THIS>
|
||||
|
||||
|
||||
# Enable the full set of Danswer Enterprise Edition features
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
|
||||
|
||||
28
.vscode/launch.template.jsonc
vendored
28
.vscode/launch.template.jsonc
vendored
@@ -17,6 +17,7 @@
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceRoot}/web",
|
||||
"runtimeExecutable": "npm",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"runtimeArgs": [
|
||||
"run", "dev"
|
||||
],
|
||||
@@ -28,6 +29,7 @@
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
@@ -45,8 +47,9 @@
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"LOG_ALL_MODEL_INTERACTIONS": "True",
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
@@ -63,6 +66,7 @@
|
||||
"request": "launch",
|
||||
"program": "danswer/background/update.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"ENABLE_MINI_CHUNK": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
@@ -77,7 +81,9 @@
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
@@ -100,6 +106,24 @@
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-v"
|
||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,6 +72,10 @@ For convenience here's a command for it:
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
--> Note that this virtual environment MUST NOT be set up WITHIN the danswer
|
||||
directory
|
||||
|
||||
_For Windows, activate the virtual environment using Command Prompt:_
|
||||
```bash
|
||||
.venv\Scripts\activate
|
||||
|
||||
8
LICENSE
8
LICENSE
@@ -1,6 +1,10 @@
|
||||
MIT License
|
||||
Copyright (c) 2023-present DanswerAI, Inc.
|
||||
|
||||
Copyright (c) 2023 Yuhong Sun, Chris Weaver
|
||||
Portions of this software are licensed as follows:
|
||||
|
||||
* All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE".
|
||||
* All third party components incorporated into the Danswer Software are licensed under the original license provided by the owner of the applicable component.
|
||||
* Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
||||
22
README.md
22
README.md
@@ -11,7 +11,7 @@
|
||||
<a href="https://docs.danswer.dev/" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ" target="_blank">
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ" target="_blank">
|
||||
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
|
||||
</a>
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
@@ -105,5 +105,25 @@ Efficiently pulls the latest changes from:
|
||||
* Websites
|
||||
* And more ...
|
||||
|
||||
## 📚 Editions
|
||||
|
||||
There are two editions of Danswer:
|
||||
|
||||
* Danswer Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Danswer you will get if you follow the Deployment guide above.
|
||||
* Danswer Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
|
||||
* Single Sign-On (SSO), with support for both SAML and OIDC
|
||||
* Role-based access control
|
||||
* Document permission inheritance from connected sources
|
||||
* Usage analytics and query history accessible to admins
|
||||
* Whitelabeling
|
||||
* API key authentication
|
||||
* Encryption of secrets
|
||||
* Any many more! Checkout [our website](https://www.danswer.ai/) for the latest.
|
||||
|
||||
To try the Danswer Enterprise Edition:
|
||||
|
||||
1. Checkout our [Cloud product](https://app.danswer.ai/signup).
|
||||
2. For self-hosting, contact us at [founders@danswer.ai](mailto:founders@danswer.ai) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
|
||||
|
||||
## 💡 Contributing
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
2
backend/.gitignore
vendored
2
backend/.gitignore
vendored
@@ -5,7 +5,7 @@ site_crawls/
|
||||
.ipynb_checkpoints/
|
||||
api_keys.py
|
||||
*ipynb
|
||||
.env
|
||||
.env*
|
||||
vespa-app.zip
|
||||
dynamic_config_storage/
|
||||
celerybeat-schedule*
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
|
||||
LABEL com.danswer.maintainer="founders@danswer.ai"
|
||||
LABEL com.danswer.description="This image is for the backend of Danswer. It is MIT Licensed and \
|
||||
free for all to use. You can find it at https://hub.docker.com/r/danswer/danswer-backend. For \
|
||||
more details, visit https://github.com/danswer-ai/danswer."
|
||||
LABEL com.danswer.description="This image is the web/frontend container of Danswer which \
|
||||
contains code for both the Community and Enterprise editions of Danswer. If you do not \
|
||||
have a contract or agreement with DanswerAI, you are not permitted to use the Enterprise \
|
||||
Edition features outside of personal development or testing purposes. Please reach out to \
|
||||
founders@danswer.ai for more information. Please visit https://github.com/danswer-ai/danswer"
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.3-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION}
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
# Install system dependencies
|
||||
# cmake needed for psycopg (postgres)
|
||||
# libpq-dev needed for psycopg (postgres)
|
||||
@@ -17,18 +19,32 @@ RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
# zip for Vespa step futher down
|
||||
# ca-certificates for HTTPS
|
||||
RUN apt-get update && \
|
||||
apt-get install -y cmake curl zip ca-certificates libgnutls30=3.7.9-2+deb12u2 \
|
||||
libblkid1=2.38.1-5+deb12u1 libmount1=2.38.1-5+deb12u1 libsmartcols1=2.38.1-5+deb12u1 \
|
||||
libuuid1=2.38.1-5+deb12u1 && \
|
||||
apt-get install -y \
|
||||
cmake \
|
||||
curl \
|
||||
zip \
|
||||
ca-certificates \
|
||||
libgnutls30=3.7.9-2+deb12u3 \
|
||||
libblkid1=2.38.1-5+deb12u1 \
|
||||
libmount1=2.38.1-5+deb12u1 \
|
||||
libsmartcols1=2.38.1-5+deb12u1 \
|
||||
libuuid1=2.38.1-5+deb12u1 \
|
||||
libxmlsec1-dev \
|
||||
pkg-config \
|
||||
gcc && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
# Install Python dependencies
|
||||
# Remove py which is pulled in by retry, py is not needed and is a CVE
|
||||
COPY ./requirements/default.txt /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \
|
||||
COPY ./requirements/ee.txt /tmp/ee-requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
-r /tmp/requirements.txt \
|
||||
-r /tmp/ee-requirements.txt && \
|
||||
pip uninstall -y py && \
|
||||
playwright install chromium && playwright install-deps chromium && \
|
||||
playwright install chromium && \
|
||||
playwright install-deps chromium && \
|
||||
ln -s /usr/local/bin/supervisord /usr/bin/supervisord
|
||||
|
||||
# Cleanup for CVEs and size reduction
|
||||
@@ -36,11 +52,20 @@ RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \
|
||||
# xserver-common and xvfb included by playwright installation but not needed after
|
||||
# perl-base is part of the base Python Debian image but not needed for Danswer functionality
|
||||
# perl-base could only be removed with --allow-remove-essential
|
||||
RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cmake \
|
||||
libldap-2.5-0 libldap-2.5-0 && \
|
||||
RUN apt-get update && \
|
||||
apt-get remove -y --allow-remove-essential \
|
||||
perl-base \
|
||||
xserver-common \
|
||||
xvfb \
|
||||
cmake \
|
||||
libldap-2.5-0 \
|
||||
libxmlsec1-dev \
|
||||
pkg-config \
|
||||
gcc && \
|
||||
apt-get install -y libxmlsec1-openssl && \
|
||||
apt-get autoremove -y && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('intfloat/e5-base-v2')"
|
||||
@@ -53,12 +78,24 @@ nltk.download('punkt', quiet=True);"
|
||||
|
||||
# Set up application files
|
||||
WORKDIR /app
|
||||
|
||||
# Enterprise Version Files
|
||||
COPY ./ee /app/ee
|
||||
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
|
||||
|
||||
# Set up application files
|
||||
COPY ./danswer /app/danswer
|
||||
COPY ./shared_configs /app/shared_configs
|
||||
COPY ./alembic /app/alembic
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
|
||||
# Escape hatch
|
||||
COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
|
||||
# Put logo in assets
|
||||
COPY ./assets /app/assets
|
||||
|
||||
ENV PYTHONPATH /app
|
||||
|
||||
# Default command which does nothing
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
"""Add thread specific model selection
|
||||
|
||||
Revision ID: 0568ccf46a6b
|
||||
Revises: e209dc5a8156
|
||||
Create Date: 2024-06-19 14:25:36.376046
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0568ccf46a6b"
|
||||
down_revision = "e209dc5a8156"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column("current_alternate_model", sa.String(), nullable=True),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("chat_session", "current_alternate_model")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,32 @@
|
||||
"""add search doc relevance details
|
||||
|
||||
Revision ID: 05c07bf07c00
|
||||
Revises: b896bbd0d5a7
|
||||
Create Date: 2024-07-10 17:48:15.886653
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "05c07bf07c00"
|
||||
down_revision = "b896bbd0d5a7"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"search_doc",
|
||||
sa.Column("is_relevant", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"search_doc",
|
||||
sa.Column("relevance_explanation", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("search_doc", "relevance_explanation")
|
||||
op.drop_column("search_doc", "is_relevant")
|
||||
@@ -0,0 +1,86 @@
|
||||
"""remove-feedback-foreignkey-constraint
|
||||
|
||||
Revision ID: 23957775e5f5
|
||||
Revises: bc9771dccadf
|
||||
Create Date: 2024-06-27 16:04:51.480437
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "23957775e5f5"
|
||||
down_revision = "bc9771dccadf"
|
||||
branch_labels = None # type: ignore
|
||||
depends_on = None # type: ignore
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_feedback__chat_message_fk",
|
||||
"chat_feedback",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
op.alter_column(
|
||||
"chat_feedback", "chat_message_id", existing_type=sa.Integer(), nullable=True
|
||||
)
|
||||
op.drop_constraint(
|
||||
"document_retrieval_feedback__chat_message_fk",
|
||||
"document_retrieval_feedback",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"document_retrieval_feedback__chat_message_fk",
|
||||
"document_retrieval_feedback",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
op.alter_column(
|
||||
"document_retrieval_feedback",
|
||||
"chat_message_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"chat_feedback", "chat_message_id", existing_type=sa.Integer(), nullable=False
|
||||
)
|
||||
op.drop_constraint(
|
||||
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_feedback__chat_message_fk",
|
||||
"chat_feedback",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"document_retrieval_feedback",
|
||||
"chat_message_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False,
|
||||
)
|
||||
op.drop_constraint(
|
||||
"document_retrieval_feedback__chat_message_fk",
|
||||
"document_retrieval_feedback",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"document_retrieval_feedback__chat_message_fk",
|
||||
"document_retrieval",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -0,0 +1,38 @@
|
||||
"""add alternate assistant to chat message
|
||||
|
||||
Revision ID: 3a7802814195
|
||||
Revises: 23957775e5f5
|
||||
Create Date: 2024-06-05 11:18:49.966333
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3a7802814195"
|
||||
down_revision = "23957775e5f5"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_chat_message_persona",
|
||||
"chat_message",
|
||||
"persona",
|
||||
["alternate_assistant_id"],
|
||||
["id"],
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint("fk_chat_message_persona", "chat_message", type_="foreignkey")
|
||||
op.drop_column("chat_message", "alternate_assistant_id")
|
||||
@@ -0,0 +1,65 @@
|
||||
"""add cloud embedding model and update embedding_model
|
||||
|
||||
Revision ID: 44f856ae2a4a
|
||||
Revises: d716b0791ddd
|
||||
Create Date: 2024-06-28 20:01:05.927647
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "44f856ae2a4a"
|
||||
down_revision = "d716b0791ddd"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create embedding_provider table
|
||||
op.create_table(
|
||||
"embedding_provider",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("api_key", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("default_model_id", sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
|
||||
# Add cloud_provider_id to embedding_model table
|
||||
op.add_column(
|
||||
"embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Add foreign key constraints
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_model_cloud_provider",
|
||||
"embedding_model",
|
||||
"embedding_provider",
|
||||
["cloud_provider_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_provider_default_model",
|
||||
"embedding_provider",
|
||||
"embedding_model",
|
||||
["default_model_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove foreign key constraints
|
||||
op.drop_constraint(
|
||||
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
|
||||
)
|
||||
op.drop_constraint(
|
||||
"fk_embedding_provider_default_model", "embedding_provider", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Remove cloud_provider_id column
|
||||
op.drop_column("embedding_model", "cloud_provider_id")
|
||||
|
||||
# Drop embedding_provider table
|
||||
op.drop_table("embedding_provider")
|
||||
@@ -0,0 +1,23 @@
|
||||
"""added is_internet to DBDoc
|
||||
|
||||
Revision ID: 4505fd7302e1
|
||||
Revises: c18cdf4b497e
|
||||
Create Date: 2024-06-18 20:46:09.095034
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4505fd7302e1"
|
||||
down_revision = "c18cdf4b497e"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("search_doc", sa.Column("is_internet", sa.Boolean(), nullable=True))
|
||||
op.add_column("tool", sa.Column("display_name", sa.String(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("tool", "display_name")
|
||||
op.drop_column("search_doc", "is_internet")
|
||||
@@ -13,8 +13,8 @@ from sqlalchemy.dialects import postgresql
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "48d14957fe80"
|
||||
down_revision = "b85f02ec1308"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
"""added slack_auto_filter
|
||||
|
||||
Revision ID: 7aea705850d5
|
||||
Revises: 4505fd7302e1
|
||||
Create Date: 2024-07-10 11:01:23.581015
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "7aea705850d5"
|
||||
down_revision = "4505fd7302e1"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"slack_bot_config",
|
||||
sa.Column("enable_auto_filters", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE slack_bot_config SET enable_auto_filters = FALSE WHERE enable_auto_filters IS NULL"
|
||||
)
|
||||
op.alter_column(
|
||||
"slack_bot_config",
|
||||
"enable_auto_filters",
|
||||
existing_type=sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("slack_bot_config", "enable_auto_filters")
|
||||
@@ -0,0 +1,23 @@
|
||||
"""backfill is_internet data to False
|
||||
|
||||
Revision ID: b896bbd0d5a7
|
||||
Revises: 44f856ae2a4a
|
||||
Create Date: 2024-07-16 15:21:05.718571
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b896bbd0d5a7"
|
||||
down_revision = "44f856ae2a4a"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("UPDATE search_doc SET is_internet = FALSE WHERE is_internet IS NULL")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -0,0 +1,51 @@
|
||||
"""create usage reports table
|
||||
|
||||
Revision ID: bc9771dccadf
|
||||
Revises: 0568ccf46a6b
|
||||
Create Date: 2024-06-18 10:04:26.800282
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bc9771dccadf"
|
||||
down_revision = "0568ccf46a6b"
|
||||
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"usage_reports",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("report_name", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"requestor_user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("period_from", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("period_to", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["report_name"],
|
||||
["file_store.file_name"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["requestor_user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("usage_reports")
|
||||
@@ -0,0 +1,75 @@
|
||||
"""Add standard_answer tables
|
||||
|
||||
Revision ID: c18cdf4b497e
|
||||
Revises: 3a7802814195
|
||||
Create Date: 2024-06-06 15:15:02.000648
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c18cdf4b497e"
|
||||
down_revision = "3a7802814195"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"standard_answer",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("keyword", sa.String(), nullable=False),
|
||||
sa.Column("answer", sa.String(), nullable=False),
|
||||
sa.Column("active", sa.Boolean(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("keyword"),
|
||||
)
|
||||
op.create_table(
|
||||
"standard_answer_category",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
op.create_table(
|
||||
"standard_answer__standard_answer_category",
|
||||
sa.Column("standard_answer_id", sa.Integer(), nullable=False),
|
||||
sa.Column("standard_answer_category_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["standard_answer_category_id"],
|
||||
["standard_answer_category.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["standard_answer_id"],
|
||||
["standard_answer.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("standard_answer_id", "standard_answer_category_id"),
|
||||
)
|
||||
op.create_table(
|
||||
"slack_bot_config__standard_answer_category",
|
||||
sa.Column("slack_bot_config_id", sa.Integer(), nullable=False),
|
||||
sa.Column("standard_answer_category_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["slack_bot_config_id"],
|
||||
["slack_bot_config.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["standard_answer_category_id"],
|
||||
["standard_answer_category.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("slack_bot_config_id", "standard_answer_category_id"),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"chat_session", sa.Column("slack_thread_id", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_session", "slack_thread_id")
|
||||
|
||||
op.drop_table("slack_bot_config__standard_answer_category")
|
||||
op.drop_table("standard_answer__standard_answer_category")
|
||||
op.drop_table("standard_answer_category")
|
||||
op.drop_table("standard_answer")
|
||||
@@ -0,0 +1,45 @@
|
||||
"""combined slack id fields
|
||||
|
||||
Revision ID: d716b0791ddd
|
||||
Revises: 7aea705850d5
|
||||
Create Date: 2024-07-10 17:57:45.630550
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d716b0791ddd"
|
||||
down_revision = "7aea705850d5"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE slack_bot_config
|
||||
SET channel_config = jsonb_set(
|
||||
channel_config,
|
||||
'{respond_member_group_list}',
|
||||
coalesce(channel_config->'respond_team_member_list', '[]'::jsonb) ||
|
||||
coalesce(channel_config->'respond_slack_group_list', '[]'::jsonb)
|
||||
) - 'respond_team_member_list' - 'respond_slack_group_list'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE slack_bot_config
|
||||
SET channel_config = jsonb_set(
|
||||
jsonb_set(
|
||||
channel_config - 'respond_member_group_list',
|
||||
'{respond_team_member_list}',
|
||||
'[]'::jsonb
|
||||
),
|
||||
'{respond_slack_group_list}',
|
||||
'[]'::jsonb
|
||||
)
|
||||
"""
|
||||
)
|
||||
@@ -0,0 +1,22 @@
|
||||
"""added-prune-frequency
|
||||
|
||||
Revision ID: e209dc5a8156
|
||||
Revises: 48d14957fe80
|
||||
Create Date: 2024-06-16 16:02:35.273231
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "e209dc5a8156"
|
||||
down_revision = "48d14957fe80"
|
||||
branch_labels = None # type: ignore
|
||||
depends_on = None # type: ignore
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("connector", sa.Column("prune_freq", sa.Integer(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector", "prune_freq")
|
||||
2
backend/assets/.gitignore
vendored
Normal file
2
backend/assets/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
*
|
||||
!.gitignore
|
||||
@@ -1,6 +1,7 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.access.utils import prefix_user
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
from danswer.db.document import get_acccess_info_for_documents
|
||||
from danswer.db.models import User
|
||||
@@ -19,7 +20,7 @@ def _get_access_for_documents(
|
||||
cc_pair_to_delete=cc_pair_to_delete,
|
||||
)
|
||||
return {
|
||||
document_id: DocumentAccess.build(user_ids, is_public)
|
||||
document_id: DocumentAccess.build(user_ids, [], is_public)
|
||||
for document_id, user_ids, is_public in document_access_info
|
||||
}
|
||||
|
||||
@@ -38,12 +39,6 @@ def get_access_for_documents(
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def prefix_user(user_id: str) -> str:
|
||||
"""Prefixes a user ID to eliminate collision with group names.
|
||||
This assumes that groups are prefixed with a different prefix."""
|
||||
return f"user_id:{user_id}"
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
|
||||
@@ -1,20 +1,30 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from danswer.access.utils import prefix_user
|
||||
from danswer.access.utils import prefix_user_group
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocumentAccess:
|
||||
user_ids: set[str] # stringified UUIDs
|
||||
user_groups: set[str] # names of user groups associated with this document
|
||||
is_public: bool
|
||||
|
||||
def to_acl(self) -> list[str]:
|
||||
return list(self.user_ids) + ([PUBLIC_DOC_PAT] if self.is_public else [])
|
||||
return (
|
||||
[prefix_user(user_id) for user_id in self.user_ids]
|
||||
+ [prefix_user_group(group_name) for group_name in self.user_groups]
|
||||
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build(cls, user_ids: list[UUID | None], is_public: bool) -> "DocumentAccess":
|
||||
def build(
|
||||
cls, user_ids: list[UUID | None], user_groups: list[str], is_public: bool
|
||||
) -> "DocumentAccess":
|
||||
return cls(
|
||||
user_ids={str(user_id) for user_id in user_ids if user_id},
|
||||
user_groups=set(user_groups),
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
10
backend/danswer/access/utils.py
Normal file
10
backend/danswer/access/utils.py
Normal file
@@ -0,0 +1,10 @@
|
||||
def prefix_user(user_id: str) -> str:
|
||||
"""Prefixes a user ID to eliminate collision with group names.
|
||||
This assumes that groups are prefixed with a different prefix."""
|
||||
return f"user_id:{user_id}"
|
||||
|
||||
|
||||
def prefix_user_group(user_group_name: str) -> str:
|
||||
"""Prefixes a user group name to eliminate collision with user IDs.
|
||||
This assumes that user ids are prefixed with a different prefix."""
|
||||
return f"group:{user_group_name}"
|
||||
21
backend/danswer/auth/invited_users.py
Normal file
21
backend/danswer/auth/invited_users.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import cast
|
||||
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
|
||||
USER_STORE_KEY = "INVITED_USERS"
|
||||
|
||||
|
||||
def get_invited_users() -> list[str]:
|
||||
try:
|
||||
store = get_dynamic_config_store()
|
||||
return cast(list, store.load(USER_STORE_KEY))
|
||||
except ConfigNotFoundError:
|
||||
return list()
|
||||
|
||||
|
||||
def write_invited_users(emails: list[str]) -> int:
|
||||
store = get_dynamic_config_store()
|
||||
store.store(USER_STORE_KEY, cast(JSON_ro, emails))
|
||||
return len(emails)
|
||||
@@ -9,6 +9,12 @@ class UserRole(str, Enum):
|
||||
ADMIN = "admin"
|
||||
|
||||
|
||||
class UserStatus(str, Enum):
|
||||
LIVE = "live"
|
||||
INVITED = "invited"
|
||||
DEACTIVATED = "deactivated"
|
||||
|
||||
|
||||
class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
role: UserRole
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
import smtplib
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -27,6 +26,7 @@ from fastapi_users.openapi import OpenAPIResponseType
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.invited_users import get_invited_users
|
||||
from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
@@ -46,6 +46,7 @@ from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from danswer.db.auth import get_access_token_db
|
||||
from danswer.db.auth import get_default_admin_user_emails
|
||||
from danswer.db.auth import get_user_count
|
||||
from danswer.db.auth import get_user_db
|
||||
from danswer.db.engine import get_session
|
||||
@@ -54,14 +55,13 @@ from danswer.db.models import User
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation,
|
||||
)
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
USER_WHITELIST_FILE = "/home/danswer_whitelist.txt"
|
||||
_user_whitelist: list[str] | None = None
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
if AUTH_TYPE not in [AuthType.DISABLED, AuthType.BASIC, AuthType.GOOGLE_OAUTH]:
|
||||
@@ -92,20 +92,8 @@ def user_needs_to_be_verified() -> bool:
|
||||
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
|
||||
|
||||
def get_user_whitelist() -> list[str]:
|
||||
global _user_whitelist
|
||||
if _user_whitelist is None:
|
||||
if os.path.exists(USER_WHITELIST_FILE):
|
||||
with open(USER_WHITELIST_FILE, "r") as file:
|
||||
_user_whitelist = [line.strip() for line in file]
|
||||
else:
|
||||
_user_whitelist = []
|
||||
|
||||
return _user_whitelist
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str) -> None:
|
||||
whitelist = get_user_whitelist()
|
||||
whitelist = get_invited_users()
|
||||
if (whitelist and email not in whitelist) or not email:
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
@@ -163,7 +151,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
verify_email_domain(user_create.email)
|
||||
if hasattr(user_create, "role"):
|
||||
user_count = await get_user_count()
|
||||
if user_count == 0:
|
||||
if user_count == 0 or user_create.email in get_default_admin_user_emails():
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
|
||||
@@ -4,13 +4,22 @@ from typing import cast
|
||||
from celery import Celery # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from danswer.background.celery.celery_utils import should_prune_cc_pair
|
||||
from danswer.background.celery.celery_utils import should_sync_doc_set
|
||||
from danswer.background.connector_deletion import delete_connector_credential_pair
|
||||
from danswer.background.connector_deletion import delete_connector_credential_pair_batch
|
||||
from danswer.background.task_utils import build_celery_task_wrapper
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document_set import delete_document_set
|
||||
from danswer.db.document_set import fetch_document_sets
|
||||
@@ -22,8 +31,6 @@ from danswer.db.engine import build_connection_string
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import SYNC_DB_API
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.tasks import check_live_task_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
@@ -90,6 +97,74 @@ def cleanup_connector_credential_pair_task(
|
||||
raise e
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_cc_prune_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def prune_documents_task(connector_id: int, credential_id: int) -> None:
|
||||
"""connector pruning task. For a cc pair, this task pulls all docuement IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
try:
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
logger.warning(f"ccpair not found for {connector_id} {credential_id}")
|
||||
return
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
cc_pair.connector.source,
|
||||
InputType.PRUNE,
|
||||
cc_pair.connector.connector_specific_config,
|
||||
cc_pair.credential,
|
||||
db_session,
|
||||
)
|
||||
|
||||
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
|
||||
runnable_connector
|
||||
)
|
||||
|
||||
all_indexed_document_ids = {
|
||||
doc.id
|
||||
for doc in get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
}
|
||||
|
||||
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
|
||||
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
if len(doc_ids_to_remove) == 0:
|
||||
logger.info(
|
||||
f"No docs to prune from {cc_pair.connector.source} connector"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector"
|
||||
)
|
||||
delete_connector_credential_pair_batch(
|
||||
document_ids=doc_ids_to_remove,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
document_index=document_index,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to run pruning for connector id {connector_id} due to {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_document_set_sync_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_document_set_task(document_set_id: int) -> None:
|
||||
@@ -177,32 +252,48 @@ def sync_document_set_task(document_set_id: int) -> None:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_document_sets_sync_task() -> None:
|
||||
"""Runs periodically to check if any document sets are out of sync
|
||||
Creates a task to sync the set if needed"""
|
||||
"""Runs periodically to check if any sync tasks should be run and adds them
|
||||
to the queue"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# check if any document sets are not synced
|
||||
document_set_info = fetch_document_sets(
|
||||
user_id=None, db_session=db_session, include_outdated=True
|
||||
)
|
||||
for document_set, _ in document_set_info:
|
||||
if not document_set.is_up_to_date:
|
||||
task_name = name_document_set_sync_task(document_set.id)
|
||||
latest_sync = get_latest_task(task_name, db_session)
|
||||
|
||||
if latest_sync and check_live_task_not_timed_out(
|
||||
latest_sync, db_session
|
||||
):
|
||||
logger.info(
|
||||
f"Document set '{document_set.id}' is already syncing. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"Document set {document_set.id} syncing now!")
|
||||
if should_sync_doc_set(document_set, db_session):
|
||||
logger.info(f"Syncing the {document_set.name} document set")
|
||||
sync_document_set_task.apply_async(
|
||||
kwargs=dict(document_set_id=document_set.id),
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="check_for_prune_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_prune_task() -> None:
|
||||
"""Runs periodically to check if any prune tasks should be run and adds them
|
||||
to the queue"""
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
|
||||
for cc_pair in all_cc_pairs:
|
||||
if should_prune_cc_pair(
|
||||
connector=cc_pair.connector,
|
||||
credential=cc_pair.credential,
|
||||
db_session=db_session,
|
||||
):
|
||||
logger.info(f"Pruning the {cc_pair.connector.name} connector")
|
||||
|
||||
prune_documents_task.apply_async(
|
||||
kwargs=dict(
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
@@ -212,3 +303,11 @@ celery_app.conf.beat_schedule = {
|
||||
"schedule": timedelta(seconds=5),
|
||||
},
|
||||
}
|
||||
celery_app.conf.beat_schedule.update(
|
||||
{
|
||||
"check-for-prune": {
|
||||
"task": "check_for_prune_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
},
|
||||
}
|
||||
)
|
||||
9
backend/danswer/background/celery/celery_run.py
Normal file
9
backend/danswer/background/celery/celery_run.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Entry point for running celery worker / celery beat."""
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
celery_app = fetch_versioned_implementation(
|
||||
"danswer.background.celery.celery_app", "celery_app"
|
||||
)
|
||||
@@ -1,8 +1,32 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.configs.app_configs import PREVENT_SIMULTANEOUS_PRUNING
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.interfaces import BaseConnector
|
||||
from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.db.tasks import get_latest_task_by_type
|
||||
from danswer.server.documents.models import DeletionAttemptSnapshot
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_deletion_status(
|
||||
@@ -21,3 +45,94 @@ def get_deletion_status(
|
||||
credential_id=credential_id,
|
||||
status=task_state.status,
|
||||
)
|
||||
|
||||
|
||||
def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool:
|
||||
if document_set.is_up_to_date:
|
||||
return False
|
||||
|
||||
task_name = name_document_set_sync_task(document_set.id)
|
||||
latest_sync = get_latest_task(task_name, db_session)
|
||||
|
||||
if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session):
|
||||
logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.")
|
||||
return False
|
||||
|
||||
logger.info(f"Document set {document_set.id} syncing now!")
|
||||
return True
|
||||
|
||||
|
||||
def should_prune_cc_pair(
|
||||
connector: Connector, credential: Credential, db_session: Session
|
||||
) -> bool:
|
||||
if not connector.prune_freq:
|
||||
return False
|
||||
|
||||
pruning_task_name = name_cc_prune_task(
|
||||
connector_id=connector.id, credential_id=credential.id
|
||||
)
|
||||
last_pruning_task = get_latest_task(pruning_task_name, db_session)
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
|
||||
if not last_pruning_task:
|
||||
time_since_initialization = current_db_time - connector.time_created
|
||||
if time_since_initialization.total_seconds() >= connector.prune_freq:
|
||||
return True
|
||||
return False
|
||||
|
||||
if PREVENT_SIMULTANEOUS_PRUNING:
|
||||
pruning_type_task_name = name_cc_prune_task()
|
||||
last_pruning_type_task = get_latest_task_by_type(
|
||||
pruning_type_task_name, db_session
|
||||
)
|
||||
|
||||
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
|
||||
last_pruning_type_task, db_session
|
||||
):
|
||||
logger.info("Another Connector is already pruning. Skipping.")
|
||||
return False
|
||||
|
||||
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
|
||||
logger.info(f"Connector '{connector.name}' is already pruning. Skipping.")
|
||||
return False
|
||||
|
||||
if not last_pruning_task.start_time:
|
||||
return False
|
||||
|
||||
time_since_last_pruning = current_db_time - last_pruning_task.start_time
|
||||
return time_since_last_pruning.total_seconds() >= connector.prune_freq
|
||||
|
||||
|
||||
def document_batch_to_ids(doc_batch: list[Document]) -> set[str]:
|
||||
return {doc.id for doc in doc_batch}
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]:
|
||||
"""
|
||||
If the PruneConnector hasnt been implemented for the given connector, just pull
|
||||
all docs using the load_from_state and grab out the IDs
|
||||
"""
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
|
||||
doc_batch_generator = None
|
||||
if isinstance(runnable_connector, IdConnector):
|
||||
all_connector_doc_ids = runnable_connector.retrieve_all_source_ids()
|
||||
elif isinstance(runnable_connector, LoadConnector):
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
elif isinstance(runnable_connector, PollConnector):
|
||||
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime.now(timezone.utc).timestamp()
|
||||
doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
|
||||
else:
|
||||
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
|
||||
|
||||
if doc_batch_generator:
|
||||
doc_batch_processing_func = document_batch_to_ids
|
||||
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
|
||||
doc_batch_processing_func = rate_limit_builder(
|
||||
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
|
||||
)(document_batch_to_ids)
|
||||
for doc_batch in doc_batch_generator:
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||
|
||||
return all_connector_doc_ids
|
||||
|
||||
@@ -41,7 +41,7 @@ logger = setup_logger()
|
||||
_DELETION_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
def _delete_connector_credential_pair_batch(
|
||||
def delete_connector_credential_pair_batch(
|
||||
document_ids: list[str],
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
@@ -169,7 +169,7 @@ def delete_connector_credential_pair(
|
||||
if not documents:
|
||||
break
|
||||
|
||||
_delete_connector_credential_pair_batch(
|
||||
delete_connector_credential_pair_batch(
|
||||
document_ids=[document.id for document in documents],
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
|
||||
@@ -105,7 +105,9 @@ class SimpleJobClient:
|
||||
"""NOTE: `pure` arg is needed so this can be a drop in replacement for Dask"""
|
||||
self._cleanup_completed_jobs()
|
||||
if len(self.jobs) >= self.n_workers:
|
||||
logger.debug("No available workers to run job")
|
||||
logger.debug(
|
||||
f"No available workers to run job. Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'."
|
||||
)
|
||||
return None
|
||||
|
||||
job_id = self.job_id_counter
|
||||
|
||||
@@ -6,11 +6,7 @@ from datetime import timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.connector_deletion import (
|
||||
_delete_connector_credential_pair_batch,
|
||||
)
|
||||
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
|
||||
from danswer.configs.app_configs import DISABLE_DOCUMENT_CLEANUP
|
||||
from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
@@ -21,8 +17,6 @@ from danswer.connectors.models import InputType
|
||||
from danswer.db.connector import disable_connector
|
||||
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.credentials import backend_update_credential_json
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
@@ -37,6 +31,7 @@ from danswer.indexing.embedder import DefaultIndexingEmbedder
|
||||
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.utils.logger import IndexAttemptSingleton
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -46,7 +41,7 @@ def _get_document_generator(
|
||||
attempt: IndexAttempt,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> tuple[GenerateDocumentsOutput, bool]:
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
NOTE: `start_time` and `end_time` are only used for poll connectors
|
||||
|
||||
@@ -57,16 +52,13 @@ def _get_document_generator(
|
||||
task = attempt.connector.input_type
|
||||
|
||||
try:
|
||||
runnable_connector, new_credential_json = instantiate_connector(
|
||||
runnable_connector = instantiate_connector(
|
||||
attempt.connector.source,
|
||||
task,
|
||||
attempt.connector.connector_specific_config,
|
||||
attempt.credential.credential_json,
|
||||
attempt.credential,
|
||||
db_session,
|
||||
)
|
||||
if new_credential_json is not None:
|
||||
backend_update_credential_json(
|
||||
attempt.credential, new_credential_json, db_session
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
disable_connector(attempt.connector.id, db_session)
|
||||
@@ -75,7 +67,7 @@ def _get_document_generator(
|
||||
if task == InputType.LOAD_STATE:
|
||||
assert isinstance(runnable_connector, LoadConnector)
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
is_listing_complete = True
|
||||
|
||||
elif task == InputType.POLL:
|
||||
assert isinstance(runnable_connector, PollConnector)
|
||||
if attempt.connector_id is None or attempt.credential_id is None:
|
||||
@@ -88,13 +80,12 @@ def _get_document_generator(
|
||||
doc_batch_generator = runnable_connector.poll_source(
|
||||
start=start_time.timestamp(), end=end_time.timestamp()
|
||||
)
|
||||
is_listing_complete = False
|
||||
|
||||
else:
|
||||
# Event types cannot be handled by a background type
|
||||
raise RuntimeError(f"Invalid task type: {task}")
|
||||
|
||||
return doc_batch_generator, is_listing_complete
|
||||
return doc_batch_generator
|
||||
|
||||
|
||||
def _run_indexing(
|
||||
@@ -107,7 +98,6 @@ def _run_indexing(
|
||||
3. Updates Postgres to record the indexed documents + the outcome of this run
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
db_embedding_model = index_attempt.embedding_model
|
||||
index_name = db_embedding_model.index_name
|
||||
|
||||
@@ -125,6 +115,8 @@ def _run_indexing(
|
||||
normalize=db_embedding_model.normalize,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
api_key=db_embedding_model.api_key,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
@@ -166,7 +158,7 @@ def _run_indexing(
|
||||
datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
doc_batch_generator, is_listing_complete = _get_document_generator(
|
||||
doc_batch_generator = _get_document_generator(
|
||||
db_session=db_session,
|
||||
attempt=index_attempt,
|
||||
start_time=window_start,
|
||||
@@ -224,39 +216,6 @@ def _run_indexing(
|
||||
docs_removed_from_index=0,
|
||||
)
|
||||
|
||||
if is_listing_complete and not DISABLE_DOCUMENT_CLEANUP:
|
||||
# clean up all documents from the index that have not been returned from the connector
|
||||
all_indexed_document_ids = {
|
||||
d.id
|
||||
for d in get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
)
|
||||
}
|
||||
doc_ids_to_remove = list(
|
||||
all_indexed_document_ids - all_connector_doc_ids
|
||||
)
|
||||
logger.debug(
|
||||
f"Cleaning up {len(doc_ids_to_remove)} documents that are not contained in the newest connector state"
|
||||
)
|
||||
|
||||
# delete docs from cc-pair and receive the number of completely deleted docs in return
|
||||
_delete_connector_credential_pair_batch(
|
||||
document_ids=doc_ids_to_remove,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
document_index=document_index,
|
||||
)
|
||||
|
||||
update_docs_indexed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
total_docs_indexed=document_count,
|
||||
new_docs_indexed=net_doc_change,
|
||||
docs_removed_from_index=len(doc_ids_to_remove),
|
||||
)
|
||||
|
||||
run_end_dt = window_end
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
@@ -329,6 +288,7 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
)
|
||||
|
||||
if attempt is None:
|
||||
raise RuntimeError(f"Unable to find IndexAttempt for ID '{index_attempt_id}'")
|
||||
|
||||
@@ -346,11 +306,14 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
|
||||
return attempt
|
||||
|
||||
|
||||
def run_indexing_entrypoint(index_attempt_id: int) -> None:
|
||||
def run_indexing_entrypoint(index_attempt_id: int, is_ee: bool = False) -> None:
|
||||
"""Entrypoint for indexing run when using dask distributed.
|
||||
Wraps the actual logic in a `try` block so that we can catch any exceptions
|
||||
and mark the attempt as failed."""
|
||||
try:
|
||||
if is_ee:
|
||||
global_version.set_ee()
|
||||
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
|
||||
|
||||
@@ -22,6 +22,15 @@ def name_document_set_sync_task(document_set_id: int) -> str:
|
||||
return f"sync_doc_set_{document_set_id}"
|
||||
|
||||
|
||||
def name_cc_prune_task(
|
||||
connector_id: int | None = None, credential_id: int | None = None
|
||||
) -> str:
|
||||
task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}"
|
||||
if not connector_id or not credential_id:
|
||||
task_name = "prune_connector_credential_pair"
|
||||
return task_name
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable)
|
||||
|
||||
|
||||
|
||||
@@ -35,6 +35,8 @@ from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import LOG_LEVEL
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
@@ -307,10 +309,18 @@ def kickoff_indexing_jobs(
|
||||
|
||||
if use_secondary_index:
|
||||
run = secondary_client.submit(
|
||||
run_indexing_entrypoint, attempt.id, pure=False
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
else:
|
||||
run = client.submit(run_indexing_entrypoint, attempt.id, pure=False)
|
||||
run = client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
|
||||
if run:
|
||||
secondary_str = "(secondary index) " if use_secondary_index else ""
|
||||
@@ -333,13 +343,15 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
logger.info("Running a first inference to warm up embedding model")
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
if db_embedding_model.cloud_provider_id is None:
|
||||
logger.info("Running a first inference to warm up embedding model")
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
client_primary: Client | SimpleJobClient
|
||||
client_secondary: Client | SimpleJobClient
|
||||
@@ -398,6 +410,8 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
|
||||
|
||||
def update__main() -> None:
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
logger.info("Starting Indexing Loop")
|
||||
update_loop()
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -9,42 +8,30 @@ from danswer.chat.models import LlmDoc
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def llm_doc_from_inference_section(inf_chunk: InferenceSection) -> LlmDoc:
|
||||
def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc:
|
||||
return LlmDoc(
|
||||
document_id=inf_chunk.document_id,
|
||||
document_id=inference_section.center_chunk.document_id,
|
||||
# This one is using the combined content of all the chunks of the section
|
||||
# In default settings, this is the same as just the content of base chunk
|
||||
content=inf_chunk.combined_content,
|
||||
blurb=inf_chunk.blurb,
|
||||
semantic_identifier=inf_chunk.semantic_identifier,
|
||||
source_type=inf_chunk.source_type,
|
||||
metadata=inf_chunk.metadata,
|
||||
updated_at=inf_chunk.updated_at,
|
||||
link=inf_chunk.source_links[0] if inf_chunk.source_links else None,
|
||||
source_links=inf_chunk.source_links,
|
||||
content=inference_section.combined_content,
|
||||
blurb=inference_section.center_chunk.blurb,
|
||||
semantic_identifier=inference_section.center_chunk.semantic_identifier,
|
||||
source_type=inference_section.center_chunk.source_type,
|
||||
metadata=inference_section.center_chunk.metadata,
|
||||
updated_at=inference_section.center_chunk.updated_at,
|
||||
link=inference_section.center_chunk.source_links[0]
|
||||
if inference_section.center_chunk.source_links
|
||||
else None,
|
||||
source_links=inference_section.center_chunk.source_links,
|
||||
)
|
||||
|
||||
|
||||
def map_document_id_order(
|
||||
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
|
||||
) -> dict[str, int]:
|
||||
order_mapping = {}
|
||||
current = 1 if one_indexed else 0
|
||||
for chunk in chunks:
|
||||
if chunk.document_id not in order_mapping:
|
||||
order_mapping[chunk.document_id] = current
|
||||
current += 1
|
||||
|
||||
return order_mapping
|
||||
|
||||
|
||||
def create_chat_chain(
|
||||
chat_session_id: int,
|
||||
db_session: Session,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import cast
|
||||
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -50,7 +48,7 @@ def load_personas_from_yaml(
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for persona in all_personas:
|
||||
doc_set_names = persona["document_sets"]
|
||||
doc_sets: list[DocumentSetDBModel] | None = [
|
||||
doc_sets: list[DocumentSetDBModel] = [
|
||||
get_or_create_document_set_by_name(db_session, name)
|
||||
for name in doc_set_names
|
||||
]
|
||||
@@ -58,22 +56,24 @@ def load_personas_from_yaml(
|
||||
# Assume if user hasn't set any document sets for the persona, the user may want
|
||||
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
|
||||
# the document sets for the persona
|
||||
if not doc_sets:
|
||||
doc_sets = None
|
||||
|
||||
prompt_set_names = persona["prompts"]
|
||||
if not prompt_set_names:
|
||||
prompts: list[PromptDBModel | None] | None = None
|
||||
doc_set_ids: list[int] | None = None
|
||||
if doc_sets:
|
||||
doc_set_ids = [doc_set.id for doc_set in doc_sets]
|
||||
else:
|
||||
prompts = [
|
||||
doc_set_ids = None
|
||||
|
||||
prompt_ids: list[int] | None = None
|
||||
prompt_set_names = persona["prompts"]
|
||||
if prompt_set_names:
|
||||
prompts: list[PromptDBModel | None] = [
|
||||
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
|
||||
for prompt_name in prompt_set_names
|
||||
]
|
||||
if any([prompt is None for prompt in prompts]):
|
||||
raise ValueError("Invalid Persona configs, not all prompts exist")
|
||||
|
||||
if not prompts:
|
||||
prompts = None
|
||||
if prompts:
|
||||
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
|
||||
|
||||
p_id = persona.get("id")
|
||||
upsert_persona(
|
||||
@@ -91,8 +91,8 @@ def load_personas_from_yaml(
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
|
||||
prompts=cast(list[PromptDBModel] | None, prompts),
|
||||
document_sets=doc_sets,
|
||||
prompt_ids=prompt_ids,
|
||||
document_set_ids=doc_set_ids,
|
||||
default_persona=True,
|
||||
is_public=True,
|
||||
db_session=db_session,
|
||||
|
||||
@@ -42,11 +42,21 @@ class QADocsResponse(RetrievalDocs):
|
||||
return initial_dict
|
||||
|
||||
|
||||
# Second chunk of info for streaming QA
|
||||
class LLMRelevanceFilterResponse(BaseModel):
|
||||
relevant_chunk_indices: list[int]
|
||||
|
||||
|
||||
class RelevanceChunk(BaseModel):
|
||||
# TODO make this document level. Also slight misnomer here as this is actually
|
||||
# done at the section level currently rather than the chunk
|
||||
relevant: bool | None = None
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class LLMRelevanceSummaryResponse(BaseModel):
|
||||
relevance_summaries: dict[str, RelevanceChunk]
|
||||
|
||||
|
||||
class DanswerAnswerPiece(BaseModel):
|
||||
# A small piece of a complete answer. Used for streaming back answers.
|
||||
answer_piece: str | None # if None, specifies the end of an Answer
|
||||
|
||||
@@ -10,14 +10,15 @@ from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import CustomToolResponse
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import ImageGenerationDisplay
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.db.chat import attach_files_to_chat_message
|
||||
from danswer.db.chat import create_db_search_doc
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
@@ -34,6 +35,7 @@ from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import SearchDoc as DbSearchDoc
|
||||
from danswer.db.models import ToolCall
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_persona_by_id
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
@@ -46,10 +48,15 @@ from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llm_for_persona
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.retrieval.search_runner import inference_documents_from_ids
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.retrieval.search_runner import inference_sections_from_ids
|
||||
from danswer.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.search.utils import dedupe_documents
|
||||
from danswer.search.utils import drop_llm_indices
|
||||
@@ -64,6 +71,14 @@ from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
from danswer.tools.internet_search.internet_search_tool import (
|
||||
INTERNET_SEARCH_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.internet_search.internet_search_tool import (
|
||||
internet_search_response_to_search_docs,
|
||||
)
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
@@ -141,6 +156,37 @@ def _handle_search_tool_response_summary(
|
||||
)
|
||||
|
||||
|
||||
def _handle_internet_search_tool_response_summary(
|
||||
packet: ToolResponse,
|
||||
db_session: Session,
|
||||
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
|
||||
internet_search_response = cast(InternetSearchResponse, packet.response)
|
||||
server_search_docs = internet_search_response_to_search_docs(
|
||||
internet_search_response
|
||||
)
|
||||
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||
for doc in server_search_docs
|
||||
]
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
return (
|
||||
QADocsResponse(
|
||||
rephrased_query=internet_search_response.revised_query,
|
||||
top_documents=response_docs,
|
||||
predicted_flow=QueryFlow.QUESTION_ANSWER,
|
||||
predicted_search=SearchType.HYBRID,
|
||||
applied_source_filters=[],
|
||||
applied_time_cutoff=None,
|
||||
recency_bias_multiplier=1.0,
|
||||
),
|
||||
reference_db_search_docs,
|
||||
)
|
||||
|
||||
|
||||
def _check_should_force_search(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
) -> ForceUseTool | None:
|
||||
@@ -168,7 +214,7 @@ def _check_should_force_search(
|
||||
args = {"query": new_msg_req.message}
|
||||
|
||||
return ForceUseTool(
|
||||
tool_name=SearchTool.NAME,
|
||||
tool_name=SearchTool._NAME,
|
||||
args=args,
|
||||
)
|
||||
return None
|
||||
@@ -223,7 +269,15 @@ def stream_chat_message_objects(
|
||||
parent_id = new_msg_req.parent_message_id
|
||||
reference_doc_ids = new_msg_req.search_doc_ids
|
||||
retrieval_options = new_msg_req.retrieval_options
|
||||
persona = chat_session.persona
|
||||
alternate_assistant_id = new_msg_req.alternate_assistant_id
|
||||
|
||||
# use alternate persona if alternative assistant id is passed in
|
||||
if alternate_assistant_id is not None:
|
||||
persona = get_persona_by_id(
|
||||
alternate_assistant_id, user=user, db_session=db_session
|
||||
)
|
||||
else:
|
||||
persona = chat_session.persona
|
||||
|
||||
prompt_id = new_msg_req.prompt_id
|
||||
if prompt_id is None and persona.prompts:
|
||||
@@ -235,7 +289,7 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
try:
|
||||
llm = get_llm_for_persona(
|
||||
llm, fast_llm = get_llms_for_persona(
|
||||
persona=persona,
|
||||
llm_override=new_msg_req.llm_override or chat_session.llm_override,
|
||||
additional_headers=litellm_additional_headers,
|
||||
@@ -328,7 +382,7 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
selected_db_search_docs = None
|
||||
selected_llm_docs: list[LlmDoc] | None = None
|
||||
selected_sections: list[InferenceSection] | None = None
|
||||
if reference_doc_ids:
|
||||
identifier_tuples = get_doc_query_identifiers_from_model(
|
||||
search_doc_ids=reference_doc_ids,
|
||||
@@ -338,8 +392,8 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
# Generates full documents currently
|
||||
# May extend to include chunk ranges
|
||||
selected_llm_docs = inference_documents_from_ids(
|
||||
# May extend to use sections instead in the future
|
||||
selected_sections = inference_sections_from_ids(
|
||||
doc_identifiers=identifier_tuples,
|
||||
document_index=document_index,
|
||||
)
|
||||
@@ -380,6 +434,7 @@ def stream_chat_message_objects(
|
||||
# rephrased_query=,
|
||||
# token_count=,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
alternate_assistant_id=new_msg_req.alternate_assistant_id,
|
||||
# error=,
|
||||
# reference_docs=,
|
||||
db_session=db_session,
|
||||
@@ -389,11 +444,15 @@ def stream_chat_message_objects(
|
||||
if not final_msg.prompt:
|
||||
raise RuntimeError("No Prompt found")
|
||||
|
||||
prompt_config = PromptConfig.from_model(
|
||||
final_msg.prompt,
|
||||
prompt_override=(
|
||||
new_msg_req.prompt_override or chat_session.prompt_override
|
||||
),
|
||||
prompt_config = (
|
||||
PromptConfig.from_model(
|
||||
final_msg.prompt,
|
||||
prompt_override=(
|
||||
new_msg_req.prompt_override or chat_session.prompt_override
|
||||
),
|
||||
)
|
||||
if not persona
|
||||
else PromptConfig.from_model(persona.prompts[0])
|
||||
)
|
||||
|
||||
# find out what tools to use
|
||||
@@ -411,21 +470,22 @@ def stream_chat_message_objects(
|
||||
retrieval_options=retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
selected_docs=selected_llm_docs,
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
dalle_key = None
|
||||
img_generation_llm_config: LLMConfig | None = None
|
||||
if (
|
||||
llm
|
||||
and llm.config.api_key
|
||||
and llm.config.model_provider == "openai"
|
||||
):
|
||||
dalle_key = llm.config.api_key
|
||||
img_generation_llm_config = llm.config
|
||||
else:
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
openai_provider = next(
|
||||
@@ -442,9 +502,30 @@ def stream_chat_message_objects(
|
||||
raise ValueError(
|
||||
"Image generation tool requires an OpenAI API key"
|
||||
)
|
||||
dalle_key = openai_provider.api_key
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider=openai_provider.provider,
|
||||
model_name=openai_provider.default_model_name,
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=openai_provider.api_key,
|
||||
api_base=openai_provider.api_base,
|
||||
api_version=openai_provider.api_version,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
ImageGenerationTool(api_key=dalle_key)
|
||||
ImageGenerationTool(
|
||||
api_key=cast(str, img_generation_llm_config.api_key),
|
||||
api_base=img_generation_llm_config.api_base,
|
||||
api_version=img_generation_llm_config.api_version,
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
]
|
||||
elif tool_cls.__name__ == InternetSearchTool.__name__:
|
||||
bing_api_key = BING_API_KEY
|
||||
if not bing_api_key:
|
||||
raise ValueError(
|
||||
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
InternetSearchTool(api_key=bing_api_key)
|
||||
]
|
||||
|
||||
continue
|
||||
@@ -481,10 +562,14 @@ def stream_chat_message_objects(
|
||||
prompt_config=prompt_config,
|
||||
llm=(
|
||||
llm
|
||||
or get_llm_for_persona(
|
||||
persona=persona,
|
||||
llm_override=new_msg_req.llm_override or chat_session.llm_override,
|
||||
additional_headers=litellm_additional_headers,
|
||||
or get_main_llm_from_tuple(
|
||||
get_llms_for_persona(
|
||||
persona=persona,
|
||||
llm_override=(
|
||||
new_msg_req.llm_override or chat_session.llm_override
|
||||
),
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
)
|
||||
),
|
||||
message_history=[
|
||||
@@ -548,6 +633,15 @@ def stream_chat_message_objects(
|
||||
yield ImageGenerationDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
yield CustomToolResponse(
|
||||
@@ -589,7 +683,7 @@ def stream_chat_message_objects(
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name()] = tool_id
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
message=answer.llm_answer,
|
||||
|
||||
@@ -4,6 +4,7 @@ import urllib.parse
|
||||
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.configs.constants import DocumentIndexType
|
||||
from danswer.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
|
||||
#####
|
||||
# App Configs
|
||||
@@ -45,13 +46,14 @@ DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
|
||||
# Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive
|
||||
# information. This provides an extra layer of security on top of Postgres access controls
|
||||
# and is available in Danswer EE
|
||||
ENCRYPTION_KEY_SECRET = os.environ.get("ENCRYPTION_KEY_SECRET")
|
||||
ENCRYPTION_KEY_SECRET = os.environ.get("ENCRYPTION_KEY_SECRET") or ""
|
||||
|
||||
# Turn off mask if admin users should see full credentials for data connectors.
|
||||
MASK_CREDENTIAL_PREFIX = (
|
||||
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
|
||||
)
|
||||
|
||||
|
||||
SESSION_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
@@ -160,6 +162,11 @@ WEB_CONNECTOR_OAUTH_CLIENT_SECRET = os.environ.get("WEB_CONNECTOR_OAUTH_CLIENT_S
|
||||
WEB_CONNECTOR_OAUTH_TOKEN_URL = os.environ.get("WEB_CONNECTOR_OAUTH_TOKEN_URL")
|
||||
WEB_CONNECTOR_VALIDATE_URLS = os.environ.get("WEB_CONNECTOR_VALIDATE_URLS")
|
||||
|
||||
HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY = os.environ.get(
|
||||
"HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY",
|
||||
HtmlBasedConnectorTransformLinksStrategy.STRIP,
|
||||
)
|
||||
|
||||
NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP = (
|
||||
os.environ.get("NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP", "").lower()
|
||||
== "true"
|
||||
@@ -178,6 +185,12 @@ CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES = (
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Save pages labels as Danswer metadata tags
|
||||
# The reason to skip this would be to reduce the number of calls to Confluence due to rate limit concerns
|
||||
CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING = (
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING", "").lower() == "true"
|
||||
)
|
||||
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
@@ -199,6 +212,22 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = (
|
||||
os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day
|
||||
|
||||
PREVENT_SIMULTANEOUS_PRUNING = (
|
||||
os.environ.get("PREVENT_SIMULTANEOUS_PRUNING", "").lower() == "true"
|
||||
)
|
||||
|
||||
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
|
||||
MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE = int(
|
||||
os.environ.get("MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE", 0)
|
||||
)
|
||||
|
||||
# comma delimited list of zendesk article labels to skip indexing for
|
||||
ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS = os.environ.get(
|
||||
"ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS", ""
|
||||
).split(",")
|
||||
|
||||
|
||||
#####
|
||||
# Indexing Configs
|
||||
@@ -219,19 +248,17 @@ DISABLE_INDEX_UPDATE_ON_SWAP = (
|
||||
# fairly large amount of memory in order to increase substantially, since
|
||||
# each worker loads the embedding models into memory.
|
||||
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
|
||||
CHUNK_OVERLAP = 0
|
||||
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
|
||||
ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
|
||||
# Finer grained chunking for more detail retention
|
||||
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
|
||||
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
|
||||
MINI_CHUNK_SIZE = 150
|
||||
# Include the document level metadata in each chunk. If the metadata is too long, then it is thrown out
|
||||
# We don't want the metadata to overwhelm the actual contents of the chunk
|
||||
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"
|
||||
# Timeout to wait for job's last update before killing it, in hours
|
||||
CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3))
|
||||
# If set to true, then will not clean up documents that "no longer exist" when running Load connectors
|
||||
DISABLE_DOCUMENT_CLEANUP = (
|
||||
os.environ.get("DISABLE_DOCUMENT_CLEANUP", "").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
@@ -246,10 +273,14 @@ JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
|
||||
CURRENT_PROCESS_IS_AN_INDEXING_JOB = (
|
||||
os.environ.get("CURRENT_PROCESS_IS_AN_INDEXING_JOB", "").lower() == "true"
|
||||
)
|
||||
# Logs every model prompt and output, mostly used for development or exploration purposes
|
||||
# Sets LiteLLM to verbose logging
|
||||
LOG_ALL_MODEL_INTERACTIONS = (
|
||||
os.environ.get("LOG_ALL_MODEL_INTERACTIONS", "").lower() == "true"
|
||||
)
|
||||
# Logs Danswer only model interactions like prompts, responses, messages etc.
|
||||
LOG_DANSWER_MODEL_INTERACTIONS = (
|
||||
os.environ.get("LOG_DANSWER_MODEL_INTERACTIONS", "").lower() == "true"
|
||||
)
|
||||
# If set to `true` will enable additional logs about Vespa query performance
|
||||
# (time spent on finding the right docs + time spent fetching summaries from disk)
|
||||
LOG_VESPA_TIMING_INFORMATION = (
|
||||
@@ -268,3 +299,15 @@ TOKEN_BUDGET_GLOBALLY_ENABLED = (
|
||||
CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
|
||||
os.environ.get("CUSTOM_ANSWER_VALIDITY_CONDITIONS", "[]")
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# Enterprise Edition Configs
|
||||
#####
|
||||
# NOTE: this should only be enabled if you have purchased an enterprise license.
|
||||
# if you're interested in an enterprise license, please reach out to us at
|
||||
# founders@danswer.ai OR message Chris Weaver or Yuhong Sun in the Danswer
|
||||
# Slack community (https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ)
|
||||
ENTERPRISE_EDITION_ENABLED = (
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
|
||||
)
|
||||
|
||||
@@ -5,7 +5,10 @@ PROMPTS_YAML = "./danswer/chat/prompts.yaml"
|
||||
PERSONAS_YAML = "./danswer/chat/personas.yaml"
|
||||
|
||||
NUM_RETURNED_HITS = 50
|
||||
NUM_RERANKED_RESULTS = 15
|
||||
# Used for LLM filtering and reranking
|
||||
# We want this to be approximately the number of results we want to show on the first page
|
||||
# It cannot be too large due to cost and latency implications
|
||||
NUM_RERANKED_RESULTS = 20
|
||||
|
||||
# May be less depending on model
|
||||
MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
|
||||
@@ -25,9 +28,10 @@ BASE_RECENCY_DECAY = 0.5
|
||||
FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
|
||||
# Currently this next one is not configurable via env
|
||||
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
|
||||
DISABLE_LLM_FILTER_EXTRACTION = (
|
||||
os.environ.get("DISABLE_LLM_FILTER_EXTRACTION", "").lower() == "true"
|
||||
)
|
||||
# For the highest matching base size chunk, how many chunks above and below do we pull in by default
|
||||
# Note this is not in any of the deployment configs yet
|
||||
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 0)
|
||||
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 0)
|
||||
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
|
||||
# in relation to the user query
|
||||
DISABLE_LLM_CHUNK_FILTER = (
|
||||
@@ -43,8 +47,6 @@ DISABLE_LLM_QUERY_REPHRASE = (
|
||||
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
|
||||
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
|
||||
# Include additional document/chunk metadata in prompt to GenerativeAI
|
||||
INCLUDE_METADATA = False
|
||||
# Keyword Search Drop Stopwords
|
||||
# If user has changed the default model, would most likely be to use a multilingual
|
||||
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords
|
||||
@@ -64,9 +66,31 @@ TITLE_CONTENT_RATIO = max(
|
||||
# A list of languages passed to the LLM to rephase the query
|
||||
# For example "English,French,Spanish", be sure to use the "," separator
|
||||
MULTILINGUAL_QUERY_EXPANSION = os.environ.get("MULTILINGUAL_QUERY_EXPANSION") or None
|
||||
LANGUAGE_HINT = "\n" + (
|
||||
os.environ.get("LANGUAGE_HINT")
|
||||
or "IMPORTANT: Respond in the same language as my query!"
|
||||
)
|
||||
LANGUAGE_CHAT_NAMING_HINT = (
|
||||
os.environ.get("LANGUAGE_CHAT_NAMING_HINT")
|
||||
or "The name of the conversation must be in the same language as the user query."
|
||||
)
|
||||
|
||||
|
||||
# Agentic search takes significantly more tokens and therefore has much higher cost.
|
||||
# This configuration allows users to get a search-only experience with instant results
|
||||
# and no involvement from the LLM.
|
||||
# Additionally, some LLM providers have strict rate limits which may prohibit
|
||||
# sending many API requests at once (as is done in agentic search).
|
||||
DISABLE_AGENTIC_SEARCH = (
|
||||
os.environ.get("DISABLE_AGENTIC_SEARCH") or "false"
|
||||
).lower() == "true"
|
||||
|
||||
|
||||
# Stops streaming answers back to the UI if this pattern is seen:
|
||||
STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
|
||||
|
||||
# The backend logic for this being True isn't fully supported yet
|
||||
HARD_DELETE_CHATS = False
|
||||
|
||||
# Internet Search
|
||||
BING_API_KEY = os.environ.get("BING_API_KEY") or None
|
||||
|
||||
@@ -19,6 +19,7 @@ DOCUMENT_SETS = "document_sets"
|
||||
TIME_FILTER = "time_filter"
|
||||
METADATA = "metadata"
|
||||
METADATA_LIST = "metadata_list"
|
||||
METADATA_SUFFIX = "metadata_suffix"
|
||||
MATCH_HIGHLIGHTS = "match_highlights"
|
||||
# stored in the `metadata` of a chunk. Used to signify that this chunk should
|
||||
# not be used for QA. For example, Google Drive file types which can't be parsed
|
||||
@@ -41,17 +42,16 @@ DEFAULT_BOOST = 0
|
||||
SESSION_KEY = "session"
|
||||
QUERY_EVENT_ID = "query_event_id"
|
||||
LLM_CHUNKS = "llm_chunks"
|
||||
TOKEN_BUDGET = "token_budget"
|
||||
TOKEN_BUDGET_TIME_PERIOD = "token_budget_time_period"
|
||||
ENABLE_TOKEN_BUDGET = "enable_token_budget"
|
||||
TOKEN_BUDGET_SETTINGS = "token_budget_settings"
|
||||
|
||||
# For chunking/processing chunks
|
||||
TITLE_SEPARATOR = "\n\r\n"
|
||||
MAX_CHUNK_TITLE_LEN = 1000
|
||||
RETURN_SEPARATOR = "\n\r\n"
|
||||
SECTION_SEPARATOR = "\n\n"
|
||||
# For combining attributes, doesn't have to be unique/perfect to work
|
||||
INDEX_SEPARATOR = "==="
|
||||
|
||||
# For File Connector Metadata override file
|
||||
DANSWER_METADATA_FILENAME = ".danswer_metadata.json"
|
||||
|
||||
# Messages
|
||||
DISABLED_GEN_AI_MSG = (
|
||||
@@ -102,6 +102,21 @@ class DocumentSource(str, Enum):
|
||||
CLICKUP = "clickup"
|
||||
MEDIAWIKI = "mediawiki"
|
||||
WIKIPEDIA = "wikipedia"
|
||||
S3 = "s3"
|
||||
R2 = "r2"
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
OCI_STORAGE = "oci_storage"
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
R2 = "r2"
|
||||
S3 = "s3"
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
OCI_STORAGE = "oci_storage"
|
||||
|
||||
# Special case, for internet search
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
|
||||
|
||||
class DocumentIndexType(str, Enum):
|
||||
@@ -117,6 +132,11 @@ class AuthType(str, Enum):
|
||||
SAML = "saml"
|
||||
|
||||
|
||||
class QAFeedbackType(str, Enum):
|
||||
LIKE = "like" # User likes the answer, used for metrics
|
||||
DISLIKE = "dislike" # User dislikes the answer, used for metrics
|
||||
|
||||
|
||||
class SearchFeedbackType(str, Enum):
|
||||
ENDORSE = "endorse" # boost this document for all future queries
|
||||
REJECT = "reject" # down-boost this document for all future queries
|
||||
@@ -142,4 +162,5 @@ class FileOrigin(str, Enum):
|
||||
CHAT_UPLOAD = "chat_upload"
|
||||
CHAT_IMAGE_GEN = "chat_image_gen"
|
||||
CONNECTOR = "connector"
|
||||
GENERATED_REPORT = "generated_report"
|
||||
OTHER = "other"
|
||||
|
||||
@@ -47,10 +47,6 @@ DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
|
||||
DANSWER_BOT_RESPOND_EVERY_CHANNEL = (
|
||||
os.environ.get("DANSWER_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true"
|
||||
)
|
||||
# Auto detect query options like time cutoff or heavily favor recently updated docs
|
||||
DISABLE_DANSWER_BOT_FILTER_DETECT = (
|
||||
os.environ.get("DISABLE_DANSWER_BOT_FILTER_DETECT", "").lower() == "true"
|
||||
)
|
||||
# Add a second LLM call post Answer to verify if the Answer is valid
|
||||
# Throws out answers that don't directly or fully answer the user query
|
||||
# This is the default for all DanswerBot channels unless the channel is configured individually
|
||||
|
||||
@@ -39,8 +39,8 @@ ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ")
|
||||
# Purely an optimization, memory limitation consideration
|
||||
BATCH_SIZE_ENCODE_CHUNKS = 8
|
||||
# For score display purposes, only way is to know the expected ranges
|
||||
CROSS_ENCODER_RANGE_MAX = 12
|
||||
CROSS_ENCODER_RANGE_MIN = -12
|
||||
CROSS_ENCODER_RANGE_MAX = 1
|
||||
CROSS_ENCODER_RANGE_MIN = 0
|
||||
|
||||
# Unused currently, can't be used with the current default encoder model due to its output range
|
||||
SEARCH_DISTANCE_CUTOFF = 0
|
||||
|
||||
0
backend/danswer/connectors/blob/__init__.py
Normal file
0
backend/danswer/connectors/blob/__init__.py
Normal file
277
backend/danswer/connectors/blob/connector.py
Normal file
277
backend/danswer/connectors/blob/connector.py
Normal file
@@ -0,0 +1,277 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
import boto3
|
||||
from botocore.client import Config
|
||||
from mypy_boto3_s3 import S3Client
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import BlobType
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
bucket_type: str,
|
||||
bucket_name: str,
|
||||
prefix: str = "",
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.bucket_type: BlobType = BlobType(bucket_type)
|
||||
self.bucket_name = bucket_name
|
||||
self.prefix = prefix if not prefix or prefix.endswith("/") else prefix + "/"
|
||||
self.batch_size = batch_size
|
||||
self.s3_client: Optional[S3Client] = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Checks for boto3 credentials based on the bucket type.
|
||||
(1) R2: Access Key ID, Secret Access Key, Account ID
|
||||
(2) S3: AWS Access Key ID, AWS Secret Access Key
|
||||
(3) GOOGLE_CLOUD_STORAGE: Access Key ID, Secret Access Key, Project ID
|
||||
(4) OCI_STORAGE: Namespace, Region, Access Key ID, Secret Access Key
|
||||
|
||||
For each bucket type, the method initializes the appropriate S3 client:
|
||||
- R2: Uses Cloudflare R2 endpoint with S3v4 signature
|
||||
- S3: Creates a standard boto3 S3 client
|
||||
- GOOGLE_CLOUD_STORAGE: Uses Google Cloud Storage endpoint
|
||||
- OCI_STORAGE: Uses Oracle Cloud Infrastructure Object Storage endpoint
|
||||
|
||||
Raises ConnectorMissingCredentialError if required credentials are missing.
|
||||
Raises ValueError for unsupported bucket types.
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
f"Loading credentials for {self.bucket_name} or type {self.bucket_type}"
|
||||
)
|
||||
|
||||
if self.bucket_type == BlobType.R2:
|
||||
if not all(
|
||||
credentials.get(key)
|
||||
for key in ["r2_access_key_id", "r2_secret_access_key", "account_id"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Cloudflare R2")
|
||||
self.s3_client = boto3.client(
|
||||
"s3",
|
||||
endpoint_url=f"https://{credentials['account_id']}.r2.cloudflarestorage.com",
|
||||
aws_access_key_id=credentials["r2_access_key_id"],
|
||||
aws_secret_access_key=credentials["r2_secret_access_key"],
|
||||
region_name="auto",
|
||||
config=Config(signature_version="s3v4"),
|
||||
)
|
||||
|
||||
elif self.bucket_type == BlobType.S3:
|
||||
if not all(
|
||||
credentials.get(key)
|
||||
for key in ["aws_access_key_id", "aws_secret_access_key"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Google Cloud Storage")
|
||||
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=credentials["aws_access_key_id"],
|
||||
aws_secret_access_key=credentials["aws_secret_access_key"],
|
||||
)
|
||||
self.s3_client = session.client("s3")
|
||||
|
||||
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
|
||||
if not all(
|
||||
credentials.get(key) for key in ["access_key_id", "secret_access_key"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Google Cloud Storage")
|
||||
|
||||
self.s3_client = boto3.client(
|
||||
"s3",
|
||||
endpoint_url="https://storage.googleapis.com",
|
||||
aws_access_key_id=credentials["access_key_id"],
|
||||
aws_secret_access_key=credentials["secret_access_key"],
|
||||
region_name="auto",
|
||||
)
|
||||
|
||||
elif self.bucket_type == BlobType.OCI_STORAGE:
|
||||
if not all(
|
||||
credentials.get(key)
|
||||
for key in ["namespace", "region", "access_key_id", "secret_access_key"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Oracle Cloud Infrastructure")
|
||||
|
||||
self.s3_client = boto3.client(
|
||||
"s3",
|
||||
endpoint_url=f"https://{credentials['namespace']}.compat.objectstorage.{credentials['region']}.oraclecloud.com",
|
||||
aws_access_key_id=credentials["access_key_id"],
|
||||
aws_secret_access_key=credentials["secret_access_key"],
|
||||
region_name=credentials["region"],
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")
|
||||
|
||||
return None
|
||||
|
||||
def _download_object(self, key: str) -> bytes:
|
||||
if self.s3_client is None:
|
||||
raise ConnectorMissingCredentialError("Blob storage")
|
||||
object = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
|
||||
return object["Body"].read()
|
||||
|
||||
# NOTE: Left in as may be useful for one-off access to documents and sharing across orgs.
|
||||
# def _get_presigned_url(self, key: str) -> str:
|
||||
# if self.s3_client is None:
|
||||
# raise ConnectorMissingCredentialError("Blog storage")
|
||||
|
||||
# url = self.s3_client.generate_presigned_url(
|
||||
# "get_object",
|
||||
# Params={"Bucket": self.bucket_name, "Key": key},
|
||||
# ExpiresIn=self.presign_length,
|
||||
# )
|
||||
# return url
|
||||
|
||||
def _get_blob_link(self, key: str) -> str:
|
||||
if self.s3_client is None:
|
||||
raise ConnectorMissingCredentialError("Blob storage")
|
||||
|
||||
if self.bucket_type == BlobType.R2:
|
||||
account_id = self.s3_client.meta.endpoint_url.split("//")[1].split(".")[0]
|
||||
return f"https://{account_id}.r2.cloudflarestorage.com/{self.bucket_name}/{key}"
|
||||
|
||||
elif self.bucket_type == BlobType.S3:
|
||||
region = self.s3_client.meta.region_name
|
||||
return f"https://{self.bucket_name}.s3.{region}.amazonaws.com/{key}"
|
||||
|
||||
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
|
||||
return f"https://storage.cloud.google.com/{self.bucket_name}/{key}"
|
||||
|
||||
elif self.bucket_type == BlobType.OCI_STORAGE:
|
||||
namespace = self.s3_client.meta.endpoint_url.split("//")[1].split(".")[0]
|
||||
region = self.s3_client.meta.region_name
|
||||
return f"https://objectstorage.{region}.oraclecloud.com/n/{namespace}/b/{self.bucket_name}/o/{key}"
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")
|
||||
|
||||
def _yield_blob_objects(
|
||||
self,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.s3_client is None:
|
||||
raise ConnectorMissingCredentialError("Blog storage")
|
||||
|
||||
paginator = self.s3_client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
|
||||
|
||||
batch: list[Document] = []
|
||||
for page in pages:
|
||||
if "Contents" not in page:
|
||||
continue
|
||||
|
||||
for obj in page["Contents"]:
|
||||
if obj["Key"].endswith("/"):
|
||||
continue
|
||||
|
||||
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
|
||||
|
||||
if not start <= last_modified <= end:
|
||||
continue
|
||||
|
||||
downloaded_file = self._download_object(obj["Key"])
|
||||
link = self._get_blob_link(obj["Key"])
|
||||
name = os.path.basename(obj["Key"])
|
||||
|
||||
try:
|
||||
text = extract_file_text(
|
||||
name,
|
||||
BytesIO(downloaded_file),
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"{self.bucket_type}:{self.bucket_name}:{obj['Key']}",
|
||||
sections=[Section(link=link, text=text)],
|
||||
source=DocumentSource(self.bucket_type.value),
|
||||
semantic_identifier=name,
|
||||
doc_updated_at=last_modified,
|
||||
metadata={},
|
||||
)
|
||||
)
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error decoding object {obj['Key']} as UTF-8: {e}"
|
||||
)
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
logger.info("Loading blob objects")
|
||||
return self._yield_blob_objects(
|
||||
start=datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
end=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.s3_client is None:
|
||||
raise ConnectorMissingCredentialError("Blog storage")
|
||||
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
for batch in self._yield_blob_objects(start_datetime, end_datetime):
|
||||
yield batch
|
||||
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
credentials_dict = {
|
||||
"aws_access_key_id": os.environ.get("AWS_ACCESS_KEY_ID"),
|
||||
"aws_secret_access_key": os.environ.get("AWS_SECRET_ACCESS_KEY"),
|
||||
}
|
||||
|
||||
# Initialize the connector
|
||||
connector = BlobStorageConnector(
|
||||
bucket_type=os.environ.get("BUCKET_TYPE") or "s3",
|
||||
bucket_name=os.environ.get("BUCKET_NAME") or "test",
|
||||
prefix="",
|
||||
)
|
||||
|
||||
try:
|
||||
connector.load_credentials(credentials_dict)
|
||||
document_batch_generator = connector.load_from_state()
|
||||
for document_batch in document_batch_generator:
|
||||
print("First batch of documents:")
|
||||
for doc in document_batch:
|
||||
print(f"Document ID: {doc.id}")
|
||||
print(f"Semantic Identifier: {doc.semantic_identifier}")
|
||||
print(f"Source: {doc.source}")
|
||||
print(f"Updated At: {doc.doc_updated_at}")
|
||||
print("Sections:")
|
||||
for section in doc.sections:
|
||||
print(f" - Link: {section.link}")
|
||||
print(f" - Text: {section.text[:100]}...")
|
||||
print("---")
|
||||
break
|
||||
|
||||
except ConnectorMissingCredentialError as e:
|
||||
print(f"Error: {e}")
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred: {e}")
|
||||
@@ -15,6 +15,7 @@ from requests import HTTPError
|
||||
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
@@ -36,16 +37,18 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
# Potential Improvements
|
||||
# 1. If wiki page instead of space, do a search of all the children of the page instead of index all in the space
|
||||
# 2. Include attachments, etc
|
||||
# 3. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
|
||||
# 1. Include attachments, etc
|
||||
# 2. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
|
||||
|
||||
|
||||
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str]:
|
||||
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]:
|
||||
"""Sample
|
||||
https://danswer.atlassian.net/wiki/spaces/1234abcd/overview
|
||||
URL w/ page: https://danswer.atlassian.net/wiki/spaces/1234abcd/pages/5678efgh/overview
|
||||
URL w/o page: https://danswer.atlassian.net/wiki/spaces/ASAM/overview
|
||||
|
||||
wiki_base is https://danswer.atlassian.net/wiki
|
||||
space is 1234abcd
|
||||
page_id is 5678efgh
|
||||
"""
|
||||
parsed_url = urlparse(wiki_url)
|
||||
wiki_base = (
|
||||
@@ -54,18 +57,25 @@ def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str]:
|
||||
+ parsed_url.netloc
|
||||
+ parsed_url.path.split("/spaces")[0]
|
||||
)
|
||||
space = parsed_url.path.split("/")[3]
|
||||
return wiki_base, space
|
||||
|
||||
path_parts = parsed_url.path.split("/")
|
||||
space = path_parts[3]
|
||||
|
||||
page_id = path_parts[5] if len(path_parts) > 5 else ""
|
||||
return wiki_base, space, page_id
|
||||
|
||||
|
||||
def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, str]:
|
||||
def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, str, str]:
|
||||
"""Sample
|
||||
https://danswer.ai/confluence/display/1234abcd/overview
|
||||
URL w/ page https://danswer.ai/confluence/display/1234abcd/pages/5678efgh/overview
|
||||
URL w/o page https://danswer.ai/confluence/display/1234abcd/overview
|
||||
wiki_base is https://danswer.ai/confluence
|
||||
space is 1234abcd
|
||||
page_id is 5678efgh
|
||||
"""
|
||||
# /display/ is always right before the space and at the end of the base url
|
||||
# /display/ is always right before the space and at the end of the base print()
|
||||
DISPLAY = "/display/"
|
||||
PAGE = "/pages/"
|
||||
|
||||
parsed_url = urlparse(wiki_url)
|
||||
wiki_base = (
|
||||
@@ -75,10 +85,13 @@ def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, st
|
||||
+ parsed_url.path.split(DISPLAY)[0]
|
||||
)
|
||||
space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0]
|
||||
return wiki_base, space
|
||||
page_id = ""
|
||||
if (content := parsed_url.path.split(PAGE)) and len(content) > 1:
|
||||
page_id = content[1]
|
||||
return wiki_base, space, page_id
|
||||
|
||||
|
||||
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, bool]:
|
||||
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]:
|
||||
is_confluence_cloud = (
|
||||
".atlassian.net/wiki/spaces/" in wiki_url
|
||||
or ".jira.com/wiki/spaces/" in wiki_url
|
||||
@@ -86,15 +99,19 @@ def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, bool]:
|
||||
|
||||
try:
|
||||
if is_confluence_cloud:
|
||||
wiki_base, space = _extract_confluence_keys_from_cloud_url(wiki_url)
|
||||
wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url(
|
||||
wiki_url
|
||||
)
|
||||
else:
|
||||
wiki_base, space = _extract_confluence_keys_from_datacenter_url(wiki_url)
|
||||
wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url(
|
||||
wiki_url
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Not a valid Confluence Wiki Link, unable to extract wiki base and space names. Exception: {e}"
|
||||
error_msg = f"Not a valid Confluence Wiki Link, unable to extract wiki base, space, and page id. Exception: {e}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
return wiki_base, space, is_confluence_cloud
|
||||
return wiki_base, space, page_id, is_confluence_cloud
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@@ -195,10 +212,135 @@ def _comment_dfs(
|
||||
return comments_str
|
||||
|
||||
|
||||
class RecursiveIndexer:
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
confluence_client: Confluence,
|
||||
index_origin: bool,
|
||||
origin_page_id: str,
|
||||
) -> None:
|
||||
self.batch_size = 1
|
||||
# batch_size
|
||||
self.confluence_client = confluence_client
|
||||
self.index_origin = index_origin
|
||||
self.origin_page_id = origin_page_id
|
||||
self.pages = self.recurse_children_pages(0, self.origin_page_id)
|
||||
|
||||
def get_pages(self, ind: int, size: int) -> list[dict]:
|
||||
if ind * size > len(self.pages):
|
||||
return []
|
||||
return self.pages[ind * size : (ind + 1) * size]
|
||||
|
||||
def _fetch_origin_page(
|
||||
self,
|
||||
) -> dict[str, Any]:
|
||||
get_page_by_id = make_confluence_call_handle_rate_limit(
|
||||
self.confluence_client.get_page_by_id
|
||||
)
|
||||
try:
|
||||
origin_page = get_page_by_id(
|
||||
self.origin_page_id, expand="body.storage.value,version"
|
||||
)
|
||||
return origin_page
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Appending orgin page with id {self.origin_page_id} failed: {e}"
|
||||
)
|
||||
return {}
|
||||
|
||||
def recurse_children_pages(
|
||||
self,
|
||||
start_ind: int,
|
||||
page_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
pages: list[dict[str, Any]] = []
|
||||
current_level_pages: list[dict[str, Any]] = []
|
||||
next_level_pages: list[dict[str, Any]] = []
|
||||
|
||||
# Initial fetch of first level children
|
||||
index = start_ind
|
||||
while batch := self._fetch_single_depth_child_pages(
|
||||
index, self.batch_size, page_id
|
||||
):
|
||||
current_level_pages.extend(batch)
|
||||
index += len(batch)
|
||||
|
||||
pages.extend(current_level_pages)
|
||||
|
||||
# Recursively index children and children's children, etc.
|
||||
while current_level_pages:
|
||||
for child in current_level_pages:
|
||||
child_index = 0
|
||||
while child_batch := self._fetch_single_depth_child_pages(
|
||||
child_index, self.batch_size, child["id"]
|
||||
):
|
||||
next_level_pages.extend(child_batch)
|
||||
child_index += len(child_batch)
|
||||
|
||||
pages.extend(next_level_pages)
|
||||
current_level_pages = next_level_pages
|
||||
next_level_pages = []
|
||||
|
||||
if self.index_origin:
|
||||
try:
|
||||
origin_page = self._fetch_origin_page()
|
||||
pages.append(origin_page)
|
||||
except Exception as e:
|
||||
logger.warning(f"Appending origin page with id {page_id} failed: {e}")
|
||||
|
||||
return pages
|
||||
|
||||
def _fetch_single_depth_child_pages(
|
||||
self, start_ind: int, batch_size: int, page_id: str
|
||||
) -> list[dict[str, Any]]:
|
||||
child_pages: list[dict[str, Any]] = []
|
||||
|
||||
get_page_child_by_type = make_confluence_call_handle_rate_limit(
|
||||
self.confluence_client.get_page_child_by_type
|
||||
)
|
||||
|
||||
try:
|
||||
child_page = get_page_child_by_type(
|
||||
page_id,
|
||||
type="page",
|
||||
start=start_ind,
|
||||
limit=batch_size,
|
||||
expand="body.storage.value,version",
|
||||
)
|
||||
|
||||
child_pages.extend(child_page)
|
||||
return child_pages
|
||||
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Batch failed with page {page_id} at offset {start_ind} "
|
||||
f"with size {batch_size}, processing pages individually..."
|
||||
)
|
||||
|
||||
for i in range(batch_size):
|
||||
ind = start_ind + i
|
||||
try:
|
||||
child_page = get_page_child_by_type(
|
||||
page_id,
|
||||
type="page",
|
||||
start=ind,
|
||||
limit=1,
|
||||
expand="body.storage.value,version",
|
||||
)
|
||||
child_pages.extend(child_page)
|
||||
except Exception as e:
|
||||
logger.warning(f"Page {page_id} at offset {ind} failed: {e}")
|
||||
raise e
|
||||
|
||||
return child_pages
|
||||
|
||||
|
||||
class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
wiki_page_url: str,
|
||||
index_origin: bool = True,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
# if a page has one of the labels specified in this list, we will just
|
||||
@@ -209,11 +351,27 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
self.wiki_base, self.space, self.is_cloud = extract_confluence_keys_from_url(
|
||||
wiki_page_url
|
||||
)
|
||||
self.recursive_indexer: RecursiveIndexer | None = None
|
||||
self.index_origin = index_origin
|
||||
(
|
||||
self.wiki_base,
|
||||
self.space,
|
||||
self.page_id,
|
||||
self.is_cloud,
|
||||
) = extract_confluence_keys_from_url(wiki_page_url)
|
||||
|
||||
self.space_level_scan = False
|
||||
|
||||
self.confluence_client: Confluence | None = None
|
||||
|
||||
if self.page_id is None or self.page_id == "":
|
||||
self.space_level_scan = True
|
||||
|
||||
logger.info(
|
||||
f"wiki_base: {self.wiki_base}, space: {self.space}, page_id: {self.page_id},"
|
||||
+ f" space_level_scan: {self.space_level_scan}, origin: {self.index_origin}"
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
username = credentials["confluence_username"]
|
||||
access_token = credentials["confluence_access_token"]
|
||||
@@ -231,8 +389,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
self,
|
||||
confluence_client: Confluence,
|
||||
start_ind: int,
|
||||
) -> Collection[dict[str, Any]]:
|
||||
def _fetch(start_ind: int, batch_size: int) -> Collection[dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
def _fetch_space(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
|
||||
get_all_pages_from_space = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_all_pages_from_space
|
||||
)
|
||||
@@ -241,9 +399,11 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
self.space,
|
||||
start=start_ind,
|
||||
limit=batch_size,
|
||||
status="current"
|
||||
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
|
||||
else None,
|
||||
status=(
|
||||
"current"
|
||||
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
|
||||
else None
|
||||
),
|
||||
expand="body.storage.value,version",
|
||||
)
|
||||
except Exception:
|
||||
@@ -262,9 +422,11 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
self.space,
|
||||
start=start_ind + i,
|
||||
limit=1,
|
||||
status="current"
|
||||
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
|
||||
else None,
|
||||
status=(
|
||||
"current"
|
||||
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
|
||||
else None
|
||||
),
|
||||
expand="body.storage.value,version",
|
||||
)
|
||||
)
|
||||
@@ -285,17 +447,41 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
return view_pages
|
||||
|
||||
def _fetch_page(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
|
||||
if self.recursive_indexer is None:
|
||||
self.recursive_indexer = RecursiveIndexer(
|
||||
origin_page_id=self.page_id,
|
||||
batch_size=self.batch_size,
|
||||
confluence_client=self.confluence_client,
|
||||
index_origin=self.index_origin,
|
||||
)
|
||||
|
||||
return self.recursive_indexer.get_pages(start_ind, batch_size)
|
||||
|
||||
pages: list[dict[str, Any]] = []
|
||||
|
||||
try:
|
||||
return _fetch(start_ind, self.batch_size)
|
||||
pages = (
|
||||
_fetch_space(start_ind, self.batch_size)
|
||||
if self.space_level_scan
|
||||
else _fetch_page(start_ind, self.batch_size)
|
||||
)
|
||||
return pages
|
||||
|
||||
except Exception as e:
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
|
||||
# error checking phase, only reachable if `self.continue_on_failure=True`
|
||||
pages: list[dict[str, Any]] = []
|
||||
for i in range(self.batch_size):
|
||||
try:
|
||||
pages.extend(_fetch(start_ind + i, 1))
|
||||
pages = (
|
||||
_fetch_space(start_ind, self.batch_size)
|
||||
if self.space_level_scan
|
||||
else _fetch_page(start_ind, self.batch_size)
|
||||
)
|
||||
return pages
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Ran into exception when fetching pages from Confluence"
|
||||
@@ -307,6 +493,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
get_page_child_by_type = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_page_child_by_type
|
||||
)
|
||||
|
||||
try:
|
||||
comment_pages = cast(
|
||||
Collection[dict[str, Any]],
|
||||
@@ -355,7 +542,14 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
page_id, start=0, limit=500
|
||||
)
|
||||
for attachment in attachments_container["results"]:
|
||||
if attachment["metadata"]["mediaType"] in ["image/jpeg", "image/png"]:
|
||||
if attachment["metadata"]["mediaType"] in [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"video/mp4",
|
||||
"video/quicktime",
|
||||
]:
|
||||
continue
|
||||
|
||||
if attachment["title"] not in files_in_used:
|
||||
@@ -366,7 +560,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
if response.status_code == 200:
|
||||
extract = extract_file_text(
|
||||
attachment["title"], io.BytesIO(response.content)
|
||||
attachment["title"], io.BytesIO(response.content), False
|
||||
)
|
||||
files_attachment_content.append(extract)
|
||||
|
||||
@@ -386,8 +580,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
batch = self._fetch_pages(self.confluence_client, start_ind)
|
||||
|
||||
for page in batch:
|
||||
last_modified_str = page["version"]["when"]
|
||||
author = cast(str | None, page["version"].get("by", {}).get("email"))
|
||||
@@ -403,15 +597,18 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
if time_filter is None or time_filter(last_modified):
|
||||
page_id = page["id"]
|
||||
|
||||
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
|
||||
page_labels = self._fetch_labels(self.confluence_client, page_id)
|
||||
|
||||
# check disallowed labels
|
||||
if self.labels_to_skip:
|
||||
page_labels = self._fetch_labels(self.confluence_client, page_id)
|
||||
label_intersection = self.labels_to_skip.intersection(page_labels)
|
||||
if label_intersection:
|
||||
logger.info(
|
||||
f"Page with ID '{page_id}' has a label which has been "
|
||||
f"designated as disallowed: {label_intersection}. Skipping."
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
page_html = (
|
||||
@@ -432,6 +629,11 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
page_text += attachment_text
|
||||
comments_text = self._fetch_comments(self.confluence_client, page_id)
|
||||
page_text += comments_text
|
||||
doc_metadata: dict[str, str | list[str]] = {
|
||||
"Wiki Space Name": self.space
|
||||
}
|
||||
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING and page_labels:
|
||||
doc_metadata["labels"] = page_labels
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
@@ -440,12 +642,10 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=page["title"],
|
||||
doc_updated_at=last_modified,
|
||||
primary_owners=[BasicExpertInfo(email=author)]
|
||||
if author
|
||||
else None,
|
||||
metadata={
|
||||
"Wiki Space Name": self.space,
|
||||
},
|
||||
primary_owners=(
|
||||
[BasicExpertInfo(email=author)] if author else None
|
||||
),
|
||||
metadata=doc_metadata,
|
||||
)
|
||||
)
|
||||
return doc_batch, len(batch)
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
from requests import HTTPError
|
||||
from retry import retry
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
@@ -18,23 +22,38 @@ class ConfluenceRateLimitError(Exception):
|
||||
|
||||
|
||||
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
|
||||
@retry(
|
||||
exceptions=ConfluenceRateLimitError,
|
||||
tries=10,
|
||||
delay=1,
|
||||
max_delay=600, # 10 minutes
|
||||
backoff=2,
|
||||
jitter=1,
|
||||
)
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
try:
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
if (
|
||||
e.response.status_code == 429
|
||||
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
|
||||
):
|
||||
raise ConfluenceRateLimitError()
|
||||
raise
|
||||
starting_delay = 5
|
||||
backoff = 2
|
||||
max_delay = 600
|
||||
|
||||
for attempt in range(10):
|
||||
try:
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
if (
|
||||
e.response.status_code == 429
|
||||
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
|
||||
):
|
||||
retry_after = None
|
||||
try:
|
||||
retry_after = int(e.response.headers.get("Retry-After"))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
if retry_after:
|
||||
logger.warning(
|
||||
f"Rate limit hit. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
time.sleep(retry_after)
|
||||
else:
|
||||
logger.warning(
|
||||
"Rate limit hit. Retrying with exponential backoff..."
|
||||
)
|
||||
delay = min(starting_delay * (backoff**attempt), max_delay)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
# re-raise, let caller handle
|
||||
raise
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import TypeVar
|
||||
|
||||
from dateutil.parser import parse
|
||||
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.utils.text_processing import is_valid_email
|
||||
|
||||
@@ -57,3 +58,7 @@ def process_in_batches(
|
||||
) -> Iterator[list[U]]:
|
||||
for i in range(0, len(objects), batch_size):
|
||||
yield [process_function(obj) for obj in objects[i : i + batch_size]]
|
||||
|
||||
|
||||
def get_metadata_keys_to_ignore() -> list[str]:
|
||||
return [IGNORE_FOR_QA]
|
||||
|
||||
@@ -11,6 +11,9 @@ from requests import Response
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
@@ -58,63 +61,36 @@ class DiscourseConnector(PollConnector):
|
||||
self.category_id_map: dict[int, str] = {}
|
||||
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.permissions: DiscoursePerms | None = None
|
||||
self.active_categories: set | None = None
|
||||
|
||||
@rate_limit_builder(max_calls=100, period=60)
|
||||
def _make_request(self, endpoint: str, params: dict | None = None) -> Response:
|
||||
if not self.permissions:
|
||||
raise ConnectorMissingCredentialError("Discourse")
|
||||
return discourse_request(endpoint, self.permissions, params)
|
||||
|
||||
def _get_categories_map(
|
||||
self,
|
||||
) -> None:
|
||||
assert self.permissions is not None
|
||||
categories_endpoint = urllib.parse.urljoin(self.base_url, "categories.json")
|
||||
response = discourse_request(
|
||||
response = self._make_request(
|
||||
endpoint=categories_endpoint,
|
||||
perms=self.permissions,
|
||||
params={"include_subcategories": True},
|
||||
)
|
||||
categories = response.json()["category_list"]["categories"]
|
||||
|
||||
self.category_id_map = {
|
||||
category["id"]: category["name"]
|
||||
for category in categories
|
||||
if not self.categories or category["name"].lower() in self.categories
|
||||
cat["id"]: cat["name"]
|
||||
for cat in categories
|
||||
if not self.categories or cat["name"].lower() in self.categories
|
||||
}
|
||||
|
||||
def _get_latest_topics(
|
||||
self, start: datetime | None, end: datetime | None
|
||||
) -> list[int]:
|
||||
assert self.permissions is not None
|
||||
topic_ids = []
|
||||
|
||||
valid_categories = set(self.category_id_map.keys())
|
||||
|
||||
latest_endpoint = urllib.parse.urljoin(self.base_url, "latest.json")
|
||||
response = discourse_request(endpoint=latest_endpoint, perms=self.permissions)
|
||||
topics = response.json()["topic_list"]["topics"]
|
||||
for topic in topics:
|
||||
last_time = topic.get("last_posted_at")
|
||||
if not last_time:
|
||||
continue
|
||||
last_time_dt = time_str_to_utc(last_time)
|
||||
|
||||
if start and start > last_time_dt:
|
||||
continue
|
||||
if end and end < last_time_dt:
|
||||
continue
|
||||
|
||||
if valid_categories and topic.get("category_id") not in valid_categories:
|
||||
continue
|
||||
|
||||
topic_ids.append(topic["id"])
|
||||
|
||||
return topic_ids
|
||||
self.active_categories = set(self.category_id_map)
|
||||
|
||||
def _get_doc_from_topic(self, topic_id: int) -> Document:
|
||||
assert self.permissions is not None
|
||||
topic_endpoint = urllib.parse.urljoin(self.base_url, f"t/{topic_id}.json")
|
||||
response = discourse_request(
|
||||
endpoint=topic_endpoint,
|
||||
perms=self.permissions,
|
||||
)
|
||||
response = self._make_request(endpoint=topic_endpoint)
|
||||
topic = response.json()
|
||||
|
||||
topic_url = urllib.parse.urljoin(self.base_url, f"t/{topic['slug']}")
|
||||
@@ -138,10 +114,16 @@ class DiscourseConnector(PollConnector):
|
||||
sections.append(
|
||||
Section(link=topic_url, text=parse_html_page_basic(post["cooked"]))
|
||||
)
|
||||
category_name = self.category_id_map.get(topic["category_id"])
|
||||
|
||||
metadata: dict[str, str | list[str]] = (
|
||||
{
|
||||
"category": category_name,
|
||||
}
|
||||
if category_name
|
||||
else {}
|
||||
)
|
||||
|
||||
metadata: dict[str, str | list[str]] = {
|
||||
"category": self.category_id_map[topic["category_id"]],
|
||||
}
|
||||
if topic.get("tags"):
|
||||
metadata["tags"] = topic["tags"]
|
||||
|
||||
@@ -157,26 +139,78 @@ class DiscourseConnector(PollConnector):
|
||||
)
|
||||
return doc
|
||||
|
||||
def _get_latest_topics(
|
||||
self, start: datetime | None, end: datetime | None, page: int
|
||||
) -> list[int]:
|
||||
assert self.permissions is not None
|
||||
topic_ids = []
|
||||
|
||||
if not self.categories:
|
||||
latest_endpoint = urllib.parse.urljoin(
|
||||
self.base_url, f"latest.json?page={page}"
|
||||
)
|
||||
response = self._make_request(endpoint=latest_endpoint)
|
||||
topics = response.json()["topic_list"]["topics"]
|
||||
|
||||
else:
|
||||
topics = []
|
||||
empty_categories = []
|
||||
|
||||
for category_id in self.category_id_map.keys():
|
||||
category_endpoint = urllib.parse.urljoin(
|
||||
self.base_url, f"c/{category_id}.json?page={page}&sys=latest"
|
||||
)
|
||||
response = self._make_request(endpoint=category_endpoint)
|
||||
new_topics = response.json()["topic_list"]["topics"]
|
||||
|
||||
if len(new_topics) == 0:
|
||||
empty_categories.append(category_id)
|
||||
topics.extend(new_topics)
|
||||
|
||||
for empty_category in empty_categories:
|
||||
self.category_id_map.pop(empty_category)
|
||||
|
||||
for topic in topics:
|
||||
last_time = topic.get("last_posted_at")
|
||||
if not last_time:
|
||||
continue
|
||||
|
||||
last_time_dt = time_str_to_utc(last_time)
|
||||
if (start and start > last_time_dt) or (end and end < last_time_dt):
|
||||
continue
|
||||
|
||||
topic_ids.append(topic["id"])
|
||||
if len(topic_ids) >= self.batch_size:
|
||||
break
|
||||
|
||||
return topic_ids
|
||||
|
||||
def _yield_discourse_documents(
|
||||
self, topic_ids: list[int]
|
||||
self,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch: list[Document] = []
|
||||
for topic_id in topic_ids:
|
||||
doc_batch.append(self._get_doc_from_topic(topic_id))
|
||||
page = 1
|
||||
while topic_ids := self._get_latest_topics(start, end, page):
|
||||
doc_batch: list[Document] = []
|
||||
for topic_id in topic_ids:
|
||||
doc_batch.append(self._get_doc_from_topic(topic_id))
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
page += 1
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
def load_credentials(
|
||||
self,
|
||||
credentials: dict[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
self.permissions = DiscoursePerms(
|
||||
api_key=credentials["discourse_api_key"],
|
||||
api_username=credentials["discourse_api_username"],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def poll_source(
|
||||
@@ -184,16 +218,13 @@ class DiscourseConnector(PollConnector):
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.permissions is None:
|
||||
raise ConnectorMissingCredentialError("Discourse")
|
||||
|
||||
start_datetime = datetime.utcfromtimestamp(start).replace(tzinfo=timezone.utc)
|
||||
end_datetime = datetime.utcfromtimestamp(end).replace(tzinfo=timezone.utc)
|
||||
|
||||
self._get_categories_map()
|
||||
|
||||
latest_topic_ids = self._get_latest_topics(
|
||||
start=start_datetime, end=end_datetime
|
||||
)
|
||||
|
||||
yield from self._yield_discourse_documents(latest_topic_ids)
|
||||
yield from self._yield_discourse_documents(start_datetime, end_datetime)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -209,7 +240,5 @@ if __name__ == "__main__":
|
||||
|
||||
current = time.time()
|
||||
one_year_ago = current - 24 * 60 * 60 * 360
|
||||
|
||||
latest_docs = connector.poll_source(one_year_ago, current)
|
||||
|
||||
print(next(latest_docs))
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from typing import Any
|
||||
from typing import Type
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.axero.connector import AxeroConnector
|
||||
from danswer.connectors.blob.connector import BlobStorageConnector
|
||||
from danswer.connectors.bookstack.connector import BookstackConnector
|
||||
from danswer.connectors.clickup.connector import ClickupConnector
|
||||
from danswer.connectors.confluence.connector import ConfluenceConnector
|
||||
@@ -40,6 +43,8 @@ from danswer.connectors.web.connector import WebConnector
|
||||
from danswer.connectors.wikipedia.connector import WikipediaConnector
|
||||
from danswer.connectors.zendesk.connector import ZendeskConnector
|
||||
from danswer.connectors.zulip.connector import ZulipConnector
|
||||
from danswer.db.credentials import backend_update_credential_json
|
||||
from danswer.db.models import Credential
|
||||
|
||||
|
||||
class ConnectorMissingException(Exception):
|
||||
@@ -86,6 +91,10 @@ def identify_connector_class(
|
||||
DocumentSource.CLICKUP: ClickupConnector,
|
||||
DocumentSource.MEDIAWIKI: MediaWikiConnector,
|
||||
DocumentSource.WIKIPEDIA: WikipediaConnector,
|
||||
DocumentSource.S3: BlobStorageConnector,
|
||||
DocumentSource.R2: BlobStorageConnector,
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.OCI_STORAGE: BlobStorageConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
|
||||
@@ -111,7 +120,6 @@ def identify_connector_class(
|
||||
raise ConnectorMissingException(
|
||||
f"Connector for source={source} does not accept input_type={input_type}"
|
||||
)
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
@@ -119,10 +127,14 @@ def instantiate_connector(
|
||||
source: DocumentSource,
|
||||
input_type: InputType,
|
||||
connector_specific_config: dict[str, Any],
|
||||
credentials: dict[str, Any],
|
||||
) -> tuple[BaseConnector, dict[str, Any] | None]:
|
||||
credential: Credential,
|
||||
db_session: Session,
|
||||
) -> BaseConnector:
|
||||
connector_class = identify_connector_class(source, input_type)
|
||||
connector = connector_class(**connector_specific_config)
|
||||
new_credentials = connector.load_credentials(credentials)
|
||||
new_credentials = connector.load_credentials(credential.credential_json)
|
||||
|
||||
return connector, new_credentials
|
||||
if new_credentials is not None:
|
||||
backend_update_credential_json(credential, new_credentials, db_session)
|
||||
|
||||
return connector
|
||||
|
||||
@@ -85,8 +85,18 @@ def _process_file(
|
||||
|
||||
all_metadata = {**metadata, **file_metadata} if metadata else file_metadata
|
||||
|
||||
# add a prefix to avoid conflicts with other connectors
|
||||
doc_id = f"FILE_CONNECTOR__{file_name}"
|
||||
if metadata:
|
||||
doc_id = metadata.get("document_id") or doc_id
|
||||
|
||||
# If this is set, we will show this in the UI as the "name" of the file
|
||||
file_display_name_override = all_metadata.get("file_display_name")
|
||||
file_display_name = all_metadata.get("file_display_name") or os.path.basename(
|
||||
file_name
|
||||
)
|
||||
title = (
|
||||
all_metadata["title"] or "" if "title" in all_metadata else file_display_name
|
||||
)
|
||||
|
||||
time_updated = all_metadata.get("time_updated", datetime.now(timezone.utc))
|
||||
if isinstance(time_updated, str):
|
||||
@@ -101,6 +111,7 @@ def _process_file(
|
||||
for k, v in all_metadata.items()
|
||||
if k
|
||||
not in [
|
||||
"document_id",
|
||||
"time_updated",
|
||||
"doc_updated_at",
|
||||
"link",
|
||||
@@ -108,6 +119,7 @@ def _process_file(
|
||||
"secondary_owners",
|
||||
"filename",
|
||||
"file_display_name",
|
||||
"title",
|
||||
]
|
||||
}
|
||||
|
||||
@@ -126,13 +138,13 @@ def _process_file(
|
||||
|
||||
return [
|
||||
Document(
|
||||
id=f"FILE_CONNECTOR__{file_name}", # add a prefix to avoid conflicts with other connectors
|
||||
id=doc_id,
|
||||
sections=[
|
||||
Section(link=all_metadata.get("link"), text=file_content_raw.strip())
|
||||
],
|
||||
source=DocumentSource.FILE,
|
||||
semantic_identifier=file_display_name_override
|
||||
or os.path.basename(file_name),
|
||||
semantic_identifier=file_display_name,
|
||||
title=title,
|
||||
doc_updated_at=final_time_updated,
|
||||
primary_owners=p_owners,
|
||||
secondary_owners=s_owners,
|
||||
|
||||
@@ -473,6 +473,11 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
|
||||
doc_batch = []
|
||||
for file in files_batch:
|
||||
try:
|
||||
# Skip files that are shortcuts
|
||||
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
|
||||
logger.info("Ignoring Drive Shortcut Filetype")
|
||||
continue
|
||||
|
||||
if self.only_org_public:
|
||||
if "permissions" not in file:
|
||||
continue
|
||||
|
||||
@@ -50,6 +50,12 @@ class PollConnector(BaseConnector):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IdConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def retrieve_all_source_ids(self) -> set[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Event driven
|
||||
class EventConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -6,6 +6,7 @@ from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import INDEX_SEPARATOR
|
||||
from danswer.configs.constants import RETURN_SEPARATOR
|
||||
from danswer.utils.text_processing import make_url_compatible
|
||||
|
||||
|
||||
@@ -13,6 +14,7 @@ class InputType(str, Enum):
|
||||
LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file
|
||||
POLL = "poll" # e.g. calling an API to get all documents in the last hour
|
||||
EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events
|
||||
PRUNE = "prune"
|
||||
|
||||
|
||||
class ConnectorMissingCredentialError(PermissionError):
|
||||
@@ -116,7 +118,12 @@ class DocumentBase(BaseModel):
|
||||
# If title is explicitly empty, return a None here for embedding purposes
|
||||
if self.title == "":
|
||||
return None
|
||||
return self.semantic_identifier if self.title is None else self.title
|
||||
replace_chars = set(RETURN_SEPARATOR)
|
||||
title = self.semantic_identifier if self.title is None else self.title
|
||||
for char in replace_chars:
|
||||
title = title.replace(char, " ")
|
||||
title = title.strip()
|
||||
return title
|
||||
|
||||
def get_metadata_str_attributes(self) -> list[str] | None:
|
||||
if not self.metadata:
|
||||
|
||||
@@ -368,7 +368,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
compare_time = time.mktime(
|
||||
time.strptime(page[filter_field], "%Y-%m-%dT%H:%M:%S.000Z")
|
||||
)
|
||||
if compare_time <= end or compare_time > start:
|
||||
if compare_time > start and compare_time <= end:
|
||||
filtered_pages += [NotionPage(**page)]
|
||||
return filtered_pages
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
@@ -23,11 +24,12 @@ from danswer.utils.logger import setup_logger
|
||||
|
||||
DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
||||
MAX_QUERY_LENGTH = 10000 # max query length is 20,000 characters
|
||||
ID_PREFIX = "SALESFORCE_"
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class SalesforceConnector(LoadConnector, PollConnector):
|
||||
class SalesforceConnector(LoadConnector, PollConnector, IdConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
@@ -77,8 +79,9 @@ class SalesforceConnector(LoadConnector, PollConnector):
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
extracted_id = f"SALESFORCE_{object_dict['Id']}"
|
||||
extracted_link = f"https://{self.sf_client.sf_instance}/{extracted_id}"
|
||||
salesforce_id = object_dict["Id"]
|
||||
danswer_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
|
||||
extracted_link = f"https://{self.sf_client.sf_instance}/{salesforce_id}"
|
||||
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
|
||||
extracted_object_text = extract_dict_text(object_dict)
|
||||
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
|
||||
@@ -89,7 +92,7 @@ class SalesforceConnector(LoadConnector, PollConnector):
|
||||
]
|
||||
|
||||
doc = Document(
|
||||
id=extracted_id,
|
||||
id=danswer_salesforce_id,
|
||||
sections=[Section(link=extracted_link, text=extracted_object_text)],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier=extracted_semantic_identifier,
|
||||
@@ -229,8 +232,6 @@ class SalesforceConnector(LoadConnector, PollConnector):
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
return self._fetch_from_salesforce()
|
||||
|
||||
def poll_source(
|
||||
@@ -242,6 +243,20 @@ class SalesforceConnector(LoadConnector, PollConnector):
|
||||
end_datetime = datetime.utcfromtimestamp(end)
|
||||
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
|
||||
|
||||
def retrieve_all_source_ids(self) -> set[str]:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
all_retrieved_ids: set[str] = set()
|
||||
for parent_object_type in self.parent_object_list:
|
||||
query = f"SELECT Id FROM {parent_object_type}"
|
||||
query_result = self.sf_client.query_all(query)
|
||||
all_retrieved_ids.update(
|
||||
f"{ID_PREFIX}{instance_dict.get('Id', '')}"
|
||||
for instance_dict in query_result["records"]
|
||||
)
|
||||
|
||||
return all_retrieved_ids
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = SalesforceConnector(
|
||||
|
||||
@@ -29,6 +29,7 @@ from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import pdf_to_text
|
||||
from danswer.file_processing.html_utils import web_html_cleanup
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.sitemap import list_pages_for_site
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -145,16 +146,21 @@ def extract_urls_from_sitemap(sitemap_url: str) -> list[str]:
|
||||
response.raise_for_status()
|
||||
|
||||
soup = BeautifulSoup(response.content, "html.parser")
|
||||
result = [
|
||||
urls = [
|
||||
_ensure_absolute_url(sitemap_url, loc_tag.text)
|
||||
for loc_tag in soup.find_all("loc")
|
||||
]
|
||||
if not result:
|
||||
|
||||
if len(urls) == 0 and len(soup.find_all("urlset")) == 0:
|
||||
# the given url doesn't look like a sitemap, let's try to find one
|
||||
urls = list_pages_for_site(sitemap_url)
|
||||
|
||||
if len(urls) == 0:
|
||||
raise ValueError(
|
||||
f"No URLs found in sitemap {sitemap_url}. Try using the 'single' or 'recursive' scraping options instead."
|
||||
)
|
||||
|
||||
return result
|
||||
return urls
|
||||
|
||||
|
||||
def _ensure_absolute_url(source_url: str, maybe_relative_url: str) -> str:
|
||||
@@ -264,7 +270,7 @@ class WebConnector(LoadConnector):
|
||||
id=current_url,
|
||||
sections=[Section(link=current_url, text=page_text)],
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=current_url.split(".")[-1],
|
||||
semantic_identifier=current_url.split("/")[-1],
|
||||
metadata={},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from zenpy import Zenpy # type: ignore
|
||||
from zenpy.lib.api_objects.help_centre_objects import Article # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.app_configs import ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
time_str_to_utc,
|
||||
@@ -81,7 +82,14 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
doc_batch = []
|
||||
for article in articles:
|
||||
if article.body is None or article.draft:
|
||||
if (
|
||||
article.body is None
|
||||
or article.draft
|
||||
or any(
|
||||
label in ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
|
||||
for label in article.label_names
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
doc_batch.append(_article_to_document(article))
|
||||
|
||||
@@ -25,6 +25,7 @@ from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.icons import source_to_github_img_link
|
||||
@@ -353,6 +354,22 @@ def build_quotes_block(
|
||||
return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))]
|
||||
|
||||
|
||||
def build_standard_answer_blocks(
|
||||
answer_message: str,
|
||||
) -> list[Block]:
|
||||
generate_button_block = ButtonElement(
|
||||
action_id=GENERATE_ANSWER_BUTTON_ACTION_ID,
|
||||
text="Generate Full Answer",
|
||||
)
|
||||
answer_block = SectionBlock(text=answer_message)
|
||||
return [
|
||||
answer_block,
|
||||
ActionsBlock(
|
||||
elements=[generate_button_block],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def build_qa_response_blocks(
|
||||
message_id: int | None,
|
||||
answer: str | None,
|
||||
@@ -457,7 +474,7 @@ def build_follow_up_resolved_blocks(
|
||||
if tag_str:
|
||||
tag_str += " "
|
||||
|
||||
group_str = " ".join([f"<!subteam^{group}>" for group in group_ids])
|
||||
group_str = " ".join([f"<!subteam^{group_id}|>" for group_id in group_ids])
|
||||
if group_str:
|
||||
group_str += " "
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ FOLLOWUP_BUTTON_ACTION_ID = "followup-button"
|
||||
FOLLOWUP_BUTTON_RESOLVED_ACTION_ID = "followup-resolved-button"
|
||||
SLACK_CHANNEL_ID = "channel_id"
|
||||
VIEW_DOC_FEEDBACK_ID = "view-doc-feedback"
|
||||
GENERATE_ANSWER_BUTTON_ACTION_ID = "generate-answer-button"
|
||||
|
||||
|
||||
class FeedbackVisibility(str, Enum):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -8,6 +9,7 @@ from slack_sdk.socket_mode import SocketModeClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
@@ -21,12 +23,17 @@ from danswer.danswerbot.slack.constants import VIEW_DOC_FEEDBACK_ID
|
||||
from danswer.danswerbot.slack.handlers.handle_message import (
|
||||
remove_scheduled_feedback_reminder,
|
||||
)
|
||||
from danswer.danswerbot.slack.handlers.handle_regular_answer import (
|
||||
handle_regular_answer,
|
||||
)
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import build_feedback_id
|
||||
from danswer.danswerbot.slack.utils import decompose_action_id
|
||||
from danswer.danswerbot.slack.utils import fetch_groupids_from_names
|
||||
from danswer.danswerbot.slack.utils import fetch_userids_from_emails
|
||||
from danswer.danswerbot.slack.utils import fetch_group_ids_from_names
|
||||
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
|
||||
from danswer.danswerbot.slack.utils import get_channel_name_from_id
|
||||
from danswer.danswerbot.slack.utils import get_feedback_visibility
|
||||
from danswer.danswerbot.slack.utils import read_slack_thread
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import update_emote_react
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
@@ -36,7 +43,7 @@ from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger_base = setup_logger()
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def handle_doc_feedback_button(
|
||||
@@ -44,7 +51,7 @@ def handle_doc_feedback_button(
|
||||
client: SocketModeClient,
|
||||
) -> None:
|
||||
if not (actions := req.payload.get("actions")):
|
||||
logger_base.error("Missing actions. Unable to build the source feedback view")
|
||||
logger.error("Missing actions. Unable to build the source feedback view")
|
||||
return
|
||||
|
||||
# Extracts the feedback_id coming from the 'source feedback' button
|
||||
@@ -72,6 +79,66 @@ def handle_doc_feedback_button(
|
||||
)
|
||||
|
||||
|
||||
def handle_generate_answer_button(
|
||||
req: SocketModeRequest,
|
||||
client: SocketModeClient,
|
||||
) -> None:
|
||||
channel_id = req.payload["channel"]["id"]
|
||||
channel_name = req.payload["channel"]["name"]
|
||||
message_ts = req.payload["message"]["ts"]
|
||||
thread_ts = req.payload["container"]["thread_ts"]
|
||||
user_id = req.payload["user"]["id"]
|
||||
|
||||
if not thread_ts:
|
||||
raise ValueError("Missing thread_ts in the payload")
|
||||
|
||||
thread_messages = read_slack_thread(
|
||||
channel=channel_id, thread=thread_ts, client=client.web_client
|
||||
)
|
||||
# remove all assistant messages till we get to the last user message
|
||||
# we want the new answer to be generated off of the last "question" in
|
||||
# the thread
|
||||
for i in range(len(thread_messages) - 1, -1, -1):
|
||||
if thread_messages[i].role == MessageType.USER:
|
||||
break
|
||||
if thread_messages[i].role == MessageType.ASSISTANT:
|
||||
thread_messages.pop(i)
|
||||
|
||||
# tell the user that we're working on it
|
||||
# Send an ephemeral message to the user that we're generating the answer
|
||||
respond_in_thread(
|
||||
client=client.web_client,
|
||||
channel=channel_id,
|
||||
receiver_ids=[user_id],
|
||||
text="I'm working on generating a full answer for you. This may take a moment...",
|
||||
thread_ts=thread_ts,
|
||||
)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
slack_bot_config = get_slack_bot_config_for_channel(
|
||||
channel_name=channel_name, db_session=db_session
|
||||
)
|
||||
|
||||
handle_regular_answer(
|
||||
message_info=SlackMessageInfo(
|
||||
thread_messages=thread_messages,
|
||||
channel_to_respond=channel_id,
|
||||
msg_to_respond=cast(str, message_ts or thread_ts),
|
||||
thread_to_respond=cast(str, thread_ts or message_ts),
|
||||
sender=user_id or None,
|
||||
bypass_filters=True,
|
||||
is_bot_msg=False,
|
||||
is_bot_dm=False,
|
||||
),
|
||||
slack_bot_config=slack_bot_config,
|
||||
receiver_ids=None,
|
||||
client=client.web_client,
|
||||
channel=channel_id,
|
||||
logger=cast(logging.Logger, logger),
|
||||
feedback_reminder_id=None,
|
||||
)
|
||||
|
||||
|
||||
def handle_slack_feedback(
|
||||
feedback_id: str,
|
||||
feedback_type: str,
|
||||
@@ -129,7 +196,7 @@ def handle_slack_feedback(
|
||||
feedback=feedback,
|
||||
)
|
||||
else:
|
||||
logger_base.error(f"Feedback type '{feedback_type}' not supported")
|
||||
logger.error(f"Feedback type '{feedback_type}' not supported")
|
||||
|
||||
if get_feedback_visibility() == FeedbackVisibility.PRIVATE or feedback_type not in [
|
||||
LIKE_BLOCK_ACTION_ID,
|
||||
@@ -193,11 +260,11 @@ def handle_followup_button(
|
||||
tag_names = slack_bot_config.channel_config.get("follow_up_tags")
|
||||
remaining = None
|
||||
if tag_names:
|
||||
tag_ids, remaining = fetch_userids_from_emails(
|
||||
tag_ids, remaining = fetch_user_ids_from_emails(
|
||||
tag_names, client.web_client
|
||||
)
|
||||
if remaining:
|
||||
group_ids, _ = fetch_groupids_from_names(remaining, client.web_client)
|
||||
group_ids, _ = fetch_group_ids_from_names(remaining, client.web_client)
|
||||
|
||||
blocks = build_follow_up_resolved_blocks(tag_ids=tag_ids, group_ids=group_ids)
|
||||
|
||||
@@ -272,7 +339,7 @@ def handle_followup_resolved_button(
|
||||
)
|
||||
|
||||
if not response.get("ok"):
|
||||
logger_base.error("Unable to delete message for resolved")
|
||||
logger.error("Unable to delete message for resolved")
|
||||
|
||||
if immediate:
|
||||
msg_text = f"{clicker_name} has marked this question as resolved!"
|
||||
|
||||
@@ -1,91 +1,34 @@
|
||||
import datetime
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.models.blocks import DividerBlock
|
||||
from slack_sdk.models.blocks import SectionBlock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_FEEDBACK_REMINDER
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
|
||||
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
|
||||
from danswer.configs.danswerbot_configs import DISABLE_DANSWER_BOT_FILTER_DETECT
|
||||
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
|
||||
from danswer.danswerbot.slack.blocks import build_documents_blocks
|
||||
from danswer.danswerbot.slack.blocks import build_follow_up_block
|
||||
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
|
||||
from danswer.danswerbot.slack.blocks import build_sources_blocks
|
||||
from danswer.danswerbot.slack.blocks import get_feedback_reminder_blocks
|
||||
from danswer.danswerbot.slack.blocks import get_restate_blocks
|
||||
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
|
||||
from danswer.danswerbot.slack.handlers.handle_regular_answer import (
|
||||
handle_regular_answer,
|
||||
)
|
||||
from danswer.danswerbot.slack.handlers.handle_standard_answers import (
|
||||
handle_standard_answers,
|
||||
)
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import ChannelIdAdapter
|
||||
from danswer.danswerbot.slack.utils import fetch_userids_from_emails
|
||||
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
|
||||
from danswer.danswerbot.slack.utils import fetch_user_ids_from_groups
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import slack_usage_report
|
||||
from danswer.danswerbot.slack.utils import SlackRateLimiter
|
||||
from danswer.danswerbot.slack.utils import update_emote_react
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.persona import fetch_persona_by_id
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.llm.factory import get_llm_for_persona
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.one_shot_answer.answer_question import get_search_answer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.search.models import BaseFilters
|
||||
from danswer.search.models import OptionalSearchSetting
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
|
||||
|
||||
logger_base = setup_logger()
|
||||
|
||||
srl = SlackRateLimiter()
|
||||
|
||||
RT = TypeVar("RT") # return type
|
||||
|
||||
|
||||
def rate_limits(
|
||||
client: WebClient, channel: str, thread_ts: Optional[str]
|
||||
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
|
||||
def decorator(func: Callable[..., RT]) -> Callable[..., RT]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> RT:
|
||||
if not srl.is_available():
|
||||
func_randid, position = srl.init_waiter()
|
||||
srl.notify(client, channel, position, thread_ts)
|
||||
while not srl.is_available():
|
||||
srl.waiter(func_randid)
|
||||
srl.acquire_slot()
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None:
|
||||
if details.is_bot_msg and details.sender:
|
||||
@@ -173,17 +116,9 @@ def remove_scheduled_feedback_reminder(
|
||||
|
||||
def handle_message(
|
||||
message_info: SlackMessageInfo,
|
||||
channel_config: SlackBotConfig | None,
|
||||
slack_bot_config: SlackBotConfig | None,
|
||||
client: WebClient,
|
||||
feedback_reminder_id: str | None,
|
||||
num_retries: int = DANSWER_BOT_NUM_RETRIES,
|
||||
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
|
||||
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
|
||||
disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER,
|
||||
disable_auto_detect_filters: bool = DISABLE_DANSWER_BOT_FILTER_DETECT,
|
||||
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
|
||||
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
|
||||
thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
|
||||
) -> bool:
|
||||
"""Potentially respond to the user message depending on filters and if an answer was generated
|
||||
|
||||
@@ -200,14 +135,22 @@ def handle_message(
|
||||
)
|
||||
|
||||
messages = message_info.thread_messages
|
||||
message_ts_to_respond_to = message_info.msg_to_respond
|
||||
sender_id = message_info.sender
|
||||
bypass_filters = message_info.bypass_filters
|
||||
is_bot_msg = message_info.is_bot_msg
|
||||
is_bot_dm = message_info.is_bot_dm
|
||||
|
||||
action = "slack_message"
|
||||
if is_bot_msg:
|
||||
action = "slack_slash_message"
|
||||
elif bypass_filters:
|
||||
action = "slack_tag_message"
|
||||
elif is_bot_dm:
|
||||
action = "slack_dm_message"
|
||||
slack_usage_report(action=action, sender_id=sender_id, client=client)
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
persona = channel_config.persona if channel_config else None
|
||||
persona = slack_bot_config.persona if slack_bot_config else None
|
||||
prompt = None
|
||||
if persona:
|
||||
document_set_names = [
|
||||
@@ -215,36 +158,13 @@ def handle_message(
|
||||
]
|
||||
prompt = persona.prompts[0] if persona.prompts else None
|
||||
|
||||
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
|
||||
|
||||
# figure out if we want to use citations or quotes
|
||||
use_citations = (
|
||||
not DANSWER_BOT_USE_QUOTES
|
||||
if channel_config is None
|
||||
else channel_config.response_type == SlackBotResponseType.CITATIONS
|
||||
)
|
||||
|
||||
# List of user id to send message to, if None, send to everyone in channel
|
||||
send_to: list[str] | None = None
|
||||
respond_tag_only = False
|
||||
respond_team_member_list = None
|
||||
|
||||
bypass_acl = False
|
||||
if (
|
||||
channel_config
|
||||
and channel_config.persona
|
||||
and channel_config.persona.document_sets
|
||||
):
|
||||
# For Slack channels, use the full document set, admin will be warned when configuring it
|
||||
# with non-public document sets
|
||||
bypass_acl = True
|
||||
respond_member_group_list = None
|
||||
|
||||
channel_conf = None
|
||||
if channel_config and channel_config.channel_config:
|
||||
channel_conf = channel_config.channel_config
|
||||
if slack_bot_config and slack_bot_config.channel_config:
|
||||
channel_conf = slack_bot_config.channel_config
|
||||
if not bypass_filters and "answer_filters" in channel_conf:
|
||||
reflexion = "well_answered_postfilter" in channel_conf["answer_filters"]
|
||||
|
||||
if (
|
||||
"questionmark_prefilter" in channel_conf["answer_filters"]
|
||||
and "?" not in messages[-1].message
|
||||
@@ -261,7 +181,7 @@ def handle_message(
|
||||
)
|
||||
|
||||
respond_tag_only = channel_conf.get("respond_tag_only") or False
|
||||
respond_team_member_list = channel_conf.get("respond_team_member_list") or None
|
||||
respond_member_group_list = channel_conf.get("respond_member_group_list", None)
|
||||
|
||||
if respond_tag_only and not bypass_filters:
|
||||
logger.info(
|
||||
@@ -270,12 +190,23 @@ def handle_message(
|
||||
)
|
||||
return False
|
||||
|
||||
if respond_team_member_list:
|
||||
send_to, _ = fetch_userids_from_emails(respond_team_member_list, client)
|
||||
# List of user id to send message to, if None, send to everyone in channel
|
||||
send_to: list[str] | None = None
|
||||
missing_users: list[str] | None = None
|
||||
if respond_member_group_list:
|
||||
send_to, missing_ids = fetch_user_ids_from_emails(
|
||||
respond_member_group_list, client
|
||||
)
|
||||
|
||||
user_ids, missing_users = fetch_user_ids_from_groups(missing_ids, client)
|
||||
send_to = list(set(send_to + user_ids)) if send_to else user_ids
|
||||
|
||||
if missing_users:
|
||||
logger.warning(f"Failed to find these users/groups: {missing_users}")
|
||||
|
||||
# If configured to respond to team members only, then cannot be used with a /DanswerBot command
|
||||
# which would just respond to the sender
|
||||
if respond_team_member_list and is_bot_msg:
|
||||
if send_to and is_bot_msg:
|
||||
if sender_id:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
@@ -290,324 +221,28 @@ def handle_message(
|
||||
except SlackApiError as e:
|
||||
logger.error(f"Was not able to react to user message due to: {e}")
|
||||
|
||||
@retry(
|
||||
tries=num_retries,
|
||||
delay=0.25,
|
||||
backoff=2,
|
||||
logger=logger,
|
||||
)
|
||||
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
|
||||
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | None:
|
||||
action = "slack_message"
|
||||
if is_bot_msg:
|
||||
action = "slack_slash_message"
|
||||
elif bypass_filters:
|
||||
action = "slack_tag_message"
|
||||
elif is_bot_dm:
|
||||
action = "slack_dm_message"
|
||||
|
||||
slack_usage_report(action=action, sender_id=sender_id, client=client)
|
||||
|
||||
max_document_tokens: int | None = None
|
||||
max_history_tokens: int | None = None
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if len(new_message_request.messages) > 1:
|
||||
persona = cast(
|
||||
Persona,
|
||||
fetch_persona_by_id(db_session, new_message_request.persona_id),
|
||||
)
|
||||
llm = get_llm_for_persona(persona)
|
||||
|
||||
# In cases of threads, split the available tokens between docs and thread context
|
||||
input_tokens = get_max_input_tokens(
|
||||
model_name=llm.config.model_name,
|
||||
model_provider=llm.config.model_provider,
|
||||
)
|
||||
max_history_tokens = int(input_tokens * thread_context_percent)
|
||||
|
||||
remaining_tokens = input_tokens - max_history_tokens
|
||||
|
||||
query_text = new_message_request.messages[0].message
|
||||
if persona:
|
||||
max_document_tokens = compute_max_document_tokens_for_persona(
|
||||
persona=persona,
|
||||
actual_user_input=query_text,
|
||||
max_llm_token_override=remaining_tokens,
|
||||
)
|
||||
else:
|
||||
max_document_tokens = (
|
||||
remaining_tokens
|
||||
- 512 # Needs to be more than any of the QA prompts
|
||||
- check_number_of_tokens(query_text)
|
||||
)
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
return None
|
||||
|
||||
# This also handles creating the query event in postgres
|
||||
answer = get_search_answer(
|
||||
query_req=new_message_request,
|
||||
user=None,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
db_session=db_session,
|
||||
answer_generation_timeout=answer_generation_timeout,
|
||||
enable_reflexion=reflexion,
|
||||
bypass_acl=bypass_acl,
|
||||
use_citations=use_citations,
|
||||
danswerbot_flow=True,
|
||||
)
|
||||
if not answer.error_msg:
|
||||
return answer
|
||||
else:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
|
||||
try:
|
||||
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
|
||||
# it allows the slack flow to extract out filters from the user query
|
||||
filters = BaseFilters(
|
||||
source_type=None,
|
||||
document_set=document_set_names,
|
||||
time_cutoff=None,
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# first check if we need to respond with a standard answer
|
||||
used_standard_answer = handle_standard_answers(
|
||||
message_info=message_info,
|
||||
receiver_ids=send_to,
|
||||
slack_bot_config=slack_bot_config,
|
||||
prompt=prompt,
|
||||
logger=logger,
|
||||
client=client,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Default True because no other ways to apply filters in Slack (no nice UI)
|
||||
auto_detect_filters = (
|
||||
persona.llm_filter_extraction if persona is not None else True
|
||||
)
|
||||
if disable_auto_detect_filters:
|
||||
auto_detect_filters = False
|
||||
|
||||
retrieval_details = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=False,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=auto_detect_filters,
|
||||
)
|
||||
|
||||
# This includes throwing out answer via reflexion
|
||||
answer = _get_answer(
|
||||
DirectQARequest(
|
||||
messages=messages,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
persona_id=persona.id if persona is not None else 0,
|
||||
retrieval_options=retrieval_details,
|
||||
chain_of_thought=not disable_cot,
|
||||
skip_rerank=not ENABLE_RERANKING_ASYNC_FLOW,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unable to process message - did not successfully answer "
|
||||
f"in {num_retries} attempts"
|
||||
)
|
||||
# Optionally, respond in thread with the error message, Used primarily
|
||||
# for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=None,
|
||||
text=f"Encountered exception when trying to answer: \n\n```{e}```",
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
# In case of failures, don't keep the reaction there permanently
|
||||
try:
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=True,
|
||||
client=client,
|
||||
)
|
||||
except SlackApiError as e:
|
||||
logger.error(f"Failed to remove Reaction due to: {e}")
|
||||
|
||||
return True
|
||||
|
||||
# Edge case handling, for tracking down the Slack usage issue
|
||||
if answer is None:
|
||||
assert DISABLE_GENERATIVE_AI is True
|
||||
try:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=send_to,
|
||||
text="Hello! Danswer has some results for you!",
|
||||
blocks=[
|
||||
SectionBlock(
|
||||
text="Danswer is down for maintenance.\nWe're working hard on recharging the AI!"
|
||||
)
|
||||
],
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
|
||||
unfurl=False,
|
||||
)
|
||||
|
||||
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
|
||||
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
|
||||
if respond_team_member_list:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
text=(
|
||||
"👋 Hi, we've just gathered and forwarded the relevant "
|
||||
+ "information to the team. They'll get back to you shortly!"
|
||||
),
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
if used_standard_answer:
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Unable to process message - could not respond in slack in {num_retries} attempts"
|
||||
)
|
||||
return True
|
||||
|
||||
# Got an answer at this point, can remove reaction and give results
|
||||
try:
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=True,
|
||||
client=client,
|
||||
)
|
||||
except SlackApiError as e:
|
||||
logger.error(f"Failed to remove Reaction due to: {e}")
|
||||
|
||||
if answer.answer_valid is False:
|
||||
logger.info(
|
||||
"Answer was evaluated to be invalid, throwing it away without responding."
|
||||
)
|
||||
update_emote_react(
|
||||
emoji=DANSWER_FOLLOWUP_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=False,
|
||||
client=client,
|
||||
)
|
||||
|
||||
if answer.answer:
|
||||
logger.debug(answer.answer)
|
||||
return True
|
||||
|
||||
retrieval_info = answer.docs
|
||||
if not retrieval_info:
|
||||
# This should not happen, even with no docs retrieved, there is still info returned
|
||||
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
|
||||
|
||||
top_docs = retrieval_info.top_documents
|
||||
if not top_docs and not should_respond_even_with_no_docs:
|
||||
logger.error(
|
||||
f"Unable to answer question: '{answer.rephrase}' - no documents found"
|
||||
)
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=None,
|
||||
text="Found no documents when trying to answer. Did you index any documents?",
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
return True
|
||||
|
||||
if not answer.answer and disable_docs_only_answer:
|
||||
logger.info(
|
||||
"Unable to find answer - not responding since the "
|
||||
"`DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER` env variable is set"
|
||||
)
|
||||
return True
|
||||
|
||||
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
|
||||
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
|
||||
|
||||
answer_blocks = build_qa_response_blocks(
|
||||
message_id=answer.chat_message_id,
|
||||
answer=answer.answer,
|
||||
quotes=answer.quotes.quotes if answer.quotes else None,
|
||||
source_filters=retrieval_info.applied_source_filters,
|
||||
time_cutoff=retrieval_info.applied_time_cutoff,
|
||||
favor_recent=retrieval_info.recency_bias_multiplier > 1,
|
||||
# currently Personas don't support quotes
|
||||
# if citations are enabled, also don't use quotes
|
||||
skip_quotes=persona is not None or use_citations,
|
||||
process_message_for_citations=use_citations,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
|
||||
# Get the chunks fed to the LLM only, then fill with other docs
|
||||
llm_doc_inds = answer.llm_chunks_indices or []
|
||||
llm_docs = [top_docs[i] for i in llm_doc_inds]
|
||||
remaining_docs = [
|
||||
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
|
||||
]
|
||||
priority_ordered_docs = llm_docs + remaining_docs
|
||||
|
||||
document_blocks = []
|
||||
citations_block = []
|
||||
# if citations are enabled, only show cited documents
|
||||
if use_citations:
|
||||
citations = answer.citations or []
|
||||
cited_docs = []
|
||||
for citation in citations:
|
||||
matching_doc = next(
|
||||
(d for d in top_docs if d.document_id == citation.document_id),
|
||||
None,
|
||||
)
|
||||
if matching_doc:
|
||||
cited_docs.append((citation.citation_num, matching_doc))
|
||||
|
||||
cited_docs.sort()
|
||||
citations_block = build_sources_blocks(cited_documents=cited_docs)
|
||||
elif priority_ordered_docs:
|
||||
document_blocks = build_documents_blocks(
|
||||
documents=priority_ordered_docs,
|
||||
message_id=answer.chat_message_id,
|
||||
)
|
||||
document_blocks = [DividerBlock()] + document_blocks
|
||||
|
||||
all_blocks = (
|
||||
restate_question_block + answer_blocks + citations_block + document_blocks
|
||||
)
|
||||
|
||||
if channel_conf and channel_conf.get("follow_up_tags") is not None:
|
||||
all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id))
|
||||
|
||||
try:
|
||||
respond_in_thread(
|
||||
# if no standard answer applies, try a regular answer
|
||||
issue_with_regular_answer = handle_regular_answer(
|
||||
message_info=message_info,
|
||||
slack_bot_config=slack_bot_config,
|
||||
receiver_ids=send_to,
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=send_to,
|
||||
text="Hello! Danswer has some results for you!",
|
||||
blocks=all_blocks,
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
|
||||
unfurl=False,
|
||||
logger=logger,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
|
||||
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
|
||||
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
|
||||
if respond_team_member_list:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
text=(
|
||||
"👋 Hi, we've just gathered and forwarded the relevant "
|
||||
+ "information to the team. They'll get back to you shortly!"
|
||||
),
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Unable to process message - could not respond in slack in {num_retries} attempts"
|
||||
)
|
||||
return True
|
||||
return issue_with_regular_answer
|
||||
|
||||
@@ -0,0 +1,465 @@
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.models.blocks import DividerBlock
|
||||
from slack_sdk.models.blocks import SectionBlock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
|
||||
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
|
||||
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
|
||||
from danswer.danswerbot.slack.blocks import build_documents_blocks
|
||||
from danswer.danswerbot.slack.blocks import build_follow_up_block
|
||||
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
|
||||
from danswer.danswerbot.slack.blocks import build_sources_blocks
|
||||
from danswer.danswerbot.slack.blocks import get_restate_blocks
|
||||
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import SlackRateLimiter
|
||||
from danswer.danswerbot.slack.utils import update_emote_react
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.persona import fetch_persona_by_id
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.one_shot_answer.answer_question import get_search_answer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.models import BaseFilters
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
|
||||
|
||||
|
||||
srl = SlackRateLimiter()
|
||||
|
||||
RT = TypeVar("RT") # return type
|
||||
|
||||
|
||||
def rate_limits(
|
||||
client: WebClient, channel: str, thread_ts: Optional[str]
|
||||
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
|
||||
def decorator(func: Callable[..., RT]) -> Callable[..., RT]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> RT:
|
||||
if not srl.is_available():
|
||||
func_randid, position = srl.init_waiter()
|
||||
srl.notify(client, channel, position, thread_ts)
|
||||
while not srl.is_available():
|
||||
srl.waiter(func_randid)
|
||||
srl.acquire_slot()
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def handle_regular_answer(
|
||||
message_info: SlackMessageInfo,
|
||||
slack_bot_config: SlackBotConfig | None,
|
||||
receiver_ids: list[str] | None,
|
||||
client: WebClient,
|
||||
channel: str,
|
||||
logger: logging.Logger,
|
||||
feedback_reminder_id: str | None,
|
||||
num_retries: int = DANSWER_BOT_NUM_RETRIES,
|
||||
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
|
||||
thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
|
||||
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
|
||||
disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER,
|
||||
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
|
||||
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
|
||||
) -> bool:
|
||||
channel_conf = slack_bot_config.channel_config if slack_bot_config else None
|
||||
|
||||
messages = message_info.thread_messages
|
||||
message_ts_to_respond_to = message_info.msg_to_respond
|
||||
is_bot_msg = message_info.is_bot_msg
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
persona = slack_bot_config.persona if slack_bot_config else None
|
||||
prompt = None
|
||||
if persona:
|
||||
document_set_names = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
]
|
||||
prompt = persona.prompts[0] if persona.prompts else None
|
||||
|
||||
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
|
||||
|
||||
bypass_acl = False
|
||||
if (
|
||||
slack_bot_config
|
||||
and slack_bot_config.persona
|
||||
and slack_bot_config.persona.document_sets
|
||||
):
|
||||
# For Slack channels, use the full document set, admin will be warned when configuring it
|
||||
# with non-public document sets
|
||||
bypass_acl = True
|
||||
|
||||
# figure out if we want to use citations or quotes
|
||||
use_citations = (
|
||||
not DANSWER_BOT_USE_QUOTES
|
||||
if slack_bot_config is None
|
||||
else slack_bot_config.response_type == SlackBotResponseType.CITATIONS
|
||||
)
|
||||
|
||||
if not message_ts_to_respond_to:
|
||||
raise RuntimeError(
|
||||
"No message timestamp to respond to in `handle_message`. This should never happen."
|
||||
)
|
||||
|
||||
@retry(
|
||||
tries=num_retries,
|
||||
delay=0.25,
|
||||
backoff=2,
|
||||
logger=logger,
|
||||
)
|
||||
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
|
||||
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | None:
|
||||
max_document_tokens: int | None = None
|
||||
max_history_tokens: int | None = None
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if len(new_message_request.messages) > 1:
|
||||
persona = cast(
|
||||
Persona,
|
||||
fetch_persona_by_id(db_session, new_message_request.persona_id),
|
||||
)
|
||||
llm, _ = get_llms_for_persona(persona)
|
||||
|
||||
# In cases of threads, split the available tokens between docs and thread context
|
||||
input_tokens = get_max_input_tokens(
|
||||
model_name=llm.config.model_name,
|
||||
model_provider=llm.config.model_provider,
|
||||
)
|
||||
max_history_tokens = int(input_tokens * thread_context_percent)
|
||||
|
||||
remaining_tokens = input_tokens - max_history_tokens
|
||||
|
||||
query_text = new_message_request.messages[0].message
|
||||
if persona:
|
||||
max_document_tokens = compute_max_document_tokens_for_persona(
|
||||
persona=persona,
|
||||
actual_user_input=query_text,
|
||||
max_llm_token_override=remaining_tokens,
|
||||
)
|
||||
else:
|
||||
max_document_tokens = (
|
||||
remaining_tokens
|
||||
- 512 # Needs to be more than any of the QA prompts
|
||||
- check_number_of_tokens(query_text)
|
||||
)
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
return None
|
||||
|
||||
# This also handles creating the query event in postgres
|
||||
answer = get_search_answer(
|
||||
query_req=new_message_request,
|
||||
user=None,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
db_session=db_session,
|
||||
answer_generation_timeout=answer_generation_timeout,
|
||||
enable_reflexion=reflexion,
|
||||
bypass_acl=bypass_acl,
|
||||
use_citations=use_citations,
|
||||
danswerbot_flow=True,
|
||||
)
|
||||
if not answer.error_msg:
|
||||
return answer
|
||||
else:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
|
||||
try:
|
||||
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
|
||||
# it allows the slack flow to extract out filters from the user query
|
||||
filters = BaseFilters(
|
||||
source_type=None,
|
||||
document_set=document_set_names,
|
||||
time_cutoff=None,
|
||||
)
|
||||
|
||||
# Default True because no other ways to apply filters in Slack (no nice UI)
|
||||
# Commenting this out because this is only available to the slackbot for now
|
||||
# later we plan to implement this at the persona level where this will get
|
||||
# commented back in
|
||||
# auto_detect_filters = (
|
||||
# persona.llm_filter_extraction if persona is not None else True
|
||||
# )
|
||||
auto_detect_filters = (
|
||||
slack_bot_config.enable_auto_filters
|
||||
if slack_bot_config is not None
|
||||
else False
|
||||
)
|
||||
retrieval_details = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=False,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=auto_detect_filters,
|
||||
)
|
||||
|
||||
# This includes throwing out answer via reflexion
|
||||
answer = _get_answer(
|
||||
DirectQARequest(
|
||||
messages=messages,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
persona_id=persona.id if persona is not None else 0,
|
||||
retrieval_options=retrieval_details,
|
||||
chain_of_thought=not disable_cot,
|
||||
skip_rerank=not ENABLE_RERANKING_ASYNC_FLOW,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unable to process message - did not successfully answer "
|
||||
f"in {num_retries} attempts"
|
||||
)
|
||||
# Optionally, respond in thread with the error message, Used primarily
|
||||
# for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=None,
|
||||
text=f"Encountered exception when trying to answer: \n\n```{e}```",
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
# In case of failures, don't keep the reaction there permanently
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=True,
|
||||
client=client,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
# Edge case handling, for tracking down the Slack usage issue
|
||||
if answer is None:
|
||||
assert DISABLE_GENERATIVE_AI is True
|
||||
try:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=receiver_ids,
|
||||
text="Hello! Danswer has some results for you!",
|
||||
blocks=[
|
||||
SectionBlock(
|
||||
text="Danswer is down for maintenance.\nWe're working hard on recharging the AI!"
|
||||
)
|
||||
],
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
|
||||
unfurl=False,
|
||||
)
|
||||
|
||||
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
|
||||
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
|
||||
if receiver_ids:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
text=(
|
||||
"👋 Hi, we've just gathered and forwarded the relevant "
|
||||
+ "information to the team. They'll get back to you shortly!"
|
||||
),
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Unable to process message - could not respond in slack in {num_retries} attempts"
|
||||
)
|
||||
return True
|
||||
|
||||
# Got an answer at this point, can remove reaction and give results
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=True,
|
||||
client=client,
|
||||
)
|
||||
|
||||
if answer.answer_valid is False:
|
||||
logger.info(
|
||||
"Answer was evaluated to be invalid, throwing it away without responding."
|
||||
)
|
||||
update_emote_react(
|
||||
emoji=DANSWER_FOLLOWUP_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=False,
|
||||
client=client,
|
||||
)
|
||||
|
||||
if answer.answer:
|
||||
logger.debug(answer.answer)
|
||||
return True
|
||||
|
||||
retrieval_info = answer.docs
|
||||
if not retrieval_info:
|
||||
# This should not happen, even with no docs retrieved, there is still info returned
|
||||
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
|
||||
|
||||
top_docs = retrieval_info.top_documents
|
||||
if not top_docs and not should_respond_even_with_no_docs:
|
||||
logger.error(
|
||||
f"Unable to answer question: '{answer.rephrase}' - no documents found"
|
||||
)
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=None,
|
||||
text="Found no documents when trying to answer. Did you index any documents?",
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
return True
|
||||
|
||||
if not answer.answer and disable_docs_only_answer:
|
||||
logger.info(
|
||||
"Unable to find answer - not responding since the "
|
||||
"`DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER` env variable is set"
|
||||
)
|
||||
return True
|
||||
|
||||
only_respond_with_citations_or_quotes = (
|
||||
channel_conf
|
||||
and "well_answered_postfilter" in channel_conf.get("answer_filters", [])
|
||||
)
|
||||
has_citations_or_quotes = bool(answer.citations or answer.quotes)
|
||||
if (
|
||||
only_respond_with_citations_or_quotes
|
||||
and not has_citations_or_quotes
|
||||
and not message_info.bypass_filters
|
||||
):
|
||||
logger.error(
|
||||
f"Unable to find citations or quotes to answer: '{answer.rephrase}' - not answering!"
|
||||
)
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=None,
|
||||
text="Found no citations or quotes when trying to answer.",
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
return True
|
||||
|
||||
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
|
||||
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
|
||||
|
||||
answer_blocks = build_qa_response_blocks(
|
||||
message_id=answer.chat_message_id,
|
||||
answer=answer.answer,
|
||||
quotes=answer.quotes.quotes if answer.quotes else None,
|
||||
source_filters=retrieval_info.applied_source_filters,
|
||||
time_cutoff=retrieval_info.applied_time_cutoff,
|
||||
favor_recent=retrieval_info.recency_bias_multiplier > 1,
|
||||
# currently Personas don't support quotes
|
||||
# if citations are enabled, also don't use quotes
|
||||
skip_quotes=persona is not None or use_citations,
|
||||
process_message_for_citations=use_citations,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
|
||||
# Get the chunks fed to the LLM only, then fill with other docs
|
||||
llm_doc_inds = answer.llm_chunks_indices or []
|
||||
llm_docs = [top_docs[i] for i in llm_doc_inds]
|
||||
remaining_docs = [
|
||||
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
|
||||
]
|
||||
priority_ordered_docs = llm_docs + remaining_docs
|
||||
|
||||
document_blocks = []
|
||||
citations_block = []
|
||||
# if citations are enabled, only show cited documents
|
||||
if use_citations:
|
||||
citations = answer.citations or []
|
||||
cited_docs = []
|
||||
for citation in citations:
|
||||
matching_doc = next(
|
||||
(d for d in top_docs if d.document_id == citation.document_id),
|
||||
None,
|
||||
)
|
||||
if matching_doc:
|
||||
cited_docs.append((citation.citation_num, matching_doc))
|
||||
|
||||
cited_docs.sort()
|
||||
citations_block = build_sources_blocks(cited_documents=cited_docs)
|
||||
elif priority_ordered_docs:
|
||||
document_blocks = build_documents_blocks(
|
||||
documents=priority_ordered_docs,
|
||||
message_id=answer.chat_message_id,
|
||||
)
|
||||
document_blocks = [DividerBlock()] + document_blocks
|
||||
|
||||
all_blocks = (
|
||||
restate_question_block + answer_blocks + citations_block + document_blocks
|
||||
)
|
||||
|
||||
if channel_conf and channel_conf.get("follow_up_tags") is not None:
|
||||
all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id))
|
||||
|
||||
try:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=receiver_ids,
|
||||
text="Hello! Danswer has some results for you!",
|
||||
blocks=all_blocks,
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
|
||||
unfurl=False,
|
||||
)
|
||||
|
||||
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
|
||||
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
|
||||
if receiver_ids:
|
||||
send_team_member_message(
|
||||
client=client,
|
||||
channel=channel,
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Unable to process message - could not respond in slack in {num_retries} attempts"
|
||||
)
|
||||
return True
|
||||
@@ -0,0 +1,216 @@
|
||||
import logging
|
||||
|
||||
from slack_sdk import WebClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
|
||||
from danswer.danswerbot.slack.blocks import build_standard_answer_blocks
|
||||
from danswer.danswerbot.slack.blocks import get_restate_blocks
|
||||
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import update_emote_react
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_chat_messages_by_sessions
|
||||
from danswer.db.chat import get_chat_sessions_by_slack_thread_id
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.standard_answer import fetch_standard_answer_categories_by_names
|
||||
from danswer.db.standard_answer import find_matching_standard_answers
|
||||
from danswer.server.manage.models import StandardAnswer
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def oneoff_standard_answers(
|
||||
message: str,
|
||||
slack_bot_categories: list[str],
|
||||
db_session: Session,
|
||||
) -> list[StandardAnswer]:
|
||||
"""
|
||||
Respond to the user message if it matches any configured standard answers.
|
||||
|
||||
Returns a list of matching StandardAnswers if found, otherwise None.
|
||||
"""
|
||||
configured_standard_answers = {
|
||||
standard_answer
|
||||
for category in fetch_standard_answer_categories_by_names(
|
||||
slack_bot_categories, db_session=db_session
|
||||
)
|
||||
for standard_answer in category.standard_answers
|
||||
}
|
||||
|
||||
matching_standard_answers = find_matching_standard_answers(
|
||||
query=message,
|
||||
id_in=[answer.id for answer in configured_standard_answers],
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
server_standard_answers = [
|
||||
StandardAnswer.from_model(db_answer) for db_answer in matching_standard_answers
|
||||
]
|
||||
return server_standard_answers
|
||||
|
||||
|
||||
def handle_standard_answers(
|
||||
message_info: SlackMessageInfo,
|
||||
receiver_ids: list[str] | None,
|
||||
slack_bot_config: SlackBotConfig | None,
|
||||
prompt: Prompt | None,
|
||||
logger: logging.Logger,
|
||||
client: WebClient,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Potentially respond to the user message depending on whether the user's message matches
|
||||
any of the configured standard answers and also whether those answers have already been
|
||||
provided in the current thread.
|
||||
|
||||
Returns True if standard answers are found to match the user's message and therefore,
|
||||
we still need to respond to the users.
|
||||
"""
|
||||
# if no channel config, then no standard answers are configured
|
||||
if not slack_bot_config:
|
||||
return False
|
||||
|
||||
slack_thread_id = message_info.thread_to_respond
|
||||
configured_standard_answer_categories = (
|
||||
slack_bot_config.standard_answer_categories if slack_bot_config else []
|
||||
)
|
||||
configured_standard_answers = set(
|
||||
[
|
||||
standard_answer
|
||||
for standard_answer_category in configured_standard_answer_categories
|
||||
for standard_answer in standard_answer_category.standard_answers
|
||||
]
|
||||
)
|
||||
query_msg = message_info.thread_messages[-1]
|
||||
|
||||
if slack_thread_id is None:
|
||||
used_standard_answer_ids = set([])
|
||||
else:
|
||||
chat_sessions = get_chat_sessions_by_slack_thread_id(
|
||||
slack_thread_id=slack_thread_id,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
chat_messages = get_chat_messages_by_sessions(
|
||||
chat_session_ids=[chat_session.id for chat_session in chat_sessions],
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
skip_permission_check=True,
|
||||
)
|
||||
used_standard_answer_ids = set(
|
||||
[
|
||||
standard_answer.id
|
||||
for chat_message in chat_messages
|
||||
for standard_answer in chat_message.standard_answers
|
||||
]
|
||||
)
|
||||
|
||||
usable_standard_answers = configured_standard_answers.difference(
|
||||
used_standard_answer_ids
|
||||
)
|
||||
if usable_standard_answers:
|
||||
matching_standard_answers = find_matching_standard_answers(
|
||||
query=query_msg.message,
|
||||
id_in=[standard_answer.id for standard_answer in usable_standard_answers],
|
||||
db_session=db_session,
|
||||
)
|
||||
else:
|
||||
matching_standard_answers = []
|
||||
if matching_standard_answers:
|
||||
chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="",
|
||||
user_id=None,
|
||||
persona_id=slack_bot_config.persona.id if slack_bot_config.persona else 0,
|
||||
danswerbot_flow=True,
|
||||
slack_thread_id=slack_thread_id,
|
||||
one_shot=True,
|
||||
)
|
||||
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
|
||||
new_user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=root_message,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
message=query_msg.message,
|
||||
token_count=0,
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
formatted_answers = []
|
||||
for standard_answer in matching_standard_answers:
|
||||
block_quotified_answer = ">" + standard_answer.answer.replace("\n", "\n> ")
|
||||
formatted_answer = (
|
||||
f'Since you mentioned _"{standard_answer.keyword}"_, '
|
||||
f"I thought this might be useful: \n\n{block_quotified_answer}"
|
||||
)
|
||||
formatted_answers.append(formatted_answer)
|
||||
answer_message = "\n\n".join(formatted_answers)
|
||||
|
||||
_ = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=new_user_message,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
message=answer_message,
|
||||
token_count=0,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
error=None,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=True,
|
||||
client=client,
|
||||
)
|
||||
|
||||
restate_question_blocks = get_restate_blocks(
|
||||
msg=query_msg.message,
|
||||
is_bot_msg=message_info.is_bot_msg,
|
||||
)
|
||||
|
||||
answer_blocks = build_standard_answer_blocks(
|
||||
answer_message=answer_message,
|
||||
)
|
||||
|
||||
all_blocks = restate_question_blocks + answer_blocks
|
||||
|
||||
try:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=message_info.channel_to_respond,
|
||||
receiver_ids=receiver_ids,
|
||||
text="Hello! Danswer has some results for you!",
|
||||
blocks=all_blocks,
|
||||
thread_ts=message_info.msg_to_respond,
|
||||
unfurl=False,
|
||||
)
|
||||
|
||||
if receiver_ids and slack_thread_id:
|
||||
send_team_member_message(
|
||||
client=client,
|
||||
channel=message_info.channel_to_respond,
|
||||
thread_ts=slack_thread_id,
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to send standard answer message: {e}")
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
19
backend/danswer/danswerbot/slack/handlers/utils.py
Normal file
19
backend/danswer/danswerbot/slack/handlers/utils.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
|
||||
|
||||
def send_team_member_message(
|
||||
client: WebClient,
|
||||
channel: str,
|
||||
thread_ts: str,
|
||||
) -> None:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
text=(
|
||||
"👋 Hi, we've just gathered and forwarded the relevant "
|
||||
+ "information to the team. They'll get back to you shortly!"
|
||||
),
|
||||
thread_ts=thread_ts,
|
||||
)
|
||||
@@ -4,9 +4,9 @@ from danswer.configs.constants import DocumentSource
|
||||
def source_to_github_img_link(source: DocumentSource) -> str | None:
|
||||
# TODO: store these images somewhere better
|
||||
if source == DocumentSource.WEB.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/Web.png"
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Web.png"
|
||||
if source == DocumentSource.FILE.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/File.png"
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/File.png"
|
||||
if source == DocumentSource.GOOGLE_SITES.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/GoogleSites.png"
|
||||
if source == DocumentSource.SLACK.value:
|
||||
@@ -20,13 +20,13 @@ def source_to_github_img_link(source: DocumentSource) -> str | None:
|
||||
if source == DocumentSource.GITLAB.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Gitlab.png"
|
||||
if source == DocumentSource.CONFLUENCE.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/Confluence.png"
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Confluence.png"
|
||||
if source == DocumentSource.JIRA.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/Jira.png"
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Jira.png"
|
||||
if source == DocumentSource.NOTION.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Notion.png"
|
||||
if source == DocumentSource.ZENDESK.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/Zendesk.png"
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Zendesk.png"
|
||||
if source == DocumentSource.GONG.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Gong.png"
|
||||
if source == DocumentSource.LINEAR.value:
|
||||
@@ -38,7 +38,7 @@ def source_to_github_img_link(source: DocumentSource) -> str | None:
|
||||
if source == DocumentSource.ZULIP.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Zulip.png"
|
||||
if source == DocumentSource.GURU.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/Guru.png"
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/Guru.png"
|
||||
if source == DocumentSource.HUBSPOT.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/HubSpot.png"
|
||||
if source == DocumentSource.DOCUMENT360.value:
|
||||
@@ -51,8 +51,8 @@ def source_to_github_img_link(source: DocumentSource) -> str | None:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/web/public/Sharepoint.png"
|
||||
if source == DocumentSource.REQUESTTRACKER.value:
|
||||
# just use file icon for now
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/File.png"
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/File.png"
|
||||
if source == DocumentSource.INGESTION_API.value:
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/File.png"
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/File.png"
|
||||
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/improve-slack-flow/backend/slackbot_images/File.png"
|
||||
return "https://raw.githubusercontent.com/danswer-ai/danswer/main/backend/slackbot_images/File.png"
|
||||
|
||||
@@ -18,6 +18,7 @@ from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
|
||||
@@ -27,6 +28,9 @@ from danswer.danswerbot.slack.handlers.handle_buttons import handle_followup_but
|
||||
from danswer.danswerbot.slack.handlers.handle_buttons import (
|
||||
handle_followup_resolved_button,
|
||||
)
|
||||
from danswer.danswerbot.slack.handlers.handle_buttons import (
|
||||
handle_generate_answer_button,
|
||||
)
|
||||
from danswer.danswerbot.slack.handlers.handle_buttons import handle_slack_feedback
|
||||
from danswer.danswerbot.slack.handlers.handle_message import handle_message
|
||||
from danswer.danswerbot.slack.handlers.handle_message import (
|
||||
@@ -56,7 +60,6 @@ from shared_configs.configs import MODEL_SERVER_PORT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# In rare cases, some users have been experiencing a massive amount of trivial messages coming through
|
||||
# to the Slack Bot with trivial messages. Adding this to avoid exploding LLM costs while we track down
|
||||
# the cause.
|
||||
@@ -70,6 +73,9 @@ _SLACK_GREETINGS_TO_IGNORE = {
|
||||
":wave:",
|
||||
}
|
||||
|
||||
# this is always (currently) the user id of Slack's official slackbot
|
||||
_OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
|
||||
|
||||
|
||||
def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool:
|
||||
"""True to keep going, False to ignore this Slack request"""
|
||||
@@ -92,6 +98,15 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
|
||||
channel_specific_logger.error("Cannot respond to empty message - skipping")
|
||||
return False
|
||||
|
||||
if (
|
||||
req.payload.setdefault("event", {}).get("user", "")
|
||||
== _OFFICIAL_SLACKBOT_USER_ID
|
||||
):
|
||||
channel_specific_logger.info(
|
||||
"Ignoring messages from Slack's official Slackbot"
|
||||
)
|
||||
return False
|
||||
|
||||
if (
|
||||
msg in _SLACK_GREETINGS_TO_IGNORE
|
||||
or remove_danswer_bot_tag(msg, client=client.web_client)
|
||||
@@ -255,6 +270,7 @@ def build_request_details(
|
||||
thread_messages=thread_messages,
|
||||
channel_to_respond=channel,
|
||||
msg_to_respond=cast(str, message_ts or thread_ts),
|
||||
thread_to_respond=cast(str, thread_ts or message_ts),
|
||||
sender=event.get("user") or None,
|
||||
bypass_filters=tagged,
|
||||
is_bot_msg=False,
|
||||
@@ -272,6 +288,7 @@ def build_request_details(
|
||||
thread_messages=[single_msg],
|
||||
channel_to_respond=channel,
|
||||
msg_to_respond=None,
|
||||
thread_to_respond=None,
|
||||
sender=sender,
|
||||
bypass_filters=True,
|
||||
is_bot_msg=True,
|
||||
@@ -341,7 +358,7 @@ def process_message(
|
||||
|
||||
failed = handle_message(
|
||||
message_info=details,
|
||||
channel_config=slack_bot_config,
|
||||
slack_bot_config=slack_bot_config,
|
||||
client=client.web_client,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
@@ -379,6 +396,8 @@ def action_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
|
||||
return handle_followup_resolved_button(req, client, immediate=True)
|
||||
elif action["action_id"] == FOLLOWUP_BUTTON_RESOLVED_ACTION_ID:
|
||||
return handle_followup_resolved_button(req, client, immediate=False)
|
||||
elif action["action_id"] == GENERATE_ANSWER_BUTTON_ACTION_ID:
|
||||
return handle_generate_answer_button(req, client)
|
||||
|
||||
|
||||
def view_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
|
||||
@@ -450,13 +469,13 @@ if __name__ == "__main__":
|
||||
# or the tokens have updated (set up for the first time)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
warm_up_encoders(
|
||||
model_name=embedding_model.model_name,
|
||||
normalize=embedding_model.normalize,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
if embedding_model.cloud_provider_id is None:
|
||||
warm_up_encoders(
|
||||
model_name=embedding_model.model_name,
|
||||
normalize=embedding_model.normalize,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
slack_bot_tokens = latest_slack_bot_tokens
|
||||
# potentially may cause a message to be dropped, but it is complicated
|
||||
|
||||
@@ -7,6 +7,7 @@ class SlackMessageInfo(BaseModel):
|
||||
thread_messages: list[ThreadMessage]
|
||||
channel_to_respond: str
|
||||
msg_to_respond: str | None
|
||||
thread_to_respond: str | None
|
||||
sender: str | None
|
||||
bypass_filters: bool # User has tagged @DanswerBot
|
||||
is_bot_msg: bool # User is using /DanswerBot
|
||||
|
||||
@@ -30,7 +30,7 @@ from danswer.danswerbot.slack.tokens import fetch_tokens
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
@@ -58,7 +58,7 @@ def rephrase_slack_message(msg: str) -> str:
|
||||
return messages
|
||||
|
||||
try:
|
||||
llm = get_default_llm(use_fast_llm=False, timeout=5)
|
||||
llm, _ = get_default_llms(timeout=5)
|
||||
except GenAIDisabledException:
|
||||
logger.warning("Unable to rephrase Slack user message, Gen AI disabled")
|
||||
return msg
|
||||
@@ -77,17 +77,25 @@ def update_emote_react(
|
||||
remove: bool,
|
||||
client: WebClient,
|
||||
) -> None:
|
||||
if not message_ts:
|
||||
logger.error(f"Tried to remove a react in {channel} but no message specified")
|
||||
return
|
||||
try:
|
||||
if not message_ts:
|
||||
logger.error(
|
||||
f"Tried to remove a react in {channel} but no message specified"
|
||||
)
|
||||
return
|
||||
|
||||
func = client.reactions_remove if remove else client.reactions_add
|
||||
slack_call = make_slack_api_rate_limited(func) # type: ignore
|
||||
slack_call(
|
||||
name=emoji,
|
||||
channel=channel,
|
||||
timestamp=message_ts,
|
||||
)
|
||||
func = client.reactions_remove if remove else client.reactions_add
|
||||
slack_call = make_slack_api_rate_limited(func) # type: ignore
|
||||
slack_call(
|
||||
name=emoji,
|
||||
channel=channel,
|
||||
timestamp=message_ts,
|
||||
)
|
||||
except SlackApiError as e:
|
||||
if remove:
|
||||
logger.error(f"Failed to remove Reaction due to: {e}")
|
||||
else:
|
||||
logger.error(f"Was not able to react to user message due to: {e}")
|
||||
|
||||
|
||||
def get_danswer_bot_app_id(web_client: WebClient) -> Any:
|
||||
@@ -136,16 +144,13 @@ def respond_in_thread(
|
||||
receiver_ids: list[str] | None = None,
|
||||
metadata: Metadata | None = None,
|
||||
unfurl: bool = True,
|
||||
) -> None:
|
||||
) -> list[str]:
|
||||
if not text and not blocks:
|
||||
raise ValueError("One of `text` or `blocks` must be provided")
|
||||
|
||||
message_ids: list[str] = []
|
||||
if not receiver_ids:
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postMessage)
|
||||
else:
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
|
||||
|
||||
if not receiver_ids:
|
||||
response = slack_call(
|
||||
channel=channel,
|
||||
text=text,
|
||||
@@ -157,7 +162,9 @@ def respond_in_thread(
|
||||
)
|
||||
if not response.get("ok"):
|
||||
raise RuntimeError(f"Failed to post message: {response}")
|
||||
message_ids.append(response["message_ts"])
|
||||
else:
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
|
||||
for receiver in receiver_ids:
|
||||
response = slack_call(
|
||||
channel=channel,
|
||||
@@ -171,6 +178,9 @@ def respond_in_thread(
|
||||
)
|
||||
if not response.get("ok"):
|
||||
raise RuntimeError(f"Failed to post message: {response}")
|
||||
message_ids.append(response["message_ts"])
|
||||
|
||||
return message_ids
|
||||
|
||||
|
||||
def build_feedback_id(
|
||||
@@ -292,7 +302,7 @@ def get_channel_name_from_id(
|
||||
raise e
|
||||
|
||||
|
||||
def fetch_userids_from_emails(
|
||||
def fetch_user_ids_from_emails(
|
||||
user_emails: list[str], client: WebClient
|
||||
) -> tuple[list[str], list[str]]:
|
||||
user_ids: list[str] = []
|
||||
@@ -308,32 +318,72 @@ def fetch_userids_from_emails(
|
||||
return user_ids, failed_to_find
|
||||
|
||||
|
||||
def fetch_groupids_from_names(
|
||||
names: list[str], client: WebClient
|
||||
def fetch_user_ids_from_groups(
|
||||
given_names: list[str], client: WebClient
|
||||
) -> tuple[list[str], list[str]]:
|
||||
group_ids: set[str] = set()
|
||||
user_ids: list[str] = []
|
||||
failed_to_find: list[str] = []
|
||||
try:
|
||||
response = client.usergroups_list()
|
||||
if not isinstance(response.data, dict):
|
||||
logger.error("Error fetching user groups")
|
||||
return user_ids, given_names
|
||||
|
||||
all_group_data = response.data.get("usergroups", [])
|
||||
name_id_map = {d["name"]: d["id"] for d in all_group_data}
|
||||
handle_id_map = {d["handle"]: d["id"] for d in all_group_data}
|
||||
for given_name in given_names:
|
||||
group_id = name_id_map.get(given_name) or handle_id_map.get(
|
||||
given_name.lstrip("@")
|
||||
)
|
||||
if not group_id:
|
||||
failed_to_find.append(given_name)
|
||||
continue
|
||||
try:
|
||||
response = client.usergroups_users_list(usergroup=group_id)
|
||||
if isinstance(response.data, dict):
|
||||
user_ids.extend(response.data.get("users", []))
|
||||
else:
|
||||
failed_to_find.append(given_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching user group ids: {str(e)}")
|
||||
failed_to_find.append(given_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching user groups: {str(e)}")
|
||||
failed_to_find = given_names
|
||||
|
||||
return user_ids, failed_to_find
|
||||
|
||||
|
||||
def fetch_group_ids_from_names(
|
||||
given_names: list[str], client: WebClient
|
||||
) -> tuple[list[str], list[str]]:
|
||||
group_data: list[str] = []
|
||||
failed_to_find: list[str] = []
|
||||
|
||||
try:
|
||||
response = client.usergroups_list()
|
||||
if response.get("ok") and "usergroups" in response.data:
|
||||
all_groups_dicts = response.data["usergroups"] # type: ignore
|
||||
name_id_map = {d["name"]: d["id"] for d in all_groups_dicts}
|
||||
handle_id_map = {d["handle"]: d["id"] for d in all_groups_dicts}
|
||||
for group in names:
|
||||
if group in name_id_map:
|
||||
group_ids.add(name_id_map[group])
|
||||
elif group in handle_id_map:
|
||||
group_ids.add(handle_id_map[group])
|
||||
else:
|
||||
failed_to_find.append(group)
|
||||
else:
|
||||
# Most likely a Slack App scope issue
|
||||
if not isinstance(response.data, dict):
|
||||
logger.error("Error fetching user groups")
|
||||
return group_data, given_names
|
||||
|
||||
all_group_data = response.data.get("usergroups", [])
|
||||
|
||||
name_id_map = {d["name"]: d["id"] for d in all_group_data}
|
||||
handle_id_map = {d["handle"]: d["id"] for d in all_group_data}
|
||||
|
||||
for given_name in given_names:
|
||||
id = handle_id_map.get(given_name.lstrip("@"))
|
||||
id = id or name_id_map.get(given_name)
|
||||
if id:
|
||||
group_data.append(id)
|
||||
else:
|
||||
failed_to_find.append(given_name)
|
||||
except Exception as e:
|
||||
failed_to_find = given_names
|
||||
logger.error(f"Error fetching user groups: {str(e)}")
|
||||
|
||||
return list(group_ids), failed_to_find
|
||||
return group_data, failed_to_find
|
||||
|
||||
|
||||
def fetch_user_semantic_id_from_id(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
@@ -16,6 +17,20 @@ from danswer.db.engine import get_sqlalchemy_async_engine
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import OAuthAccount
|
||||
from danswer.db.models import User
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
|
||||
|
||||
def get_default_admin_user_emails() -> list[str]:
|
||||
"""Returns a list of emails who should default to Admin role.
|
||||
Only used in the EE version. For MIT, just return empty list."""
|
||||
get_default_admin_user_emails_fn: Callable[
|
||||
[], list[str]
|
||||
] = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.auth.users", "get_default_admin_user_emails_", lambda: []
|
||||
)
|
||||
return get_default_admin_user_emails_fn()
|
||||
|
||||
|
||||
async def get_user_count() -> int:
|
||||
@@ -32,7 +47,7 @@ async def get_user_count() -> int:
|
||||
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase):
|
||||
async def create(self, create_dict: Dict[str, Any]) -> UP:
|
||||
user_count = await get_user_count()
|
||||
if user_count == 0:
|
||||
if user_count == 0 or create_dict["email"] in get_default_admin_user_emails():
|
||||
create_dict["role"] = UserRole.ADMIN
|
||||
else:
|
||||
create_dict["role"] = UserRole.BASIC
|
||||
|
||||
@@ -1,17 +1,26 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import nullsfirst
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.exc import MultipleResultsFound
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.chat.models import LLMRelevanceSummaryResponse
|
||||
from danswer.configs.chat_configs import HARD_DELETE_CHATS
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import ChatMessage__SearchDoc
|
||||
from danswer.db.models import ChatSession
|
||||
from danswer.db.models import ChatSessionSharedStatus
|
||||
from danswer.db.models import Prompt
|
||||
@@ -19,6 +28,7 @@ from danswer.db.models import SearchDoc
|
||||
from danswer.db.models import SearchDoc as DBSearchDoc
|
||||
from danswer.db.models import ToolCall
|
||||
from danswer.db.models import User
|
||||
from danswer.db.pg_file_store import delete_lobj_by_name
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.llm.override_models import LLMOverride
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
@@ -29,6 +39,7 @@ from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -63,17 +74,59 @@ def get_chat_session_by_id(
|
||||
return chat_session
|
||||
|
||||
|
||||
def get_chat_sessions_by_slack_thread_id(
|
||||
slack_thread_id: str,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> Sequence[ChatSession]:
|
||||
stmt = select(ChatSession).where(ChatSession.slack_thread_id == slack_thread_id)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(ChatSession.user_id == user_id, ChatSession.user_id.is_(None))
|
||||
)
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def get_first_messages_for_chat_sessions(
|
||||
chat_session_ids: list[int], db_session: Session
|
||||
) -> dict[int, str]:
|
||||
subquery = (
|
||||
select(ChatMessage.chat_session_id, func.min(ChatMessage.id).label("min_id"))
|
||||
.where(
|
||||
and_(
|
||||
ChatMessage.chat_session_id.in_(chat_session_ids),
|
||||
ChatMessage.message_type == MessageType.USER, # Select USER messages
|
||||
)
|
||||
)
|
||||
.group_by(ChatMessage.chat_session_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = select(ChatMessage.chat_session_id, ChatMessage.message).join(
|
||||
subquery,
|
||||
(ChatMessage.chat_session_id == subquery.c.chat_session_id)
|
||||
& (ChatMessage.id == subquery.c.min_id),
|
||||
)
|
||||
|
||||
first_messages = db_session.execute(query).all()
|
||||
return dict([(row.chat_session_id, row.message) for row in first_messages])
|
||||
|
||||
|
||||
def get_chat_sessions_by_user(
|
||||
user_id: UUID | None,
|
||||
deleted: bool | None,
|
||||
db_session: Session,
|
||||
include_one_shot: bool = False,
|
||||
only_one_shot: bool = False,
|
||||
) -> list[ChatSession]:
|
||||
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
|
||||
|
||||
if not include_one_shot:
|
||||
if only_one_shot:
|
||||
stmt = stmt.where(ChatSession.one_shot.is_(True))
|
||||
else:
|
||||
stmt = stmt.where(ChatSession.one_shot.is_(False))
|
||||
|
||||
stmt = stmt.order_by(desc(ChatSession.time_created))
|
||||
|
||||
if deleted is not None:
|
||||
stmt = stmt.where(ChatSession.deleted == deleted)
|
||||
|
||||
@@ -83,15 +136,71 @@ def get_chat_sessions_by_user(
|
||||
return list(chat_sessions)
|
||||
|
||||
|
||||
def delete_search_doc_message_relationship(
|
||||
message_id: int, db_session: Session
|
||||
) -> None:
|
||||
db_session.query(ChatMessage__SearchDoc).filter(
|
||||
ChatMessage__SearchDoc.chat_message_id == message_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_tool_call_for_message_id(message_id: int, db_session: Session) -> None:
|
||||
stmt = delete(ToolCall).where(ToolCall.message_id == message_id)
|
||||
db_session.execute(stmt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_orphaned_search_docs(db_session: Session) -> None:
|
||||
orphaned_docs = (
|
||||
db_session.query(SearchDoc)
|
||||
.outerjoin(ChatMessage__SearchDoc)
|
||||
.filter(ChatMessage__SearchDoc.chat_message_id.is_(None))
|
||||
.all()
|
||||
)
|
||||
for doc in orphaned_docs:
|
||||
db_session.delete(doc)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_messages_and_files_from_chat_session(
|
||||
chat_session_id: int, db_session: Session
|
||||
) -> None:
|
||||
# Select messages older than cutoff_time with files
|
||||
messages_with_files = db_session.execute(
|
||||
select(ChatMessage.id, ChatMessage.files).where(
|
||||
ChatMessage.chat_session_id == chat_session_id,
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for id, files in messages_with_files:
|
||||
delete_tool_call_for_message_id(message_id=id, db_session=db_session)
|
||||
delete_search_doc_message_relationship(message_id=id, db_session=db_session)
|
||||
for file_info in files or {}:
|
||||
lobj_name = file_info.get("id")
|
||||
if lobj_name:
|
||||
logger.info(f"Deleting file with name: {lobj_name}")
|
||||
delete_lobj_by_name(lobj_name, db_session)
|
||||
|
||||
db_session.execute(
|
||||
delete(ChatMessage).where(ChatMessage.chat_session_id == chat_session_id)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
delete_orphaned_search_docs(db_session)
|
||||
|
||||
|
||||
def create_chat_session(
|
||||
db_session: Session,
|
||||
description: str,
|
||||
user_id: UUID | None,
|
||||
persona_id: int | None = None,
|
||||
persona_id: int,
|
||||
llm_override: LLMOverride | None = None,
|
||||
prompt_override: PromptOverride | None = None,
|
||||
one_shot: bool = False,
|
||||
danswerbot_flow: bool = False,
|
||||
slack_thread_id: str | None = None,
|
||||
) -> ChatSession:
|
||||
chat_session = ChatSession(
|
||||
user_id=user_id,
|
||||
@@ -101,6 +210,7 @@ def create_chat_session(
|
||||
prompt_override=prompt_override,
|
||||
one_shot=one_shot,
|
||||
danswerbot_flow=danswerbot_flow,
|
||||
slack_thread_id=slack_thread_id,
|
||||
)
|
||||
|
||||
db_session.add(chat_session)
|
||||
@@ -139,25 +249,30 @@ def delete_chat_session(
|
||||
db_session: Session,
|
||||
hard_delete: bool = HARD_DELETE_CHATS,
|
||||
) -> None:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
|
||||
)
|
||||
|
||||
if hard_delete:
|
||||
stmt_messages = delete(ChatMessage).where(
|
||||
ChatMessage.chat_session_id == chat_session_id
|
||||
)
|
||||
db_session.execute(stmt_messages)
|
||||
|
||||
stmt = delete(ChatSession).where(ChatSession.id == chat_session_id)
|
||||
db_session.execute(stmt)
|
||||
|
||||
delete_messages_and_files_from_chat_session(chat_session_id, db_session)
|
||||
db_session.execute(delete(ChatSession).where(ChatSession.id == chat_session_id))
|
||||
else:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
|
||||
)
|
||||
chat_session.deleted = True
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_chat_sessions_older_than(days_old: int, db_session: Session) -> None:
|
||||
cutoff_time = datetime.utcnow() - timedelta(days=days_old)
|
||||
old_sessions = db_session.execute(
|
||||
select(ChatSession.user_id, ChatSession.id).where(
|
||||
ChatSession.time_created < cutoff_time
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for user_id, session_id in old_sessions:
|
||||
delete_chat_session(user_id, session_id, db_session, hard_delete=True)
|
||||
|
||||
|
||||
def get_chat_message(
|
||||
chat_message_id: int,
|
||||
user_id: UUID | None,
|
||||
@@ -183,6 +298,39 @@ def get_chat_message(
|
||||
return chat_message
|
||||
|
||||
|
||||
def get_chat_messages_by_sessions(
|
||||
chat_session_ids: list[int],
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
skip_permission_check: bool = False,
|
||||
) -> Sequence[ChatMessage]:
|
||||
if not skip_permission_check:
|
||||
for chat_session_id in chat_session_ids:
|
||||
get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
|
||||
)
|
||||
stmt = (
|
||||
select(ChatMessage)
|
||||
.where(ChatMessage.chat_session_id.in_(chat_session_ids))
|
||||
.order_by(nullsfirst(ChatMessage.parent_message))
|
||||
)
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def get_search_docs_for_chat_message(
|
||||
chat_message_id: int, db_session: Session
|
||||
) -> list[SearchDoc]:
|
||||
stmt = (
|
||||
select(SearchDoc)
|
||||
.join(
|
||||
ChatMessage__SearchDoc, ChatMessage__SearchDoc.search_doc_id == SearchDoc.id
|
||||
)
|
||||
.where(ChatMessage__SearchDoc.chat_message_id == chat_message_id)
|
||||
)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def get_chat_messages_by_session(
|
||||
chat_session_id: int,
|
||||
user_id: UUID | None,
|
||||
@@ -203,8 +351,6 @@ def get_chat_messages_by_session(
|
||||
|
||||
if prefetch_tool_calls:
|
||||
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
|
||||
|
||||
if prefetch_tool_calls:
|
||||
result = db_session.scalars(stmt).unique().all()
|
||||
else:
|
||||
result = db_session.scalars(stmt).all()
|
||||
@@ -259,6 +405,7 @@ def create_new_chat_message(
|
||||
rephrased_query: str | None = None,
|
||||
error: str | None = None,
|
||||
reference_docs: list[DBSearchDoc] | None = None,
|
||||
alternate_assistant_id: int | None = None,
|
||||
# Maps the citation number [n] to the DB SearchDoc
|
||||
citations: dict[int, int] | None = None,
|
||||
tool_calls: list[ToolCall] | None = None,
|
||||
@@ -277,6 +424,7 @@ def create_new_chat_message(
|
||||
files=files,
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
error=error,
|
||||
alternate_assistant_id=alternate_assistant_id,
|
||||
)
|
||||
|
||||
# SQL Alchemy will propagate this to update the reference_docs' foreign keys
|
||||
@@ -371,16 +519,46 @@ def get_doc_query_identifiers_from_model(
|
||||
)
|
||||
raise ValueError("Docs references do not belong to user")
|
||||
|
||||
if any(
|
||||
[doc.chat_messages[0].chat_session_id != chat_session.id for doc in search_docs]
|
||||
):
|
||||
raise ValueError("Invalid reference doc, not from this chat session.")
|
||||
try:
|
||||
if any(
|
||||
[
|
||||
doc.chat_messages[0].chat_session_id != chat_session.id
|
||||
for doc in search_docs
|
||||
]
|
||||
):
|
||||
raise ValueError("Invalid reference doc, not from this chat session.")
|
||||
except IndexError:
|
||||
# This happens when the doc has no chat_messages associated with it.
|
||||
# which happens as an edge case where the chat message failed to save
|
||||
# This usually happens when the LLM fails either immediately or partially through.
|
||||
raise RuntimeError("Chat session failed, please start a new session.")
|
||||
|
||||
doc_query_identifiers = [(doc.document_id, doc.chunk_ind) for doc in search_docs]
|
||||
|
||||
return doc_query_identifiers
|
||||
|
||||
|
||||
def update_search_docs_table_with_relevance(
|
||||
db_session: Session,
|
||||
reference_db_search_docs: list[SearchDoc],
|
||||
relevance_summary: LLMRelevanceSummaryResponse,
|
||||
) -> None:
|
||||
for search_doc in reference_db_search_docs:
|
||||
relevance_data = relevance_summary.relevance_summaries.get(
|
||||
f"{search_doc.document_id}-{search_doc.chunk_ind}"
|
||||
)
|
||||
if relevance_data is not None:
|
||||
db_session.execute(
|
||||
update(SearchDoc)
|
||||
.where(SearchDoc.id == search_doc.id)
|
||||
.values(
|
||||
is_relevant=relevance_data.relevant,
|
||||
relevance_explanation=relevance_data.content,
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_db_search_doc(
|
||||
server_search_doc: ServerSearchDoc,
|
||||
db_session: Session,
|
||||
@@ -395,17 +573,19 @@ def create_db_search_doc(
|
||||
boost=server_search_doc.boost,
|
||||
hidden=server_search_doc.hidden,
|
||||
doc_metadata=server_search_doc.metadata,
|
||||
is_relevant=server_search_doc.is_relevant,
|
||||
relevance_explanation=server_search_doc.relevance_explanation,
|
||||
# For docs further down that aren't reranked, we can't use the retrieval score
|
||||
score=server_search_doc.score or 0.0,
|
||||
match_highlights=server_search_doc.match_highlights,
|
||||
updated_at=server_search_doc.updated_at,
|
||||
primary_owners=server_search_doc.primary_owners,
|
||||
secondary_owners=server_search_doc.secondary_owners,
|
||||
is_internet=server_search_doc.is_internet,
|
||||
)
|
||||
|
||||
db_session.add(db_search_doc)
|
||||
db_session.commit()
|
||||
|
||||
return db_search_doc
|
||||
|
||||
|
||||
@@ -431,14 +611,17 @@ def translate_db_search_doc_to_server_search_doc(
|
||||
hidden=db_search_doc.hidden,
|
||||
metadata=db_search_doc.doc_metadata if not remove_doc_content else {},
|
||||
score=db_search_doc.score,
|
||||
match_highlights=db_search_doc.match_highlights
|
||||
if not remove_doc_content
|
||||
else [],
|
||||
match_highlights=(
|
||||
db_search_doc.match_highlights if not remove_doc_content else []
|
||||
),
|
||||
relevance_explanation=db_search_doc.relevance_explanation,
|
||||
is_relevant=db_search_doc.is_relevant,
|
||||
updated_at=db_search_doc.updated_at if not remove_doc_content else None,
|
||||
primary_owners=db_search_doc.primary_owners if not remove_doc_content else [],
|
||||
secondary_owners=db_search_doc.secondary_owners
|
||||
if not remove_doc_content
|
||||
else [],
|
||||
secondary_owners=(
|
||||
db_search_doc.secondary_owners if not remove_doc_content else []
|
||||
),
|
||||
is_internet=db_search_doc.is_internet,
|
||||
)
|
||||
|
||||
|
||||
@@ -456,9 +639,11 @@ def get_retrieval_docs_from_chat_message(
|
||||
|
||||
|
||||
def translate_db_message_to_chat_message_detail(
|
||||
chat_message: ChatMessage, remove_doc_content: bool = False
|
||||
chat_message: ChatMessage,
|
||||
remove_doc_content: bool = False,
|
||||
) -> ChatMessageDetail:
|
||||
chat_msg_detail = ChatMessageDetail(
|
||||
chat_session_id=chat_message.chat_session_id,
|
||||
message_id=chat_message.id,
|
||||
parent_message=chat_message.parent_message,
|
||||
latest_child_message=chat_message.latest_child_message,
|
||||
@@ -479,6 +664,7 @@ def translate_db_message_to_chat_message_detail(
|
||||
)
|
||||
for tool_call in chat_message.tool_calls
|
||||
],
|
||||
alternate_assistant_id=chat_message.alternate_assistant_id,
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
||||
@@ -7,6 +7,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import DEFAULT_PRUNING_FREQ
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.models import Connector
|
||||
@@ -84,6 +85,9 @@ def create_connector(
|
||||
input_type=connector_data.input_type,
|
||||
connector_specific_config=connector_data.connector_specific_config,
|
||||
refresh_freq=connector_data.refresh_freq,
|
||||
prune_freq=connector_data.prune_freq
|
||||
if connector_data.prune_freq is not None
|
||||
else DEFAULT_PRUNING_FREQ,
|
||||
disabled=connector_data.disabled,
|
||||
)
|
||||
db_session.add(connector)
|
||||
@@ -113,6 +117,11 @@ def update_connector(
|
||||
connector.input_type = connector_data.input_type
|
||||
connector.connector_specific_config = connector_data.connector_specific_config
|
||||
connector.refresh_freq = connector_data.refresh_freq
|
||||
connector.prune_freq = (
|
||||
connector_data.prune_freq
|
||||
if connector_data.prune_freq is not None
|
||||
else DEFAULT_PRUNING_FREQ
|
||||
)
|
||||
connector.disabled = connector_data.disabled
|
||||
|
||||
db_session.commit()
|
||||
@@ -259,6 +268,7 @@ def create_initial_default_connector(db_session: Session) -> None:
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={},
|
||||
refresh_freq=None,
|
||||
prune_freq=None,
|
||||
)
|
||||
db_session.add(connector)
|
||||
db_session.commit()
|
||||
|
||||
@@ -151,7 +151,8 @@ def add_credential_to_connector(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
cc_pair_name: str | None,
|
||||
user: User,
|
||||
is_public: bool,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> StatusResponse[int]:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
@@ -185,6 +186,7 @@ def add_credential_to_connector(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
name=cc_pair_name,
|
||||
is_public=is_public,
|
||||
)
|
||||
db_session.add(association)
|
||||
db_session.commit()
|
||||
@@ -199,7 +201,7 @@ def add_credential_to_connector(
|
||||
def remove_credential_from_connector(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
user: User,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> StatusResponse[int]:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
|
||||
@@ -12,6 +12,7 @@ from danswer.connectors.gmail.constants import (
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import User
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
@@ -142,6 +143,18 @@ def delete_credential(
|
||||
f"Credential by provided id {credential_id} does not exist or does not belong to user"
|
||||
)
|
||||
|
||||
associated_connectors = (
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.filter(ConnectorCredentialPair.credential_id == credential_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
if associated_connectors:
|
||||
raise ValueError(
|
||||
f"Cannot delete credential {credential_id} as it is still associated with {len(associated_connectors)} connector(s). "
|
||||
"Please delete all associated connectors first."
|
||||
)
|
||||
|
||||
db_session.delete(credential)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -327,6 +327,7 @@ def prepare_to_modify_documents(
|
||||
Multiple commits will result in a sqlalchemy.exc.InvalidRequestError.
|
||||
NOTE: this function will commit any existing transaction.
|
||||
"""
|
||||
|
||||
db_session.commit() # ensure that we're not in a transaction
|
||||
|
||||
lock_acquired = False
|
||||
|
||||
@@ -10,10 +10,15 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
from danswer.db.llm import fetch_embedding_provider
|
||||
from danswer.db.models import CloudEmbeddingProvider
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.indexing.models import EmbeddingModelDetail
|
||||
from danswer.search.search_nlp_models import clean_model_name
|
||||
from danswer.server.manage.embedding.models import (
|
||||
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
|
||||
)
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -31,6 +36,7 @@ def create_embedding_model(
|
||||
query_prefix=model_details.query_prefix,
|
||||
passage_prefix=model_details.passage_prefix,
|
||||
status=status,
|
||||
cloud_provider_id=model_details.cloud_provider_id,
|
||||
# Every single embedding model except the initial one from migrations has this name
|
||||
# The initial one from migration is called "danswer_chunk"
|
||||
index_name=f"danswer_chunk_{clean_model_name(model_details.model_name)}",
|
||||
@@ -42,6 +48,42 @@ def create_embedding_model(
|
||||
return embedding_model
|
||||
|
||||
|
||||
def get_model_id_from_name(
|
||||
db_session: Session, embedding_provider_name: str
|
||||
) -> int | None:
|
||||
query = select(CloudEmbeddingProvider).where(
|
||||
CloudEmbeddingProvider.name == embedding_provider_name
|
||||
)
|
||||
provider = db_session.execute(query).scalars().first()
|
||||
return provider.id if provider else None
|
||||
|
||||
|
||||
def get_current_db_embedding_provider(
|
||||
db_session: Session,
|
||||
) -> ServerCloudEmbeddingProvider | None:
|
||||
current_embedding_model = EmbeddingModelDetail.from_model(
|
||||
get_current_db_embedding_model(db_session=db_session)
|
||||
)
|
||||
|
||||
if (
|
||||
current_embedding_model is None
|
||||
or current_embedding_model.cloud_provider_id is None
|
||||
):
|
||||
return None
|
||||
|
||||
embedding_provider = fetch_embedding_provider(
|
||||
db_session=db_session, provider_id=current_embedding_model.cloud_provider_id
|
||||
)
|
||||
if embedding_provider is None:
|
||||
raise RuntimeError("No embedding provider exists for this model.")
|
||||
|
||||
current_embedding_provider = ServerCloudEmbeddingProvider.from_request(
|
||||
cloud_provider_model=embedding_provider
|
||||
)
|
||||
|
||||
return current_embedding_provider
|
||||
|
||||
|
||||
def get_current_db_embedding_model(db_session: Session) -> EmbeddingModel:
|
||||
query = (
|
||||
select(EmbeddingModel)
|
||||
|
||||
@@ -2,11 +2,34 @@ from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from danswer.server.manage.llm.models import FullLLMProvider
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
|
||||
|
||||
def upsert_cloud_embedding_provider(
|
||||
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
|
||||
) -> CloudEmbeddingProvider:
|
||||
existing_provider = (
|
||||
db_session.query(CloudEmbeddingProviderModel)
|
||||
.filter_by(name=provider.name)
|
||||
.first()
|
||||
)
|
||||
if existing_provider:
|
||||
for key, value in provider.dict().items():
|
||||
setattr(existing_provider, key, value)
|
||||
else:
|
||||
new_provider = CloudEmbeddingProviderModel(**provider.dict())
|
||||
db_session.add(new_provider)
|
||||
existing_provider = new_provider
|
||||
db_session.commit()
|
||||
db_session.refresh(existing_provider)
|
||||
return CloudEmbeddingProvider.from_request(existing_provider)
|
||||
|
||||
|
||||
def upsert_llm_provider(
|
||||
db_session: Session, llm_provider: LLMProviderUpsertRequest
|
||||
) -> FullLLMProvider:
|
||||
@@ -26,7 +49,6 @@ def upsert_llm_provider(
|
||||
existing_llm_provider.model_names = llm_provider.model_names
|
||||
db_session.commit()
|
||||
return FullLLMProvider.from_model(existing_llm_provider)
|
||||
|
||||
# if it does not exist, create a new entry
|
||||
llm_provider_model = LLMProviderModel(
|
||||
name=llm_provider.name,
|
||||
@@ -46,10 +68,26 @@ def upsert_llm_provider(
|
||||
return FullLLMProvider.from_model(llm_provider_model)
|
||||
|
||||
|
||||
def fetch_existing_embedding_providers(
|
||||
db_session: Session,
|
||||
) -> list[CloudEmbeddingProviderModel]:
|
||||
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
|
||||
|
||||
|
||||
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
|
||||
return list(db_session.scalars(select(LLMProviderModel)).all())
|
||||
|
||||
|
||||
def fetch_embedding_provider(
|
||||
db_session: Session, provider_id: int
|
||||
) -> CloudEmbeddingProviderModel | None:
|
||||
return db_session.scalar(
|
||||
select(CloudEmbeddingProviderModel).where(
|
||||
CloudEmbeddingProviderModel.id == provider_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
@@ -70,6 +108,16 @@ def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider |
|
||||
return FullLLMProvider.from_model(provider_model)
|
||||
|
||||
|
||||
def remove_embedding_provider(
|
||||
db_session: Session, embedding_provider_name: str
|
||||
) -> None:
|
||||
db_session.execute(
|
||||
delete(CloudEmbeddingProviderModel).where(
|
||||
CloudEmbeddingProviderModel.name == embedding_provider_name
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
||||
db_session.execute(
|
||||
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
|
||||
@@ -130,6 +130,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
chat_folders: Mapped[list["ChatFolder"]] = relationship(
|
||||
"ChatFolder", back_populates="user"
|
||||
)
|
||||
|
||||
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
|
||||
# Personas owned by this user
|
||||
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
|
||||
@@ -246,6 +247,39 @@ class Persona__Tool(Base):
|
||||
tool_id: Mapped[int] = mapped_column(ForeignKey("tool.id"), primary_key=True)
|
||||
|
||||
|
||||
class StandardAnswer__StandardAnswerCategory(Base):
|
||||
__tablename__ = "standard_answer__standard_answer_category"
|
||||
|
||||
standard_answer_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("standard_answer.id"), primary_key=True
|
||||
)
|
||||
standard_answer_category_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("standard_answer_category.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class SlackBotConfig__StandardAnswerCategory(Base):
|
||||
__tablename__ = "slack_bot_config__standard_answer_category"
|
||||
|
||||
slack_bot_config_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("slack_bot_config.id"), primary_key=True
|
||||
)
|
||||
standard_answer_category_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("standard_answer_category.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class ChatMessage__StandardAnswer(Base):
|
||||
__tablename__ = "chat_message__standard_answer"
|
||||
|
||||
chat_message_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("chat_message.id"), primary_key=True
|
||||
)
|
||||
standard_answer_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("standard_answer.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Documents/Indexing Tables
|
||||
"""
|
||||
@@ -383,6 +417,7 @@ class Connector(Base):
|
||||
postgresql.JSONB()
|
||||
)
|
||||
refresh_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
prune_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
@@ -435,7 +470,7 @@ class Credential(Base):
|
||||
|
||||
class EmbeddingModel(Base):
|
||||
__tablename__ = "embedding_model"
|
||||
# ID is used also to indicate the order that the models are configured by the admin
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
model_name: Mapped[str] = mapped_column(String)
|
||||
model_dim: Mapped[int] = mapped_column(Integer)
|
||||
@@ -447,6 +482,16 @@ class EmbeddingModel(Base):
|
||||
)
|
||||
index_name: Mapped[str] = mapped_column(String)
|
||||
|
||||
# New field for cloud provider relationship
|
||||
cloud_provider_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("embedding_provider.id")
|
||||
)
|
||||
cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship(
|
||||
"CloudEmbeddingProvider",
|
||||
back_populates="embedding_models",
|
||||
foreign_keys=[cloud_provider_id],
|
||||
)
|
||||
|
||||
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
|
||||
"IndexAttempt", back_populates="embedding_model"
|
||||
)
|
||||
@@ -466,6 +511,18 @@ class EmbeddingModel(Base):
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
|
||||
cloud_provider='{self.cloud_provider.name if self.cloud_provider else 'None'}')>"
|
||||
|
||||
@property
|
||||
def api_key(self) -> str | None:
|
||||
return self.cloud_provider.api_key if self.cloud_provider else None
|
||||
|
||||
@property
|
||||
def provider_type(self) -> str | None:
|
||||
return self.cloud_provider.name if self.cloud_provider else None
|
||||
|
||||
|
||||
class IndexAttempt(Base):
|
||||
"""
|
||||
@@ -485,6 +542,7 @@ class IndexAttempt(Base):
|
||||
ForeignKey("credential.id"),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Some index attempts that run from beginning will still have this as False
|
||||
# This is only for attempts that are explicitly marked as from the start via
|
||||
# the run once API
|
||||
@@ -611,6 +669,10 @@ class SearchDoc(Base):
|
||||
secondary_owners: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
is_internet: Mapped[bool] = mapped_column(Boolean, default=False, nullable=True)
|
||||
|
||||
is_relevant: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
relevance_explanation: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
chat_messages = relationship(
|
||||
"ChatMessage",
|
||||
@@ -660,6 +722,12 @@ class ChatSession(Base):
|
||||
ForeignKey("chat_folder.id"), nullable=True
|
||||
)
|
||||
|
||||
current_alternate_model: Mapped[str | None] = mapped_column(String, default=None)
|
||||
|
||||
slack_thread_id: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True, default=None
|
||||
)
|
||||
|
||||
# the latest "overrides" specified by the user. These take precedence over
|
||||
# the attached persona. However, overrides specified directly in the
|
||||
# `send-message` call will take precedence over these.
|
||||
@@ -687,7 +755,7 @@ class ChatSession(Base):
|
||||
"ChatFolder", back_populates="chat_sessions"
|
||||
)
|
||||
messages: Mapped[list["ChatMessage"]] = relationship(
|
||||
"ChatMessage", back_populates="chat_session", cascade="delete"
|
||||
"ChatMessage", back_populates="chat_session"
|
||||
)
|
||||
persona: Mapped["Persona"] = relationship("Persona")
|
||||
|
||||
@@ -705,6 +773,11 @@ class ChatMessage(Base):
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
chat_session_id: Mapped[int] = mapped_column(ForeignKey("chat_session.id"))
|
||||
|
||||
alternate_assistant_id = mapped_column(
|
||||
Integer, ForeignKey("persona.id"), nullable=True
|
||||
)
|
||||
|
||||
parent_message: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
latest_child_message: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
message: Mapped[str] = mapped_column(Text)
|
||||
@@ -733,11 +806,15 @@ class ChatMessage(Base):
|
||||
|
||||
chat_session: Mapped[ChatSession] = relationship("ChatSession")
|
||||
prompt: Mapped[Optional["Prompt"]] = relationship("Prompt")
|
||||
|
||||
chat_message_feedbacks: Mapped[list["ChatMessageFeedback"]] = relationship(
|
||||
"ChatMessageFeedback", back_populates="chat_message"
|
||||
"ChatMessageFeedback",
|
||||
back_populates="chat_message",
|
||||
)
|
||||
|
||||
document_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship(
|
||||
"DocumentRetrievalFeedback", back_populates="chat_message"
|
||||
"DocumentRetrievalFeedback",
|
||||
back_populates="chat_message",
|
||||
)
|
||||
search_docs: Mapped[list["SearchDoc"]] = relationship(
|
||||
"SearchDoc",
|
||||
@@ -748,6 +825,11 @@ class ChatMessage(Base):
|
||||
"ToolCall",
|
||||
back_populates="message",
|
||||
)
|
||||
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
|
||||
"StandardAnswer",
|
||||
secondary=ChatMessage__StandardAnswer.__table__,
|
||||
back_populates="chat_messages",
|
||||
)
|
||||
|
||||
|
||||
class ChatFolder(Base):
|
||||
@@ -784,7 +866,9 @@ class DocumentRetrievalFeedback(Base):
|
||||
__tablename__ = "document_retrieval_feedback"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
chat_message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
|
||||
chat_message_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("chat_message.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
document_id: Mapped[str] = mapped_column(ForeignKey("document.id"))
|
||||
# How high up this document is in the results, 1 for first
|
||||
document_rank: Mapped[int] = mapped_column(Integer)
|
||||
@@ -794,7 +878,9 @@ class DocumentRetrievalFeedback(Base):
|
||||
)
|
||||
|
||||
chat_message: Mapped[ChatMessage] = relationship(
|
||||
"ChatMessage", back_populates="document_feedbacks"
|
||||
"ChatMessage",
|
||||
back_populates="document_feedbacks",
|
||||
foreign_keys=[chat_message_id],
|
||||
)
|
||||
document: Mapped[Document] = relationship(
|
||||
"Document", back_populates="retrieval_feedbacks"
|
||||
@@ -805,22 +891,21 @@ class ChatMessageFeedback(Base):
|
||||
__tablename__ = "chat_feedback"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
chat_message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
|
||||
chat_message_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("chat_message.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
is_positive: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
required_followup: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
feedback_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
predefined_feedback: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
chat_message: Mapped[ChatMessage] = relationship(
|
||||
"ChatMessage", back_populates="chat_message_feedbacks"
|
||||
"ChatMessage",
|
||||
back_populates="chat_message_feedbacks",
|
||||
foreign_keys=[chat_message_id],
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Structures, Organizational, Configurations Tables
|
||||
"""
|
||||
|
||||
|
||||
class LLMProvider(Base):
|
||||
__tablename__ = "llm_provider"
|
||||
|
||||
@@ -849,6 +934,29 @@ class LLMProvider(Base):
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(Base):
|
||||
__tablename__ = "embedding_provider"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString())
|
||||
default_model_id: Mapped[int | None] = mapped_column(
|
||||
Integer, ForeignKey("embedding_model.id"), nullable=True
|
||||
)
|
||||
|
||||
embedding_models: Mapped[list["EmbeddingModel"]] = relationship(
|
||||
"EmbeddingModel",
|
||||
back_populates="cloud_provider",
|
||||
foreign_keys="EmbeddingModel.cloud_provider_id",
|
||||
)
|
||||
default_model: Mapped["EmbeddingModel"] = relationship(
|
||||
"EmbeddingModel", foreign_keys=[default_model_id]
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<EmbeddingProvider(name='{self.name}')>"
|
||||
|
||||
|
||||
class DocumentSet(Base):
|
||||
__tablename__ = "document_set"
|
||||
|
||||
@@ -928,6 +1036,7 @@ class Tool(Base):
|
||||
# ID of the tool in the codebase, only applies for in-code tools.
|
||||
# tools defined via the UI will have this as None
|
||||
in_code_tool_id: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
display_name: Mapped[str] = mapped_column(String, nullable=True)
|
||||
|
||||
# OpenAPI scheme for the tool. Only applies to tools defined via the UI.
|
||||
openapi_schema: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
@@ -1057,13 +1166,60 @@ class ChannelConfig(TypedDict):
|
||||
channel_names: list[str]
|
||||
respond_tag_only: NotRequired[bool] # defaults to False
|
||||
respond_to_bots: NotRequired[bool] # defaults to False
|
||||
respond_team_member_list: NotRequired[list[str]]
|
||||
respond_member_group_list: NotRequired[list[str]]
|
||||
answer_filters: NotRequired[list[AllowedAnswerFilters]]
|
||||
# If None then no follow up
|
||||
# If empty list, follow up with no tags
|
||||
follow_up_tags: NotRequired[list[str]]
|
||||
|
||||
|
||||
class StandardAnswerCategory(Base):
|
||||
__tablename__ = "standard_answer_category"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
|
||||
"StandardAnswer",
|
||||
secondary=StandardAnswer__StandardAnswerCategory.__table__,
|
||||
back_populates="categories",
|
||||
)
|
||||
slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship(
|
||||
"SlackBotConfig",
|
||||
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
|
||||
back_populates="standard_answer_categories",
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswer(Base):
|
||||
__tablename__ = "standard_answer"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
keyword: Mapped[str] = mapped_column(String)
|
||||
answer: Mapped[str] = mapped_column(String)
|
||||
active: Mapped[bool] = mapped_column(Boolean)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"unique_keyword_active",
|
||||
keyword,
|
||||
active,
|
||||
unique=True,
|
||||
postgresql_where=(active == True), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
categories: Mapped[list[StandardAnswerCategory]] = relationship(
|
||||
"StandardAnswerCategory",
|
||||
secondary=StandardAnswer__StandardAnswerCategory.__table__,
|
||||
back_populates="standard_answers",
|
||||
)
|
||||
chat_messages: Mapped[list[ChatMessage]] = relationship(
|
||||
"ChatMessage",
|
||||
secondary=ChatMessage__StandardAnswer.__table__,
|
||||
back_populates="standard_answers",
|
||||
)
|
||||
|
||||
|
||||
class SlackBotResponseType(str, PyEnum):
|
||||
QUOTES = "quotes"
|
||||
CITATIONS = "citations"
|
||||
@@ -1084,7 +1240,16 @@ class SlackBotConfig(Base):
|
||||
Enum(SlackBotResponseType, native_enum=False), nullable=False
|
||||
)
|
||||
|
||||
enable_auto_filters: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
|
||||
persona: Mapped[Persona | None] = relationship("Persona")
|
||||
standard_answer_categories: Mapped[list[StandardAnswerCategory]] = relationship(
|
||||
"StandardAnswerCategory",
|
||||
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
|
||||
back_populates="slack_bot_configs",
|
||||
)
|
||||
|
||||
|
||||
class TaskQueueState(Base):
|
||||
@@ -1364,3 +1529,30 @@ class EmailToExternalUserCache(Base):
|
||||
)
|
||||
|
||||
user = relationship("User")
|
||||
|
||||
|
||||
class UsageReport(Base):
|
||||
"""This stores metadata about usage reports generated by admin including user who generated
|
||||
them as well las the period they cover. The actual zip file of the report is stored as a lo
|
||||
using the PGFileStore
|
||||
"""
|
||||
|
||||
__tablename__ = "usage_reports"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
report_name: Mapped[str] = mapped_column(ForeignKey("file_store.file_name"))
|
||||
|
||||
# if None, report was auto-generated
|
||||
requestor_user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id"), nullable=True
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
period_from: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True)
|
||||
)
|
||||
period_to: Mapped[datetime.datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
|
||||
requestor = relationship("User")
|
||||
file = relationship("PGFileStore")
|
||||
|
||||
@@ -12,8 +12,8 @@ from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from danswer.db.document_set import get_document_sets_by_ids
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import Persona
|
||||
@@ -62,19 +62,6 @@ def create_update_persona(
|
||||
) -> PersonaSnapshot:
|
||||
"""Higher level function than upsert_persona, although either is valid to use."""
|
||||
# Permission to actually use these is checked later
|
||||
document_sets = list(
|
||||
get_document_sets_by_ids(
|
||||
document_set_ids=create_persona_request.document_set_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
prompts = list(
|
||||
get_prompts_by_ids(
|
||||
prompt_ids=create_persona_request.prompt_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
persona = upsert_persona(
|
||||
persona_id=persona_id,
|
||||
@@ -85,9 +72,9 @@ def create_update_persona(
|
||||
llm_relevance_filter=create_persona_request.llm_relevance_filter,
|
||||
llm_filter_extraction=create_persona_request.llm_filter_extraction,
|
||||
recency_bias=create_persona_request.recency_bias,
|
||||
prompts=prompts,
|
||||
prompt_ids=create_persona_request.prompt_ids,
|
||||
tool_ids=create_persona_request.tool_ids,
|
||||
document_sets=document_sets,
|
||||
document_set_ids=create_persona_request.document_set_ids,
|
||||
llm_model_provider_override=create_persona_request.llm_model_provider_override,
|
||||
llm_model_version_override=create_persona_request.llm_model_version_override,
|
||||
starter_messages=create_persona_request.starter_messages,
|
||||
@@ -330,13 +317,13 @@ def upsert_persona(
|
||||
llm_relevance_filter: bool,
|
||||
llm_filter_extraction: bool,
|
||||
recency_bias: RecencyBiasSetting,
|
||||
prompts: list[Prompt] | None,
|
||||
document_sets: list[DocumentSet] | None,
|
||||
llm_model_provider_override: str | None,
|
||||
llm_model_version_override: str | None,
|
||||
starter_messages: list[StarterMessage] | None,
|
||||
is_public: bool,
|
||||
db_session: Session,
|
||||
prompt_ids: list[int] | None = None,
|
||||
document_set_ids: list[int] | None = None,
|
||||
tool_ids: list[int] | None = None,
|
||||
persona_id: int | None = None,
|
||||
default_persona: bool = False,
|
||||
@@ -356,6 +343,28 @@ def upsert_persona(
|
||||
if not tools and tool_ids:
|
||||
raise ValueError("Tools not found")
|
||||
|
||||
# Fetch and attach document_sets by IDs
|
||||
document_sets = None
|
||||
if document_set_ids is not None:
|
||||
document_sets = (
|
||||
db_session.query(DocumentSet)
|
||||
.filter(DocumentSet.id.in_(document_set_ids))
|
||||
.all()
|
||||
)
|
||||
if not document_sets and document_set_ids:
|
||||
raise ValueError("document_sets not found")
|
||||
|
||||
# Fetch and attach prompts by IDs
|
||||
prompts = None
|
||||
if prompt_ids is not None:
|
||||
prompts = db_session.query(Prompt).filter(Prompt.id.in_(prompt_ids)).all()
|
||||
if not prompts and prompt_ids:
|
||||
raise ValueError("prompts not found")
|
||||
|
||||
# ensure all specified tools are valid
|
||||
if tools:
|
||||
validate_persona_tools(tools)
|
||||
|
||||
if persona:
|
||||
if not default_persona and persona.default_persona:
|
||||
raise ValueError("Cannot update default persona with non-default.")
|
||||
@@ -383,10 +392,10 @@ def upsert_persona(
|
||||
|
||||
if prompts is not None:
|
||||
persona.prompts.clear()
|
||||
persona.prompts = prompts
|
||||
persona.prompts = prompts or []
|
||||
|
||||
if tools is not None:
|
||||
persona.tools = tools
|
||||
persona.tools = tools or []
|
||||
|
||||
else:
|
||||
persona = Persona(
|
||||
@@ -453,6 +462,14 @@ def update_persona_visibility(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def validate_persona_tools(tools: list[Tool]) -> None:
|
||||
for tool in tools:
|
||||
if tool.name == "InternetSearchTool" and not BING_API_KEY:
|
||||
raise ValueError(
|
||||
"Bing API key not found, please contact your Danswer admin to get it added!"
|
||||
)
|
||||
|
||||
|
||||
def check_user_can_edit_persona(user: User | None, persona: Persona) -> None:
|
||||
# if user is None, assume that no-auth is turned on
|
||||
if user is None:
|
||||
@@ -537,12 +554,22 @@ def get_persona_by_id(
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
include_deleted: bool = False,
|
||||
is_for_edit: bool = True, # NOTE: assume true for safety
|
||||
) -> Persona:
|
||||
stmt = select(Persona).where(Persona.id == persona_id)
|
||||
|
||||
or_conditions = []
|
||||
|
||||
# if user is an admin, they should have access to all Personas
|
||||
if user is not None and user.role != UserRole.ADMIN:
|
||||
stmt = stmt.where(or_(Persona.user_id == user.id, Persona.user_id.is_(None)))
|
||||
or_conditions.extend([Persona.user_id == user.id, Persona.user_id.is_(None)])
|
||||
|
||||
# if we aren't editing, also give access to all public personas
|
||||
if not is_for_edit:
|
||||
or_conditions.append(Persona.is_public.is_(True))
|
||||
|
||||
if or_conditions:
|
||||
stmt = stmt.where(or_(*or_conditions))
|
||||
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(Persona.deleted.is_(False))
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import tempfile
|
||||
from io import BytesIO
|
||||
from typing import IO
|
||||
|
||||
@@ -6,6 +7,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.db.models import PGFileStore
|
||||
from danswer.file_store.constants import MAX_IN_MEMORY_SIZE
|
||||
from danswer.file_store.constants import STANDARD_CHUNK_SIZE
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -15,6 +18,25 @@ def get_pg_conn_from_session(db_session: Session) -> connection:
|
||||
return db_session.connection().connection.connection # type: ignore
|
||||
|
||||
|
||||
def get_pgfilestore_by_file_name(
|
||||
file_name: str,
|
||||
db_session: Session,
|
||||
) -> PGFileStore:
|
||||
pgfilestore = db_session.query(PGFileStore).filter_by(file_name=file_name).first()
|
||||
|
||||
if not pgfilestore:
|
||||
raise RuntimeError(f"File by name {file_name} does not exist or was deleted")
|
||||
|
||||
return pgfilestore
|
||||
|
||||
|
||||
def delete_pgfilestore_by_file_name(
|
||||
file_name: str,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
db_session.query(PGFileStore).filter_by(file_name=file_name).delete()
|
||||
|
||||
|
||||
def create_populate_lobj(
|
||||
content: IO,
|
||||
db_session: Session,
|
||||
@@ -26,18 +48,40 @@ def create_populate_lobj(
|
||||
pg_conn = get_pg_conn_from_session(db_session)
|
||||
large_object = pg_conn.lobject()
|
||||
|
||||
large_object.write(content.read())
|
||||
# write in multiple chunks to avoid loading the whole file into memory
|
||||
while True:
|
||||
chunk = content.read(STANDARD_CHUNK_SIZE)
|
||||
if not chunk:
|
||||
break
|
||||
large_object.write(chunk)
|
||||
|
||||
large_object.close()
|
||||
|
||||
return large_object.oid
|
||||
|
||||
|
||||
def read_lobj(lobj_oid: int, db_session: Session, mode: str | None = None) -> IO:
|
||||
def read_lobj(
|
||||
lobj_oid: int,
|
||||
db_session: Session,
|
||||
mode: str | None = None,
|
||||
use_tempfile: bool = False,
|
||||
) -> IO:
|
||||
pg_conn = get_pg_conn_from_session(db_session)
|
||||
large_object = (
|
||||
pg_conn.lobject(lobj_oid, mode=mode) if mode else pg_conn.lobject(lobj_oid)
|
||||
)
|
||||
return BytesIO(large_object.read())
|
||||
|
||||
if use_tempfile:
|
||||
temp_file = tempfile.SpooledTemporaryFile(max_size=MAX_IN_MEMORY_SIZE)
|
||||
while True:
|
||||
chunk = large_object.read(STANDARD_CHUNK_SIZE)
|
||||
if not chunk:
|
||||
break
|
||||
temp_file.write(chunk)
|
||||
temp_file.seek(0)
|
||||
return temp_file
|
||||
else:
|
||||
return BytesIO(large_object.read())
|
||||
|
||||
|
||||
def delete_lobj_by_id(
|
||||
@@ -48,6 +92,23 @@ def delete_lobj_by_id(
|
||||
pg_conn.lobject(lobj_oid).unlink()
|
||||
|
||||
|
||||
def delete_lobj_by_name(
|
||||
lobj_name: str,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
try:
|
||||
pgfilestore = get_pgfilestore_by_file_name(lobj_name, db_session)
|
||||
except RuntimeError:
|
||||
logger.info(f"no file with name {lobj_name} found")
|
||||
return
|
||||
|
||||
pg_conn = get_pg_conn_from_session(db_session)
|
||||
pg_conn.lobject(pgfilestore.lobj_oid).unlink()
|
||||
|
||||
delete_pgfilestore_by_file_name(lobj_name, db_session)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def upsert_pgfilestore(
|
||||
file_name: str,
|
||||
display_name: str | None,
|
||||
@@ -87,22 +148,3 @@ def upsert_pgfilestore(
|
||||
db_session.commit()
|
||||
|
||||
return pgfilestore
|
||||
|
||||
|
||||
def get_pgfilestore_by_file_name(
|
||||
file_name: str,
|
||||
db_session: Session,
|
||||
) -> PGFileStore:
|
||||
pgfilestore = db_session.query(PGFileStore).filter_by(file_name=file_name).first()
|
||||
|
||||
if not pgfilestore:
|
||||
raise RuntimeError(f"File by name {file_name} does not exist or was deleted")
|
||||
|
||||
return pgfilestore
|
||||
|
||||
|
||||
def delete_pgfilestore_by_file_name(
|
||||
file_name: str,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
db_session.query(PGFileStore).filter_by(file_name=file_name).delete()
|
||||
|
||||
@@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from danswer.db.document_set import get_document_sets_by_ids
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Persona__DocumentSet
|
||||
@@ -15,6 +14,7 @@ from danswer.db.models import User
|
||||
from danswer.db.persona import get_default_prompt
|
||||
from danswer.db.persona import mark_persona_as_deleted
|
||||
from danswer.db.persona import upsert_persona
|
||||
from danswer.db.standard_answer import fetch_standard_answer_categories_by_ids
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
|
||||
|
||||
@@ -42,12 +42,6 @@ def create_slack_bot_persona(
|
||||
num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
) -> Persona:
|
||||
"""NOTE: does not commit changes"""
|
||||
document_sets = list(
|
||||
get_document_sets_by_ids(
|
||||
document_set_ids=document_set_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
|
||||
# create/update persona associated with the slack bot
|
||||
persona_name = _build_persona_name(channel_names)
|
||||
@@ -59,10 +53,10 @@ def create_slack_bot_persona(
|
||||
description="",
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
prompts=[default_prompt],
|
||||
document_sets=document_sets,
|
||||
prompt_ids=[default_prompt.id],
|
||||
document_set_ids=document_set_ids,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
@@ -79,12 +73,25 @@ def insert_slack_bot_config(
|
||||
persona_id: int | None,
|
||||
channel_config: ChannelConfig,
|
||||
response_type: SlackBotResponseType,
|
||||
standard_answer_category_ids: list[int],
|
||||
enable_auto_filters: bool,
|
||||
db_session: Session,
|
||||
) -> SlackBotConfig:
|
||||
existing_standard_answer_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=standard_answer_category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
|
||||
raise ValueError(
|
||||
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
|
||||
)
|
||||
|
||||
slack_bot_config = SlackBotConfig(
|
||||
persona_id=persona_id,
|
||||
channel_config=channel_config,
|
||||
response_type=response_type,
|
||||
standard_answer_categories=existing_standard_answer_categories,
|
||||
enable_auto_filters=enable_auto_filters,
|
||||
)
|
||||
db_session.add(slack_bot_config)
|
||||
db_session.commit()
|
||||
@@ -97,6 +104,8 @@ def update_slack_bot_config(
|
||||
persona_id: int | None,
|
||||
channel_config: ChannelConfig,
|
||||
response_type: SlackBotResponseType,
|
||||
standard_answer_category_ids: list[int],
|
||||
enable_auto_filters: bool,
|
||||
db_session: Session,
|
||||
) -> SlackBotConfig:
|
||||
slack_bot_config = db_session.scalar(
|
||||
@@ -106,6 +115,16 @@ def update_slack_bot_config(
|
||||
raise ValueError(
|
||||
f"Unable to find slack bot config with ID {slack_bot_config_id}"
|
||||
)
|
||||
|
||||
existing_standard_answer_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=standard_answer_category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
|
||||
raise ValueError(
|
||||
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
|
||||
)
|
||||
|
||||
# get the existing persona id before updating the object
|
||||
existing_persona_id = slack_bot_config.persona_id
|
||||
|
||||
@@ -115,6 +134,10 @@ def update_slack_bot_config(
|
||||
slack_bot_config.persona_id = persona_id
|
||||
slack_bot_config.channel_config = channel_config
|
||||
slack_bot_config.response_type = response_type
|
||||
slack_bot_config.standard_answer_categories = list(
|
||||
existing_standard_answer_categories
|
||||
)
|
||||
slack_bot_config.enable_auto_filters = enable_auto_filters
|
||||
|
||||
# if the persona has changed, then clean up the old persona
|
||||
if persona_id != existing_persona_id and existing_persona_id:
|
||||
|
||||
239
backend/danswer/db/standard_answer.py
Normal file
239
backend/danswer/db/standard_answer.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import string
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import StandardAnswer
|
||||
from danswer.db.models import StandardAnswerCategory
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def check_category_validity(category_name: str) -> bool:
|
||||
"""If a category name is too long, it should not be used (it will cause an error in Postgres
|
||||
as the unique constraint can only apply to entries that are less than 2704 bytes).
|
||||
|
||||
Additionally, extremely long categories are not really usable / useful."""
|
||||
if len(category_name) > 255:
|
||||
logger.error(
|
||||
f"Category with name '{category_name}' is too long, cannot be used"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def insert_standard_answer_category(
|
||||
category_name: str, db_session: Session
|
||||
) -> StandardAnswerCategory:
|
||||
if not check_category_validity(category_name):
|
||||
raise ValueError(f"Invalid category name: {category_name}")
|
||||
standard_answer_category = StandardAnswerCategory(name=category_name)
|
||||
db_session.add(standard_answer_category)
|
||||
db_session.commit()
|
||||
|
||||
return standard_answer_category
|
||||
|
||||
|
||||
def insert_standard_answer(
|
||||
keyword: str,
|
||||
answer: str,
|
||||
category_ids: list[int],
|
||||
db_session: Session,
|
||||
) -> StandardAnswer:
|
||||
existing_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
if len(existing_categories) != len(category_ids):
|
||||
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
|
||||
|
||||
standard_answer = StandardAnswer(
|
||||
keyword=keyword,
|
||||
answer=answer,
|
||||
categories=existing_categories,
|
||||
active=True,
|
||||
)
|
||||
db_session.add(standard_answer)
|
||||
db_session.commit()
|
||||
return standard_answer
|
||||
|
||||
|
||||
def update_standard_answer(
|
||||
standard_answer_id: int,
|
||||
keyword: str,
|
||||
answer: str,
|
||||
category_ids: list[int],
|
||||
db_session: Session,
|
||||
) -> StandardAnswer:
|
||||
standard_answer = db_session.scalar(
|
||||
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
|
||||
)
|
||||
if standard_answer is None:
|
||||
raise ValueError(f"No standard answer with id {standard_answer_id}")
|
||||
|
||||
existing_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
if len(existing_categories) != len(category_ids):
|
||||
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
|
||||
|
||||
standard_answer.keyword = keyword
|
||||
standard_answer.answer = answer
|
||||
standard_answer.categories = list(existing_categories)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return standard_answer
|
||||
|
||||
|
||||
def remove_standard_answer(
|
||||
standard_answer_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
standard_answer = db_session.scalar(
|
||||
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
|
||||
)
|
||||
if standard_answer is None:
|
||||
raise ValueError(f"No standard answer with id {standard_answer_id}")
|
||||
|
||||
standard_answer.active = False
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_standard_answer_category(
|
||||
standard_answer_category_id: int,
|
||||
category_name: str,
|
||||
db_session: Session,
|
||||
) -> StandardAnswerCategory:
|
||||
standard_answer_category = db_session.scalar(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.id == standard_answer_category_id
|
||||
)
|
||||
)
|
||||
if standard_answer_category is None:
|
||||
raise ValueError(
|
||||
f"No standard answer category with id {standard_answer_category_id}"
|
||||
)
|
||||
|
||||
if not check_category_validity(category_name):
|
||||
raise ValueError(f"Invalid category name: {category_name}")
|
||||
|
||||
standard_answer_category.name = category_name
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return standard_answer_category
|
||||
|
||||
|
||||
def fetch_standard_answer_category(
|
||||
standard_answer_category_id: int,
|
||||
db_session: Session,
|
||||
) -> StandardAnswerCategory | None:
|
||||
return db_session.scalar(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.id == standard_answer_category_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def fetch_standard_answer_categories_by_names(
|
||||
standard_answer_category_names: list[str],
|
||||
db_session: Session,
|
||||
) -> Sequence[StandardAnswerCategory]:
|
||||
return db_session.scalars(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.name.in_(standard_answer_category_names)
|
||||
)
|
||||
).all()
|
||||
|
||||
|
||||
def fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids: list[int],
|
||||
db_session: Session,
|
||||
) -> Sequence[StandardAnswerCategory]:
|
||||
return db_session.scalars(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.id.in_(standard_answer_category_ids)
|
||||
)
|
||||
).all()
|
||||
|
||||
|
||||
def fetch_standard_answer_categories(
|
||||
db_session: Session,
|
||||
) -> Sequence[StandardAnswerCategory]:
|
||||
return db_session.scalars(select(StandardAnswerCategory)).all()
|
||||
|
||||
|
||||
def fetch_standard_answer(
|
||||
standard_answer_id: int,
|
||||
db_session: Session,
|
||||
) -> StandardAnswer | None:
|
||||
return db_session.scalar(
|
||||
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
|
||||
)
|
||||
|
||||
|
||||
def find_matching_standard_answers(
|
||||
id_in: list[int],
|
||||
query: str,
|
||||
db_session: Session,
|
||||
) -> list[StandardAnswer]:
|
||||
stmt = (
|
||||
select(StandardAnswer)
|
||||
.where(StandardAnswer.active.is_(True))
|
||||
.where(StandardAnswer.id.in_(id_in))
|
||||
)
|
||||
possible_standard_answers = db_session.scalars(stmt).all()
|
||||
|
||||
matching_standard_answers: list[StandardAnswer] = []
|
||||
for standard_answer in possible_standard_answers:
|
||||
# Remove punctuation and split the keyword into individual words
|
||||
keyword_words = "".join(
|
||||
char
|
||||
for char in standard_answer.keyword.lower()
|
||||
if char not in string.punctuation
|
||||
).split()
|
||||
|
||||
# Remove punctuation and split the query into individual words
|
||||
query_words = "".join(
|
||||
char for char in query.lower() if char not in string.punctuation
|
||||
).split()
|
||||
|
||||
# Check if all of the keyword words are in the query words
|
||||
if all(word in query_words for word in keyword_words):
|
||||
matching_standard_answers.append(standard_answer)
|
||||
|
||||
return matching_standard_answers
|
||||
|
||||
|
||||
def fetch_standard_answers(db_session: Session) -> Sequence[StandardAnswer]:
|
||||
return db_session.scalars(
|
||||
select(StandardAnswer).where(StandardAnswer.active.is_(True))
|
||||
).all()
|
||||
|
||||
|
||||
def create_initial_default_standard_answer_category(db_session: Session) -> None:
|
||||
default_category_id = 0
|
||||
default_category_name = "General"
|
||||
default_category = fetch_standard_answer_category(
|
||||
standard_answer_category_id=default_category_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if default_category is not None:
|
||||
if default_category.name != default_category_name:
|
||||
raise ValueError(
|
||||
"DB is not in a valid initial state. "
|
||||
"Default standard answer category does not have expected name."
|
||||
)
|
||||
return
|
||||
|
||||
standard_answer_category = StandardAnswerCategory(
|
||||
id=default_category_id,
|
||||
name=default_category_name,
|
||||
)
|
||||
db_session.add(standard_answer_category)
|
||||
db_session.commit()
|
||||
@@ -1,5 +1,6 @@
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -91,9 +92,10 @@ def create_or_add_document_tag_list(
|
||||
new_tags.append(new_tag)
|
||||
existing_tag_values.add(tag_value)
|
||||
|
||||
logger.debug(
|
||||
f"Created new tags: {', '.join([f'{tag.tag_key}:{tag.tag_value}' for tag in new_tags])}"
|
||||
)
|
||||
if new_tags:
|
||||
logger.debug(
|
||||
f"Created new tags: {', '.join([f'{tag.tag_key}:{tag.tag_value}' for tag in new_tags])}"
|
||||
)
|
||||
|
||||
all_tags = existing_tags + new_tags
|
||||
|
||||
@@ -106,18 +108,28 @@ def create_or_add_document_tag_list(
|
||||
|
||||
|
||||
def get_tags_by_value_prefix_for_source_types(
|
||||
tag_key_prefix: str | None,
|
||||
tag_value_prefix: str | None,
|
||||
sources: list[DocumentSource] | None,
|
||||
limit: int | None,
|
||||
db_session: Session,
|
||||
) -> list[Tag]:
|
||||
query = select(Tag)
|
||||
|
||||
if tag_value_prefix:
|
||||
query = query.where(Tag.tag_value.startswith(tag_value_prefix))
|
||||
if tag_key_prefix or tag_value_prefix:
|
||||
conditions = []
|
||||
if tag_key_prefix:
|
||||
conditions.append(Tag.tag_key.ilike(f"{tag_key_prefix}%"))
|
||||
if tag_value_prefix:
|
||||
conditions.append(Tag.tag_value.ilike(f"{tag_value_prefix}%"))
|
||||
query = query.where(or_(*conditions))
|
||||
|
||||
if sources:
|
||||
query = query.where(Tag.source.in_(sources))
|
||||
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
result = db_session.execute(query)
|
||||
|
||||
tags = result.scalars().all()
|
||||
|
||||
@@ -26,6 +26,23 @@ def get_latest_task(
|
||||
return latest_task
|
||||
|
||||
|
||||
def get_latest_task_by_type(
|
||||
task_name: str,
|
||||
db_session: Session,
|
||||
) -> TaskQueueState | None:
|
||||
stmt = (
|
||||
select(TaskQueueState)
|
||||
.where(TaskQueueState.task_name.like(f"%{task_name}%"))
|
||||
.order_by(desc(TaskQueueState.id))
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
latest_task = result.scalars().first()
|
||||
|
||||
return latest_task
|
||||
|
||||
|
||||
def register_task(
|
||||
task_id: str,
|
||||
task_name: str,
|
||||
@@ -66,7 +83,7 @@ def mark_task_finished(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def check_live_task_not_timed_out(
|
||||
def check_task_is_live_and_not_timed_out(
|
||||
task: TaskQueueState,
|
||||
db_session: Session,
|
||||
timeout: int = JOB_TIMEOUT,
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.schema import Column
|
||||
|
||||
from danswer.db.models import User
|
||||
|
||||
|
||||
def list_users(db_session: Session) -> Sequence[User]:
|
||||
def list_users(db_session: Session, q: str = "") -> Sequence[User]:
|
||||
"""List all users. No pagination as of now, as the # of users
|
||||
is assumed to be relatively small (<< 1 million)"""
|
||||
return db_session.scalars(select(User)).unique().all()
|
||||
query = db_session.query(User)
|
||||
if q:
|
||||
query = query.filter(Column("email").ilike("%{}%".format(q)))
|
||||
return query.all()
|
||||
|
||||
|
||||
def get_user_by_email(email: str, db_session: Session) -> User | None:
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceChunkUncleaned
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -186,7 +186,7 @@ class IdRetrievalCapable(abc.ABC):
|
||||
min_chunk_ind: int | None,
|
||||
max_chunk_ind: int | None,
|
||||
user_access_control_list: list[str] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
"""
|
||||
Fetch chunk(s) based on document id
|
||||
|
||||
@@ -222,7 +222,7 @@ class KeywordCapable(abc.ABC):
|
||||
time_decay_multiplier: float,
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
) -> list[InferenceChunk]:
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
"""
|
||||
Run keyword search and return a list of chunks. Inference chunks are chunks with all of the
|
||||
information required for query time purposes. For example, some details of the document
|
||||
@@ -262,7 +262,7 @@ class VectorCapable(abc.ABC):
|
||||
time_decay_multiplier: float,
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
) -> list[InferenceChunk]:
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
"""
|
||||
Run vector/semantic search and return a list of inference chunks.
|
||||
|
||||
@@ -298,7 +298,7 @@ class HybridCapable(abc.ABC):
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
hybrid_alpha: float | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
"""
|
||||
Run hybrid search and return a list of inference chunks.
|
||||
|
||||
@@ -348,7 +348,7 @@ class AdminCapable(abc.ABC):
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
) -> list[InferenceChunk]:
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
"""
|
||||
Run the special search for the admin document explorer page
|
||||
|
||||
|
||||
@@ -91,6 +91,9 @@ schema DANSWER_CHUNK_NAME {
|
||||
field metadata type string {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
field metadata_suffix type string {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
field doc_updated_at type int {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
@@ -150,43 +153,41 @@ schema DANSWER_CHUNK_NAME {
|
||||
query(query_embedding) tensor<float>(x[VARIABLE_DIM])
|
||||
}
|
||||
|
||||
# This must be separate function for normalize_linear to work
|
||||
function vector_score() {
|
||||
function title_vector_score() {
|
||||
expression {
|
||||
# If no title, the full vector score comes from the content embedding
|
||||
(query(title_content_ratio) * if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))) +
|
||||
((1 - query(title_content_ratio)) * closeness(field, embeddings))
|
||||
}
|
||||
}
|
||||
|
||||
# This must be separate function for normalize_linear to work
|
||||
function keyword_score() {
|
||||
expression {
|
||||
(query(title_content_ratio) * bm25(title)) +
|
||||
((1 - query(title_content_ratio)) * bm25(content))
|
||||
#query(title_content_ratio) * if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))
|
||||
if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))
|
||||
}
|
||||
}
|
||||
|
||||
first-phase {
|
||||
expression: vector_score
|
||||
expression: closeness(field, embeddings)
|
||||
}
|
||||
|
||||
# Weighted average between Vector Search and BM-25
|
||||
# Each is a weighted average between the Title and Content fields
|
||||
# Finally each doc is boosted by it's user feedback based boost and recency
|
||||
# If any embedding or index field is missing, it just receives a score of 0
|
||||
# Assumptions:
|
||||
# - For a given query + corpus, the BM-25 scores will be relatively similar in distribution
|
||||
# therefore not normalizing before combining.
|
||||
# - For documents without title, it gets a score of 0 for that and this is ok as documents
|
||||
# without any title match should be penalized.
|
||||
global-phase {
|
||||
expression {
|
||||
(
|
||||
# Weighted Vector Similarity Score
|
||||
(query(alpha) * normalize_linear(vector_score)) +
|
||||
(
|
||||
query(alpha) * (
|
||||
(query(title_content_ratio) * normalize_linear(title_vector_score))
|
||||
+
|
||||
((1 - query(title_content_ratio)) * normalize_linear(closeness(field, embeddings)))
|
||||
)
|
||||
)
|
||||
|
||||
+
|
||||
|
||||
# Weighted Keyword Similarity Score
|
||||
((1 - query(alpha)) * normalize_linear(keyword_score))
|
||||
(
|
||||
(1 - query(alpha)) * (
|
||||
(query(title_content_ratio) * normalize_linear(bm25(title)))
|
||||
+
|
||||
((1 - query(title_content_ratio)) * normalize_linear(bm25(content)))
|
||||
)
|
||||
)
|
||||
)
|
||||
# Boost based on user feedback
|
||||
* document_boost
|
||||
@@ -201,8 +202,6 @@ schema DANSWER_CHUNK_NAME {
|
||||
bm25(content)
|
||||
closeness(field, title_embedding)
|
||||
closeness(field, embeddings)
|
||||
keyword_score
|
||||
vector_score
|
||||
document_boost
|
||||
recency_bias
|
||||
closest(embeddings)
|
||||
|
||||
@@ -41,6 +41,7 @@ from danswer.configs.constants import HIDDEN
|
||||
from danswer.configs.constants import INDEX_SEPARATOR
|
||||
from danswer.configs.constants import METADATA
|
||||
from danswer.configs.constants import METADATA_LIST
|
||||
from danswer.configs.constants import METADATA_SUFFIX
|
||||
from danswer.configs.constants import PRIMARY_OWNERS
|
||||
from danswer.configs.constants import RECENCY_BIAS
|
||||
from danswer.configs.constants import SECONDARY_OWNERS
|
||||
@@ -51,7 +52,6 @@ from danswer.configs.constants import SOURCE_LINKS
|
||||
from danswer.configs.constants import SOURCE_TYPE
|
||||
from danswer.configs.constants import TITLE
|
||||
from danswer.configs.constants import TITLE_EMBEDDING
|
||||
from danswer.configs.constants import TITLE_SEPARATOR
|
||||
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
@@ -64,7 +64,7 @@ from danswer.document_index.vespa.utils import remove_invalid_unicode_chars
|
||||
from danswer.document_index.vespa.utils import replace_invalid_doc_id_characters
|
||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceChunkUncleaned
|
||||
from danswer.search.retrieval.search_runner import query_processing
|
||||
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
|
||||
from danswer.utils.batching import batch_generator
|
||||
@@ -119,6 +119,7 @@ def _does_document_exist(
|
||||
chunk. This checks for whether the chunk exists already in the index"""
|
||||
doc_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}"
|
||||
doc_fetch_response = http_client.get(doc_url)
|
||||
|
||||
if doc_fetch_response.status_code == 404:
|
||||
return False
|
||||
|
||||
@@ -346,8 +347,10 @@ def _index_vespa_chunk(
|
||||
TITLE: remove_invalid_unicode_chars(title) if title else None,
|
||||
SKIP_TITLE_EMBEDDING: not title,
|
||||
CONTENT: remove_invalid_unicode_chars(chunk.content),
|
||||
# This duplication of `content` is needed for keyword highlighting :(
|
||||
CONTENT_SUMMARY: remove_invalid_unicode_chars(chunk.content),
|
||||
# This duplication of `content` is needed for keyword highlighting
|
||||
# Note that it's not exactly the same as the actual content
|
||||
# which contains the title prefix and metadata suffix
|
||||
CONTENT_SUMMARY: remove_invalid_unicode_chars(chunk.content_summary),
|
||||
SOURCE_TYPE: str(document.source.value),
|
||||
SOURCE_LINKS: json.dumps(chunk.source_links),
|
||||
SEMANTIC_IDENTIFIER: remove_invalid_unicode_chars(document.semantic_identifier),
|
||||
@@ -355,6 +358,7 @@ def _index_vespa_chunk(
|
||||
METADATA: json.dumps(document.metadata),
|
||||
# Save as a list for efficient extraction as an Attribute
|
||||
METADATA_LIST: chunk.source_document.get_metadata_str_attributes(),
|
||||
METADATA_SUFFIX: chunk.metadata_suffix,
|
||||
EMBEDDINGS: embeddings_name_vector_map,
|
||||
TITLE_EMBEDDING: chunk.title_embedding,
|
||||
BOOST: chunk.boost,
|
||||
@@ -559,7 +563,9 @@ def _process_dynamic_summary(
|
||||
return processed_summary
|
||||
|
||||
|
||||
def _vespa_hit_to_inference_chunk(hit: dict[str, Any]) -> InferenceChunk:
|
||||
def _vespa_hit_to_inference_chunk(
|
||||
hit: dict[str, Any], null_score: bool = False
|
||||
) -> InferenceChunkUncleaned:
|
||||
fields = cast(dict[str, Any], hit["fields"])
|
||||
|
||||
# parse fields that are stored as strings, but are really json / datetime
|
||||
@@ -582,19 +588,6 @@ def _vespa_hit_to_inference_chunk(hit: dict[str, Any]) -> InferenceChunk:
|
||||
f"Chunk with blurb: {fields.get(BLURB, 'Unknown')[:50]}... has no Semantic Identifier"
|
||||
)
|
||||
|
||||
# Remove the title from the first chunk as every chunk already included
|
||||
# its semantic identifier for LLM
|
||||
content = fields[CONTENT]
|
||||
if fields[CHUNK_ID] == 0:
|
||||
parts = content.split(TITLE_SEPARATOR, maxsplit=1)
|
||||
content = parts[1] if len(parts) > 1 and "\n" not in parts[0] else content
|
||||
|
||||
# User ran into this, not sure why this could happen, error checking here
|
||||
blurb = fields.get(BLURB)
|
||||
if not blurb:
|
||||
logger.error(f"Chunk with id {fields.get(semantic_identifier)} ")
|
||||
blurb = ""
|
||||
|
||||
source_links = fields.get(SOURCE_LINKS, {})
|
||||
source_links_dict_unprocessed = (
|
||||
json.loads(source_links) if isinstance(source_links, str) else source_links
|
||||
@@ -604,29 +597,33 @@ def _vespa_hit_to_inference_chunk(hit: dict[str, Any]) -> InferenceChunk:
|
||||
for k, v in cast(dict[str, str], source_links_dict_unprocessed).items()
|
||||
}
|
||||
|
||||
return InferenceChunk(
|
||||
return InferenceChunkUncleaned(
|
||||
chunk_id=fields[CHUNK_ID],
|
||||
blurb=blurb,
|
||||
content=content,
|
||||
blurb=fields.get(BLURB, ""), # Unused
|
||||
content=fields[CONTENT], # Includes extra title prefix and metadata suffix
|
||||
source_links=source_links_dict,
|
||||
section_continuation=fields[SECTION_CONTINUATION],
|
||||
document_id=fields[DOCUMENT_ID],
|
||||
source_type=fields[SOURCE_TYPE],
|
||||
title=fields.get(TITLE),
|
||||
semantic_identifier=fields[SEMANTIC_IDENTIFIER],
|
||||
boost=fields.get(BOOST, 1),
|
||||
recency_bias=fields.get("matchfeatures", {}).get(RECENCY_BIAS, 1.0),
|
||||
score=hit.get("relevance", 0),
|
||||
score=None if null_score else hit.get("relevance", 0),
|
||||
hidden=fields.get(HIDDEN, False),
|
||||
primary_owners=fields.get(PRIMARY_OWNERS),
|
||||
secondary_owners=fields.get(SECONDARY_OWNERS),
|
||||
metadata=metadata,
|
||||
metadata_suffix=fields.get(METADATA_SUFFIX),
|
||||
match_highlights=match_highlights,
|
||||
updated_at=updated_at,
|
||||
)
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _query_vespa(query_params: Mapping[str, str | int | float]) -> list[InferenceChunk]:
|
||||
def _query_vespa(
|
||||
query_params: Mapping[str, str | int | float]
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
if "query" in query_params and not cast(str, query_params["query"]).strip():
|
||||
raise ValueError("No/empty query received")
|
||||
|
||||
@@ -681,16 +678,6 @@ def _query_vespa(query_params: Mapping[str, str | int | float]) -> list[Inferenc
|
||||
return inference_chunks
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _inference_chunk_by_vespa_id(vespa_id: str, index_name: str) -> InferenceChunk:
|
||||
res = requests.get(
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_id}"
|
||||
)
|
||||
res.raise_for_status()
|
||||
|
||||
return _vespa_hit_to_inference_chunk(res.json())
|
||||
|
||||
|
||||
def in_memory_zip_from_file_bytes(file_contents: dict[str, bytes]) -> BinaryIO:
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zipf:
|
||||
@@ -735,6 +722,7 @@ class VespaIndex(DocumentIndex):
|
||||
f"{SOURCE_TYPE}, "
|
||||
f"{SOURCE_LINKS}, "
|
||||
f"{SEMANTIC_IDENTIFIER}, "
|
||||
f"{TITLE}, "
|
||||
f"{SECTION_CONTINUATION}, "
|
||||
f"{BOOST}, "
|
||||
f"{HIDDEN}, "
|
||||
@@ -742,6 +730,7 @@ class VespaIndex(DocumentIndex):
|
||||
f"{PRIMARY_OWNERS}, "
|
||||
f"{SECONDARY_OWNERS}, "
|
||||
f"{METADATA}, "
|
||||
f"{METADATA_SUFFIX}, "
|
||||
f"{CONTENT_SUMMARY} "
|
||||
f"from {{index_name}} where "
|
||||
)
|
||||
@@ -977,7 +966,7 @@ class VespaIndex(DocumentIndex):
|
||||
min_chunk_ind: int | None,
|
||||
max_chunk_ind: int | None,
|
||||
user_access_control_list: list[str] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
document_id = replace_invalid_doc_id_characters(document_id)
|
||||
|
||||
vespa_chunks = _get_vespa_chunks_by_document_id(
|
||||
@@ -992,7 +981,8 @@ class VespaIndex(DocumentIndex):
|
||||
return []
|
||||
|
||||
inference_chunks = [
|
||||
_vespa_hit_to_inference_chunk(chunk) for chunk in vespa_chunks
|
||||
_vespa_hit_to_inference_chunk(chunk, null_score=True)
|
||||
for chunk in vespa_chunks
|
||||
]
|
||||
inference_chunks.sort(key=lambda chunk: chunk.chunk_id)
|
||||
return inference_chunks
|
||||
@@ -1005,7 +995,7 @@ class VespaIndex(DocumentIndex):
|
||||
num_to_retrieve: int = NUM_RETURNED_HITS,
|
||||
offset: int = 0,
|
||||
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
|
||||
) -> list[InferenceChunk]:
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
# IMPORTANT: THIS FUNCTION IS NOT UP TO DATE, DOES NOT WORK CORRECTLY
|
||||
vespa_where_clauses = _build_vespa_filters(filters)
|
||||
yql = (
|
||||
@@ -1042,7 +1032,7 @@ class VespaIndex(DocumentIndex):
|
||||
offset: int = 0,
|
||||
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
|
||||
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
|
||||
) -> list[InferenceChunk]:
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
# IMPORTANT: THIS FUNCTION IS NOT UP TO DATE, DOES NOT WORK CORRECTLY
|
||||
vespa_where_clauses = _build_vespa_filters(filters)
|
||||
yql = (
|
||||
@@ -1086,7 +1076,7 @@ class VespaIndex(DocumentIndex):
|
||||
title_content_ratio: float | None = TITLE_CONTENT_RATIO,
|
||||
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
|
||||
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
|
||||
) -> list[InferenceChunk]:
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
vespa_where_clauses = _build_vespa_filters(filters)
|
||||
# Needs to be at least as much as the value set in Vespa schema config
|
||||
target_hits = max(10 * num_to_retrieve, 1000)
|
||||
@@ -1130,7 +1120,7 @@ class VespaIndex(DocumentIndex):
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int = NUM_RETURNED_HITS,
|
||||
offset: int = 0,
|
||||
) -> list[InferenceChunk]:
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
vespa_where_clauses = _build_vespa_filters(filters, include_hidden=True)
|
||||
yql = (
|
||||
VespaIndex.yql_base.format(index_name=self.index_name)
|
||||
|
||||
8
backend/danswer/file_processing/enums.py
Normal file
8
backend/danswer/file_processing/enums.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class HtmlBasedConnectorTransformLinksStrategy(str, Enum):
|
||||
# remove links entirely
|
||||
STRIP = "strip"
|
||||
# turn HTML links into markdown links
|
||||
MARKDOWN = "markdown"
|
||||
@@ -3,6 +3,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import zipfile
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from email.parser import Parser as EmailParser
|
||||
from pathlib import Path
|
||||
@@ -16,6 +17,7 @@ import pptx # type: ignore
|
||||
from pypdf import PdfReader
|
||||
from pypdf.errors import PdfStreamError
|
||||
|
||||
from danswer.configs.constants import DANSWER_METADATA_FILENAME
|
||||
from danswer.file_processing.html_utils import parse_html_page_basic
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -64,6 +66,16 @@ def check_file_ext_is_valid(ext: str) -> bool:
|
||||
return ext in VALID_FILE_EXTENSIONS
|
||||
|
||||
|
||||
def is_text_file(file: IO[bytes]) -> bool:
|
||||
"""
|
||||
checks if the first 1024 bytes only contain printable or whitespace characters
|
||||
if it does, then we say its a plaintext file
|
||||
"""
|
||||
raw_data = file.read(1024)
|
||||
text_chars = bytearray({7, 8, 9, 10, 12, 13, 27} | set(range(0x20, 0x100)) - {0x7F})
|
||||
return all(c in text_chars for c in raw_data)
|
||||
|
||||
|
||||
def detect_encoding(file: IO[bytes]) -> str:
|
||||
raw_data = file.read(50000)
|
||||
encoding = chardet.detect(raw_data)["encoding"] or "utf-8"
|
||||
@@ -88,7 +100,7 @@ def load_files_from_zip(
|
||||
with zipfile.ZipFile(zip_file_io, "r") as zip_file:
|
||||
zip_metadata = {}
|
||||
try:
|
||||
metadata_file_info = zip_file.getinfo(".danswer_metadata.json")
|
||||
metadata_file_info = zip_file.getinfo(DANSWER_METADATA_FILENAME)
|
||||
with zip_file.open(metadata_file_info, "r") as metadata_file:
|
||||
try:
|
||||
zip_metadata = json.load(metadata_file)
|
||||
@@ -96,18 +108,19 @@ def load_files_from_zip(
|
||||
# convert list of dicts to dict of dicts
|
||||
zip_metadata = {d["filename"]: d for d in zip_metadata}
|
||||
except json.JSONDecodeError:
|
||||
logger.warn("Unable to load .danswer_metadata.json")
|
||||
logger.warn(f"Unable to load {DANSWER_METADATA_FILENAME}")
|
||||
except KeyError:
|
||||
logger.info("No .danswer_metadata.json file")
|
||||
logger.info(f"No {DANSWER_METADATA_FILENAME} file")
|
||||
|
||||
for file_info in zip_file.infolist():
|
||||
with zip_file.open(file_info.filename, "r") as file:
|
||||
if ignore_dirs and file_info.is_dir():
|
||||
continue
|
||||
|
||||
if ignore_macos_resource_fork_files and is_macos_resource_fork_file(
|
||||
file_info.filename
|
||||
):
|
||||
if (
|
||||
ignore_macos_resource_fork_files
|
||||
and is_macos_resource_fork_file(file_info.filename)
|
||||
) or file_info.filename == DANSWER_METADATA_FILENAME:
|
||||
continue
|
||||
yield file_info, file, zip_metadata.get(file_info.filename, {})
|
||||
|
||||
@@ -259,37 +272,32 @@ def extract_file_text(
|
||||
file: IO[Any],
|
||||
break_on_unprocessable: bool = True,
|
||||
) -> str:
|
||||
if not file_name:
|
||||
return file_io_to_text(file)
|
||||
extension_to_function: dict[str, Callable[[IO[Any]], str]] = {
|
||||
".pdf": pdf_to_text,
|
||||
".docx": docx_to_text,
|
||||
".pptx": pptx_to_text,
|
||||
".xlsx": xlsx_to_text,
|
||||
".eml": eml_to_text,
|
||||
".epub": epub_to_text,
|
||||
".html": parse_html_page_basic,
|
||||
}
|
||||
|
||||
extension = get_file_ext(file_name)
|
||||
if not check_file_ext_is_valid(extension):
|
||||
def _process_file() -> str:
|
||||
if file_name:
|
||||
extension = get_file_ext(file_name)
|
||||
if check_file_ext_is_valid(extension):
|
||||
return extension_to_function.get(extension, file_io_to_text)(file)
|
||||
|
||||
# Either the file somehow has no name or the extension is not one that we are familiar with
|
||||
if is_text_file(file):
|
||||
return file_io_to_text(file)
|
||||
|
||||
raise ValueError("Unknown file extension and unknown text encoding")
|
||||
|
||||
try:
|
||||
return _process_file()
|
||||
except Exception as e:
|
||||
if break_on_unprocessable:
|
||||
raise RuntimeError(f"Unprocessable file type: {file_name}")
|
||||
else:
|
||||
logger.warning(f"Unprocessable file type: {file_name}")
|
||||
return ""
|
||||
|
||||
if extension == ".pdf":
|
||||
return pdf_to_text(file=file)
|
||||
|
||||
elif extension == ".docx":
|
||||
return docx_to_text(file)
|
||||
|
||||
elif extension == ".pptx":
|
||||
return pptx_to_text(file)
|
||||
|
||||
elif extension == ".xlsx":
|
||||
return xlsx_to_text(file)
|
||||
|
||||
elif extension == ".eml":
|
||||
return eml_to_text(file)
|
||||
|
||||
elif extension == ".epub":
|
||||
return epub_to_text(file)
|
||||
|
||||
elif extension == ".html":
|
||||
return parse_html_page_basic(file)
|
||||
|
||||
else:
|
||||
return file_io_to_text(file)
|
||||
raise RuntimeError(f"Failed to process file: {str(e)}") from e
|
||||
logger.warning(f"Failed to process file: {str(e)}")
|
||||
return ""
|
||||
|
||||
@@ -5,8 +5,10 @@ from typing import IO
|
||||
|
||||
import bs4
|
||||
|
||||
from danswer.configs.app_configs import HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY
|
||||
from danswer.configs.app_configs import WEB_CONNECTOR_IGNORED_CLASSES
|
||||
from danswer.configs.app_configs import WEB_CONNECTOR_IGNORED_ELEMENTS
|
||||
from danswer.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
|
||||
MINTLIFY_UNWANTED = ["sticky", "hidden"]
|
||||
|
||||
@@ -32,6 +34,19 @@ def strip_newlines(document: str) -> str:
|
||||
return re.sub(r"[\n\r]+", " ", document)
|
||||
|
||||
|
||||
def format_element_text(element_text: str, link_href: str | None) -> str:
|
||||
element_text_no_newlines = strip_newlines(element_text)
|
||||
|
||||
if (
|
||||
not link_href
|
||||
or HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY
|
||||
== HtmlBasedConnectorTransformLinksStrategy.STRIP
|
||||
):
|
||||
return element_text_no_newlines
|
||||
|
||||
return f"[{element_text_no_newlines}]({link_href})"
|
||||
|
||||
|
||||
def format_document_soup(
|
||||
document: bs4.BeautifulSoup, table_cell_separator: str = "\t"
|
||||
) -> str:
|
||||
@@ -49,6 +64,8 @@ def format_document_soup(
|
||||
verbatim_output = 0
|
||||
in_table = False
|
||||
last_added_newline = False
|
||||
link_href: str | None = None
|
||||
|
||||
for e in document.descendants:
|
||||
verbatim_output -= 1
|
||||
if isinstance(e, bs4.element.NavigableString):
|
||||
@@ -71,7 +88,7 @@ def format_document_soup(
|
||||
content_to_add = (
|
||||
element_text
|
||||
if verbatim_output > 0
|
||||
else strip_newlines(element_text)
|
||||
else format_element_text(element_text, link_href)
|
||||
)
|
||||
|
||||
# Don't join separate elements without any spacing
|
||||
@@ -98,7 +115,14 @@ def format_document_soup(
|
||||
elif in_table:
|
||||
# don't handle other cases while in table
|
||||
pass
|
||||
|
||||
elif e.name == "a":
|
||||
href_value = e.get("href", None)
|
||||
# mostly for typing, having multiple hrefs is not valid HTML
|
||||
link_href = (
|
||||
href_value[0] if isinstance(href_value, list) else href_value
|
||||
)
|
||||
elif e.name == "/a":
|
||||
link_href = None
|
||||
elif e.name in ["p", "div"]:
|
||||
if not list_element_start:
|
||||
text += "\n"
|
||||
|
||||
2
backend/danswer/file_store/constants.py
Normal file
2
backend/danswer/file_store/constants.py
Normal file
@@ -0,0 +1,2 @@
|
||||
MAX_IN_MEMORY_SIZE = 30 * 1024 * 1024 # 30MB
|
||||
STANDARD_CHUNK_SIZE = 10 * 1024 * 1024 # 10MB chunks
|
||||
@@ -5,6 +5,7 @@ from typing import IO
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.db.models import PGFileStore
|
||||
from danswer.db.pg_file_store import create_populate_lobj
|
||||
from danswer.db.pg_file_store import delete_lobj_by_id
|
||||
from danswer.db.pg_file_store import delete_pgfilestore_by_file_name
|
||||
@@ -26,6 +27,7 @@ class FileStore(ABC):
|
||||
display_name: str | None,
|
||||
file_origin: FileOrigin,
|
||||
file_type: str,
|
||||
file_metadata: dict | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save a file to the blob store
|
||||
@@ -41,12 +43,17 @@ class FileStore(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def read_file(self, file_name: str, mode: str | None) -> IO:
|
||||
def read_file(
|
||||
self, file_name: str, mode: str | None, use_tempfile: bool = False
|
||||
) -> IO:
|
||||
"""
|
||||
Read the content of a given file by the name
|
||||
|
||||
Parameters:
|
||||
- file_name: Name of file to read
|
||||
- mode: Mode to open the file (e.g. 'b' for binary)
|
||||
- use_tempfile: Whether to use a temporary file to store the contents
|
||||
in order to avoid loading the entire file into memory
|
||||
|
||||
Returns:
|
||||
Contents of the file and metadata dict
|
||||
@@ -73,6 +80,7 @@ class PostgresBackedFileStore(FileStore):
|
||||
display_name: str | None,
|
||||
file_origin: FileOrigin,
|
||||
file_type: str,
|
||||
file_metadata: dict | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
# The large objects in postgres are saved as special objects can be listed with
|
||||
@@ -85,20 +93,33 @@ class PostgresBackedFileStore(FileStore):
|
||||
file_type=file_type,
|
||||
lobj_oid=obj_id,
|
||||
db_session=self.db_session,
|
||||
file_metadata=file_metadata,
|
||||
)
|
||||
self.db_session.commit()
|
||||
except Exception:
|
||||
self.db_session.rollback()
|
||||
raise
|
||||
|
||||
def read_file(self, file_name: str, mode: str | None = None) -> IO:
|
||||
def read_file(
|
||||
self, file_name: str, mode: str | None = None, use_tempfile: bool = False
|
||||
) -> IO:
|
||||
file_record = get_pgfilestore_by_file_name(
|
||||
file_name=file_name, db_session=self.db_session
|
||||
)
|
||||
return read_lobj(
|
||||
lobj_oid=file_record.lobj_oid, db_session=self.db_session, mode=mode
|
||||
lobj_oid=file_record.lobj_oid,
|
||||
db_session=self.db_session,
|
||||
mode=mode,
|
||||
use_tempfile=use_tempfile,
|
||||
)
|
||||
|
||||
def read_file_record(self, file_name: str) -> PGFileStore:
|
||||
file_record = get_pgfilestore_by_file_name(
|
||||
file_name=file_name, db_session=self.db_session
|
||||
)
|
||||
|
||||
return file_record
|
||||
|
||||
def delete_file(self, file_name: str) -> None:
|
||||
try:
|
||||
file_record = get_pgfilestore_by_file_name(
|
||||
|
||||
@@ -3,12 +3,16 @@ from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from danswer.configs.app_configs import BLURB_SIZE
|
||||
from danswer.configs.app_configs import CHUNK_OVERLAP
|
||||
from danswer.configs.app_configs import MINI_CHUNK_SIZE
|
||||
from danswer.configs.app_configs import SKIP_METADATA_IN_CHUNK
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MAX_CHUNK_TITLE_LEN
|
||||
from danswer.configs.constants import RETURN_SEPARATOR
|
||||
from danswer.configs.constants import SECTION_SEPARATOR
|
||||
from danswer.configs.constants import TITLE_SEPARATOR
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_metadata_keys_to_ignore,
|
||||
)
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.search.search_nlp_models import get_default_tokenizer
|
||||
@@ -19,6 +23,14 @@ if TYPE_CHECKING:
|
||||
from transformers import AutoTokenizer # type:ignore
|
||||
|
||||
|
||||
# Not supporting overlaps, we need a clean combination of chunks and it is unclear if overlaps
|
||||
# actually help quality at all
|
||||
CHUNK_OVERLAP = 0
|
||||
# Fairly arbitrary numbers but the general concept is we don't want the title/metadata to
|
||||
# overwhelm the actual contents of the chunk
|
||||
MAX_METADATA_PERCENTAGE = 0.25
|
||||
CHUNK_MIN_CONTENT = 256
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
ChunkFunc = Callable[[Document], list[DocAwareChunk]]
|
||||
@@ -44,6 +56,8 @@ def chunk_large_section(
|
||||
chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
chunk_overlap: int = CHUNK_OVERLAP,
|
||||
blurb_size: int = BLURB_SIZE,
|
||||
title_prefix: str = "",
|
||||
metadata_suffix: str = "",
|
||||
) -> list[DocAwareChunk]:
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
@@ -60,30 +74,69 @@ def chunk_large_section(
|
||||
source_document=document,
|
||||
chunk_id=start_chunk_id + chunk_ind,
|
||||
blurb=blurb,
|
||||
content=chunk_str,
|
||||
content=f"{title_prefix}{chunk_str}{metadata_suffix}",
|
||||
content_summary=chunk_str,
|
||||
source_links={0: section_link_text},
|
||||
section_continuation=(chunk_ind != 0),
|
||||
metadata_suffix=metadata_suffix,
|
||||
)
|
||||
for chunk_ind, chunk_str in enumerate(split_texts)
|
||||
]
|
||||
return chunks
|
||||
|
||||
|
||||
def _get_metadata_suffix_for_document_index(
|
||||
metadata: dict[str, str | list[str]]
|
||||
) -> str:
|
||||
if not metadata:
|
||||
return ""
|
||||
metadata_str = "Metadata:\n"
|
||||
for key, value in metadata.items():
|
||||
if key in get_metadata_keys_to_ignore():
|
||||
continue
|
||||
|
||||
value_str = ", ".join(value) if isinstance(value, list) else value
|
||||
metadata_str += f"\t{key} - {value_str}\n"
|
||||
return metadata_str.strip()
|
||||
|
||||
|
||||
def chunk_document(
|
||||
document: Document,
|
||||
chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
subsection_overlap: int = CHUNK_OVERLAP,
|
||||
blurb_size: int = BLURB_SIZE,
|
||||
include_metadata: bool = not SKIP_METADATA_IN_CHUNK,
|
||||
) -> list[DocAwareChunk]:
|
||||
title = document.get_title_for_document_index()
|
||||
title_prefix = title.replace("\n", " ") + TITLE_SEPARATOR if title else ""
|
||||
tokenizer = get_default_tokenizer()
|
||||
|
||||
title = document.get_title_for_document_index()
|
||||
title_prefix = f"{title[:MAX_CHUNK_TITLE_LEN]}{RETURN_SEPARATOR}" if title else ""
|
||||
title_tokens = len(tokenizer.tokenize(title_prefix))
|
||||
|
||||
metadata_suffix = ""
|
||||
metadata_tokens = 0
|
||||
if include_metadata:
|
||||
metadata = _get_metadata_suffix_for_document_index(document.metadata)
|
||||
metadata_suffix = RETURN_SEPARATOR + metadata if metadata else ""
|
||||
metadata_tokens = len(tokenizer.tokenize(metadata_suffix))
|
||||
|
||||
if metadata_tokens >= chunk_tok_size * MAX_METADATA_PERCENTAGE:
|
||||
metadata_suffix = ""
|
||||
metadata_tokens = 0
|
||||
|
||||
content_token_limit = chunk_tok_size - title_tokens - metadata_tokens
|
||||
|
||||
# If there is not enough context remaining then just index the chunk with no prefix/suffix
|
||||
if content_token_limit <= CHUNK_MIN_CONTENT:
|
||||
content_token_limit = chunk_tok_size
|
||||
title_prefix = ""
|
||||
metadata_suffix = ""
|
||||
|
||||
chunks: list[DocAwareChunk] = []
|
||||
link_offsets: dict[int, str] = {}
|
||||
chunk_text = ""
|
||||
for ind, section in enumerate(document.sections):
|
||||
section_text = title_prefix + section.text if ind == 0 else section.text
|
||||
for section in document.sections:
|
||||
section_text = section.text
|
||||
section_link_text = section.link or ""
|
||||
|
||||
section_tok_length = len(tokenizer.tokenize(section_text))
|
||||
@@ -92,16 +145,18 @@ def chunk_document(
|
||||
|
||||
# Large sections are considered self-contained/unique therefore they start a new chunk and are not concatenated
|
||||
# at the end by other sections
|
||||
if section_tok_length > chunk_tok_size:
|
||||
if section_tok_length > content_token_limit:
|
||||
if chunk_text:
|
||||
chunks.append(
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks),
|
||||
blurb=extract_blurb(chunk_text, blurb_size),
|
||||
content=chunk_text,
|
||||
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
|
||||
content_summary=chunk_text,
|
||||
source_links=link_offsets,
|
||||
section_continuation=False,
|
||||
metadata_suffix=metadata_suffix,
|
||||
)
|
||||
)
|
||||
link_offsets = {}
|
||||
@@ -113,9 +168,11 @@ def chunk_document(
|
||||
document=document,
|
||||
start_chunk_id=len(chunks),
|
||||
tokenizer=tokenizer,
|
||||
chunk_size=chunk_tok_size,
|
||||
chunk_size=content_token_limit,
|
||||
chunk_overlap=subsection_overlap,
|
||||
blurb_size=blurb_size,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix=metadata_suffix,
|
||||
)
|
||||
chunks.extend(large_section_chunks)
|
||||
continue
|
||||
@@ -125,7 +182,7 @@ def chunk_document(
|
||||
current_tok_length
|
||||
+ len(tokenizer.tokenize(SECTION_SEPARATOR))
|
||||
+ section_tok_length
|
||||
<= chunk_tok_size
|
||||
<= content_token_limit
|
||||
):
|
||||
chunk_text += (
|
||||
SECTION_SEPARATOR + section_text if chunk_text else section_text
|
||||
@@ -137,9 +194,11 @@ def chunk_document(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks),
|
||||
blurb=extract_blurb(chunk_text, blurb_size),
|
||||
content=chunk_text,
|
||||
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
|
||||
content_summary=chunk_text,
|
||||
source_links=link_offsets,
|
||||
section_continuation=False,
|
||||
metadata_suffix=metadata_suffix,
|
||||
)
|
||||
)
|
||||
link_offsets = {0: section_link_text}
|
||||
@@ -153,9 +212,11 @@ def chunk_document(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks),
|
||||
blurb=extract_blurb(chunk_text, blurb_size),
|
||||
content=chunk_text,
|
||||
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
|
||||
content_summary=chunk_text,
|
||||
source_links=link_offsets,
|
||||
section_continuation=False,
|
||||
metadata_suffix=metadata_suffix,
|
||||
)
|
||||
)
|
||||
return chunks
|
||||
@@ -164,6 +225,9 @@ def chunk_document(
|
||||
def split_chunk_text_into_mini_chunks(
|
||||
chunk_text: str, mini_chunk_size: int = MINI_CHUNK_SIZE
|
||||
) -> list[str]:
|
||||
"""The minichunks won't all have the title prefix or metadata suffix
|
||||
It could be a significant percentage of every minichunk so better to not include it
|
||||
"""
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
token_count_func = get_default_tokenizer().tokenize
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user