mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-27 04:35:50 +00:00
Compare commits
183 Commits
benchmarki
...
v0.3.16
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
30983657ec | ||
|
|
6b6b3daab7 | ||
|
|
20441df4a4 | ||
|
|
d7141df5fc | ||
|
|
615bb7b095 | ||
|
|
e759718c3e | ||
|
|
06d8d0e53c | ||
|
|
ae9b556876 | ||
|
|
f883611e94 | ||
|
|
13c536c033 | ||
|
|
2e6be57880 | ||
|
|
b352d83b8c | ||
|
|
aa67768c79 | ||
|
|
6004e540f3 | ||
|
|
64d2cea396 | ||
|
|
b5947a1c74 | ||
|
|
cdf260b277 | ||
|
|
73483b5e09 | ||
|
|
a6a444f365 | ||
|
|
449a403c73 | ||
|
|
4aebf824d2 | ||
|
|
26946198de | ||
|
|
e5035b8992 | ||
|
|
2e9af3086a | ||
|
|
dab3ba8a41 | ||
|
|
1e84b0daa4 | ||
|
|
f4c8abdf21 | ||
|
|
ccc5bb1e67 | ||
|
|
c3cf9134bb | ||
|
|
0370b9b38d | ||
|
|
95bf1c13ad | ||
|
|
00c1f93b12 | ||
|
|
a122510cee | ||
|
|
dca4f7a72b | ||
|
|
535dc265c5 | ||
|
|
56882367ba | ||
|
|
d9fbd7ffe2 | ||
|
|
8b7d01fb3b | ||
|
|
016a087b10 | ||
|
|
241b886976 | ||
|
|
ff014e4f5a | ||
|
|
0318507911 | ||
|
|
6650f01dc6 | ||
|
|
962e3f726a | ||
|
|
25a73b9921 | ||
|
|
dc0b3672ac | ||
|
|
c4ad03a65d | ||
|
|
c6f354fd03 | ||
|
|
2f001c23b7 | ||
|
|
4d950aa60d | ||
|
|
56406a0b53 | ||
|
|
eb31c08461 | ||
|
|
26f94c9890 | ||
|
|
a9570e01e2 | ||
|
|
402d83e167 | ||
|
|
10dcd49fc8 | ||
|
|
0fdad0e777 | ||
|
|
fab767d794 | ||
|
|
7dd70ca4c0 | ||
|
|
370760eeee | ||
|
|
24a62cb33d | ||
|
|
9e4a4ddf39 | ||
|
|
c281859509 | ||
|
|
2180a40bd3 | ||
|
|
997f9c3191 | ||
|
|
677c32ea79 | ||
|
|
edfc849652 | ||
|
|
9d296b623b | ||
|
|
5957b888a5 | ||
|
|
c7a91b1819 | ||
|
|
a099f8e296 | ||
|
|
16c8969028 | ||
|
|
65fde8f1b3 | ||
|
|
229db47e5d | ||
|
|
2e3397feb0 | ||
|
|
d5658ce477 | ||
|
|
ddf3f99da4 | ||
|
|
56785e6065 | ||
|
|
26e808d2a1 | ||
|
|
e3ac373f05 | ||
|
|
9e9a578921 | ||
|
|
f7172612e1 | ||
|
|
5aa2de7a40 | ||
|
|
e0b87d9d4e | ||
|
|
5607fdcddd | ||
|
|
651de071f7 | ||
|
|
5629ca7d96 | ||
|
|
bc403d97f2 | ||
|
|
292c78b193 | ||
|
|
ac35719038 | ||
|
|
02095e9281 | ||
|
|
8954a04602 | ||
|
|
8020db9e9a | ||
|
|
17c2f06338 | ||
|
|
9cff294a71 | ||
|
|
e983aaeca7 | ||
|
|
7ea774f35b | ||
|
|
d1846823ba | ||
|
|
fda89ac810 | ||
|
|
006fd4c438 | ||
|
|
9b7069a043 | ||
|
|
c64c25b2e1 | ||
|
|
c2727a3f19 | ||
|
|
37daf4f3e4 | ||
|
|
fcb7f6fcc0 | ||
|
|
429016d4a2 | ||
|
|
c83a450ec4 | ||
|
|
187b94a7d8 | ||
|
|
30225fd4c5 | ||
|
|
a4f053fa5b | ||
|
|
eab4fe83a0 | ||
|
|
78d1ae0379 | ||
|
|
87beb1f4d1 | ||
|
|
05c2b7d34e | ||
|
|
39d09a162a | ||
|
|
d291fea020 | ||
|
|
2665bff78e | ||
|
|
65d38ac8c3 | ||
|
|
8391d89bea | ||
|
|
ac2ed31726 | ||
|
|
47f947b045 | ||
|
|
63b051b342 | ||
|
|
a5729e2fa6 | ||
|
|
3cec854c5c | ||
|
|
26c6651a03 | ||
|
|
13001ede98 | ||
|
|
fda377a2fa | ||
|
|
bdfb894507 | ||
|
|
35c3511daa | ||
|
|
c1e19d0d93 | ||
|
|
e78aefb408 | ||
|
|
aa2e859b46 | ||
|
|
c0c8ae6c08 | ||
|
|
1225c663eb | ||
|
|
e052d607d5 | ||
|
|
8e5e11a554 | ||
|
|
57f0323f52 | ||
|
|
6e9f31d1e9 | ||
|
|
eeb844e35e | ||
|
|
d6a84ab413 | ||
|
|
68160d49dd | ||
|
|
0cc3d65839 | ||
|
|
df37387146 | ||
|
|
f72825cd46 | ||
|
|
6fb07d20cc | ||
|
|
b258ec1bed | ||
|
|
4fd55b8928 | ||
|
|
b3ea53fa46 | ||
|
|
fa0d19cc8c | ||
|
|
d5916e420c | ||
|
|
39b912befd | ||
|
|
37c5f24d91 | ||
|
|
ae72cd56f8 | ||
|
|
be5ef77896 | ||
|
|
0ed8f14015 | ||
|
|
a03e443541 | ||
|
|
4935459798 | ||
|
|
efb52873dd | ||
|
|
442f7595cc | ||
|
|
81cbcbb403 | ||
|
|
0a0e672b35 | ||
|
|
69644b266e | ||
|
|
5a4820c55f | ||
|
|
a5d69bb392 | ||
|
|
23ee45c033 | ||
|
|
31bfd015ae | ||
|
|
0125d8a0f6 | ||
|
|
4f64444f0f | ||
|
|
abf9cc3248 | ||
|
|
f5bf2e6374 | ||
|
|
24b3b1fa9e | ||
|
|
7433dddac3 | ||
|
|
fe938b6fc6 | ||
|
|
2db029672b | ||
|
|
602f9c4a0a | ||
|
|
551705ad62 | ||
|
|
d9581ce0ae | ||
|
|
e27800d501 | ||
|
|
927dffecb5 | ||
|
|
68b23b6339 | ||
|
|
174f54473e | ||
|
|
329824ab22 | ||
|
|
b0f76b97ef |
15
.github/ISSUE_TEMPLATE/sweep-template.yml
vendored
Normal file
15
.github/ISSUE_TEMPLATE/sweep-template.yml
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
name: Sweep Issue
|
||||
title: 'Sweep: '
|
||||
description: For small bugs, features, refactors, and tests to be handled by Sweep, an AI-powered junior developer.
|
||||
labels: sweep
|
||||
body:
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: Details
|
||||
description: Tell Sweep where and what to edit and provide enough context for a new developer to the codebase
|
||||
placeholder: |
|
||||
Unit Tests: Write unit tests for <FILE>. Test each function in the file. Make sure to test edge cases.
|
||||
Bugs: The bug might be in <FILE>. Here are the logs: ...
|
||||
Features: the new endpoint should use the ... class from <FILE> because it contains ... logic.
|
||||
Refactors: We are migrating this function to ... version because ...
|
||||
@@ -1,4 +1,4 @@
|
||||
name: Build and Push Backend Images on Tagging
|
||||
name: Build and Push Backend Image on Tag
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -32,3 +32,11 @@ jobs:
|
||||
tags: |
|
||||
danswer/danswer-backend:${{ github.ref_name }}
|
||||
danswer/danswer-backend:latest
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
with:
|
||||
image-ref: docker.io/danswer/danswer-backend:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
|
||||
42
.github/workflows/docker-build-push-model-server-container-on-tag.yml
vendored
Normal file
42
.github/workflows/docker-build-push-model-server-container-on-tag.yml
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
name: Build and Push Model Server Image on Tag
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Model Server Image Docker Build and Push
|
||||
uses: docker/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
danswer/danswer-model-server:${{ github.ref_name }}
|
||||
danswer/danswer-model-server:latest
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
with:
|
||||
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
@@ -1,4 +1,4 @@
|
||||
name: Build and Push Web Images on Tagging
|
||||
name: Build and Push Web Image on Tag
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -32,3 +32,11 @@ jobs:
|
||||
tags: |
|
||||
danswer/danswer-web-server:${{ github.ref_name }}
|
||||
danswer/danswer-web-server:latest
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
with:
|
||||
image-ref: docker.io/danswer/danswer-web-server:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
.env
|
||||
.DS_store
|
||||
.venv
|
||||
.venv
|
||||
.mypy_cache
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/CONTRIBUTING.md"} -->
|
||||
|
||||
# Contributing to Danswer
|
||||
Hey there! We are so excited that you're interested in Danswer.
|
||||
|
||||
@@ -86,7 +88,12 @@ Once the above is done, navigate to `danswer/web` run:
|
||||
npm i
|
||||
```
|
||||
|
||||
Install Playwright (required by the Web Connector), with the python venv active, run:
|
||||
Install Playwright (required by the Web Connector)
|
||||
|
||||
> Note: If you have just done the pip install, open a new terminal and source the python virtual-env again.
|
||||
This will update the path to include playwright
|
||||
|
||||
Then install Playwright by running:
|
||||
```bash
|
||||
playwright install
|
||||
```
|
||||
@@ -113,7 +120,7 @@ npm run dev
|
||||
|
||||
Package the Vespa schema. This will only need to be done when the Vespa schema is updated locally.
|
||||
|
||||
Nagivate to `danswer/backend/danswer/document_index/vespa/app_config` and run:
|
||||
Navigate to `danswer/backend/danswer/document_index/vespa/app_config` and run:
|
||||
```bash
|
||||
zip -r ../vespa-app.zip .
|
||||
```
|
||||
|
||||
32
README.md
32
README.md
@@ -1,3 +1,5 @@
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/README.md"} -->
|
||||
|
||||
<h2 align="center">
|
||||
<a href="https://www.danswer.ai/"> <img width="50%" src="https://github.com/danswer-owners/danswer/blob/1fabd9372d66cd54238847197c33f091a724803b/DanswerWithName.png?raw=true)" /></a>
|
||||
</h2>
|
||||
@@ -9,7 +11,7 @@
|
||||
<a href="https://docs.danswer.dev/" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-1u5ycen3o-6SJbWfivLWP5LPyp_jftuw" target="_blank">
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-1u3h3ke3b-VGh1idW19R8oiNRiKBYv2w" target="_blank">
|
||||
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
|
||||
</a>
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
@@ -27,7 +29,11 @@
|
||||
Danswer provides a fully-featured web UI:
|
||||
|
||||
|
||||
https://github.com/danswer-ai/danswer/assets/25087905/619607a1-4ad2-41a0-9728-351752acc26e
|
||||
|
||||
|
||||
https://github.com/danswer-ai/danswer/assets/32520769/563be14c-9304-47b5-bf0a-9049c2b6f410
|
||||
|
||||
|
||||
|
||||
|
||||
Or, if you prefer, you can plug Danswer into your existing Slack workflows (more integrations to come 😁):
|
||||
@@ -45,37 +51,43 @@ Danswer can easily be tested locally or deployed on a virtual machine with a sin
|
||||
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/danswer-ai/danswer/tree/main/deployment/kubernetes).
|
||||
|
||||
## 💃 Features
|
||||
* Direct QA powered by Generative AI models with answers backed by quotes and source links.
|
||||
* Intelligent Document Retrieval (Semantic Search/Reranking) using the latest LLMs.
|
||||
* An AI Helper backed by a custom Deep Learning model to interpret user intent.
|
||||
* Direct QA + Chat powered by Generative AI models with answers backed by quotes and source links.
|
||||
* Intelligent Document Retrieval (Hybrid Search + Reranking) using the latest NLP models.
|
||||
* Automatic time/source filter extraction from natural language + custom model to identify user intent.
|
||||
* User authentication and document level access management.
|
||||
* Support for an LLM of your choice (GPT-4, Llama2, Orca, etc.)
|
||||
* Management Dashboard to manage connectors and set up features such as live update fetching.
|
||||
* Support for LLMs of your choice (GPT-4, Mixstral, Llama2, etc.)
|
||||
* Management Dashboards to manage connectors and set up features such as live update fetching.
|
||||
* One line Docker Compose (or Kubernetes) deployment to host Danswer anywhere.
|
||||
|
||||
## 🔌 Connectors
|
||||
|
||||
Danswer currently syncs documents (every 10 minutes) from:
|
||||
Efficiently pulls the latest changes from:
|
||||
* Slack
|
||||
* GitHub
|
||||
* Google Drive
|
||||
* Confluence
|
||||
* Jira
|
||||
* Zendesk
|
||||
* Notion
|
||||
* Gong
|
||||
* Slab
|
||||
* Linear
|
||||
* Productboard
|
||||
* Guru
|
||||
* Zulip
|
||||
* Bookstack
|
||||
* Document360
|
||||
* Request Tracker
|
||||
* Hubspot
|
||||
* Local Files
|
||||
* Websites
|
||||
* With more to come...
|
||||
|
||||
## 🚧 Roadmap
|
||||
* Chat/Conversation support.
|
||||
* Organizational understanding.
|
||||
* Ability to locate and suggest experts.
|
||||
* Ability to locate and suggest experts from your team.
|
||||
* Code Search
|
||||
* Structured Query Languages (SQL, Excel formulas, etc.)
|
||||
|
||||
## 💡 Contributing
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
17
backend/.dockerignore
Normal file
17
backend/.dockerignore
Normal file
@@ -0,0 +1,17 @@
|
||||
**/__pycache__
|
||||
venv/
|
||||
env/
|
||||
*.egg-info
|
||||
.cache
|
||||
.git/
|
||||
.svn/
|
||||
.vscode/
|
||||
.idea/
|
||||
*.log
|
||||
log/
|
||||
.env
|
||||
secrets.yaml
|
||||
build/
|
||||
dist/
|
||||
.coverage
|
||||
htmlcov/
|
||||
1
backend/.gitignore
vendored
1
backend/.gitignore
vendored
@@ -1,4 +1,5 @@
|
||||
__pycache__/
|
||||
.mypy_cache
|
||||
.idea/
|
||||
site_crawls/
|
||||
.ipynb_checkpoints/
|
||||
|
||||
@@ -1,10 +1,18 @@
|
||||
FROM python:3.11.4-slim-bookworm
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.3-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION}
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
|
||||
# Install system dependencies
|
||||
# cmake needed for psycopg (postgres)
|
||||
# libpq-dev needed for psycopg (postgres)
|
||||
# curl included just for users' convenience
|
||||
# zip for Vespa step futher down
|
||||
# ca-certificates for HTTPS
|
||||
RUN apt-get update && \
|
||||
apt-get install -y git cmake pkg-config libprotobuf-c-dev protobuf-compiler \
|
||||
libprotobuf-dev libgoogle-perftools-dev libpq-dev build-essential cron curl \
|
||||
supervisor zip ca-certificates gnupg && \
|
||||
apt-get install -y cmake curl zip ca-certificates && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
@@ -13,27 +21,15 @@ RUN apt-get update && \
|
||||
COPY ./requirements/default.txt /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \
|
||||
pip uninstall -y py && \
|
||||
playwright install chromium && \
|
||||
playwright install-deps chromium
|
||||
|
||||
# install nodejs and replace nodejs packaged with playwright (18.17.0) with the one installed below
|
||||
# based on the instructions found here:
|
||||
# https://nodejs.org/en/download/package-manager#debian-and-ubuntu-based-linux-distributions
|
||||
# this is temporarily needed until playwright updates their packaged node version to
|
||||
# 20.5.1+
|
||||
RUN mkdir -p /etc/apt/keyrings && \
|
||||
curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg && \
|
||||
echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" | tee /etc/apt/sources.list.d/nodesource.list && \
|
||||
apt-get update && \
|
||||
apt-get install -y nodejs && \
|
||||
cp /usr/bin/node /usr/local/lib/python3.11/site-packages/playwright/driver/node && \
|
||||
apt-get remove -y nodejs
|
||||
playwright install chromium && playwright install-deps chromium && \
|
||||
ln -s /usr/local/bin/supervisord /usr/bin/supervisord
|
||||
|
||||
# Cleanup for CVEs and size reduction
|
||||
# Remove tornado test key to placate vulnerability scanners
|
||||
# More details can be found here:
|
||||
# https://github.com/tornadoweb/tornado/issues/3107
|
||||
RUN apt-get remove -y linux-libc-dev && \
|
||||
# xserver-common and xvfb included by playwright installation but not needed after
|
||||
# perl-base is part of the base Python Debian image but not needed for Danswer functionality
|
||||
# perl-base could only be removed with --allow-remove-essential
|
||||
RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cmake libldap-2.5-0 libldap-2.5-0 && \
|
||||
apt-get autoremove -y && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
@@ -41,18 +37,16 @@ RUN apt-get remove -y linux-libc-dev && \
|
||||
# Set up application files
|
||||
WORKDIR /app
|
||||
COPY ./danswer /app/danswer
|
||||
COPY ./shared_models /app/shared_models
|
||||
COPY ./alembic /app/alembic
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
|
||||
# Create Vespa app zip
|
||||
WORKDIR /app/danswer/document_index/vespa/app_config
|
||||
RUN zip -r /app/danswer/vespa-app.zip .
|
||||
WORKDIR /app
|
||||
|
||||
# TODO: remove this once all users have migrated
|
||||
COPY ./scripts/migrate_vespa_to_acl.py /app/migrate_vespa_to_acl.py
|
||||
|
||||
ENV PYTHONPATH /app
|
||||
|
||||
# Default command which does nothing
|
||||
|
||||
39
backend/Dockerfile.model_server
Normal file
39
backend/Dockerfile.model_server
Normal file
@@ -0,0 +1,39 @@
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.3-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION}
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
|
||||
COPY ./requirements/model_server.txt /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt
|
||||
|
||||
RUN apt-get remove -y --allow-remove-essential perl-base && \
|
||||
apt-get autoremove -y
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Needed for model configs and defaults
|
||||
COPY ./danswer/configs /app/danswer/configs
|
||||
COPY ./danswer/dynamic_configs /app/danswer/dynamic_configs
|
||||
|
||||
# Utils used by model server
|
||||
COPY ./danswer/utils/logger.py /app/danswer/utils/logger.py
|
||||
COPY ./danswer/utils/timing.py /app/danswer/utils/timing.py
|
||||
COPY ./danswer/utils/telemetry.py /app/danswer/utils/telemetry.py
|
||||
|
||||
# Place to fetch version information
|
||||
COPY ./danswer/__init__.py /app/danswer/__init__.py
|
||||
|
||||
# Shared implementations for running NLP models locally
|
||||
COPY ./danswer/search/search_nlp_models.py /app/danswer/search/search_nlp_models.py
|
||||
|
||||
# Request/Response models
|
||||
COPY ./shared_models /app/shared_models
|
||||
|
||||
# Model Server main code
|
||||
COPY ./model_server /app/model_server
|
||||
|
||||
ENV PYTHONPATH /app
|
||||
|
||||
CMD ["uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000"]
|
||||
@@ -1,4 +1,8 @@
|
||||
Generic single-database configuration with an async dbapi.
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/backend/alembic/README.md"} -->
|
||||
|
||||
# Alembic DB Migrations
|
||||
These files are for creating/updating the tables in the Relational DB (Postgres).
|
||||
Danswer migrations use a generic single-database configuration with an async dbapi.
|
||||
|
||||
## To generate new migrations:
|
||||
run from danswer/backend:
|
||||
@@ -7,7 +11,6 @@ run from danswer/backend:
|
||||
More info can be found here: https://alembic.sqlalchemy.org/en/latest/autogenerate.html
|
||||
|
||||
## Running migrations
|
||||
|
||||
To run all un-applied migrations:
|
||||
`alembic upgrade head`
|
||||
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Introduce Danswer APIs
|
||||
|
||||
Revision ID: 15326fcec57e
|
||||
Revises: 77d07dffae64
|
||||
Create Date: 2023-11-11 20:51:24.228999
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "15326fcec57e"
|
||||
down_revision = "77d07dffae64"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("credential", "is_admin", new_column_name="admin_public")
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("from_ingestion_api", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.alter_column(
|
||||
"connector",
|
||||
"source",
|
||||
type_=sa.String(length=50),
|
||||
existing_type=sa.Enum(DocumentSource, native_enum=False),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("document", "from_ingestion_api")
|
||||
op.alter_column("credential", "admin_public", new_column_name="is_admin")
|
||||
@@ -0,0 +1,28 @@
|
||||
"""Add additional retrieval controls to Persona
|
||||
|
||||
Revision ID: 50b683a8295c
|
||||
Revises: 7da0ae5ad583
|
||||
Create Date: 2023-11-27 17:23:29.668422
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "50b683a8295c"
|
||||
down_revision = "7da0ae5ad583"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("persona", sa.Column("num_chunks", sa.Integer(), nullable=True))
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("apply_llm_relevance_filter", sa.Boolean(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "apply_llm_relevance_filter")
|
||||
op.drop_column("persona", "num_chunks")
|
||||
@@ -0,0 +1,32 @@
|
||||
"""CC-Pair Name not Unique
|
||||
|
||||
Revision ID: 76b60d407dfb
|
||||
Revises: b156fa702355
|
||||
Create Date: 2023-12-22 21:42:10.018804
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "76b60d407dfb"
|
||||
down_revision = "b156fa702355"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("DELETE FROM connector_credential_pair WHERE name IS NULL")
|
||||
op.drop_constraint(
|
||||
"connector_credential_pair__name__key",
|
||||
"connector_credential_pair",
|
||||
type_="unique",
|
||||
)
|
||||
op.alter_column(
|
||||
"connector_credential_pair", "name", existing_type=sa.String(), nullable=False
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# This wasn't really required by the code either, no good reason to make it unique again
|
||||
pass
|
||||
@@ -0,0 +1,23 @@
|
||||
"""Add description to persona
|
||||
|
||||
Revision ID: 7da0ae5ad583
|
||||
Revises: e86866a9c78a
|
||||
Create Date: 2023-11-27 00:16:19.959414
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7da0ae5ad583"
|
||||
down_revision = "e86866a9c78a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("persona", sa.Column("description", sa.String(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "description")
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Add chat session to query_event
|
||||
|
||||
Revision ID: 80696cf850ae
|
||||
Revises: 15326fcec57e
|
||||
Create Date: 2023-11-26 02:38:35.008070
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "80696cf850ae"
|
||||
down_revision = "15326fcec57e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"query_event",
|
||||
sa.Column("chat_session_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_query_event_chat_session_id",
|
||||
"query_event",
|
||||
"chat_session",
|
||||
["chat_session_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"fk_query_event_chat_session_id", "query_event", type_="foreignkey"
|
||||
)
|
||||
op.drop_column("query_event", "chat_session_id")
|
||||
@@ -0,0 +1,34 @@
|
||||
"""Add is_visible to Persona
|
||||
|
||||
Revision ID: 891cd83c87a8
|
||||
Revises: 76b60d407dfb
|
||||
Create Date: 2023-12-21 11:55:54.132279
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "891cd83c87a8"
|
||||
down_revision = "76b60d407dfb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("is_visible", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.execute("UPDATE persona SET is_visible = true")
|
||||
op.alter_column("persona", "is_visible", nullable=False)
|
||||
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("display_priority", sa.Integer(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "is_visible")
|
||||
op.drop_column("persona", "display_priority")
|
||||
61
backend/alembic/versions/904e5138fffb_tags.py
Normal file
61
backend/alembic/versions/904e5138fffb_tags.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Tags
|
||||
|
||||
Revision ID: 904e5138fffb
|
||||
Revises: 891cd83c87a8
|
||||
Create Date: 2024-01-01 10:44:43.733974
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "904e5138fffb"
|
||||
down_revision = "891cd83c87a8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"tag",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("tag_key", sa.String(), nullable=False),
|
||||
sa.Column("tag_value", sa.String(), nullable=False),
|
||||
sa.Column("source", sa.String(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"tag_key", "tag_value", "source", name="_tag_key_value_source_uc"
|
||||
),
|
||||
)
|
||||
op.create_table(
|
||||
"document__tag",
|
||||
sa.Column("document_id", sa.String(), nullable=False),
|
||||
sa.Column("tag_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["document_id"],
|
||||
["document.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["tag_id"],
|
||||
["tag.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("document_id", "tag_id"),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"search_doc",
|
||||
sa.Column(
|
||||
"doc_metadata",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.execute("UPDATE search_doc SET doc_metadata = '{}' WHERE doc_metadata IS NULL")
|
||||
op.alter_column("search_doc", "doc_metadata", nullable=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("document__tag")
|
||||
op.drop_table("tag")
|
||||
op.drop_column("search_doc", "doc_metadata")
|
||||
520
backend/alembic/versions/b156fa702355_chat_reworked.py
Normal file
520
backend/alembic/versions/b156fa702355_chat_reworked.py
Normal file
@@ -0,0 +1,520 @@
|
||||
"""Chat Reworked
|
||||
|
||||
Revision ID: b156fa702355
|
||||
Revises: baf71f781b9e
|
||||
Create Date: 2023-12-12 00:57:41.823371
|
||||
|
||||
"""
|
||||
import fastapi_users_db_sqlalchemy
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.dialects.postgresql import ENUM
|
||||
from danswer.configs.constants import DocumentSource
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b156fa702355"
|
||||
down_revision = "baf71f781b9e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
searchtype_enum = ENUM(
|
||||
"KEYWORD", "SEMANTIC", "HYBRID", name="searchtype", create_type=True
|
||||
)
|
||||
recencybiassetting_enum = ENUM(
|
||||
"FAVOR_RECENT",
|
||||
"BASE_DECAY",
|
||||
"NO_DECAY",
|
||||
"AUTO",
|
||||
name="recencybiassetting",
|
||||
create_type=True,
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
searchtype_enum.create(bind)
|
||||
recencybiassetting_enum.create(bind)
|
||||
|
||||
# This is irrecoverable, whatever
|
||||
op.execute("DELETE FROM chat_feedback")
|
||||
op.execute("DELETE FROM document_retrieval_feedback")
|
||||
|
||||
op.create_table(
|
||||
"search_doc",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("document_id", sa.String(), nullable=False),
|
||||
sa.Column("chunk_ind", sa.Integer(), nullable=False),
|
||||
sa.Column("semantic_id", sa.String(), nullable=False),
|
||||
sa.Column("link", sa.String(), nullable=True),
|
||||
sa.Column("blurb", sa.String(), nullable=False),
|
||||
sa.Column("boost", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"source_type",
|
||||
sa.Enum(DocumentSource, native=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("hidden", sa.Boolean(), nullable=False),
|
||||
sa.Column("score", sa.Float(), nullable=False),
|
||||
sa.Column("match_highlights", postgresql.ARRAY(sa.String()), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("primary_owners", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
sa.Column("secondary_owners", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"prompt",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("description", sa.String(), nullable=False),
|
||||
sa.Column("system_prompt", sa.Text(), nullable=False),
|
||||
sa.Column("task_prompt", sa.Text(), nullable=False),
|
||||
sa.Column("include_citations", sa.Boolean(), nullable=False),
|
||||
sa.Column("datetime_aware", sa.Boolean(), nullable=False),
|
||||
sa.Column("default_prompt", sa.Boolean(), nullable=False),
|
||||
sa.Column("deleted", sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"persona__prompt",
|
||||
sa.Column("persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column("prompt_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["prompt_id"],
|
||||
["prompt.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("persona_id", "prompt_id"),
|
||||
)
|
||||
|
||||
# Changes to persona first so chat_sessions can have the right persona
|
||||
# The empty persona will be overwritten on server startup
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"search_type",
|
||||
searchtype_enum,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.execute("UPDATE persona SET search_type = 'HYBRID'")
|
||||
op.alter_column("persona", "search_type", nullable=False)
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("llm_relevance_filter", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.execute("UPDATE persona SET llm_relevance_filter = TRUE")
|
||||
op.alter_column("persona", "llm_relevance_filter", nullable=False)
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("llm_filter_extraction", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.execute("UPDATE persona SET llm_filter_extraction = TRUE")
|
||||
op.alter_column("persona", "llm_filter_extraction", nullable=False)
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"recency_bias",
|
||||
recencybiassetting_enum,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.execute("UPDATE persona SET recency_bias = 'BASE_DECAY'")
|
||||
op.alter_column("persona", "recency_bias", nullable=False)
|
||||
op.alter_column("persona", "description", existing_type=sa.VARCHAR(), nullable=True)
|
||||
op.execute("UPDATE persona SET description = ''")
|
||||
op.alter_column("persona", "description", nullable=False)
|
||||
op.create_foreign_key("persona__user_fk", "persona", "user", ["user_id"], ["id"])
|
||||
op.drop_column("persona", "datetime_aware")
|
||||
op.drop_column("persona", "tools")
|
||||
op.drop_column("persona", "hint_text")
|
||||
op.drop_column("persona", "apply_llm_relevance_filter")
|
||||
op.drop_column("persona", "retrieval_enabled")
|
||||
op.drop_column("persona", "system_text")
|
||||
|
||||
# Need to create a persona row so fk can work
|
||||
result = bind.execute(sa.text("SELECT 1 FROM persona WHERE id = 0"))
|
||||
exists = result.fetchone()
|
||||
if not exists:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona (
|
||||
id, user_id, name, description, search_type, num_chunks,
|
||||
llm_relevance_filter, llm_filter_extraction, recency_bias,
|
||||
llm_model_version_override, default_persona, deleted
|
||||
) VALUES (
|
||||
0, NULL, '', '', 'HYBRID', NULL,
|
||||
TRUE, TRUE, 'BASE_DECAY', NULL, TRUE, FALSE
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
delete_statement = sa.text(
|
||||
"""
|
||||
DELETE FROM persona
|
||||
WHERE name = 'Danswer' AND default_persona = TRUE AND id != 0
|
||||
"""
|
||||
)
|
||||
|
||||
bind.execute(delete_statement)
|
||||
|
||||
op.add_column(
|
||||
"chat_feedback",
|
||||
sa.Column("chat_message_id", sa.Integer(), nullable=False),
|
||||
)
|
||||
op.drop_constraint(
|
||||
"chat_feedback_chat_message_chat_session_id_chat_message_me_fkey",
|
||||
"chat_feedback",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_column("chat_feedback", "chat_message_edit_number")
|
||||
op.drop_column("chat_feedback", "chat_message_chat_session_id")
|
||||
op.drop_column("chat_feedback", "chat_message_message_number")
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"id",
|
||||
sa.Integer(),
|
||||
primary_key=True,
|
||||
autoincrement=True,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("parent_message", sa.Integer(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("latest_child_message", sa.Integer(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("rephrased_query", sa.Text(), nullable=True)
|
||||
)
|
||||
op.add_column("chat_message", sa.Column("prompt_id", sa.Integer(), nullable=True))
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("citations", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
)
|
||||
op.add_column("chat_message", sa.Column("error", sa.Text(), nullable=True))
|
||||
op.drop_constraint("fk_chat_message_persona_id", "chat_message", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"chat_message__prompt_fk", "chat_message", "prompt", ["prompt_id"], ["id"]
|
||||
)
|
||||
op.drop_column("chat_message", "parent_edit_number")
|
||||
op.drop_column("chat_message", "persona_id")
|
||||
op.drop_column("chat_message", "reference_docs")
|
||||
op.drop_column("chat_message", "edit_number")
|
||||
op.drop_column("chat_message", "latest")
|
||||
op.drop_column("chat_message", "message_number")
|
||||
op.add_column("chat_session", sa.Column("one_shot", sa.Boolean(), nullable=True))
|
||||
op.execute("UPDATE chat_session SET one_shot = TRUE")
|
||||
op.alter_column("chat_session", "one_shot", nullable=False)
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"persona_id",
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=True,
|
||||
)
|
||||
op.execute("UPDATE chat_session SET persona_id = 0")
|
||||
op.alter_column("chat_session", "persona_id", nullable=False)
|
||||
op.add_column(
|
||||
"document_retrieval_feedback",
|
||||
sa.Column("chat_message_id", sa.Integer(), nullable=False),
|
||||
)
|
||||
op.drop_constraint(
|
||||
"document_retrieval_feedback_qa_event_id_fkey",
|
||||
"document_retrieval_feedback",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"document_retrieval_feedback__chat_message_fk",
|
||||
"document_retrieval_feedback",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
)
|
||||
op.drop_column("document_retrieval_feedback", "qa_event_id")
|
||||
|
||||
# Relation table must be created after the other tables are correct
|
||||
op.create_table(
|
||||
"chat_message__search_doc",
|
||||
sa.Column("chat_message_id", sa.Integer(), nullable=False),
|
||||
sa.Column("search_doc_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["chat_message_id"],
|
||||
["chat_message.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_doc_id"],
|
||||
["search_doc.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("chat_message_id", "search_doc_id"),
|
||||
)
|
||||
|
||||
# Needs to be created after chat_message id field is added
|
||||
op.create_foreign_key(
|
||||
"chat_feedback__chat_message_fk",
|
||||
"chat_feedback",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
op.drop_table("query_event")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
|
||||
)
|
||||
op.drop_constraint(
|
||||
"document_retrieval_feedback__chat_message_fk",
|
||||
"document_retrieval_feedback",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint("persona__user_fk", "persona", type_="foreignkey")
|
||||
op.drop_constraint("chat_message__prompt_fk", "chat_message", type_="foreignkey")
|
||||
op.drop_constraint(
|
||||
"chat_message__search_doc_chat_message_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("system_text", sa.TEXT(), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"retrieval_enabled",
|
||||
sa.BOOLEAN(),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.execute("UPDATE persona SET retrieval_enabled = TRUE")
|
||||
op.alter_column("persona", "retrieval_enabled", nullable=False)
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"apply_llm_relevance_filter",
|
||||
sa.BOOLEAN(),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("hint_text", sa.TEXT(), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"tools",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("datetime_aware", sa.BOOLEAN(), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.execute("UPDATE persona SET datetime_aware = TRUE")
|
||||
op.alter_column("persona", "datetime_aware", nullable=False)
|
||||
op.alter_column("persona", "description", existing_type=sa.VARCHAR(), nullable=True)
|
||||
op.drop_column("persona", "recency_bias")
|
||||
op.drop_column("persona", "llm_filter_extraction")
|
||||
op.drop_column("persona", "llm_relevance_filter")
|
||||
op.drop_column("persona", "search_type")
|
||||
op.drop_column("persona", "user_id")
|
||||
op.add_column(
|
||||
"document_retrieval_feedback",
|
||||
sa.Column("qa_event_id", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
)
|
||||
op.drop_column("document_retrieval_feedback", "chat_message_id")
|
||||
op.alter_column(
|
||||
"chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True
|
||||
)
|
||||
op.drop_column("chat_session", "one_shot")
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"message_number",
|
||||
sa.INTEGER(),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("latest", sa.BOOLEAN(), autoincrement=False, nullable=False),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"edit_number",
|
||||
sa.INTEGER(),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"reference_docs",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("persona_id", sa.INTEGER(), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"parent_edit_number",
|
||||
sa.INTEGER(),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_chat_message_persona_id",
|
||||
"chat_message",
|
||||
"persona",
|
||||
["persona_id"],
|
||||
["id"],
|
||||
)
|
||||
op.drop_column("chat_message", "error")
|
||||
op.drop_column("chat_message", "citations")
|
||||
op.drop_column("chat_message", "prompt_id")
|
||||
op.drop_column("chat_message", "rephrased_query")
|
||||
op.drop_column("chat_message", "latest_child_message")
|
||||
op.drop_column("chat_message", "parent_message")
|
||||
op.drop_column("chat_message", "id")
|
||||
op.add_column(
|
||||
"chat_feedback",
|
||||
sa.Column(
|
||||
"chat_message_message_number",
|
||||
sa.INTEGER(),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_feedback",
|
||||
sa.Column(
|
||||
"chat_message_chat_session_id",
|
||||
sa.INTEGER(),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_feedback",
|
||||
sa.Column(
|
||||
"chat_message_edit_number",
|
||||
sa.INTEGER(),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
op.drop_column("chat_feedback", "chat_message_id")
|
||||
op.create_table(
|
||||
"query_event",
|
||||
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column("query", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.Column(
|
||||
"selected_search_flow",
|
||||
sa.VARCHAR(),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("llm_answer", sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column("feedback", sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=True),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
postgresql.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"retrieved_document_ids",
|
||||
postgresql.ARRAY(sa.VARCHAR()),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("chat_session_id", sa.INTEGER(), autoincrement=False, nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["chat_session_id"],
|
||||
["chat_session.id"],
|
||||
name="fk_query_event_chat_session_id",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"], ["user.id"], name="query_event_user_id_fkey"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name="query_event_pkey"),
|
||||
)
|
||||
op.drop_table("chat_message__search_doc")
|
||||
op.drop_table("persona__prompt")
|
||||
op.drop_table("prompt")
|
||||
op.drop_table("search_doc")
|
||||
op.create_unique_constraint(
|
||||
"uq_chat_message_combination",
|
||||
"chat_message",
|
||||
["chat_session_id", "message_number", "edit_number"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_feedback_chat_message_chat_session_id_chat_message_me_fkey",
|
||||
"chat_feedback",
|
||||
"chat_message",
|
||||
[
|
||||
"chat_message_chat_session_id",
|
||||
"chat_message_message_number",
|
||||
"chat_message_edit_number",
|
||||
],
|
||||
["chat_session_id", "message_number", "edit_number"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"document_retrieval_feedback_qa_event_id_fkey",
|
||||
"document_retrieval_feedback",
|
||||
"query_event",
|
||||
["qa_event_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
op.execute("DROP TYPE IF EXISTS searchtype")
|
||||
op.execute("DROP TYPE IF EXISTS recencybiassetting")
|
||||
op.execute("DROP TYPE IF EXISTS documentsource")
|
||||
@@ -0,0 +1,26 @@
|
||||
"""Add llm_model_version_override to Persona
|
||||
|
||||
Revision ID: baf71f781b9e
|
||||
Revises: 50b683a8295c
|
||||
Create Date: 2023-12-06 21:56:50.286158
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "baf71f781b9e"
|
||||
down_revision = "50b683a8295c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("llm_model_version_override", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "llm_model_version_override")
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Add persona to chat_session
|
||||
|
||||
Revision ID: e86866a9c78a
|
||||
Revises: 80696cf850ae
|
||||
Create Date: 2023-11-26 02:51:47.657357
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e86866a9c78a"
|
||||
down_revision = "80696cf850ae"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("chat_session", sa.Column("persona_id", sa.Integer(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
"fk_chat_session_persona_id", "chat_session", "persona", ["persona_id"], ["id"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint("fk_chat_session_persona_id", "chat_session", type_="foreignkey")
|
||||
op.drop_column("chat_session", "persona_id")
|
||||
@@ -0,0 +1,3 @@
|
||||
import os
|
||||
|
||||
__version__ = os.environ.get("DANSWER_VERSION", "") or "0.3-dev"
|
||||
|
||||
@@ -4,7 +4,7 @@ from danswer.access.models import DocumentAccess
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
from danswer.db.document import get_acccess_info_for_documents
|
||||
from danswer.db.models import User
|
||||
from danswer.server.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
|
||||
@@ -48,6 +48,8 @@ from danswer.db.engine import get_session
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import User
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
@@ -66,6 +68,12 @@ def verify_auth_setting() -> None:
|
||||
logger.info(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
def user_needs_to_be_verified() -> bool:
|
||||
# all other auth types besides basic should require users to be
|
||||
# verified
|
||||
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
|
||||
|
||||
def get_user_whitelist() -> list[str]:
|
||||
global _user_whitelist
|
||||
if _user_whitelist is None:
|
||||
@@ -102,10 +110,9 @@ def verify_email_domain(email: str) -> None:
|
||||
def send_user_verification_email(user_email: str, token: str) -> None:
|
||||
msg = MIMEMultipart()
|
||||
msg["Subject"] = "Danswer Email Verification"
|
||||
msg["From"] = "no-reply@danswer.dev"
|
||||
msg["To"] = user_email
|
||||
|
||||
link = f"{WEB_DOMAIN}/verify-email?token={token}"
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
|
||||
body = MIMEText(f"Click the following link to verify your email address: {link}")
|
||||
msg.attach(body)
|
||||
@@ -170,6 +177,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
self, user: User, request: Optional[Request] = None
|
||||
) -> None:
|
||||
logger.info(f"User {user.id} has registered.")
|
||||
optional_telemetry(record_type=RecordType.SIGN_UP, data={"user": "create"})
|
||||
|
||||
async def on_after_forgot_password(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
@@ -253,9 +261,11 @@ fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
|
||||
)
|
||||
|
||||
|
||||
optional_valid_user = fastapi_users.current_user(
|
||||
active=True, verified=REQUIRE_EMAIL_VERIFICATION, optional=True
|
||||
)
|
||||
# NOTE: verified=REQUIRE_EMAIL_VERIFICATION is not used here since we
|
||||
# take care of that in `double_check_user` ourself. This is needed, since
|
||||
# we want the /me endpoint to still return a user even if they are not
|
||||
# yet verified, so that the frontend knows they exist
|
||||
optional_valid_user = fastapi_users.current_user(active=True, optional=True)
|
||||
|
||||
|
||||
async def double_check_user(
|
||||
@@ -273,6 +283,12 @@ async def double_check_user(
|
||||
detail="Access denied. User is not authenticated.",
|
||||
)
|
||||
|
||||
if user_needs_to_be_verified() and not user.is_verified:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
|
||||
@@ -36,8 +36,9 @@ from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_broker_url = "sqla+" + build_connection_string(db_api=SYNC_DB_API)
|
||||
celery_backend_url = "db+" + build_connection_string(db_api=SYNC_DB_API)
|
||||
connection_string = build_connection_string(db_api=SYNC_DB_API)
|
||||
celery_broker_url = f"sqla+{connection_string}"
|
||||
celery_backend_url = f"db+{connection_string}"
|
||||
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
|
||||
|
||||
|
||||
@@ -208,8 +209,10 @@ def clean_old_temp_files_task(
|
||||
Currently handled async of the indexing job"""
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
for file in os.listdir(base_path):
|
||||
if file_age_in_hours(file) > age_threshold_in_hours:
|
||||
os.remove(Path(base_path) / file)
|
||||
full_file_path = Path(base_path) / file
|
||||
if file_age_in_hours(full_file_path) > age_threshold_in_hours:
|
||||
logger.info(f"Cleaning up uploaded file: {full_file_path}")
|
||||
os.remove(full_file_path)
|
||||
|
||||
|
||||
#####
|
||||
|
||||
@@ -2,7 +2,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.server.models import DeletionAttemptSnapshot
|
||||
from danswer.server.documents.models import DeletionAttemptSnapshot
|
||||
|
||||
|
||||
def get_deletion_status(
|
||||
|
||||
@@ -11,8 +11,6 @@ connector / credential pair from the access list
|
||||
(6) delete all relevant entries from postgres
|
||||
"""
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -35,9 +33,8 @@ from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.server.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -173,14 +170,8 @@ def delete_connector_credential_pair(
|
||||
|
||||
# Clean up document sets / access information from Postgres
|
||||
# and sync these updates to Vespa
|
||||
cleanup_synced_entities__versioned = cast(
|
||||
Callable[[ConnectorCredentialPair, Session], None],
|
||||
fetch_versioned_implementation(
|
||||
"danswer.background.connector_deletion",
|
||||
"cleanup_synced_entities",
|
||||
),
|
||||
)
|
||||
cleanup_synced_entities__versioned(cc_pair, db_session)
|
||||
# TODO: add user group cleanup with `fetch_versioned_implementation`
|
||||
cleanup_synced_entities(cc_pair, db_session)
|
||||
|
||||
# clean up the rest of the related Postgres entities
|
||||
delete_index_attempts(
|
||||
|
||||
75
backend/danswer/background/indexing/checkpointing.py
Normal file
75
backend/danswer/background/indexing/checkpointing.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Experimental functionality related to splitting up indexing
|
||||
into a series of checkpoints to better handle intermittent failures
|
||||
/ jobs being killed by cloud providers."""
|
||||
import datetime
|
||||
|
||||
from danswer.configs.app_configs import EXPERIMENTAL_CHECKPOINTING_ENABLED
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import datetime_to_utc
|
||||
|
||||
|
||||
def _2010_dt() -> datetime.datetime:
|
||||
return datetime.datetime(year=2010, month=1, day=1, tzinfo=datetime.timezone.utc)
|
||||
|
||||
|
||||
def _2020_dt() -> datetime.datetime:
|
||||
return datetime.datetime(year=2020, month=1, day=1, tzinfo=datetime.timezone.utc)
|
||||
|
||||
|
||||
def _default_end_time(
|
||||
last_successful_run: datetime.datetime | None,
|
||||
) -> datetime.datetime:
|
||||
"""If year is before 2010, go to the beginning of 2010.
|
||||
If year is 2010-2020, go in 5 year increments.
|
||||
If year > 2020, then go in 180 day increments.
|
||||
|
||||
For connectors that don't support a `filter_by` and instead rely on `sort_by`
|
||||
for polling, then this will cause a massive duplication of fetches. For these
|
||||
connectors, you may want to override this function to return a more reasonable
|
||||
plan (e.g. extending the 2020+ windows to 6 months, 1 year, or higher)."""
|
||||
last_successful_run = (
|
||||
datetime_to_utc(last_successful_run) if last_successful_run else None
|
||||
)
|
||||
if last_successful_run is None or last_successful_run < _2010_dt():
|
||||
return _2010_dt()
|
||||
|
||||
if last_successful_run < _2020_dt():
|
||||
return min(last_successful_run + datetime.timedelta(days=365 * 5), _2020_dt())
|
||||
|
||||
return last_successful_run + datetime.timedelta(days=180)
|
||||
|
||||
|
||||
def find_end_time_for_indexing_attempt(
|
||||
last_successful_run: datetime.datetime | None, source_type: DocumentSource
|
||||
) -> datetime.datetime | None:
|
||||
# NOTE: source_type can be used to override the default for certain connectors
|
||||
end_of_window = _default_end_time(last_successful_run)
|
||||
now = datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
if end_of_window < now:
|
||||
return end_of_window
|
||||
|
||||
# None signals that we should index up to current time
|
||||
return None
|
||||
|
||||
|
||||
def get_time_windows_for_index_attempt(
|
||||
last_successful_run: datetime.datetime, source_type: DocumentSource
|
||||
) -> list[tuple[datetime.datetime, datetime.datetime]]:
|
||||
if not EXPERIMENTAL_CHECKPOINTING_ENABLED:
|
||||
return [(last_successful_run, datetime.datetime.now(tz=datetime.timezone.utc))]
|
||||
|
||||
time_windows: list[tuple[datetime.datetime, datetime.datetime]] = []
|
||||
start_of_window: datetime.datetime | None = last_successful_run
|
||||
while start_of_window:
|
||||
end_of_window = find_end_time_for_indexing_attempt(
|
||||
last_successful_run=start_of_window, source_type=source_type
|
||||
)
|
||||
time_windows.append(
|
||||
(
|
||||
start_of_window,
|
||||
end_of_window or datetime.datetime.now(tz=datetime.timezone.utc),
|
||||
)
|
||||
)
|
||||
start_of_window = end_of_window
|
||||
|
||||
return time_windows
|
||||
33
backend/danswer/background/indexing/dask_utils.py
Normal file
33
backend/danswer/background/indexing/dask_utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import asyncio
|
||||
|
||||
import psutil
|
||||
from dask.distributed import WorkerPlugin
|
||||
from distributed import Worker
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ResourceLogger(WorkerPlugin):
|
||||
def __init__(self, log_interval: int = 60 * 5):
|
||||
self.log_interval = log_interval
|
||||
|
||||
def setup(self, worker: Worker) -> None:
|
||||
"""This method will be called when the plugin is attached to a worker."""
|
||||
self.worker = worker
|
||||
worker.loop.add_callback(self.log_resources)
|
||||
|
||||
async def log_resources(self) -> None:
|
||||
"""Periodically log CPU and memory usage.
|
||||
|
||||
NOTE: must be async or else will clog up the worker indefinitely due to the fact that
|
||||
Dask uses Tornado under the hood (which is async)"""
|
||||
while True:
|
||||
cpu_percent = psutil.cpu_percent(interval=None)
|
||||
memory_available_gb = psutil.virtual_memory().available / (1024.0**3)
|
||||
# You can now log these values or send them to a monitoring service
|
||||
logger.debug(
|
||||
f"Worker {self.worker.address}: CPU usage {cpu_percent}%, Memory available {memory_available_gb}GB"
|
||||
)
|
||||
await asyncio.sleep(self.log_interval)
|
||||
@@ -4,12 +4,13 @@ not follow the expected behavior, etc.
|
||||
|
||||
NOTE: cannot use Celery directly due to
|
||||
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
|
||||
import multiprocessing
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from torch import multiprocessing
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -94,7 +95,7 @@ class SimpleJobClient:
|
||||
job_id = self.job_id_counter
|
||||
self.job_id_counter += 1
|
||||
|
||||
process = multiprocessing.Process(target=func, args=args)
|
||||
process = multiprocessing.Process(target=func, args=args, daemon=True)
|
||||
job = SimpleJob(id=job_id, process=process)
|
||||
process.start()
|
||||
|
||||
|
||||
260
backend/danswer/background/indexing/run_indexing.py
Normal file
260
backend/danswer/background/indexing/run_indexing.py
Normal file
@@ -0,0 +1,260 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import torch
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector import disable_connector
|
||||
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.credentials import backend_update_credential_json
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress
|
||||
from danswer.db.index_attempt import mark_attempt_succeeded
|
||||
from danswer.db.index_attempt import update_docs_indexed
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.utils.logger import IndexAttemptSingleton
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_document_generator(
|
||||
db_session: Session,
|
||||
attempt: IndexAttempt,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""NOTE: `start_time` and `end_time` are only used for poll connectors"""
|
||||
task = attempt.connector.input_type
|
||||
|
||||
try:
|
||||
runnable_connector, new_credential_json = instantiate_connector(
|
||||
attempt.connector.source,
|
||||
task,
|
||||
attempt.connector.connector_specific_config,
|
||||
attempt.credential.credential_json,
|
||||
)
|
||||
if new_credential_json is not None:
|
||||
backend_update_credential_json(
|
||||
attempt.credential, new_credential_json, db_session
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
disable_connector(attempt.connector.id, db_session)
|
||||
raise e
|
||||
|
||||
if task == InputType.LOAD_STATE:
|
||||
assert isinstance(runnable_connector, LoadConnector)
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
|
||||
elif task == InputType.POLL:
|
||||
assert isinstance(runnable_connector, PollConnector)
|
||||
if attempt.connector_id is None or attempt.credential_id is None:
|
||||
raise ValueError(
|
||||
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
|
||||
f"can't fetch time range."
|
||||
)
|
||||
|
||||
logger.info(f"Polling for updates between {start_time} and {end_time}")
|
||||
doc_batch_generator = runnable_connector.poll_source(
|
||||
start=start_time.timestamp(), end=end_time.timestamp()
|
||||
)
|
||||
|
||||
else:
|
||||
# Event types cannot be handled by a background type
|
||||
raise RuntimeError(f"Invalid task type: {task}")
|
||||
|
||||
return doc_batch_generator
|
||||
|
||||
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt: IndexAttempt,
|
||||
) -> None:
|
||||
"""
|
||||
1. Get documents which are either new or updated from specified application
|
||||
2. Embed and index these documents into the chosen datastore (vespa)
|
||||
3. Updates Postgres to record the indexed documents + the outcome of this run
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# mark as started
|
||||
mark_attempt_in_progress(index_attempt, db_session)
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=index_attempt.connector.id,
|
||||
credential_id=index_attempt.credential.id,
|
||||
attempt_status=IndexingStatus.IN_PROGRESS,
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline()
|
||||
db_connector = index_attempt.connector
|
||||
db_credential = index_attempt.credential
|
||||
last_successful_index_time = get_last_successful_attempt_time(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
net_doc_change = 0
|
||||
document_count = 0
|
||||
chunk_count = 0
|
||||
run_end_dt = None
|
||||
for ind, (window_start, window_end) in enumerate(
|
||||
get_time_windows_for_index_attempt(
|
||||
last_successful_run=datetime.fromtimestamp(
|
||||
last_successful_index_time, tz=timezone.utc
|
||||
),
|
||||
source_type=db_connector.source,
|
||||
)
|
||||
):
|
||||
doc_batch_generator = _get_document_generator(
|
||||
db_session=db_session,
|
||||
attempt=index_attempt,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
)
|
||||
|
||||
try:
|
||||
for doc_batch in doc_batch_generator:
|
||||
# check if connector is disabled mid run and stop if so
|
||||
db_session.refresh(db_connector)
|
||||
if db_connector.disabled:
|
||||
# let the `except` block handle this
|
||||
raise RuntimeError("Connector was disabled mid run")
|
||||
|
||||
logger.debug(
|
||||
f"Indexing batch of documents: {[doc.to_short_descriptor() for doc in doc_batch]}"
|
||||
)
|
||||
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
documents=doc_batch,
|
||||
index_attempt_metadata=IndexAttemptMetadata(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
),
|
||||
)
|
||||
net_doc_change += new_docs
|
||||
chunk_count += total_batch_chunks
|
||||
document_count += len(doc_batch)
|
||||
|
||||
# commit transaction so that the `update` below begins
|
||||
# with a brand new transaction. Postgres uses the start
|
||||
# of the transactions when computing `NOW()`, so if we have
|
||||
# a long running transaction, the `time_updated` field will
|
||||
# be inaccurate
|
||||
db_session.commit()
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
update_docs_indexed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
total_docs_indexed=document_count,
|
||||
new_docs_indexed=net_doc_change,
|
||||
)
|
||||
|
||||
run_end_dt = window_end
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
attempt_status=IndexingStatus.IN_PROGRESS,
|
||||
net_docs=net_doc_change,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
|
||||
)
|
||||
# Only mark the attempt as a complete failure if this is the first indexing window.
|
||||
# Otherwise, some progress was made - the next run will not start from the beginning.
|
||||
# In this case, it is not accurate to mark it as a failure. When the next run begins,
|
||||
# if that fails immediately, it will be marked as a failure.
|
||||
#
|
||||
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
|
||||
# to give better clarity in the UI, as the next run will never happen.
|
||||
if ind == 0 or db_connector.disabled:
|
||||
mark_attempt_failed(index_attempt, db_session, failure_reason=str(e))
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=index_attempt.connector.id,
|
||||
credential_id=index_attempt.credential.id,
|
||||
attempt_status=IndexingStatus.FAILED,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
raise e
|
||||
|
||||
# break => similar to success case. As mentioned above, if the next run fails for the same
|
||||
# reason it will then be marked as a failure
|
||||
break
|
||||
|
||||
mark_attempt_succeeded(index_attempt, db_session)
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
attempt_status=IndexingStatus.SUCCESS,
|
||||
net_docs=net_doc_change,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Indexed or refreshed {document_count} total documents for a total of {chunk_count} indexed chunks"
|
||||
)
|
||||
logger.info(
|
||||
f"Connector successfully finished, elapsed time: {time.time() - start_time} seconds"
|
||||
)
|
||||
|
||||
|
||||
def run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
|
||||
"""Entrypoint for indexing run when using dask distributed.
|
||||
Wraps the actual logic in a `try` block so that we can catch any exceptions
|
||||
and mark the attempt as failed."""
|
||||
try:
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
|
||||
|
||||
logger.info(f"Setting task to use {num_threads} threads")
|
||||
torch.set_num_threads(num_threads)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
if attempt is None:
|
||||
raise RuntimeError(
|
||||
f"Unable to find IndexAttempt for ID '{index_attempt_id}'"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Running indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
)
|
||||
|
||||
_run_indexing(
|
||||
db_session=db_session,
|
||||
index_attempt=attempt,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Completed indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")
|
||||
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import dask
|
||||
import torch
|
||||
@@ -10,23 +9,19 @@ from dask.distributed import Future
|
||||
from distributed import LocalCluster
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.indexing.dask_utils import ResourceLogger
|
||||
from danswer.background.indexing.job_client import SimpleJob
|
||||
from danswer.background.indexing.job_client import SimpleJobClient
|
||||
from danswer.configs.app_configs import EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED
|
||||
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
||||
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||
from danswer.configs.app_configs import LOG_LEVEL
|
||||
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
||||
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector import disable_connector
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from danswer.db.connector_credential_pair import mark_all_in_progress_cc_pairs_failed
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.credentials import backend_update_credential_json
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
@@ -35,15 +30,10 @@ from danswer.db.index_attempt import get_inprogress_index_attempts
|
||||
from danswer.db.index_attempt import get_last_attempt
|
||||
from danswer.db.index_attempt import get_not_started_index_attempts
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress
|
||||
from danswer.db.index_attempt import mark_attempt_succeeded
|
||||
from danswer.db.index_attempt import update_docs_indexed
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.search.search_nlp_models import warm_up_models
|
||||
from danswer.utils.logger import IndexAttemptSingleton
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -57,6 +47,9 @@ _UNEXPECTED_STATE_FAILURE_REASON = (
|
||||
)
|
||||
|
||||
|
||||
"""Util funcs"""
|
||||
|
||||
|
||||
def _get_num_threads() -> int:
|
||||
"""Get # of "threads" to use for ML models in an indexing job. By default uses
|
||||
the torch implementation, which returns the # of physical cores on the machine.
|
||||
@@ -64,19 +57,34 @@ def _get_num_threads() -> int:
|
||||
return max(MIN_THREADS_ML_MODELS, torch.get_num_threads())
|
||||
|
||||
|
||||
def should_create_new_indexing(
|
||||
def _should_create_new_indexing(
|
||||
connector: Connector, last_index: IndexAttempt | None, db_session: Session
|
||||
) -> bool:
|
||||
if connector.refresh_freq is None:
|
||||
return False
|
||||
if not last_index:
|
||||
return True
|
||||
|
||||
# only one scheduled job per connector at a time
|
||||
if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
return False
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
time_since_index = current_db_time - last_index.time_updated
|
||||
return time_since_index.total_seconds() >= connector.refresh_freq
|
||||
|
||||
|
||||
def mark_run_failed(
|
||||
def _is_indexing_job_marked_as_finished(index_attempt: IndexAttempt | None) -> bool:
|
||||
if index_attempt is None:
|
||||
return False
|
||||
|
||||
return (
|
||||
index_attempt.status == IndexingStatus.FAILED
|
||||
or index_attempt.status == IndexingStatus.SUCCESS
|
||||
)
|
||||
|
||||
|
||||
def _mark_run_failed(
|
||||
db_session: Session, index_attempt: IndexAttempt, failure_reason: str
|
||||
) -> None:
|
||||
"""Marks the `index_attempt` row as failed + updates the `
|
||||
@@ -102,342 +110,141 @@ def mark_run_failed(
|
||||
)
|
||||
|
||||
|
||||
def create_indexing_jobs(
|
||||
db_session: Session, existing_jobs: dict[int, Future | SimpleJob]
|
||||
) -> None:
|
||||
"""Main funcs"""
|
||||
|
||||
|
||||
def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
"""Creates new indexing jobs for each connector / credential pair which is:
|
||||
1. Enabled
|
||||
2. `refresh_frequency` time has passed since the last indexing run for this pair
|
||||
3. There is not already an ongoing indexing attempt for this pair
|
||||
"""
|
||||
ongoing_pairs: set[tuple[int | None, int | None]] = set()
|
||||
for attempt_id in existing_jobs:
|
||||
attempt = get_index_attempt(db_session=db_session, index_attempt_id=attempt_id)
|
||||
if attempt is None:
|
||||
logger.error(
|
||||
f"Unable to find IndexAttempt for ID '{attempt_id}' when creating "
|
||||
"indexing jobs"
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
ongoing_pairs: set[tuple[int | None, int | None]] = set()
|
||||
for attempt_id in existing_jobs:
|
||||
attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=attempt_id
|
||||
)
|
||||
continue
|
||||
ongoing_pairs.add((attempt.connector_id, attempt.credential_id))
|
||||
|
||||
enabled_connectors = fetch_connectors(db_session, disabled_status=False)
|
||||
for connector in enabled_connectors:
|
||||
for association in connector.credentials:
|
||||
credential = association.credential
|
||||
|
||||
# check if there is an ogoing indexing attempt for this connector + credential pair
|
||||
if (connector.id, credential.id) in ongoing_pairs:
|
||||
if attempt is None:
|
||||
logger.error(
|
||||
f"Unable to find IndexAttempt for ID '{attempt_id}' when creating "
|
||||
"indexing jobs"
|
||||
)
|
||||
continue
|
||||
ongoing_pairs.add((attempt.connector_id, attempt.credential_id))
|
||||
|
||||
last_attempt = get_last_attempt(connector.id, credential.id, db_session)
|
||||
if not should_create_new_indexing(connector, last_attempt, db_session):
|
||||
continue
|
||||
create_index_attempt(connector.id, credential.id, db_session)
|
||||
enabled_connectors = fetch_connectors(db_session, disabled_status=False)
|
||||
for connector in enabled_connectors:
|
||||
for association in connector.credentials:
|
||||
credential = association.credential
|
||||
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
attempt_status=IndexingStatus.NOT_STARTED,
|
||||
)
|
||||
# check if there is an ongoing indexing attempt for this connector + credential pair
|
||||
if (connector.id, credential.id) in ongoing_pairs:
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt(connector.id, credential.id, db_session)
|
||||
if not _should_create_new_indexing(connector, last_attempt, db_session):
|
||||
continue
|
||||
create_index_attempt(connector.id, credential.id, db_session)
|
||||
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
attempt_status=IndexingStatus.NOT_STARTED,
|
||||
)
|
||||
|
||||
|
||||
def cleanup_indexing_jobs(
|
||||
db_session: Session, existing_jobs: dict[int, Future | SimpleJob]
|
||||
existing_jobs: dict[int, Future | SimpleJob],
|
||||
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
|
||||
) -> dict[int, Future | SimpleJob]:
|
||||
existing_jobs_copy = existing_jobs.copy()
|
||||
|
||||
# clean up completed jobs
|
||||
for attempt_id, job in existing_jobs.items():
|
||||
# do nothing for ongoing jobs
|
||||
if not job.done():
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
logger.error(job.exception())
|
||||
|
||||
job.release()
|
||||
del existing_jobs_copy[attempt_id]
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=attempt_id
|
||||
)
|
||||
if not index_attempt:
|
||||
logger.error(
|
||||
f"Unable to find IndexAttempt for ID '{attempt_id}' when cleaning "
|
||||
"up indexing jobs"
|
||||
)
|
||||
continue
|
||||
|
||||
if index_attempt.status == IndexingStatus.IN_PROGRESS or job.status == "error":
|
||||
mark_run_failed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for attempt_id, job in existing_jobs.items():
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=attempt_id
|
||||
)
|
||||
|
||||
# clean up in-progress jobs that were never completed
|
||||
connectors = fetch_connectors(db_session)
|
||||
for connector in connectors:
|
||||
in_progress_indexing_attempts = get_inprogress_index_attempts(
|
||||
connector.id, db_session
|
||||
)
|
||||
for index_attempt in in_progress_indexing_attempts:
|
||||
if index_attempt.id in existing_jobs:
|
||||
# check to see if the job has been updated in the 3 hours, if not
|
||||
# assume it to frozen in some bad state and just mark it as failed. Note: this relies
|
||||
# on the fact that the `time_updated` field is constantly updated every
|
||||
# batch of documents indexed
|
||||
current_db_time = get_db_current_time(db_session=db_session)
|
||||
time_since_update = current_db_time - index_attempt.time_updated
|
||||
if time_since_update.total_seconds() > 60 * 60:
|
||||
existing_jobs[index_attempt.id].cancel()
|
||||
mark_run_failed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
failure_reason="Indexing run frozen - no updates in an hour. "
|
||||
"The run will be re-attempted at next scheduled indexing time.",
|
||||
)
|
||||
else:
|
||||
# If job isn't known, simply mark it as failed
|
||||
mark_run_failed(
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done() and not _is_indexing_job_marked_as_finished(
|
||||
index_attempt
|
||||
):
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
logger.error(job.exception())
|
||||
|
||||
job.release()
|
||||
del existing_jobs_copy[attempt_id]
|
||||
|
||||
if not index_attempt:
|
||||
logger.error(
|
||||
f"Unable to find IndexAttempt for ID '{attempt_id}' when cleaning "
|
||||
"up indexing jobs"
|
||||
)
|
||||
continue
|
||||
|
||||
if (
|
||||
index_attempt.status == IndexingStatus.IN_PROGRESS
|
||||
or job.status == "error"
|
||||
):
|
||||
_mark_run_failed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||
)
|
||||
|
||||
# clean up in-progress jobs that were never completed
|
||||
connectors = fetch_connectors(db_session)
|
||||
for connector in connectors:
|
||||
in_progress_indexing_attempts = get_inprogress_index_attempts(
|
||||
connector.id, db_session
|
||||
)
|
||||
for index_attempt in in_progress_indexing_attempts:
|
||||
if index_attempt.id in existing_jobs:
|
||||
# check to see if the job has been updated in last n hours, if not
|
||||
# assume it to frozen in some bad state and just mark it as failed. Note: this relies
|
||||
# on the fact that the `time_updated` field is constantly updated every
|
||||
# batch of documents indexed
|
||||
current_db_time = get_db_current_time(db_session=db_session)
|
||||
time_since_update = current_db_time - index_attempt.time_updated
|
||||
if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
|
||||
existing_jobs[index_attempt.id].cancel()
|
||||
_mark_run_failed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
failure_reason="Indexing run frozen - no updates in an hour. "
|
||||
"The run will be re-attempted at next scheduled indexing time.",
|
||||
)
|
||||
else:
|
||||
# If job isn't known, simply mark it as failed
|
||||
_mark_run_failed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||
)
|
||||
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt: IndexAttempt,
|
||||
) -> None:
|
||||
"""
|
||||
1. Get documents which are either new or updated from specified application
|
||||
2. Embed and index these documents into the chosen datastore (vespa)
|
||||
3. Updates Postgres to record the indexed documents + the outcome of this run
|
||||
"""
|
||||
|
||||
def _get_document_generator(
|
||||
db_session: Session, attempt: IndexAttempt
|
||||
) -> tuple[GenerateDocumentsOutput, float]:
|
||||
# "official" timestamp for this run
|
||||
# used for setting time bounds when fetching updates from apps and
|
||||
# is stored in the DB as the last successful run time if this run succeeds
|
||||
run_time = time.time()
|
||||
run_dt = datetime.fromtimestamp(run_time, tz=timezone.utc)
|
||||
run_time_str = run_dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
task = attempt.connector.input_type
|
||||
|
||||
try:
|
||||
runnable_connector, new_credential_json = instantiate_connector(
|
||||
attempt.connector.source,
|
||||
task,
|
||||
attempt.connector.connector_specific_config,
|
||||
attempt.credential.credential_json,
|
||||
)
|
||||
if new_credential_json is not None:
|
||||
backend_update_credential_json(
|
||||
attempt.credential, new_credential_json, db_session
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
disable_connector(attempt.connector.id, db_session)
|
||||
raise e
|
||||
|
||||
if task == InputType.LOAD_STATE:
|
||||
assert isinstance(runnable_connector, LoadConnector)
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
|
||||
elif task == InputType.POLL:
|
||||
assert isinstance(runnable_connector, PollConnector)
|
||||
if attempt.connector_id is None or attempt.credential_id is None:
|
||||
raise ValueError(
|
||||
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
|
||||
f"can't fetch time range."
|
||||
)
|
||||
last_run_time = get_last_successful_attempt_time(
|
||||
attempt.connector_id, attempt.credential_id, db_session
|
||||
)
|
||||
last_run_time_str = datetime.fromtimestamp(
|
||||
last_run_time, tz=timezone.utc
|
||||
).strftime("%Y-%m-%d %H:%M:%S")
|
||||
logger.info(
|
||||
f"Polling for updates between {last_run_time_str} and {run_time_str}"
|
||||
)
|
||||
doc_batch_generator = runnable_connector.poll_source(
|
||||
start=last_run_time, end=run_time
|
||||
)
|
||||
|
||||
else:
|
||||
# Event types cannot be handled by a background type
|
||||
raise RuntimeError(f"Invalid task type: {task}")
|
||||
|
||||
return doc_batch_generator, run_time
|
||||
|
||||
doc_batch_generator, run_time = _get_document_generator(db_session, index_attempt)
|
||||
|
||||
def _index(
|
||||
db_session: Session,
|
||||
attempt: IndexAttempt,
|
||||
doc_batch_generator: GenerateDocumentsOutput,
|
||||
run_time: float,
|
||||
) -> None:
|
||||
indexing_pipeline = build_indexing_pipeline()
|
||||
|
||||
run_dt = datetime.fromtimestamp(run_time, tz=timezone.utc)
|
||||
db_connector = attempt.connector
|
||||
db_credential = attempt.credential
|
||||
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
attempt_status=IndexingStatus.IN_PROGRESS,
|
||||
run_dt=run_dt,
|
||||
)
|
||||
|
||||
net_doc_change = 0
|
||||
document_count = 0
|
||||
chunk_count = 0
|
||||
try:
|
||||
for doc_batch in doc_batch_generator:
|
||||
logger.debug(
|
||||
f"Indexing batch of documents: {[doc.to_short_descriptor() for doc in doc_batch]}"
|
||||
)
|
||||
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
documents=doc_batch,
|
||||
index_attempt_metadata=IndexAttemptMetadata(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
),
|
||||
)
|
||||
net_doc_change += new_docs
|
||||
chunk_count += total_batch_chunks
|
||||
document_count += len(doc_batch)
|
||||
|
||||
# commit transaction so that the `update` below begins
|
||||
# with a brand new transaction. Postgres uses the start
|
||||
# of the transactions when computing `NOW()`, so if we have
|
||||
# a long running transaction, the `time_updated` field will
|
||||
# be inaccurate
|
||||
db_session.commit()
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
update_docs_indexed(
|
||||
db_session=db_session,
|
||||
index_attempt=attempt,
|
||||
total_docs_indexed=document_count,
|
||||
new_docs_indexed=net_doc_change,
|
||||
)
|
||||
|
||||
# check if connector is disabled mid run and stop if so
|
||||
db_session.refresh(db_connector)
|
||||
if db_connector.disabled:
|
||||
# let the `except` block handle this
|
||||
raise RuntimeError("Connector was disabled mid run")
|
||||
|
||||
mark_attempt_succeeded(attempt, db_session)
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
attempt_status=IndexingStatus.SUCCESS,
|
||||
net_docs=net_doc_change,
|
||||
run_dt=run_dt,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Indexed or updated {document_count} total documents for a total of {chunk_count} chunks"
|
||||
)
|
||||
logger.info(
|
||||
f"Connector successfully finished, elapsed time: {time.time() - run_time} seconds"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"Failed connector elapsed time: {time.time() - run_time} seconds"
|
||||
)
|
||||
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
|
||||
# The last attempt won't be marked failed until the next cycle's check for still in-progress attempts
|
||||
# The connector_credential_pair is marked failed here though to reflect correctly in UI asap
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=attempt.connector.id,
|
||||
credential_id=attempt.credential.id,
|
||||
attempt_status=IndexingStatus.FAILED,
|
||||
net_docs=net_doc_change,
|
||||
run_dt=run_dt,
|
||||
)
|
||||
raise e
|
||||
|
||||
_index(db_session, index_attempt, doc_batch_generator, run_time)
|
||||
|
||||
|
||||
def _run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
|
||||
"""Entrypoint for indexing run when using dask distributed.
|
||||
Wraps the actual logic in a `try` block so that we can catch any exceptions
|
||||
and mark the attempt as failed."""
|
||||
try:
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
|
||||
|
||||
logger.info(f"Setting task to use {num_threads} threads")
|
||||
torch.set_num_threads(num_threads)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
if attempt is None:
|
||||
raise RuntimeError(
|
||||
f"Unable to find IndexAttempt for ID '{index_attempt_id}'"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Running indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
)
|
||||
mark_attempt_in_progress(attempt, db_session)
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=attempt.connector.id,
|
||||
credential_id=attempt.credential.id,
|
||||
attempt_status=IndexingStatus.IN_PROGRESS,
|
||||
)
|
||||
|
||||
_run_indexing(
|
||||
db_session=db_session,
|
||||
index_attempt=attempt,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Completed indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")
|
||||
|
||||
|
||||
def kickoff_indexing_jobs(
|
||||
db_session: Session,
|
||||
existing_jobs: dict[int, Future | SimpleJob],
|
||||
client: Client | SimpleJobClient,
|
||||
) -> dict[int, Future | SimpleJob]:
|
||||
existing_jobs_copy = existing_jobs.copy()
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
# Don't include jobs waiting in the Dask queue that just haven't started running
|
||||
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
|
||||
new_indexing_attempts = [
|
||||
attempt
|
||||
for attempt in get_not_started_index_attempts(db_session)
|
||||
if attempt.id not in existing_jobs
|
||||
]
|
||||
with Session(engine) as db_session:
|
||||
new_indexing_attempts = [
|
||||
attempt
|
||||
for attempt in get_not_started_index_attempts(db_session)
|
||||
if attempt.id not in existing_jobs
|
||||
]
|
||||
|
||||
logger.info(f"Found {len(new_indexing_attempts)} new indexing tasks.")
|
||||
|
||||
@@ -449,19 +256,23 @@ def kickoff_indexing_jobs(
|
||||
logger.warning(
|
||||
f"Skipping index attempt as Connector has been deleted: {attempt}"
|
||||
)
|
||||
mark_attempt_failed(attempt, db_session, failure_reason="Connector is null")
|
||||
with Session(engine) as db_session:
|
||||
mark_attempt_failed(
|
||||
attempt, db_session, failure_reason="Connector is null"
|
||||
)
|
||||
continue
|
||||
if attempt.credential is None:
|
||||
logger.warning(
|
||||
f"Skipping index attempt as Credential has been deleted: {attempt}"
|
||||
)
|
||||
mark_attempt_failed(
|
||||
attempt, db_session, failure_reason="Credential is null"
|
||||
)
|
||||
with Session(engine) as db_session:
|
||||
mark_attempt_failed(
|
||||
attempt, db_session, failure_reason="Credential is null"
|
||||
)
|
||||
continue
|
||||
|
||||
run = client.submit(
|
||||
_run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
|
||||
run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
|
||||
)
|
||||
if run:
|
||||
logger.info(
|
||||
@@ -476,9 +287,7 @@ def kickoff_indexing_jobs(
|
||||
|
||||
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
|
||||
client: Client | SimpleJobClient
|
||||
if EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED:
|
||||
client = SimpleJobClient(n_workers=num_workers)
|
||||
else:
|
||||
if DASK_JOB_CLIENT_ENABLED:
|
||||
cluster = LocalCluster(
|
||||
n_workers=num_workers,
|
||||
threads_per_worker=1,
|
||||
@@ -489,6 +298,10 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
client = Client(cluster)
|
||||
if LOG_LEVEL.lower() == "debug":
|
||||
client.register_worker_plugin(ResourceLogger())
|
||||
else:
|
||||
client = SimpleJobClient(n_workers=num_workers)
|
||||
|
||||
existing_jobs: dict[int, Future | SimpleJob] = {}
|
||||
engine = get_sqlalchemy_engine()
|
||||
@@ -502,15 +315,20 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
start = time.time()
|
||||
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
|
||||
logger.info(f"Running update, current UTC time: {start_time_utc}")
|
||||
|
||||
if existing_jobs:
|
||||
# TODO: make this debug level once the "no jobs are being scheduled" issue is resolved
|
||||
logger.info(
|
||||
"Found existing indexing jobs: "
|
||||
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
|
||||
)
|
||||
|
||||
try:
|
||||
with Session(engine, expire_on_commit=False) as db_session:
|
||||
existing_jobs = cleanup_indexing_jobs(
|
||||
db_session=db_session, existing_jobs=existing_jobs
|
||||
)
|
||||
create_indexing_jobs(db_session=db_session, existing_jobs=existing_jobs)
|
||||
existing_jobs = kickoff_indexing_jobs(
|
||||
db_session=db_session, existing_jobs=existing_jobs, client=client
|
||||
)
|
||||
existing_jobs = cleanup_indexing_jobs(existing_jobs=existing_jobs)
|
||||
create_indexing_jobs(existing_jobs=existing_jobs)
|
||||
existing_jobs = kickoff_indexing_jobs(
|
||||
existing_jobs=existing_jobs, client=client
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run update due to {e}")
|
||||
sleep_time = delay - (time.time() - start)
|
||||
@@ -518,8 +336,19 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
time.sleep(sleep_time)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Warming up Embedding Model(s)")
|
||||
warm_up_models(indexer_only=True)
|
||||
def update__main() -> None:
|
||||
# needed for CUDA to work with multiprocessing
|
||||
# NOTE: needs to be done on application startup
|
||||
# before any other torch code has been run
|
||||
if not DASK_JOB_CLIENT_ENABLED:
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
if not MODEL_SERVER_HOST:
|
||||
logger.info("Warming up Embedding Model(s)")
|
||||
warm_up_models(indexer_only=True, skip_cross_encoders=True)
|
||||
logger.info("Starting Indexing Loop")
|
||||
update_loop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
update__main()
|
||||
|
||||
@@ -1,581 +0,0 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_prompts import build_combined_query
|
||||
from danswer.chat.chat_prompts import DANSWER_TOOL_NAME
|
||||
from danswer.chat.chat_prompts import form_require_search_text
|
||||
from danswer.chat.chat_prompts import form_tool_followup_text
|
||||
from danswer.chat.chat_prompts import form_tool_less_followup_text
|
||||
from danswer.chat.chat_prompts import form_tool_section_text
|
||||
from danswer.chat.chat_prompts import form_user_prompt_text
|
||||
from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
|
||||
from danswer.chat.chat_prompts import REQUIRE_DANSWER_SYSTEM_MSG
|
||||
from danswer.chat.chat_prompts import YES_SEARCH
|
||||
from danswer.chat.personas import build_system_text_from_persona
|
||||
from danswer.chat.tools import call_tool
|
||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
|
||||
from danswer.configs.chat_configs import FORCE_TOOL_PROMPT
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_INPUT_TOKENS
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import DanswerChatModelOut
|
||||
from danswer.direct_qa.interfaces import StreamingError
|
||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import translate_danswer_msg_to_langchain
|
||||
from danswer.search.access_filters import build_access_filters_for_user
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.search_runner import chunks_to_search_docs
|
||||
from danswer.search.search_runner import search_chunks
|
||||
from danswer.server.models import RetrievalDocs
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import extract_embedded_json
|
||||
from danswer.utils.text_processing import has_unescaped_quote
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
LLM_CHAT_FAILURE_MSG = "The large-language-model failed to generate a valid response."
|
||||
|
||||
|
||||
def _parse_embedded_json_streamed_response(
|
||||
tokens: Iterator[str],
|
||||
) -> Iterator[DanswerAnswerPiece | DanswerChatModelOut]:
|
||||
final_answer = False
|
||||
just_start_stream = False
|
||||
model_output = ""
|
||||
hold = ""
|
||||
finding_end = 0
|
||||
for token in tokens:
|
||||
model_output += token
|
||||
hold += token
|
||||
|
||||
if (
|
||||
final_answer is False
|
||||
and '"action":"finalanswer",' in model_output.lower().replace(" ", "")
|
||||
):
|
||||
final_answer = True
|
||||
|
||||
if final_answer and '"actioninput":"' in model_output.lower().replace(
|
||||
" ", ""
|
||||
).replace("_", ""):
|
||||
if not just_start_stream:
|
||||
just_start_stream = True
|
||||
hold = ""
|
||||
|
||||
if has_unescaped_quote(hold):
|
||||
finding_end += 1
|
||||
hold = hold[: hold.find('"')]
|
||||
|
||||
if finding_end <= 1:
|
||||
if finding_end == 1:
|
||||
finding_end += 1
|
||||
|
||||
yield DanswerAnswerPiece(answer_piece=hold)
|
||||
hold = ""
|
||||
|
||||
model_final = extract_embedded_json(model_output)
|
||||
if "action" not in model_final or "action_input" not in model_final:
|
||||
raise ValueError("Model did not provide all required action values")
|
||||
|
||||
yield DanswerChatModelOut(
|
||||
model_raw=model_output,
|
||||
action=model_final["action"],
|
||||
action_input=model_final["action_input"],
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def _find_last_index(
|
||||
lst: list[int], max_prompt_tokens: int = GEN_AI_MAX_INPUT_TOKENS
|
||||
) -> int:
|
||||
"""From the back, find the index of the last element to include
|
||||
before the list exceeds the maximum"""
|
||||
running_sum = 0
|
||||
|
||||
last_ind = 0
|
||||
for i in range(len(lst) - 1, -1, -1):
|
||||
running_sum += lst[i]
|
||||
if running_sum > max_prompt_tokens:
|
||||
last_ind = i + 1
|
||||
break
|
||||
if last_ind >= len(lst):
|
||||
raise ValueError("Last message alone is too large!")
|
||||
return last_ind
|
||||
|
||||
|
||||
def danswer_chat_retrieval(
|
||||
query_message: ChatMessage,
|
||||
history: list[ChatMessage],
|
||||
llm: LLM,
|
||||
filters: IndexFilters,
|
||||
) -> list[InferenceChunk]:
|
||||
if history:
|
||||
query_combination_msgs = build_combined_query(query_message, history)
|
||||
reworded_query = llm.invoke(query_combination_msgs)
|
||||
else:
|
||||
reworded_query = query_message.message
|
||||
|
||||
search_query = SearchQuery(
|
||||
query=reworded_query,
|
||||
search_type=SearchType.HYBRID,
|
||||
filters=filters,
|
||||
favor_recent=False,
|
||||
)
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
ranked_chunks, unranked_chunks = search_chunks(
|
||||
query=search_query, document_index=get_default_document_index()
|
||||
)
|
||||
|
||||
if not ranked_chunks:
|
||||
return []
|
||||
|
||||
if unranked_chunks:
|
||||
ranked_chunks.extend(unranked_chunks)
|
||||
|
||||
filtered_ranked_chunks = [
|
||||
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
|
||||
]
|
||||
|
||||
# get all chunks that fit into the token limit
|
||||
usable_chunks = get_usable_chunks(
|
||||
chunks=filtered_ranked_chunks,
|
||||
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_CHAT,
|
||||
)
|
||||
|
||||
return usable_chunks
|
||||
|
||||
|
||||
def _drop_messages_history_overflow(
|
||||
system_msg: BaseMessage | None,
|
||||
system_token_count: int,
|
||||
history_msgs: list[BaseMessage],
|
||||
history_token_counts: list[int],
|
||||
final_msg: BaseMessage,
|
||||
final_msg_token_count: int,
|
||||
) -> list[BaseMessage]:
|
||||
"""As message history grows, messages need to be dropped starting from the furthest in the past.
|
||||
The System message should be kept if at all possible and the latest user input which is inserted in the
|
||||
prompt template must be included"""
|
||||
|
||||
if len(history_msgs) != len(history_token_counts):
|
||||
# This should never happen
|
||||
raise ValueError("Need exactly 1 token count per message for tracking overflow")
|
||||
|
||||
prompt: list[BaseMessage] = []
|
||||
|
||||
# Start dropping from the history if necessary
|
||||
all_tokens = history_token_counts + [system_token_count, final_msg_token_count]
|
||||
ind_prev_msg_start = _find_last_index(all_tokens)
|
||||
|
||||
if system_msg and ind_prev_msg_start <= len(history_msgs):
|
||||
prompt.append(system_msg)
|
||||
|
||||
prompt.extend(history_msgs[ind_prev_msg_start:])
|
||||
|
||||
prompt.append(final_msg)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def extract_citations_from_stream(
|
||||
tokens: Iterator[str], links: list[str | None]
|
||||
) -> Iterator[str]:
|
||||
if not links:
|
||||
yield from tokens
|
||||
return
|
||||
|
||||
max_citation_num = len(links) + 1 # LLM is prompted to 1 index these
|
||||
curr_segment = ""
|
||||
prepend_bracket = False
|
||||
for token in tokens:
|
||||
# Special case of [1][ where ][ is a single token
|
||||
if prepend_bracket:
|
||||
curr_segment += "[" + curr_segment
|
||||
prepend_bracket = False
|
||||
|
||||
curr_segment += token
|
||||
|
||||
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
|
||||
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
|
||||
|
||||
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
|
||||
citation_found = re.search(citation_pattern, curr_segment)
|
||||
|
||||
if citation_found:
|
||||
numerical_value = int(citation_found.group(1))
|
||||
if 1 <= numerical_value <= max_citation_num:
|
||||
link = links[numerical_value - 1]
|
||||
if link:
|
||||
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1)
|
||||
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
|
||||
|
||||
# In case there's another open bracket like [1][, don't want to match this
|
||||
possible_citation_found = None
|
||||
|
||||
# if we see "[", but haven't seen the right side, hold back - this may be a
|
||||
# citation that needs to be replaced with a link
|
||||
if possible_citation_found:
|
||||
continue
|
||||
|
||||
# Special case with back to back citations [1][2]
|
||||
if curr_segment and curr_segment[-1] == "[":
|
||||
curr_segment = curr_segment[:-1]
|
||||
prepend_bracket = True
|
||||
|
||||
yield curr_segment
|
||||
curr_segment = ""
|
||||
|
||||
if curr_segment:
|
||||
if prepend_bracket:
|
||||
yield "[" + curr_segment
|
||||
else:
|
||||
yield curr_segment
|
||||
|
||||
|
||||
def llm_contextless_chat_answer(
|
||||
messages: list[ChatMessage],
|
||||
system_text: str | None = None,
|
||||
tokenizer: Callable | None = None,
|
||||
) -> Iterator[DanswerAnswerPiece | StreamingError]:
|
||||
try:
|
||||
prompt_msgs = [translate_danswer_msg_to_langchain(msg) for msg in messages]
|
||||
|
||||
if system_text:
|
||||
tokenizer = tokenizer or get_default_llm_tokenizer()
|
||||
system_tokens = len(tokenizer(system_text))
|
||||
system_msg = SystemMessage(content=system_text)
|
||||
|
||||
message_tokens = [msg.token_count for msg in messages] + [system_tokens]
|
||||
else:
|
||||
message_tokens = [msg.token_count for msg in messages]
|
||||
|
||||
last_msg_ind = _find_last_index(message_tokens)
|
||||
|
||||
remaining_user_msgs = prompt_msgs[last_msg_ind:]
|
||||
if not remaining_user_msgs:
|
||||
raise ValueError("Last user message is too long!")
|
||||
|
||||
if system_text:
|
||||
all_msgs = [system_msg] + remaining_user_msgs
|
||||
else:
|
||||
all_msgs = remaining_user_msgs
|
||||
|
||||
for token in get_default_llm().stream(all_msgs):
|
||||
yield DanswerAnswerPiece(answer_piece=token)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"LLM failed to produce valid chat message, error: {e}")
|
||||
yield StreamingError(error=str(e))
|
||||
|
||||
|
||||
def llm_contextual_chat_answer(
|
||||
messages: list[ChatMessage],
|
||||
persona: Persona,
|
||||
user: User | None,
|
||||
tokenizer: Callable,
|
||||
db_session: Session,
|
||||
run_search_system_text: str = REQUIRE_DANSWER_SYSTEM_MSG,
|
||||
) -> Iterator[DanswerAnswerPiece | RetrievalDocs | StreamingError]:
|
||||
last_message = messages[-1]
|
||||
final_query_text = last_message.message
|
||||
previous_messages = messages[:-1]
|
||||
previous_msgs_as_basemessage = [
|
||||
translate_danswer_msg_to_langchain(msg) for msg in previous_messages
|
||||
]
|
||||
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
|
||||
if not final_query_text:
|
||||
raise ValueError("User chat message is empty.")
|
||||
|
||||
# Determine if a search is necessary to answer the user query
|
||||
user_req_search_text = form_require_search_text(last_message)
|
||||
last_user_msg = HumanMessage(content=user_req_search_text)
|
||||
|
||||
previous_msg_token_counts = [msg.token_count for msg in previous_messages]
|
||||
danswer_system_tokens = len(tokenizer(run_search_system_text))
|
||||
last_user_msg_tokens = len(tokenizer(user_req_search_text))
|
||||
|
||||
need_search_prompt = _drop_messages_history_overflow(
|
||||
system_msg=SystemMessage(content=run_search_system_text),
|
||||
system_token_count=danswer_system_tokens,
|
||||
history_msgs=previous_msgs_as_basemessage,
|
||||
history_token_counts=previous_msg_token_counts,
|
||||
final_msg=last_user_msg,
|
||||
final_msg_token_count=last_user_msg_tokens,
|
||||
)
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
model_out = llm.invoke(need_search_prompt)
|
||||
|
||||
# Model will output "Yes Search" if search is useful
|
||||
# Be a little forgiving though, if we match yes, it's good enough
|
||||
retrieved_chunks: list[InferenceChunk] = []
|
||||
if (YES_SEARCH.split()[0] + " ").lower() in model_out.lower():
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
doc_set_filter = [doc_set.name for doc_set in persona.document_sets] or None
|
||||
final_filters = IndexFilters(
|
||||
source_type=None,
|
||||
document_set=doc_set_filter,
|
||||
time_cutoff=None,
|
||||
access_control_list=user_acl_filters,
|
||||
)
|
||||
|
||||
retrieved_chunks = danswer_chat_retrieval(
|
||||
query_message=last_message,
|
||||
history=previous_messages,
|
||||
llm=llm,
|
||||
filters=final_filters,
|
||||
)
|
||||
|
||||
yield RetrievalDocs(top_documents=chunks_to_search_docs(retrieved_chunks))
|
||||
|
||||
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
|
||||
|
||||
last_user_msg_text = form_tool_less_followup_text(
|
||||
tool_output=tool_result_str,
|
||||
query=last_message.message,
|
||||
hint_text=persona.hint_text,
|
||||
)
|
||||
last_user_msg_tokens = len(tokenizer(last_user_msg_text))
|
||||
last_user_msg = HumanMessage(content=last_user_msg_text)
|
||||
|
||||
else:
|
||||
last_user_msg_tokens = len(tokenizer(final_query_text))
|
||||
last_user_msg = HumanMessage(content=final_query_text)
|
||||
|
||||
system_text = build_system_text_from_persona(persona)
|
||||
system_msg = SystemMessage(content=system_text) if system_text else None
|
||||
system_tokens = len(tokenizer(system_text)) if system_text else 0
|
||||
|
||||
prompt = _drop_messages_history_overflow(
|
||||
system_msg=system_msg,
|
||||
system_token_count=system_tokens,
|
||||
history_msgs=previous_msgs_as_basemessage,
|
||||
history_token_counts=previous_msg_token_counts,
|
||||
final_msg=last_user_msg,
|
||||
final_msg_token_count=last_user_msg_tokens,
|
||||
)
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
tokens = llm.stream(prompt)
|
||||
links = [
|
||||
chunk.source_links[0] if chunk.source_links else None
|
||||
for chunk in retrieved_chunks
|
||||
]
|
||||
|
||||
for segment in extract_citations_from_stream(tokens, links):
|
||||
yield DanswerAnswerPiece(answer_piece=segment)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"LLM failed to produce valid chat message, error: {e}")
|
||||
yield StreamingError(error=str(e))
|
||||
|
||||
|
||||
def llm_tools_enabled_chat_answer(
|
||||
messages: list[ChatMessage],
|
||||
persona: Persona,
|
||||
user: User | None,
|
||||
tokenizer: Callable,
|
||||
db_session: Session,
|
||||
) -> Iterator[DanswerAnswerPiece | RetrievalDocs | StreamingError]:
|
||||
retrieval_enabled = persona.retrieval_enabled
|
||||
system_text = build_system_text_from_persona(persona)
|
||||
hint_text = persona.hint_text
|
||||
tool_text = form_tool_section_text(persona.tools, persona.retrieval_enabled)
|
||||
|
||||
last_message = messages[-1]
|
||||
previous_messages = messages[:-1]
|
||||
previous_msgs_as_basemessage = [
|
||||
translate_danswer_msg_to_langchain(msg) for msg in previous_messages
|
||||
]
|
||||
|
||||
# Failure reasons include:
|
||||
# - Invalid LLM output, wrong format or wrong/missing keys
|
||||
# - No "Final Answer" from model after tool calling
|
||||
# - LLM times out or is otherwise unavailable
|
||||
# - Calling invalid tool or tool call fails
|
||||
# - Last message has more tokens than model is set to accept
|
||||
# - Missing user input
|
||||
try:
|
||||
if not last_message.message:
|
||||
raise ValueError("User chat message is empty.")
|
||||
|
||||
# Build the prompt using the last user message
|
||||
user_text = form_user_prompt_text(
|
||||
query=last_message.message,
|
||||
tool_text=tool_text,
|
||||
hint_text=hint_text,
|
||||
)
|
||||
last_user_msg = HumanMessage(content=user_text)
|
||||
|
||||
# Count tokens once to reuse
|
||||
previous_msg_token_counts = [msg.token_count for msg in previous_messages]
|
||||
system_tokens = len(tokenizer(system_text)) if system_text else 0
|
||||
last_user_msg_tokens = len(tokenizer(user_text))
|
||||
|
||||
prompt = _drop_messages_history_overflow(
|
||||
system_msg=SystemMessage(content=system_text) if system_text else None,
|
||||
system_token_count=system_tokens,
|
||||
history_msgs=previous_msgs_as_basemessage,
|
||||
history_token_counts=previous_msg_token_counts,
|
||||
final_msg=last_user_msg,
|
||||
final_msg_token_count=last_user_msg_tokens,
|
||||
)
|
||||
|
||||
llm = get_default_llm()
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
tokens = llm.stream(prompt)
|
||||
|
||||
final_result: DanswerChatModelOut | None = None
|
||||
final_answer_streamed = False
|
||||
|
||||
for result in _parse_embedded_json_streamed_response(tokens):
|
||||
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
|
||||
yield result
|
||||
final_answer_streamed = True
|
||||
|
||||
if isinstance(result, DanswerChatModelOut):
|
||||
final_result = result
|
||||
break
|
||||
|
||||
if final_answer_streamed:
|
||||
return
|
||||
|
||||
if final_result is None:
|
||||
raise RuntimeError("Model output finished without final output parsing.")
|
||||
|
||||
if (
|
||||
retrieval_enabled
|
||||
and final_result.action.lower() == DANSWER_TOOL_NAME.lower()
|
||||
):
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
doc_set_filter = [doc_set.name for doc_set in persona.document_sets] or None
|
||||
|
||||
final_filters = IndexFilters(
|
||||
source_type=None,
|
||||
document_set=doc_set_filter,
|
||||
time_cutoff=None,
|
||||
access_control_list=user_acl_filters,
|
||||
)
|
||||
|
||||
retrieved_chunks = danswer_chat_retrieval(
|
||||
query_message=last_message,
|
||||
history=previous_messages,
|
||||
llm=llm,
|
||||
filters=final_filters,
|
||||
)
|
||||
yield RetrievalDocs(top_documents=chunks_to_search_docs(retrieved_chunks))
|
||||
|
||||
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
|
||||
else:
|
||||
tool_result_str = call_tool(final_result)
|
||||
|
||||
# The AI's tool calling message
|
||||
tool_call_msg_text = final_result.model_raw
|
||||
tool_call_msg_token_count = len(tokenizer(tool_call_msg_text))
|
||||
|
||||
# Create the new message to use the results of the tool call
|
||||
tool_followup_text = form_tool_followup_text(
|
||||
tool_output=tool_result_str,
|
||||
query=last_message.message,
|
||||
hint_text=hint_text,
|
||||
)
|
||||
tool_followup_msg = HumanMessage(content=tool_followup_text)
|
||||
tool_followup_tokens = len(tokenizer(tool_followup_text))
|
||||
|
||||
# Drop previous messages, the drop order goes: previous messages in the history,
|
||||
# the last user prompt and generated intermediate messages from this recent prompt,
|
||||
# the system message, then finally the tool message that was the last thing generated
|
||||
follow_up_prompt = _drop_messages_history_overflow(
|
||||
system_msg=SystemMessage(content=system_text) if system_text else None,
|
||||
system_token_count=system_tokens,
|
||||
history_msgs=previous_msgs_as_basemessage
|
||||
+ [last_user_msg, AIMessage(content=tool_call_msg_text)],
|
||||
history_token_counts=previous_msg_token_counts
|
||||
+ [last_user_msg_tokens, tool_call_msg_token_count],
|
||||
final_msg=tool_followup_msg,
|
||||
final_msg_token_count=tool_followup_tokens,
|
||||
)
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
tokens = llm.stream(follow_up_prompt)
|
||||
|
||||
for result in _parse_embedded_json_streamed_response(tokens):
|
||||
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
|
||||
yield result
|
||||
final_answer_streamed = True
|
||||
|
||||
if final_answer_streamed is False:
|
||||
raise RuntimeError("LLM did not to produce a Final Answer after tool call")
|
||||
except Exception as e:
|
||||
logger.exception(f"LLM failed to produce valid chat message, error: {e}")
|
||||
yield StreamingError(error=str(e))
|
||||
|
||||
|
||||
def llm_chat_answer(
|
||||
messages: list[ChatMessage],
|
||||
persona: Persona | None,
|
||||
tokenizer: Callable,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> Iterator[DanswerAnswerPiece | RetrievalDocs | StreamingError]:
|
||||
# Common error cases to keep in mind:
|
||||
# - User asks question about something long ago, due to context limit, the message is dropped
|
||||
# - Tool use gives wrong/irrelevant results, model gets confused by the noise
|
||||
# - Model is too weak of an LLM, fails to follow instructions
|
||||
# - Bad persona design leads to confusing instructions to the model
|
||||
# - Bad configurations, too small token limit, mismatched tokenizer to LLM, etc.
|
||||
|
||||
# No setting/persona available therefore no retrieval and no additional tools
|
||||
if persona is None:
|
||||
return llm_contextless_chat_answer(messages)
|
||||
|
||||
# Persona is configured but with retrieval off and no tools
|
||||
# therefore cannot retrieve any context so contextless
|
||||
elif persona.retrieval_enabled is False and not persona.tools:
|
||||
return llm_contextless_chat_answer(
|
||||
messages, system_text=persona.system_text, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
# No additional tools outside of Danswer retrieval, can use a more basic prompt
|
||||
# Doesn't require tool calling output format (all LLM outputs are therefore valid)
|
||||
elif persona.retrieval_enabled and not persona.tools and not FORCE_TOOL_PROMPT:
|
||||
return llm_contextual_chat_answer(
|
||||
messages=messages,
|
||||
persona=persona,
|
||||
tokenizer=tokenizer,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Use most flexible/complex prompt format that allows arbitrary tool calls
|
||||
# that are configured in the persona file
|
||||
# WARNING: this flow does not work well with weaker LLMs (anything below GPT-4)
|
||||
return llm_tools_enabled_chat_answer(
|
||||
messages=messages,
|
||||
persona=persona,
|
||||
tokenizer=tokenizer,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
@@ -1,274 +0,0 @@
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
from danswer.configs.constants import CODE_BLOCK_PAT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import ToolInfo
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.utils import translate_danswer_msg_to_langchain
|
||||
|
||||
DANSWER_TOOL_NAME = "Current Search"
|
||||
DANSWER_TOOL_DESCRIPTION = (
|
||||
"A search tool that can find information on any topic "
|
||||
"including up to date and proprietary knowledge."
|
||||
)
|
||||
|
||||
DANSWER_SYSTEM_MSG = (
|
||||
"Given a conversation (between Human and Assistant) and a final message from Human, "
|
||||
"rewrite the last message to be a standalone question which captures required/relevant context "
|
||||
"from previous messages. This question must be useful for a semantic search engine. "
|
||||
"It is used for a natural language search."
|
||||
)
|
||||
|
||||
|
||||
YES_SEARCH = "Yes Search"
|
||||
NO_SEARCH = "No Search"
|
||||
REQUIRE_DANSWER_SYSTEM_MSG = (
|
||||
"You are a large language model whose only job is to determine if the system should call an external search tool "
|
||||
"to be able to answer the user's last message.\n"
|
||||
f'\nRespond with "{NO_SEARCH}" if:\n'
|
||||
f"- there is sufficient information in chat history to fully answer the user query\n"
|
||||
f"- there is enough knowledge in the LLM to fully answer the user query\n"
|
||||
f"- the user query does not rely on any specific knowledge\n"
|
||||
f'\nRespond with "{YES_SEARCH}" if:\n'
|
||||
"- additional knowledge about entities, processes, problems, or anything else could lead to a better answer.\n"
|
||||
"- there is some uncertainty what the user is referring to\n\n"
|
||||
f'Respond with EXACTLY and ONLY "{YES_SEARCH}" or "{NO_SEARCH}"'
|
||||
)
|
||||
|
||||
TOOL_TEMPLATE = """
|
||||
TOOLS
|
||||
------
|
||||
You can use tools to look up information that may be helpful in answering the user's \
|
||||
original question. The available tools are:
|
||||
|
||||
{tool_overviews}
|
||||
|
||||
RESPONSE FORMAT INSTRUCTIONS
|
||||
----------------------------
|
||||
When responding to me, please output a response in one of two formats:
|
||||
|
||||
**Option 1:**
|
||||
Use this if you want to use a tool. Markdown code snippet formatted in the following schema:
|
||||
|
||||
```json
|
||||
{{
|
||||
"action": string, \\ The action to take. {tool_names}
|
||||
"action_input": string \\ The input to the action
|
||||
}}
|
||||
```
|
||||
|
||||
**Option #2:**
|
||||
Use this if you want to respond directly to the user. Markdown code snippet formatted in the following schema:
|
||||
|
||||
```json
|
||||
{{
|
||||
"action": "Final Answer",
|
||||
"action_input": string \\ You should put what you want to return to use here
|
||||
}}
|
||||
```
|
||||
"""
|
||||
|
||||
TOOL_LESS_PROMPT = """
|
||||
Respond with a markdown code snippet in the following schema:
|
||||
|
||||
```json
|
||||
{{
|
||||
"action": "Final Answer",
|
||||
"action_input": string \\ You should put what you want to return to use here
|
||||
}}
|
||||
```
|
||||
"""
|
||||
|
||||
USER_INPUT = """
|
||||
USER'S INPUT
|
||||
--------------------
|
||||
Here is the user's input \
|
||||
(remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else):
|
||||
|
||||
{user_input}
|
||||
"""
|
||||
|
||||
TOOL_FOLLOWUP = """
|
||||
TOOL RESPONSE:
|
||||
---------------------
|
||||
{tool_output}
|
||||
|
||||
USER'S INPUT
|
||||
--------------------
|
||||
Okay, so what is the response to my last comment? If using information obtained from the tools you must \
|
||||
mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES!
|
||||
If the tool response is not useful, ignore it completely.
|
||||
{optional_reminder}{hint}
|
||||
IMPORTANT! You MUST respond with a markdown code snippet of a json blob with a single action, and NOTHING else.
|
||||
"""
|
||||
|
||||
|
||||
TOOL_LESS_FOLLOWUP = """
|
||||
Refer to the following documents when responding to my final query. Ignore any documents that are not relevant.
|
||||
|
||||
CONTEXT DOCUMENTS:
|
||||
---------------------
|
||||
{context_str}
|
||||
|
||||
FINAL QUERY:
|
||||
--------------------
|
||||
{user_query}
|
||||
|
||||
{hint_text}
|
||||
"""
|
||||
|
||||
|
||||
def form_user_prompt_text(
|
||||
query: str,
|
||||
tool_text: str | None,
|
||||
hint_text: str | None,
|
||||
user_input_prompt: str = USER_INPUT,
|
||||
tool_less_prompt: str = TOOL_LESS_PROMPT,
|
||||
) -> str:
|
||||
user_prompt = tool_text or tool_less_prompt
|
||||
|
||||
user_prompt += user_input_prompt.format(user_input=query)
|
||||
|
||||
if hint_text:
|
||||
if user_prompt[-1] != "\n":
|
||||
user_prompt += "\n"
|
||||
user_prompt += "\nHint: " + hint_text
|
||||
|
||||
return user_prompt.strip()
|
||||
|
||||
|
||||
def form_tool_section_text(
|
||||
tools: list[ToolInfo] | None, retrieval_enabled: bool, template: str = TOOL_TEMPLATE
|
||||
) -> str | None:
|
||||
if not tools and not retrieval_enabled:
|
||||
return None
|
||||
|
||||
if retrieval_enabled and tools:
|
||||
tools.append(
|
||||
{"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION}
|
||||
)
|
||||
|
||||
tools_intro = []
|
||||
if tools:
|
||||
num_tools = len(tools)
|
||||
for tool in tools:
|
||||
description_formatted = tool["description"].replace("\n", " ")
|
||||
tools_intro.append(f"> {tool['name']}: {description_formatted}")
|
||||
|
||||
prefix = "Must be one of " if num_tools > 1 else "Must be "
|
||||
|
||||
tools_intro_text = "\n".join(tools_intro)
|
||||
tool_names_text = prefix + ", ".join([tool["name"] for tool in tools])
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
return template.format(
|
||||
tool_overviews=tools_intro_text, tool_names=tool_names_text
|
||||
).strip()
|
||||
|
||||
|
||||
def format_danswer_chunks_for_chat(chunks: list[InferenceChunk]) -> str:
|
||||
if not chunks:
|
||||
return "No Results Found"
|
||||
|
||||
return "\n".join(
|
||||
f"DOCUMENT {ind}:{CODE_BLOCK_PAT.format(chunk.content)}"
|
||||
for ind, chunk in enumerate(chunks, start=1)
|
||||
)
|
||||
|
||||
|
||||
def form_tool_followup_text(
|
||||
tool_output: str,
|
||||
query: str,
|
||||
hint_text: str | None,
|
||||
tool_followup_prompt: str = TOOL_FOLLOWUP,
|
||||
ignore_hint: bool = False,
|
||||
) -> str:
|
||||
# If multi-line query, it likely confuses the model more than helps
|
||||
if "\n" not in query:
|
||||
optional_reminder = f"\nAs a reminder, my query was: {query}\n"
|
||||
else:
|
||||
optional_reminder = ""
|
||||
|
||||
if not ignore_hint and hint_text:
|
||||
hint_text_spaced = f"\nHint: {hint_text}\n"
|
||||
else:
|
||||
hint_text_spaced = ""
|
||||
|
||||
return tool_followup_prompt.format(
|
||||
tool_output=tool_output,
|
||||
optional_reminder=optional_reminder,
|
||||
hint=hint_text_spaced,
|
||||
).strip()
|
||||
|
||||
|
||||
def build_combined_query(
|
||||
query_message: ChatMessage,
|
||||
history: list[ChatMessage],
|
||||
) -> list[BaseMessage]:
|
||||
user_query = query_message.message
|
||||
combined_query_msgs: list[BaseMessage] = []
|
||||
|
||||
if not user_query:
|
||||
raise ValueError("Can't rephrase/search an empty query")
|
||||
|
||||
combined_query_msgs.append(SystemMessage(content=DANSWER_SYSTEM_MSG))
|
||||
|
||||
combined_query_msgs.extend(
|
||||
[translate_danswer_msg_to_langchain(msg) for msg in history]
|
||||
)
|
||||
|
||||
combined_query_msgs.append(
|
||||
HumanMessage(
|
||||
content=(
|
||||
"Help me rewrite this final message into a standalone query that takes into consideration the "
|
||||
f"past messages of the conversation if relevant. This query is used with a semantic search engine to "
|
||||
f"retrieve documents. You must ONLY return the rewritten query and nothing else. "
|
||||
f"Remember, the search engine does not have access to the conversation history!"
|
||||
f"\n\nQuery:\n{query_message.message}"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return combined_query_msgs
|
||||
|
||||
|
||||
def form_require_search_single_msg_text(
|
||||
query_message: ChatMessage,
|
||||
history: list[ChatMessage],
|
||||
) -> str:
|
||||
prompt = "MESSAGE_HISTORY\n---------------\n" if history else ""
|
||||
|
||||
for msg in history:
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
prefix = "AI"
|
||||
else:
|
||||
prefix = "User"
|
||||
prompt += f"{prefix}:\n```\n{msg.message}\n```\n\n"
|
||||
|
||||
prompt += f"\nFINAL QUERY:\n---------------\n{query_message.message}"
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def form_require_search_text(query_message: ChatMessage) -> str:
|
||||
return (
|
||||
query_message.message
|
||||
+ f"\n\nHint: respond with EXACTLY {YES_SEARCH} or {NO_SEARCH}"
|
||||
)
|
||||
|
||||
|
||||
def form_tool_less_followup_text(
|
||||
tool_output: str,
|
||||
query: str,
|
||||
hint_text: str | None,
|
||||
tool_followup_prompt: str = TOOL_LESS_FOLLOWUP,
|
||||
) -> str:
|
||||
hint = f"Hint: {hint_text}" if hint_text else ""
|
||||
return tool_followup_prompt.format(
|
||||
context_str=tool_output, user_query=query, hint_text=hint
|
||||
).strip()
|
||||
479
backend/danswer/chat/chat_utils.py
Normal file
479
backend/danswer/chat/chat_utils.py
Normal file
@@ -0,0 +1,479 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from functools import lru_cache
|
||||
from typing import cast
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.chat_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_INPUT_TOKENS
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_PROMPT
|
||||
from danswer.prompts.chat_prompts import CITATION_REMINDER
|
||||
from danswer.prompts.chat_prompts import DEFAULT_IGNORE_STATEMENT
|
||||
from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT
|
||||
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
|
||||
from danswer.prompts.constants import CODE_BLOCK_PAT
|
||||
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
|
||||
from danswer.prompts.prompt_utils import get_current_llm_day_time
|
||||
|
||||
# Maps connector enum string to a more natural language representation for the LLM
|
||||
# If not on the list, uses the original but slightly cleaned up, see below
|
||||
CONNECTOR_NAME_MAP = {
|
||||
"web": "Website",
|
||||
"requesttracker": "Request Tracker",
|
||||
"github": "GitHub",
|
||||
"file": "File Upload",
|
||||
}
|
||||
|
||||
|
||||
def clean_up_source(source_str: str) -> str:
|
||||
if source_str in CONNECTOR_NAME_MAP:
|
||||
return CONNECTOR_NAME_MAP[source_str]
|
||||
return source_str.replace("_", " ").title()
|
||||
|
||||
|
||||
def build_context_str(
|
||||
context_docs: list[LlmDoc | InferenceChunk],
|
||||
include_metadata: bool = True,
|
||||
) -> str:
|
||||
context_str = ""
|
||||
for ind, doc in enumerate(context_docs, start=1):
|
||||
if include_metadata:
|
||||
context_str += f"DOCUMENT {ind}: {doc.semantic_identifier}\n"
|
||||
context_str += f"Source: {clean_up_source(doc.source_type)}\n"
|
||||
if doc.updated_at:
|
||||
update_str = doc.updated_at.strftime("%B %d, %Y %H:%M")
|
||||
context_str += f"Updated: {update_str}\n"
|
||||
context_str += f"{CODE_BLOCK_PAT.format(doc.content.strip())}\n\n\n"
|
||||
|
||||
return context_str.strip()
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def build_chat_system_message(
|
||||
prompt: Prompt,
|
||||
context_exists: bool,
|
||||
llm_tokenizer: Callable,
|
||||
citation_line: str = REQUIRE_CITATION_STATEMENT,
|
||||
no_citation_line: str = NO_CITATION_STATEMENT,
|
||||
) -> tuple[SystemMessage | None, int]:
|
||||
system_prompt = prompt.system_prompt.strip()
|
||||
if prompt.include_citations:
|
||||
if context_exists:
|
||||
system_prompt += citation_line
|
||||
else:
|
||||
system_prompt += no_citation_line
|
||||
if prompt.datetime_aware:
|
||||
if system_prompt:
|
||||
system_prompt += (
|
||||
f"\n\nAdditional Information:\n\t- {get_current_llm_day_time()}."
|
||||
)
|
||||
else:
|
||||
system_prompt = get_current_llm_day_time()
|
||||
|
||||
if not system_prompt:
|
||||
return None, 0
|
||||
|
||||
token_count = len(llm_tokenizer(system_prompt))
|
||||
system_msg = SystemMessage(content=system_prompt)
|
||||
|
||||
return system_msg, token_count
|
||||
|
||||
|
||||
def build_task_prompt_reminders(
|
||||
prompt: Prompt,
|
||||
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
|
||||
citation_str: str = CITATION_REMINDER,
|
||||
language_hint_str: str = LANGUAGE_HINT,
|
||||
) -> str:
|
||||
base_task = prompt.task_prompt
|
||||
citation_or_nothing = citation_str if prompt.include_citations else ""
|
||||
language_hint_or_nothing = language_hint_str.lstrip() if use_language_hint else ""
|
||||
return base_task + citation_or_nothing + language_hint_or_nothing
|
||||
|
||||
|
||||
def llm_doc_from_inference_chunk(inf_chunk: InferenceChunk) -> LlmDoc:
|
||||
return LlmDoc(
|
||||
document_id=inf_chunk.document_id,
|
||||
content=inf_chunk.content,
|
||||
semantic_identifier=inf_chunk.semantic_identifier,
|
||||
source_type=inf_chunk.source_type,
|
||||
updated_at=inf_chunk.updated_at,
|
||||
link=inf_chunk.source_links[0] if inf_chunk.source_links else None,
|
||||
)
|
||||
|
||||
|
||||
def map_document_id_order(
|
||||
chunks: list[InferenceChunk | LlmDoc], one_indexed: bool = True
|
||||
) -> dict[str, int]:
|
||||
order_mapping = {}
|
||||
current = 1 if one_indexed else 0
|
||||
for chunk in chunks:
|
||||
if chunk.document_id not in order_mapping:
|
||||
order_mapping[chunk.document_id] = current
|
||||
current += 1
|
||||
|
||||
return order_mapping
|
||||
|
||||
|
||||
def build_chat_user_message(
|
||||
chat_message: ChatMessage,
|
||||
prompt: Prompt,
|
||||
context_docs: list[LlmDoc],
|
||||
llm_tokenizer: Callable,
|
||||
all_doc_useful: bool,
|
||||
user_prompt_template: str = CHAT_USER_PROMPT,
|
||||
context_free_template: str = CHAT_USER_CONTEXT_FREE_PROMPT,
|
||||
ignore_str: str = DEFAULT_IGNORE_STATEMENT,
|
||||
) -> tuple[HumanMessage, int]:
|
||||
user_query = chat_message.message
|
||||
|
||||
if not context_docs:
|
||||
# Simpler prompt for cases where there is no context
|
||||
user_prompt = (
|
||||
context_free_template.format(
|
||||
task_prompt=prompt.task_prompt, user_query=user_query
|
||||
)
|
||||
if prompt.task_prompt
|
||||
else user_query
|
||||
)
|
||||
user_prompt = user_prompt.strip()
|
||||
token_count = len(llm_tokenizer(user_prompt))
|
||||
user_msg = HumanMessage(content=user_prompt)
|
||||
return user_msg, token_count
|
||||
|
||||
context_docs_str = build_context_str(
|
||||
cast(list[LlmDoc | InferenceChunk], context_docs)
|
||||
)
|
||||
optional_ignore = "" if all_doc_useful else ignore_str
|
||||
|
||||
task_prompt_with_reminder = build_task_prompt_reminders(prompt)
|
||||
|
||||
user_prompt = user_prompt_template.format(
|
||||
optional_ignore_statement=optional_ignore,
|
||||
context_docs_str=context_docs_str,
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
user_query=user_query,
|
||||
)
|
||||
|
||||
user_prompt = user_prompt.strip()
|
||||
token_count = len(llm_tokenizer(user_prompt))
|
||||
user_msg = HumanMessage(content=user_prompt)
|
||||
|
||||
return user_msg, token_count
|
||||
|
||||
|
||||
def _get_usable_chunks(
|
||||
chunks: list[InferenceChunk], token_limit: int
|
||||
) -> list[InferenceChunk]:
|
||||
total_token_count = 0
|
||||
usable_chunks = []
|
||||
for chunk in chunks:
|
||||
chunk_token_count = check_number_of_tokens(chunk.content)
|
||||
if total_token_count + chunk_token_count > token_limit:
|
||||
break
|
||||
|
||||
total_token_count += chunk_token_count
|
||||
usable_chunks.append(chunk)
|
||||
|
||||
# try and return at least one chunk if possible. This chunk will
|
||||
# get truncated later on in the pipeline. This would only occur if
|
||||
# the first chunk is larger than the token limit (usually due to character
|
||||
# count -> token count mismatches caused by special characters / non-ascii
|
||||
# languages)
|
||||
if not usable_chunks and chunks:
|
||||
usable_chunks = [chunks[0]]
|
||||
|
||||
return usable_chunks
|
||||
|
||||
|
||||
def get_usable_chunks(
|
||||
chunks: list[InferenceChunk],
|
||||
token_limit: int = NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
|
||||
offset: int = 0,
|
||||
) -> list[InferenceChunk]:
|
||||
offset_into_chunks = 0
|
||||
usable_chunks: list[InferenceChunk] = []
|
||||
for _ in range(min(offset + 1, 1)): # go through this process at least once
|
||||
if offset_into_chunks >= len(chunks) and offset_into_chunks > 0:
|
||||
raise ValueError(
|
||||
"Chunks offset too large, should not retry this many times"
|
||||
)
|
||||
|
||||
usable_chunks = _get_usable_chunks(
|
||||
chunks=chunks[offset_into_chunks:], token_limit=token_limit
|
||||
)
|
||||
offset_into_chunks += len(usable_chunks)
|
||||
|
||||
return usable_chunks
|
||||
|
||||
|
||||
def get_chunks_for_qa(
|
||||
chunks: list[InferenceChunk],
|
||||
llm_chunk_selection: list[bool],
|
||||
token_limit: float | None = NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
|
||||
batch_offset: int = 0,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Gives back indices of chunks to pass into the LLM for Q&A.
|
||||
|
||||
Only selects chunks viable for Q&A, within the token limit, and prioritize those selected
|
||||
by the LLM in a separate flow (this can be turned off)
|
||||
|
||||
Note, the batch_offset calculation has to count the batches from the beginning each time as
|
||||
there's no way to know which chunks were included in the prior batches without recounting atm,
|
||||
this is somewhat slow as it requires tokenizing all the chunks again
|
||||
"""
|
||||
batch_index = 0
|
||||
latest_batch_indices: list[int] = []
|
||||
token_count = 0
|
||||
|
||||
# First iterate the LLM selected chunks, then iterate the rest if tokens remaining
|
||||
for selection_target in [True, False]:
|
||||
for ind, chunk in enumerate(chunks):
|
||||
if llm_chunk_selection[ind] is not selection_target or chunk.metadata.get(
|
||||
IGNORE_FOR_QA
|
||||
):
|
||||
continue
|
||||
|
||||
# We calculate it live in case the user uses a different LLM + tokenizer
|
||||
chunk_token = check_number_of_tokens(chunk.content)
|
||||
# 50 for an approximate/slight overestimate for # tokens for metadata for the chunk
|
||||
token_count += chunk_token + 50
|
||||
|
||||
# Always use at least 1 chunk
|
||||
if (
|
||||
token_limit is None
|
||||
or token_count <= token_limit
|
||||
or not latest_batch_indices
|
||||
):
|
||||
latest_batch_indices.append(ind)
|
||||
current_chunk_unused = False
|
||||
else:
|
||||
current_chunk_unused = True
|
||||
|
||||
if token_limit is not None and token_count >= token_limit:
|
||||
if batch_index < batch_offset:
|
||||
batch_index += 1
|
||||
if current_chunk_unused:
|
||||
latest_batch_indices = [ind]
|
||||
token_count = chunk_token
|
||||
else:
|
||||
latest_batch_indices = []
|
||||
token_count = 0
|
||||
else:
|
||||
return latest_batch_indices
|
||||
|
||||
return latest_batch_indices
|
||||
|
||||
|
||||
def create_chat_chain(
|
||||
chat_session_id: int,
|
||||
db_session: Session,
|
||||
) -> tuple[ChatMessage, list[ChatMessage]]:
|
||||
"""Build the linear chain of messages without including the root message"""
|
||||
mainline_messages: list[ChatMessage] = []
|
||||
all_chat_messages = get_chat_messages_by_session(
|
||||
chat_session_id=chat_session_id,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
skip_permission_check=True,
|
||||
)
|
||||
id_to_msg = {msg.id: msg for msg in all_chat_messages}
|
||||
|
||||
if not all_chat_messages:
|
||||
raise ValueError("No messages in Chat Session")
|
||||
|
||||
root_message = all_chat_messages[0]
|
||||
if root_message.parent_message is not None:
|
||||
raise RuntimeError(
|
||||
"Invalid root message, unable to fetch valid chat message sequence"
|
||||
)
|
||||
|
||||
current_message: ChatMessage | None = root_message
|
||||
while current_message is not None:
|
||||
child_msg = current_message.latest_child_message
|
||||
if not child_msg:
|
||||
break
|
||||
current_message = id_to_msg.get(child_msg)
|
||||
|
||||
if current_message is None:
|
||||
raise RuntimeError(
|
||||
"Invalid message chain,"
|
||||
"could not find next message in the same session"
|
||||
)
|
||||
|
||||
mainline_messages.append(current_message)
|
||||
|
||||
if not mainline_messages:
|
||||
raise RuntimeError("Could not trace chat message history")
|
||||
|
||||
return mainline_messages[-1], mainline_messages[:-1]
|
||||
|
||||
|
||||
def combine_message_chain(
|
||||
messages: list[ChatMessage],
|
||||
msg_limit: int | None = 10,
|
||||
token_limit: int | None = GEN_AI_HISTORY_CUTOFF,
|
||||
) -> str:
|
||||
"""Used for secondary LLM flows that require the chat history"""
|
||||
message_strs: list[str] = []
|
||||
total_token_count = 0
|
||||
|
||||
if msg_limit is not None:
|
||||
messages = messages[-msg_limit:]
|
||||
|
||||
for message in reversed(messages):
|
||||
message_token_count = message.token_count
|
||||
|
||||
if (
|
||||
token_limit is not None
|
||||
and total_token_count + message_token_count > token_limit
|
||||
):
|
||||
break
|
||||
|
||||
role = message.message_type.value.upper()
|
||||
message_strs.insert(0, f"{role}:\n{message.message}")
|
||||
total_token_count += message_token_count
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def find_last_index(
|
||||
lst: list[int], max_prompt_tokens: int = GEN_AI_MAX_INPUT_TOKENS
|
||||
) -> int:
|
||||
"""From the back, find the index of the last element to include
|
||||
before the list exceeds the maximum"""
|
||||
running_sum = 0
|
||||
|
||||
last_ind = 0
|
||||
for i in range(len(lst) - 1, -1, -1):
|
||||
running_sum += lst[i]
|
||||
if running_sum > max_prompt_tokens:
|
||||
last_ind = i + 1
|
||||
break
|
||||
if last_ind >= len(lst):
|
||||
raise ValueError("Last message alone is too large!")
|
||||
return last_ind
|
||||
|
||||
|
||||
def drop_messages_history_overflow(
|
||||
system_msg: BaseMessage | None,
|
||||
system_token_count: int,
|
||||
history_msgs: list[BaseMessage],
|
||||
history_token_counts: list[int],
|
||||
final_msg: BaseMessage,
|
||||
final_msg_token_count: int,
|
||||
) -> list[BaseMessage]:
|
||||
"""As message history grows, messages need to be dropped starting from the furthest in the past.
|
||||
The System message should be kept if at all possible and the latest user input which is inserted in the
|
||||
prompt template must be included"""
|
||||
|
||||
if len(history_msgs) != len(history_token_counts):
|
||||
# This should never happen
|
||||
raise ValueError("Need exactly 1 token count per message for tracking overflow")
|
||||
|
||||
prompt: list[BaseMessage] = []
|
||||
|
||||
# Start dropping from the history if necessary
|
||||
all_tokens = history_token_counts + [system_token_count, final_msg_token_count]
|
||||
ind_prev_msg_start = find_last_index(all_tokens)
|
||||
|
||||
if system_msg and ind_prev_msg_start <= len(history_msgs):
|
||||
prompt.append(system_msg)
|
||||
|
||||
prompt.extend(history_msgs[ind_prev_msg_start:])
|
||||
|
||||
prompt.append(final_msg)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def extract_citations_from_stream(
|
||||
tokens: Iterator[str],
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: dict[str, int],
|
||||
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
|
||||
max_citation_num = len(context_docs)
|
||||
curr_segment = ""
|
||||
prepend_bracket = False
|
||||
cited_inds = set()
|
||||
for token in tokens:
|
||||
# Special case of [1][ where ][ is a single token
|
||||
# This is where the model attempts to do consecutive citations like [1][2]
|
||||
if prepend_bracket:
|
||||
curr_segment += "[" + curr_segment
|
||||
prepend_bracket = False
|
||||
|
||||
curr_segment += token
|
||||
|
||||
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
|
||||
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
|
||||
|
||||
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
|
||||
citation_found = re.search(citation_pattern, curr_segment)
|
||||
|
||||
if citation_found:
|
||||
numerical_value = int(citation_found.group(1))
|
||||
if 1 <= numerical_value <= max_citation_num:
|
||||
context_llm_doc = context_docs[
|
||||
numerical_value - 1
|
||||
] # remove 1 index offset
|
||||
|
||||
link = context_llm_doc.link
|
||||
target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id]
|
||||
|
||||
# Use the citation number for the document's rank in
|
||||
# the search (or selected docs) results
|
||||
curr_segment = re.sub(
|
||||
rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment
|
||||
)
|
||||
|
||||
if target_citation_num not in cited_inds:
|
||||
cited_inds.add(target_citation_num)
|
||||
yield CitationInfo(
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
if link:
|
||||
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1)
|
||||
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
|
||||
|
||||
# In case there's another open bracket like [1][, don't want to match this
|
||||
possible_citation_found = None
|
||||
|
||||
# if we see "[", but haven't seen the right side, hold back - this may be a
|
||||
# citation that needs to be replaced with a link
|
||||
if possible_citation_found:
|
||||
continue
|
||||
|
||||
# Special case with back to back citations [1][2]
|
||||
if curr_segment and curr_segment[-1] == "[":
|
||||
curr_segment = curr_segment[:-1]
|
||||
prepend_bracket = True
|
||||
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||
curr_segment = ""
|
||||
|
||||
if curr_segment:
|
||||
if prepend_bracket:
|
||||
yield DanswerAnswerPiece(answer_piece="[" + curr_segment)
|
||||
else:
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||
106
backend/danswer/chat/load_yamls.py
Normal file
106
backend/danswer/chat/load_yamls.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from typing import cast
|
||||
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.chat_configs import PERSONAS_YAML
|
||||
from danswer.configs.chat_configs import PROMPTS_YAML
|
||||
from danswer.db.chat import get_prompt_by_name
|
||||
from danswer.db.chat import upsert_persona
|
||||
from danswer.db.chat import upsert_prompt
|
||||
from danswer.db.document_set import get_or_create_document_set_by_name
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||
from danswer.db.models import Prompt as PromptDBModel
|
||||
from danswer.search.models import RecencyBiasSetting
|
||||
|
||||
|
||||
def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
|
||||
with open(prompts_yaml, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_prompts = data.get("prompts", [])
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for prompt in all_prompts:
|
||||
upsert_prompt(
|
||||
user_id=None,
|
||||
prompt_id=prompt.get("id"),
|
||||
name=prompt["name"],
|
||||
description=prompt["description"].strip(),
|
||||
system_prompt=prompt["system"].strip(),
|
||||
task_prompt=prompt["task"].strip(),
|
||||
include_citations=prompt["include_citations"],
|
||||
datetime_aware=prompt.get("datetime_aware", True),
|
||||
default_prompt=True,
|
||||
personas=None,
|
||||
shared=True,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
|
||||
def load_personas_from_yaml(
|
||||
personas_yaml: str = PERSONAS_YAML,
|
||||
default_chunks: float = DEFAULT_NUM_CHUNKS_FED_TO_CHAT,
|
||||
) -> None:
|
||||
with open(personas_yaml, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_personas = data.get("personas", [])
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for persona in all_personas:
|
||||
doc_set_names = persona["document_sets"]
|
||||
doc_sets: list[DocumentSetDBModel] | None = [
|
||||
get_or_create_document_set_by_name(db_session, name)
|
||||
for name in doc_set_names
|
||||
]
|
||||
|
||||
# Assume if user hasn't set any document sets for the persona, the user may want
|
||||
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
|
||||
# the document sets for the persona
|
||||
if not doc_sets:
|
||||
doc_sets = None
|
||||
|
||||
prompt_set_names = persona["prompts"]
|
||||
if not prompt_set_names:
|
||||
prompts: list[PromptDBModel | None] | None = None
|
||||
else:
|
||||
prompts = [
|
||||
get_prompt_by_name(
|
||||
prompt_name, user_id=None, shared=True, db_session=db_session
|
||||
)
|
||||
for prompt_name in prompt_set_names
|
||||
]
|
||||
if any([prompt is None for prompt in prompts]):
|
||||
raise ValueError("Invalid Persona configs, not all prompts exist")
|
||||
|
||||
if not prompts:
|
||||
prompts = None
|
||||
|
||||
upsert_persona(
|
||||
user_id=None,
|
||||
persona_id=persona.get("id"),
|
||||
name=persona["name"],
|
||||
description=persona["description"],
|
||||
num_chunks=persona.get("num_chunks")
|
||||
if persona.get("num_chunks") is not None
|
||||
else default_chunks,
|
||||
llm_relevance_filter=persona.get("llm_relevance_filter"),
|
||||
llm_filter_extraction=persona.get("llm_filter_extraction"),
|
||||
llm_model_version_override=None,
|
||||
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
|
||||
prompts=cast(list[PromptDBModel] | None, prompts),
|
||||
document_sets=doc_sets,
|
||||
default_persona=True,
|
||||
shared=True,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def load_chat_yamls(
|
||||
prompt_yaml: str = PROMPTS_YAML,
|
||||
personas_yaml: str = PERSONAS_YAML,
|
||||
) -> None:
|
||||
load_prompts_from_yaml(prompt_yaml)
|
||||
load_personas_from_yaml(personas_yaml)
|
||||
100
backend/danswer/chat/models.py
Normal file
100
backend/danswer/chat/models.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.search.models import QueryFlow
|
||||
from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SearchResponse
|
||||
from danswer.search.models import SearchType
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
"""This contains the minimal set information for the LLM portion including citations"""
|
||||
|
||||
document_id: str
|
||||
content: str
|
||||
semantic_identifier: str
|
||||
source_type: DocumentSource
|
||||
updated_at: datetime | None
|
||||
link: str | None
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
class QADocsResponse(RetrievalDocs):
|
||||
rephrased_query: str | None = None
|
||||
predicted_flow: QueryFlow | None
|
||||
predicted_search: SearchType | None
|
||||
applied_source_filters: list[DocumentSource] | None
|
||||
applied_time_cutoff: datetime | None
|
||||
recency_bias_multiplier: float
|
||||
|
||||
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().dict(*args, **kwargs) # type: ignore
|
||||
initial_dict["applied_time_cutoff"] = (
|
||||
self.applied_time_cutoff.isoformat() if self.applied_time_cutoff else None
|
||||
)
|
||||
return initial_dict
|
||||
|
||||
|
||||
# Second chunk of info for streaming QA
|
||||
class LLMRelevanceFilterResponse(BaseModel):
|
||||
relevant_chunk_indices: list[int]
|
||||
|
||||
|
||||
class DanswerAnswerPiece(BaseModel):
|
||||
# A small piece of a complete answer. Used for streaming back answers.
|
||||
answer_piece: str | None # if None, specifies the end of an Answer
|
||||
|
||||
|
||||
# An intermediate representation of citations, later translated into
|
||||
# a mapping of the citation [n] number to SearchDoc
|
||||
class CitationInfo(BaseModel):
|
||||
citation_num: int
|
||||
document_id: str
|
||||
|
||||
|
||||
class StreamingError(BaseModel):
|
||||
error: str
|
||||
|
||||
|
||||
class DanswerQuote(BaseModel):
|
||||
# This is during inference so everything is a string by this point
|
||||
quote: str
|
||||
document_id: str
|
||||
link: str | None
|
||||
source_type: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class DanswerQuotes(BaseModel):
|
||||
quotes: list[DanswerQuote]
|
||||
|
||||
|
||||
class DanswerAnswer(BaseModel):
|
||||
answer: str | None
|
||||
|
||||
|
||||
class QAResponse(SearchResponse, DanswerAnswer):
|
||||
quotes: list[DanswerQuote] | None
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
eval_res_valid: bool | None = None
|
||||
llm_chunks_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
|
||||
|
||||
AnswerQuestionReturn = tuple[DanswerAnswer, DanswerQuotes]
|
||||
|
||||
|
||||
AnswerQuestionStreamReturn = Iterator[
|
||||
DanswerAnswerPiece | DanswerQuotes | StreamingError
|
||||
]
|
||||
|
||||
|
||||
class LLMMetricsContainer(BaseModel):
|
||||
prompt_tokens: int
|
||||
response_tokens: int
|
||||
@@ -1,81 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import PERSONAS_YAML
|
||||
from danswer.db.chat import upsert_persona
|
||||
from danswer.db.document_set import get_or_create_document_set_by_name
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import ToolInfo
|
||||
|
||||
|
||||
def build_system_text_from_persona(persona: Persona) -> str | None:
|
||||
text = (persona.system_text or "").strip()
|
||||
if persona.datetime_aware:
|
||||
current_datetime = datetime.now()
|
||||
# Format looks like: "October 16, 2023 14:30"
|
||||
formatted_datetime = current_datetime.strftime("%B %d, %Y %H:%M")
|
||||
|
||||
text += (
|
||||
"\n\nAdditional Information:\n"
|
||||
f"\t- The current date and time is {formatted_datetime}."
|
||||
)
|
||||
|
||||
return text or None
|
||||
|
||||
|
||||
def validate_tool_info(item: Any) -> ToolInfo:
|
||||
if not (
|
||||
isinstance(item, dict)
|
||||
and "name" in item
|
||||
and isinstance(item["name"], str)
|
||||
and "description" in item
|
||||
and isinstance(item["description"], str)
|
||||
):
|
||||
raise ValueError(
|
||||
"Invalid Persona configuration yaml Found, not all tools have name/description"
|
||||
)
|
||||
return ToolInfo(name=item["name"], description=item["description"])
|
||||
|
||||
|
||||
def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None:
|
||||
with open(personas_yaml, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_personas = data.get("personas", [])
|
||||
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session:
|
||||
for persona in all_personas:
|
||||
tools = [validate_tool_info(tool) for tool in persona["tools"]]
|
||||
|
||||
doc_set_names = persona["document_sets"]
|
||||
doc_sets: list[DocumentSetDBModel] | None = [
|
||||
get_or_create_document_set_by_name(db_session, name)
|
||||
for name in doc_set_names
|
||||
]
|
||||
|
||||
# Assume if user hasn't set any document sets for the persona, the user may want
|
||||
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
|
||||
# the document sets for the persona
|
||||
if not doc_sets:
|
||||
doc_sets = None
|
||||
|
||||
upsert_persona(
|
||||
name=persona["name"],
|
||||
retrieval_enabled=persona.get("retrieval_enabled", True),
|
||||
# Default to knowing the date/time if not specified, however if there is no
|
||||
# system prompt, do not interfere with the flow by adding a
|
||||
# system prompt that is ONLY the date info, this would likely not be useful
|
||||
datetime_aware=persona.get(
|
||||
"datetime_aware", bool(persona.get("system"))
|
||||
),
|
||||
system_text=persona.get("system"),
|
||||
tools=tools,
|
||||
hint_text=persona.get("hint"),
|
||||
default_persona=True,
|
||||
document_sets=doc_sets,
|
||||
db_session=db_session,
|
||||
)
|
||||
@@ -1,12 +1,34 @@
|
||||
# Currently in the UI, each Persona only has one prompt, which is why there are 3 very similar personas defined below.
|
||||
|
||||
personas:
|
||||
- name: "Danswer"
|
||||
system: |
|
||||
You are a question answering system that is constantly learning and improving.
|
||||
You can process and comprehend vast amounts of text and utilize this knowledge to provide accurate and detailed answers to diverse queries.
|
||||
Your responses are as INFORMATIVE and DETAILED as possible.
|
||||
Cite relevant statements using the format [1], [2], etc to reference the document number, do not provide any links following the citation.
|
||||
# Document Sets that this persona has access to, specified as a list of names here.
|
||||
# If left empty, the persona has access to all and only public docs
|
||||
# This id field can be left blank for other default personas, however an id 0 persona must exist
|
||||
# this is for DanswerBot to use when tagged in a non-configured channel
|
||||
# Careful setting specific IDs, this won't autoincrement the next ID value for postgres
|
||||
- id: 0
|
||||
name: "Default"
|
||||
description: >
|
||||
Default Danswer Question Answering functionality.
|
||||
# Default Prompt objects attached to the persona, see prompts.yaml
|
||||
prompts:
|
||||
- "Answer-Question"
|
||||
# Default number of chunks to include as context, set to 0 to disable retrieval
|
||||
# Remove the field to set to the system default number of chunks/tokens to pass to Gen AI
|
||||
# If selecting documents, user can bypass this up until NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
||||
# Each chunk is 512 tokens long
|
||||
num_chunks: 5
|
||||
# Enable/Disable usage of the LLM chunk filter feature whereby each chunk is passed to the LLM to determine
|
||||
# if the chunk is useful or not towards the latest user query
|
||||
# This feature can be overriden for all personas via DISABLE_LLM_CHUNK_FILTER env variable
|
||||
llm_relevance_filter: true
|
||||
# Enable/Disable usage of the LLM to extract query time filters including source type and time range filters
|
||||
llm_filter_extraction: true
|
||||
# Decay documents priority as they age, options are:
|
||||
# - favor_recent (2x base by default, configurable)
|
||||
# - base_decay
|
||||
# - no_decay
|
||||
# - auto (model chooses between favor_recent and base_decay based on user query)
|
||||
recency_bias: "auto"
|
||||
# Default Document Sets for this persona, specified as a list of names here.
|
||||
# If the document set by the name exists, it will be attached to the persona
|
||||
# If the document set by the name does not exist, it will be created as an empty document set with no connectors
|
||||
# The admin can then use the UI to add new connectors to the document set
|
||||
@@ -16,19 +38,28 @@ personas:
|
||||
# - "Engineer Onboarding"
|
||||
# - "Benefits"
|
||||
document_sets: []
|
||||
# Danswer custom tool flow, "Current Search" tool name is reserved if this is enabled.
|
||||
retrieval_enabled: true
|
||||
# Inject a statement at the end of system prompt to inform the LLM of the current date/time
|
||||
# Format looks like: "October 16, 2023 14:30"
|
||||
datetime_aware: true
|
||||
# Personas can be given tools for Agentifying Danswer, however the tool call must be implemented in the code
|
||||
# Once implemented, it can be given to personas via the config.
|
||||
# Example of adding tools, it must follow this structure:
|
||||
# tools:
|
||||
# - name: "Calculator"
|
||||
# description: "Use this tool to accurately process math equations, counting, etc."
|
||||
# - name: "Current Weather"
|
||||
# description: "Call this to get the current weather info."
|
||||
tools: []
|
||||
# Short tip to pass near the end of the prompt to emphasize some requirement
|
||||
hint: "Try to be as informative as possible!"
|
||||
|
||||
|
||||
- name: "Summarize"
|
||||
description: >
|
||||
A less creative assistant which summarizes relevant documents but does not try to
|
||||
extrapolate any answers for you.
|
||||
prompts:
|
||||
- "Summarize"
|
||||
num_chunks: 5
|
||||
llm_relevance_filter: true
|
||||
llm_filter_extraction: true
|
||||
recency_bias: "auto"
|
||||
document_sets: []
|
||||
|
||||
|
||||
- name: "Paraphrase"
|
||||
description: >
|
||||
The least creative default assistant that only provides quotes from the documents.
|
||||
prompts:
|
||||
- "Paraphrase"
|
||||
num_chunks: 5
|
||||
llm_relevance_filter: true
|
||||
llm_filter_extraction: true
|
||||
recency_bias: "auto"
|
||||
document_sets: []
|
||||
|
||||
471
backend/danswer/chat/process_message.py
Normal file
471
backend/danswer/chat/process_message.py
Normal file
@@ -0,0 +1,471 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from functools import partial
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import build_chat_system_message
|
||||
from danswer.chat.chat_utils import build_chat_user_message
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.chat_utils import drop_messages_history_overflow
|
||||
from danswer.chat.chat_utils import extract_citations_from_stream
|
||||
from danswer.chat.chat_utils import get_chunks_for_qa
|
||||
from danswer.chat.chat_utils import llm_doc_from_inference_chunk
|
||||
from danswer.chat.chat_utils import map_document_id_order
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import CHUNK_SIZE
|
||||
from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.constants import DISABLED_GEN_AI_MSG
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.chat import create_db_search_doc
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_chat_message
|
||||
from danswer.db.chat import get_chat_session_by_id
|
||||
from danswer.db.chat import get_db_search_doc_by_id
|
||||
from danswer.db.chat import get_doc_query_identifiers_from_model
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import SearchDoc as DbSearchDoc
|
||||
from danswer.db.models import User
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.llm.utils import translate_history_to_basemessages
|
||||
from danswer.search.models import OptionalSearchSetting
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.search.request_preprocessing import retrieval_preprocessing
|
||||
from danswer.search.search_runner import chunks_to_search_docs
|
||||
from danswer.search.search_runner import full_chunk_search_generator
|
||||
from danswer.search.search_runner import inference_documents_from_ids
|
||||
from danswer.secondary_llm_flows.choose_search import check_if_need_search
|
||||
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generate_ai_chat_response(
|
||||
query_message: ChatMessage,
|
||||
history: list[ChatMessage],
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: dict[str, int],
|
||||
llm: LLM | None,
|
||||
llm_tokenizer: Callable,
|
||||
all_doc_useful: bool,
|
||||
) -> Iterator[DanswerAnswerPiece | CitationInfo | StreamingError]:
|
||||
if llm is None:
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
except GenAIDisabledException:
|
||||
# Not an error if it's a user configuration
|
||||
yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG)
|
||||
return
|
||||
|
||||
if query_message.prompt is None:
|
||||
raise RuntimeError("No prompt received for generating Gen AI answer.")
|
||||
|
||||
try:
|
||||
context_exists = len(context_docs) > 0
|
||||
|
||||
system_message_or_none, system_tokens = build_chat_system_message(
|
||||
prompt=query_message.prompt,
|
||||
context_exists=context_exists,
|
||||
llm_tokenizer=llm_tokenizer,
|
||||
)
|
||||
|
||||
history_basemessages, history_token_counts = translate_history_to_basemessages(
|
||||
history
|
||||
)
|
||||
|
||||
# Be sure the context_docs passed to build_chat_user_message
|
||||
# Is the same as passed in later for extracting citations
|
||||
user_message, user_tokens = build_chat_user_message(
|
||||
chat_message=query_message,
|
||||
prompt=query_message.prompt,
|
||||
context_docs=context_docs,
|
||||
llm_tokenizer=llm_tokenizer,
|
||||
all_doc_useful=all_doc_useful,
|
||||
)
|
||||
|
||||
prompt = drop_messages_history_overflow(
|
||||
system_msg=system_message_or_none,
|
||||
system_token_count=system_tokens,
|
||||
history_msgs=history_basemessages,
|
||||
history_token_counts=history_token_counts,
|
||||
final_msg=user_message,
|
||||
final_msg_token_count=user_tokens,
|
||||
)
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
tokens = llm.stream(prompt)
|
||||
|
||||
yield from extract_citations_from_stream(
|
||||
tokens, context_docs, doc_id_to_rank_map
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"LLM failed to produce valid chat message, error: {e}")
|
||||
yield StreamingError(error=str(e))
|
||||
|
||||
|
||||
def translate_citations(
|
||||
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
|
||||
) -> dict[int, int]:
|
||||
"""Always cites the first instance of the document_id, assumes the db_docs
|
||||
are sorted in the order displayed in the UI"""
|
||||
doc_id_to_saved_doc_id_map: dict[str, int] = {}
|
||||
for db_doc in db_docs:
|
||||
if db_doc.document_id not in doc_id_to_saved_doc_id_map:
|
||||
doc_id_to_saved_doc_id_map[db_doc.document_id] = db_doc.id
|
||||
|
||||
citation_to_saved_doc_id_map: dict[int, int] = {}
|
||||
for citation in citations_list:
|
||||
if citation.citation_num not in citation_to_saved_doc_id_map:
|
||||
citation_to_saved_doc_id_map[
|
||||
citation.citation_num
|
||||
] = doc_id_to_saved_doc_id_map[citation.document_id]
|
||||
|
||||
return citation_to_saved_doc_id_map
|
||||
|
||||
|
||||
@log_generator_function_time()
|
||||
def stream_chat_message(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
# Needed to translate persona num_chunks to tokens to the LLM
|
||||
default_num_chunks: float = DEFAULT_NUM_CHUNKS_FED_TO_CHAT,
|
||||
default_chunk_size: int = CHUNK_SIZE,
|
||||
) -> Iterator[str]:
|
||||
"""Streams in order:
|
||||
1. [conditional] Retrieved documents if a search needs to be run
|
||||
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
|
||||
"""
|
||||
try:
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
message_text = new_msg_req.message
|
||||
chat_session_id = new_msg_req.chat_session_id
|
||||
parent_id = new_msg_req.parent_message_id
|
||||
prompt_id = new_msg_req.prompt_id
|
||||
reference_doc_ids = new_msg_req.search_doc_ids
|
||||
retrieval_options = new_msg_req.retrieval_options
|
||||
persona = chat_session.persona
|
||||
query_override = new_msg_req.query_override
|
||||
|
||||
if reference_doc_ids is None and retrieval_options is None:
|
||||
raise RuntimeError(
|
||||
"Must specify a set of documents for chat or specify search options"
|
||||
)
|
||||
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
except GenAIDisabledException:
|
||||
llm = None
|
||||
|
||||
llm_tokenizer = get_default_llm_token_encode()
|
||||
document_index = get_default_document_index()
|
||||
|
||||
# Every chat Session begins with an empty root message
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
|
||||
if parent_id is not None:
|
||||
parent_message = get_chat_message(
|
||||
chat_message_id=parent_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
else:
|
||||
parent_message = root_message
|
||||
|
||||
# Create new message at the right place in the tree and update the parent's child pointer
|
||||
# Don't commit yet until we verify the chat message chain
|
||||
new_user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=parent_message,
|
||||
prompt_id=prompt_id,
|
||||
message=message_text,
|
||||
token_count=len(llm_tokenizer(message_text)),
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
# Create linear history of messages
|
||||
final_msg, history_msgs = create_chat_chain(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
|
||||
if final_msg.id != new_user_message.id:
|
||||
db_session.rollback()
|
||||
raise RuntimeError(
|
||||
"The new message was not on the mainline. "
|
||||
"Be sure to update the chat pointers before calling this."
|
||||
)
|
||||
|
||||
# Save now to save the latest chat message
|
||||
db_session.commit()
|
||||
|
||||
run_search = False
|
||||
# Retrieval options are only None if reference_doc_ids are provided
|
||||
if retrieval_options is not None and persona.num_chunks != 0:
|
||||
if retrieval_options.run_search == OptionalSearchSetting.ALWAYS:
|
||||
run_search = True
|
||||
elif retrieval_options.run_search == OptionalSearchSetting.NEVER:
|
||||
run_search = False
|
||||
else:
|
||||
run_search = check_if_need_search(
|
||||
query_message=final_msg, history=history_msgs, llm=llm
|
||||
)
|
||||
|
||||
rephrased_query = None
|
||||
if reference_doc_ids:
|
||||
identifier_tuples = get_doc_query_identifiers_from_model(
|
||||
search_doc_ids=reference_doc_ids,
|
||||
chat_session=chat_session,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Generates full documents currently
|
||||
# May extend to include chunk ranges
|
||||
llm_docs: list[LlmDoc] = inference_documents_from_ids(
|
||||
doc_identifiers=identifier_tuples,
|
||||
document_index=get_default_document_index(),
|
||||
)
|
||||
doc_id_to_rank_map = map_document_id_order(
|
||||
cast(list[InferenceChunk | LlmDoc], llm_docs)
|
||||
)
|
||||
|
||||
# In case the search doc is deleted, just don't include it
|
||||
# though this should never happen
|
||||
db_search_docs_or_none = [
|
||||
get_db_search_doc_by_id(doc_id=doc_id, db_session=db_session)
|
||||
for doc_id in reference_doc_ids
|
||||
]
|
||||
|
||||
reference_db_search_docs = [
|
||||
db_sd for db_sd in db_search_docs_or_none if db_sd
|
||||
]
|
||||
|
||||
elif run_search:
|
||||
rephrased_query = (
|
||||
history_based_query_rephrase(
|
||||
query_message=final_msg, history=history_msgs, llm=llm
|
||||
)
|
||||
if query_override is None
|
||||
else query_override
|
||||
)
|
||||
|
||||
(
|
||||
retrieval_request,
|
||||
predicted_search_type,
|
||||
predicted_flow,
|
||||
) = retrieval_preprocessing(
|
||||
query=rephrased_query,
|
||||
retrieval_details=cast(RetrievalDetails, retrieval_options),
|
||||
persona=persona,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
documents_generator = full_chunk_search_generator(
|
||||
search_query=retrieval_request,
|
||||
document_index=document_index,
|
||||
)
|
||||
time_cutoff = retrieval_request.filters.time_cutoff
|
||||
recency_bias_multiplier = retrieval_request.recency_bias_multiplier
|
||||
run_llm_chunk_filter = not retrieval_request.skip_llm_chunk_filter
|
||||
|
||||
# First fetch and return the top chunks to the UI so the user can
|
||||
# immediately see some results
|
||||
top_chunks = cast(list[InferenceChunk], next(documents_generator))
|
||||
|
||||
# Get ranking of the documents for citation purposes later
|
||||
doc_id_to_rank_map = map_document_id_order(
|
||||
cast(list[InferenceChunk | LlmDoc], top_chunks)
|
||||
)
|
||||
|
||||
top_docs = chunks_to_search_docs(top_chunks)
|
||||
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=top_doc, db_session=db_session)
|
||||
for top_doc in top_docs
|
||||
]
|
||||
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
|
||||
initial_response = QADocsResponse(
|
||||
rephrased_query=rephrased_query,
|
||||
top_documents=response_docs,
|
||||
predicted_flow=predicted_flow,
|
||||
predicted_search=predicted_search_type,
|
||||
applied_source_filters=retrieval_request.filters.source_type,
|
||||
applied_time_cutoff=time_cutoff,
|
||||
recency_bias_multiplier=recency_bias_multiplier,
|
||||
).dict()
|
||||
yield get_json_line(initial_response)
|
||||
|
||||
# Get the final ordering of chunks for the LLM call
|
||||
llm_chunk_selection = cast(list[bool], next(documents_generator))
|
||||
|
||||
# Yield the list of LLM selected chunks for showing the LLM selected icons in the UI
|
||||
llm_relevance_filtering_response = LLMRelevanceFilterResponse(
|
||||
relevant_chunk_indices=[
|
||||
index for index, value in enumerate(llm_chunk_selection) if value
|
||||
]
|
||||
if run_llm_chunk_filter
|
||||
else []
|
||||
).dict()
|
||||
yield get_json_line(llm_relevance_filtering_response)
|
||||
|
||||
# Prep chunks to pass to LLM
|
||||
num_llm_chunks = (
|
||||
persona.num_chunks
|
||||
if persona.num_chunks is not None
|
||||
else default_num_chunks
|
||||
)
|
||||
llm_chunks_indices = get_chunks_for_qa(
|
||||
chunks=top_chunks,
|
||||
llm_chunk_selection=llm_chunk_selection,
|
||||
token_limit=num_llm_chunks * default_chunk_size,
|
||||
)
|
||||
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
|
||||
llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in llm_chunks]
|
||||
|
||||
else:
|
||||
llm_docs = []
|
||||
doc_id_to_rank_map = {}
|
||||
reference_db_search_docs = None
|
||||
|
||||
# Cannot determine these without the LLM step or breaking out early
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=new_user_message,
|
||||
prompt_id=prompt_id,
|
||||
# message=,
|
||||
rephrased_query=rephrased_query,
|
||||
# token_count=,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
# error=,
|
||||
reference_docs=reference_db_search_docs,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# If no prompt is provided, this is interpreted as not wanting an AI Answer
|
||||
# Simply provide/save the retrieval results
|
||||
if final_msg.prompt is None:
|
||||
gen_ai_response_message = partial_response(
|
||||
message="",
|
||||
token_count=0,
|
||||
citations=None,
|
||||
error=None,
|
||||
)
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
|
||||
yield get_json_line(msg_detail_response.dict())
|
||||
|
||||
# Stop here after saving message details, the above still needs to be sent for the
|
||||
# message id to send the next follow-up message
|
||||
return
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
response_packets = generate_ai_chat_response(
|
||||
query_message=final_msg,
|
||||
history=history_msgs,
|
||||
context_docs=llm_docs,
|
||||
doc_id_to_rank_map=doc_id_to_rank_map,
|
||||
llm=llm,
|
||||
llm_tokenizer=llm_tokenizer,
|
||||
all_doc_useful=reference_doc_ids is not None,
|
||||
)
|
||||
|
||||
# Capture outputs and errors
|
||||
llm_output = ""
|
||||
error: str | None = None
|
||||
citations: list[CitationInfo] = []
|
||||
for packet in response_packets:
|
||||
if isinstance(packet, DanswerAnswerPiece):
|
||||
token = packet.answer_piece
|
||||
if token:
|
||||
llm_output += token
|
||||
elif isinstance(packet, StreamingError):
|
||||
error = packet.error
|
||||
elif isinstance(packet, CitationInfo):
|
||||
citations.append(packet)
|
||||
continue
|
||||
|
||||
yield get_json_line(packet.dict())
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
# Frontend will erase whatever answer and show this instead
|
||||
# This will be the issue 99% of the time
|
||||
error_packet = StreamingError(
|
||||
error="LLM failed to respond, have you set your API key?"
|
||||
)
|
||||
|
||||
yield get_json_line(error_packet.dict())
|
||||
return
|
||||
|
||||
# Post-LLM answer processing
|
||||
try:
|
||||
db_citations = None
|
||||
if reference_db_search_docs:
|
||||
db_citations = translate_citations(
|
||||
citations_list=citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
gen_ai_response_message = partial_response(
|
||||
message=llm_output,
|
||||
token_count=len(llm_tokenizer(llm_output)),
|
||||
citations=db_citations,
|
||||
error=error,
|
||||
)
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
|
||||
yield get_json_line(msg_detail_response.dict())
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
# Frontend will erase whatever answer and show this instead
|
||||
error_packet = StreamingError(error="Failed to parse LLM output")
|
||||
|
||||
yield get_json_line(error_packet.dict())
|
||||
68
backend/danswer/chat/prompts.yaml
Normal file
68
backend/danswer/chat/prompts.yaml
Normal file
@@ -0,0 +1,68 @@
|
||||
prompts:
|
||||
# This id field can be left blank for other default prompts, however an id 0 prompt must exist
|
||||
# This is to act as a default
|
||||
# Careful setting specific IDs, this won't autoincrement the next ID value for postgres
|
||||
- id: 0
|
||||
name: "Answer-Question"
|
||||
description: "Answers user questions using retrieved context!"
|
||||
# System Prompt (as shown in UI)
|
||||
system: >
|
||||
You are a question answering system that is constantly learning and improving.
|
||||
|
||||
You can process and comprehend vast amounts of text and utilize this knowledge to provide
|
||||
grounded, accurate, and concise answers to diverse queries.
|
||||
|
||||
You always clearly communicate ANY UNCERTAINTY in your answer.
|
||||
# Task Prompt (as shown in UI)
|
||||
task: >
|
||||
Answer my query based on the documents provided.
|
||||
The documents may not all be relevant, ignore any documents that are not directly relevant
|
||||
to the most recent user query.
|
||||
|
||||
I have not read or seen any of the documents and do not want to read them.
|
||||
|
||||
If there are no relevant documents, refer to the chat history and existing knowledge.
|
||||
# Inject a statement at the end of system prompt to inform the LLM of the current date/time
|
||||
# Format looks like: "October 16, 2023 14:30"
|
||||
datetime_aware: true
|
||||
# Prompts the LLM to include citations in the for [1], [2] etc.
|
||||
# which get parsed to match the passed in sources
|
||||
include_citations: true
|
||||
|
||||
|
||||
- name: "Summarize"
|
||||
description: "Summarize relevant information from retrieved context!"
|
||||
system: >
|
||||
You are a text summarizing assistant that highlights the most important knowledge from the
|
||||
context provided, prioritizing the information that relates to the user query.
|
||||
|
||||
You ARE NOT creative and always stick to the provided documents.
|
||||
If there are no documents, refer to the conversation history.
|
||||
|
||||
IMPORTANT: YOU ONLY SUMMARIZE THE IMPORTANT INFORMATION FROM THE PROVIDED DOCUMENTS,
|
||||
NEVER USE YOUR OWN KNOWLEDGE.
|
||||
task: >
|
||||
Summarize the documents provided in relation to the query below.
|
||||
NEVER refer to the documents by number, I do not have them in the same order as you.
|
||||
Do not make up any facts, only use what is in the documents.
|
||||
datetime_aware: true
|
||||
include_citations: true
|
||||
|
||||
|
||||
- name: "Paraphrase"
|
||||
description: "Recites information from retrieved context! Least creative but most safe!"
|
||||
system: >
|
||||
Quote and cite relevant information from provided context based on the user query.
|
||||
|
||||
You only provide quotes that are EXACT substrings from provided documents!
|
||||
|
||||
If there are no documents provided,
|
||||
simply tell the user that there are no documents to reference.
|
||||
|
||||
You NEVER generate new text or phrases outside of the citation.
|
||||
DO NOT explain your responses, only provide the quotes and NOTHING ELSE.
|
||||
task: >
|
||||
Provide EXACT quotes from the provided documents above. Do not generate any new text that is not
|
||||
directly from the documents.
|
||||
datetime_aware: true
|
||||
include_citations: true
|
||||
@@ -1,7 +1,115 @@
|
||||
from danswer.direct_qa.interfaces import DanswerChatModelOut
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.prompts.chat_tools import DANSWER_TOOL_DESCRIPTION
|
||||
from danswer.prompts.chat_tools import DANSWER_TOOL_NAME
|
||||
from danswer.prompts.chat_tools import TOOL_FOLLOWUP
|
||||
from danswer.prompts.chat_tools import TOOL_LESS_FOLLOWUP
|
||||
from danswer.prompts.chat_tools import TOOL_LESS_PROMPT
|
||||
from danswer.prompts.chat_tools import TOOL_TEMPLATE
|
||||
from danswer.prompts.chat_tools import USER_INPUT
|
||||
|
||||
|
||||
class ToolInfo(TypedDict):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
||||
class DanswerChatModelOut(BaseModel):
|
||||
model_raw: str
|
||||
action: str
|
||||
action_input: str
|
||||
|
||||
|
||||
def call_tool(
|
||||
model_actions: DanswerChatModelOut,
|
||||
) -> str:
|
||||
raise NotImplementedError("There are no additional tool integrations right now")
|
||||
|
||||
|
||||
def form_user_prompt_text(
|
||||
query: str,
|
||||
tool_text: str | None,
|
||||
hint_text: str | None,
|
||||
user_input_prompt: str = USER_INPUT,
|
||||
tool_less_prompt: str = TOOL_LESS_PROMPT,
|
||||
) -> str:
|
||||
user_prompt = tool_text or tool_less_prompt
|
||||
|
||||
user_prompt += user_input_prompt.format(user_input=query)
|
||||
|
||||
if hint_text:
|
||||
if user_prompt[-1] != "\n":
|
||||
user_prompt += "\n"
|
||||
user_prompt += "\nHint: " + hint_text
|
||||
|
||||
return user_prompt.strip()
|
||||
|
||||
|
||||
def form_tool_section_text(
|
||||
tools: list[ToolInfo] | None, retrieval_enabled: bool, template: str = TOOL_TEMPLATE
|
||||
) -> str | None:
|
||||
if not tools and not retrieval_enabled:
|
||||
return None
|
||||
|
||||
if retrieval_enabled and tools:
|
||||
tools.append(
|
||||
{"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION}
|
||||
)
|
||||
|
||||
tools_intro = []
|
||||
if tools:
|
||||
num_tools = len(tools)
|
||||
for tool in tools:
|
||||
description_formatted = tool["description"].replace("\n", " ")
|
||||
tools_intro.append(f"> {tool['name']}: {description_formatted}")
|
||||
|
||||
prefix = "Must be one of " if num_tools > 1 else "Must be "
|
||||
|
||||
tools_intro_text = "\n".join(tools_intro)
|
||||
tool_names_text = prefix + ", ".join([tool["name"] for tool in tools])
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
return template.format(
|
||||
tool_overviews=tools_intro_text, tool_names=tool_names_text
|
||||
).strip()
|
||||
|
||||
|
||||
def form_tool_followup_text(
|
||||
tool_output: str,
|
||||
query: str,
|
||||
hint_text: str | None,
|
||||
tool_followup_prompt: str = TOOL_FOLLOWUP,
|
||||
ignore_hint: bool = False,
|
||||
) -> str:
|
||||
# If multi-line query, it likely confuses the model more than helps
|
||||
if "\n" not in query:
|
||||
optional_reminder = f"\nAs a reminder, my query was: {query}\n"
|
||||
else:
|
||||
optional_reminder = ""
|
||||
|
||||
if not ignore_hint and hint_text:
|
||||
hint_text_spaced = f"\nHint: {hint_text}\n"
|
||||
else:
|
||||
hint_text_spaced = ""
|
||||
|
||||
return tool_followup_prompt.format(
|
||||
tool_output=tool_output,
|
||||
optional_reminder=optional_reminder,
|
||||
hint=hint_text_spaced,
|
||||
).strip()
|
||||
|
||||
|
||||
def form_tool_less_followup_text(
|
||||
tool_output: str,
|
||||
query: str,
|
||||
hint_text: str | None,
|
||||
tool_followup_prompt: str = TOOL_LESS_FOLLOWUP,
|
||||
) -> str:
|
||||
hint = f"Hint: {hint_text}" if hint_text else ""
|
||||
return tool_followup_prompt.format(
|
||||
context_str=tool_output, user_query=query, hint_text=hint
|
||||
).strip()
|
||||
|
||||
@@ -3,11 +3,16 @@ import os
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.configs.constants import DocumentIndexType
|
||||
|
||||
|
||||
#####
|
||||
# App Configs
|
||||
#####
|
||||
APP_HOST = "0.0.0.0"
|
||||
APP_PORT = 8080
|
||||
# API_PREFIX is used to prepend a base path for all API routes
|
||||
# generally used if using a reverse proxy which doesn't support stripping the `/api`
|
||||
# prefix from requests directed towards the API server. In these cases, set this to `/api`
|
||||
APP_API_PREFIX = os.environ.get("API_PREFIX", "")
|
||||
|
||||
|
||||
#####
|
||||
@@ -15,10 +20,9 @@ APP_PORT = 8080
|
||||
#####
|
||||
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
|
||||
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = 86400 # 1 day
|
||||
# DISABLE_GENERATIVE_AI will turn of the question answering part of Danswer.
|
||||
# Use this if you want to use Danswer as a search engine only without the LLM capabilities
|
||||
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
|
||||
|
||||
|
||||
#####
|
||||
# Web Configs
|
||||
#####
|
||||
@@ -39,7 +43,7 @@ MASK_CREDENTIAL_PREFIX = (
|
||||
|
||||
SECRET = os.environ.get("SECRET", "")
|
||||
SESSION_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS", 86400)
|
||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400
|
||||
) # 1 day
|
||||
|
||||
# set `VALID_EMAIL_DOMAINS` to a comma seperated list of domains in order to
|
||||
@@ -56,7 +60,6 @@ VALID_EMAIL_DOMAINS = (
|
||||
if _VALID_EMAIL_DOMAINS_STR
|
||||
else []
|
||||
)
|
||||
|
||||
# OAuth Login Flow
|
||||
# Used for both Google OAuth2 and OIDC flows
|
||||
OAUTH_CLIENT_ID = (
|
||||
@@ -67,12 +70,12 @@ OAUTH_CLIENT_SECRET = (
|
||||
or ""
|
||||
)
|
||||
|
||||
# The following Basic Auth configs are not supported by the frontend UI
|
||||
# for basic auth
|
||||
REQUIRE_EMAIL_VERIFICATION = (
|
||||
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
|
||||
)
|
||||
SMTP_SERVER = os.environ.get("SMTP_SERVER", "smtp.gmail.com")
|
||||
SMTP_PORT = int(os.environ.get("SMTP_PORT", "587"))
|
||||
SMTP_SERVER = os.environ.get("SMTP_SERVER") or "smtp.gmail.com"
|
||||
SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587")
|
||||
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
|
||||
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
|
||||
|
||||
@@ -80,7 +83,7 @@ SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
|
||||
#####
|
||||
# DB Configs
|
||||
#####
|
||||
DOCUMENT_INDEX_NAME = "danswer_index" # Shared by vector/keyword indices
|
||||
DOCUMENT_INDEX_NAME = "danswer_index"
|
||||
# Vespa is now the default document index store for both keyword and vector
|
||||
DOCUMENT_INDEX_TYPE = os.environ.get(
|
||||
"DOCUMENT_INDEX_TYPE", DocumentIndexType.COMBINED.value
|
||||
@@ -93,7 +96,10 @@ VESPA_DEPLOYMENT_ZIP = (
|
||||
os.environ.get("VESPA_DEPLOYMENT_ZIP") or "/app/danswer/vespa-app.zip"
|
||||
)
|
||||
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
|
||||
INDEX_BATCH_SIZE = 16
|
||||
try:
|
||||
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE", 16))
|
||||
except ValueError:
|
||||
INDEX_BATCH_SIZE = 16
|
||||
|
||||
# Below are intended to match the env variables names used by the official postgres docker image
|
||||
# https://hub.docker.com/_/postgres
|
||||
@@ -140,80 +146,17 @@ CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [
|
||||
|
||||
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
|
||||
|
||||
EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
DASK_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
EXPERIMENTAL_CHECKPOINTING_ENABLED = (
|
||||
os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
#####
|
||||
# Query Configs
|
||||
#####
|
||||
NUM_RETURNED_HITS = 50
|
||||
NUM_RERANKED_RESULTS = 15
|
||||
# We feed in document chunks until we reach this token limit.
|
||||
# Default is ~5 full chunks (max chunk size is 2000 chars), although some chunks
|
||||
# may be smaller which could result in passing in more total chunks
|
||||
NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int(
|
||||
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (512 * 5)
|
||||
)
|
||||
NUM_DOCUMENT_TOKENS_FED_TO_CHAT = int(
|
||||
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_CHAT") or (512 * 3)
|
||||
)
|
||||
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
|
||||
# Capped in Vespa at 0.5
|
||||
DOC_TIME_DECAY = float(
|
||||
os.environ.get("DOC_TIME_DECAY") or 0.5 # Hits limit at 2 years by default
|
||||
)
|
||||
FAVOR_RECENT_DECAY_MULTIPLIER = 2
|
||||
DISABLE_TIME_FILTER_EXTRACTION = (
|
||||
os.environ.get("DISABLE_TIME_FILTER_EXTRACTION", "").lower() == "true"
|
||||
)
|
||||
# 1 edit per 2 characters, currently unused due to fuzzy match being too slow
|
||||
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
|
||||
# Include additional document/chunk metadata in prompt to GenerativeAI
|
||||
INCLUDE_METADATA = False
|
||||
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "True").lower() != "false"
|
||||
# Keyword Search Drop Stopwords
|
||||
# If user has changed the default model, would most likely be to use a multilingual
|
||||
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords
|
||||
if os.environ.get("EDIT_KEYWORD_QUERY"):
|
||||
EDIT_KEYWORD_QUERY = os.environ.get("EDIT_KEYWORD_QUERY", "").lower() == "true"
|
||||
else:
|
||||
EDIT_KEYWORD_QUERY = not os.environ.get("DOCUMENT_ENCODER_MODEL")
|
||||
|
||||
|
||||
#####
|
||||
# Text Processing Configs
|
||||
# Indexing Configs
|
||||
#####
|
||||
CHUNK_SIZE = 512 # Tokens by embedding model
|
||||
CHUNK_OVERLAP = int(CHUNK_SIZE * 0.05) # 5% overlap
|
||||
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
|
||||
ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
|
||||
# Finer grained chunking for more detail retention
|
||||
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
|
||||
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
|
||||
MINI_CHUNK_SIZE = 150
|
||||
|
||||
|
||||
#####
|
||||
# Encoder Model Endpoint Configs (Currently unused, running the models in memory)
|
||||
#####
|
||||
BI_ENCODER_HOST = "localhost"
|
||||
BI_ENCODER_PORT = 9000
|
||||
CROSS_ENCODER_HOST = "localhost"
|
||||
CROSS_ENCODER_PORT = 9000
|
||||
|
||||
|
||||
#####
|
||||
# Miscellaneous
|
||||
#####
|
||||
PERSONAS_YAML = "./danswer/chat/personas.yaml"
|
||||
DYNAMIC_CONFIG_STORE = os.environ.get(
|
||||
"DYNAMIC_CONFIG_STORE", "FileSystemBackedDynamicConfigStore"
|
||||
)
|
||||
DYNAMIC_CONFIG_DIR_PATH = os.environ.get("DYNAMIC_CONFIG_DIR_PATH", "/home/storage")
|
||||
# notset, debug, info, warning, error, or critical
|
||||
LOG_LEVEL = os.environ.get("LOG_LEVEL", "info")
|
||||
# NOTE: Currently only supported in the Confluence and Google Drive connectors +
|
||||
# only handles some failures (Confluence = handles API call failures, Google
|
||||
# Drive = handles failures pulling files / parsing them)
|
||||
@@ -225,8 +168,57 @@ CONTINUE_ON_CONNECTOR_FAILURE = os.environ.get(
|
||||
# fairly large amount of memory in order to increase substantially, since
|
||||
# each worker loads the embedding models into memory.
|
||||
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
|
||||
CHUNK_OVERLAP = 0
|
||||
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
|
||||
ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
|
||||
# Finer grained chunking for more detail retention
|
||||
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
|
||||
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
|
||||
MINI_CHUNK_SIZE = 150
|
||||
# Timeout to wait for job's last update before killing it, in hours
|
||||
CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 1))
|
||||
|
||||
|
||||
#####
|
||||
# Model Server Configs
|
||||
#####
|
||||
# If MODEL_SERVER_HOST is set, the NLP models required for Danswer are offloaded to the server via
|
||||
# requests. Be sure to include the scheme in the MODEL_SERVER_HOST value.
|
||||
MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or None
|
||||
MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0"
|
||||
MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000")
|
||||
|
||||
# specify this env variable directly to have a different model server for the background
|
||||
# indexing job vs the api server so that background indexing does not effect query-time
|
||||
# performance
|
||||
INDEXING_MODEL_SERVER_HOST = (
|
||||
os.environ.get("INDEXING_MODEL_SERVER_HOST") or MODEL_SERVER_HOST
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# Miscellaneous
|
||||
#####
|
||||
DYNAMIC_CONFIG_STORE = os.environ.get(
|
||||
"DYNAMIC_CONFIG_STORE", "FileSystemBackedDynamicConfigStore"
|
||||
)
|
||||
DYNAMIC_CONFIG_DIR_PATH = os.environ.get("DYNAMIC_CONFIG_DIR_PATH", "/home/storage")
|
||||
JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
|
||||
# used to allow the background indexing jobs to use a different embedding
|
||||
# model server than the API server
|
||||
CURRENT_PROCESS_IS_AN_INDEXING_JOB = (
|
||||
os.environ.get("CURRENT_PROCESS_IS_AN_INDEXING_JOB", "").lower() == "true"
|
||||
)
|
||||
# Logs every model prompt and output, mostly used for development or exploration purposes
|
||||
LOG_ALL_MODEL_INTERACTIONS = (
|
||||
os.environ.get("LOG_ALL_MODEL_INTERACTIONS", "").lower() == "true"
|
||||
)
|
||||
# If set to `true` will enable additional logs about Vespa query performance
|
||||
# (time spent on finding the right docs + time spent fetching summaries from disk)
|
||||
LOG_VESPA_TIMING_INFORMATION = (
|
||||
os.environ.get("LOG_VESPA_TIMING_INFORMATION", "").lower() == "true"
|
||||
)
|
||||
# Anonymous usage telemetry
|
||||
DISABLE_TELEMETRY = os.environ.get("DISABLE_TELEMETRY", "").lower() == "true"
|
||||
# notset, debug, info, warning, error, or critical
|
||||
LOG_LEVEL = os.environ.get("LOG_LEVEL", "info")
|
||||
|
||||
@@ -1,3 +1,75 @@
|
||||
import os
|
||||
|
||||
FORCE_TOOL_PROMPT = os.environ.get("FORCE_TOOL_PROMPT", "").lower() == "true"
|
||||
from danswer.configs.model_configs import CHUNK_SIZE
|
||||
|
||||
PROMPTS_YAML = "./danswer/chat/prompts.yaml"
|
||||
PERSONAS_YAML = "./danswer/chat/personas.yaml"
|
||||
|
||||
NUM_RETURNED_HITS = 50
|
||||
NUM_RERANKED_RESULTS = 15
|
||||
# We feed in document chunks until we reach this token limit.
|
||||
# Default is ~5 full chunks (max chunk size is 2000 chars), although some chunks may be
|
||||
# significantly smaller which could result in passing in more total chunks.
|
||||
# There is also a slight bit of overhead, not accounted for here such as separator patterns
|
||||
# between the docs, metadata for the docs, etc.
|
||||
# Finally, this is combined with the rest of the QA prompt, so don't set this too close to the
|
||||
# model token limit
|
||||
NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int(
|
||||
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (CHUNK_SIZE * 5)
|
||||
)
|
||||
DEFAULT_NUM_CHUNKS_FED_TO_CHAT: float = (
|
||||
float(NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL) / CHUNK_SIZE
|
||||
)
|
||||
NUM_DOCUMENT_TOKENS_FED_TO_CHAT = int(
|
||||
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_CHAT") or (CHUNK_SIZE * 3)
|
||||
)
|
||||
# For selecting a different LLM question-answering prompt format
|
||||
# Valid values: default, cot, weak
|
||||
QA_PROMPT_OVERRIDE = os.environ.get("QA_PROMPT_OVERRIDE") or None
|
||||
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
|
||||
# Capped in Vespa at 0.5
|
||||
DOC_TIME_DECAY = float(
|
||||
os.environ.get("DOC_TIME_DECAY") or 0.5 # Hits limit at 2 years by default
|
||||
)
|
||||
FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
|
||||
# Currently this next one is not configurable via env
|
||||
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
|
||||
DISABLE_LLM_FILTER_EXTRACTION = (
|
||||
os.environ.get("DISABLE_LLM_FILTER_EXTRACTION", "").lower() == "true"
|
||||
)
|
||||
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
|
||||
# in relation to the user query
|
||||
DISABLE_LLM_CHUNK_FILTER = (
|
||||
os.environ.get("DISABLE_LLM_CHUNK_FILTER", "").lower() == "true"
|
||||
)
|
||||
# Whether the LLM should be used to decide if a search would help given the chat history
|
||||
DISABLE_LLM_CHOOSE_SEARCH = (
|
||||
os.environ.get("DISABLE_LLM_CHOOSE_SEARCH", "").lower() == "true"
|
||||
)
|
||||
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
|
||||
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
|
||||
# Include additional document/chunk metadata in prompt to GenerativeAI
|
||||
INCLUDE_METADATA = False
|
||||
# Keyword Search Drop Stopwords
|
||||
# If user has changed the default model, would most likely be to use a multilingual
|
||||
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords
|
||||
if os.environ.get("EDIT_KEYWORD_QUERY"):
|
||||
EDIT_KEYWORD_QUERY = os.environ.get("EDIT_KEYWORD_QUERY", "").lower() == "true"
|
||||
else:
|
||||
EDIT_KEYWORD_QUERY = not os.environ.get("DOCUMENT_ENCODER_MODEL")
|
||||
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
|
||||
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.66)))
|
||||
# Weighting factor between Title and Content of documents during search, 1 for completely
|
||||
# Title based. Default heavily favors Content because Title is also included at the top of
|
||||
# Content. This is to avoid cases where the Content is very relevant but it may not be clear
|
||||
# if the title is separated out. Title is most of a "boost" than a separate field.
|
||||
TITLE_CONTENT_RATIO = max(
|
||||
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.20))
|
||||
)
|
||||
# A list of languages passed to the LLM to rephase the query
|
||||
# For example "English,French,Spanish", be sure to use the "," separator
|
||||
MULTILINGUAL_QUERY_EXPANSION = os.environ.get("MULTILINGUAL_QUERY_EXPANSION") or None
|
||||
|
||||
# The backend logic for this being True isn't fully supported yet
|
||||
HARD_DELETE_CHATS = False
|
||||
|
||||
@@ -11,11 +11,13 @@ SEMANTIC_IDENTIFIER = "semantic_identifier"
|
||||
TITLE = "title"
|
||||
SECTION_CONTINUATION = "section_continuation"
|
||||
EMBEDDINGS = "embeddings"
|
||||
TITLE_EMBEDDING = "title_embedding"
|
||||
ALLOWED_USERS = "allowed_users"
|
||||
ACCESS_CONTROL_LIST = "access_control_list"
|
||||
DOCUMENT_SETS = "document_sets"
|
||||
TIME_FILTER = "time_filter"
|
||||
METADATA = "metadata"
|
||||
METADATA_LIST = "metadata_list"
|
||||
MATCH_HIGHLIGHTS = "match_highlights"
|
||||
# stored in the `metadata` of a chunk. Used to signify that this chunk should
|
||||
# not be used for QA. For example, Google Drive file types which can't be parsed
|
||||
@@ -35,26 +37,31 @@ SCORE = "score"
|
||||
ID_SEPARATOR = ":;:"
|
||||
DEFAULT_BOOST = 0
|
||||
SESSION_KEY = "session"
|
||||
QUERY_EVENT_ID = "query_event_id"
|
||||
LLM_CHUNKS = "llm_chunks"
|
||||
|
||||
# Prompt building constants:
|
||||
GENERAL_SEP_PAT = "\n-----\n"
|
||||
CODE_BLOCK_PAT = "\n```\n{}\n```\n"
|
||||
DOC_SEP_PAT = "---NEW DOCUMENT---"
|
||||
DOC_CONTENT_START_PAT = "DOCUMENT CONTENTS:\n"
|
||||
QUESTION_PAT = "Query:"
|
||||
THOUGHT_PAT = "Thought:"
|
||||
ANSWER_PAT = "Answer:"
|
||||
FINAL_ANSWER_PAT = "Final Answer:"
|
||||
UNCERTAINTY_PAT = "?"
|
||||
QUOTE_PAT = "Quote:"
|
||||
QUOTES_PAT_PLURAL = "Quotes:"
|
||||
INVALID_PAT = "Invalid:"
|
||||
# For chunking/processing chunks
|
||||
TITLE_SEPARATOR = "\n\r\n"
|
||||
SECTION_SEPARATOR = "\n\n"
|
||||
# For combining attributes, doesn't have to be unique/perfect to work
|
||||
INDEX_SEPARATOR = "==="
|
||||
|
||||
|
||||
# Messages
|
||||
DISABLED_GEN_AI_MSG = (
|
||||
"Your System Admin has disabled the Generative AI functionalities of Danswer.\n"
|
||||
"Please contact them if you wish to have this enabled.\n"
|
||||
"You can still use Danswer as a search engine."
|
||||
)
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
# Special case, document passed in via Danswer APIs without specifying a source type
|
||||
INGESTION_API = "ingestion_api"
|
||||
SLACK = "slack"
|
||||
WEB = "web"
|
||||
GOOGLE_DRIVE = "google_drive"
|
||||
REQUESTTRACKER = "requesttracker"
|
||||
GITHUB = "github"
|
||||
GURU = "guru"
|
||||
BOOKSTACK = "bookstack"
|
||||
@@ -86,11 +93,6 @@ class AuthType(str, Enum):
|
||||
SAML = "saml"
|
||||
|
||||
|
||||
class QAFeedbackType(str, Enum):
|
||||
LIKE = "like" # User likes the answer, used for metrics
|
||||
DISLIKE = "dislike" # User dislikes the answer, used for metrics
|
||||
|
||||
|
||||
class SearchFeedbackType(str, Enum):
|
||||
ENDORSE = "endorse" # boost this document for all future queries
|
||||
REJECT = "reject" # down-boost this document for all future queries
|
||||
@@ -100,7 +102,7 @@ class SearchFeedbackType(str, Enum):
|
||||
|
||||
class MessageType(str, Enum):
|
||||
# Using OpenAI standards, Langchain equivalent shown in comment
|
||||
# System message is always constructed on the fly, not saved
|
||||
SYSTEM = "system" # SystemMessage
|
||||
USER = "user" # HumanMessage
|
||||
ASSISTANT = "assistant" # AIMessage
|
||||
DANSWER = "danswer" # FunctionMessage
|
||||
|
||||
@@ -41,11 +41,10 @@ DISABLE_DANSWER_BOT_FILTER_DETECT = (
|
||||
)
|
||||
# Add a second LLM call post Answer to verify if the Answer is valid
|
||||
# Throws out answers that don't directly or fully answer the user query
|
||||
# This is the default for all DanswerBot channels unless the bot is configured individually
|
||||
# This is the default for all DanswerBot channels unless the channel is configured individually
|
||||
# Set/unset by "Hide Non Answers"
|
||||
ENABLE_DANSWERBOT_REFLEXION = (
|
||||
os.environ.get("ENABLE_DANSWERBOT_REFLEXION", "").lower() == "true"
|
||||
)
|
||||
# Add the per document feedback blocks that affect the document rankings via boosting
|
||||
ENABLE_SLACK_DOC_FEEDBACK = (
|
||||
os.environ.get("ENABLE_SLACK_DOC_FEEDBACK", "").lower() == "true"
|
||||
)
|
||||
# Currently not support chain of thought, probably will add back later
|
||||
DANSWER_BOT_DISABLE_COT = True
|
||||
|
||||
@@ -3,11 +3,11 @@ import os
|
||||
#####
|
||||
# Embedding/Reranking Model Configs
|
||||
#####
|
||||
CHUNK_SIZE = 512
|
||||
# Important considerations when choosing models
|
||||
# Max tokens count needs to be high considering use case (at least 512)
|
||||
# Models used must be MIT or Apache license
|
||||
# Inference/Indexing speed
|
||||
|
||||
# https://huggingface.co/DOCUMENT_ENCODER_MODEL
|
||||
# The useable models configured as below must be SentenceTransformer compatible
|
||||
DOCUMENT_ENCODER_MODEL = (
|
||||
@@ -21,6 +21,7 @@ NORMALIZE_EMBEDDINGS = (
|
||||
os.environ.get("NORMALIZE_EMBEDDINGS") or "False"
|
||||
).lower() == "true"
|
||||
# These are only used if reranking is turned off, to normalize the direct retrieval scores for display
|
||||
# Currently unused
|
||||
SIM_SCORE_RANGE_LOW = float(os.environ.get("SIM_SCORE_RANGE_LOW") or 0.0)
|
||||
SIM_SCORE_RANGE_HIGH = float(os.environ.get("SIM_SCORE_RANGE_HIGH") or 1.0)
|
||||
# Certain models like e5, BGE, etc use a prefix for asymmetric retrievals (query generally shorter than docs)
|
||||
@@ -34,7 +35,12 @@ MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1)
|
||||
|
||||
|
||||
# Cross Encoder Settings
|
||||
SKIP_RERANKING = os.environ.get("SKIP_RERANKING", "").lower() == "true"
|
||||
ENABLE_RERANKING_ASYNC_FLOW = (
|
||||
os.environ.get("ENABLE_RERANKING_ASYNC_FLOW", "").lower() == "true"
|
||||
)
|
||||
ENABLE_RERANKING_REAL_TIME_FLOW = (
|
||||
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
|
||||
)
|
||||
# https://www.sbert.net/docs/pretrained-models/ce-msmarco.html
|
||||
CROSS_ENCODER_MODEL_ENSEMBLE = [
|
||||
"cross-encoder/ms-marco-MiniLM-L-4-v2",
|
||||
@@ -70,6 +76,11 @@ INTENT_MODEL_VERSION = "danswer/intent-model"
|
||||
GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai"
|
||||
# If using Azure, it's the engine name, for example: Danswer
|
||||
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo"
|
||||
# For secondary flows like extracting filters or deciding if a chunk is useful, we don't need
|
||||
# as powerful of a model as say GPT-4 so we can use an alternative that is faster and cheaper
|
||||
FAST_GEN_AI_MODEL_VERSION = (
|
||||
os.environ.get("FAST_GEN_AI_MODEL_VERSION") or GEN_AI_MODEL_VERSION
|
||||
)
|
||||
|
||||
# If the Generative AI model requires an API key for access, otherwise can leave blank
|
||||
GEN_AI_API_KEY = (
|
||||
@@ -80,9 +91,14 @@ GEN_AI_API_KEY = (
|
||||
GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None
|
||||
# API Version, such as (for Azure): 2023-09-15-preview
|
||||
GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None
|
||||
# LiteLLM custom_llm_provider
|
||||
GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None
|
||||
|
||||
# Set this to be enough for an answer + quotes. Also used for Chat
|
||||
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
|
||||
# This next restriction is only used for chat ATM, used to expire old messages as needed
|
||||
GEN_AI_MAX_INPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_INPUT_TOKENS") or 3000)
|
||||
# History for secondary LLM flows, not primary chat flow, generally we don't need to
|
||||
# include as much as possible as this just bumps up the cost unnecessarily
|
||||
GEN_AI_HISTORY_CUTOFF = int(0.5 * GEN_AI_MAX_INPUT_TOKENS)
|
||||
GEN_AI_TEMPERATURE = float(os.environ.get("GEN_AI_TEMPERATURE") or 0)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/backend/danswer/connectors/README.md"} -->
|
||||
|
||||
# Writing a new Danswer Connector
|
||||
This README covers how to contribute a new Connector for Danswer. It includes an overview of the design, interfaces,
|
||||
and required changes.
|
||||
|
||||
@@ -8,6 +8,7 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.bookstack.client import BookStackApiClient
|
||||
from danswer.connectors.cross_connector_utils.html_utils import parse_html_page_basic
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
@@ -72,13 +73,21 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
bookstack_client: BookStackApiClient, book: dict[str, Any]
|
||||
) -> Document:
|
||||
url = bookstack_client.build_app_url("/books/" + str(book.get("slug")))
|
||||
title = str(book.get("name", ""))
|
||||
text = book.get("name", "") + "\n" + book.get("description", "")
|
||||
updated_at_str = (
|
||||
str(book.get("updated_at")) if book.get("updated_at") is not None else None
|
||||
)
|
||||
return Document(
|
||||
id="book:" + str(book.get("id")),
|
||||
id="book__" + str(book.get("id")),
|
||||
sections=[Section(link=url, text=text)],
|
||||
source=DocumentSource.BOOKSTACK,
|
||||
semantic_identifier="Book: " + str(book.get("name")),
|
||||
metadata={"type": "book", "updated_at": str(book.get("updated_at"))},
|
||||
semantic_identifier="Book: " + title,
|
||||
title=title,
|
||||
doc_updated_at=time_str_to_utc(updated_at_str)
|
||||
if updated_at_str is not None
|
||||
else None,
|
||||
metadata={"type": "book"},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -91,13 +100,23 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
+ "/chapter/"
|
||||
+ str(chapter.get("slug"))
|
||||
)
|
||||
title = str(chapter.get("name", ""))
|
||||
text = chapter.get("name", "") + "\n" + chapter.get("description", "")
|
||||
updated_at_str = (
|
||||
str(chapter.get("updated_at"))
|
||||
if chapter.get("updated_at") is not None
|
||||
else None
|
||||
)
|
||||
return Document(
|
||||
id="chapter:" + str(chapter.get("id")),
|
||||
id="chapter__" + str(chapter.get("id")),
|
||||
sections=[Section(link=url, text=text)],
|
||||
source=DocumentSource.BOOKSTACK,
|
||||
semantic_identifier="Chapter: " + str(chapter.get("name")),
|
||||
metadata={"type": "chapter", "updated_at": str(chapter.get("updated_at"))},
|
||||
semantic_identifier="Chapter: " + title,
|
||||
title=title,
|
||||
doc_updated_at=time_str_to_utc(updated_at_str)
|
||||
if updated_at_str is not None
|
||||
else None,
|
||||
metadata={"type": "chapter"},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -105,13 +124,23 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
bookstack_client: BookStackApiClient, shelf: dict[str, Any]
|
||||
) -> Document:
|
||||
url = bookstack_client.build_app_url("/shelves/" + str(shelf.get("slug")))
|
||||
title = str(shelf.get("name", ""))
|
||||
text = shelf.get("name", "") + "\n" + shelf.get("description", "")
|
||||
updated_at_str = (
|
||||
str(shelf.get("updated_at"))
|
||||
if shelf.get("updated_at") is not None
|
||||
else None
|
||||
)
|
||||
return Document(
|
||||
id="shelf:" + str(shelf.get("id")),
|
||||
sections=[Section(link=url, text=text)],
|
||||
source=DocumentSource.BOOKSTACK,
|
||||
semantic_identifier="Shelf: " + str(shelf.get("name")),
|
||||
metadata={"type": "shelf", "updated_at": shelf.get("updated_at")},
|
||||
semantic_identifier="Shelf: " + title,
|
||||
title=title,
|
||||
doc_updated_at=time_str_to_utc(updated_at_str)
|
||||
if updated_at_str is not None
|
||||
else None,
|
||||
metadata={"type": "shelf"},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -119,7 +148,7 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
bookstack_client: BookStackApiClient, page: dict[str, Any]
|
||||
) -> Document:
|
||||
page_id = str(page.get("id"))
|
||||
page_name = str(page.get("name"))
|
||||
title = str(page.get("name", ""))
|
||||
page_data = bookstack_client.get("/pages/" + page_id, {})
|
||||
url = bookstack_client.build_app_url(
|
||||
"/books/"
|
||||
@@ -127,17 +156,24 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
+ "/page/"
|
||||
+ str(page_data.get("slug"))
|
||||
)
|
||||
page_html = (
|
||||
"<h1>" + html.escape(page_name) + "</h1>" + str(page_data.get("html"))
|
||||
)
|
||||
page_html = "<h1>" + html.escape(title) + "</h1>" + str(page_data.get("html"))
|
||||
text = parse_html_page_basic(page_html)
|
||||
updated_at_str = (
|
||||
str(page_data.get("updated_at"))
|
||||
if page_data.get("updated_at") is not None
|
||||
else None
|
||||
)
|
||||
time.sleep(0.1)
|
||||
return Document(
|
||||
id="page:" + page_id,
|
||||
sections=[Section(link=url, text=text)],
|
||||
source=DocumentSource.BOOKSTACK,
|
||||
semantic_identifier="Page: " + str(page_name),
|
||||
metadata={"type": "page", "updated_at": page_data.get("updated_at")},
|
||||
semantic_identifier="Page: " + str(title),
|
||||
title=str(title),
|
||||
doc_updated_at=time_str_to_utc(updated_at_str)
|
||||
if updated_at_str is not None
|
||||
else None,
|
||||
metadata={"type": "page"},
|
||||
)
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
|
||||
@@ -2,10 +2,12 @@ from collections.abc import Callable
|
||||
from collections.abc import Collection
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import bs4
|
||||
from atlassian import Confluence # type:ignore
|
||||
from requests import HTTPError
|
||||
|
||||
@@ -13,11 +15,12 @@ from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.html_utils import parse_html_page_basic
|
||||
from danswer.connectors.cross_connector_utils.html_utils import format_document_soup
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
@@ -84,6 +87,53 @@ def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, bool]:
|
||||
return wiki_base, space, is_confluence_cloud
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def _get_user(user_id: str, confluence_client: Confluence) -> str:
|
||||
"""Get Confluence Display Name based on the account-id or userkey value
|
||||
|
||||
Args:
|
||||
user_id (str): The user id (i.e: the account-id or userkey)
|
||||
confluence_client (Confluence): The Confluence Client
|
||||
|
||||
Returns:
|
||||
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
|
||||
"""
|
||||
user_not_found = "Unknown User"
|
||||
|
||||
try:
|
||||
return confluence_client.get_user_details_by_accountid(user_id).get(
|
||||
"displayName", user_not_found
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Unable to get the User Display Name with the id: '{user_id}' - {e}"
|
||||
)
|
||||
return user_not_found
|
||||
|
||||
|
||||
def parse_html_page(text: str, confluence_client: Confluence) -> str:
|
||||
"""Parse a Confluence html page and replace the 'user Id' by the real
|
||||
User Display Name
|
||||
|
||||
Args:
|
||||
text (str): The page content
|
||||
confluence_client (Confluence): Confluence client
|
||||
|
||||
Returns:
|
||||
str: loaded and formated Confluence page
|
||||
"""
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
for user in soup.findAll("ri:user"):
|
||||
user_id = (
|
||||
user.attrs["ri:account-id"]
|
||||
if "ri:account-id" in user.attrs
|
||||
else user.attrs["ri:userkey"]
|
||||
)
|
||||
# Include @ sign for tagging, more clear for LLM
|
||||
user.replaceWith("@" + _get_user(user_id, confluence_client))
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
def _comment_dfs(
|
||||
comments_str: str,
|
||||
comment_pages: Collection[dict[str, Any]],
|
||||
@@ -91,7 +141,9 @@ def _comment_dfs(
|
||||
) -> str:
|
||||
for comment_page in comment_pages:
|
||||
comment_html = comment_page["body"]["storage"]["value"]
|
||||
comments_str += "\nComment:\n" + parse_html_page_basic(comment_html)
|
||||
comments_str += "\nComment:\n" + parse_html_page(
|
||||
comment_html, confluence_client
|
||||
)
|
||||
child_comment_pages = confluence_client.get_page_child_by_type(
|
||||
comment_page["id"],
|
||||
type="comment",
|
||||
@@ -281,9 +333,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
if not page_html:
|
||||
logger.debug("Page is empty, skipping: %s", page_url)
|
||||
continue
|
||||
page_text = (
|
||||
page.get("title", "") + "\n" + parse_html_page_basic(page_html)
|
||||
)
|
||||
page_text = parse_html_page(page_html, self.confluence_client)
|
||||
comments_text = self._fetch_comments(self.confluence_client, page_id)
|
||||
page_text += comments_text
|
||||
|
||||
@@ -294,7 +344,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=page["title"],
|
||||
doc_updated_at=last_modified,
|
||||
primary_owners=[author] if author else None,
|
||||
primary_owners=[BasicExpertInfo(email=author)]
|
||||
if author
|
||||
else None,
|
||||
metadata={
|
||||
"Wiki Space Name": self.space,
|
||||
},
|
||||
|
||||
@@ -1,45 +1,71 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import zipfile
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import IO
|
||||
|
||||
import chardet
|
||||
from pypdf import PdfReader
|
||||
from pypdf.errors import PdfStreamError
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_METADATA_FLAG = "#DANSWER_METADATA="
|
||||
|
||||
def extract_metadata(line: str) -> dict | None:
|
||||
html_comment_pattern = r"<!--\s*DANSWER_METADATA=\{(.*?)\}\s*-->"
|
||||
hashtag_pattern = r"#DANSWER_METADATA=\{(.*?)\}"
|
||||
|
||||
html_comment_match = re.search(html_comment_pattern, line)
|
||||
hashtag_match = re.search(hashtag_pattern, line)
|
||||
|
||||
if html_comment_match:
|
||||
json_str = html_comment_match.group(1)
|
||||
elif hashtag_match:
|
||||
json_str = hashtag_match.group(1)
|
||||
else:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads("{" + json_str + "}")
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
def read_pdf_file(file: IO[Any], file_name: str, pdf_pass: str | None = None) -> str:
|
||||
pdf_reader = PdfReader(file)
|
||||
|
||||
# if marked as encrypted and a password is provided, try to decrypt
|
||||
if pdf_reader.is_encrypted and pdf_pass is not None:
|
||||
decrypt_success = False
|
||||
if pdf_pass is not None:
|
||||
try:
|
||||
decrypt_success = pdf_reader.decrypt(pdf_pass) != 0
|
||||
except Exception:
|
||||
logger.error(f"Unable to decrypt pdf {file_name}")
|
||||
else:
|
||||
logger.info(f"No Password available to to decrypt pdf {file_name}")
|
||||
|
||||
if not decrypt_success:
|
||||
# By user request, keep files that are unreadable just so they
|
||||
# can be discoverable by title.
|
||||
return ""
|
||||
|
||||
try:
|
||||
pdf_reader = PdfReader(file)
|
||||
|
||||
# If marked as encrypted and a password is provided, try to decrypt
|
||||
if pdf_reader.is_encrypted and pdf_pass is not None:
|
||||
decrypt_success = False
|
||||
if pdf_pass is not None:
|
||||
try:
|
||||
decrypt_success = pdf_reader.decrypt(pdf_pass) != 0
|
||||
except Exception:
|
||||
logger.error(f"Unable to decrypt pdf {file_name}")
|
||||
else:
|
||||
logger.info(f"No Password available to to decrypt pdf {file_name}")
|
||||
|
||||
if not decrypt_success:
|
||||
# By user request, keep files that are unreadable just so they
|
||||
# can be discoverable by title.
|
||||
return ""
|
||||
|
||||
return "\n".join(page.extract_text() for page in pdf_reader.pages)
|
||||
except PdfStreamError:
|
||||
logger.exception(f"PDF file {file_name} is not a valid PDF")
|
||||
except Exception:
|
||||
logger.exception(f"Failed to read PDF {file_name}")
|
||||
return ""
|
||||
|
||||
# File is still discoverable by title
|
||||
# but the contents are not included as they cannot be parsed
|
||||
return ""
|
||||
|
||||
|
||||
def is_macos_resource_fork_file(file_name: str) -> bool:
|
||||
@@ -66,16 +92,33 @@ def load_files_from_zip(
|
||||
yield file_info, file
|
||||
|
||||
|
||||
def read_file(file_reader: IO[Any]) -> tuple[str, dict[str, Any]]:
|
||||
def detect_encoding(file_path: str | Path) -> str:
|
||||
with open(file_path, "rb") as file:
|
||||
raw_data = file.read(50000) # Read a portion of the file to guess encoding
|
||||
return chardet.detect(raw_data)["encoding"] or "utf-8"
|
||||
|
||||
|
||||
def read_file(
|
||||
file_reader: IO[Any], encoding: str = "utf-8", errors: str = "replace"
|
||||
) -> tuple[str, dict]:
|
||||
metadata = {}
|
||||
file_content_raw = ""
|
||||
for ind, line in enumerate(file_reader):
|
||||
if isinstance(line, bytes):
|
||||
line = line.decode("utf-8")
|
||||
line = str(line)
|
||||
try:
|
||||
line = line.decode(encoding) if isinstance(line, bytes) else line
|
||||
except UnicodeDecodeError:
|
||||
line = (
|
||||
line.decode(encoding, errors=errors)
|
||||
if isinstance(line, bytes)
|
||||
else line
|
||||
)
|
||||
|
||||
if ind == 0 and line.startswith(_METADATA_FLAG):
|
||||
metadata = json.loads(line.replace(_METADATA_FLAG, "", 1).strip())
|
||||
if ind == 0:
|
||||
metadata_or_none = extract_metadata(line)
|
||||
if metadata_or_none is not None:
|
||||
metadata = metadata_or_none
|
||||
else:
|
||||
file_content_raw += line
|
||||
else:
|
||||
file_content_raw += line
|
||||
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from dateutil.parser import parse
|
||||
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.utils.text_processing import is_valid_email
|
||||
|
||||
|
||||
def datetime_to_utc(dt: datetime) -> datetime:
|
||||
if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
|
||||
return dt.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def time_str_to_utc(datetime_str: str) -> datetime:
|
||||
dt = parse(datetime_str)
|
||||
return datetime_to_utc(dt)
|
||||
|
||||
|
||||
def basic_expert_info_representation(info: BasicExpertInfo) -> str | None:
|
||||
if info.first_name and info.last_name:
|
||||
return f"{info.first_name} {info.middle_initial} {info.last_name}"
|
||||
|
||||
if info.display_name:
|
||||
return info.display_name
|
||||
|
||||
if info.email and is_valid_email(info.email):
|
||||
return info.email
|
||||
|
||||
if info.first_name:
|
||||
return info.first_name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_experts_stores_representations(
|
||||
experts: list[BasicExpertInfo] | None,
|
||||
) -> list[str] | None:
|
||||
if not experts:
|
||||
return None
|
||||
|
||||
reps = [basic_expert_info_representation(owner) for owner in experts]
|
||||
return [owner for owner in reps if owner is not None]
|
||||
@@ -3,16 +3,17 @@ from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from dateutil.parser import parse
|
||||
from jira import JIRA
|
||||
from jira.resources import Issue
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
@@ -60,26 +61,32 @@ def fetch_jira_issues_batch(
|
||||
logger.warning(f"Found Jira object not of type Issue {jira}")
|
||||
continue
|
||||
|
||||
ticket_updated_time = parse(jira.fields.updated)
|
||||
|
||||
semantic_rep = (
|
||||
f"Jira Ticket Summary: {jira.fields.summary}\n"
|
||||
f"Description: {jira.fields.description}\n"
|
||||
+ "\n".join(
|
||||
[f"Comment: {comment.body}" for comment in jira.fields.comment.comments]
|
||||
)
|
||||
semantic_rep = f"{jira.fields.description}\n" + "\n".join(
|
||||
[f"Comment: {comment.body}" for comment in jira.fields.comment.comments]
|
||||
)
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
|
||||
|
||||
author = None
|
||||
try:
|
||||
author = BasicExpertInfo(
|
||||
display_name=jira.fields.creator.displayName,
|
||||
email=jira.fields.creator.emailAddress,
|
||||
)
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=page_url,
|
||||
sections=[Section(link=page_url, text=semantic_rep)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=jira.fields.summary,
|
||||
doc_updated_at=ticket_updated_time.astimezone(timezone.utc),
|
||||
metadata={},
|
||||
doc_updated_at=time_str_to_utc(jira.fields.updated),
|
||||
primary_owners=[author] if author is not None else None,
|
||||
# TODO add secondary_owners if needed
|
||||
metadata={"label": jira.fields.labels} if jira.fields.labels else {},
|
||||
)
|
||||
)
|
||||
return doc_batch, len(batch)
|
||||
|
||||
@@ -140,11 +140,7 @@ class Document360Connector(LoadConnector, PollConnector):
|
||||
html_content = article_details["html_content"]
|
||||
article_content = parse_html_page_basic(html_content)
|
||||
doc_text = (
|
||||
f"workspace: {self.workspace}\n"
|
||||
f"category: {article['category_name']}\n"
|
||||
f"article: {article_details['title']} - "
|
||||
f"{article_details.get('description', '')}\n"
|
||||
f"{article_content}"
|
||||
f"{article_details.get('description', '')}\n{article_content}".strip()
|
||||
)
|
||||
|
||||
document = Document(
|
||||
@@ -154,7 +150,10 @@ class Document360Connector(LoadConnector, PollConnector):
|
||||
semantic_identifier=article_details["title"],
|
||||
doc_updated_at=updated_at,
|
||||
primary_owners=authors,
|
||||
metadata={},
|
||||
metadata={
|
||||
"workspace": self.workspace,
|
||||
"category": article["category_name"],
|
||||
},
|
||||
)
|
||||
|
||||
doc_batch.append(document)
|
||||
@@ -190,8 +189,8 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
current = time.time()
|
||||
one_day_ago = current - 24 * 60 * 60 * 360 # 1 year
|
||||
latest_docs = document360_connector.poll_source(one_day_ago, current)
|
||||
one_year_ago = current - 24 * 60 * 60 * 360
|
||||
latest_docs = document360_connector.poll_source(one_year_ago, current)
|
||||
|
||||
for doc in latest_docs:
|
||||
print(doc)
|
||||
|
||||
@@ -21,6 +21,7 @@ from danswer.connectors.linear.connector import LinearConnector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.connectors.notion.connector import NotionConnector
|
||||
from danswer.connectors.productboard.connector import ProductboardConnector
|
||||
from danswer.connectors.requesttracker.connector import RequestTrackerConnector
|
||||
from danswer.connectors.slab.connector import SlabConnector
|
||||
from danswer.connectors.slack.connector import SlackLoadConnector
|
||||
from danswer.connectors.slack.connector import SlackPollConnector
|
||||
@@ -53,6 +54,7 @@ def identify_connector_class(
|
||||
DocumentSource.SLAB: SlabConnector,
|
||||
DocumentSource.NOTION: NotionConnector,
|
||||
DocumentSource.ZULIP: ZulipConnector,
|
||||
DocumentSource.REQUESTTRACKER: RequestTrackerConnector,
|
||||
DocumentSource.GURU: GuruConnector,
|
||||
DocumentSource.LINEAR: LinearConnector,
|
||||
DocumentSource.HUBSPOT: HubSpotConnector,
|
||||
|
||||
0
backend/danswer/connectors/file/__init__.py
Normal file
0
backend/danswer/connectors/file/__init__.py
Normal file
@@ -8,9 +8,11 @@ from typing import IO
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.file_utils import detect_encoding
|
||||
from danswer.connectors.cross_connector_utils.file_utils import load_files_from_zip
|
||||
from danswer.connectors.cross_connector_utils.file_utils import read_file
|
||||
from danswer.connectors.cross_connector_utils.file_utils import read_pdf_file
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.file.utils import check_file_ext_is_valid
|
||||
from danswer.connectors.file.utils import get_file_ext
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
@@ -31,11 +33,12 @@ def _open_files_at_location(
|
||||
if extension == ".zip":
|
||||
for file_info, file in load_files_from_zip(file_path, ignore_dirs=True):
|
||||
yield file_info.filename, file
|
||||
elif extension == ".txt" or extension == ".pdf":
|
||||
mode = "r"
|
||||
if extension == ".pdf":
|
||||
mode = "rb"
|
||||
with open(file_path, mode) as file:
|
||||
elif extension in [".txt", ".md", ".mdx"]:
|
||||
encoding = detect_encoding(file_path)
|
||||
with open(file_path, "r", encoding=encoding, errors="replace") as file:
|
||||
yield os.path.basename(file_path), file
|
||||
elif extension == ".pdf":
|
||||
with open(file_path, "rb") as file:
|
||||
yield os.path.basename(file_path), file
|
||||
else:
|
||||
logger.warning(f"Skipping file '{file_path}' with extension '{extension}'")
|
||||
@@ -61,13 +64,20 @@ def _process_file(
|
||||
else:
|
||||
file_content_raw, metadata = read_file(file)
|
||||
|
||||
dt_str = metadata.get("doc_updated_at")
|
||||
final_time_updated = time_str_to_utc(dt_str) if dt_str else time_updated
|
||||
|
||||
return [
|
||||
Document(
|
||||
id=file_name,
|
||||
sections=[Section(link=metadata.get("link", ""), text=file_content_raw)],
|
||||
sections=[
|
||||
Section(link=metadata.get("link"), text=file_content_raw.strip())
|
||||
],
|
||||
source=DocumentSource.FILE,
|
||||
semantic_identifier=file_name,
|
||||
doc_updated_at=time_updated,
|
||||
doc_updated_at=final_time_updated,
|
||||
primary_owners=metadata.get("primary_owners"),
|
||||
secondary_owners=metadata.get("secondary_owners"),
|
||||
metadata={},
|
||||
)
|
||||
]
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import IO
|
||||
|
||||
from danswer.configs.app_configs import FILE_CONNECTOR_TMP_STORAGE_PATH
|
||||
|
||||
_VALID_FILE_EXTENSIONS = [".txt", ".zip", ".pdf"]
|
||||
_VALID_FILE_EXTENSIONS = [".txt", ".zip", ".pdf", ".md", ".mdx"]
|
||||
|
||||
|
||||
def get_file_ext(file_path_or_name: str | Path) -> str:
|
||||
|
||||
@@ -37,10 +37,9 @@ def _batch_github_objects(
|
||||
|
||||
|
||||
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
|
||||
full_context = f"Pull-Request {pull_request.title}\n{pull_request.body}"
|
||||
return Document(
|
||||
id=pull_request.html_url,
|
||||
sections=[Section(link=pull_request.html_url, text=full_context)],
|
||||
sections=[Section(link=pull_request.html_url, text=pull_request.body or "")],
|
||||
source=DocumentSource.GITHUB,
|
||||
semantic_identifier=pull_request.title,
|
||||
# updated_at is UTC time but is timezone unaware, explicitly add UTC
|
||||
@@ -48,7 +47,7 @@ def _convert_pr_to_document(pull_request: PullRequest) -> Document:
|
||||
# due to local time discrepancies with UTC
|
||||
doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc),
|
||||
metadata={
|
||||
"merged": pull_request.merged,
|
||||
"merged": str(pull_request.merged),
|
||||
"state": pull_request.state,
|
||||
},
|
||||
)
|
||||
@@ -60,10 +59,9 @@ def _fetch_issue_comments(issue: Issue) -> str:
|
||||
|
||||
|
||||
def _convert_issue_to_document(issue: Issue) -> Document:
|
||||
full_context = f"Issue {issue.title}\n{issue.body}"
|
||||
return Document(
|
||||
id=issue.html_url,
|
||||
sections=[Section(link=issue.html_url, text=full_context)],
|
||||
sections=[Section(link=issue.html_url, text=issue.body or "")],
|
||||
source=DocumentSource.GITHUB,
|
||||
semantic_identifier=issue.title,
|
||||
# updated_at is UTC time but is timezone unaware
|
||||
|
||||
@@ -32,7 +32,6 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
self,
|
||||
workspaces: list[str] | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
use_end_time: bool = False,
|
||||
continue_on_fail: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
hide_user_info: bool = False,
|
||||
) -> None:
|
||||
@@ -40,7 +39,6 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
self.batch_size: int = batch_size
|
||||
self.continue_on_fail = continue_on_fail
|
||||
self.auth_token_basic: str | None = None
|
||||
self.use_end_time = use_end_time
|
||||
self.hide_user_info = hide_user_info
|
||||
|
||||
def _get_auth_header(self) -> dict[str, str]:
|
||||
@@ -102,7 +100,12 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
# If no calls in the range, just break out
|
||||
if response.status_code == 404:
|
||||
break
|
||||
response.raise_for_status()
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except Exception:
|
||||
logger.error(f"Error fetching transcripts: {response.text}")
|
||||
raise
|
||||
|
||||
data = response.json()
|
||||
call_transcripts = data.get("callTranscripts", [])
|
||||
@@ -203,9 +206,6 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
speaker_to_name: dict[str, str] = {}
|
||||
|
||||
transcript_text = ""
|
||||
if call_title:
|
||||
transcript_text += f"Call Title: {call_title}\n\n"
|
||||
|
||||
call_purpose = call_metadata["purpose"]
|
||||
if call_purpose:
|
||||
transcript_text += f"Call Description: {call_purpose}\n\n"
|
||||
@@ -231,6 +231,11 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
transcript_text += f"{speaker_name}: {monolog}\n\n"
|
||||
|
||||
metadata = {}
|
||||
if call_metadata.get("system"):
|
||||
metadata["client"] = call_metadata.get("system")
|
||||
# TODO calls have a clientUniqueId field, can pull that in later
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=call_id,
|
||||
@@ -243,7 +248,7 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
doc_updated_at=datetime.fromisoformat(call_time_str).astimezone(
|
||||
timezone.utc
|
||||
),
|
||||
metadata={},
|
||||
metadata={"client": call_metadata.get("system")},
|
||||
)
|
||||
)
|
||||
yield doc_batch
|
||||
@@ -263,6 +268,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
# if this env variable is set, don't start from a timestamp before the specified
|
||||
# start time
|
||||
# TODO: remove this once this is globally available
|
||||
@@ -272,6 +279,10 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
else:
|
||||
special_start_datetime = datetime.fromtimestamp(0, tz=timezone.utc)
|
||||
|
||||
# don't let the special start dt be past the end time, this causes issues when
|
||||
# the Gong API (`filter.fromDateTime: must be before toDateTime`)
|
||||
special_start_datetime = min(special_start_datetime, end_datetime)
|
||||
|
||||
start_datetime = max(
|
||||
datetime.fromtimestamp(start, tz=timezone.utc), special_start_datetime
|
||||
)
|
||||
@@ -280,11 +291,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
# so adding a 1 day buffer and fetching by default till current time
|
||||
start_one_day_offset = start_datetime - timedelta(days=1)
|
||||
start_time = start_one_day_offset.isoformat()
|
||||
end_time = (
|
||||
datetime.fromtimestamp(end, tz=timezone.utc).isoformat()
|
||||
if self.use_end_time
|
||||
else None
|
||||
)
|
||||
|
||||
end_time = datetime.fromtimestamp(end, tz=timezone.utc).isoformat()
|
||||
|
||||
logger.info(f"Fetching Gong calls between {start_time} and {end_time}")
|
||||
return self._fetch_calls(start_time, end_time)
|
||||
@@ -292,7 +300,6 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import time
|
||||
|
||||
connector = GongConnector()
|
||||
connector.load_credentials(
|
||||
@@ -302,6 +309,5 @@ if __name__ == "__main__":
|
||||
}
|
||||
)
|
||||
|
||||
current = time.time()
|
||||
latest_docs = connector.load_from_state()
|
||||
print(next(latest_docs))
|
||||
|
||||
@@ -62,7 +62,10 @@ class GDriveMimeType(str, Enum):
|
||||
|
||||
GoogleDriveFileType = dict[str, Any]
|
||||
|
||||
add_retries = retry_builder()
|
||||
# Google Drive APIs are quite flakey and may 500 for an
|
||||
# extended period of time. Trying to combat here by adding a very
|
||||
# long retry period (~20 minutes of trying every minute)
|
||||
add_retries = retry_builder(tries=50, max_delay=30)
|
||||
|
||||
|
||||
def _run_drive_file_query(
|
||||
@@ -101,12 +104,18 @@ def _run_drive_file_query(
|
||||
for file in files:
|
||||
if follow_shortcuts and "shortcutDetails" in file:
|
||||
try:
|
||||
file = service.files().get(
|
||||
fileId=file["shortcutDetails"]["targetId"],
|
||||
supportsAllDrives=include_shared,
|
||||
fields="mimeType, id, name, modifiedTime, webViewLink, shortcutDetails",
|
||||
)
|
||||
file = add_retries(lambda: file.execute())()
|
||||
file_shortcut_points_to = add_retries(
|
||||
lambda: (
|
||||
service.files()
|
||||
.get(
|
||||
fileId=file["shortcutDetails"]["targetId"],
|
||||
supportsAllDrives=include_shared,
|
||||
fields="mimeType, id, name, modifiedTime, webViewLink, shortcutDetails",
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
)()
|
||||
yield file_shortcut_points_to
|
||||
except HttpError:
|
||||
logger.error(
|
||||
f"Failed to follow shortcut with details: {file['shortcutDetails']}"
|
||||
@@ -114,7 +123,8 @@ def _run_drive_file_query(
|
||||
if continue_on_failure:
|
||||
continue
|
||||
raise
|
||||
yield file
|
||||
else:
|
||||
yield file
|
||||
|
||||
|
||||
def _get_folder_id(
|
||||
@@ -456,24 +466,20 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
|
||||
doc_batch = []
|
||||
for file in files_batch:
|
||||
try:
|
||||
text_contents = extract_text(file, service)
|
||||
if text_contents:
|
||||
full_context = file["name"] + " - " + text_contents
|
||||
else:
|
||||
full_context = file["name"]
|
||||
text_contents = extract_text(file, service) or ""
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=file["webViewLink"],
|
||||
sections=[
|
||||
Section(link=file["webViewLink"], text=full_context)
|
||||
Section(link=file["webViewLink"], text=text_contents)
|
||||
],
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
semantic_identifier=file["name"],
|
||||
doc_updated_at=datetime.fromisoformat(
|
||||
file["modifiedTime"]
|
||||
).astimezone(timezone.utc),
|
||||
metadata={} if text_contents else {IGNORE_FOR_QA: True},
|
||||
metadata={} if text_contents else {IGNORE_FOR_QA: "True"},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -25,9 +25,9 @@ from danswer.connectors.google_drive.constants import SCOPES
|
||||
from danswer.db.credentials import update_credential_json
|
||||
from danswer.db.models import User
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from danswer.server.models import CredentialBase
|
||||
from danswer.server.models import GoogleAppCredentials
|
||||
from danswer.server.models import GoogleServiceAccountKey
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.server.documents.models import GoogleAppCredentials
|
||||
from danswer.server.documents.models import GoogleServiceAccountKey
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -130,7 +130,7 @@ def build_service_account_creds(
|
||||
|
||||
return CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
is_admin=True,
|
||||
admin_public=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
0
backend/danswer/connectors/google_site/__init__.py
Normal file
0
backend/danswer/connectors/google_site/__init__.py
Normal file
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
import urllib.parse
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -20,42 +20,31 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def process_link(element: BeautifulSoup | Tag) -> str:
|
||||
href = cast(str | None, element.get("href"))
|
||||
if not href:
|
||||
raise RuntimeError(f"Invalid link - {element}")
|
||||
def a_tag_text_to_path(atag: Tag) -> str:
|
||||
page_path = atag.text.strip().lower()
|
||||
page_path = re.sub(r"[^a-zA-Z0-9\s]", "", page_path)
|
||||
page_path = "-".join(page_path.split())
|
||||
|
||||
# cleanup href
|
||||
href = urllib.parse.unquote(href)
|
||||
href = href.rstrip(".html").lower()
|
||||
href = href.replace("_", "")
|
||||
href = href.replace(" ", "-")
|
||||
|
||||
return href
|
||||
return page_path
|
||||
|
||||
|
||||
def find_google_sites_page_path_from_navbar(
|
||||
element: BeautifulSoup | Tag, path: str, is_initial: bool
|
||||
element: BeautifulSoup | Tag, path: str, depth: int
|
||||
) -> str | None:
|
||||
ul = cast(Tag | None, element.find("ul"))
|
||||
if ul:
|
||||
if not is_initial:
|
||||
a = cast(Tag, element.find("a"))
|
||||
new_path = f"{path}/{process_link(a)}"
|
||||
if a.get("aria-selected") == "true":
|
||||
return new_path
|
||||
else:
|
||||
new_path = ""
|
||||
for li in ul.find_all("li", recursive=False):
|
||||
found_link = find_google_sites_page_path_from_navbar(li, new_path, False)
|
||||
if found_link:
|
||||
return found_link
|
||||
else:
|
||||
a = cast(Tag, element.find("a"))
|
||||
if a:
|
||||
href = process_link(a)
|
||||
if href and a.get("aria-selected") == "true":
|
||||
return path + "/" + href
|
||||
lis = cast(
|
||||
list[Tag],
|
||||
element.find_all("li", attrs={"data-nav-level": f"{depth}"}),
|
||||
)
|
||||
for li in lis:
|
||||
a = cast(Tag, li.find("a"))
|
||||
if a.get("aria-selected") == "true":
|
||||
return f"{path}/{a_tag_text_to_path(a)}"
|
||||
elif a.get("aria-expanded") == "true":
|
||||
sub_path = find_google_sites_page_path_from_navbar(
|
||||
element, f"{path}/{a_tag_text_to_path(a)}", depth + 1
|
||||
)
|
||||
if sub_path:
|
||||
return sub_path
|
||||
|
||||
return None
|
||||
|
||||
@@ -79,6 +68,7 @@ class GoogleSitesConnector(LoadConnector):
|
||||
|
||||
# load the HTML files
|
||||
files = load_files_from_zip(self.zip_path)
|
||||
count = 0
|
||||
for file_info, file_io in files:
|
||||
# skip non-published files
|
||||
if "/PUBLISHED/" not in file_info.filename:
|
||||
@@ -94,13 +84,15 @@ class GoogleSitesConnector(LoadConnector):
|
||||
# get the link out of the navbar
|
||||
header = cast(Tag, soup.find("header"))
|
||||
nav = cast(Tag, header.find("nav"))
|
||||
path = find_google_sites_page_path_from_navbar(nav, "", True)
|
||||
path = find_google_sites_page_path_from_navbar(nav, "", 1)
|
||||
if not path:
|
||||
count += 1
|
||||
logger.error(
|
||||
f"Could not find path for '{file_info.filename}'. "
|
||||
+ "This page will not have a working link."
|
||||
+ "This page will not have a working link.\n\n"
|
||||
+ f"# of broken links so far - {count}"
|
||||
)
|
||||
|
||||
logger.info(f"Path to page: {path}")
|
||||
# cleanup the hidden `Skip to main content` and `Skip to navigation` that
|
||||
# appears at the top of every page
|
||||
for div in soup.find_all("div", attrs={"data-is-touch-wrapper": "true"}):
|
||||
|
||||
@@ -8,6 +8,7 @@ import requests
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.html_utils import parse_html_page_basic
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
@@ -76,14 +77,26 @@ class GuruConnector(LoadConnector, PollConnector):
|
||||
for card in cards:
|
||||
title = card["preferredPhrase"]
|
||||
link = GURU_CARDS_URL + card["slug"]
|
||||
content_text = title + "\n" + parse_html_page_basic(card["content"])
|
||||
content_text = parse_html_page_basic(card["content"])
|
||||
last_updated = time_str_to_utc(card["lastModified"])
|
||||
last_verified = (
|
||||
time_str_to_utc(card.get("lastVerified"))
|
||||
if card.get("lastVerified")
|
||||
else None
|
||||
)
|
||||
|
||||
# For Danswer, we decay document score overtime, either last_updated or
|
||||
# last_verified is a good enough signal for the document's recency
|
||||
latest_time = (
|
||||
max(last_verified, last_updated) if last_verified else last_updated
|
||||
)
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=card["id"],
|
||||
sections=[Section(link=link, text=content_text)],
|
||||
source=DocumentSource.GURU,
|
||||
semantic_identifier=title,
|
||||
doc_updated_at=latest_time,
|
||||
metadata={},
|
||||
)
|
||||
)
|
||||
@@ -109,3 +122,18 @@ class GuruConnector(LoadConnector, PollConnector):
|
||||
end_time = unixtime_to_guru_time_str(end)
|
||||
|
||||
return self._process_cards(start_time, end_time)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
connector = GuruConnector()
|
||||
connector.load_credentials(
|
||||
{
|
||||
"guru_user": os.environ["GURU_USER"],
|
||||
"guru_user_token": os.environ["GURU_USER_TOKEN"],
|
||||
}
|
||||
)
|
||||
|
||||
latest_docs = connector.load_from_state()
|
||||
print(next(latest_docs))
|
||||
|
||||
@@ -73,7 +73,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
|
||||
title = ticket.properties["subject"]
|
||||
link = self.ticket_base_url + ticket.id
|
||||
content_text = title + "\n" + ticket.properties["content"]
|
||||
content_text = ticket.properties["content"]
|
||||
|
||||
associated_emails: list[str] = []
|
||||
associated_notes: list[str] = []
|
||||
|
||||
0
backend/danswer/connectors/linear/__init__.py
Normal file
0
backend/danswer/connectors/linear/__init__.py
Normal file
@@ -8,6 +8,7 @@ import requests
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
@@ -30,7 +31,6 @@ def _make_query(request_body: dict[str, Any], api_key: str) -> requests.Response
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response: requests.Response | None = None
|
||||
for i in range(_NUM_RETRIES):
|
||||
try:
|
||||
response = requests.post(
|
||||
@@ -187,8 +187,8 @@ class LinearConnector(LoadConnector, PollConnector):
|
||||
],
|
||||
source=DocumentSource.LINEAR,
|
||||
semantic_identifier=node["identifier"],
|
||||
doc_updated_at=time_str_to_utc(node["updatedAt"]),
|
||||
metadata={
|
||||
"updated_at": node["updatedAt"],
|
||||
"team": node["team"]["name"],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import INDEX_SEPARATOR
|
||||
from danswer.utils.text_processing import make_url_compatible
|
||||
|
||||
|
||||
class InputType(str, Enum):
|
||||
LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file
|
||||
POLL = "poll" # e.g. calling an API to get all documents in the last hour
|
||||
EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events
|
||||
|
||||
|
||||
class ConnectorMissingCredentialError(PermissionError):
|
||||
@@ -14,44 +22,93 @@ class ConnectorMissingCredentialError(PermissionError):
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Section:
|
||||
link: str
|
||||
class Section(BaseModel):
|
||||
text: str
|
||||
link: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Document:
|
||||
id: str # This must be unique or during indexing/reindexing, chunks will be overwritten
|
||||
class BasicExpertInfo(BaseModel):
|
||||
"""Basic Information for the owner of a document, any of the fields can be left as None
|
||||
Display fallback goes as follows:
|
||||
- first_name + (optional middle_initial) + last_name
|
||||
- display_name
|
||||
- email
|
||||
- first_name
|
||||
"""
|
||||
|
||||
display_name: str | None = None
|
||||
first_name: str | None = None
|
||||
middle_initial: str | None = None
|
||||
last_name: str | None = None
|
||||
email: str | None = None
|
||||
|
||||
|
||||
class DocumentBase(BaseModel):
|
||||
"""Used for Danswer ingestion api, the ID is inferred before use if not provided"""
|
||||
|
||||
id: str | None = None
|
||||
sections: list[Section]
|
||||
source: DocumentSource
|
||||
source: DocumentSource | None = None
|
||||
semantic_identifier: str # displayed in the UI as the main identifier for the doc
|
||||
metadata: dict[str, Any]
|
||||
metadata: dict[str, str | list[str]]
|
||||
# UTC time
|
||||
doc_updated_at: datetime | None = None
|
||||
# Owner, creator, etc.
|
||||
primary_owners: list[str] | None = None
|
||||
primary_owners: list[BasicExpertInfo] | None = None
|
||||
# Assignee, space owner, etc.
|
||||
secondary_owners: list[str] | None = None
|
||||
# `title` is used when computing best matches for a query
|
||||
# if `None`, then we will use the `semantic_identifier` as the title in Vespa
|
||||
secondary_owners: list[BasicExpertInfo] | None = None
|
||||
# title is used for search whereas semantic_identifier is used for displaying in the UI
|
||||
# different because Slack message may display as #general but general should not be part
|
||||
# of the search, at least not in the same way as a document title should be for like Confluence
|
||||
# The default title is semantic_identifier though unless otherwise specified
|
||||
title: str | None = None
|
||||
from_ingestion_api: bool = False
|
||||
|
||||
def get_title_for_document_index(self) -> str:
|
||||
def get_title_for_document_index(self) -> str | None:
|
||||
# If title is explicitly empty, return a None here for embedding purposes
|
||||
if self.title == "":
|
||||
return None
|
||||
return self.semantic_identifier if self.title is None else self.title
|
||||
|
||||
def get_metadata_str_attributes(self) -> list[str] | None:
|
||||
if not self.metadata:
|
||||
return None
|
||||
# Combined string for the key/value for easy filtering
|
||||
attributes: list[str] = []
|
||||
for k, v in self.metadata.items():
|
||||
if isinstance(v, list):
|
||||
attributes.extend([k + INDEX_SEPARATOR + vi for vi in v])
|
||||
else:
|
||||
attributes.append(k + INDEX_SEPARATOR + v)
|
||||
return attributes
|
||||
|
||||
|
||||
class Document(DocumentBase):
|
||||
id: str # This must be unique or during indexing/reindexing, chunks will be overwritten
|
||||
source: DocumentSource
|
||||
|
||||
def to_short_descriptor(self) -> str:
|
||||
"""Used when logging the identity of a document"""
|
||||
return f"ID: '{self.id}'; Semantic ID: '{self.semantic_identifier}'"
|
||||
|
||||
|
||||
class InputType(str, Enum):
|
||||
LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file
|
||||
POLL = "poll" # e.g. calling an API to get all documents in the last hour
|
||||
EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events
|
||||
@classmethod
|
||||
def from_base(cls, base: DocumentBase) -> "Document":
|
||||
return cls(
|
||||
id=make_url_compatible(base.id)
|
||||
if base.id
|
||||
else "ingestion_api_" + make_url_compatible(base.semantic_identifier),
|
||||
sections=base.sections,
|
||||
source=base.source or DocumentSource.INGESTION_API,
|
||||
semantic_identifier=base.semantic_identifier,
|
||||
metadata=base.metadata,
|
||||
doc_updated_at=base.doc_updated_at,
|
||||
primary_owners=base.primary_owners,
|
||||
secondary_owners=base.secondary_owners,
|
||||
title=base.title,
|
||||
from_ingestion_api=base.from_ingestion_api,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexAttemptMetadata:
|
||||
class IndexAttemptMetadata(BaseModel):
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
|
||||
@@ -24,6 +24,8 @@ from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_NOTION_CALL_TIMEOUT = 30 # 30 seconds
|
||||
|
||||
|
||||
@dataclass
|
||||
class NotionPage:
|
||||
@@ -80,6 +82,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
self.indexed_pages: set[str] = set()
|
||||
self.root_page_id = root_page_id
|
||||
# if enabled, will recursively index child pages as they are found rather
|
||||
# relying entirely on the `search` API. We have recieved reports that the
|
||||
# `search` API misses many pages - in those cases, this might need to be
|
||||
@@ -87,8 +90,9 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
# NOTE: this also removes all benefits polling, since we need to traverse
|
||||
# all pages regardless of if they are updated. If the notion workspace is
|
||||
# very large, this may not be practical.
|
||||
self.recursive_index_enabled = recursive_index_enabled
|
||||
self.root_page_id = root_page_id
|
||||
self.recursive_index_enabled = (
|
||||
recursive_index_enabled or self.root_page_id is not None
|
||||
)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_blocks(self, block_id: str, cursor: str | None = None) -> dict[str, Any]:
|
||||
@@ -96,7 +100,12 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
logger.debug(f"Fetching children of block with ID '{block_id}'")
|
||||
block_url = f"https://api.notion.com/v1/blocks/{block_id}/children"
|
||||
query_params = None if not cursor else {"start_cursor": cursor}
|
||||
res = requests.get(block_url, headers=self.headers, params=query_params)
|
||||
res = requests.get(
|
||||
block_url,
|
||||
headers=self.headers,
|
||||
params=query_params,
|
||||
timeout=_NOTION_CALL_TIMEOUT,
|
||||
)
|
||||
try:
|
||||
res.raise_for_status()
|
||||
except Exception as e:
|
||||
@@ -109,7 +118,11 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
"""Fetch a page from it's ID via the Notion API."""
|
||||
logger.debug(f"Fetching page for ID '{page_id}'")
|
||||
block_url = f"https://api.notion.com/v1/pages/{page_id}"
|
||||
res = requests.get(block_url, headers=self.headers)
|
||||
res = requests.get(
|
||||
block_url,
|
||||
headers=self.headers,
|
||||
timeout=_NOTION_CALL_TIMEOUT,
|
||||
)
|
||||
try:
|
||||
res.raise_for_status()
|
||||
except Exception as e:
|
||||
@@ -117,6 +130,64 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
raise e
|
||||
return NotionPage(**res.json())
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_database(
|
||||
self, database_id: str, cursor: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch a database from it's ID via the Notion API."""
|
||||
logger.debug(f"Fetching database for ID '{database_id}'")
|
||||
block_url = f"https://api.notion.com/v1/databases/{database_id}/query"
|
||||
body = None if not cursor else {"start_cursor": cursor}
|
||||
res = requests.post(
|
||||
block_url,
|
||||
headers=self.headers,
|
||||
json=body,
|
||||
timeout=_NOTION_CALL_TIMEOUT,
|
||||
)
|
||||
try:
|
||||
res.raise_for_status()
|
||||
except Exception as e:
|
||||
if res.json().get("code") == "object_not_found":
|
||||
# this happens when a database is not shared with the integration
|
||||
# in this case, we should just ignore the database
|
||||
logger.error(
|
||||
f"Unable to access database with ID '{database_id}'. "
|
||||
f"This is likely due to the database not being shared "
|
||||
f"with the Danswer integration. Exact exception:\n{e}"
|
||||
)
|
||||
return {"results": [], "next_cursor": None}
|
||||
logger.exception(f"Error fetching database - {res.json()}")
|
||||
raise e
|
||||
return res.json()
|
||||
|
||||
def _read_pages_from_database(self, database_id: str) -> list[str]:
|
||||
"""Returns a list of all page IDs in the database"""
|
||||
result_pages: list[str] = []
|
||||
cursor = None
|
||||
while True:
|
||||
data = self._fetch_database(database_id, cursor)
|
||||
|
||||
for result in data["results"]:
|
||||
obj_id = result["id"]
|
||||
obj_type = result["object"]
|
||||
if obj_type == "page":
|
||||
logger.debug(
|
||||
f"Found page with ID '{obj_id}' in database '{database_id}'"
|
||||
)
|
||||
result_pages.append(result["id"])
|
||||
elif obj_type == "database":
|
||||
logger.debug(
|
||||
f"Found database with ID '{obj_id}' in database '{database_id}'"
|
||||
)
|
||||
result_pages.extend(self._read_pages_from_database(obj_id))
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
break
|
||||
|
||||
cursor = data["next_cursor"]
|
||||
|
||||
return result_pages
|
||||
|
||||
def _read_blocks(
|
||||
self, page_block_id: str
|
||||
) -> tuple[list[tuple[str, str]], list[str]]:
|
||||
@@ -141,8 +212,20 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
text = rich_text["text"]["content"]
|
||||
cur_result_text_arr.append(text)
|
||||
|
||||
if result["has_children"] and result_type == "child_page":
|
||||
child_pages.append(result_block_id)
|
||||
if result["has_children"]:
|
||||
if result_type == "child_page":
|
||||
child_pages.append(result_block_id)
|
||||
else:
|
||||
logger.debug(f"Entering sub-block: {result_block_id}")
|
||||
subblock_result_lines, subblock_child_pages = self._read_blocks(
|
||||
result_block_id
|
||||
)
|
||||
logger.debug(f"Finished sub-block: {result_block_id}")
|
||||
result_lines.extend(subblock_result_lines)
|
||||
child_pages.extend(subblock_child_pages)
|
||||
|
||||
if result_type == "child_database" and self.recursive_index_enabled:
|
||||
child_pages.extend(self._read_pages_from_database(result_block_id))
|
||||
|
||||
cur_result_text = "\n".join(cur_result_text_arr)
|
||||
if cur_result_text:
|
||||
@@ -184,7 +267,8 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
yield (
|
||||
Document(
|
||||
id=page.id,
|
||||
sections=[Section(link=page.url, text=f"{page_title}\n")]
|
||||
# Will add title to the first section later in processing
|
||||
sections=[Section(link=page.url, text="")]
|
||||
+ [
|
||||
Section(
|
||||
link=f"{page.url}#{block_id.replace('-', '')}",
|
||||
@@ -221,6 +305,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
"https://api.notion.com/v1/search",
|
||||
headers=self.headers,
|
||||
json=query_dict,
|
||||
timeout=_NOTION_CALL_TIMEOUT,
|
||||
)
|
||||
res.raise_for_status()
|
||||
return NotionSearchResponse(**res.json())
|
||||
|
||||
@@ -10,9 +10,11 @@ from retry import retry
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -93,26 +95,24 @@ class ProductboardConnector(PollConnector):
|
||||
for feature in self._fetch_documents(
|
||||
initial_link=f"{_PRODUCT_BOARD_BASE_URL}/features"
|
||||
):
|
||||
owner = self._get_owner_email(feature)
|
||||
experts = [BasicExpertInfo(email=owner)] if owner else None
|
||||
|
||||
yield Document(
|
||||
id=feature["id"],
|
||||
sections=[
|
||||
Section(
|
||||
link=feature["links"]["html"],
|
||||
text=" - ".join(
|
||||
(
|
||||
feature["name"],
|
||||
self._parse_description_html(feature["description"]),
|
||||
)
|
||||
),
|
||||
text=self._parse_description_html(feature["description"]),
|
||||
)
|
||||
],
|
||||
semantic_identifier=feature["name"],
|
||||
source=DocumentSource.PRODUCTBOARD,
|
||||
doc_updated_at=time_str_to_utc(feature["updatedAt"]),
|
||||
primary_owners=experts,
|
||||
metadata={
|
||||
"productboard_entity_type": feature["type"],
|
||||
"entity_type": feature["type"],
|
||||
"status": feature["status"]["name"],
|
||||
"owner": self._get_owner_email(feature),
|
||||
"updated_at": feature["updatedAt"],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -121,25 +121,23 @@ class ProductboardConnector(PollConnector):
|
||||
for component in self._fetch_documents(
|
||||
initial_link=f"{_PRODUCT_BOARD_BASE_URL}/components"
|
||||
):
|
||||
owner = self._get_owner_email(component)
|
||||
experts = [BasicExpertInfo(email=owner)] if owner else None
|
||||
|
||||
yield Document(
|
||||
id=component["id"],
|
||||
sections=[
|
||||
Section(
|
||||
link=component["links"]["html"],
|
||||
text=" - ".join(
|
||||
(
|
||||
component["name"],
|
||||
self._parse_description_html(component["description"]),
|
||||
)
|
||||
),
|
||||
text=self._parse_description_html(component["description"]),
|
||||
)
|
||||
],
|
||||
semantic_identifier=component["name"],
|
||||
source=DocumentSource.PRODUCTBOARD,
|
||||
doc_updated_at=time_str_to_utc(component["updatedAt"]),
|
||||
primary_owners=experts,
|
||||
metadata={
|
||||
"productboard_entity_type": "component",
|
||||
"owner": self._get_owner_email(component),
|
||||
"updated_at": component["updatedAt"],
|
||||
"entity_type": "component",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -149,25 +147,23 @@ class ProductboardConnector(PollConnector):
|
||||
for product in self._fetch_documents(
|
||||
initial_link=f"{_PRODUCT_BOARD_BASE_URL}/products"
|
||||
):
|
||||
owner = self._get_owner_email(product)
|
||||
experts = [BasicExpertInfo(email=owner)] if owner else None
|
||||
|
||||
yield Document(
|
||||
id=product["id"],
|
||||
sections=[
|
||||
Section(
|
||||
link=product["links"]["html"],
|
||||
text=" - ".join(
|
||||
(
|
||||
product["name"],
|
||||
self._parse_description_html(product["description"]),
|
||||
)
|
||||
),
|
||||
text=self._parse_description_html(product["description"]),
|
||||
)
|
||||
],
|
||||
semantic_identifier=product["name"],
|
||||
source=DocumentSource.PRODUCTBOARD,
|
||||
doc_updated_at=time_str_to_utc(product["updatedAt"]),
|
||||
primary_owners=experts,
|
||||
metadata={
|
||||
"productboard_entity_type": "product",
|
||||
"owner": self._get_owner_email(product),
|
||||
"updated_at": product["updatedAt"],
|
||||
"entity_type": "product",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -175,26 +171,24 @@ class ProductboardConnector(PollConnector):
|
||||
for objective in self._fetch_documents(
|
||||
initial_link=f"{_PRODUCT_BOARD_BASE_URL}/objectives"
|
||||
):
|
||||
owner = self._get_owner_email(objective)
|
||||
experts = [BasicExpertInfo(email=owner)] if owner else None
|
||||
|
||||
yield Document(
|
||||
id=objective["id"],
|
||||
sections=[
|
||||
Section(
|
||||
link=objective["links"]["html"],
|
||||
text=" - ".join(
|
||||
(
|
||||
objective["name"],
|
||||
self._parse_description_html(objective["description"]),
|
||||
)
|
||||
),
|
||||
text=self._parse_description_html(objective["description"]),
|
||||
)
|
||||
],
|
||||
semantic_identifier=objective["name"],
|
||||
source=DocumentSource.PRODUCTBOARD,
|
||||
doc_updated_at=time_str_to_utc(objective["updatedAt"]),
|
||||
primary_owners=experts,
|
||||
metadata={
|
||||
"productboard_entity_type": "release",
|
||||
"entity_type": "release",
|
||||
"state": objective["state"],
|
||||
"owner": self._get_owner_email(objective),
|
||||
"updated_at": objective["updatedAt"],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -252,3 +246,20 @@ class ProductboardConnector(PollConnector):
|
||||
|
||||
if document_batch:
|
||||
yield document_batch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import time
|
||||
|
||||
connector = ProductboardConnector()
|
||||
connector.load_credentials(
|
||||
{
|
||||
"productboard_access_token": os.environ["PRODUCTBOARD_ACCESS_TOKEN"],
|
||||
}
|
||||
)
|
||||
|
||||
current = time.time()
|
||||
one_year_ago = current - 24 * 60 * 60 * 360
|
||||
latest_docs = connector.poll_source(one_year_ago, current)
|
||||
print(next(latest_docs))
|
||||
|
||||
1
backend/danswer/connectors/requesttracker/.gitignore
vendored
Normal file
1
backend/danswer/connectors/requesttracker/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
.env
|
||||
153
backend/danswer/connectors/requesttracker/connector.py
Normal file
153
backend/danswer/connectors/requesttracker/connector.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from logging import DEBUG as LOG_LVL_DEBUG
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from rt.rest1 import ALL_QUEUES
|
||||
from rt.rest1 import Rt
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class RequestTrackerError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RequestTrackerConnector(PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
|
||||
def txn_link(self, tid: int, txn: int) -> str:
|
||||
return f"{self.rt_base_url}/Ticket/Display.html?id={tid}&txn={txn}"
|
||||
|
||||
def build_doc_sections_from_txn(
|
||||
self, connection: Rt, ticket_id: int
|
||||
) -> List[Section]:
|
||||
Sections: List[Section] = []
|
||||
|
||||
get_history_resp = connection.get_history(ticket_id)
|
||||
|
||||
if get_history_resp is None:
|
||||
raise RequestTrackerError(f"Ticket {ticket_id} cannot be found")
|
||||
|
||||
for tx in get_history_resp:
|
||||
Sections.append(
|
||||
Section(
|
||||
link=self.txn_link(ticket_id, int(tx["id"])),
|
||||
text="\n".join(
|
||||
[
|
||||
f"{k}:\n{v}\n" if k != "Attachments" else ""
|
||||
for (k, v) in tx.items()
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
return Sections
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> Optional[dict[str, Any]]:
|
||||
self.rt_username = credentials.get("requesttracker_username")
|
||||
self.rt_password = credentials.get("requesttracker_password")
|
||||
self.rt_base_url = credentials.get("requesttracker_base_url")
|
||||
return None
|
||||
|
||||
# This does not include RT file attachments yet.
|
||||
def _process_tickets(
|
||||
self, start: datetime, end: datetime
|
||||
) -> GenerateDocumentsOutput:
|
||||
if any([self.rt_username, self.rt_password, self.rt_base_url]) is None:
|
||||
raise ConnectorMissingCredentialError("requesttracker")
|
||||
|
||||
Rt0 = Rt(
|
||||
f"{self.rt_base_url}/REST/1.0/",
|
||||
self.rt_username,
|
||||
self.rt_password,
|
||||
)
|
||||
|
||||
Rt0.login()
|
||||
|
||||
d0 = start.strftime("%Y-%m-%d %H:%M:%S")
|
||||
d1 = end.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
tickets = Rt0.search(
|
||||
Queue=ALL_QUEUES,
|
||||
raw_query=f"Updated > '{d0}' AND Updated < '{d1}'",
|
||||
)
|
||||
|
||||
doc_batch: List[Document] = []
|
||||
|
||||
for ticket in tickets:
|
||||
ticket_keys_to_omit = ["id", "Subject"]
|
||||
tid: int = int(ticket["numerical_id"])
|
||||
ticketLink: str = f"{self.rt_base_url}/Ticket/Display.html?id={tid}"
|
||||
logger.info(f"Processing ticket {tid}")
|
||||
doc = Document(
|
||||
id=ticket["id"],
|
||||
# Will add title to the first section later in processing
|
||||
sections=[Section(link=ticketLink, text="")]
|
||||
+ self.build_doc_sections_from_txn(Rt0, tid),
|
||||
source=DocumentSource.REQUESTTRACKER,
|
||||
semantic_identifier=ticket["Subject"],
|
||||
metadata={
|
||||
key: value
|
||||
for key, value in ticket.items()
|
||||
if key not in ticket_keys_to_omit
|
||||
},
|
||||
)
|
||||
|
||||
doc_batch.append(doc)
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
# Keep query short, only look behind 1 day at maximum
|
||||
one_day_ago: float = end - (24 * 60 * 60)
|
||||
_start: float = start if start > one_day_ago else one_day_ago
|
||||
start_datetime = datetime.fromtimestamp(_start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
yield from self._process_tickets(start_datetime, end_datetime)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
logger.setLevel(LOG_LVL_DEBUG)
|
||||
rt_connector = RequestTrackerConnector()
|
||||
rt_connector.load_credentials(
|
||||
{
|
||||
"requesttracker_username": os.getenv("RT_USERNAME"),
|
||||
"requesttracker_password": os.getenv("RT_PASSWORD"),
|
||||
"requesttracker_base_url": os.getenv("RT_BASE_URL"),
|
||||
}
|
||||
)
|
||||
|
||||
current = time.time()
|
||||
one_day_ago = current - (24 * 60 * 60) # 1 days
|
||||
latest_docs = rt_connector.poll_source(one_day_ago, current)
|
||||
|
||||
for doc in latest_docs:
|
||||
print(doc)
|
||||
@@ -5,14 +5,35 @@ from zenpy.lib.api_objects.help_centre_objects import Article # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.html_utils import parse_html_page_basic
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
|
||||
|
||||
def _article_to_document(article: Article) -> Document:
|
||||
author = BasicExpertInfo(
|
||||
display_name=article.author.name, email=article.author.email
|
||||
)
|
||||
update_time = time_str_to_utc(article.updated_at)
|
||||
return Document(
|
||||
id=f"article:{article.id}",
|
||||
sections=[
|
||||
Section(link=article.html_url, text=parse_html_page_basic(article.body))
|
||||
],
|
||||
source=DocumentSource.ZENDESK,
|
||||
semantic_identifier=article.title,
|
||||
doc_updated_at=update_time,
|
||||
primary_owners=[author],
|
||||
metadata={"type": "article"},
|
||||
)
|
||||
|
||||
|
||||
class ZendeskClientNotSetUpError(PermissionError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Zendesk Client is not set up, was load_credentials called?")
|
||||
@@ -34,18 +55,6 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self.poll_source(None, None)
|
||||
|
||||
def _article_to_document(self, article: Article) -> Document:
|
||||
return Document(
|
||||
id=f"article:{article.id}",
|
||||
sections=[Section(link=article.html_url, text=article.body)],
|
||||
source=DocumentSource.ZENDESK,
|
||||
semantic_identifier="Article: " + article.title,
|
||||
metadata={
|
||||
"type": "article",
|
||||
"updated_at": article.updated_at,
|
||||
},
|
||||
)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
@@ -64,7 +73,10 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
if article.body is None:
|
||||
continue
|
||||
|
||||
doc_batch.append(self._article_to_document(article))
|
||||
doc_batch.append(_article_to_document(article))
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch.clear()
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
0
backend/danswer/connectors/zulip/__init__.py
Normal file
0
backend/danswer/connectors/zulip/__init__.py
Normal file
@@ -1,33 +1,37 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
import timeago # type: ignore
|
||||
from slack_sdk.models.blocks import ActionsBlock
|
||||
from slack_sdk.models.blocks import Block
|
||||
from slack_sdk.models.blocks import ButtonElement
|
||||
from slack_sdk.models.blocks import ConfirmObject
|
||||
from slack_sdk.models.blocks import DividerBlock
|
||||
from slack_sdk.models.blocks import HeaderBlock
|
||||
from slack_sdk.models.blocks import Option
|
||||
from slack_sdk.models.blocks import RadioButtonsElement
|
||||
from slack_sdk.models.blocks import SectionBlock
|
||||
|
||||
from danswer.chat.models import DanswerQuote
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
|
||||
from danswer.configs.danswerbot_configs import ENABLE_SLACK_DOC_FEEDBACK
|
||||
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.utils import build_feedback_block_id
|
||||
from danswer.danswerbot.slack.utils import build_feedback_id
|
||||
from danswer.danswerbot.slack.utils import remove_slack_text_interactions
|
||||
from danswer.danswerbot.slack.utils import translate_vespa_highlight_to_slack
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.server.models import SearchDoc
|
||||
from danswer.search.models import SavedSearchDoc
|
||||
from danswer.utils.text_processing import decode_escapes
|
||||
from danswer.utils.text_processing import replace_whitespaces_w_space
|
||||
|
||||
|
||||
_MAX_BLURB_LEN = 75
|
||||
|
||||
|
||||
def build_qa_feedback_block(query_event_id: int) -> Block:
|
||||
def build_qa_feedback_block(message_id: int) -> Block:
|
||||
return ActionsBlock(
|
||||
block_id=build_feedback_block_id(query_event_id),
|
||||
block_id=build_feedback_id(message_id),
|
||||
elements=[
|
||||
ButtonElement(
|
||||
action_id=LIKE_BLOCK_ACTION_ID,
|
||||
@@ -43,33 +47,44 @@ def build_qa_feedback_block(query_event_id: int) -> Block:
|
||||
)
|
||||
|
||||
|
||||
def get_document_feedback_blocks() -> Block:
|
||||
return SectionBlock(
|
||||
text=(
|
||||
"- 'Up-Boost' if this document is a good source of information and should be "
|
||||
"shown more often.\n"
|
||||
"- 'Down-boost' if this document is a poor source of information and should be "
|
||||
"shown less often.\n"
|
||||
"- 'Hide' if this document is deprecated and should never be shown anymore."
|
||||
),
|
||||
accessory=RadioButtonsElement(
|
||||
options=[
|
||||
Option(
|
||||
text=":thumbsup: Up-Boost",
|
||||
value=SearchFeedbackType.ENDORSE.value,
|
||||
),
|
||||
Option(
|
||||
text=":thumbsdown: Down-Boost",
|
||||
value=SearchFeedbackType.REJECT.value,
|
||||
),
|
||||
Option(
|
||||
text=":x: Hide",
|
||||
value=SearchFeedbackType.HIDE.value,
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def build_doc_feedback_block(
|
||||
query_event_id: int,
|
||||
message_id: int,
|
||||
document_id: str,
|
||||
document_rank: int,
|
||||
) -> Block:
|
||||
return ActionsBlock(
|
||||
block_id=build_feedback_block_id(query_event_id, document_id, document_rank),
|
||||
elements=[
|
||||
ButtonElement(
|
||||
action_id=SearchFeedbackType.ENDORSE.value,
|
||||
text="⬆",
|
||||
style="primary",
|
||||
confirm=ConfirmObject(
|
||||
title="Endorse this Document",
|
||||
text="This is a good source of information and should be shown more often!",
|
||||
),
|
||||
),
|
||||
ButtonElement(
|
||||
action_id=SearchFeedbackType.REJECT.value,
|
||||
text="⬇",
|
||||
style="danger",
|
||||
confirm=ConfirmObject(
|
||||
title="Reject this Document",
|
||||
text="This is a bad source of information and should be shown less often.",
|
||||
),
|
||||
),
|
||||
],
|
||||
) -> ButtonElement:
|
||||
feedback_id = build_feedback_id(message_id, document_id, document_rank)
|
||||
return ButtonElement(
|
||||
action_id=FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID,
|
||||
value=feedback_id,
|
||||
text="Give Feedback",
|
||||
)
|
||||
|
||||
|
||||
@@ -77,7 +92,7 @@ def get_restate_blocks(
|
||||
msg: str,
|
||||
is_bot_msg: bool,
|
||||
) -> list[Block]:
|
||||
# Only the slash command needs this context because the user doesnt see their own input
|
||||
# Only the slash command needs this context because the user doesn't see their own input
|
||||
if not is_bot_msg:
|
||||
return []
|
||||
|
||||
@@ -88,13 +103,15 @@ def get_restate_blocks(
|
||||
|
||||
|
||||
def build_documents_blocks(
|
||||
documents: list[SearchDoc],
|
||||
query_event_id: int,
|
||||
documents: list[SavedSearchDoc],
|
||||
message_id: int | None,
|
||||
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
|
||||
include_feedback: bool = ENABLE_SLACK_DOC_FEEDBACK,
|
||||
) -> list[Block]:
|
||||
header_text = (
|
||||
"Retrieved Documents" if DISABLE_GENERATIVE_AI else "Reference Documents"
|
||||
)
|
||||
seen_docs_identifiers = set()
|
||||
section_blocks: list[Block] = [HeaderBlock(text="Reference Documents")]
|
||||
section_blocks: list[Block] = [HeaderBlock(text=header_text)]
|
||||
included_docs = 0
|
||||
for rank, d in enumerate(documents):
|
||||
if d.document_id in seen_docs_identifiers:
|
||||
@@ -110,24 +127,32 @@ def build_documents_blocks(
|
||||
|
||||
included_docs += 1
|
||||
|
||||
header_line = f"{doc_sem_id}\n"
|
||||
if d.link:
|
||||
block_text = f"<{d.link}|{doc_sem_id}>:\n>{remove_slack_text_interactions(match_str)}"
|
||||
else:
|
||||
block_text = f"{doc_sem_id}:\n>{remove_slack_text_interactions(match_str)}"
|
||||
header_line = f"<{d.link}|{doc_sem_id}>\n"
|
||||
|
||||
updated_at_line = ""
|
||||
if d.updated_at is not None:
|
||||
updated_at_line = (
|
||||
f"_Updated {timeago.format(d.updated_at, datetime.now(pytz.utc))}_\n"
|
||||
)
|
||||
|
||||
body_text = f">{remove_slack_text_interactions(match_str)}"
|
||||
|
||||
block_text = header_line + updated_at_line + body_text
|
||||
|
||||
feedback: ButtonElement | dict = {}
|
||||
if message_id is not None:
|
||||
feedback = build_doc_feedback_block(
|
||||
message_id=message_id,
|
||||
document_id=d.document_id,
|
||||
document_rank=rank,
|
||||
)
|
||||
|
||||
section_blocks.append(
|
||||
SectionBlock(text=block_text),
|
||||
SectionBlock(text=block_text, accessory=feedback),
|
||||
)
|
||||
|
||||
if include_feedback:
|
||||
section_blocks.append(
|
||||
build_doc_feedback_block(
|
||||
query_event_id=query_event_id,
|
||||
document_id=d.document_id,
|
||||
document_rank=rank,
|
||||
),
|
||||
)
|
||||
|
||||
section_blocks.append(DividerBlock())
|
||||
|
||||
if included_docs >= num_docs_to_display:
|
||||
@@ -179,18 +204,29 @@ def build_quotes_block(
|
||||
|
||||
|
||||
def build_qa_response_blocks(
|
||||
query_event_id: int,
|
||||
message_id: int | None,
|
||||
answer: str | None,
|
||||
quotes: list[DanswerQuote] | None,
|
||||
source_filters: list[DocumentSource] | None,
|
||||
time_cutoff: datetime | None,
|
||||
favor_recent: bool,
|
||||
skip_quotes: bool = False,
|
||||
) -> list[Block]:
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
return []
|
||||
|
||||
quotes_blocks: list[Block] = []
|
||||
|
||||
ai_answer_header = HeaderBlock(text="AI Answer")
|
||||
|
||||
filter_block: Block | None = None
|
||||
if time_cutoff or favor_recent:
|
||||
if time_cutoff or favor_recent or source_filters:
|
||||
filter_text = "Filters: "
|
||||
if source_filters:
|
||||
sources_str = ", ".join([s.value for s in source_filters])
|
||||
filter_text += f"`Sources in [{sources_str}]`"
|
||||
if time_cutoff or favor_recent:
|
||||
filter_text += " and "
|
||||
if time_cutoff is not None:
|
||||
time_str = time_cutoff.strftime("%b %d, %Y")
|
||||
filter_text += f"`Docs Updated >= {time_str}` "
|
||||
@@ -206,7 +242,8 @@ def build_qa_response_blocks(
|
||||
text="Sorry, I was unable to find an answer, but I did find some potentially relevant docs 🤓"
|
||||
)
|
||||
else:
|
||||
answer_block = SectionBlock(text=remove_slack_text_interactions(answer))
|
||||
answer_processed = decode_escapes(remove_slack_text_interactions(answer))
|
||||
answer_block = SectionBlock(text=answer_processed)
|
||||
if quotes:
|
||||
quotes_blocks = build_quotes_block(quotes)
|
||||
|
||||
@@ -218,15 +255,22 @@ def build_qa_response_blocks(
|
||||
)
|
||||
]
|
||||
|
||||
feedback_block = build_qa_feedback_block(query_event_id=query_event_id)
|
||||
feedback_block = None
|
||||
if message_id is not None:
|
||||
feedback_block = build_qa_feedback_block(message_id=message_id)
|
||||
|
||||
response_blocks: list[Block] = [ai_answer_header]
|
||||
|
||||
if filter_block is not None:
|
||||
response_blocks.append(filter_block)
|
||||
|
||||
response_blocks.extend(
|
||||
[answer_block, feedback_block] + quotes_blocks + [DividerBlock()]
|
||||
)
|
||||
response_blocks.append(answer_block)
|
||||
|
||||
if feedback_block is not None:
|
||||
response_blocks.append(feedback_block)
|
||||
|
||||
if not skip_quotes:
|
||||
response_blocks.extend(quotes_blocks)
|
||||
response_blocks.append(DividerBlock())
|
||||
|
||||
return response_blocks
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
LIKE_BLOCK_ACTION_ID = "feedback-like"
|
||||
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
|
||||
FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button"
|
||||
SLACK_CHANNEL_ID = "channel_id"
|
||||
VIEW_DOC_FEEDBACK_ID = "view-doc-feedback"
|
||||
|
||||
@@ -1,19 +1,60 @@
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.models.views import View
|
||||
from slack_sdk.socket_mode import SocketModeClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import QAFeedbackType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.danswerbot.slack.blocks import get_document_feedback_blocks
|
||||
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.utils import decompose_block_id
|
||||
from danswer.danswerbot.slack.constants import VIEW_DOC_FEEDBACK_ID
|
||||
from danswer.danswerbot.slack.utils import build_feedback_id
|
||||
from danswer.danswerbot.slack.utils import decompose_feedback_id
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.feedback import create_chat_message_feedback
|
||||
from danswer.db.feedback import create_doc_retrieval_feedback
|
||||
from danswer.db.feedback import update_query_event_feedback
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger_base = setup_logger()
|
||||
|
||||
|
||||
def handle_doc_feedback_button(
|
||||
req: SocketModeRequest,
|
||||
client: SocketModeClient,
|
||||
) -> None:
|
||||
if not (actions := req.payload.get("actions")):
|
||||
logger_base.error("Missing actions. Unable to build the source feedback view")
|
||||
return
|
||||
|
||||
# Extracts the feedback_id coming from the 'source feedback' button
|
||||
# and generates a new one for the View, to keep track of the doc info
|
||||
query_event_id, doc_id, doc_rank = decompose_feedback_id(actions[0].get("value"))
|
||||
external_id = build_feedback_id(query_event_id, doc_id, doc_rank)
|
||||
|
||||
channel_id = req.payload["container"]["channel_id"]
|
||||
thread_ts = req.payload["container"]["thread_ts"]
|
||||
|
||||
data = View(
|
||||
type="modal",
|
||||
callback_id=VIEW_DOC_FEEDBACK_ID,
|
||||
external_id=external_id,
|
||||
# We use the private metadata to keep track of the channel id and thread ts
|
||||
private_metadata=f"{channel_id}_{thread_ts}",
|
||||
title="Give Feedback",
|
||||
blocks=[get_document_feedback_blocks()],
|
||||
submit="send",
|
||||
close="cancel",
|
||||
)
|
||||
|
||||
client.web_client.views_open(
|
||||
trigger_id=req.payload["trigger_id"], view=data.to_dict()
|
||||
)
|
||||
|
||||
|
||||
def handle_slack_feedback(
|
||||
block_id: str,
|
||||
feedback_id: str,
|
||||
feedback_type: str,
|
||||
client: WebClient,
|
||||
user_id_to_post_confirmation: str,
|
||||
@@ -22,37 +63,43 @@ def handle_slack_feedback(
|
||||
) -> None:
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
query_id, doc_id, doc_rank = decompose_block_id(block_id)
|
||||
message_id, doc_id, doc_rank = decompose_feedback_id(feedback_id)
|
||||
|
||||
with Session(engine) as db_session:
|
||||
if feedback_type in [LIKE_BLOCK_ACTION_ID, DISLIKE_BLOCK_ACTION_ID]:
|
||||
update_query_event_feedback(
|
||||
feedback=QAFeedbackType.LIKE
|
||||
if feedback_type == LIKE_BLOCK_ACTION_ID
|
||||
else QAFeedbackType.DISLIKE,
|
||||
query_id=query_id,
|
||||
create_chat_message_feedback(
|
||||
is_positive=feedback_type == LIKE_BLOCK_ACTION_ID,
|
||||
feedback_text="",
|
||||
chat_message_id=message_id,
|
||||
user_id=None, # no "user" for Slack bot for now
|
||||
db_session=db_session,
|
||||
)
|
||||
if feedback_type in [
|
||||
elif feedback_type in [
|
||||
SearchFeedbackType.ENDORSE.value,
|
||||
SearchFeedbackType.REJECT.value,
|
||||
SearchFeedbackType.HIDE.value,
|
||||
]:
|
||||
if doc_id is None or doc_rank is None:
|
||||
raise ValueError("Missing information for Document Feedback")
|
||||
|
||||
if feedback_type == SearchFeedbackType.ENDORSE.value:
|
||||
feedback = SearchFeedbackType.ENDORSE
|
||||
elif feedback_type == SearchFeedbackType.REJECT.value:
|
||||
feedback = SearchFeedbackType.REJECT
|
||||
else:
|
||||
feedback = SearchFeedbackType.HIDE
|
||||
|
||||
create_doc_retrieval_feedback(
|
||||
qa_event_id=query_id,
|
||||
message_id=message_id,
|
||||
document_id=doc_id,
|
||||
document_rank=doc_rank,
|
||||
user_id=None,
|
||||
document_index=get_default_document_index(),
|
||||
db_session=db_session,
|
||||
clicked=False, # Not tracking this for Slack
|
||||
feedback=SearchFeedbackType.ENDORSE
|
||||
if feedback_type == SearchFeedbackType.ENDORSE.value
|
||||
else SearchFeedbackType.REJECT,
|
||||
feedback=feedback,
|
||||
)
|
||||
else:
|
||||
logger_base.error(f"Feedback type '{feedback_type}' not supported")
|
||||
|
||||
# post message to slack confirming that feedback was received
|
||||
client.chat_postEphemeral(
|
||||
|
||||
@@ -6,8 +6,8 @@ from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
|
||||
@@ -25,11 +25,15 @@ from danswer.danswerbot.slack.utils import fetch_userids_from_emails
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.direct_qa.answer_question import answer_qa_query
|
||||
from danswer.one_shot_answer.answer_question import get_search_answer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.search.models import BaseFilters
|
||||
from danswer.server.models import QAResponse
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.search.models import OptionalSearchSetting
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
|
||||
logger_base = setup_logger()
|
||||
|
||||
@@ -75,6 +79,7 @@ def handle_message(
|
||||
disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER,
|
||||
disable_auto_detect_filters: bool = DISABLE_DANSWER_BOT_FILTER_DETECT,
|
||||
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
|
||||
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
|
||||
) -> bool:
|
||||
"""Potentially respond to the user message depending on filters and if an answer was generated
|
||||
|
||||
@@ -83,36 +88,56 @@ def handle_message(
|
||||
Query thrown out by filters due to config does not count as a failure that should be notified
|
||||
Danswer failing to answer/retrieve docs does count and should be notified
|
||||
"""
|
||||
msg = message_info.msg_content
|
||||
channel = message_info.channel_to_respond
|
||||
message_ts_to_respond_to = message_info.msg_to_respond
|
||||
sender_id = message_info.sender
|
||||
bipass_filters = message_info.bipass_filters
|
||||
is_bot_msg = message_info.is_bot_msg
|
||||
|
||||
logger = cast(
|
||||
logging.Logger,
|
||||
ChannelIdAdapter(logger_base, extra={SLACK_CHANNEL_ID: channel}),
|
||||
)
|
||||
|
||||
messages = message_info.thread_messages
|
||||
message_ts_to_respond_to = message_info.msg_to_respond
|
||||
sender_id = message_info.sender
|
||||
bypass_filters = message_info.bypass_filters
|
||||
is_bot_msg = message_info.is_bot_msg
|
||||
is_bot_dm = message_info.is_bot_dm
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
if channel_config and channel_config.persona:
|
||||
persona = channel_config.persona if channel_config else None
|
||||
prompt = None
|
||||
if persona:
|
||||
document_set_names = [
|
||||
document_set.name for document_set in channel_config.persona.document_sets
|
||||
document_set.name for document_set in persona.document_sets
|
||||
]
|
||||
prompt = persona.prompts[0] if persona.prompts else None
|
||||
|
||||
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
|
||||
|
||||
# List of user id to send message to, if None, send to everyone in channel
|
||||
send_to: list[str] | None = None
|
||||
respond_tag_only = False
|
||||
respond_team_member_list = None
|
||||
|
||||
bypass_acl = False
|
||||
if (
|
||||
channel_config
|
||||
and channel_config.persona
|
||||
and channel_config.persona.document_sets
|
||||
):
|
||||
# For Slack channels, use the full document set, admin will be warned when configuring it
|
||||
# with non-public document sets
|
||||
bypass_acl = True
|
||||
|
||||
if channel_config and channel_config.channel_config:
|
||||
channel_conf = channel_config.channel_config
|
||||
if not bipass_filters and "answer_filters" in channel_conf:
|
||||
if not bypass_filters and "answer_filters" in channel_conf:
|
||||
reflexion = "well_answered_postfilter" in channel_conf["answer_filters"]
|
||||
|
||||
if (
|
||||
"questionmark_prefilter" in channel_conf["answer_filters"]
|
||||
and "?" not in msg
|
||||
and "?" not in messages[-1].message
|
||||
):
|
||||
logger.info(
|
||||
"Skipping message since it does not contain a question mark"
|
||||
@@ -128,7 +153,7 @@ def handle_message(
|
||||
respond_tag_only = channel_conf.get("respond_tag_only") or False
|
||||
respond_team_member_list = channel_conf.get("respond_team_member_list") or None
|
||||
|
||||
if respond_tag_only and not bipass_filters:
|
||||
if respond_tag_only and not bypass_filters:
|
||||
logger.info(
|
||||
"Skipping message since the channel is configured such that "
|
||||
"DanswerBot only responds to tags"
|
||||
@@ -161,24 +186,34 @@ def handle_message(
|
||||
backoff=2,
|
||||
logger=logger,
|
||||
)
|
||||
def _get_answer(question: QuestionRequest) -> QAResponse:
|
||||
engine = get_sqlalchemy_engine()
|
||||
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse:
|
||||
action = "slack_message"
|
||||
if is_bot_msg:
|
||||
action = "slack_slash_message"
|
||||
elif bypass_filters:
|
||||
action = "slack_tag_message"
|
||||
elif is_bot_dm:
|
||||
action = "slack_dm_message"
|
||||
optional_telemetry(
|
||||
record_type=RecordType.USAGE,
|
||||
data={"action": action},
|
||||
)
|
||||
|
||||
with Session(engine, expire_on_commit=False) as db_session:
|
||||
# This also handles creating the query event in postgres
|
||||
answer = answer_qa_query(
|
||||
question=question,
|
||||
answer = get_search_answer(
|
||||
query_req=new_message_request,
|
||||
user=None,
|
||||
db_session=db_session,
|
||||
answer_generation_timeout=answer_generation_timeout,
|
||||
real_time_flow=False,
|
||||
enable_reflexion=reflexion,
|
||||
bypass_acl=bypass_acl,
|
||||
)
|
||||
if not answer.error_msg:
|
||||
return answer
|
||||
else:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
|
||||
answer_failed = False
|
||||
try:
|
||||
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
|
||||
# it allows the slack flow to extract out filters from the user query
|
||||
@@ -188,19 +223,30 @@ def handle_message(
|
||||
time_cutoff=None,
|
||||
)
|
||||
|
||||
auto_detect_filters = (
|
||||
persona.llm_filter_extraction if persona is not None else False
|
||||
)
|
||||
if disable_auto_detect_filters:
|
||||
auto_detect_filters = False
|
||||
|
||||
retrieval_details = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=False,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=auto_detect_filters,
|
||||
)
|
||||
|
||||
# This includes throwing out answer via reflexion
|
||||
answer = _get_answer(
|
||||
QuestionRequest(
|
||||
query=msg,
|
||||
collection=DOCUMENT_INDEX_NAME,
|
||||
enable_auto_detect_filters=not disable_auto_detect_filters,
|
||||
filters=filters,
|
||||
favor_recent=None,
|
||||
offset=None,
|
||||
DirectQARequest(
|
||||
messages=messages,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
persona_id=persona.id if persona is not None else 0,
|
||||
retrieval_options=retrieval_details,
|
||||
chain_of_thought=not disable_cot,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
answer_failed = True
|
||||
logger.exception(
|
||||
f"Unable to process message - did not successfully answer "
|
||||
f"in {num_retries} attempts"
|
||||
@@ -216,15 +262,21 @@ def handle_message(
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
# In case of failures, don't keep the reaction there permanently
|
||||
try:
|
||||
remove_react(message_info, client)
|
||||
except SlackApiError as e:
|
||||
logger.error(f"Failed to remove Reaction due to: {e}")
|
||||
|
||||
return True
|
||||
|
||||
# Got an answer at this point, can remove reaction and give results
|
||||
try:
|
||||
remove_react(message_info, client)
|
||||
except SlackApiError as e:
|
||||
logger.error(f"Failed to remove Reaction due to: {e}")
|
||||
|
||||
if answer_failed:
|
||||
return True
|
||||
|
||||
if answer.eval_res_valid is False:
|
||||
if answer.answer_valid is False:
|
||||
logger.info(
|
||||
"Answer was evaluated to be invalid, throwing it away without responding."
|
||||
)
|
||||
@@ -232,10 +284,18 @@ def handle_message(
|
||||
logger.debug(answer.answer)
|
||||
return True
|
||||
|
||||
if not answer.top_ranked_docs:
|
||||
logger.error(f"Unable to answer question: '{msg}' - no documents found")
|
||||
# Optionally, respond in thread with the error message, Used primarily
|
||||
# for debugging purposes
|
||||
retrieval_info = answer.docs
|
||||
if not retrieval_info:
|
||||
# This should not happen, even with no docs retrieved, there is still info returned
|
||||
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
|
||||
|
||||
top_docs = retrieval_info.top_documents
|
||||
if not top_docs and not should_respond_even_with_no_docs:
|
||||
logger.error(
|
||||
f"Unable to answer question: '{answer.rephrase}' - no documents found"
|
||||
)
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
@@ -254,18 +314,32 @@ def handle_message(
|
||||
return True
|
||||
|
||||
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
|
||||
restate_question_block = get_restate_blocks(msg, is_bot_msg)
|
||||
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
|
||||
|
||||
answer_blocks = build_qa_response_blocks(
|
||||
query_event_id=answer.query_event_id,
|
||||
message_id=answer.chat_message_id,
|
||||
answer=answer.answer,
|
||||
quotes=answer.quotes,
|
||||
time_cutoff=answer.time_cutoff,
|
||||
favor_recent=answer.favor_recent,
|
||||
quotes=answer.quotes.quotes if answer.quotes else None,
|
||||
source_filters=retrieval_info.applied_source_filters,
|
||||
time_cutoff=retrieval_info.applied_time_cutoff,
|
||||
favor_recent=retrieval_info.recency_bias_multiplier > 1,
|
||||
skip_quotes=persona is not None, # currently Personas don't support quotes
|
||||
)
|
||||
|
||||
document_blocks = build_documents_blocks(
|
||||
documents=answer.top_ranked_docs, query_event_id=answer.query_event_id
|
||||
# Get the chunks fed to the LLM only, then fill with other docs
|
||||
llm_doc_inds = answer.llm_chunks_indices or []
|
||||
llm_docs = [top_docs[i] for i in llm_doc_inds]
|
||||
remaining_docs = [
|
||||
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
|
||||
]
|
||||
priority_ordered_docs = llm_docs + remaining_docs
|
||||
document_blocks = (
|
||||
build_documents_blocks(
|
||||
documents=priority_ordered_docs,
|
||||
message_id=answer.chat_message_id,
|
||||
)
|
||||
if priority_ordered_docs
|
||||
else []
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import re
|
||||
import time
|
||||
from threading import Event
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -9,44 +9,39 @@ from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
|
||||
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
|
||||
from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW
|
||||
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
|
||||
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
|
||||
from danswer.danswerbot.slack.constants import VIEW_DOC_FEEDBACK_ID
|
||||
from danswer.danswerbot.slack.handlers.handle_feedback import handle_doc_feedback_button
|
||||
from danswer.danswerbot.slack.handlers.handle_feedback import handle_slack_feedback
|
||||
from danswer.danswerbot.slack.handlers.handle_message import handle_message
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.tokens import fetch_tokens
|
||||
from danswer.danswerbot.slack.utils import ChannelIdAdapter
|
||||
from danswer.danswerbot.slack.utils import decompose_block_id
|
||||
from danswer.danswerbot.slack.utils import decompose_feedback_id
|
||||
from danswer.danswerbot.slack.utils import get_channel_name_from_id
|
||||
from danswer.danswerbot.slack.utils import get_danswer_bot_app_id
|
||||
from danswer.danswerbot.slack.utils import get_view_values
|
||||
from danswer.danswerbot.slack.utils import read_slack_thread
|
||||
from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.search.search_nlp_models import warm_up_models
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class MissingTokensException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _get_socket_client() -> SocketModeClient:
|
||||
# For more info on how to set this up, checkout the docs:
|
||||
# https://docs.danswer.dev/slack_bot_setup
|
||||
try:
|
||||
slack_bot_tokens = fetch_tokens()
|
||||
except ConfigNotFoundError:
|
||||
raise MissingTokensException("Slack tokens not found")
|
||||
return SocketModeClient(
|
||||
# This app-level token will be used only for establishing a connection
|
||||
app_token=slack_bot_tokens.app_token,
|
||||
web_client=WebClient(token=slack_bot_tokens.bot_token),
|
||||
)
|
||||
|
||||
|
||||
def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool:
|
||||
"""True to keep going, False to ignore this Slack request"""
|
||||
if req.type == "events_api":
|
||||
@@ -77,7 +72,7 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
|
||||
return False
|
||||
|
||||
if event_type == "message":
|
||||
bot_tag_id = client.web_client.auth_test().get("user_id")
|
||||
bot_tag_id = get_danswer_bot_app_id(client.web_client)
|
||||
# DMs with the bot don't pick up the @DanswerBot so we have to keep the
|
||||
# caught events_api
|
||||
if bot_tag_id and bot_tag_id in msg and event.get("channel_type") != "im":
|
||||
@@ -101,8 +96,14 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
|
||||
message_ts = event.get("ts")
|
||||
thread_ts = event.get("thread_ts")
|
||||
# Pick the root of the thread (if a thread exists)
|
||||
if thread_ts and message_ts != thread_ts:
|
||||
channel_specific_logger.info(
|
||||
# Can respond in thread if it's an "im" directly to Danswer or @DanswerBot is tagged
|
||||
if (
|
||||
thread_ts
|
||||
and message_ts != thread_ts
|
||||
and event_type != "app_mention"
|
||||
and event.get("channel_type") != "im"
|
||||
):
|
||||
channel_specific_logger.debug(
|
||||
"Skipping message since it is not the root of a thread"
|
||||
)
|
||||
return False
|
||||
@@ -135,28 +136,41 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
|
||||
|
||||
|
||||
def process_feedback(req: SocketModeRequest, client: SocketModeClient) -> None:
|
||||
actions = req.payload.get("actions")
|
||||
if not actions:
|
||||
logger.error("Unable to process block actions - no actions found")
|
||||
# Answer feedback
|
||||
if actions := req.payload.get("actions"):
|
||||
action = cast(dict[str, Any], actions[0])
|
||||
feedback_type = cast(str, action.get("action_id"))
|
||||
feedback_id = cast(str, action.get("block_id"))
|
||||
channel_id = cast(str, req.payload["container"]["channel_id"])
|
||||
thread_ts = cast(str, req.payload["container"]["thread_ts"])
|
||||
# Doc feedback
|
||||
elif view := req.payload.get("view"):
|
||||
view_values = get_view_values(view["state"]["values"])
|
||||
private_metadata = view.get("private_metadata").split("_")
|
||||
if not view_values:
|
||||
logger.error("Unable to process feedback. Missing view values")
|
||||
return
|
||||
|
||||
feedback_type = [x for x in view_values.values()][0]
|
||||
feedback_id = cast(str, view.get("external_id"))
|
||||
channel_id = private_metadata[0]
|
||||
thread_ts = private_metadata[1]
|
||||
else:
|
||||
logger.error("Unable to process feedback. Actions or View not found")
|
||||
return
|
||||
|
||||
action = cast(dict[str, Any], actions[0])
|
||||
action_id = cast(str, action.get("action_id"))
|
||||
block_id = cast(str, action.get("block_id"))
|
||||
user_id = cast(str, req.payload["user"]["id"])
|
||||
channel_id = cast(str, req.payload["container"]["channel_id"])
|
||||
thread_ts = cast(str, req.payload["container"]["thread_ts"])
|
||||
|
||||
handle_slack_feedback(
|
||||
block_id=block_id,
|
||||
feedback_type=action_id,
|
||||
feedback_id=feedback_id,
|
||||
feedback_type=feedback_type,
|
||||
client=client.web_client,
|
||||
user_id_to_post_confirmation=user_id,
|
||||
channel_id_to_post_confirmation=channel_id,
|
||||
thread_ts_to_post_confirmation=thread_ts,
|
||||
)
|
||||
|
||||
query_event_id, _, _ = decompose_block_id(block_id)
|
||||
query_event_id, _, _ = decompose_feedback_id(feedback_id)
|
||||
logger.info(f"Successfully handled QA feedback for event: {query_event_id}")
|
||||
|
||||
|
||||
@@ -170,21 +184,29 @@ def build_request_details(
|
||||
tagged = event.get("type") == "app_mention"
|
||||
message_ts = event.get("ts")
|
||||
thread_ts = event.get("thread_ts")
|
||||
bot_tag_id = client.web_client.auth_test().get("user_id")
|
||||
# Might exist even if not tagged, specifically in the case of @DanswerBot
|
||||
# in DanswerBot DM channel
|
||||
msg = re.sub(rf"<@{bot_tag_id}>\s", "", msg)
|
||||
|
||||
msg = remove_danswer_bot_tag(msg, client=client.web_client)
|
||||
|
||||
if tagged:
|
||||
logger.info("User tagged DanswerBot")
|
||||
|
||||
if thread_ts != message_ts and thread_ts is not None:
|
||||
thread_messages = read_slack_thread(
|
||||
channel=channel, thread=thread_ts, client=client.web_client
|
||||
)
|
||||
else:
|
||||
thread_messages = [
|
||||
ThreadMessage(message=msg, sender=None, role=MessageType.USER)
|
||||
]
|
||||
|
||||
return SlackMessageInfo(
|
||||
msg_content=msg,
|
||||
thread_messages=thread_messages,
|
||||
channel_to_respond=channel,
|
||||
msg_to_respond=cast(str, thread_ts or message_ts),
|
||||
msg_to_respond=cast(str, message_ts or thread_ts),
|
||||
sender=event.get("user") or None,
|
||||
bipass_filters=tagged,
|
||||
bypass_filters=tagged,
|
||||
is_bot_msg=False,
|
||||
is_bot_dm=event.get("channel_type") == "im",
|
||||
)
|
||||
|
||||
elif req.type == "slash_commands":
|
||||
@@ -192,13 +214,16 @@ def build_request_details(
|
||||
msg = req.payload["text"]
|
||||
sender = req.payload["user_id"]
|
||||
|
||||
single_msg = ThreadMessage(message=msg, sender=None, role=MessageType.USER)
|
||||
|
||||
return SlackMessageInfo(
|
||||
msg_content=msg,
|
||||
thread_messages=[single_msg],
|
||||
channel_to_respond=channel,
|
||||
msg_to_respond=None,
|
||||
sender=sender,
|
||||
bipass_filters=True,
|
||||
bypass_filters=True,
|
||||
is_bot_msg=True,
|
||||
is_bot_dm=False,
|
||||
)
|
||||
|
||||
raise RuntimeError("Programming fault, this should never happen.")
|
||||
@@ -247,8 +272,9 @@ def process_message(
|
||||
and not respond_every_channel
|
||||
# Can't have configs for DMs so don't toss them out
|
||||
and not is_dm
|
||||
# If @DanswerBot or /DanswerBot, always respond with the default configs
|
||||
and not (details.is_bot_msg or details.bipass_filters)
|
||||
# If /DanswerBot (is_bot_msg) or @DanswerBot (bypass_filters)
|
||||
# always respond with the default configs
|
||||
and not (details.is_bot_msg or details.bypass_filters)
|
||||
):
|
||||
return
|
||||
|
||||
@@ -268,21 +294,59 @@ def acknowledge_message(req: SocketModeRequest, client: SocketModeClient) -> Non
|
||||
client.send_socket_mode_response(response)
|
||||
|
||||
|
||||
def action_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
|
||||
if actions := req.payload.get("actions"):
|
||||
action = cast(dict[str, Any], actions[0])
|
||||
|
||||
if action["action_id"] in [DISLIKE_BLOCK_ACTION_ID, LIKE_BLOCK_ACTION_ID]:
|
||||
# AI Answer feedback
|
||||
return process_feedback(req, client)
|
||||
elif action["action_id"] == FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID:
|
||||
# Activation of the "source feedback" button
|
||||
return handle_doc_feedback_button(req, client)
|
||||
|
||||
|
||||
def view_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
|
||||
if view := req.payload.get("view"):
|
||||
if view["callback_id"] == VIEW_DOC_FEEDBACK_ID:
|
||||
return process_feedback(req, client)
|
||||
|
||||
|
||||
def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> None:
|
||||
# Always respond right away, if Slack doesn't receive these frequently enough
|
||||
# it will assume the Bot is DEAD!!! :(
|
||||
acknowledge_message(req, client)
|
||||
|
||||
try:
|
||||
if req.type == "interactive" and req.payload.get("type") == "block_actions":
|
||||
return process_feedback(req, client)
|
||||
|
||||
if req.type == "interactive":
|
||||
if req.payload.get("type") == "block_actions":
|
||||
return action_routing(req, client)
|
||||
elif req.payload.get("type") == "view_submission":
|
||||
return view_routing(req, client)
|
||||
elif req.type == "events_api" or req.type == "slash_commands":
|
||||
return process_message(req, client)
|
||||
except Exception:
|
||||
logger.exception("Failed to process slack event")
|
||||
|
||||
|
||||
def _get_socket_client(slack_bot_tokens: SlackBotTokens) -> SocketModeClient:
|
||||
# For more info on how to set this up, checkout the docs:
|
||||
# https://docs.danswer.dev/slack_bot_setup
|
||||
return SocketModeClient(
|
||||
# This app-level token will be used only for establishing a connection
|
||||
app_token=slack_bot_tokens.app_token,
|
||||
web_client=WebClient(token=slack_bot_tokens.bot_token),
|
||||
)
|
||||
|
||||
|
||||
def _initialize_socket_client(socket_client: SocketModeClient) -> None:
|
||||
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
|
||||
|
||||
# Establish a WebSocket connection to the Socket Mode servers
|
||||
logger.info("Listening for messages from Slack...")
|
||||
socket_client.connect()
|
||||
|
||||
|
||||
# Follow the guide (https://docs.danswer.dev/slack_bot_setup) to set up
|
||||
# the slack bot in your workspace, and then add the bot to any channels you want to
|
||||
# try and answer questions for. Running this file will setup Danswer to listen to all
|
||||
@@ -293,21 +357,37 @@ def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> Non
|
||||
# NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC
|
||||
# without issue.
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
socket_client = _get_socket_client()
|
||||
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
|
||||
warm_up_models(skip_cross_encoders=not ENABLE_RERANKING_ASYNC_FLOW)
|
||||
|
||||
# Establish a WebSocket connection to the Socket Mode servers
|
||||
logger.info("Listening for messages from Slack...")
|
||||
socket_client.connect()
|
||||
slack_bot_tokens: SlackBotTokens | None = None
|
||||
socket_client: SocketModeClient | None = None
|
||||
while True:
|
||||
try:
|
||||
latest_slack_bot_tokens = fetch_tokens()
|
||||
|
||||
# Just not to stop this process
|
||||
from threading import Event
|
||||
if latest_slack_bot_tokens != slack_bot_tokens:
|
||||
if slack_bot_tokens is not None:
|
||||
logger.info("Slack Bot tokens have changed - reconnecting")
|
||||
slack_bot_tokens = latest_slack_bot_tokens
|
||||
# potentially may cause a message to be dropped, but it is complicated
|
||||
# to avoid + (1) if the user is changing tokens, they are likely okay with some
|
||||
# "migration downtime" and (2) if a single message is lost it is okay
|
||||
# as this should be a very rare occurrence
|
||||
if socket_client:
|
||||
socket_client.close()
|
||||
|
||||
Event().wait()
|
||||
except MissingTokensException:
|
||||
# try again every 30 seconds. This is needed since the user may add tokens
|
||||
# via the UI at any point in the programs lifecycle - if we just allow it to
|
||||
# fail, then the user will need to restart the containers after adding tokens
|
||||
logger.debug("Missing Slack Bot tokens - waiting 60 seconds and trying again")
|
||||
time.sleep(60)
|
||||
socket_client = _get_socket_client(slack_bot_tokens)
|
||||
_initialize_socket_client(socket_client)
|
||||
|
||||
# Let the handlers run in the background + re-check for token updates every 60 seconds
|
||||
Event().wait(timeout=60)
|
||||
except ConfigNotFoundError:
|
||||
# try again every 30 seconds. This is needed since the user may add tokens
|
||||
# via the UI at any point in the programs lifecycle - if we just allow it to
|
||||
# fail, then the user will need to restart the containers after adding tokens
|
||||
logger.debug(
|
||||
"Missing Slack Bot tokens - waiting 60 seconds and trying again"
|
||||
)
|
||||
if socket_client:
|
||||
socket_client.disconnect()
|
||||
time.sleep(60)
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
|
||||
|
||||
class SlackMessageInfo(BaseModel):
|
||||
msg_content: str
|
||||
thread_messages: list[ThreadMessage]
|
||||
channel_to_respond: str
|
||||
msg_to_respond: str | None
|
||||
sender: str | None
|
||||
bipass_filters: bool
|
||||
is_bot_msg: bool
|
||||
bypass_filters: bool # User has tagged @DanswerBot
|
||||
is_bot_msg: bool # User is using /DanswerBot
|
||||
is_bot_dm: bool # User is direct messaging to DanswerBot
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
from typing import cast
|
||||
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from danswer.server.models import SlackBotTokens
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
|
||||
|
||||
_SLACK_BOT_TOKENS_CONFIG_KEY = "slack_bot_tokens_config_key"
|
||||
|
||||
@@ -13,17 +13,34 @@ from slack_sdk.models.blocks import Block
|
||||
from slack_sdk.models.metadata import Metadata
|
||||
|
||||
from danswer.configs.constants import ID_SEPARATOR
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
|
||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.connectors.slack.utils import SlackTextCleaner
|
||||
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
|
||||
from danswer.danswerbot.slack.tokens import fetch_tokens
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import replace_whitespaces_w_space
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
DANSWER_BOT_APP_ID: str | None = None
|
||||
|
||||
|
||||
def get_danswer_bot_app_id(web_client: WebClient) -> Any:
|
||||
global DANSWER_BOT_APP_ID
|
||||
if DANSWER_BOT_APP_ID is None:
|
||||
DANSWER_BOT_APP_ID = web_client.auth_test().get("user_id")
|
||||
return DANSWER_BOT_APP_ID
|
||||
|
||||
|
||||
def remove_danswer_bot_tag(message_str: str, client: WebClient) -> str:
|
||||
bot_tag_id = get_danswer_bot_app_id(web_client=client)
|
||||
return re.sub(rf"<@{bot_tag_id}>\s", "", message_str)
|
||||
|
||||
|
||||
class ChannelIdAdapter(logging.LoggerAdapter):
|
||||
"""This is used to add the channel ID to all log messages
|
||||
emitted in this file"""
|
||||
@@ -95,8 +112,8 @@ def respond_in_thread(
|
||||
raise RuntimeError(f"Failed to post message: {response}")
|
||||
|
||||
|
||||
def build_feedback_block_id(
|
||||
query_event_id: int,
|
||||
def build_feedback_id(
|
||||
message_id: int,
|
||||
document_id: str | None = None,
|
||||
document_rank: int | None = None,
|
||||
) -> str:
|
||||
@@ -108,21 +125,21 @@ def build_feedback_block_id(
|
||||
raise ValueError(
|
||||
"Separator pattern should not already exist in document id"
|
||||
)
|
||||
block_id = ID_SEPARATOR.join(
|
||||
[str(query_event_id), document_id, str(document_rank)]
|
||||
feedback_id = ID_SEPARATOR.join(
|
||||
[str(message_id), document_id, str(document_rank)]
|
||||
)
|
||||
else:
|
||||
block_id = str(query_event_id)
|
||||
feedback_id = str(message_id)
|
||||
|
||||
return unique_prefix + ID_SEPARATOR + block_id
|
||||
return unique_prefix + ID_SEPARATOR + feedback_id
|
||||
|
||||
|
||||
def decompose_block_id(block_id: str) -> tuple[int, str | None, int | None]:
|
||||
def decompose_feedback_id(feedback_id: str) -> tuple[int, str | None, int | None]:
|
||||
"""Decompose into query_id, document_id, document_rank, see above function"""
|
||||
try:
|
||||
components = block_id.split(ID_SEPARATOR)
|
||||
components = feedback_id.split(ID_SEPARATOR)
|
||||
if len(components) != 2 and len(components) != 4:
|
||||
raise ValueError("Block ID does not contain right number of elements")
|
||||
raise ValueError("Feedback ID does not contain right number of elements")
|
||||
|
||||
if len(components) == 2:
|
||||
return int(components[-1]), None, None
|
||||
@@ -131,7 +148,36 @@ def decompose_block_id(block_id: str) -> tuple[int, str | None, int | None]:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise ValueError("Received invalid Feedback Block Identifier")
|
||||
raise ValueError("Received invalid Feedback Identifier")
|
||||
|
||||
|
||||
def get_view_values(state_values: dict[str, Any]) -> dict[str, str]:
|
||||
"""Extract view values
|
||||
|
||||
Args:
|
||||
state_values (dict): The Slack view-submission values
|
||||
|
||||
Returns:
|
||||
dict: keys/values of the view state content
|
||||
"""
|
||||
view_values = {}
|
||||
for _, view_data in state_values.items():
|
||||
for k, v in view_data.items():
|
||||
if (
|
||||
"selected_option" in v
|
||||
and isinstance(v["selected_option"], dict)
|
||||
and "value" in v["selected_option"]
|
||||
):
|
||||
view_values[k] = v["selected_option"]["value"]
|
||||
elif "selected_options" in v and isinstance(v["selected_options"], list):
|
||||
view_values[k] = [
|
||||
x["value"] for x in v["selected_options"] if "value" in x
|
||||
]
|
||||
elif "selected_date" in v:
|
||||
view_values[k] = v["selected_date"]
|
||||
elif "value" in v:
|
||||
view_values[k] = v["value"]
|
||||
return view_values
|
||||
|
||||
|
||||
def translate_vespa_highlight_to_slack(match_strs: list[str], used_chars: int) -> str:
|
||||
@@ -201,3 +247,57 @@ def fetch_userids_from_emails(user_emails: list[str], client: WebClient) -> list
|
||||
)
|
||||
|
||||
return user_ids
|
||||
|
||||
|
||||
def fetch_user_semantic_id_from_id(user_id: str, client: WebClient) -> str | None:
|
||||
response = client.users_info(user=user_id)
|
||||
if not response["ok"]:
|
||||
return None
|
||||
|
||||
user: dict = cast(dict[Any, dict], response.data).get("user", {})
|
||||
|
||||
return (
|
||||
user.get("real_name")
|
||||
or user.get("name")
|
||||
or user.get("profile", {}).get("email")
|
||||
)
|
||||
|
||||
|
||||
def read_slack_thread(
|
||||
channel: str, thread: str, client: WebClient
|
||||
) -> list[ThreadMessage]:
|
||||
thread_messages: list[ThreadMessage] = []
|
||||
response = client.conversations_replies(channel=channel, ts=thread)
|
||||
replies = cast(dict, response.data).get("messages", [])
|
||||
for reply in replies:
|
||||
if "user" in reply and "bot_id" not in reply:
|
||||
message = remove_danswer_bot_tag(reply["text"], client=client)
|
||||
user_sem_id = fetch_user_semantic_id_from_id(reply["user"], client)
|
||||
message_type = MessageType.USER
|
||||
else:
|
||||
self_app_id = get_danswer_bot_app_id(client)
|
||||
|
||||
# Only include bot messages from Danswer, other bots are not taken in as context
|
||||
if self_app_id != reply.get("user"):
|
||||
continue
|
||||
|
||||
blocks = reply["blocks"]
|
||||
if len(blocks) <= 1:
|
||||
continue
|
||||
|
||||
# The useful block is the second one after the header block that says AI Answer
|
||||
message = reply["blocks"][1]["text"]["text"]
|
||||
|
||||
if message.startswith("_Filters"):
|
||||
if len(blocks) <= 2:
|
||||
continue
|
||||
message = reply["blocks"][2]["text"]["text"]
|
||||
|
||||
user_sem_id = "Assistant"
|
||||
message_type = MessageType.ASSISTANT
|
||||
|
||||
thread_messages.append(
|
||||
ThreadMessage(message=message, sender=user_sem_id, role=message_type)
|
||||
)
|
||||
|
||||
return thread_messages
|
||||
|
||||
0
backend/danswer/db/__init__.py
Normal file
0
backend/danswer/db/__init__.py
Normal file
@@ -1,30 +1,64 @@
|
||||
from typing import Any
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import not_
|
||||
from sqlalchemy import nullsfirst
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.exc import MultipleResultsFound
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import HARD_DELETE_CHATS
|
||||
from danswer.configs.chat_configs import HARD_DELETE_CHATS
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import ChatSession
|
||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||
from danswer.db.models import DocumentSet as DBDocumentSet
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import ToolInfo
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import SearchDoc
|
||||
from danswer.db.models import SearchDoc as DBSearchDoc
|
||||
from danswer.search.models import RecencyBiasSetting
|
||||
from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SavedSearchDoc
|
||||
from danswer.search.models import SearchDoc as ServerSearchDoc
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def fetch_chat_sessions_by_user(
|
||||
def get_chat_session_by_id(
|
||||
chat_session_id: int, user_id: UUID | None, db_session: Session
|
||||
) -> ChatSession:
|
||||
stmt = select(ChatSession).where(
|
||||
ChatSession.id == chat_session_id, ChatSession.user_id == user_id
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
chat_session = result.scalar_one_or_none()
|
||||
|
||||
if not chat_session:
|
||||
raise ValueError("Invalid Chat Session ID provided")
|
||||
|
||||
if chat_session.deleted:
|
||||
raise ValueError("Chat session has been deleted")
|
||||
|
||||
return chat_session
|
||||
|
||||
|
||||
def get_chat_sessions_by_user(
|
||||
user_id: UUID | None,
|
||||
deleted: bool | None,
|
||||
db_session: Session,
|
||||
include_one_shot: bool = False,
|
||||
) -> list[ChatSession]:
|
||||
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
|
||||
|
||||
if not include_one_shot:
|
||||
stmt = stmt.where(ChatSession.one_shot.is_(False))
|
||||
|
||||
if deleted is not None:
|
||||
stmt = stmt.where(ChatSession.deleted == deleted)
|
||||
|
||||
@@ -34,76 +68,18 @@ def fetch_chat_sessions_by_user(
|
||||
return list(chat_sessions)
|
||||
|
||||
|
||||
def fetch_chat_messages_by_session(
|
||||
chat_session_id: int, db_session: Session
|
||||
) -> list[ChatMessage]:
|
||||
stmt = (
|
||||
select(ChatMessage)
|
||||
.where(ChatMessage.chat_session_id == chat_session_id)
|
||||
.order_by(ChatMessage.message_number.asc(), ChatMessage.edit_number.asc())
|
||||
)
|
||||
result = db_session.execute(stmt).scalars().all()
|
||||
return list(result)
|
||||
|
||||
|
||||
def fetch_chat_message(
|
||||
chat_session_id: int, message_number: int, edit_number: int, db_session: Session
|
||||
) -> ChatMessage:
|
||||
stmt = (
|
||||
select(ChatMessage)
|
||||
.where(
|
||||
(ChatMessage.chat_session_id == chat_session_id)
|
||||
& (ChatMessage.message_number == message_number)
|
||||
& (ChatMessage.edit_number == edit_number)
|
||||
)
|
||||
.options(selectinload(ChatMessage.chat_session))
|
||||
)
|
||||
|
||||
chat_message = db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not chat_message:
|
||||
raise ValueError("Invalid Chat Message specified")
|
||||
|
||||
return chat_message
|
||||
|
||||
|
||||
def fetch_chat_session_by_id(chat_session_id: int, db_session: Session) -> ChatSession:
|
||||
stmt = select(ChatSession).where(ChatSession.id == chat_session_id)
|
||||
result = db_session.execute(stmt)
|
||||
chat_session = result.scalar_one_or_none()
|
||||
|
||||
if not chat_session:
|
||||
raise ValueError("Invalid Chat Session ID provided")
|
||||
|
||||
return chat_session
|
||||
|
||||
|
||||
def verify_parent_exists(
|
||||
chat_session_id: int,
|
||||
message_number: int,
|
||||
parent_edit_number: int | None,
|
||||
db_session: Session,
|
||||
) -> ChatMessage:
|
||||
stmt = select(ChatMessage).where(
|
||||
(ChatMessage.chat_session_id == chat_session_id)
|
||||
& (ChatMessage.message_number == message_number - 1)
|
||||
& (ChatMessage.edit_number == parent_edit_number)
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
|
||||
try:
|
||||
return result.scalar_one()
|
||||
except NoResultFound:
|
||||
raise ValueError("Invalid message, parent message not found")
|
||||
|
||||
|
||||
def create_chat_session(
|
||||
description: str, user_id: UUID | None, db_session: Session
|
||||
db_session: Session,
|
||||
description: str,
|
||||
user_id: UUID | None,
|
||||
persona_id: int | None = None,
|
||||
one_shot: bool = False,
|
||||
) -> ChatSession:
|
||||
chat_session = ChatSession(
|
||||
user_id=user_id,
|
||||
persona_id=persona_id,
|
||||
description=description,
|
||||
one_shot=one_shot,
|
||||
)
|
||||
|
||||
db_session.add(chat_session)
|
||||
@@ -115,14 +91,13 @@ def create_chat_session(
|
||||
def update_chat_session(
|
||||
user_id: UUID | None, chat_session_id: int, description: str, db_session: Session
|
||||
) -> ChatSession:
|
||||
chat_session = fetch_chat_session_by_id(chat_session_id, db_session)
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
|
||||
)
|
||||
|
||||
if chat_session.deleted:
|
||||
raise ValueError("Trying to rename a deleted chat session")
|
||||
|
||||
if user_id != chat_session.user_id:
|
||||
raise ValueError("User trying to update chat of another user.")
|
||||
|
||||
chat_session.description = description
|
||||
|
||||
db_session.commit()
|
||||
@@ -136,10 +111,9 @@ def delete_chat_session(
|
||||
db_session: Session,
|
||||
hard_delete: bool = HARD_DELETE_CHATS,
|
||||
) -> None:
|
||||
chat_session = fetch_chat_session_by_id(chat_session_id, db_session)
|
||||
|
||||
if user_id != chat_session.user_id:
|
||||
raise ValueError("User trying to delete chat of another user.")
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
|
||||
)
|
||||
|
||||
if hard_delete:
|
||||
stmt_messages = delete(ChatMessage).where(
|
||||
@@ -156,185 +130,374 @@ def delete_chat_session(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _set_latest_chat_message_no_commit(
|
||||
chat_session_id: int,
|
||||
message_number: int,
|
||||
parent_edit_number: int | None,
|
||||
edit_number: int,
|
||||
def get_chat_message(
|
||||
chat_message_id: int,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
if message_number != 0 and parent_edit_number is None:
|
||||
raise ValueError(
|
||||
"Only initial message in a chat is allowed to not have a parent"
|
||||
) -> ChatMessage:
|
||||
stmt = select(ChatMessage).where(ChatMessage.id == chat_message_id)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
chat_message = result.scalar_one_or_none()
|
||||
|
||||
if not chat_message:
|
||||
raise ValueError("Invalid Chat Message specified")
|
||||
|
||||
chat_user = chat_message.chat_session.user
|
||||
expected_user_id = chat_user.id if chat_user is not None else None
|
||||
|
||||
if expected_user_id != user_id:
|
||||
logger.error(
|
||||
f"User {user_id} tried to fetch a chat message that does not belong to them"
|
||||
)
|
||||
raise ValueError("Chat message does not belong to user")
|
||||
|
||||
return chat_message
|
||||
|
||||
|
||||
def get_chat_messages_by_session(
|
||||
chat_session_id: int,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
skip_permission_check: bool = False,
|
||||
) -> list[ChatMessage]:
|
||||
if not skip_permission_check:
|
||||
get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
|
||||
)
|
||||
|
||||
db_session.query(ChatMessage).filter(
|
||||
and_(
|
||||
ChatMessage.chat_session_id == chat_session_id,
|
||||
ChatMessage.message_number == message_number,
|
||||
ChatMessage.parent_edit_number == parent_edit_number,
|
||||
)
|
||||
).update({ChatMessage.latest: False})
|
||||
stmt = (
|
||||
select(ChatMessage).where(ChatMessage.chat_session_id == chat_session_id)
|
||||
# Start with the root message which has no parent
|
||||
.order_by(nullsfirst(ChatMessage.parent_message))
|
||||
)
|
||||
|
||||
db_session.query(ChatMessage).filter(
|
||||
and_(
|
||||
ChatMessage.chat_session_id == chat_session_id,
|
||||
ChatMessage.message_number == message_number,
|
||||
ChatMessage.edit_number == edit_number,
|
||||
result = db_session.execute(stmt).scalars().all()
|
||||
|
||||
return list(result)
|
||||
|
||||
|
||||
def get_or_create_root_message(
|
||||
chat_session_id: int,
|
||||
db_session: Session,
|
||||
) -> ChatMessage:
|
||||
try:
|
||||
root_message: ChatMessage | None = (
|
||||
db_session.query(ChatMessage)
|
||||
.filter(
|
||||
ChatMessage.chat_session_id == chat_session_id,
|
||||
ChatMessage.parent_message.is_(None),
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
).update({ChatMessage.latest: True})
|
||||
except MultipleResultsFound:
|
||||
raise Exception(
|
||||
"Multiple root messages found for chat session. Data inconsistency detected."
|
||||
)
|
||||
|
||||
if root_message is not None:
|
||||
return root_message
|
||||
else:
|
||||
new_root_message = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
prompt_id=None,
|
||||
parent_message=None,
|
||||
latest_child_message=None,
|
||||
message="",
|
||||
token_count=0,
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
db_session.add(new_root_message)
|
||||
db_session.commit()
|
||||
return new_root_message
|
||||
|
||||
|
||||
def create_new_chat_message(
|
||||
chat_session_id: int,
|
||||
message_number: int,
|
||||
parent_message: ChatMessage,
|
||||
message: str,
|
||||
prompt_id: int | None,
|
||||
token_count: int,
|
||||
parent_edit_number: int | None,
|
||||
message_type: MessageType,
|
||||
db_session: Session,
|
||||
retrieval_docs: dict[str, Any] | None = None,
|
||||
rephrased_query: str | None = None,
|
||||
error: str | None = None,
|
||||
reference_docs: list[DBSearchDoc] | None = None,
|
||||
# Maps the citation number [n] to the DB SearchDoc
|
||||
citations: dict[int, int] | None = None,
|
||||
commit: bool = True,
|
||||
) -> ChatMessage:
|
||||
"""Creates a new chat message and sets it to the latest message of its parent message"""
|
||||
# Get the count of existing edits at the provided message number
|
||||
latest_edit_number = (
|
||||
db_session.query(func.max(ChatMessage.edit_number))
|
||||
.filter_by(
|
||||
chat_session_id=chat_session_id,
|
||||
message_number=message_number,
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
# The new message is a new edit at the provided message number
|
||||
new_edit_number = latest_edit_number + 1 if latest_edit_number is not None else 0
|
||||
|
||||
# Create a new message and set it to be the latest for its parent message
|
||||
new_chat_message = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
message_number=message_number,
|
||||
parent_edit_number=parent_edit_number,
|
||||
edit_number=new_edit_number,
|
||||
parent_message=parent_message.id,
|
||||
latest_child_message=None,
|
||||
message=message,
|
||||
reference_docs=retrieval_docs,
|
||||
rephrased_query=rephrased_query,
|
||||
prompt_id=prompt_id,
|
||||
token_count=token_count,
|
||||
message_type=message_type,
|
||||
citations=citations,
|
||||
error=error,
|
||||
)
|
||||
|
||||
# SQL Alchemy will propagate this to update the reference_docs' foreign keys
|
||||
if reference_docs:
|
||||
new_chat_message.search_docs = reference_docs
|
||||
|
||||
db_session.add(new_chat_message)
|
||||
|
||||
# Set the previous latest message of the same parent, as no longer the latest
|
||||
_set_latest_chat_message_no_commit(
|
||||
chat_session_id=chat_session_id,
|
||||
message_number=message_number,
|
||||
parent_edit_number=parent_edit_number,
|
||||
edit_number=new_edit_number,
|
||||
db_session=db_session,
|
||||
)
|
||||
# Flush the session to get an ID for the new chat message
|
||||
db_session.flush()
|
||||
|
||||
db_session.commit()
|
||||
parent_message.latest_child_message = new_chat_message.id
|
||||
if commit:
|
||||
db_session.commit()
|
||||
|
||||
return new_chat_message
|
||||
|
||||
|
||||
def set_latest_chat_message(
|
||||
chat_session_id: int,
|
||||
message_number: int,
|
||||
parent_edit_number: int | None,
|
||||
edit_number: int,
|
||||
def set_as_latest_chat_message(
|
||||
chat_message: ChatMessage,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
_set_latest_chat_message_no_commit(
|
||||
chat_session_id=chat_session_id,
|
||||
message_number=message_number,
|
||||
parent_edit_number=parent_edit_number,
|
||||
edit_number=edit_number,
|
||||
db_session=db_session,
|
||||
parent_message_id = chat_message.parent_message
|
||||
|
||||
if parent_message_id is None:
|
||||
raise RuntimeError(
|
||||
f"Trying to set a latest message without parent, message id: {chat_message.id}"
|
||||
)
|
||||
|
||||
parent_message = get_chat_message(
|
||||
chat_message_id=parent_message_id, user_id=user_id, db_session=db_session
|
||||
)
|
||||
|
||||
parent_message.latest_child_message = chat_message.id
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def fetch_persona_by_id(persona_id: int, db_session: Session) -> Persona:
|
||||
def get_prompt_by_id(
|
||||
prompt_id: int,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
include_deleted: bool = False,
|
||||
) -> Prompt:
|
||||
stmt = select(Prompt).where(
|
||||
Prompt.id == prompt_id, or_(Prompt.user_id == user_id, Prompt.user_id.is_(None))
|
||||
)
|
||||
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(Prompt.deleted.is_(False))
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
prompt = result.scalar_one_or_none()
|
||||
|
||||
if prompt is None:
|
||||
raise ValueError(
|
||||
f"Prompt with ID {prompt_id} does not exist or does not belong to user"
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def get_persona_by_id(
|
||||
persona_id: int,
|
||||
# if user_id is `None` assume the user is an admin or auth is disabled
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
include_deleted: bool = False,
|
||||
) -> Persona:
|
||||
stmt = select(Persona).where(Persona.id == persona_id)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(or_(Persona.user_id == user_id, Persona.user_id.is_(None)))
|
||||
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(Persona.deleted.is_(False))
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
persona = result.scalar_one_or_none()
|
||||
|
||||
if persona is None:
|
||||
raise ValueError(f"Persona with ID {persona_id} does not exist")
|
||||
raise ValueError(
|
||||
f"Persona with ID {persona_id} does not exist or does not belong to user"
|
||||
)
|
||||
|
||||
return persona
|
||||
|
||||
|
||||
def fetch_default_persona_by_name(
|
||||
persona_name: str, db_session: Session
|
||||
) -> Persona | None:
|
||||
stmt = select(Persona).where(
|
||||
Persona.name == persona_name, Persona.default_persona == True # noqa: E712
|
||||
)
|
||||
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> Sequence[Prompt]:
|
||||
"""Unsafe, can fetch prompts from all users"""
|
||||
if not prompt_ids:
|
||||
return []
|
||||
prompts = db_session.scalars(select(Prompt).where(Prompt.id.in_(prompt_ids))).all()
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
def get_personas_by_ids(
|
||||
persona_ids: list[int], db_session: Session
|
||||
) -> Sequence[Persona]:
|
||||
"""Unsafe, can fetch personas from all users"""
|
||||
if not persona_ids:
|
||||
return []
|
||||
personas = db_session.scalars(
|
||||
select(Persona).where(Persona.id.in_(persona_ids))
|
||||
).all()
|
||||
|
||||
return personas
|
||||
|
||||
|
||||
def get_prompt_by_name(
|
||||
prompt_name: str, user_id: UUID | None, shared: bool, db_session: Session
|
||||
) -> Prompt | None:
|
||||
"""Cannot do shared and user owned simultaneously as there may be two of those"""
|
||||
stmt = select(Prompt).where(Prompt.name == prompt_name)
|
||||
if shared:
|
||||
stmt = stmt.where(Prompt.user_id.is_(None))
|
||||
else:
|
||||
stmt = stmt.where(Prompt.user_id == user_id)
|
||||
result = db_session.execute(stmt).scalar_one_or_none()
|
||||
return result
|
||||
|
||||
|
||||
def fetch_persona_by_name(persona_name: str, db_session: Session) -> Persona | None:
|
||||
"""Try to fetch a default persona by name first,
|
||||
if not exist, try to find any persona with the name
|
||||
Note that name is not guaranteed unique unless default is true"""
|
||||
persona = fetch_default_persona_by_name(persona_name, db_session)
|
||||
if persona is not None:
|
||||
return persona
|
||||
def get_persona_by_name(
|
||||
persona_name: str, user_id: UUID | None, shared: bool, db_session: Session
|
||||
) -> Persona | None:
|
||||
"""Cannot do shared and user owned simultaneously as there may be two of those"""
|
||||
stmt = select(Persona).where(Persona.name == persona_name)
|
||||
if shared:
|
||||
stmt = stmt.where(Persona.user_id.is_(None))
|
||||
else:
|
||||
stmt = stmt.where(Persona.user_id == user_id)
|
||||
result = db_session.execute(stmt).scalar_one_or_none()
|
||||
return result
|
||||
|
||||
stmt = select(Persona).where(Persona.name == persona_name) # noqa: E712
|
||||
result = db_session.execute(stmt).first()
|
||||
if result:
|
||||
return result[0]
|
||||
return None
|
||||
|
||||
def upsert_prompt(
|
||||
user_id: UUID | None,
|
||||
name: str,
|
||||
description: str,
|
||||
system_prompt: str,
|
||||
task_prompt: str,
|
||||
include_citations: bool,
|
||||
datetime_aware: bool,
|
||||
personas: list[Persona] | None,
|
||||
shared: bool,
|
||||
db_session: Session,
|
||||
prompt_id: int | None = None,
|
||||
default_prompt: bool = True,
|
||||
commit: bool = True,
|
||||
) -> Prompt:
|
||||
if prompt_id is not None:
|
||||
prompt = db_session.query(Prompt).filter_by(id=prompt_id).first()
|
||||
else:
|
||||
prompt = get_prompt_by_name(
|
||||
prompt_name=name, user_id=user_id, shared=shared, db_session=db_session
|
||||
)
|
||||
|
||||
if prompt:
|
||||
if not default_prompt and prompt.default_prompt:
|
||||
raise ValueError("Cannot update default prompt with non-default.")
|
||||
|
||||
prompt.name = name
|
||||
prompt.description = description
|
||||
prompt.system_prompt = system_prompt
|
||||
prompt.task_prompt = task_prompt
|
||||
prompt.include_citations = include_citations
|
||||
prompt.datetime_aware = datetime_aware
|
||||
prompt.default_prompt = default_prompt
|
||||
|
||||
if personas is not None:
|
||||
prompt.personas.clear()
|
||||
prompt.personas = personas
|
||||
|
||||
else:
|
||||
prompt = Prompt(
|
||||
id=prompt_id,
|
||||
user_id=None if shared else user_id,
|
||||
name=name,
|
||||
description=description,
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
include_citations=include_citations,
|
||||
datetime_aware=datetime_aware,
|
||||
default_prompt=default_prompt,
|
||||
personas=personas or [],
|
||||
)
|
||||
db_session.add(prompt)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
else:
|
||||
# Flush the session so that the Prompt has an ID
|
||||
db_session.flush()
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def upsert_persona(
|
||||
user_id: UUID | None,
|
||||
name: str,
|
||||
retrieval_enabled: bool,
|
||||
datetime_aware: bool,
|
||||
system_text: str | None,
|
||||
tools: list[ToolInfo] | None,
|
||||
hint_text: str | None,
|
||||
description: str,
|
||||
num_chunks: float,
|
||||
llm_relevance_filter: bool,
|
||||
llm_filter_extraction: bool,
|
||||
recency_bias: RecencyBiasSetting,
|
||||
prompts: list[Prompt] | None,
|
||||
document_sets: list[DBDocumentSet] | None,
|
||||
llm_model_version_override: str | None,
|
||||
shared: bool,
|
||||
db_session: Session,
|
||||
persona_id: int | None = None,
|
||||
default_persona: bool = False,
|
||||
document_sets: list[DocumentSetDBModel] | None = None,
|
||||
commit: bool = True,
|
||||
) -> Persona:
|
||||
persona = db_session.query(Persona).filter_by(id=persona_id).first()
|
||||
|
||||
# Default personas are defined via yaml files at deployment time
|
||||
if persona is None and default_persona:
|
||||
persona = fetch_default_persona_by_name(name, db_session)
|
||||
if persona_id is not None:
|
||||
persona = db_session.query(Persona).filter_by(id=persona_id).first()
|
||||
else:
|
||||
persona = get_persona_by_name(
|
||||
persona_name=name, user_id=user_id, shared=shared, db_session=db_session
|
||||
)
|
||||
|
||||
if persona:
|
||||
if not default_persona and persona.default_persona:
|
||||
raise ValueError("Cannot update default persona with non-default.")
|
||||
|
||||
persona.name = name
|
||||
persona.retrieval_enabled = retrieval_enabled
|
||||
persona.datetime_aware = datetime_aware
|
||||
persona.system_text = system_text
|
||||
persona.tools = tools
|
||||
persona.hint_text = hint_text
|
||||
persona.description = description
|
||||
persona.num_chunks = num_chunks
|
||||
persona.llm_relevance_filter = llm_relevance_filter
|
||||
persona.llm_filter_extraction = llm_filter_extraction
|
||||
persona.recency_bias = recency_bias
|
||||
persona.default_persona = default_persona
|
||||
persona.llm_model_version_override = llm_model_version_override
|
||||
persona.deleted = False # Un-delete if previously deleted
|
||||
|
||||
# Do not delete any associations manually added unless
|
||||
# a new updated list is provided
|
||||
if document_sets is not None:
|
||||
persona.document_sets.clear()
|
||||
persona.document_sets = document_sets
|
||||
persona.document_sets = document_sets or []
|
||||
|
||||
if prompts is not None:
|
||||
persona.prompts.clear()
|
||||
persona.prompts = prompts
|
||||
|
||||
else:
|
||||
persona = Persona(
|
||||
id=persona_id,
|
||||
user_id=None if shared else user_id,
|
||||
name=name,
|
||||
retrieval_enabled=retrieval_enabled,
|
||||
datetime_aware=datetime_aware,
|
||||
system_text=system_text,
|
||||
tools=tools,
|
||||
hint_text=hint_text,
|
||||
description=description,
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
default_persona=default_persona,
|
||||
document_sets=document_sets if document_sets else [],
|
||||
prompts=prompts or [],
|
||||
document_sets=document_sets or [],
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
)
|
||||
db_session.add(persona)
|
||||
|
||||
@@ -345,3 +508,204 @@ def upsert_persona(
|
||||
db_session.flush()
|
||||
|
||||
return persona
|
||||
|
||||
|
||||
def mark_prompt_as_deleted(
|
||||
prompt_id: int,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
prompt = get_prompt_by_id(
|
||||
prompt_id=prompt_id, user_id=user_id, db_session=db_session
|
||||
)
|
||||
prompt.deleted = True
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_persona_as_deleted(
|
||||
persona_id: int,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
persona = get_persona_by_id(
|
||||
persona_id=persona_id, user_id=user_id, db_session=db_session
|
||||
)
|
||||
persona.deleted = True
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_persona_visibility(
|
||||
persona_id: int,
|
||||
is_visible: bool,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
persona = get_persona_by_id(
|
||||
persona_id=persona_id, user_id=None, db_session=db_session
|
||||
)
|
||||
persona.is_visible = is_visible
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_all_personas_display_priority(
|
||||
display_priority_map: dict[int, int],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Updates the display priority of all lives Personas"""
|
||||
personas = get_personas(user_id=None, db_session=db_session)
|
||||
available_persona_ids = {persona.id for persona in personas}
|
||||
if available_persona_ids != set(display_priority_map.keys()):
|
||||
raise ValueError("Invalid persona IDs provided")
|
||||
|
||||
for persona in personas:
|
||||
persona.display_priority = display_priority_map[persona.id]
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_prompts(
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
include_default: bool = True,
|
||||
include_deleted: bool = False,
|
||||
) -> Sequence[Prompt]:
|
||||
stmt = select(Prompt).where(
|
||||
or_(Prompt.user_id == user_id, Prompt.user_id.is_(None))
|
||||
)
|
||||
|
||||
if not include_default:
|
||||
stmt = stmt.where(Prompt.default_prompt.is_(False))
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(Prompt.deleted.is_(False))
|
||||
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def get_personas(
|
||||
# if user_id is `None` assume the user is an admin or auth is disabled
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
include_default: bool = True,
|
||||
include_slack_bot_personas: bool = False,
|
||||
include_deleted: bool = False,
|
||||
) -> Sequence[Persona]:
|
||||
stmt = select(Persona)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(or_(Persona.user_id == user_id, Persona.user_id.is_(None)))
|
||||
|
||||
if not include_default:
|
||||
stmt = stmt.where(Persona.default_persona.is_(False))
|
||||
if not include_slack_bot_personas:
|
||||
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(Persona.deleted.is_(False))
|
||||
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def get_doc_query_identifiers_from_model(
|
||||
search_doc_ids: list[int],
|
||||
chat_session: ChatSession,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> list[tuple[str, int]]:
|
||||
"""Given a list of search_doc_ids"""
|
||||
search_docs = (
|
||||
db_session.query(SearchDoc).filter(SearchDoc.id.in_(search_doc_ids)).all()
|
||||
)
|
||||
|
||||
if user_id != chat_session.user_id:
|
||||
logger.error(
|
||||
f"Docs referenced are from a chat session not belonging to user {user_id}"
|
||||
)
|
||||
raise ValueError("Docs references do not belong to user")
|
||||
|
||||
if any(
|
||||
[doc.chat_messages[0].chat_session_id != chat_session.id for doc in search_docs]
|
||||
):
|
||||
raise ValueError("Invalid reference doc, not from this chat session.")
|
||||
|
||||
doc_query_identifiers = [(doc.document_id, doc.chunk_ind) for doc in search_docs]
|
||||
|
||||
return doc_query_identifiers
|
||||
|
||||
|
||||
def create_db_search_doc(
|
||||
server_search_doc: ServerSearchDoc,
|
||||
db_session: Session,
|
||||
) -> SearchDoc:
|
||||
db_search_doc = SearchDoc(
|
||||
document_id=server_search_doc.document_id,
|
||||
chunk_ind=server_search_doc.chunk_ind,
|
||||
semantic_id=server_search_doc.semantic_identifier,
|
||||
link=server_search_doc.link,
|
||||
blurb=server_search_doc.blurb,
|
||||
source_type=server_search_doc.source_type,
|
||||
boost=server_search_doc.boost,
|
||||
hidden=server_search_doc.hidden,
|
||||
doc_metadata=server_search_doc.metadata,
|
||||
score=server_search_doc.score,
|
||||
match_highlights=server_search_doc.match_highlights,
|
||||
updated_at=server_search_doc.updated_at,
|
||||
primary_owners=server_search_doc.primary_owners,
|
||||
secondary_owners=server_search_doc.secondary_owners,
|
||||
)
|
||||
|
||||
db_session.add(db_search_doc)
|
||||
db_session.commit()
|
||||
|
||||
return db_search_doc
|
||||
|
||||
|
||||
def get_db_search_doc_by_id(doc_id: int, db_session: Session) -> DBSearchDoc | None:
|
||||
"""There are no safety checks here like user permission etc., use with caution"""
|
||||
search_doc = db_session.query(SearchDoc).filter(SearchDoc.id == doc_id).first()
|
||||
return search_doc
|
||||
|
||||
|
||||
def translate_db_search_doc_to_server_search_doc(
|
||||
db_search_doc: SearchDoc,
|
||||
) -> SavedSearchDoc:
|
||||
return SavedSearchDoc(
|
||||
db_doc_id=db_search_doc.id,
|
||||
document_id=db_search_doc.document_id,
|
||||
chunk_ind=db_search_doc.chunk_ind,
|
||||
semantic_identifier=db_search_doc.semantic_id,
|
||||
link=db_search_doc.link,
|
||||
blurb=db_search_doc.blurb,
|
||||
source_type=db_search_doc.source_type,
|
||||
boost=db_search_doc.boost,
|
||||
hidden=db_search_doc.hidden,
|
||||
metadata=db_search_doc.doc_metadata,
|
||||
score=db_search_doc.score,
|
||||
match_highlights=db_search_doc.match_highlights,
|
||||
updated_at=db_search_doc.updated_at,
|
||||
primary_owners=db_search_doc.primary_owners,
|
||||
secondary_owners=db_search_doc.secondary_owners,
|
||||
)
|
||||
|
||||
|
||||
def get_retrieval_docs_from_chat_message(chat_message: ChatMessage) -> RetrievalDocs:
|
||||
return RetrievalDocs(
|
||||
top_documents=[
|
||||
translate_db_search_doc_to_server_search_doc(db_doc)
|
||||
for db_doc in chat_message.search_docs
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def translate_db_message_to_chat_message_detail(
|
||||
chat_message: ChatMessage,
|
||||
) -> ChatMessageDetail:
|
||||
chat_msg_detail = ChatMessageDetail(
|
||||
message_id=chat_message.id,
|
||||
parent_message=chat_message.parent_message,
|
||||
latest_child_message=chat_message.latest_child_message,
|
||||
message=chat_message.message,
|
||||
rephrased_query=chat_message.rephrased_query,
|
||||
context_docs=get_retrieval_docs_from_chat_message(chat_message),
|
||||
message_type=chat_message.message_type,
|
||||
time_sent=chat_message.time_sent,
|
||||
citations=chat_message.citations,
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
||||
@@ -11,8 +11,8 @@ from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.server.models import ConnectorBase
|
||||
from danswer.server.models import ObjectCreationIdResponse
|
||||
from danswer.server.documents.models import ConnectorBase
|
||||
from danswer.server.documents.models import ObjectCreationIdResponse
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -36,8 +36,12 @@ def fetch_connectors(
|
||||
return list(results.all())
|
||||
|
||||
|
||||
def connector_by_name_exists(connector_name: str, db_session: Session) -> bool:
|
||||
stmt = select(Connector).where(Connector.name == connector_name)
|
||||
def connector_by_name_source_exists(
|
||||
connector_name: str, source: DocumentSource, db_session: Session
|
||||
) -> bool:
|
||||
stmt = select(Connector).where(
|
||||
Connector.name == connector_name, Connector.source == source
|
||||
)
|
||||
result = db_session.execute(stmt)
|
||||
connector = result.scalar_one_or_none()
|
||||
return connector is not None
|
||||
@@ -50,11 +54,26 @@ def fetch_connector_by_id(connector_id: int, db_session: Session) -> Connector |
|
||||
return connector
|
||||
|
||||
|
||||
def fetch_ingestion_connector_by_name(
|
||||
connector_name: str, db_session: Session
|
||||
) -> Connector | None:
|
||||
stmt = (
|
||||
select(Connector)
|
||||
.where(Connector.name == connector_name)
|
||||
.where(Connector.source == DocumentSource.INGESTION_API)
|
||||
)
|
||||
result = db_session.execute(stmt)
|
||||
connector = result.scalar_one_or_none()
|
||||
return connector
|
||||
|
||||
|
||||
def create_connector(
|
||||
connector_data: ConnectorBase,
|
||||
db_session: Session,
|
||||
) -> ObjectCreationIdResponse:
|
||||
if connector_by_name_exists(connector_data.name, db_session):
|
||||
if connector_by_name_source_exists(
|
||||
connector_data.name, connector_data.source, db_session
|
||||
):
|
||||
raise ValueError(
|
||||
"Connector by this name already exists, duplicate naming not allowed."
|
||||
)
|
||||
@@ -82,8 +101,8 @@ def update_connector(
|
||||
if connector is None:
|
||||
return None
|
||||
|
||||
if connector_data.name != connector.name and connector_by_name_exists(
|
||||
connector_data.name, db_session
|
||||
if connector_data.name != connector.name and connector_by_name_source_exists(
|
||||
connector_data.name, connector_data.source, db_session
|
||||
):
|
||||
raise ValueError(
|
||||
"Connector by this name already exists, duplicate naming not allowed."
|
||||
@@ -202,3 +221,44 @@ def fetch_latest_index_attempts_by_status(
|
||||
),
|
||||
)
|
||||
return cast(list[IndexAttempt], query.all())
|
||||
|
||||
|
||||
def fetch_unique_document_sources(db_session: Session) -> list[DocumentSource]:
|
||||
distinct_sources = db_session.query(Connector.source).distinct().all()
|
||||
|
||||
sources = [
|
||||
source[0]
|
||||
for source in distinct_sources
|
||||
if source[0] != DocumentSource.INGESTION_API
|
||||
]
|
||||
|
||||
return sources
|
||||
|
||||
|
||||
def create_initial_default_connector(db_session: Session) -> None:
|
||||
default_connector_id = 0
|
||||
default_connector = fetch_connector_by_id(default_connector_id, db_session)
|
||||
|
||||
if default_connector is not None:
|
||||
if (
|
||||
default_connector.source != DocumentSource.INGESTION_API
|
||||
or default_connector.input_type != InputType.LOAD_STATE
|
||||
or default_connector.refresh_freq is not None
|
||||
or default_connector.disabled
|
||||
):
|
||||
raise ValueError(
|
||||
"DB is not in a valid initial state. "
|
||||
"Default connector does not have expected values."
|
||||
)
|
||||
return
|
||||
|
||||
connector = Connector(
|
||||
id=default_connector_id,
|
||||
name="Ingestion API",
|
||||
source=DocumentSource.INGESTION_API,
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={},
|
||||
refresh_freq=None,
|
||||
)
|
||||
db_session.add(connector)
|
||||
db_session.commit()
|
||||
|
||||
@@ -54,6 +54,8 @@ def get_last_successful_attempt_time(
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
) -> float:
|
||||
"""Gets the timestamp of the last successful index run stored in
|
||||
the CC Pair row in the database"""
|
||||
connector_credential_pair = get_connector_credential_pair(
|
||||
connector_id, credential_id, db_session
|
||||
)
|
||||
@@ -84,7 +86,10 @@ def update_connector_credential_pair(
|
||||
cc_pair.last_attempt_status = attempt_status
|
||||
# simply don't update last_successful_index_time if run_dt is not specified
|
||||
# at worst, this would result in re-indexing documents that were already indexed
|
||||
if attempt_status == IndexingStatus.SUCCESS and run_dt is not None:
|
||||
if (
|
||||
attempt_status == IndexingStatus.SUCCESS
|
||||
or attempt_status == IndexingStatus.IN_PROGRESS
|
||||
) and run_dt is not None:
|
||||
cc_pair.last_successful_index_time = run_dt
|
||||
if net_docs is not None:
|
||||
cc_pair.total_docs_indexed += net_docs
|
||||
@@ -117,6 +122,27 @@ def mark_all_in_progress_cc_pairs_failed(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def associate_default_cc_pair(db_session: Session) -> None:
|
||||
existing_association = (
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.filter(
|
||||
ConnectorCredentialPair.connector_id == 0,
|
||||
ConnectorCredentialPair.credential_id == 0,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
if existing_association is not None:
|
||||
return
|
||||
|
||||
association = ConnectorCredentialPair(
|
||||
connector_id=0,
|
||||
credential_id=0,
|
||||
name="DefaultCCPair",
|
||||
)
|
||||
db_session.add(association)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def add_credential_to_connector(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
|
||||
1
backend/danswer/db/constants.py
Normal file
1
backend/danswer/db/constants.py
Normal file
@@ -0,0 +1 @@
|
||||
SLACK_BOT_PERSONA_PREFIX = "__slack_bot_persona__"
|
||||
@@ -9,18 +9,20 @@ from danswer.auth.schemas import UserRole
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import User
|
||||
from danswer.server.models import CredentialBase
|
||||
from danswer.server.models import ObjectCreationIdResponse
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _attach_user_filters(stmt: Select[tuple[Credential]], user: User | None) -> Select:
|
||||
def _attach_user_filters(
|
||||
stmt: Select[tuple[Credential]],
|
||||
user: User | None,
|
||||
assume_admin: bool = False, # Used with API key
|
||||
) -> Select:
|
||||
"""Attaches filters to the statement to ensure that the user can only
|
||||
access the appropriate credentials"""
|
||||
if user:
|
||||
@@ -29,11 +31,18 @@ def _attach_user_filters(stmt: Select[tuple[Credential]], user: User | None) ->
|
||||
or_(
|
||||
Credential.user_id == user.id,
|
||||
Credential.user_id.is_(None),
|
||||
Credential.is_admin == True, # noqa: E712
|
||||
Credential.admin_public == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
else:
|
||||
stmt = stmt.where(Credential.user_id == user.id)
|
||||
elif assume_admin:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
Credential.user_id.is_(None),
|
||||
Credential.admin_public == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
|
||||
return stmt
|
||||
|
||||
@@ -49,10 +58,13 @@ def fetch_credentials(
|
||||
|
||||
|
||||
def fetch_credential_by_id(
|
||||
credential_id: int, user: User | None, db_session: Session
|
||||
credential_id: int,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
assume_admin: bool = False,
|
||||
) -> Credential | None:
|
||||
stmt = select(Credential).where(Credential.id == credential_id)
|
||||
stmt = _attach_user_filters(stmt, user)
|
||||
stmt = _attach_user_filters(stmt, user, assume_admin=assume_admin)
|
||||
result = db_session.execute(stmt)
|
||||
credential = result.scalar_one_or_none()
|
||||
return credential
|
||||
@@ -62,16 +74,16 @@ def create_credential(
|
||||
credential_data: CredentialBase,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> ObjectCreationIdResponse:
|
||||
) -> Credential:
|
||||
credential = Credential(
|
||||
credential_json=credential_data.credential_json,
|
||||
user_id=user.id if user else None,
|
||||
is_admin=credential_data.is_admin,
|
||||
admin_public=credential_data.admin_public,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.commit()
|
||||
|
||||
return ObjectCreationIdResponse(id=credential.id)
|
||||
return credential
|
||||
|
||||
|
||||
def update_credential(
|
||||
@@ -131,30 +143,26 @@ def delete_credential(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_initial_public_credential() -> None:
|
||||
def create_initial_public_credential(db_session: Session) -> None:
|
||||
public_cred_id = 0
|
||||
error_msg = (
|
||||
"DB is not in a valid initial state."
|
||||
"There must exist an empty public credential for data connectors that do not require additional Auth."
|
||||
)
|
||||
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session:
|
||||
first_credential = fetch_credential_by_id(public_cred_id, None, db_session)
|
||||
first_credential = fetch_credential_by_id(public_cred_id, None, db_session)
|
||||
|
||||
if first_credential is not None:
|
||||
if (
|
||||
first_credential.credential_json != {}
|
||||
or first_credential.user is not None
|
||||
):
|
||||
raise ValueError(error_msg)
|
||||
return
|
||||
if first_credential is not None:
|
||||
if first_credential.credential_json != {} or first_credential.user is not None:
|
||||
raise ValueError(error_msg)
|
||||
return
|
||||
|
||||
credential = Credential(
|
||||
id=public_cred_id,
|
||||
credential_json={},
|
||||
user_id=None,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.commit()
|
||||
credential = Credential(
|
||||
id=public_cred_id,
|
||||
credential_json={},
|
||||
user_id=None,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_google_drive_service_account_credentials(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
@@ -16,9 +17,10 @@ from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import Document as DbDocument
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.tag import delete_document_tags_for_documents
|
||||
from danswer.db.utils import model_to_dict
|
||||
from danswer.document_index.interfaces import DocumentMetadata
|
||||
from danswer.server.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -39,6 +41,15 @@ def get_documents_for_connector_credential_pair(
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def get_documents_by_ids(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> list[DbDocument]:
|
||||
stmt = select(DbDocument).where(DbDocument.id.in_(document_ids))
|
||||
documents = db_session.execute(stmt).scalars().all()
|
||||
return list(documents)
|
||||
|
||||
|
||||
def get_document_connector_cnts(
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
@@ -136,9 +147,13 @@ def get_acccess_info_for_documents(
|
||||
|
||||
|
||||
def upsert_documents(
|
||||
db_session: Session, document_metadata_batch: list[DocumentMetadata]
|
||||
db_session: Session,
|
||||
document_metadata_batch: list[DocumentMetadata],
|
||||
initial_boost: int = DEFAULT_BOOST,
|
||||
) -> None:
|
||||
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause."""
|
||||
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause.
|
||||
Also note, this function should not be used for updating documents, only creating and
|
||||
ensuring that it exists. It IGNORES the doc_updated_at field"""
|
||||
seen_documents: dict[str, DocumentMetadata] = {}
|
||||
for document_metadata in document_metadata_batch:
|
||||
doc_id = document_metadata.document_id
|
||||
@@ -154,11 +169,12 @@ def upsert_documents(
|
||||
model_to_dict(
|
||||
DbDocument(
|
||||
id=doc.document_id,
|
||||
boost=DEFAULT_BOOST,
|
||||
from_ingestion_api=doc.from_ingestion_api,
|
||||
boost=initial_boost,
|
||||
hidden=False,
|
||||
semantic_id=doc.semantic_identifier,
|
||||
link=doc.first_link,
|
||||
doc_updated_at=doc.doc_updated_at,
|
||||
doc_updated_at=None, # this is intentional
|
||||
primary_owners=doc.primary_owners,
|
||||
secondary_owners=doc.secondary_owners,
|
||||
)
|
||||
@@ -200,6 +216,21 @@ def upsert_document_by_connector_credential_pair(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_docs_updated_at(
|
||||
ids_to_new_updated_at: dict[str, datetime],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
doc_ids = list(ids_to_new_updated_at.keys())
|
||||
documents_to_update = (
|
||||
db_session.query(DbDocument).filter(DbDocument.id.in_(doc_ids)).all()
|
||||
)
|
||||
|
||||
for document in documents_to_update:
|
||||
document.doc_updated_at = ids_to_new_updated_at[document.id]
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def upsert_documents_complete(
|
||||
db_session: Session,
|
||||
document_metadata_batch: list[DocumentMetadata],
|
||||
@@ -242,6 +273,7 @@ def delete_documents_complete(db_session: Session, document_ids: list[str]) -> N
|
||||
delete_document_feedback_for_documents(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
)
|
||||
delete_document_tags_for_documents(document_ids=document_ids, db_session=db_session)
|
||||
delete_documents(db_session, document_ids)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -14,8 +14,8 @@ from danswer.db.models import Document
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||
from danswer.db.models import DocumentSet__ConnectorCredentialPair
|
||||
from danswer.server.models import DocumentSetCreationRequest
|
||||
from danswer.server.models import DocumentSetUpdateRequest
|
||||
from danswer.server.features.document_set.models import DocumentSetCreationRequest
|
||||
from danswer.server.features.document_set.models import DocumentSetUpdateRequest
|
||||
|
||||
|
||||
def _delete_document_set_cc_pairs__no_commit(
|
||||
@@ -60,6 +60,8 @@ def get_document_set_by_name(
|
||||
def get_document_sets_by_ids(
|
||||
db_session: Session, document_set_ids: list[int]
|
||||
) -> Sequence[DocumentSetDBModel]:
|
||||
if not document_set_ids:
|
||||
return []
|
||||
return db_session.scalars(
|
||||
select(DocumentSetDBModel).where(DocumentSetDBModel.id.in_(document_set_ids))
|
||||
).all()
|
||||
@@ -396,3 +398,33 @@ def get_or_create_document_set_by_name(
|
||||
db_session.commit()
|
||||
|
||||
return new_doc_set
|
||||
|
||||
|
||||
def check_document_sets_are_public(
|
||||
db_session: Session,
|
||||
document_set_ids: list[int],
|
||||
) -> bool:
|
||||
connector_credential_pair_ids = (
|
||||
db_session.query(
|
||||
DocumentSet__ConnectorCredentialPair.connector_credential_pair_id
|
||||
)
|
||||
.filter(
|
||||
DocumentSet__ConnectorCredentialPair.document_set_id.in_(document_set_ids)
|
||||
)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
not_public_exists = (
|
||||
db_session.query(ConnectorCredentialPair.id)
|
||||
.filter(
|
||||
ConnectorCredentialPair.id.in_(
|
||||
connector_credential_pair_ids # type:ignore
|
||||
),
|
||||
ConnectorCredentialPair.is_public.is_(False),
|
||||
)
|
||||
.limit(1)
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
|
||||
return not not_public_exists
|
||||
|
||||
@@ -4,34 +4,19 @@ from sqlalchemy import asc
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.constants import QAFeedbackType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.db.models import ChatMessage as DbChatMessage
|
||||
from danswer.db.chat import get_chat_message
|
||||
from danswer.db.models import ChatMessageFeedback
|
||||
from danswer.db.models import Document as DbDocument
|
||||
from danswer.db.models import DocumentRetrievalFeedback
|
||||
from danswer.db.models import QueryEvent
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.search.models import SearchType
|
||||
|
||||
|
||||
def fetch_query_event_by_id(query_id: int, db_session: Session) -> QueryEvent:
|
||||
stmt = select(QueryEvent).where(QueryEvent.id == query_id)
|
||||
result = db_session.execute(stmt)
|
||||
query_event = result.scalar_one_or_none()
|
||||
|
||||
if not query_event:
|
||||
raise ValueError("Invalid Query Event ID Provided")
|
||||
|
||||
return query_event
|
||||
|
||||
|
||||
def fetch_docs_by_id(doc_id: str, db_session: Session) -> DbDocument:
|
||||
def fetch_db_doc_by_id(doc_id: str, db_session: Session) -> DbDocument:
|
||||
stmt = select(DbDocument).where(DbDocument.id == doc_id)
|
||||
result = db_session.execute(stmt)
|
||||
doc = result.scalar_one_or_none()
|
||||
@@ -97,80 +82,20 @@ def update_document_hidden(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_query_event(
|
||||
db_session: Session,
|
||||
query: str,
|
||||
search_type: SearchType | None,
|
||||
llm_answer: str | None,
|
||||
user_id: UUID | None,
|
||||
retrieved_document_ids: list[str] | None = None,
|
||||
) -> int:
|
||||
query_event = QueryEvent(
|
||||
query=query,
|
||||
selected_search_flow=search_type,
|
||||
llm_answer=llm_answer,
|
||||
retrieved_document_ids=retrieved_document_ids,
|
||||
user_id=user_id,
|
||||
)
|
||||
db_session.add(query_event)
|
||||
db_session.commit()
|
||||
|
||||
return query_event.id
|
||||
|
||||
|
||||
def update_query_event_feedback(
|
||||
db_session: Session,
|
||||
feedback: QAFeedbackType,
|
||||
query_id: int,
|
||||
user_id: UUID | None,
|
||||
) -> None:
|
||||
query_event = fetch_query_event_by_id(query_id, db_session)
|
||||
|
||||
if user_id != query_event.user_id:
|
||||
raise ValueError("User trying to give feedback on a query run by another user.")
|
||||
|
||||
query_event.feedback = feedback
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_query_event_retrieved_documents(
|
||||
db_session: Session,
|
||||
retrieved_document_ids: list[str],
|
||||
query_id: int,
|
||||
user_id: UUID | None,
|
||||
) -> None:
|
||||
query_event = fetch_query_event_by_id(query_id, db_session)
|
||||
|
||||
if user_id != query_event.user_id:
|
||||
raise ValueError("User trying to update docs on a query run by another user.")
|
||||
|
||||
query_event.retrieved_document_ids = retrieved_document_ids
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_doc_retrieval_feedback(
|
||||
qa_event_id: int,
|
||||
message_id: int,
|
||||
document_id: str,
|
||||
document_rank: int,
|
||||
user_id: UUID | None,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
clicked: bool = False,
|
||||
feedback: SearchFeedbackType | None = None,
|
||||
) -> None:
|
||||
"""Creates a new Document feedback row and updates the boost value in Postgres and Vespa"""
|
||||
if not clicked and feedback is None:
|
||||
raise ValueError("No action taken, not valid feedback")
|
||||
|
||||
query_event = fetch_query_event_by_id(qa_event_id, db_session)
|
||||
|
||||
if user_id != query_event.user_id:
|
||||
raise ValueError("User trying to give feedback on a query run by another user.")
|
||||
|
||||
doc_m = fetch_docs_by_id(document_id, db_session)
|
||||
db_doc = fetch_db_doc_by_id(document_id, db_session)
|
||||
|
||||
retrieval_feedback = DocumentRetrievalFeedback(
|
||||
qa_event_id=qa_event_id,
|
||||
chat_message_id=message_id,
|
||||
document_id=document_id,
|
||||
document_rank=document_rank,
|
||||
clicked=clicked,
|
||||
@@ -179,20 +104,23 @@ def create_doc_retrieval_feedback(
|
||||
|
||||
if feedback is not None:
|
||||
if feedback == SearchFeedbackType.ENDORSE:
|
||||
doc_m.boost += 1
|
||||
db_doc.boost += 1
|
||||
elif feedback == SearchFeedbackType.REJECT:
|
||||
doc_m.boost -= 1
|
||||
db_doc.boost -= 1
|
||||
elif feedback == SearchFeedbackType.HIDE:
|
||||
doc_m.hidden = True
|
||||
db_doc.hidden = True
|
||||
elif feedback == SearchFeedbackType.UNHIDE:
|
||||
doc_m.hidden = False
|
||||
db_doc.hidden = False
|
||||
else:
|
||||
raise ValueError("Unhandled document feedback type")
|
||||
|
||||
if feedback in [SearchFeedbackType.ENDORSE, SearchFeedbackType.REJECT]:
|
||||
if feedback in [
|
||||
SearchFeedbackType.ENDORSE,
|
||||
SearchFeedbackType.REJECT,
|
||||
SearchFeedbackType.HIDE,
|
||||
]:
|
||||
update = UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
boost=doc_m.boost,
|
||||
document_ids=[document_id], boost=db_doc.boost, hidden=db_doc.hidden
|
||||
)
|
||||
# Updates are generally batched for efficiency, this case only 1 doc/value is updated
|
||||
document_index.update([update])
|
||||
@@ -213,40 +141,24 @@ def delete_document_feedback_for_documents(
|
||||
|
||||
|
||||
def create_chat_message_feedback(
|
||||
chat_session_id: int,
|
||||
message_number: int,
|
||||
edit_number: int,
|
||||
is_positive: bool | None,
|
||||
feedback_text: str | None,
|
||||
chat_message_id: int,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
is_positive: bool | None = None,
|
||||
feedback_text: str | None = None,
|
||||
) -> None:
|
||||
if is_positive is None and feedback_text is None:
|
||||
raise ValueError("No feedback provided")
|
||||
|
||||
try:
|
||||
chat_message = (
|
||||
db_session.query(DbChatMessage)
|
||||
.filter_by(
|
||||
chat_session_id=chat_session_id,
|
||||
message_number=message_number,
|
||||
edit_number=edit_number,
|
||||
)
|
||||
.one()
|
||||
)
|
||||
except NoResultFound:
|
||||
raise ValueError("ChatMessage not found")
|
||||
chat_message = get_chat_message(
|
||||
chat_message_id=chat_message_id, user_id=user_id, db_session=db_session
|
||||
)
|
||||
|
||||
if chat_message.message_type != MessageType.ASSISTANT:
|
||||
raise ValueError("Can only provide feedback on LLM Outputs")
|
||||
|
||||
if user_id is not None and chat_message.chat_session.user_id != user_id:
|
||||
raise ValueError("User trying to give feedback on a message by another user.")
|
||||
|
||||
message_feedback = ChatMessageFeedback(
|
||||
chat_message_chat_session_id=chat_session_id,
|
||||
chat_message_message_number=message_number,
|
||||
chat_message_edit_number=edit_number,
|
||||
chat_message_id=chat_message_id,
|
||||
is_positive=is_positive,
|
||||
feedback_text=feedback_text,
|
||||
)
|
||||
|
||||
@@ -7,13 +7,15 @@ from sqlalchemy import desc
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.server.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -55,8 +57,13 @@ def get_inprogress_index_attempts(
|
||||
|
||||
|
||||
def get_not_started_index_attempts(db_session: Session) -> list[IndexAttempt]:
|
||||
"""This eagerly loads the connector and credential so that the db_session can be expired
|
||||
before running long-living indexing jobs, which causes increasing memory usage"""
|
||||
stmt = select(IndexAttempt)
|
||||
stmt = stmt.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
|
||||
stmt = stmt.options(
|
||||
joinedload(IndexAttempt.connector), joinedload(IndexAttempt.credential)
|
||||
)
|
||||
new_attempts = db_session.scalars(stmt)
|
||||
return list(new_attempts.all())
|
||||
|
||||
@@ -88,6 +95,9 @@ def mark_attempt_failed(
|
||||
db_session.add(index_attempt)
|
||||
db_session.commit()
|
||||
|
||||
source = index_attempt.connector.source
|
||||
optional_telemetry(record_type=RecordType.FAILURE, data={"connector": source})
|
||||
|
||||
|
||||
def update_docs_indexed(
|
||||
db_session: Session,
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import NotRequired
|
||||
from typing import Optional
|
||||
from typing import TypedDict
|
||||
from uuid import UUID
|
||||
|
||||
@@ -13,14 +14,15 @@ from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTa
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy import DateTime
|
||||
from sqlalchemy import Enum
|
||||
from sqlalchemy import Float
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import ForeignKeyConstraint
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import Sequence
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy import UniqueConstraint
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm import Mapped
|
||||
@@ -31,9 +33,9 @@ from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.constants import QAFeedbackType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.search.models import RecencyBiasSetting
|
||||
from danswer.search.models import SearchType
|
||||
|
||||
|
||||
@@ -64,6 +66,11 @@ class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
"""
|
||||
Auth/Authz (users, permissions, access) Tables
|
||||
"""
|
||||
|
||||
|
||||
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
|
||||
# even an almost empty token from keycloak will not fit the default 1024 bytes
|
||||
access_token: Mapped[str] = mapped_column(Text, nullable=False) # type: ignore
|
||||
@@ -79,12 +86,11 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
credentials: Mapped[List["Credential"]] = relationship(
|
||||
"Credential", back_populates="user", lazy="joined"
|
||||
)
|
||||
query_events: Mapped[List["QueryEvent"]] = relationship(
|
||||
"QueryEvent", back_populates="user"
|
||||
)
|
||||
chat_sessions: Mapped[List["ChatSession"]] = relationship(
|
||||
"ChatSession", back_populates="user"
|
||||
)
|
||||
prompts: Mapped[List["Prompt"]] = relationship("Prompt", back_populates="user")
|
||||
personas: Mapped[List["Persona"]] = relationship("Persona", back_populates="user")
|
||||
|
||||
|
||||
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
|
||||
@@ -92,7 +98,7 @@ class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
|
||||
|
||||
|
||||
"""
|
||||
Association tables
|
||||
Association Tables
|
||||
NOTE: must be at the top since they are referenced by other tables
|
||||
"""
|
||||
|
||||
@@ -106,6 +112,13 @@ class Persona__DocumentSet(Base):
|
||||
)
|
||||
|
||||
|
||||
class Persona__Prompt(Base):
|
||||
__tablename__ = "persona__prompt"
|
||||
|
||||
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True)
|
||||
prompt_id: Mapped[int] = mapped_column(ForeignKey("prompt.id"), primary_key=True)
|
||||
|
||||
|
||||
class DocumentSet__ConnectorCredentialPair(Base):
|
||||
__tablename__ = "document_set__connector_credential_pair"
|
||||
|
||||
@@ -130,6 +143,31 @@ class DocumentSet__ConnectorCredentialPair(Base):
|
||||
document_set: Mapped["DocumentSet"] = relationship("DocumentSet")
|
||||
|
||||
|
||||
class ChatMessage__SearchDoc(Base):
|
||||
__tablename__ = "chat_message__search_doc"
|
||||
|
||||
chat_message_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("chat_message.id"), primary_key=True
|
||||
)
|
||||
search_doc_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("search_doc.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class Document__Tag(Base):
|
||||
__tablename__ = "document__tag"
|
||||
|
||||
document_id: Mapped[str] = mapped_column(
|
||||
ForeignKey("document.id"), primary_key=True
|
||||
)
|
||||
tag_id: Mapped[int] = mapped_column(ForeignKey("tag.id"), primary_key=True)
|
||||
|
||||
|
||||
"""
|
||||
Documents/Indexing Tables
|
||||
"""
|
||||
|
||||
|
||||
class ConnectorCredentialPair(Base):
|
||||
"""Connectors and Credentials can have a many-to-many relationship
|
||||
I.e. A Confluence Connector may have multiple admin users who can run it with their own credentials
|
||||
@@ -145,9 +183,7 @@ class ConnectorCredentialPair(Base):
|
||||
unique=True,
|
||||
nullable=False,
|
||||
)
|
||||
name: Mapped[str] = mapped_column(
|
||||
String, unique=True, nullable=True
|
||||
) # nullable for backwards compatability
|
||||
name: Mapped[str] = mapped_column(String, nullable=False)
|
||||
connector_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("connector.id"), primary_key=True
|
||||
)
|
||||
@@ -185,6 +221,70 @@ class ConnectorCredentialPair(Base):
|
||||
)
|
||||
|
||||
|
||||
class Document(Base):
|
||||
__tablename__ = "document"
|
||||
|
||||
# this should correspond to the ID of the document
|
||||
# (as is passed around in Danswer)
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
from_ingestion_api: Mapped[bool] = mapped_column(
|
||||
Boolean, default=False, nullable=True
|
||||
)
|
||||
# 0 for neutral, positive for mostly endorse, negative for mostly reject
|
||||
boost: Mapped[int] = mapped_column(Integer, default=DEFAULT_BOOST)
|
||||
hidden: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
semantic_id: Mapped[str] = mapped_column(String)
|
||||
# First Section's link
|
||||
link: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
# The updated time is also used as a measure of the last successful state of the doc
|
||||
# pulled from the source (to help skip reindexing already updated docs in case of
|
||||
# connector retries)
|
||||
doc_updated_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
# The following are not attached to User because the account/email may not be known
|
||||
# within Danswer
|
||||
# Something like the document creator
|
||||
primary_owners: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
# Something like assignee or space owner
|
||||
secondary_owners: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
# TODO if more sensitive data is added here for display, make sure to add user/group permission
|
||||
|
||||
retrieval_feedbacks: Mapped[List["DocumentRetrievalFeedback"]] = relationship(
|
||||
"DocumentRetrievalFeedback", back_populates="document"
|
||||
)
|
||||
tags = relationship(
|
||||
"Tag",
|
||||
secondary="document__tag",
|
||||
back_populates="documents",
|
||||
)
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
__tablename__ = "tag"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
tag_key: Mapped[str] = mapped_column(String)
|
||||
tag_value: Mapped[str] = mapped_column(String)
|
||||
source: Mapped[DocumentSource] = mapped_column(Enum(DocumentSource))
|
||||
|
||||
documents = relationship(
|
||||
"Document",
|
||||
secondary="document__tag",
|
||||
back_populates="tags",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"tag_key", "tag_value", "source", name="_tag_key_value_source_uc"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Connector(Base):
|
||||
__tablename__ = "connector"
|
||||
|
||||
@@ -226,7 +326,7 @@ class Credential(Base):
|
||||
credential_json: Mapped[dict[str, Any]] = mapped_column(postgresql.JSONB())
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
# if `true`, then all Admins will have access to the credential
|
||||
is_admin: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
admin_public: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
@@ -315,8 +415,7 @@ class IndexAttempt(Base):
|
||||
|
||||
|
||||
class DocumentByConnectorCredentialPair(Base):
|
||||
"""Represents an indexing of a document by a specific connector / credential
|
||||
pair"""
|
||||
"""Represents an indexing of a document by a specific connector / credential pair"""
|
||||
|
||||
__tablename__ = "document_by_connector_credential_pair"
|
||||
|
||||
@@ -337,47 +436,136 @@ class DocumentByConnectorCredentialPair(Base):
|
||||
)
|
||||
|
||||
|
||||
class QueryEvent(Base):
|
||||
__tablename__ = "query_event"
|
||||
"""
|
||||
Messages Tables
|
||||
"""
|
||||
|
||||
|
||||
class SearchDoc(Base):
|
||||
"""Different from Document table. This one stores the state of a document from a retrieval.
|
||||
This allows chat sessions to be replayed with the searched docs
|
||||
|
||||
Notably, this does not include the contents of the Document/Chunk, during inference if a stored
|
||||
SearchDoc is selected, an inference must be remade to retrieve the contents
|
||||
"""
|
||||
|
||||
__tablename__ = "search_doc"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
query: Mapped[str] = mapped_column(Text)
|
||||
# search_flow refers to user selection, None if user used auto
|
||||
selected_search_flow: Mapped[SearchType | None] = mapped_column(
|
||||
Enum(SearchType), nullable=True
|
||||
document_id: Mapped[str] = mapped_column(String)
|
||||
chunk_ind: Mapped[int] = mapped_column(Integer)
|
||||
semantic_id: Mapped[str] = mapped_column(String)
|
||||
link: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
blurb: Mapped[str] = mapped_column(String)
|
||||
boost: Mapped[int] = mapped_column(Integer)
|
||||
source_type: Mapped[DocumentSource] = mapped_column(Enum(DocumentSource))
|
||||
hidden: Mapped[bool] = mapped_column(Boolean)
|
||||
doc_metadata: Mapped[dict[str, str | list[str]]] = mapped_column(postgresql.JSONB())
|
||||
score: Mapped[float] = mapped_column(Float)
|
||||
match_highlights: Mapped[list[str]] = mapped_column(postgresql.ARRAY(String))
|
||||
# This is for the document, not this row in the table
|
||||
updated_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
llm_answer: Mapped[str | None] = mapped_column(Text, default=None)
|
||||
# Document IDs of the top context documents retrieved for the query (if any)
|
||||
# NOTE: not using a foreign key to enable easy deletion of documents without
|
||||
# needing to adjust `QueryEvent` rows
|
||||
retrieved_document_ids: Mapped[list[str] | None] = mapped_column(
|
||||
primary_owners: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
feedback: Mapped[QAFeedbackType | None] = mapped_column(
|
||||
Enum(QAFeedbackType), nullable=True
|
||||
)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
secondary_owners: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
|
||||
user: Mapped[User | None] = relationship("User", back_populates="query_events")
|
||||
document_feedbacks: Mapped[List["DocumentRetrievalFeedback"]] = relationship(
|
||||
"DocumentRetrievalFeedback", back_populates="qa_event"
|
||||
chat_messages = relationship(
|
||||
"ChatMessage",
|
||||
secondary="chat_message__search_doc",
|
||||
back_populates="search_docs",
|
||||
)
|
||||
|
||||
|
||||
class ChatSession(Base):
|
||||
__tablename__ = "chat_session"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"))
|
||||
description: Mapped[str] = mapped_column(Text)
|
||||
# One-shot direct answering, currently the two types of chats are not mixed
|
||||
one_shot: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# Only ever set to True if system is set to not hard-delete chats
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
user: Mapped[User] = relationship("User", back_populates="chat_sessions")
|
||||
messages: Mapped[List["ChatMessage"]] = relationship(
|
||||
"ChatMessage", back_populates="chat_session", cascade="delete"
|
||||
)
|
||||
persona: Mapped["Persona"] = relationship("Persona")
|
||||
|
||||
|
||||
class ChatMessage(Base):
|
||||
"""Note, the first message in a chain has no contents, it's a workaround to allow edits
|
||||
on the first message of a session, an empty root node basically
|
||||
|
||||
Since every user message is followed by a LLM response, chat messages generally come in pairs.
|
||||
Keeping them as separate messages however for future Agentification extensions
|
||||
Fields will be largely duplicated in the pair.
|
||||
"""
|
||||
|
||||
__tablename__ = "chat_message"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
chat_session_id: Mapped[int] = mapped_column(ForeignKey("chat_session.id"))
|
||||
parent_message: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
latest_child_message: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
message: Mapped[str] = mapped_column(Text)
|
||||
rephrased_query: Mapped[str] = mapped_column(Text, nullable=True)
|
||||
# If None, then there is no answer generation, it's the special case of only
|
||||
# showing the user the retrieved docs
|
||||
prompt_id: Mapped[int | None] = mapped_column(ForeignKey("prompt.id"))
|
||||
# If prompt is None, then token_count is 0 as this message won't be passed into
|
||||
# the LLM's context (not included in the history of messages)
|
||||
token_count: Mapped[int] = mapped_column(Integer)
|
||||
message_type: Mapped[MessageType] = mapped_column(Enum(MessageType))
|
||||
# Maps the citation numbers to a SearchDoc id
|
||||
citations: Mapped[dict[int, int]] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
# Only applies for LLM
|
||||
error: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
time_sent: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
chat_session: Mapped[ChatSession] = relationship("ChatSession")
|
||||
prompt: Mapped[Optional["Prompt"]] = relationship("Prompt")
|
||||
chat_message_feedbacks: Mapped[List["ChatMessageFeedback"]] = relationship(
|
||||
"ChatMessageFeedback", back_populates="chat_message"
|
||||
)
|
||||
document_feedbacks: Mapped[List["DocumentRetrievalFeedback"]] = relationship(
|
||||
"DocumentRetrievalFeedback", back_populates="chat_message"
|
||||
)
|
||||
search_docs = relationship(
|
||||
"SearchDoc",
|
||||
secondary="chat_message__search_doc",
|
||||
back_populates="chat_messages",
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Feedback, Logging, Metrics Tables
|
||||
"""
|
||||
|
||||
|
||||
class DocumentRetrievalFeedback(Base):
|
||||
__tablename__ = "document_retrieval_feedback"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
qa_event_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("query_event.id"),
|
||||
)
|
||||
document_id: Mapped[str] = mapped_column(
|
||||
ForeignKey("document.id"),
|
||||
)
|
||||
chat_message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
|
||||
document_id: Mapped[str] = mapped_column(ForeignKey("document.id"))
|
||||
# How high up this document is in the results, 1 for first
|
||||
document_rank: Mapped[int] = mapped_column(Integer)
|
||||
clicked: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
@@ -385,46 +573,32 @@ class DocumentRetrievalFeedback(Base):
|
||||
Enum(SearchFeedbackType), nullable=True
|
||||
)
|
||||
|
||||
qa_event: Mapped[QueryEvent] = relationship(
|
||||
"QueryEvent", back_populates="document_feedbacks"
|
||||
chat_message: Mapped[ChatMessage] = relationship(
|
||||
"ChatMessage", back_populates="document_feedbacks"
|
||||
)
|
||||
document: Mapped["Document"] = relationship(
|
||||
document: Mapped[Document] = relationship(
|
||||
"Document", back_populates="retrieval_feedbacks"
|
||||
)
|
||||
|
||||
|
||||
class Document(Base):
|
||||
__tablename__ = "document"
|
||||
class ChatMessageFeedback(Base):
|
||||
__tablename__ = "chat_feedback"
|
||||
|
||||
# this should correspond to the ID of the document
|
||||
# (as is passed around in Danswer)
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
# 0 for neutral, positive for mostly endorse, negative for mostly reject
|
||||
boost: Mapped[int] = mapped_column(Integer, default=DEFAULT_BOOST)
|
||||
hidden: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
semantic_id: Mapped[str] = mapped_column(String)
|
||||
# First Section's link
|
||||
link: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
doc_updated_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
# The following are not attached to User because the account/email may not be known
|
||||
# within Danswer
|
||||
# Something like the document creator
|
||||
primary_owners: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
# Something like assignee or space owner
|
||||
secondary_owners: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
# TODO if more sensitive data is added here for display, make sure to add user/group permission
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
chat_message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
|
||||
is_positive: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
feedback_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
retrieval_feedbacks: Mapped[List[DocumentRetrievalFeedback]] = relationship(
|
||||
"DocumentRetrievalFeedback", back_populates="document"
|
||||
chat_message: Mapped[ChatMessage] = relationship(
|
||||
"ChatMessage", back_populates="chat_message_feedbacks"
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Structures, Organizational, Configurations Tables
|
||||
"""
|
||||
|
||||
|
||||
class DocumentSet(Base):
|
||||
__tablename__ = "document_set"
|
||||
|
||||
@@ -432,7 +606,7 @@ class DocumentSet(Base):
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
description: Mapped[str] = mapped_column(String)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
# whether or not changes to the document set have been propogated
|
||||
# Whether changes to the document set have been propagated
|
||||
is_up_to_date: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
connector_credential_pairs: Mapped[list[ConnectorCredentialPair]] = relationship(
|
||||
@@ -448,59 +622,84 @@ class DocumentSet(Base):
|
||||
)
|
||||
|
||||
|
||||
class ChatSession(Base):
|
||||
__tablename__ = "chat_session"
|
||||
class Prompt(Base):
|
||||
__tablename__ = "prompt"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
# If not belong to a user, then it's shared
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
description: Mapped[str] = mapped_column(Text)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
description: Mapped[str] = mapped_column(String)
|
||||
system_prompt: Mapped[str] = mapped_column(Text)
|
||||
task_prompt: Mapped[str] = mapped_column(Text)
|
||||
include_citations: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
datetime_aware: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
# Default prompts are configured via backend during deployment
|
||||
# Treated specially (cannot be user edited etc.)
|
||||
default_prompt: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# The following texts help build up the model's ability to use the context effectively
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
user: Mapped[User] = relationship("User", back_populates="chat_sessions")
|
||||
messages: Mapped[List["ChatMessage"]] = relationship(
|
||||
"ChatMessage", back_populates="chat_session", cascade="delete"
|
||||
user: Mapped[User] = relationship("User", back_populates="prompts")
|
||||
personas: Mapped[list["Persona"]] = relationship(
|
||||
"Persona",
|
||||
secondary=Persona__Prompt.__table__,
|
||||
back_populates="prompts",
|
||||
)
|
||||
|
||||
|
||||
class ToolInfo(TypedDict):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
||||
class Persona(Base):
|
||||
# TODO introduce user and group ownership for personas
|
||||
__tablename__ = "persona"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
# If not belong to a user, then it's shared
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
# Danswer retrieval, treated as a special tool
|
||||
retrieval_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
datetime_aware: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
system_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
tools: Mapped[list[ToolInfo] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
description: Mapped[str] = mapped_column(String)
|
||||
# Currently stored but unused, all flows use hybrid
|
||||
search_type: Mapped[SearchType] = mapped_column(
|
||||
Enum(SearchType), default=SearchType.HYBRID
|
||||
)
|
||||
# Number of chunks to pass to the LLM for generation.
|
||||
# If unspecified, uses the default DEFAULT_NUM_CHUNKS_FED_TO_CHAT set in the env variable
|
||||
num_chunks: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
# Pass every chunk through LLM for evaluation, fairly expensive
|
||||
# Can be turned off globally by admin, in which case, this setting is ignored
|
||||
llm_relevance_filter: Mapped[bool] = mapped_column(Boolean)
|
||||
# Enables using LLM to extract time and source type filters
|
||||
# Can also be admin disabled globally
|
||||
llm_filter_extraction: Mapped[bool] = mapped_column(Boolean)
|
||||
recency_bias: Mapped[RecencyBiasSetting] = mapped_column(Enum(RecencyBiasSetting))
|
||||
# Allows the Persona to specify a different LLM version than is controlled
|
||||
# globablly via env variables. For flexibility, validity is not currently enforced
|
||||
# NOTE: only is applied on the actual response generation - is not used for things like
|
||||
# auto-detected time filters, relevance filters, etc.
|
||||
llm_model_version_override: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
)
|
||||
hint_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
# Default personas are configured via backend during deployment
|
||||
# Treated specially (cannot be user edited etc.)
|
||||
default_persona: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# If it's updated and no longer latest (should no longer be shown), it is also considered deleted
|
||||
# controls whether the persona is available to be selected by users
|
||||
is_visible: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
# controls the ordering of personas in the UI
|
||||
# higher priority personas are displayed first, ties are resolved by the ID,
|
||||
# where lower value IDs (e.g. created earlier) are displayed first
|
||||
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=None)
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# These are only defaults, users can select from all if desired
|
||||
prompts: Mapped[list[Prompt]] = relationship(
|
||||
"Prompt",
|
||||
secondary=Persona__Prompt.__table__,
|
||||
back_populates="personas",
|
||||
)
|
||||
# These are only defaults, users can select from all if desired
|
||||
document_sets: Mapped[list[DocumentSet]] = relationship(
|
||||
"DocumentSet",
|
||||
secondary=Persona__DocumentSet.__table__,
|
||||
back_populates="personas",
|
||||
)
|
||||
user: Mapped[User] = relationship("User", back_populates="personas")
|
||||
|
||||
# Default personas loaded via yaml cannot have the same name
|
||||
__table_args__ = (
|
||||
@@ -513,78 +712,13 @@ class Persona(Base):
|
||||
)
|
||||
|
||||
|
||||
class ChatMessage(Base):
|
||||
__tablename__ = "chat_message"
|
||||
|
||||
chat_session_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("chat_session.id"), primary_key=True
|
||||
)
|
||||
message_number: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
edit_number: Mapped[int] = mapped_column(Integer, default=0, primary_key=True)
|
||||
parent_edit_number: Mapped[int | None] = mapped_column(
|
||||
Integer, nullable=True
|
||||
) # null if first message
|
||||
latest: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
message: Mapped[str] = mapped_column(Text)
|
||||
token_count: Mapped[int] = mapped_column(Integer)
|
||||
message_type: Mapped[MessageType] = mapped_column(Enum(MessageType))
|
||||
reference_docs: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id"), nullable=True
|
||||
)
|
||||
time_sent: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
chat_session: Mapped[ChatSession] = relationship("ChatSession")
|
||||
persona: Mapped[Persona | None] = relationship("Persona")
|
||||
|
||||
|
||||
class ChatMessageFeedback(Base):
|
||||
__tablename__ = "chat_feedback"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
chat_message_chat_session_id: Mapped[int] = mapped_column(Integer)
|
||||
chat_message_message_number: Mapped[int] = mapped_column(Integer)
|
||||
chat_message_edit_number: Mapped[int] = mapped_column(Integer)
|
||||
is_positive: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
feedback_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
ForeignKeyConstraint(
|
||||
[
|
||||
"chat_message_chat_session_id",
|
||||
"chat_message_message_number",
|
||||
"chat_message_edit_number",
|
||||
],
|
||||
[
|
||||
"chat_message.chat_session_id",
|
||||
"chat_message.message_number",
|
||||
"chat_message.edit_number",
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
chat_message: Mapped[ChatMessage] = relationship(
|
||||
"ChatMessage",
|
||||
foreign_keys=[
|
||||
chat_message_chat_session_id,
|
||||
chat_message_message_number,
|
||||
chat_message_edit_number,
|
||||
],
|
||||
backref="feedbacks",
|
||||
)
|
||||
|
||||
|
||||
AllowedAnswerFilters = (
|
||||
Literal["well_answered_postfilter"] | Literal["questionmark_prefilter"]
|
||||
)
|
||||
|
||||
|
||||
class ChannelConfig(TypedDict):
|
||||
"""NOTE: is a `TypedDict` so it can be used a type hint for a JSONB column
|
||||
"""NOTE: is a `TypedDict` so it can be used as a type hint for a JSONB column
|
||||
in Postgres"""
|
||||
|
||||
channel_names: list[str]
|
||||
|
||||
@@ -3,15 +3,19 @@ from collections.abc import Sequence
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
|
||||
from danswer.db.chat import upsert_persona
|
||||
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from danswer.db.document_set import get_document_sets_by_ids
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Persona__DocumentSet
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.search.models import RecencyBiasSetting
|
||||
|
||||
|
||||
def _build_persona_name(channel_names: list[str]) -> str:
|
||||
return f"__slack_bot_persona__{'-'.join(channel_names)}"
|
||||
return f"{SLACK_BOT_PERSONA_PREFIX}{'-'.join(channel_names)}"
|
||||
|
||||
|
||||
def _cleanup_relationships(db_session: Session, persona_id: int) -> None:
|
||||
@@ -26,55 +30,51 @@ def _cleanup_relationships(db_session: Session, persona_id: int) -> None:
|
||||
db_session.delete(rel)
|
||||
|
||||
|
||||
def _create_slack_bot_persona(
|
||||
def create_slack_bot_persona(
|
||||
db_session: Session,
|
||||
channel_names: list[str],
|
||||
document_sets: list[int],
|
||||
document_set_ids: list[int],
|
||||
existing_persona_id: int | None = None,
|
||||
num_chunks: float = DEFAULT_NUM_CHUNKS_FED_TO_CHAT,
|
||||
) -> Persona:
|
||||
"""NOTE: does not commit changes"""
|
||||
document_sets = list(
|
||||
get_document_sets_by_ids(
|
||||
document_set_ids=document_set_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
|
||||
# create/update persona associated with the slack bot
|
||||
persona_name = _build_persona_name(channel_names)
|
||||
persona = upsert_persona(
|
||||
user_id=None, # Slack Bot Personas are not attached to users
|
||||
persona_id=existing_persona_id,
|
||||
name=persona_name,
|
||||
datetime_aware=False,
|
||||
retrieval_enabled=True,
|
||||
system_text=None,
|
||||
tools=None,
|
||||
hint_text=None,
|
||||
description="",
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
prompts=None,
|
||||
document_sets=document_sets,
|
||||
llm_model_version_override=None,
|
||||
shared=True,
|
||||
default_persona=False,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
if existing_persona_id:
|
||||
_cleanup_relationships(db_session=db_session, persona_id=existing_persona_id)
|
||||
|
||||
# create relationship between the new persona and the desired document_sets
|
||||
for document_set_id in document_sets:
|
||||
db_session.add(
|
||||
Persona__DocumentSet(persona_id=persona.id, document_set_id=document_set_id)
|
||||
)
|
||||
|
||||
return persona
|
||||
|
||||
|
||||
def insert_slack_bot_config(
|
||||
document_sets: list[int],
|
||||
persona_id: int | None,
|
||||
channel_config: ChannelConfig,
|
||||
db_session: Session,
|
||||
) -> SlackBotConfig:
|
||||
persona = None
|
||||
if document_sets:
|
||||
persona = _create_slack_bot_persona(
|
||||
db_session=db_session,
|
||||
channel_names=channel_config["channel_names"],
|
||||
document_sets=document_sets,
|
||||
)
|
||||
|
||||
slack_bot_config = SlackBotConfig(
|
||||
persona_id=persona.id if persona else None,
|
||||
persona_id=persona_id,
|
||||
channel_config=channel_config,
|
||||
)
|
||||
db_session.add(slack_bot_config)
|
||||
@@ -85,7 +85,7 @@ def insert_slack_bot_config(
|
||||
|
||||
def update_slack_bot_config(
|
||||
slack_bot_config_id: int,
|
||||
document_sets: list[int],
|
||||
persona_id: int | None,
|
||||
channel_config: ChannelConfig,
|
||||
db_session: Session,
|
||||
) -> SlackBotConfig:
|
||||
@@ -96,31 +96,29 @@ def update_slack_bot_config(
|
||||
raise ValueError(
|
||||
f"Unable to find slack bot config with ID {slack_bot_config_id}"
|
||||
)
|
||||
|
||||
# get the existing persona id before updating the object
|
||||
existing_persona_id = slack_bot_config.persona_id
|
||||
|
||||
persona = None
|
||||
if document_sets:
|
||||
persona = _create_slack_bot_persona(
|
||||
db_session=db_session,
|
||||
channel_names=channel_config["channel_names"],
|
||||
document_sets=document_sets,
|
||||
existing_persona_id=slack_bot_config.persona_id,
|
||||
# update the config
|
||||
# NOTE: need to do this before cleaning up the old persona or else we
|
||||
# will encounter `violates foreign key constraint` errors
|
||||
slack_bot_config.persona_id = persona_id
|
||||
slack_bot_config.channel_config = channel_config
|
||||
|
||||
# if the persona has changed, then clean up the old persona
|
||||
if persona_id != existing_persona_id and existing_persona_id:
|
||||
existing_persona = db_session.scalar(
|
||||
select(Persona).where(Persona.id == existing_persona_id)
|
||||
)
|
||||
else:
|
||||
# if no document sets and an existing persona exists, then
|
||||
# remove persona + persona -> document set relationships
|
||||
if existing_persona_id:
|
||||
# if the existing persona was one created just for use with this Slack Bot,
|
||||
# then clean it up
|
||||
if existing_persona and existing_persona.name.startswith(
|
||||
SLACK_BOT_PERSONA_PREFIX
|
||||
):
|
||||
_cleanup_relationships(
|
||||
db_session=db_session, persona_id=existing_persona_id
|
||||
)
|
||||
existing_persona = db_session.scalar(
|
||||
select(Persona).where(Persona.id == existing_persona_id)
|
||||
)
|
||||
db_session.delete(existing_persona)
|
||||
|
||||
slack_bot_config.persona_id = persona.id if persona else None
|
||||
slack_bot_config.channel_config = channel_config
|
||||
db_session.commit()
|
||||
|
||||
return slack_bot_config
|
||||
@@ -140,11 +138,30 @@ def remove_slack_bot_config(
|
||||
|
||||
existing_persona_id = slack_bot_config.persona_id
|
||||
if existing_persona_id:
|
||||
_cleanup_relationships(db_session=db_session, persona_id=existing_persona_id)
|
||||
existing_persona = db_session.scalar(
|
||||
select(Persona).where(Persona.id == existing_persona_id)
|
||||
)
|
||||
# if the existing persona was one created just for use with this Slack Bot,
|
||||
# then clean it up
|
||||
if existing_persona and existing_persona.name.startswith(
|
||||
SLACK_BOT_PERSONA_PREFIX
|
||||
):
|
||||
_cleanup_relationships(
|
||||
db_session=db_session, persona_id=existing_persona_id
|
||||
)
|
||||
db_session.delete(existing_persona)
|
||||
|
||||
db_session.delete(slack_bot_config)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def fetch_slack_bot_config(
|
||||
db_session: Session, slack_bot_config_id: int
|
||||
) -> SlackBotConfig | None:
|
||||
return db_session.scalar(
|
||||
select(SlackBotConfig).where(SlackBotConfig.id == slack_bot_config_id)
|
||||
)
|
||||
|
||||
|
||||
def fetch_slack_bot_configs(db_session: Session) -> Sequence[SlackBotConfig]:
|
||||
return db_session.scalars(select(SlackBotConfig)).all()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user