mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 00:05:47 +00:00
Compare commits
166 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bf6efbe8e8 | ||
|
|
7c29b1e028 | ||
|
|
a52c821e78 | ||
|
|
21967d4b6f | ||
|
|
cae8a131a2 | ||
|
|
72b4e8e9fe | ||
|
|
c04e2f14d9 | ||
|
|
b40a12d5d7 | ||
|
|
5e7d454ebe | ||
|
|
238509c536 | ||
|
|
9455576078 | ||
|
|
71421bb782 | ||
|
|
b88cb388b7 | ||
|
|
639986001f | ||
|
|
e7a7e78969 | ||
|
|
e255ff7d23 | ||
|
|
1be2502112 | ||
|
|
f2bedb8fdd | ||
|
|
637404f482 | ||
|
|
daae146920 | ||
|
|
d95959fb41 | ||
|
|
c667d28e7a | ||
|
|
9e0b482f47 | ||
|
|
fa84eb657f | ||
|
|
264df3441b | ||
|
|
b9bad8b7a0 | ||
|
|
600ebb6432 | ||
|
|
09fe8ea868 | ||
|
|
ad6be03b4d | ||
|
|
65d2511216 | ||
|
|
113bf19c65 | ||
|
|
6026536110 | ||
|
|
056b671cd4 | ||
|
|
8d83ae2ee8 | ||
|
|
ca988f5c5f | ||
|
|
4e4214b82c | ||
|
|
fe83f676df | ||
|
|
6d6e12119b | ||
|
|
1f2b7cb9c8 | ||
|
|
878a189011 | ||
|
|
48c10271c2 | ||
|
|
c6a79d847e | ||
|
|
1bc3f8b96f | ||
|
|
7f6a6944d6 | ||
|
|
06f4146597 | ||
|
|
7ea73d5a5a | ||
|
|
30dfe6dcb4 | ||
|
|
dc5d5dfe05 | ||
|
|
0746e0be5b | ||
|
|
970320bd49 | ||
|
|
4a7bd5578e | ||
|
|
874b098a4b | ||
|
|
ce18b63eea | ||
|
|
7a919c3589 | ||
|
|
631bac4432 | ||
|
|
53428f6e9c | ||
|
|
53b3dcbace | ||
|
|
7a3c06c2d2 | ||
|
|
7a0d823c89 | ||
|
|
db69e445d6 | ||
|
|
18e63889b7 | ||
|
|
738e60c8ed | ||
|
|
8aec873e66 | ||
|
|
7c57dde8ab | ||
|
|
f30adab853 | ||
|
|
601687a522 | ||
|
|
350cf407c9 | ||
|
|
32ec4efc7a | ||
|
|
7c6981e052 | ||
|
|
c50cd20156 | ||
|
|
14772dee71 | ||
|
|
c81e704c95 | ||
|
|
3266ef6321 | ||
|
|
c89b98b4f2 | ||
|
|
e70e0ab859 | ||
|
|
69b6e9321e | ||
|
|
7e53af18b6 | ||
|
|
b9eb1ca2ba | ||
|
|
91d44c83d2 | ||
|
|
4dbc6bb4d1 | ||
|
|
4b6a4c6bbf | ||
|
|
fd1999454a | ||
|
|
0a35422d1d | ||
|
|
69b99056b2 | ||
|
|
2a55696545 | ||
|
|
ef9942b751 | ||
|
|
993acec5e9 | ||
|
|
b01a1b509a | ||
|
|
4f994124ef | ||
|
|
14863bd457 | ||
|
|
aa1c4c635a | ||
|
|
13f6e8a6b4 | ||
|
|
66f47d294c | ||
|
|
0a685bda7d | ||
|
|
23dc8b5dad | ||
|
|
cd5f2293ad | ||
|
|
6c2269e565 | ||
|
|
46315cddf1 | ||
|
|
5f28a1b0e4 | ||
|
|
9e9b7ed61d | ||
|
|
3fb2bfefec | ||
|
|
7c618c9d17 | ||
|
|
03e2789392 | ||
|
|
2783fa08a3 | ||
|
|
edeaee93a2 | ||
|
|
5385bae100 | ||
|
|
813445ab59 | ||
|
|
af814823c8 | ||
|
|
607f61eaeb | ||
|
|
de66f7adb2 | ||
|
|
3432d932d1 | ||
|
|
9bd0cb9eb5 | ||
|
|
f12eb4a5cf | ||
|
|
16863de0aa | ||
|
|
63d1eefee5 | ||
|
|
e338677896 | ||
|
|
7be80c4af9 | ||
|
|
7f1e4a02bf | ||
|
|
5be7d27285 | ||
|
|
fd84b7a768 | ||
|
|
36941ae663 | ||
|
|
212353ed4a | ||
|
|
eb8708f770 | ||
|
|
ac448956e9 | ||
|
|
634a0b9398 | ||
|
|
09d3e47c03 | ||
|
|
9c0cc94f15 | ||
|
|
07dfde2209 | ||
|
|
28e2b78b2e | ||
|
|
0553062ac6 | ||
|
|
284e375ba3 | ||
|
|
1f2f7d0ac2 | ||
|
|
2ecc28b57d | ||
|
|
77cf9b3539 | ||
|
|
076ce2ebd0 | ||
|
|
b625ee32a7 | ||
|
|
c32b93fcc3 | ||
|
|
1c8476072e | ||
|
|
7573416ca1 | ||
|
|
86d8666481 | ||
|
|
8abcde91d4 | ||
|
|
3466451d51 | ||
|
|
413891f143 | ||
|
|
7a0a4d4b79 | ||
|
|
a3439605a5 | ||
|
|
694e79f5e1 | ||
|
|
5dfafc8612 | ||
|
|
62a4aa10db | ||
|
|
a357cdc4c9 | ||
|
|
84615abfdd | ||
|
|
8ae6b1960b | ||
|
|
d9b87bbbc2 | ||
|
|
a0065b01af | ||
|
|
c5306148a3 | ||
|
|
1e17934de4 | ||
|
|
93add96ccc | ||
|
|
3a466a4b08 | ||
|
|
85cbd9caed | ||
|
|
9dc23bf3e7 | ||
|
|
e32809f7ca | ||
|
|
3e58f9f8ab | ||
|
|
2381c8d498 | ||
|
|
c6dadb24dc | ||
|
|
5dc07d4178 | ||
|
|
129c8f8faf | ||
|
|
67bfcabbc5 |
@@ -24,6 +24,8 @@ env:
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
|
||||
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
# Slab
|
||||
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
|
||||
@@ -32,7 +32,7 @@ To contribute to this project, please follow the
|
||||
When opening a pull request, mention related issues and feel free to tag relevant maintainers.
|
||||
|
||||
Before creating a pull request please make sure that the new changes conform to the formatting and linting requirements.
|
||||
See the [Formatting and Linting](#-formatting-and-linting) section for how to run these checks locally.
|
||||
See the [Formatting and Linting](#formatting-and-linting) section for how to run these checks locally.
|
||||
|
||||
|
||||
### Getting Help 🙋
|
||||
|
||||
60
README.md
60
README.md
@@ -1,48 +1,48 @@
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/README.md"} -->
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/README.md"} -->
|
||||
<a name="readme-top"></a>
|
||||
|
||||
<h2 align="center">
|
||||
<a href="https://www.danswer.ai/"> <img width="50%" src="https://github.com/danswer-owners/danswer/blob/1fabd9372d66cd54238847197c33f091a724803b/DanswerWithName.png?raw=true)" /></a>
|
||||
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/LogoOnyx.png?raw=true)" /></a>
|
||||
</h2>
|
||||
|
||||
<p align="center">
|
||||
<p align="center">Open Source Gen-AI Chat + Unified Search.</p>
|
||||
<p align="center">Open Source Gen-AI + Enterprise Search.</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://docs.danswer.dev/" target="_blank">
|
||||
<a href="https://docs.onyx.app/" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA" target="_blank">
|
||||
<a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2sslpdbyq-iIbTaNIVPBw_i_4vrujLYQ" target="_blank">
|
||||
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
|
||||
</a>
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
<img src="https://img.shields.io/badge/discord-join-blue.svg?logo=discord&logoColor=white" alt="Discord">
|
||||
</a>
|
||||
<a href="https://github.com/danswer-ai/danswer/blob/main/README.md" target="_blank">
|
||||
<a href="https://github.com/onyx-dot-app/onyx/blob/main/README.md" target="_blank">
|
||||
<img src="https://img.shields.io/static/v1?label=license&message=MIT&color=blue" alt="License">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<strong>[Danswer](https://www.danswer.ai/)</strong> is the AI Assistant connected to your company's docs, apps, and people.
|
||||
Danswer provides a Chat interface and plugs into any LLM of your choice. Danswer can be deployed anywhere and for any
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (Formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
|
||||
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
|
||||
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
|
||||
own control. Danswer is MIT licensed and designed to be modular and easily extensible. The system also comes fully ready
|
||||
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
|
||||
for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for
|
||||
configuring Personas (AI Assistants) and their Prompts.
|
||||
configuring AI Assistants.
|
||||
|
||||
Danswer also serves as a Unified Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
|
||||
By combining LLMs and team specific knowledge, Danswer becomes a subject matter expert for the team. Imagine ChatGPT if
|
||||
Onyx also serves as a Enterprise Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
|
||||
By combining LLMs and team specific knowledge, Onyx becomes a subject matter expert for the team. Imagine ChatGPT if
|
||||
it had access to your team's unique knowledge! It enables questions such as "A customer wants feature X, is this already
|
||||
supported?" or "Where's the pull request for feature Y?"
|
||||
|
||||
<h3>Usage</h3>
|
||||
|
||||
Danswer Web App:
|
||||
Onyx Web App:
|
||||
|
||||
https://github.com/danswer-ai/danswer/assets/32520769/563be14c-9304-47b5-bf0a-9049c2b6f410
|
||||
|
||||
|
||||
Or, plug Danswer into your existing Slack workflows (more integrations to come 😁):
|
||||
Or, plug Onyx into your existing Slack workflows (more integrations to come 😁):
|
||||
|
||||
https://github.com/danswer-ai/danswer/assets/25087905/3e19739b-d178-4371-9a38-011430bdec1b
|
||||
|
||||
@@ -52,16 +52,16 @@ For more details on the Admin UI to manage connectors and users, check out our
|
||||
|
||||
## Deployment
|
||||
|
||||
Danswer can easily be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
`docker compose` command. Checkout our [docs](https://docs.danswer.dev/quickstart) to learn more.
|
||||
Onyx can easily be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
`docker compose` command. Checkout our [docs](https://docs.onyx.app/quickstart) to learn more.
|
||||
|
||||
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/danswer-ai/danswer/tree/main/deployment/kubernetes).
|
||||
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment/kubernetes).
|
||||
|
||||
|
||||
## 💃 Main Features
|
||||
* Chat UI with the ability to select documents to chat with.
|
||||
* Create custom AI Assistants with different prompts and backing knowledge sets.
|
||||
* Connect Danswer with LLM of your choice (self-host for a fully airgapped solution).
|
||||
* Connect Onyx with LLM of your choice (self-host for a fully airgapped solution).
|
||||
* Document Search + AI Answers for natural language queries.
|
||||
* Connectors to all common workplace tools like Google Drive, Confluence, Slack, etc.
|
||||
* Slack integration to get answers and search results directly in Slack.
|
||||
@@ -75,12 +75,12 @@ We also have built-in support for deployment on Kubernetes. Files for that can b
|
||||
* Organizational understanding and ability to locate and suggest experts from your team.
|
||||
|
||||
|
||||
## Other Notable Benefits of Danswer
|
||||
## Other Notable Benefits of Onyx
|
||||
* User Authentication with document level access management.
|
||||
* Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
|
||||
* Admin Dashboard to configure connectors, document-sets, access, etc.
|
||||
* Custom deep learning models + learn from user feedback.
|
||||
* Easy deployment and ability to host Danswer anywhere of your choosing.
|
||||
* Easy deployment and ability to host Onyx anywhere of your choosing.
|
||||
|
||||
|
||||
## 🔌 Connectors
|
||||
@@ -108,10 +108,10 @@ Efficiently pulls the latest changes from:
|
||||
|
||||
## 📚 Editions
|
||||
|
||||
There are two editions of Danswer:
|
||||
There are two editions of Onyx:
|
||||
|
||||
* Danswer Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Danswer you will get if you follow the Deployment guide above.
|
||||
* Danswer Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
|
||||
* Onyx Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Onyx you will get if you follow the Deployment guide above.
|
||||
* Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
|
||||
* Single Sign-On (SSO), with support for both SAML and OIDC
|
||||
* Role-based access control
|
||||
* Document permission inheritance from connected sources
|
||||
@@ -119,24 +119,24 @@ There are two editions of Danswer:
|
||||
* Whitelabeling
|
||||
* API key authentication
|
||||
* Encryption of secrets
|
||||
* Any many more! Checkout [our website](https://www.danswer.ai/) for the latest.
|
||||
* Any many more! Checkout [our website](https://www.onyx.app/) for the latest.
|
||||
|
||||
To try the Danswer Enterprise Edition:
|
||||
To try the Onyx Enterprise Edition:
|
||||
|
||||
1. Checkout our [Cloud product](https://app.danswer.ai/signup).
|
||||
2. For self-hosting, contact us at [founders@danswer.ai](mailto:founders@danswer.ai) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
|
||||
1. Checkout our [Cloud product](https://cloud.onyx.app/signup).
|
||||
2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
|
||||
|
||||
## 💡 Contributing
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
## ⭐Star History
|
||||
|
||||
[](https://star-history.com/#danswer-ai/danswer&Date)
|
||||
[](https://star-history.com/#onyx-dot-app/onyx&Date)
|
||||
|
||||
## ✨Contributors
|
||||
|
||||
<a href="https://github.com/danswer-ai/danswer/graphs/contributors">
|
||||
<img alt="contributors" src="https://contrib.rocks/image?repo=danswer-ai/danswer"/>
|
||||
<a href="https://github.com/onyx-dot-app/onyx/graphs/contributors">
|
||||
<img alt="contributors" src="https://contrib.rocks/image?repo=onyx-dot-app/onyx"/>
|
||||
</a>
|
||||
|
||||
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
|
||||
|
||||
@@ -73,6 +73,7 @@ RUN apt-get update && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
import logging
|
||||
@@ -8,6 +8,7 @@ from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.sql import text
|
||||
from sqlalchemy.sql.schema import SchemaItem
|
||||
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from danswer.db.engine import build_connection_string
|
||||
@@ -35,7 +36,18 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def include_object(
|
||||
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
|
||||
object: SchemaItem,
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
) -> bool:
|
||||
"""
|
||||
Determines whether a database object should be included in migrations.
|
||||
|
||||
45
backend/alembic/versions/6d562f86c78b_remove_default_bot.py
Normal file
45
backend/alembic/versions/6d562f86c78b_remove_default_bot.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""remove default bot
|
||||
|
||||
Revision ID: 6d562f86c78b
|
||||
Revises: 177de57c21c9
|
||||
Create Date: 2024-11-22 11:51:29.331336
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6d562f86c78b"
|
||||
down_revision = "177de57c21c9"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM slack_bot
|
||||
WHERE name = 'Default Bot'
|
||||
AND bot_token = ''
|
||||
AND app_token = ''
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM slack_channel_config
|
||||
WHERE slack_channel_config.slack_bot_id = slack_bot.id
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO slack_bot (name, enabled, bot_token, app_token)
|
||||
SELECT 'Default Bot', true, '', ''
|
||||
WHERE NOT EXISTS (SELECT 1 FROM slack_bot)
|
||||
RETURNING id;
|
||||
"""
|
||||
)
|
||||
)
|
||||
@@ -9,8 +9,8 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.context.search.enums import RecencyBiasSetting
|
||||
from danswer.context.search.enums import SearchType
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "776b3bbe9092"
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
"""add web ui option to slack config
|
||||
|
||||
Revision ID: 93560ba1b118
|
||||
Revises: 6d562f86c78b
|
||||
Create Date: 2024-11-24 06:36:17.490612
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "93560ba1b118"
|
||||
down_revision = "6d562f86c78b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add show_continue_in_web_ui with default False to all existing channel_configs
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE slack_channel_config
|
||||
SET channel_config = channel_config || '{"show_continue_in_web_ui": false}'::jsonb
|
||||
WHERE NOT channel_config ? 'show_continue_in_web_ui'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove show_continue_in_web_ui from all channel_configs
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE slack_channel_config
|
||||
SET channel_config = channel_config - 'show_continue_in_web_ui'
|
||||
"""
|
||||
)
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Combine Search and Chat
|
||||
|
||||
Revision ID: 9f696734098f
|
||||
Revises: a8c2065484e6
|
||||
Create Date: 2024-11-27 15:32:19.694972
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9f696734098f"
|
||||
down_revision = "a8c2065484e6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("chat_session", "description", nullable=True)
|
||||
op.drop_column("chat_session", "one_shot")
|
||||
op.drop_column("slack_channel_config", "response_type")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("UPDATE chat_session SET description = '' WHERE description IS NULL")
|
||||
op.alter_column("chat_session", "description", nullable=False)
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column("one_shot", sa.Boolean(), nullable=False, server_default=sa.false()),
|
||||
)
|
||||
op.add_column(
|
||||
"slack_channel_config",
|
||||
sa.Column(
|
||||
"response_type", sa.String(), nullable=False, server_default="citations"
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,27 @@
|
||||
"""add auto scroll to user model
|
||||
|
||||
Revision ID: a8c2065484e6
|
||||
Revises: abe7378b8217
|
||||
Create Date: 2024-11-22 17:34:09.690295
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a8c2065484e6"
|
||||
down_revision = "abe7378b8217"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("auto_scroll", sa.Boolean(), nullable=True, server_default=None),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "auto_scroll")
|
||||
@@ -0,0 +1,30 @@
|
||||
"""add indexing trigger to cc_pair
|
||||
|
||||
Revision ID: abe7378b8217
|
||||
Revises: 6d562f86c78b
|
||||
Create Date: 2024-11-26 19:09:53.481171
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "abe7378b8217"
|
||||
down_revision = "93560ba1b118"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"indexing_trigger",
|
||||
sa.Enum("UPDATE", "REINDEX", name="indexingmode", native_enum=False),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "indexing_trigger")
|
||||
@@ -0,0 +1,57 @@
|
||||
"""delete_input_prompts
|
||||
|
||||
Revision ID: bf7a81109301
|
||||
Revises: f7a894b06d02
|
||||
Create Date: 2024-12-09 12:00:49.884228
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bf7a81109301"
|
||||
down_revision = "f7a894b06d02"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_table("inputprompt__user")
|
||||
op.drop_table("inputprompt")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_table(
|
||||
"inputprompt",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("prompt", sa.String(), nullable=False),
|
||||
sa.Column("content", sa.String(), nullable=False),
|
||||
sa.Column("active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"inputprompt__user",
|
||||
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["input_prompt_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
|
||||
)
|
||||
@@ -0,0 +1,40 @@
|
||||
"""non-nullbale slack bot id in channel config
|
||||
|
||||
Revision ID: f7a894b06d02
|
||||
Revises: 9f696734098f
|
||||
Create Date: 2024-12-06 12:55:42.845723
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f7a894b06d02"
|
||||
down_revision = "9f696734098f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Delete all rows with null slack_bot_id
|
||||
op.execute("DELETE FROM slack_channel_config WHERE slack_bot_id IS NULL")
|
||||
|
||||
# Make slack_bot_id non-nullable
|
||||
op.alter_column(
|
||||
"slack_channel_config",
|
||||
"slack_bot_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Make slack_bot_id nullable again
|
||||
op.alter_column(
|
||||
"slack_channel_config",
|
||||
"slack_bot_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=True,
|
||||
)
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
@@ -37,8 +38,15 @@ EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
name: str,
|
||||
type_: str,
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
) -> bool:
|
||||
|
||||
@@ -18,6 +18,11 @@ class ExternalAccess:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocExternalAccess:
|
||||
"""
|
||||
This is just a class to wrap the external access and the document ID
|
||||
together. It's used for syncing document permissions to Redis.
|
||||
"""
|
||||
|
||||
external_access: ExternalAccess
|
||||
# The document ID
|
||||
doc_id: str
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
import uuid
|
||||
from urllib.parse import quote
|
||||
@@ -18,7 +19,8 @@ _API_KEY_HEADER_NAME = "Authorization"
|
||||
# organizations like the Internet Engineering Task Force (IETF).
|
||||
_API_KEY_HEADER_ALTERNATIVE_NAME = "X-Danswer-Authorization"
|
||||
_BEARER_PREFIX = "Bearer "
|
||||
_API_KEY_PREFIX = "dn_"
|
||||
_API_KEY_PREFIX = "on_"
|
||||
_DEPRECATED_API_KEY_PREFIX = "dn_"
|
||||
_API_KEY_LEN = 192
|
||||
|
||||
|
||||
@@ -52,7 +54,9 @@ def extract_tenant_from_api_key_header(request: Request) -> str | None:
|
||||
|
||||
api_key = raw_api_key_header[len(_BEARER_PREFIX) :].strip()
|
||||
|
||||
if not api_key.startswith(_API_KEY_PREFIX):
|
||||
if not api_key.startswith(_API_KEY_PREFIX) and not api_key.startswith(
|
||||
_DEPRECATED_API_KEY_PREFIX
|
||||
):
|
||||
return None
|
||||
|
||||
parts = api_key[len(_API_KEY_PREFIX) :].split(".", 1)
|
||||
@@ -63,10 +67,19 @@ def extract_tenant_from_api_key_header(request: Request) -> str | None:
|
||||
return unquote(tenant_id) if tenant_id else None
|
||||
|
||||
|
||||
def _deprecated_hash_api_key(api_key: str) -> str:
|
||||
return sha256_crypt.hash(api_key, salt="", rounds=API_KEY_HASH_ROUNDS)
|
||||
|
||||
|
||||
def hash_api_key(api_key: str) -> str:
|
||||
# NOTE: no salt is needed, as the API key is randomly generated
|
||||
# and overlaps are impossible
|
||||
return sha256_crypt.hash(api_key, salt="", rounds=API_KEY_HASH_ROUNDS)
|
||||
if api_key.startswith(_API_KEY_PREFIX):
|
||||
return hashlib.sha256(api_key.encode("utf-8")).hexdigest()
|
||||
elif api_key.startswith(_DEPRECATED_API_KEY_PREFIX):
|
||||
return _deprecated_hash_api_key(api_key)
|
||||
else:
|
||||
raise ValueError(f"Invalid API key prefix: {api_key[:3]}")
|
||||
|
||||
|
||||
def build_displayable_api_key(api_key: str) -> str:
|
||||
|
||||
@@ -9,7 +9,6 @@ from danswer.utils.special_types import JSON_ro
|
||||
def get_invited_users() -> list[str]:
|
||||
try:
|
||||
store = get_kv_store()
|
||||
|
||||
return cast(list, store.load(KV_USER_STORE_KEY))
|
||||
except KvKeyNotFoundError:
|
||||
return list()
|
||||
|
||||
@@ -23,7 +23,9 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
|
||||
)
|
||||
return UserPreferences(**preferences_data)
|
||||
except KvKeyNotFoundError:
|
||||
return UserPreferences(chosen_assistants=None, default_model=None)
|
||||
return UserPreferences(
|
||||
chosen_assistants=None, default_model=None, auto_scroll=True
|
||||
)
|
||||
|
||||
|
||||
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:
|
||||
|
||||
@@ -49,7 +49,7 @@ from httpx_oauth.oauth2 import BaseOAuth2
|
||||
from httpx_oauth.oauth2 import OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from danswer.auth.api_key import get_hashed_api_key_from_request
|
||||
from danswer.auth.invited_users import get_invited_users
|
||||
@@ -58,7 +58,6 @@ from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import DISABLE_VERIFICATION
|
||||
from danswer.configs.app_configs import EMAIL_FROM
|
||||
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
@@ -80,13 +79,14 @@ from danswer.db.auth import get_default_admin_user_emails
|
||||
from danswer.db.auth import get_user_count
|
||||
from danswer.db.auth import get_user_db
|
||||
from danswer.db.auth import SQLAlchemyUserAdminDB
|
||||
from danswer.db.engine import get_async_session
|
||||
from danswer.db.engine import get_async_session_with_tenant
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import OAuthAccount
|
||||
from danswer.db.models import User
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.server.utils import BasicAuthenticationError
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
@@ -99,11 +99,6 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
@@ -136,11 +131,12 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
|
||||
|
||||
|
||||
def user_needs_to_be_verified() -> bool:
|
||||
# all other auth types besides basic should require users to be
|
||||
# verified
|
||||
return not DISABLE_VERIFICATION and (
|
||||
AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
)
|
||||
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
|
||||
return REQUIRE_EMAIL_VERIFICATION
|
||||
|
||||
# For other auth types, if the user is authenticated it's assumed that
|
||||
# the user is already verified via the external IDP
|
||||
return False
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
@@ -609,7 +605,7 @@ optional_fastapi_current_user = fastapi_users.current_user(active=True, optional
|
||||
async def optional_user_(
|
||||
request: Request,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
async_db_session: AsyncSession,
|
||||
) -> User | None:
|
||||
"""NOTE: `request` and `db_session` are not used here, but are included
|
||||
for the EE version of this function."""
|
||||
@@ -618,13 +614,21 @@ async def optional_user_(
|
||||
|
||||
async def optional_user(
|
||||
request: Request,
|
||||
db_session: Session = Depends(get_session),
|
||||
async_db_session: AsyncSession = Depends(get_async_session),
|
||||
user: User | None = Depends(optional_fastapi_current_user),
|
||||
) -> User | None:
|
||||
versioned_fetch_user = fetch_versioned_implementation(
|
||||
"danswer.auth.users", "optional_user_"
|
||||
)
|
||||
return await versioned_fetch_user(request, user, db_session)
|
||||
user = await versioned_fetch_user(request, user, async_db_session)
|
||||
|
||||
# check if an API key is present
|
||||
if user is None:
|
||||
hashed_api_key = get_hashed_api_key_from_request(request)
|
||||
if hashed_api_key:
|
||||
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def double_check_user(
|
||||
@@ -910,8 +914,8 @@ def get_oauth_router(
|
||||
return router
|
||||
|
||||
|
||||
def api_key_dep(
|
||||
request: Request, db_session: Session = Depends(get_session)
|
||||
async def api_key_dep(
|
||||
request: Request, async_db_session: AsyncSession = Depends(get_async_session)
|
||||
) -> User | None:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return None
|
||||
@@ -921,7 +925,7 @@ def api_key_dep(
|
||||
raise HTTPException(status_code=401, detail="Missing API key")
|
||||
|
||||
if hashed_api_key:
|
||||
user = fetch_user_for_api_key(hashed_api_key, db_session)
|
||||
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
@@ -11,6 +11,7 @@ from celery.exceptions import WorkerShutdown
|
||||
from celery.states import READY_STATES
|
||||
from celery.utils.log import get_task_logger
|
||||
from celery.worker import strategy # type: ignore
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -332,16 +333,16 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
return
|
||||
|
||||
logger.info("Releasing primary worker lock.")
|
||||
lock = sender.primary_worker_lock
|
||||
lock: RedisLock = sender.primary_worker_lock
|
||||
try:
|
||||
if lock.owned():
|
||||
try:
|
||||
lock.release()
|
||||
sender.primary_worker_lock = None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to release primary worker lock: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check if primary worker lock is owned: {e}")
|
||||
except Exception:
|
||||
logger.exception("Failed to release primary worker lock")
|
||||
except Exception:
|
||||
logger.exception("Failed to check if primary worker lock is owned")
|
||||
|
||||
|
||||
def on_setup_logging(
|
||||
|
||||
@@ -11,6 +11,7 @@ from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
@@ -24,7 +25,7 @@ from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
||||
from danswer.db.engine import get_session_with_default_tenant
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_canceled
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
@@ -38,7 +39,6 @@ from danswer.redis.redis_usergroup import RedisUserGroup
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
@@ -116,9 +116,13 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
# it is planned to use this lock to enforce singleton behavior on the primary
|
||||
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
||||
# implemented yet.
|
||||
lock = r.lock(
|
||||
|
||||
# set thread_local=False since we don't control what thread the periodic task might
|
||||
# reacquire the lock with
|
||||
lock: RedisLock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
logger.info("Primary worker lock: Acquire starting.")
|
||||
@@ -165,13 +169,13 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
continue
|
||||
|
||||
failure_reason = (
|
||||
f"Orphaned index attempt found on startup: "
|
||||
f"Canceling leftover index attempt found on startup: "
|
||||
f"index_attempt={attempt.id} "
|
||||
f"cc_pair={attempt.connector_credential_pair_id} "
|
||||
f"search_settings={attempt.search_settings_id}"
|
||||
)
|
||||
logger.warning(failure_reason)
|
||||
mark_attempt_failed(attempt.id, db_session, failure_reason)
|
||||
mark_attempt_canceled(attempt.id, db_session, failure_reason)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
@@ -227,7 +231,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
if not hasattr(worker, "primary_worker_lock"):
|
||||
return
|
||||
|
||||
lock = worker.primary_worker_lock
|
||||
lock: RedisLock = worker.primary_worker_lock
|
||||
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
|
||||
@@ -2,54 +2,55 @@ from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
|
||||
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": "check_for_connector_deletion_task",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": "check_for_indexing",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": "check_for_pruning",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_PRUNING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": "kombu_message_cleanup_task",
|
||||
"task": DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": "monitor_vespa_sync",
|
||||
"task": DanswerCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": "check_for_doc_permissions_sync",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": "check_for_external_group_sync",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
|
||||
@@ -5,13 +5,13 @@ from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
@@ -29,7 +29,7 @@ class TaskDependencyError(RuntimeError):
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_connector_deletion_task",
|
||||
name=DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
@@ -37,7 +37,7 @@ class TaskDependencyError(RuntimeError):
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
lock_beat: RedisLock = r.lock(
|
||||
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -60,7 +60,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
try:
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
|
||||
self.app, cc_pair_id, db_session, lock_beat, tenant_id
|
||||
)
|
||||
except TaskDependencyError as e:
|
||||
# this means we wanted to start deleting but dependent tasks were running
|
||||
@@ -86,7 +86,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
|
||||
@@ -8,6 +8,7 @@ from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from danswer.access.models import DocExternalAccess
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
@@ -17,9 +18,11 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import upsert_document_by_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -27,7 +30,7 @@ from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncData,
|
||||
RedisConnectorPermissionSyncPayload,
|
||||
)
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import doc_permission_sync_ctx
|
||||
@@ -81,7 +84,7 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_doc_permissions_sync",
|
||||
name=DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -138,7 +141,7 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock = r.lock(
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -162,8 +165,8 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
|
||||
|
||||
app.send_task(
|
||||
"connector_permission_sync_generator_task",
|
||||
result = app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -174,8 +177,8 @@ def try_creating_permissions_sync_task(
|
||||
)
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorPermissionSyncData(
|
||||
started=None,
|
||||
payload = RedisConnectorPermissionSyncPayload(
|
||||
started=None, celery_task_id=result.id
|
||||
)
|
||||
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
@@ -190,7 +193,7 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="connector_permission_sync_generator_task",
|
||||
name=DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
@@ -216,7 +219,7 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock = r.lock(
|
||||
lock: RedisLock = r.lock(
|
||||
DanswerRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
|
||||
@@ -241,13 +244,17 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if doc_sync_func is None:
|
||||
raise ValueError(f"No doc sync func found for {source_type}")
|
||||
raise ValueError(
|
||||
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
logger.info(f"Syncing docs for {source_type}")
|
||||
logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
|
||||
|
||||
payload = RedisConnectorPermissionSyncData(
|
||||
started=datetime.now(timezone.utc),
|
||||
)
|
||||
payload = redis_connector.permissions.payload
|
||||
if not payload:
|
||||
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
|
||||
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
|
||||
@@ -256,7 +263,12 @@ def connector_permission_sync_generator_task(
|
||||
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = redis_connector.permissions.generate_tasks(
|
||||
self.app, lock, document_external_accesses, source_type
|
||||
celery_app=self.app,
|
||||
lock=lock,
|
||||
new_permissions=document_external_accesses,
|
||||
source_string=source_type,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
@@ -281,7 +293,7 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="update_external_document_permissions_task",
|
||||
name=DanswerCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
|
||||
@@ -292,6 +304,8 @@ def update_external_document_permissions_task(
|
||||
tenant_id: str | None,
|
||||
serialized_doc_external_access: dict,
|
||||
source_string: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> bool:
|
||||
document_external_access = DocExternalAccess.from_dict(
|
||||
serialized_doc_external_access
|
||||
@@ -300,18 +314,28 @@ def update_external_document_permissions_task(
|
||||
external_access = document_external_access.external_access
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Then we build the update requests to update vespa
|
||||
# Add the users to the DB if they don't exist
|
||||
batch_add_ext_perm_user_if_not_exists(
|
||||
db_session=db_session,
|
||||
emails=list(external_access.external_user_emails),
|
||||
)
|
||||
upsert_document_external_perms(
|
||||
# Then we upsert the document's external permissions in postgres
|
||||
created_new_doc = upsert_document_external_perms(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
external_access=external_access,
|
||||
source_type=DocumentSource(source_string),
|
||||
)
|
||||
|
||||
if created_new_doc:
|
||||
# If a new document was created, we associate it with the cc_pair
|
||||
upsert_document_by_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
document_ids=[doc_id],
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully synced postgres document permissions for {doc_id}"
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
@@ -16,6 +17,7 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import mark_cc_pair_as_external_group_synced
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
@@ -24,13 +26,20 @@ from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_ext_group_sync import (
|
||||
RedisConnectorExternalGroupSyncPayload,
|
||||
)
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.danswer.db.connector_credential_pair import get_cc_pairs_by_source
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair
|
||||
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS
|
||||
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
|
||||
from ee.danswer.external_permissions.sync_params import (
|
||||
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
|
||||
)
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -49,7 +58,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
# skip pruning if not active
|
||||
# skip external group sync if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
@@ -81,7 +90,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_external_group_sync",
|
||||
name=DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -102,12 +111,28 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
|
||||
# We only want to sync one cc_pair per source type in
|
||||
# GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC
|
||||
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
|
||||
# These are ordered by cc_pair id so the first one is the one we want
|
||||
cc_pairs_to_dedupe = get_cc_pairs_by_source(
|
||||
db_session, source, only_sync=True
|
||||
)
|
||||
# We only want to sync one cc_pair per source type
|
||||
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
|
||||
for cc_pair_to_remove in cc_pairs_to_dedupe[1:]:
|
||||
cc_pairs = [
|
||||
cc_pair
|
||||
for cc_pair in cc_pairs
|
||||
if cc_pair.id != cc_pair_to_remove.id
|
||||
]
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
if _is_external_group_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
tasks_created = try_creating_permissions_sync_task(
|
||||
tasks_created = try_creating_external_group_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not tasks_created:
|
||||
@@ -125,7 +150,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_creating_permissions_sync_task(
|
||||
def try_creating_external_group_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
@@ -156,8 +181,8 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
|
||||
|
||||
_ = app.send_task(
|
||||
"connector_external_group_sync_generator_task",
|
||||
result = app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -166,8 +191,13 @@ def try_creating_permissions_sync_task(
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
)
|
||||
# set a basic fence to start
|
||||
redis_connector.external_group_sync.set_fence(True)
|
||||
|
||||
payload = RedisConnectorExternalGroupSyncPayload(
|
||||
started=datetime.now(timezone.utc),
|
||||
celery_task_id=result.id,
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
@@ -182,7 +212,7 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="connector_external_group_sync_generator_task",
|
||||
name=DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
@@ -195,7 +225,7 @@ def connector_external_group_sync_generator_task(
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
Permission sync task that handles document permission syncing for a given connector credential pair
|
||||
Permission sync task that handles external group syncing for a given connector credential pair
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
@@ -203,7 +233,7 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock = r.lock(
|
||||
lock: RedisLock = r.lock(
|
||||
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
|
||||
@@ -228,9 +258,13 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if ext_group_sync_func is None:
|
||||
raise ValueError(f"No external group sync func found for {source_type}")
|
||||
raise ValueError(
|
||||
f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
|
||||
logger.info(f"Syncing docs for {source_type}")
|
||||
logger.info(
|
||||
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
|
||||
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
|
||||
|
||||
@@ -249,7 +283,6 @@ def connector_external_group_sync_generator_task(
|
||||
)
|
||||
|
||||
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run external group sync: cc_pair={cc_pair_id}"
|
||||
@@ -260,6 +293,6 @@ def connector_external_group_sync_generator_task(
|
||||
raise e
|
||||
finally:
|
||||
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
|
||||
redis_connector.external_group_sync.set_fence(False)
|
||||
redis_connector.external_group_sync.set_fence(None)
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
@@ -23,13 +23,16 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.enums import IndexingMode
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.enums import IndexModelStatus
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
@@ -37,12 +40,13 @@ from danswer.db.index_attempt import delete_index_attempt
|
||||
from danswer.db.index_attempt import get_all_index_attempts_by_status
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import mark_attempt_canceled
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_active_search_settings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
@@ -77,7 +81,7 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
self.started: datetime = datetime.now(timezone.utc)
|
||||
self.redis_lock.reacquire()
|
||||
|
||||
self.last_tag: str = ""
|
||||
self.last_tag: str = "IndexingCallback.__init__"
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
@@ -153,13 +157,13 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_indexing",
|
||||
name=DanswerCeleryTask.CHECK_FOR_INDEXING,
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
tasks_created = 0
|
||||
|
||||
locked = False
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -172,6 +176,8 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
locked = True
|
||||
|
||||
# check for search settings swap
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
old_search_settings = check_index_swap(db_session=db_session)
|
||||
@@ -205,17 +211,10 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings = [primary_search_settings]
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings.append(secondary_search_settings)
|
||||
|
||||
for search_settings_instance in search_settings:
|
||||
search_settings_list: list[SearchSettings] = get_active_search_settings(
|
||||
db_session
|
||||
)
|
||||
for search_settings_instance in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings_instance.id
|
||||
)
|
||||
@@ -231,22 +230,46 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
|
||||
search_settings_primary = False
|
||||
if search_settings_instance.id == search_settings_list[0].id:
|
||||
search_settings_primary = True
|
||||
|
||||
if not _should_index(
|
||||
cc_pair=cc_pair,
|
||||
last_index=last_attempt,
|
||||
search_settings_instance=search_settings_instance,
|
||||
secondary_index_building=len(search_settings) > 1,
|
||||
search_settings_primary=search_settings_primary,
|
||||
secondary_index_building=len(search_settings_list) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
reindex = False
|
||||
if search_settings_instance.id == search_settings_list[0].id:
|
||||
# the indexing trigger is only checked and cleared with the primary search settings
|
||||
if cc_pair.indexing_trigger is not None:
|
||||
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
|
||||
reindex = True
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing manual trigger detected: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings_instance.id} "
|
||||
f"indexing_mode={cc_pair.indexing_trigger}"
|
||||
)
|
||||
|
||||
mark_ccpair_with_indexing_trigger(
|
||||
cc_pair.id, None, db_session
|
||||
)
|
||||
|
||||
# using a task queue and only allowing one task per cc_pair/search_setting
|
||||
# prevents us from starving out certain attempts
|
||||
attempt_id = try_creating_indexing_task(
|
||||
self.app,
|
||||
cc_pair,
|
||||
search_settings_instance,
|
||||
False,
|
||||
reindex,
|
||||
db_session,
|
||||
r,
|
||||
tenant_id,
|
||||
@@ -256,7 +279,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
f"Connector indexing queued: "
|
||||
f"index_attempt={attempt_id} "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings_instance.id} "
|
||||
f"search_settings={search_settings_instance.id}"
|
||||
)
|
||||
tasks_created += 1
|
||||
|
||||
@@ -281,7 +304,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
mark_attempt_failed(
|
||||
attempt.id, db_session, failure_reason=failure_reason
|
||||
)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -289,13 +311,14 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_indexing - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
if locked:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_indexing - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
|
||||
return tasks_created
|
||||
|
||||
@@ -304,6 +327,7 @@ def _should_index(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
last_index: IndexAttempt | None,
|
||||
search_settings_instance: SearchSettings,
|
||||
search_settings_primary: bool,
|
||||
secondary_index_building: bool,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
@@ -368,6 +392,11 @@ def _should_index(
|
||||
):
|
||||
return False
|
||||
|
||||
if search_settings_primary:
|
||||
if cc_pair.indexing_trigger is not None:
|
||||
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
|
||||
return True
|
||||
|
||||
# if no attempt has ever occurred, we should index regardless of refresh_freq
|
||||
if not last_index:
|
||||
return True
|
||||
@@ -458,7 +487,7 @@ def try_creating_indexing_task(
|
||||
# when the task is sent, we have yet to finish setting up the fence
|
||||
# therefore, the task must contain code that blocks until the fence is ready
|
||||
result = celery_app.send_task(
|
||||
"connector_indexing_proxy_task",
|
||||
DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
|
||||
kwargs=dict(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
@@ -495,8 +524,14 @@ def try_creating_indexing_task(
|
||||
return index_attempt_id
|
||||
|
||||
|
||||
@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
|
||||
bind=True,
|
||||
acks_late=False,
|
||||
track_started=True,
|
||||
)
|
||||
def connector_indexing_proxy_task(
|
||||
self: Task,
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
@@ -509,6 +544,10 @@ def connector_indexing_proxy_task(
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
if not self.request.id:
|
||||
task_logger.error("self.request.id is None!")
|
||||
|
||||
client = SimpleJobClient()
|
||||
|
||||
job = client.submit(
|
||||
@@ -537,29 +576,80 @@ def connector_indexing_proxy_task(
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
while True:
|
||||
sleep(10)
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done():
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - termination signal detected: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
except Exception:
|
||||
# if the DB exceptions, we'll just get an unfriendly failure message
|
||||
# in the UI instead of the cancellation message
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception marking index attempt as canceled: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
job.cancel()
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
break
|
||||
|
||||
if not job.done():
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
exit_code: int | None = None
|
||||
if job.process:
|
||||
exit_code = job.process.exitcode
|
||||
task_logger.error(
|
||||
f"Indexing watchdog - spawned task exceptioned: "
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code} "
|
||||
f"error={job.exception()}"
|
||||
)
|
||||
|
||||
@@ -703,9 +793,12 @@ def connector_indexing_task(
|
||||
)
|
||||
break
|
||||
|
||||
# set thread_local=False since we don't control what thread the indexing/pruning
|
||||
# might run our callback with
|
||||
lock: RedisLock = r.lock(
|
||||
redis_connector_index.generator_lock_key,
|
||||
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
|
||||
@@ -13,12 +13,13 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import PostgresAdvisoryLocks
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="kombu_message_cleanup_task",
|
||||
name=DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
|
||||
@@ -8,6 +8,7 @@ from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
@@ -20,6 +21,7 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
@@ -75,7 +77,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_pruning",
|
||||
name=DanswerCeleryTask.CHECK_FOR_PRUNING,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -184,7 +186,7 @@ def try_creating_prune_generator_task(
|
||||
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
|
||||
|
||||
celery_app.send_task(
|
||||
"connector_pruning_generator_task",
|
||||
DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair.id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
@@ -209,7 +211,7 @@ def try_creating_prune_generator_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="connector_pruning_generator_task",
|
||||
name=DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
@@ -238,9 +240,12 @@ def connector_pruning_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock = r.lock(
|
||||
# set thread_local=False since we don't control what thread the indexing/pruning
|
||||
# might run our callback with
|
||||
lock: RedisLock = r.lock(
|
||||
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
|
||||
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
|
||||
@@ -9,6 +9,7 @@ from tenacity import RetryError
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document
|
||||
@@ -31,7 +32,7 @@ LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="document_by_cc_pair_cleanup_task",
|
||||
name=DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES,
|
||||
|
||||
@@ -25,6 +25,7 @@ from danswer.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector import mark_cc_pair_as_permissions_synced
|
||||
@@ -46,6 +47,7 @@ from danswer.db.document_set import fetch_document_sets_for_document
|
||||
from danswer.db.document_set import get_document_set_by_id
|
||||
from danswer.db.document_set import mark_document_set_as_synced
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
@@ -58,7 +60,7 @@ from danswer.redis.redis_connector_credential_pair import RedisConnectorCredenti
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from danswer.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncData,
|
||||
RedisConnectorPermissionSyncPayload,
|
||||
)
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
@@ -79,7 +81,7 @@ logger = setup_logger()
|
||||
# celery auto associates tasks created inside another task,
|
||||
# which bloats the result metadata considerably. trail=False prevents this.
|
||||
@shared_task(
|
||||
name="check_for_vespa_sync_task",
|
||||
name=DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
@@ -588,7 +590,7 @@ def monitor_ccpair_permissions_taskset(
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
payload: RedisConnectorPermissionSyncData | None = (
|
||||
payload: RedisConnectorPermissionSyncPayload | None = (
|
||||
redis_connector.permissions.payload
|
||||
)
|
||||
start_time: datetime | None = payload.started if payload else None
|
||||
@@ -596,9 +598,7 @@ def monitor_ccpair_permissions_taskset(
|
||||
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
|
||||
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
|
||||
|
||||
redis_connector.permissions.taskset_clear()
|
||||
redis_connector.permissions.generator_clear()
|
||||
redis_connector.permissions.set_fence(None)
|
||||
redis_connector.permissions.reset()
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
@@ -655,33 +655,52 @@ def monitor_ccpair_indexing_taskset(
|
||||
# outer = result.state in READY state
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int is None: # inner signal not set ... possible error
|
||||
result_state = result.state
|
||||
task_state = result.state
|
||||
if (
|
||||
result_state in READY_STATES
|
||||
task_state in READY_STATES
|
||||
): # outer signal in terminal state ... possible error
|
||||
# Now double check!
|
||||
if redis_connector_index.get_completion() is None:
|
||||
# inner signal still not set (and cannot change when outer result_state is READY)
|
||||
# Task is finished but generator complete isn't set.
|
||||
# We have a problem! Worker may have crashed.
|
||||
task_result = str(result.result)
|
||||
task_traceback = str(result.traceback)
|
||||
|
||||
msg = (
|
||||
f"Connector indexing aborted or exceptioned: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"celery_task={payload.celery_task_id} "
|
||||
f"result_state={result_state} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"result.state={task_state} "
|
||||
f"result.result={task_result} "
|
||||
f"result.traceback={task_traceback}"
|
||||
)
|
||||
task_logger.warning(msg)
|
||||
|
||||
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
|
||||
if index_attempt:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason=msg,
|
||||
try:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session, payload.index_attempt_id
|
||||
)
|
||||
if index_attempt:
|
||||
if (
|
||||
index_attempt.status != IndexingStatus.CANCELED
|
||||
and index_attempt.status != IndexingStatus.FAILED
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason=msg,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"monitor_ccpair_indexing_taskset - transient exception marking index attempt as failed: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
@@ -692,6 +711,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"status={status_enum.name} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
@@ -699,7 +719,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
redis_connector_index.reset()
|
||||
|
||||
|
||||
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
|
||||
@shared_task(name=DanswerCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
@@ -724,7 +744,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
|
||||
# print current queue lengths
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r)
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_indexing = celery_get_queue_length(
|
||||
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
@@ -810,7 +830,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="vespa_metadata_sync_task",
|
||||
name=DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
bind=True,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Factory stub for running celery worker / celery beat."""
|
||||
from celery import Celery
|
||||
|
||||
from danswer.background.celery.apps.beat import celery_app
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app = celery_app
|
||||
app: Celery = celery_app
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"""Factory stub for running celery worker / celery beat."""
|
||||
from celery import Celery
|
||||
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app = fetch_versioned_implementation(
|
||||
app: Celery = fetch_versioned_implementation(
|
||||
"danswer.background.celery.apps.primary", "celery_app"
|
||||
)
|
||||
|
||||
@@ -82,7 +82,7 @@ class SimpleJob:
|
||||
return "running"
|
||||
elif self.process.exitcode is None:
|
||||
return "cancelled"
|
||||
elif self.process.exitcode > 0:
|
||||
elif self.process.exitcode != 0:
|
||||
return "error"
|
||||
else:
|
||||
return "finished"
|
||||
@@ -123,7 +123,8 @@ class SimpleJobClient:
|
||||
self._cleanup_completed_jobs()
|
||||
if len(self.jobs) >= self.n_workers:
|
||||
logger.debug(
|
||||
f"No available workers to run job. Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'."
|
||||
f"No available workers to run job. "
|
||||
f"Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from danswer.db.connector_credential_pair import get_last_successful_attempt_tim
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.index_attempt import mark_attempt_canceled
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_partially_succeeded
|
||||
from danswer.db.index_attempt import mark_attempt_succeeded
|
||||
@@ -87,6 +88,10 @@ def _get_connector_runner(
|
||||
)
|
||||
|
||||
|
||||
class ConnectorStopSignal(Exception):
|
||||
"""A custom exception used to signal a stop in processing."""
|
||||
|
||||
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt: IndexAttempt,
|
||||
@@ -208,9 +213,7 @@ def _run_indexing(
|
||||
# contents still need to be initially pulled.
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_run_indexing: Connector stop signal detected"
|
||||
)
|
||||
raise ConnectorStopSignal("Connector stop signal detected")
|
||||
|
||||
# TODO: should we move this into the above callback instead?
|
||||
db_session.refresh(db_cc_pair)
|
||||
@@ -304,26 +307,16 @@ def _run_indexing(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
|
||||
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
|
||||
)
|
||||
# Only mark the attempt as a complete failure if this is the first indexing window.
|
||||
# Otherwise, some progress was made - the next run will not start from the beginning.
|
||||
# In this case, it is not accurate to mark it as a failure. When the next run begins,
|
||||
# if that fails immediately, it will be marked as a failure.
|
||||
#
|
||||
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
|
||||
# to give better clarity in the UI, as the next run will never happen.
|
||||
if (
|
||||
ind == 0
|
||||
or not db_cc_pair.status.is_active()
|
||||
or index_attempt.status != IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
mark_attempt_failed(
|
||||
|
||||
if isinstance(e, ConnectorStopSignal):
|
||||
mark_attempt_canceled(
|
||||
index_attempt.id,
|
||||
db_session,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
reason=str(e),
|
||||
)
|
||||
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
@@ -335,6 +328,37 @@ def _run_indexing(
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.stop()
|
||||
raise e
|
||||
else:
|
||||
# Only mark the attempt as a complete failure if this is the first indexing window.
|
||||
# Otherwise, some progress was made - the next run will not start from the beginning.
|
||||
# In this case, it is not accurate to mark it as a failure. When the next run begins,
|
||||
# if that fails immediately, it will be marked as a failure.
|
||||
#
|
||||
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
|
||||
# to give better clarity in the UI, as the next run will never happen.
|
||||
if (
|
||||
ind == 0
|
||||
or not db_cc_pair.status.is_active()
|
||||
or index_attempt.status != IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt.id,
|
||||
db_session,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.stop()
|
||||
raise e
|
||||
|
||||
# break => similar to success case. As mentioned above, if the next run fails for the same
|
||||
# reason it will then be marked as a failure
|
||||
|
||||
@@ -6,33 +6,27 @@ from langchain.schema.messages import BaseMessage
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
from danswer.chat.llm_response_handler import LLMResponseHandlerManager
|
||||
from danswer.chat.models import AnswerQuestionPossibleReturn
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.answering.llm_response_handler import LLMCall
|
||||
from danswer.llm.answering.llm_response_handler import LLMResponseHandlerManager
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.answering.prompts.build import default_build_system_message
|
||||
from danswer.llm.answering.prompts.build import default_build_user_message
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
AnswerResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.chat.prompt_builder.build import default_build_system_message
|
||||
from danswer.chat.prompt_builder.build import default_build_user_message
|
||||
from danswer.chat.prompt_builder.build import LLMCall
|
||||
from danswer.chat.stream_processing.answer_response_handler import (
|
||||
CitationResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
from danswer.chat.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
QuotesResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import map_document_id_order
|
||||
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
|
||||
from danswer.chat.stream_processing.utils import map_document_id_order
|
||||
from danswer.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolResponse
|
||||
@@ -212,20 +206,28 @@ class Answer:
|
||||
# + figure out what the next LLM call should be
|
||||
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
||||
|
||||
search_result = SearchTool.get_search_result(current_llm_call) or []
|
||||
search_result, displayed_search_results_map = SearchTool.get_search_result(
|
||||
current_llm_call
|
||||
) or ([], {})
|
||||
|
||||
answer_handler: AnswerResponseHandler
|
||||
if self.answer_style_config.citation_config:
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=search_result,
|
||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
)
|
||||
elif self.answer_style_config.quotes_config:
|
||||
answer_handler = QuotesResponseHandler(
|
||||
context_docs=search_result,
|
||||
)
|
||||
else:
|
||||
raise ValueError("No answer style config provided")
|
||||
# Quotes are no longer supported
|
||||
# answer_handler: AnswerResponseHandler
|
||||
# if self.answer_style_config.citation_config:
|
||||
# answer_handler = CitationResponseHandler(
|
||||
# context_docs=search_result,
|
||||
# doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
# )
|
||||
# elif self.answer_style_config.quotes_config:
|
||||
# answer_handler = QuotesResponseHandler(
|
||||
# context_docs=search_result,
|
||||
# )
|
||||
# else:
|
||||
# raise ValueError("No answer style config provided")
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=search_result,
|
||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
display_doc_order_dict=displayed_search_results_map,
|
||||
)
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_call_handler, answer_handler, self.is_cancelled
|
||||
@@ -233,6 +235,8 @@ class Answer:
|
||||
|
||||
# DEBUG: good breakpoint
|
||||
stream = self.llm.stream(
|
||||
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
prompt=current_llm_call.prompt_builder.build(),
|
||||
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
|
||||
tool_choice=(
|
||||
@@ -2,20 +2,79 @@ import re
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.datastructures import Headers
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import is_user_admin
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PersonaOverrideConfig
|
||||
from danswer.chat.models import ThreadMessage
|
||||
from danswer.configs.constants import DEFAULT_PERSONA_ID
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.llm import fetch_existing_doc_sets
|
||||
from danswer.db.llm import fetch_existing_tools
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import Tool
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_prompts_by_ids
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.natural_language_processing.utils import BaseTokenizer
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def prepare_chat_message_request(
|
||||
message_text: str,
|
||||
user: User | None,
|
||||
persona_id: int | None,
|
||||
# Does the question need to have a persona override
|
||||
persona_override_config: PersonaOverrideConfig | None,
|
||||
prompt: Prompt | None,
|
||||
message_ts_to_respond_to: str | None,
|
||||
retrieval_details: RetrievalDetails | None,
|
||||
rerank_settings: RerankingDetails | None,
|
||||
db_session: Session,
|
||||
) -> CreateChatMessageRequest:
|
||||
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
|
||||
new_chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description=None,
|
||||
user_id=user.id if user else None,
|
||||
# If using an override, this id will be ignored later on
|
||||
persona_id=persona_id or DEFAULT_PERSONA_ID,
|
||||
danswerbot_flow=True,
|
||||
slack_thread_id=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
return CreateChatMessageRequest(
|
||||
chat_session_id=new_chat_session.id,
|
||||
parent_message_id=None, # It's a standalone chat session each time
|
||||
message=message_text,
|
||||
file_descriptors=[], # Currently SlackBot/answer api do not support files in the context
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
# Can always override the persona for the single query, if it's a normal persona
|
||||
# then it will be treated the same
|
||||
persona_override_config=persona_override_config,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=retrieval_details,
|
||||
rerank_settings=rerank_settings,
|
||||
)
|
||||
|
||||
|
||||
def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc:
|
||||
return LlmDoc(
|
||||
document_id=inference_section.center_chunk.document_id,
|
||||
@@ -31,9 +90,49 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
|
||||
if inference_section.center_chunk.source_links
|
||||
else None,
|
||||
source_links=inference_section.center_chunk.source_links,
|
||||
match_highlights=inference_section.center_chunk.match_highlights,
|
||||
)
|
||||
|
||||
|
||||
def combine_message_thread(
|
||||
messages: list[ThreadMessage],
|
||||
max_tokens: int | None,
|
||||
llm_tokenizer: BaseTokenizer,
|
||||
) -> str:
|
||||
"""Used to create a single combined message context from threads"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
total_token_count = 0
|
||||
|
||||
for message in reversed(messages):
|
||||
if message.role == MessageType.USER:
|
||||
role_str = message.role.value.upper()
|
||||
if message.sender:
|
||||
role_str += " " + message.sender
|
||||
else:
|
||||
# Since other messages might have the user identifying information
|
||||
# better to use Unknown for symmetry
|
||||
role_str += " Unknown"
|
||||
else:
|
||||
role_str = message.role.value.upper()
|
||||
|
||||
msg_str = f"{role_str}:\n{message.message}"
|
||||
message_token_count = len(llm_tokenizer.encode(msg_str))
|
||||
|
||||
if (
|
||||
max_tokens is not None
|
||||
and total_token_count + message_token_count > max_tokens
|
||||
):
|
||||
break
|
||||
|
||||
message_strs.insert(0, msg_str)
|
||||
total_token_count += message_token_count
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def create_chat_chain(
|
||||
chat_session_id: UUID,
|
||||
db_session: Session,
|
||||
@@ -196,3 +295,71 @@ def extract_headers(
|
||||
if lowercase_key in headers:
|
||||
extracted_headers[lowercase_key] = headers[lowercase_key]
|
||||
return extracted_headers
|
||||
|
||||
|
||||
def create_temporary_persona(
|
||||
persona_config: PersonaOverrideConfig, db_session: Session, user: User | None = None
|
||||
) -> Persona:
|
||||
if not is_user_admin(user):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="User is not authorized to create a persona in one shot queries",
|
||||
)
|
||||
|
||||
"""Create a temporary Persona object from the provided configuration."""
|
||||
persona = Persona(
|
||||
name=persona_config.name,
|
||||
description=persona_config.description,
|
||||
num_chunks=persona_config.num_chunks,
|
||||
llm_relevance_filter=persona_config.llm_relevance_filter,
|
||||
llm_filter_extraction=persona_config.llm_filter_extraction,
|
||||
recency_bias=persona_config.recency_bias,
|
||||
llm_model_provider_override=persona_config.llm_model_provider_override,
|
||||
llm_model_version_override=persona_config.llm_model_version_override,
|
||||
)
|
||||
|
||||
if persona_config.prompts:
|
||||
persona.prompts = [
|
||||
Prompt(
|
||||
name=p.name,
|
||||
description=p.description,
|
||||
system_prompt=p.system_prompt,
|
||||
task_prompt=p.task_prompt,
|
||||
include_citations=p.include_citations,
|
||||
datetime_aware=p.datetime_aware,
|
||||
)
|
||||
for p in persona_config.prompts
|
||||
]
|
||||
elif persona_config.prompt_ids:
|
||||
persona.prompts = get_prompts_by_ids(
|
||||
db_session=db_session, prompt_ids=persona_config.prompt_ids
|
||||
)
|
||||
|
||||
persona.tools = []
|
||||
if persona_config.custom_tools_openapi:
|
||||
for schema in persona_config.custom_tools_openapi:
|
||||
tools = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(schema),
|
||||
)
|
||||
persona.tools.extend(tools)
|
||||
|
||||
if persona_config.tools:
|
||||
tool_ids = [tool.id for tool in persona_config.tools]
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
|
||||
)
|
||||
|
||||
if persona_config.tool_ids:
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(
|
||||
db_session=db_session, tool_ids=persona_config.tool_ids
|
||||
)
|
||||
)
|
||||
|
||||
fetched_docs = fetch_existing_doc_sets(
|
||||
db_session=db_session, doc_ids=persona_config.document_set_ids
|
||||
)
|
||||
persona.document_sets = fetched_docs
|
||||
|
||||
return persona
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
input_prompts:
|
||||
- id: -5
|
||||
prompt: "Elaborate"
|
||||
content: "Elaborate on the above, give me a more in depth explanation."
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -4
|
||||
prompt: "Reword"
|
||||
content: "Help me rewrite the following politely and concisely for professional communication:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -3
|
||||
prompt: "Email"
|
||||
content: "Write a professional email for me including a subject line, signature, etc. Template the parts that need editing with [ ]. The email should cover the following points:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -2
|
||||
prompt: "Debug"
|
||||
content: "Provide step-by-step troubleshooting instructions for the following issue:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
@@ -1,60 +1,22 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic.v1 import BaseModel as BaseModel__v1
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import ResponsePart
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
AnswerResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
|
||||
|
||||
|
||||
ResponsePart = (
|
||||
DanswerAnswerPiece
|
||||
| CitationInfo
|
||||
| DanswerQuotes
|
||||
| ToolCallKickoff
|
||||
| ToolResponse
|
||||
| ToolCallFinalResult
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
|
||||
class LLMCall(BaseModel__v1):
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
tools: list[Tool]
|
||||
force_use_tool: ForceUseTool
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
|
||||
using_tool_calling_llm: bool
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
from danswer.chat.prompt_builder.build import LLMCall
|
||||
from danswer.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from danswer.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
|
||||
|
||||
class LLMResponseHandlerManager:
|
||||
def __init__(
|
||||
self,
|
||||
tool_handler: "ToolResponseHandler",
|
||||
answer_handler: "AnswerResponseHandler",
|
||||
tool_handler: ToolResponseHandler,
|
||||
answer_handler: AnswerResponseHandler,
|
||||
is_cancelled: Callable[[], bool],
|
||||
):
|
||||
self.tool_handler = tool_handler
|
||||
@@ -1,17 +1,30 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SearchResponse
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.context.search.enums import QueryFlow
|
||||
from danswer.context.search.enums import RecencyBiasSetting
|
||||
from danswer.context.search.enums import SearchType
|
||||
from danswer.context.search.models import RetrievalDocs
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import Prompt
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
"""This contains the minimal set information for the LLM portion including citations"""
|
||||
@@ -25,6 +38,7 @@ class LlmDoc(BaseModel):
|
||||
updated_at: datetime | None
|
||||
link: str | None
|
||||
source_links: dict[int, str] | None
|
||||
match_highlights: list[str] | None
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
@@ -117,20 +131,6 @@ class StreamingError(BaseModel):
|
||||
stack_trace: str | None = None
|
||||
|
||||
|
||||
class DanswerQuote(BaseModel):
|
||||
# This is during inference so everything is a string by this point
|
||||
quote: str
|
||||
document_id: str
|
||||
link: str | None
|
||||
source_type: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class DanswerQuotes(BaseModel):
|
||||
quotes: list[DanswerQuote]
|
||||
|
||||
|
||||
class DanswerContext(BaseModel):
|
||||
content: str
|
||||
document_id: str
|
||||
@@ -146,14 +146,20 @@ class DanswerAnswer(BaseModel):
|
||||
answer: str | None
|
||||
|
||||
|
||||
class QAResponse(SearchResponse, DanswerAnswer):
|
||||
quotes: list[DanswerQuote] | None
|
||||
contexts: list[DanswerContexts] | None
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
eval_res_valid: bool | None = None
|
||||
class ThreadMessage(BaseModel):
|
||||
message: str
|
||||
sender: str | None = None
|
||||
role: MessageType = MessageType.USER
|
||||
|
||||
|
||||
class ChatDanswerBotResponse(BaseModel):
|
||||
answer: str | None = None
|
||||
citations: list[CitationInfo] | None = None
|
||||
docs: QADocsResponse | None = None
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
chat_message_id: int | None = None
|
||||
answer_valid: bool = True # Reflexion result, default True if Reflexion not run
|
||||
|
||||
|
||||
class FileChatDisplay(BaseModel):
|
||||
@@ -165,9 +171,41 @@ class CustomToolResponse(BaseModel):
|
||||
tool_name: str
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class PromptOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
system_prompt: str
|
||||
task_prompt: str = ""
|
||||
include_citations: bool = True
|
||||
datetime_aware: bool = True
|
||||
|
||||
|
||||
class PersonaOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
num_chunks: float | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
|
||||
prompt_ids: list[int] = Field(default_factory=list)
|
||||
|
||||
document_set_ids: list[int] = Field(default_factory=list)
|
||||
tools: list[ToolConfig] = Field(default_factory=list)
|
||||
tool_ids: list[int] = Field(default_factory=list)
|
||||
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
AnswerQuestionPossibleReturn = (
|
||||
DanswerAnswerPiece
|
||||
| DanswerQuotes
|
||||
| CitationInfo
|
||||
| DanswerContexts
|
||||
| FileChatDisplay
|
||||
@@ -183,3 +221,109 @@ AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]
|
||||
class LLMMetricsContainer(BaseModel):
|
||||
prompt_tokens: int
|
||||
response_tokens: int
|
||||
|
||||
|
||||
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
|
||||
|
||||
|
||||
class DocumentPruningConfig(BaseModel):
|
||||
max_chunks: int | None = None
|
||||
max_window_percentage: float | None = None
|
||||
max_tokens: int | None = None
|
||||
# different pruning behavior is expected when the
|
||||
# user manually selects documents they want to chat with
|
||||
# e.g. we don't want to truncate each document to be no more
|
||||
# than one chunk long
|
||||
is_manually_selected_docs: bool = False
|
||||
# If user specifies to include additional context Chunks for each match, then different pruning
|
||||
# is used. As many Sections as possible are included, and the last Section is truncated
|
||||
# If this is false, all of the Sections are truncated if they are longer than the expected Chunk size.
|
||||
# Sections are often expected to be longer than the maximum Chunk size but Chunks should not be.
|
||||
use_sections: bool = True
|
||||
# If using tools, then we need to consider the tool length
|
||||
tool_num_tokens: int = 0
|
||||
# If using a tool message to represent the docs, then we have to JSON serialize
|
||||
# the document content, which adds to the token count.
|
||||
using_tool_message: bool = False
|
||||
|
||||
|
||||
class ContextualPruningConfig(DocumentPruningConfig):
|
||||
num_chunk_multiple: int
|
||||
|
||||
@classmethod
|
||||
def from_doc_pruning_config(
|
||||
cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig
|
||||
) -> "ContextualPruningConfig":
|
||||
return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict())
|
||||
|
||||
|
||||
class CitationConfig(BaseModel):
|
||||
all_docs_useful: bool = False
|
||||
|
||||
|
||||
class QuotesConfig(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class AnswerStyleConfig(BaseModel):
|
||||
citation_config: CitationConfig | None = None
|
||||
quotes_config: QuotesConfig | None = None
|
||||
document_pruning_config: DocumentPruningConfig = Field(
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
# forces the LLM to return a structured response, see
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
# right now, only used by the simple chat API
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
|
||||
if self.citation_config is None and self.quotes_config is None:
|
||||
raise ValueError(
|
||||
"One of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
if self.citation_config is not None and self.quotes_config is not None:
|
||||
raise ValueError(
|
||||
"Only one of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""Final representation of the Prompt configuration passed
|
||||
into the `Answer` object."""
|
||||
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
datetime_aware: bool
|
||||
include_citations: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, model: "Prompt", prompt_override: PromptOverride | None = None
|
||||
) -> "PromptConfig":
|
||||
override_system_prompt = (
|
||||
prompt_override.system_prompt if prompt_override else None
|
||||
)
|
||||
override_task_prompt = prompt_override.task_prompt if prompt_override else None
|
||||
|
||||
return cls(
|
||||
system_prompt=override_system_prompt or model.system_prompt,
|
||||
task_prompt=override_task_prompt or model.task_prompt,
|
||||
datetime_aware=model.datetime_aware,
|
||||
include_citations=model.include_citations,
|
||||
)
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
ResponsePart = (
|
||||
DanswerAnswerPiece
|
||||
| CitationInfo
|
||||
| ToolCallKickoff
|
||||
| ToolResponse
|
||||
| ToolCallFinalResult
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
@@ -6,16 +6,24 @@ from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.answer import Answer
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.chat_utils import create_temporary_persona
|
||||
from danswer.chat.models import AllCitations
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import ChatDanswerBotResponse
|
||||
from danswer.chat.models import CitationConfig
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import CustomToolResponse
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import DocumentPruningConfig
|
||||
from danswer.chat.models import FileChatDisplay
|
||||
from danswer.chat.models import FinalUsedContextDocsResponse
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import MessageResponseIDInfo
|
||||
from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
@@ -23,6 +31,16 @@ from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.context.search.enums import OptionalSearchSetting
|
||||
from danswer.context.search.enums import QueryFlow
|
||||
from danswer.context.search.enums import SearchType
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.context.search.retrieval.search_runner import inference_sections_from_ids
|
||||
from danswer.context.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.context.search.utils import dedupe_documents
|
||||
from danswer.context.search.utils import drop_llm_indices
|
||||
from danswer.context.search.utils import relevant_sections_to_indices
|
||||
from danswer.db.chat import attach_files_to_chat_message
|
||||
from danswer.db.chat import create_db_search_doc
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
@@ -44,28 +62,13 @@ from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.file_store.utils import load_all_chat_files
|
||||
from danswer.file_store.utils import save_files_from_urls
|
||||
from danswer.llm.answering.answer import Answer
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.file_store.utils import save_files
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import litellm_exception_to_error_msg
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.search.retrieval.search_runner import inference_sections_from_ids
|
||||
from danswer.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.search.utils import dedupe_documents
|
||||
from danswer.search.utils import drop_llm_indices
|
||||
from danswer.search.utils import relevant_sections_to_indices
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
@@ -102,6 +105,7 @@ from danswer.tools.tool_implementations.internet_search.internet_search_tool imp
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
@@ -113,7 +117,10 @@ from danswer.tools.tool_implementations.search.search_tool import (
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.long_term_log import LongTermLogger
|
||||
from danswer.utils.timing import log_function_time
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -256,6 +263,7 @@ def _get_force_search_settings(
|
||||
ChatPacket = (
|
||||
StreamingError
|
||||
| QADocsResponse
|
||||
| DanswerContexts
|
||||
| LLMRelevanceFilterResponse
|
||||
| FinalUsedContextDocsResponse
|
||||
| ChatMessageDetail
|
||||
@@ -286,6 +294,8 @@ def stream_chat_message_objects(
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
enforce_chat_session_id_for_search_docs: bool = True,
|
||||
bypass_acl: bool = False,
|
||||
include_contexts: bool = False,
|
||||
) -> ChatPacketStream:
|
||||
"""Streams in order:
|
||||
1. [conditional] Retrieved documents if a search needs to be run
|
||||
@@ -293,6 +303,7 @@ def stream_chat_message_objects(
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
"""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
use_existing_user_message = new_msg_req.use_existing_user_message
|
||||
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
|
||||
|
||||
@@ -322,17 +333,31 @@ def stream_chat_message_objects(
|
||||
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
|
||||
)
|
||||
|
||||
# use alternate persona if alternative assistant id is passed in
|
||||
if alternate_assistant_id is not None:
|
||||
# Allows users to specify a temporary persona (assistant) in the chat session
|
||||
# this takes highest priority since it's user specified
|
||||
persona = get_persona_by_id(
|
||||
alternate_assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
elif new_msg_req.persona_override_config:
|
||||
# Certain endpoints allow users to specify arbitrary persona settings
|
||||
# this should never conflict with the alternate_assistant_id
|
||||
persona = persona = create_temporary_persona(
|
||||
db_session=db_session,
|
||||
persona_config=new_msg_req.persona_override_config,
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
persona = chat_session.persona
|
||||
|
||||
if not persona:
|
||||
raise RuntimeError("No persona specified or found for chat session")
|
||||
|
||||
# If a prompt override is specified via the API, use that with highest priority
|
||||
# but for saving it, we are just mapping it to an existing prompt
|
||||
prompt_id = new_msg_req.prompt_id
|
||||
if prompt_id is None and persona.prompts:
|
||||
prompt_id = sorted(persona.prompts, key=lambda x: x.id)[-1].id
|
||||
@@ -555,19 +580,34 @@ def stream_chat_message_objects(
|
||||
reserved_message_id=reserved_message_id,
|
||||
)
|
||||
|
||||
if not final_msg.prompt:
|
||||
raise RuntimeError("No Prompt found")
|
||||
|
||||
prompt_config = (
|
||||
PromptConfig.from_model(
|
||||
final_msg.prompt,
|
||||
prompt_override=(
|
||||
new_msg_req.prompt_override or chat_session.prompt_override
|
||||
),
|
||||
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
|
||||
if new_msg_req.persona_override_config:
|
||||
prompt_config = PromptConfig(
|
||||
system_prompt=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].system_prompt,
|
||||
task_prompt=new_msg_req.persona_override_config.prompts[0].task_prompt,
|
||||
datetime_aware=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].datetime_aware,
|
||||
include_citations=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].include_citations,
|
||||
)
|
||||
if not persona
|
||||
else PromptConfig.from_model(persona.prompts[0])
|
||||
)
|
||||
elif prompt_override:
|
||||
if not final_msg.prompt:
|
||||
raise ValueError(
|
||||
"Prompt override cannot be applied, no base prompt found."
|
||||
)
|
||||
prompt_config = PromptConfig.from_model(
|
||||
final_msg.prompt,
|
||||
prompt_override=prompt_override,
|
||||
)
|
||||
elif final_msg.prompt:
|
||||
prompt_config = PromptConfig.from_model(final_msg.prompt)
|
||||
else:
|
||||
prompt_config = PromptConfig.from_model(persona.prompts[0])
|
||||
|
||||
answer_style_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
all_docs_useful=selected_db_search_docs is not None
|
||||
@@ -587,11 +627,13 @@ def stream_chat_message_objects(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
retrieval_options=retrieval_options or RetrievalDetails(),
|
||||
rerank_settings=new_msg_req.rerank_settings,
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
latest_query_files=latest_query_files,
|
||||
bypass_acl=bypass_acl,
|
||||
),
|
||||
internet_search_tool_config=InternetSearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
@@ -605,6 +647,7 @@ def stream_chat_message_objects(
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
),
|
||||
)
|
||||
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
@@ -637,7 +680,8 @@ def stream_chat_message_objects(
|
||||
|
||||
reference_db_search_docs = None
|
||||
qa_docs_response = None
|
||||
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
|
||||
# any files to associate with the AI message e.g. dall-e generated images
|
||||
ai_message_files = []
|
||||
dropped_indices = None
|
||||
tool_result = None
|
||||
|
||||
@@ -692,8 +736,14 @@ def stream_chat_message_objects(
|
||||
list[ImageGenerationResponse], packet.response
|
||||
)
|
||||
|
||||
file_ids = save_files_from_urls(
|
||||
[img.url for img in img_generation_response]
|
||||
file_ids = save_files(
|
||||
urls=[img.url for img in img_generation_response if img.url],
|
||||
base64_files=[
|
||||
img.image_data
|
||||
for img in img_generation_response
|
||||
if img.image_data
|
||||
],
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
ai_message_files = [
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
@@ -719,15 +769,19 @@ def stream_chat_message_objects(
|
||||
or custom_tool_response.response_type == "csv"
|
||||
):
|
||||
file_ids = custom_tool_response.tool_result.file_ids
|
||||
ai_message_files = [
|
||||
FileDescriptor(
|
||||
id=str(file_id),
|
||||
type=ChatFileType.IMAGE
|
||||
if custom_tool_response.response_type == "image"
|
||||
else ChatFileType.CSV,
|
||||
)
|
||||
for file_id in file_ids
|
||||
]
|
||||
ai_message_files.extend(
|
||||
[
|
||||
FileDescriptor(
|
||||
id=str(file_id),
|
||||
type=(
|
||||
ChatFileType.IMAGE
|
||||
if custom_tool_response.response_type == "image"
|
||||
else ChatFileType.CSV
|
||||
),
|
||||
)
|
||||
for file_id in file_ids
|
||||
]
|
||||
)
|
||||
yield FileChatDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
@@ -736,6 +790,8 @@ def stream_chat_message_objects(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts:
|
||||
yield cast(DanswerContexts, packet.response)
|
||||
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
pass
|
||||
@@ -775,7 +831,8 @@ def stream_chat_message_objects(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
yield AllCitations(citations=answer.citations)
|
||||
if not answer.is_cancelled():
|
||||
yield AllCitations(citations=answer.citations)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
@@ -844,3 +901,30 @@ def stream_chat_message(
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.model_dump())
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def gather_stream_for_slack(
|
||||
packets: ChatPacketStream,
|
||||
) -> ChatDanswerBotResponse:
|
||||
response = ChatDanswerBotResponse()
|
||||
|
||||
answer = ""
|
||||
for packet in packets:
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
response.docs = packet
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
response.chat_message_id = packet.message_id
|
||||
elif isinstance(packet, LLMRelevanceFilterResponse):
|
||||
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||
elif isinstance(packet, AllCitations):
|
||||
response.citations = packet.citations
|
||||
|
||||
if answer:
|
||||
response.answer = answer
|
||||
|
||||
return response
|
||||
|
||||
@@ -4,20 +4,26 @@ from typing import cast
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from pydantic.v1 import BaseModel as BaseModel__v1
|
||||
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
|
||||
from danswer.chat.prompt_builder.utils import translate_history_to_basemessages
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input_tokens
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import check_message_tokens
|
||||
from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.llm.utils import translate_history_to_basemessages
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from danswer.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from danswer.prompts.prompt_utils import drop_messages_history_overflow
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
def default_build_system_message(
|
||||
@@ -58,8 +64,8 @@ class AnswerPromptBuilder:
|
||||
user_message: HumanMessage,
|
||||
message_history: list[PreviousMessage],
|
||||
llm_config: LLMConfig,
|
||||
raw_user_text: str,
|
||||
single_message_history: str | None = None,
|
||||
raw_user_text: str | None = None,
|
||||
) -> None:
|
||||
self.max_tokens = compute_max_llm_input_tokens(llm_config)
|
||||
|
||||
@@ -89,11 +95,7 @@ class AnswerPromptBuilder:
|
||||
|
||||
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
|
||||
|
||||
self.raw_user_message = (
|
||||
HumanMessage(content=raw_user_text)
|
||||
if raw_user_text is not None
|
||||
else user_message
|
||||
)
|
||||
self.raw_user_message = raw_user_text
|
||||
|
||||
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
|
||||
if not system_message:
|
||||
@@ -143,3 +145,15 @@ class AnswerPromptBuilder:
|
||||
return drop_messages_history_overflow(
|
||||
final_messages_with_tokens, self.max_tokens
|
||||
)
|
||||
|
||||
|
||||
class LLMCall(BaseModel__v1):
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
tools: list[Tool]
|
||||
force_use_tool: ForceUseTool
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
|
||||
using_tool_calling_llm: bool
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@@ -2,11 +2,12 @@ from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.persona import get_default_prompt__read_only
|
||||
from danswer.db.search_settings import get_multilingual_expansion
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
@@ -29,7 +30,6 @@ from danswer.prompts.token_counts import (
|
||||
from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
|
||||
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
|
||||
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -1,46 +1,16 @@
|
||||
from langchain.schema.messages import HumanMessage
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.configs.chat_configs import LANGUAGE_HINT
|
||||
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.db.search_settings import get_multilingual_expansion
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
|
||||
from danswer.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from danswer.prompts.prompt_utils import build_complete_context_str
|
||||
from danswer.search.models import InferenceChunk
|
||||
|
||||
|
||||
def _build_weak_llm_quotes_prompt(
|
||||
question: str,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: PromptConfig,
|
||||
) -> HumanMessage:
|
||||
"""Since Danswer supports a variety of LLMs, this less demanding prompt is provided
|
||||
as an option to use with weaker LLMs such as small version, low float precision, quantized,
|
||||
or distilled models. It only uses one context document and has very weak requirements of
|
||||
output format.
|
||||
"""
|
||||
context_block = ""
|
||||
if context_docs:
|
||||
context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs[0].content)
|
||||
|
||||
prompt_str = WEAK_LLM_PROMPT.format(
|
||||
system_prompt=prompt.system_prompt,
|
||||
context_block=context_block,
|
||||
task_prompt=prompt.task_prompt,
|
||||
user_query=question,
|
||||
)
|
||||
|
||||
if prompt.datetime_aware:
|
||||
prompt_str = add_date_time_to_prompt(prompt_str=prompt_str)
|
||||
|
||||
return HumanMessage(content=prompt_str)
|
||||
|
||||
|
||||
def _build_strong_llm_quotes_prompt(
|
||||
@@ -81,15 +51,9 @@ def build_quotes_user_message(
|
||||
history_str: str,
|
||||
prompt: PromptConfig,
|
||||
) -> HumanMessage:
|
||||
prompt_builder = (
|
||||
_build_weak_llm_quotes_prompt
|
||||
if QA_PROMPT_OVERRIDE == "weak"
|
||||
else _build_strong_llm_quotes_prompt
|
||||
)
|
||||
|
||||
query, _ = message_to_prompt_and_imgs(message)
|
||||
|
||||
return prompt_builder(
|
||||
return _build_strong_llm_quotes_prompt(
|
||||
question=query,
|
||||
context_docs=context_docs,
|
||||
history_str=history_str,
|
||||
62
backend/danswer/chat/prompt_builder/utils.py
Normal file
62
backend/danswer/chat/prompt_builder/utils.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
|
||||
|
||||
|
||||
def build_dummy_prompt(
|
||||
system_prompt: str, task_prompt: str, retrieval_disabled: bool
|
||||
) -> str:
|
||||
if retrieval_disabled:
|
||||
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
return PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str="<CONTEXT_DOCS>",
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
|
||||
def translate_danswer_msg_to_langchain(
|
||||
msg: ChatMessage | PreviousMessage,
|
||||
) -> BaseMessage:
|
||||
files: list[InMemoryChatFile] = []
|
||||
|
||||
# If the message is a `ChatMessage`, it doesn't have the downloaded files
|
||||
# attached. Just ignore them for now.
|
||||
if not isinstance(msg, ChatMessage):
|
||||
files = msg.files
|
||||
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
raise ValueError("System messages are not currently part of history")
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=content)
|
||||
if msg.message_type == MessageType.USER:
|
||||
return HumanMessage(content=content)
|
||||
|
||||
raise ValueError(f"New message type {msg.message_type} not handled")
|
||||
|
||||
|
||||
def translate_history_to_basemessages(
|
||||
history: list[ChatMessage] | list["PreviousMessage"],
|
||||
) -> tuple[list[BaseMessage], list[int]]:
|
||||
history_basemessages = [
|
||||
translate_danswer_msg_to_langchain(msg)
|
||||
for msg in history
|
||||
if msg.token_count != 0
|
||||
]
|
||||
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
|
||||
return history_basemessages, history_token_counts
|
||||
@@ -5,20 +5,20 @@ from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.models import ContextualPruningConfig
|
||||
from danswer.chat.models import (
|
||||
LlmDoc,
|
||||
)
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.prompt_builder.citations_prompt import compute_max_document_tokens
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.llm.answering.models import ContextualPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.tools.tool_implementations.search.search_utils import section_to_dict
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -3,16 +3,14 @@ from collections.abc import Generator
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from danswer.chat.llm_response_handler import ResponsePart
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.llm.answering.llm_response_handler import ResponsePart
|
||||
from danswer.llm.answering.stream_processing.citation_processing import (
|
||||
CitationProcessor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
QuotesProcessor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.chat.stream_processing.citation_processing import CitationProcessor
|
||||
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class AnswerResponseHandler(abc.ABC):
|
||||
@@ -37,17 +35,25 @@ class DummyAnswerResponseHandler(AnswerResponseHandler):
|
||||
|
||||
class CitationResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
self, context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_order_dict: dict[str, int],
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.display_doc_order_dict = display_doc_order_dict
|
||||
self.citation_processor = CitationProcessor(
|
||||
context_docs=self.context_docs,
|
||||
doc_id_to_rank_map=self.doc_id_to_rank_map,
|
||||
display_doc_order_dict=self.display_doc_order_dict,
|
||||
)
|
||||
self.processed_text = ""
|
||||
self.citations: list[CitationInfo] = []
|
||||
|
||||
# TODO remove this after citation issue is resolved
|
||||
logger.debug(f"Document to ranking map {self.doc_id_to_rank_map}")
|
||||
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
@@ -64,28 +70,29 @@ class CitationResponseHandler(AnswerResponseHandler):
|
||||
yield from self.citation_processor.process_token(content)
|
||||
|
||||
|
||||
class QuotesResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
is_json_prompt: bool = True,
|
||||
):
|
||||
self.quotes_processor = QuotesProcessor(
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=is_json_prompt,
|
||||
)
|
||||
# No longer in use, remove later
|
||||
# class QuotesResponseHandler(AnswerResponseHandler):
|
||||
# def __init__(
|
||||
# self,
|
||||
# context_docs: list[LlmDoc],
|
||||
# is_json_prompt: bool = True,
|
||||
# ):
|
||||
# self.quotes_processor = QuotesProcessor(
|
||||
# context_docs=context_docs,
|
||||
# is_json_prompt=is_json_prompt,
|
||||
# )
|
||||
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
if response_item is None:
|
||||
yield from self.quotes_processor.process_token(None)
|
||||
return
|
||||
# def handle_response_part(
|
||||
# self,
|
||||
# response_item: BaseMessage | None,
|
||||
# previous_response_items: list[BaseMessage],
|
||||
# ) -> Generator[ResponsePart, None, None]:
|
||||
# if response_item is None:
|
||||
# yield from self.quotes_processor.process_token(None)
|
||||
# return
|
||||
|
||||
content = (
|
||||
response_item.content if isinstance(response_item.content, str) else ""
|
||||
)
|
||||
# content = (
|
||||
# response_item.content if isinstance(response_item.content, str) else ""
|
||||
# )
|
||||
|
||||
yield from self.quotes_processor.process_token(content)
|
||||
# yield from self.quotes_processor.process_token(content)
|
||||
@@ -4,8 +4,8 @@ from collections.abc import Generator
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.configs.chat_configs import STOP_STREAM_PAT
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.prompts.constants import TRIPLE_BACKTICK
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -22,12 +22,16 @@ class CitationProcessor:
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_order_dict: dict[str, int],
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.stop_stream = stop_stream
|
||||
self.order_mapping = doc_id_to_rank_map.order_mapping
|
||||
self.display_doc_order_dict = (
|
||||
display_doc_order_dict # original order of docs to displayed to user
|
||||
)
|
||||
self.llm_out = ""
|
||||
self.max_citation_num = len(context_docs)
|
||||
self.citation_order: list[int] = []
|
||||
@@ -67,9 +71,9 @@ class CitationProcessor:
|
||||
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
|
||||
self.curr_segment = self.curr_segment.replace("```", "```plaintext")
|
||||
|
||||
citation_pattern = r"\[(\d+)\]"
|
||||
citation_pattern = r"\[(\d+)\]|\[\[(\d+)\]\]" # [1], [[1]], etc.
|
||||
citations_found = list(re.finditer(citation_pattern, self.curr_segment))
|
||||
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
|
||||
possible_citation_pattern = r"(\[+\d*$)" # [1, [, [[, [[2, etc.
|
||||
possible_citation_found = re.search(
|
||||
possible_citation_pattern, self.curr_segment
|
||||
)
|
||||
@@ -77,13 +81,15 @@ class CitationProcessor:
|
||||
if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5:
|
||||
self.current_citations = []
|
||||
|
||||
result = "" # Initialize result here
|
||||
result = ""
|
||||
if citations_found and not in_code_block(self.llm_out):
|
||||
last_citation_end = 0
|
||||
length_to_add = 0
|
||||
while len(citations_found) > 0:
|
||||
citation = citations_found.pop(0)
|
||||
numerical_value = int(citation.group(1))
|
||||
numerical_value = int(
|
||||
next(group for group in citation.groups() if group is not None)
|
||||
)
|
||||
|
||||
if 1 <= numerical_value <= self.max_citation_num:
|
||||
context_llm_doc = self.context_docs[numerical_value - 1]
|
||||
@@ -96,6 +102,18 @@ class CitationProcessor:
|
||||
self.citation_order.index(real_citation_num) + 1
|
||||
)
|
||||
|
||||
# get the value that was displayed to user, should always
|
||||
# be in the display_doc_order_dict. But check anyways
|
||||
if context_llm_doc.document_id in self.display_doc_order_dict:
|
||||
displayed_citation_num = self.display_doc_order_dict[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
else:
|
||||
displayed_citation_num = real_citation_num
|
||||
logger.warning(
|
||||
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
|
||||
)
|
||||
|
||||
# Skip consecutive citations of the same work
|
||||
if target_citation_num in self.current_citations:
|
||||
start, end = citation.span()
|
||||
@@ -116,6 +134,7 @@ class CitationProcessor:
|
||||
doc_id = int(match.group(1))
|
||||
context_llm_doc = self.context_docs[doc_id - 1]
|
||||
yield CitationInfo(
|
||||
# stay with the original for now (order of LLM cites)
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
@@ -131,29 +150,24 @@ class CitationProcessor:
|
||||
|
||||
link = context_llm_doc.link
|
||||
|
||||
# Replace the citation in the current segment
|
||||
start, end = citation.span()
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[{target_citation_num}]"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
|
||||
self.past_cite_count = len(self.llm_out)
|
||||
self.current_citations.append(target_citation_num)
|
||||
|
||||
if target_citation_num not in self.cited_inds:
|
||||
self.cited_inds.add(target_citation_num)
|
||||
yield CitationInfo(
|
||||
# stay with the original for now (order of LLM cites)
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
start, end = citation.span()
|
||||
if link:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{target_citation_num}]]({link})"
|
||||
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
|
||||
# + f"[[{target_citation_num}]]({link})"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
@@ -161,7 +175,8 @@ class CitationProcessor:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{target_citation_num}]]()"
|
||||
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
|
||||
# + f"[[{target_citation_num}]]()"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
@@ -1,3 +1,4 @@
|
||||
# THIS IS NO LONGER IN USE
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
@@ -5,16 +6,15 @@ from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
import regex
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.models import DanswerAnswer
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuote
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.prompts.constants import ANSWER_PAT
|
||||
from danswer.prompts.constants import QUOTE_PAT
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import clean_model_quote
|
||||
from danswer.utils.text_processing import clean_up_code_blocks
|
||||
@@ -26,6 +26,20 @@ logger = setup_logger()
|
||||
answer_pattern = re.compile(r'{\s*"answer"\s*:\s*"', re.IGNORECASE)
|
||||
|
||||
|
||||
class DanswerQuote(BaseModel):
|
||||
# This is during inference so everything is a string by this point
|
||||
quote: str
|
||||
document_id: str
|
||||
link: str | None
|
||||
source_type: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class DanswerQuotes(BaseModel):
|
||||
quotes: list[DanswerQuote]
|
||||
|
||||
|
||||
def _extract_answer_quotes_freeform(
|
||||
answer_raw: str,
|
||||
) -> tuple[Optional[str], Optional[list[str]]]:
|
||||
@@ -3,7 +3,7 @@ from collections.abc import Sequence
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
|
||||
|
||||
class DocumentIdOrderMapping(BaseModel):
|
||||
@@ -4,8 +4,8 @@ from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
from danswer.llm.answering.llm_response_handler import LLMCall
|
||||
from danswer.llm.answering.llm_response_handler import ResponsePart
|
||||
from danswer.chat.models import ResponsePart
|
||||
from danswer.chat.prompt_builder.build import LLMCall
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.message import build_tool_message
|
||||
@@ -62,7 +62,7 @@ class ToolResponseHandler:
|
||||
llm_call.force_use_tool.args
|
||||
if llm_call.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=llm_call.prompt_builder.get_user_message_content(),
|
||||
query=llm_call.prompt_builder.raw_user_message,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
force_run=True,
|
||||
@@ -76,7 +76,7 @@ class ToolResponseHandler:
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=llm_call.tools,
|
||||
query=llm_call.prompt_builder.get_user_message_content(),
|
||||
query=llm_call.prompt_builder.raw_user_message,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -95,7 +95,7 @@ class ToolResponseHandler:
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
query=llm_call.prompt_builder.get_user_message_content(),
|
||||
query=llm_call.prompt_builder.raw_user_message,
|
||||
llm=llm,
|
||||
)
|
||||
if available_tools_and_args
|
||||
@@ -1,115 +0,0 @@
|
||||
from typing_extensions import TypedDict # noreorder
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.prompts.chat_tools import DANSWER_TOOL_DESCRIPTION
|
||||
from danswer.prompts.chat_tools import DANSWER_TOOL_NAME
|
||||
from danswer.prompts.chat_tools import TOOL_FOLLOWUP
|
||||
from danswer.prompts.chat_tools import TOOL_LESS_FOLLOWUP
|
||||
from danswer.prompts.chat_tools import TOOL_LESS_PROMPT
|
||||
from danswer.prompts.chat_tools import TOOL_TEMPLATE
|
||||
from danswer.prompts.chat_tools import USER_INPUT
|
||||
|
||||
|
||||
class ToolInfo(TypedDict):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
||||
class DanswerChatModelOut(BaseModel):
|
||||
model_raw: str
|
||||
action: str
|
||||
action_input: str
|
||||
|
||||
|
||||
def call_tool(
|
||||
model_actions: DanswerChatModelOut,
|
||||
) -> str:
|
||||
raise NotImplementedError("There are no additional tool integrations right now")
|
||||
|
||||
|
||||
def form_user_prompt_text(
|
||||
query: str,
|
||||
tool_text: str | None,
|
||||
hint_text: str | None,
|
||||
user_input_prompt: str = USER_INPUT,
|
||||
tool_less_prompt: str = TOOL_LESS_PROMPT,
|
||||
) -> str:
|
||||
user_prompt = tool_text or tool_less_prompt
|
||||
|
||||
user_prompt += user_input_prompt.format(user_input=query)
|
||||
|
||||
if hint_text:
|
||||
if user_prompt[-1] != "\n":
|
||||
user_prompt += "\n"
|
||||
user_prompt += "\nHint: " + hint_text
|
||||
|
||||
return user_prompt.strip()
|
||||
|
||||
|
||||
def form_tool_section_text(
|
||||
tools: list[ToolInfo] | None, retrieval_enabled: bool, template: str = TOOL_TEMPLATE
|
||||
) -> str | None:
|
||||
if not tools and not retrieval_enabled:
|
||||
return None
|
||||
|
||||
if retrieval_enabled and tools:
|
||||
tools.append(
|
||||
{"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION}
|
||||
)
|
||||
|
||||
tools_intro = []
|
||||
if tools:
|
||||
num_tools = len(tools)
|
||||
for tool in tools:
|
||||
description_formatted = tool["description"].replace("\n", " ")
|
||||
tools_intro.append(f"> {tool['name']}: {description_formatted}")
|
||||
|
||||
prefix = "Must be one of " if num_tools > 1 else "Must be "
|
||||
|
||||
tools_intro_text = "\n".join(tools_intro)
|
||||
tool_names_text = prefix + ", ".join([tool["name"] for tool in tools])
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
return template.format(
|
||||
tool_overviews=tools_intro_text, tool_names=tool_names_text
|
||||
).strip()
|
||||
|
||||
|
||||
def form_tool_followup_text(
|
||||
tool_output: str,
|
||||
query: str,
|
||||
hint_text: str | None,
|
||||
tool_followup_prompt: str = TOOL_FOLLOWUP,
|
||||
ignore_hint: bool = False,
|
||||
) -> str:
|
||||
# If multi-line query, it likely confuses the model more than helps
|
||||
if "\n" not in query:
|
||||
optional_reminder = f"\nAs a reminder, my query was: {query}\n"
|
||||
else:
|
||||
optional_reminder = ""
|
||||
|
||||
if not ignore_hint and hint_text:
|
||||
hint_text_spaced = f"\nHint: {hint_text}\n"
|
||||
else:
|
||||
hint_text_spaced = ""
|
||||
|
||||
return tool_followup_prompt.format(
|
||||
tool_output=tool_output,
|
||||
optional_reminder=optional_reminder,
|
||||
hint=hint_text_spaced,
|
||||
).strip()
|
||||
|
||||
|
||||
def form_tool_less_followup_text(
|
||||
tool_output: str,
|
||||
query: str,
|
||||
hint_text: str | None,
|
||||
tool_followup_prompt: str = TOOL_LESS_FOLLOWUP,
|
||||
) -> str:
|
||||
hint = f"Hint: {hint_text}" if hint_text else ""
|
||||
return tool_followup_prompt.format(
|
||||
context_str=tool_output, user_query=query, hint_text=hint
|
||||
).strip()
|
||||
@@ -43,9 +43,6 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
||||
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
|
||||
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
|
||||
|
||||
# Necessary for cloud integration tests
|
||||
DISABLE_VERIFICATION = os.environ.get("DISABLE_VERIFICATION", "").lower() == "true"
|
||||
|
||||
# Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive
|
||||
# information. This provides an extra layer of security on top of Postgres access controls
|
||||
# and is available in Danswer EE
|
||||
@@ -84,7 +81,14 @@ OAUTH_CLIENT_SECRET = (
|
||||
or ""
|
||||
)
|
||||
|
||||
# for future OAuth connector support
|
||||
# OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "")
|
||||
# OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "")
|
||||
# OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "")
|
||||
# OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "")
|
||||
|
||||
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
|
||||
|
||||
# for basic auth
|
||||
REQUIRE_EMAIL_VERIFICATION = (
|
||||
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
|
||||
@@ -118,6 +122,8 @@ VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
VESPA_CONFIG_SERVER_HOST = os.environ.get("VESPA_CONFIG_SERVER_HOST") or VESPA_HOST
|
||||
VESPA_PORT = os.environ.get("VESPA_PORT") or "8081"
|
||||
VESPA_TENANT_PORT = os.environ.get("VESPA_TENANT_PORT") or "19071"
|
||||
# the number of times to try and connect to vespa on startup before giving up
|
||||
VESPA_NUM_ATTEMPTS_ON_STARTUP = int(os.environ.get("NUM_RETRIES_ON_STARTUP") or 10)
|
||||
|
||||
VESPA_CLOUD_URL = os.environ.get("VESPA_CLOUD_URL", "")
|
||||
|
||||
@@ -234,7 +240,7 @@ except ValueError:
|
||||
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT
|
||||
)
|
||||
|
||||
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 1
|
||||
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 3
|
||||
try:
|
||||
env_value = os.environ.get("CELERY_WORKER_INDEXING_CONCURRENCY")
|
||||
if not env_value:
|
||||
@@ -308,6 +314,22 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
|
||||
)
|
||||
|
||||
# Due to breakages in the confluence API, the timezone offset must be specified client side
|
||||
# to match the user's specified timezone.
|
||||
|
||||
# The current state of affairs:
|
||||
# CQL queries are parsed in the user's timezone and cannot be specified in UTC
|
||||
# no API retrieves the user's timezone
|
||||
# All data is returned in UTC, so we can't derive the user's timezone from that
|
||||
|
||||
# https://community.developer.atlassian.com/t/confluence-cloud-time-zone-get-via-rest-api/35954/16
|
||||
# https://jira.atlassian.com/browse/CONFCLOUD-69670
|
||||
|
||||
# enter as a floating point offset from UTC in hours (-24 < val < 24)
|
||||
# this will be applied globally, so it probably makes sense to transition this to per
|
||||
# connector as some point.
|
||||
CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", 0.0))
|
||||
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
@@ -326,6 +348,12 @@ GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
|
||||
os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Egnyte specific configs
|
||||
EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE")
|
||||
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
|
||||
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
|
||||
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
|
||||
|
||||
DASK_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
@@ -389,21 +417,28 @@ LARGE_CHUNK_RATIO = 4
|
||||
# We don't want the metadata to overwhelm the actual contents of the chunk
|
||||
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"
|
||||
# Timeout to wait for job's last update before killing it, in hours
|
||||
CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3))
|
||||
CLEANUP_INDEXING_JOBS_TIMEOUT = int(
|
||||
os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT") or 3
|
||||
)
|
||||
|
||||
# The indexer will warn in the logs whenver a document exceeds this threshold (in bytes)
|
||||
INDEXING_SIZE_WARNING_THRESHOLD = int(
|
||||
os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD", 100 * 1024 * 1024)
|
||||
os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD") or 100 * 1024 * 1024
|
||||
)
|
||||
|
||||
# during indexing, will log verbose memory diff stats every x batches and at the end.
|
||||
# 0 disables this behavior and is the default.
|
||||
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL", 0))
|
||||
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL") or 0)
|
||||
|
||||
# During an indexing attempt, specifies the number of batches which are allowed to
|
||||
# exception without aborting the attempt.
|
||||
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT", 0))
|
||||
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT") or 0)
|
||||
|
||||
# Maximum file size in a document to be indexed
|
||||
MAX_DOCUMENT_CHARS = int(os.environ.get("MAX_DOCUMENT_CHARS") or 5_000_000)
|
||||
MAX_FILE_SIZE_BYTES = int(
|
||||
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
|
||||
) # 2GB in bytes
|
||||
|
||||
#####
|
||||
# Miscellaneous
|
||||
@@ -422,6 +457,9 @@ LOG_ALL_MODEL_INTERACTIONS = (
|
||||
LOG_DANSWER_MODEL_INTERACTIONS = (
|
||||
os.environ.get("LOG_DANSWER_MODEL_INTERACTIONS", "").lower() == "true"
|
||||
)
|
||||
LOG_INDIVIDUAL_MODEL_TOKENS = (
|
||||
os.environ.get("LOG_INDIVIDUAL_MODEL_TOKENS", "").lower() == "true"
|
||||
)
|
||||
# If set to `true` will enable additional logs about Vespa query performance
|
||||
# (time spent on finding the right docs + time spent fetching summaries from disk)
|
||||
LOG_VESPA_TIMING_INFORMATION = (
|
||||
@@ -490,10 +528,6 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
|
||||
# JWT configuration
|
||||
JWT_ALGORITHM = "HS256"
|
||||
|
||||
# Super Users
|
||||
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
|
||||
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
|
||||
|
||||
#####
|
||||
# API Key Configs
|
||||
@@ -507,3 +541,6 @@ API_KEY_HASH_ROUNDS = (
|
||||
|
||||
POD_NAME = os.environ.get("POD_NAME")
|
||||
POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
|
||||
|
||||
|
||||
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import os
|
||||
|
||||
|
||||
PROMPTS_YAML = "./danswer/chat/prompts.yaml"
|
||||
PERSONAS_YAML = "./danswer/chat/personas.yaml"
|
||||
INPUT_PROMPT_YAML = "./danswer/chat/input_prompts.yaml"
|
||||
PROMPTS_YAML = "./danswer/seeding/prompts.yaml"
|
||||
PERSONAS_YAML = "./danswer/seeding/personas.yaml"
|
||||
|
||||
NUM_RETURNED_HITS = 50
|
||||
# Used for LLM filtering and reranking
|
||||
@@ -17,9 +16,6 @@ MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
|
||||
# ~3k input, half for docs, half for chat history + prompts
|
||||
CHAT_TARGET_CHUNK_PERCENTAGE = 512 * 3 / 3072
|
||||
|
||||
# For selecting a different LLM question-answering prompt format
|
||||
# Valid values: default, cot, weak
|
||||
QA_PROMPT_OVERRIDE = os.environ.get("QA_PROMPT_OVERRIDE") or None
|
||||
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
|
||||
# Capped in Vespa at 0.5
|
||||
DOC_TIME_DECAY = float(
|
||||
@@ -27,8 +23,6 @@ DOC_TIME_DECAY = float(
|
||||
)
|
||||
BASE_RECENCY_DECAY = 0.5
|
||||
FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
|
||||
# Currently this next one is not configurable via env
|
||||
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
|
||||
# For the highest matching base size chunk, how many chunks above and below do we pull in by default
|
||||
# Note this is not in any of the deployment configs yet
|
||||
# Currently only applies to search flow not chat
|
||||
|
||||
@@ -31,6 +31,8 @@ DISABLED_GEN_AI_MSG = (
|
||||
"You can still use Danswer as a search engine."
|
||||
)
|
||||
|
||||
DEFAULT_PERSONA_ID = 0
|
||||
|
||||
# Postgres connection constants for application_name
|
||||
POSTGRES_WEB_APP_NAME = "web"
|
||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||
@@ -130,6 +132,7 @@ class DocumentSource(str, Enum):
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
FRESHDESK = "freshdesk"
|
||||
FIREFLIES = "fireflies"
|
||||
EGNYTE = "egnyte"
|
||||
|
||||
|
||||
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
|
||||
@@ -259,6 +262,32 @@ class DanswerCeleryPriority(int, Enum):
|
||||
LOWEST = auto()
|
||||
|
||||
|
||||
class DanswerCeleryTask:
|
||||
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
CHECK_FOR_INDEXING = "check_for_indexing"
|
||||
CHECK_FOR_PRUNING = "check_for_pruning"
|
||||
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
"connector_permission_sync_generator_task"
|
||||
)
|
||||
UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK = (
|
||||
"update_external_document_permissions_task"
|
||||
)
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK = (
|
||||
"connector_external_group_sync_generator_task"
|
||||
)
|
||||
CONNECTOR_INDEXING_PROXY_TASK = "connector_indexing_proxy_task"
|
||||
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
|
||||
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
|
||||
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
|
||||
AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task"
|
||||
|
||||
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
|
||||
|
||||
@@ -4,11 +4,8 @@ import os
|
||||
# Danswer Slack Bot Configs
|
||||
#####
|
||||
DANSWER_BOT_NUM_RETRIES = int(os.environ.get("DANSWER_BOT_NUM_RETRIES", "5"))
|
||||
DANSWER_BOT_ANSWER_GENERATION_TIMEOUT = int(
|
||||
os.environ.get("DANSWER_BOT_ANSWER_GENERATION_TIMEOUT", "90")
|
||||
)
|
||||
# How much of the available input context can be used for thread context
|
||||
DANSWER_BOT_TARGET_CHUNK_PERCENTAGE = 512 * 2 / 3072
|
||||
MAX_THREAD_CONTEXT_PERCENTAGE = 512 * 2 / 3072
|
||||
# Number of docs to display in "Reference Documents"
|
||||
DANSWER_BOT_NUM_DOCS_TO_DISPLAY = int(
|
||||
os.environ.get("DANSWER_BOT_NUM_DOCS_TO_DISPLAY", "5")
|
||||
@@ -47,17 +44,6 @@ DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
|
||||
DANSWER_BOT_RESPOND_EVERY_CHANNEL = (
|
||||
os.environ.get("DANSWER_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true"
|
||||
)
|
||||
# Add a second LLM call post Answer to verify if the Answer is valid
|
||||
# Throws out answers that don't directly or fully answer the user query
|
||||
# This is the default for all DanswerBot channels unless the channel is configured individually
|
||||
# Set/unset by "Hide Non Answers"
|
||||
ENABLE_DANSWERBOT_REFLEXION = (
|
||||
os.environ.get("ENABLE_DANSWERBOT_REFLEXION", "").lower() == "true"
|
||||
)
|
||||
# Currently not support chain of thought, probably will add back later
|
||||
DANSWER_BOT_DISABLE_COT = True
|
||||
# if set, will default DanswerBot to use quotes and reference documents
|
||||
DANSWER_BOT_USE_QUOTES = os.environ.get("DANSWER_BOT_USE_QUOTES", "").lower() == "true"
|
||||
|
||||
# Maximum Questions Per Minute, Default Uncapped
|
||||
DANSWER_BOT_MAX_QPM = int(os.environ.get("DANSWER_BOT_MAX_QPM") or 0) or None
|
||||
|
||||
@@ -70,7 +70,9 @@ GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
|
||||
)
|
||||
|
||||
# Typically, GenAI models nowadays are at least 4K tokens
|
||||
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096
|
||||
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = int(
|
||||
os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 4096
|
||||
)
|
||||
|
||||
# Number of tokens from chat history to include at maximum
|
||||
# 3000 should be enough context regardless of use, no need to include as much as possible
|
||||
|
||||
@@ -2,6 +2,8 @@ import json
|
||||
import os
|
||||
|
||||
|
||||
IMAGE_GENERATION_OUTPUT_FORMAT = os.environ.get("IMAGE_GENERATION_OUTPUT_FORMAT", "url")
|
||||
|
||||
# if specified, will pass through request headers to the call to API calls made by custom tools
|
||||
CUSTOM_TOOL_PASS_THROUGH_HEADERS: list[str] | None = None
|
||||
_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get(
|
||||
|
||||
@@ -11,11 +11,16 @@ Connectors come in 3 different flows:
|
||||
- Load Connector:
|
||||
- Bulk indexes documents to reflect a point in time. This type of connector generally works by either pulling all
|
||||
documents via a connector's API or loads the documents from some sort of a dump file.
|
||||
- Poll connector:
|
||||
- Poll Connector:
|
||||
- Incrementally updates documents based on a provided time range. It is used by the background job to pull the latest
|
||||
changes and additions since the last round of polling. This connector helps keep the document index up to date
|
||||
without needing to fetch/embed/index every document which would be too slow to do frequently on large sets of
|
||||
documents.
|
||||
- Slim Connector:
|
||||
- This connector should be a lighter weight method of checking all documents in the source to see if they still exist.
|
||||
- This connector should be identical to the Poll or Load Connector except that it only fetches the IDs of the documents, not the documents themselves.
|
||||
- This is used by our pruning job which removes old documents from the index.
|
||||
- The optional start and end datetimes can be ignored.
|
||||
- Event Based connectors:
|
||||
- Connectors that listen to events and update documents accordingly.
|
||||
- Currently not used by the background job, this exists for future design purposes.
|
||||
@@ -26,8 +31,14 @@ Refer to [interfaces.py](https://github.com/danswer-ai/danswer/blob/main/backend
|
||||
and this first contributor created Pull Request for a new connector (Shoutout to Dan Brown):
|
||||
[Reference Pull Request](https://github.com/danswer-ai/danswer/pull/139)
|
||||
|
||||
For implementing a Slim Connector, refer to the comments in this PR:
|
||||
[Slim Connector PR](https://github.com/danswer-ai/danswer/pull/3303/files)
|
||||
|
||||
All new connectors should have tests added to the `backend/tests/daily/connectors` directory. Refer to the above PR for an example of adding tests for a new connector.
|
||||
|
||||
|
||||
#### Implementing the new Connector
|
||||
The connector must subclass one or more of LoadConnector, PollConnector, or EventConnector.
|
||||
The connector must subclass one or more of LoadConnector, PollConnector, SlimConnector, or EventConnector.
|
||||
|
||||
The `__init__` should take arguments for configuring what documents the connector will and where it finds those
|
||||
documents. For example, if you have a wiki site, it may include the configuration for the team, topic, folder, etc. of
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
@@ -13,6 +15,7 @@ from danswer.connectors.confluence.utils import attachment_to_content
|
||||
from danswer.connectors.confluence.utils import build_confluence_document_id
|
||||
from danswer.connectors.confluence.utils import datetime_from_string
|
||||
from danswer.connectors.confluence.utils import extract_text_from_confluence_html
|
||||
from danswer.connectors.confluence.utils import validate_attachment_filetype
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
@@ -51,6 +54,8 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
|
||||
"restrictions.read.restrictions.group",
|
||||
]
|
||||
|
||||
_SLIM_DOC_BATCH_SIZE = 5000
|
||||
|
||||
|
||||
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
@@ -67,6 +72,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# skip it. This is generally used to avoid indexing extra sensitive
|
||||
# pages.
|
||||
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
|
||||
timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
@@ -102,6 +108,8 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
|
||||
|
||||
self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset))
|
||||
|
||||
@property
|
||||
def confluence_client(self) -> OnyxConfluence:
|
||||
if self._confluence_client is None:
|
||||
@@ -202,12 +210,14 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
confluence_page_ids: list[str] = []
|
||||
|
||||
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
|
||||
logger.debug(f"page_query: {page_query}")
|
||||
# Fetch pages as Documents
|
||||
for page in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=page_query,
|
||||
expand=",".join(_PAGE_EXPANSION_FIELDS),
|
||||
limit=self.batch_size,
|
||||
):
|
||||
logger.debug(f"_fetch_document_batches: {page['id']}")
|
||||
confluence_page_ids.append(page["id"])
|
||||
doc = self._convert_object_to_document(page)
|
||||
if doc is not None:
|
||||
@@ -240,10 +250,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
|
||||
# Add time filters
|
||||
formatted_start_time = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
|
||||
formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=timezone.utc).strftime(
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'"
|
||||
@@ -263,12 +273,15 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
for page in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=page_query,
|
||||
expand=restrictions_expand,
|
||||
limit=_SLIM_DOC_BATCH_SIZE,
|
||||
):
|
||||
# If the page has restrictions, add them to the perm_sync_data
|
||||
# These will be used by doc_sync.py to sync permissions
|
||||
perm_sync_data = {
|
||||
"restrictions": page.get("restrictions", {}),
|
||||
"space_key": page.get("space", {}).get("key"),
|
||||
page_restrictions = page.get("restrictions")
|
||||
page_space_key = page.get("space", {}).get("key")
|
||||
page_perm_sync_data = {
|
||||
"restrictions": page_restrictions or {},
|
||||
"space_key": page_space_key,
|
||||
}
|
||||
|
||||
doc_metadata_list.append(
|
||||
@@ -278,7 +291,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
page["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
),
|
||||
perm_sync_data=perm_sync_data,
|
||||
perm_sync_data=page_perm_sync_data,
|
||||
)
|
||||
)
|
||||
attachment_cql = f"type=attachment and container='{page['id']}'"
|
||||
@@ -286,7 +299,23 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
for attachment in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_cql,
|
||||
expand=restrictions_expand,
|
||||
limit=_SLIM_DOC_BATCH_SIZE,
|
||||
):
|
||||
if not validate_attachment_filetype(attachment):
|
||||
continue
|
||||
attachment_restrictions = attachment.get("restrictions")
|
||||
if not attachment_restrictions:
|
||||
attachment_restrictions = page_restrictions
|
||||
|
||||
attachment_space_key = attachment.get("space", {}).get("key")
|
||||
if not attachment_space_key:
|
||||
attachment_space_key = page_space_key
|
||||
|
||||
attachment_perm_sync_data = {
|
||||
"restrictions": attachment_restrictions or {},
|
||||
"space_key": attachment_space_key,
|
||||
}
|
||||
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=build_confluence_document_id(
|
||||
@@ -294,8 +323,11 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
attachment["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
),
|
||||
perm_sync_data=perm_sync_data,
|
||||
perm_sync_data=attachment_perm_sync_data,
|
||||
)
|
||||
)
|
||||
yield doc_metadata_list
|
||||
doc_metadata_list = []
|
||||
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
|
||||
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
|
||||
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
|
||||
|
||||
yield doc_metadata_list
|
||||
|
||||
@@ -120,7 +120,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
_DEFAULT_PAGINATION_LIMIT = 100
|
||||
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||
|
||||
|
||||
class OnyxConfluence(Confluence):
|
||||
@@ -134,6 +134,32 @@ class OnyxConfluence(Confluence):
|
||||
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
|
||||
self._wrap_methods()
|
||||
|
||||
def get_current_user(self, expand: str | None = None) -> Any:
|
||||
"""
|
||||
Implements a method that isn't in the third party client.
|
||||
|
||||
Get information about the current user
|
||||
:param expand: OPTIONAL expand for get status of user.
|
||||
Possible param is "status". Results are "Active, Deactivated"
|
||||
:return: Returns the user details
|
||||
"""
|
||||
|
||||
from atlassian.errors import ApiPermissionError # type:ignore
|
||||
|
||||
url = "rest/api/user/current"
|
||||
params = {}
|
||||
if expand:
|
||||
params["expand"] = expand
|
||||
try:
|
||||
response = self.get(url, params=params)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 403:
|
||||
raise ApiPermissionError(
|
||||
"The calling user does not have permission", reason=e
|
||||
)
|
||||
raise
|
||||
return response
|
||||
|
||||
def _wrap_methods(self) -> None:
|
||||
"""
|
||||
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
|
||||
@@ -294,14 +320,24 @@ def _validate_connector_configuration(
|
||||
wiki_base: str,
|
||||
) -> None:
|
||||
# test connection with direct client, no retries
|
||||
confluence_client_without_retries = Confluence(
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
url=wiki_base.rstrip("/"),
|
||||
username=credentials["confluence_username"] if is_cloud else None,
|
||||
password=credentials["confluence_access_token"] if is_cloud else None,
|
||||
token=credentials["confluence_access_token"] if not is_cloud else None,
|
||||
backoff_and_retry=True,
|
||||
max_backoff_retries=6,
|
||||
max_backoff_seconds=10,
|
||||
)
|
||||
spaces = confluence_client_without_retries.get_all_spaces(limit=1)
|
||||
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
|
||||
|
||||
# uncomment the following for testing
|
||||
# the following is an attempt to retrieve the user's timezone
|
||||
# Unfornately, all data is returned in UTC regardless of the user's time zone
|
||||
# even tho CQL parses incoming times based on the user's time zone
|
||||
# space_key = spaces["results"][0]["key"]
|
||||
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
|
||||
|
||||
if not spaces:
|
||||
raise RuntimeError(
|
||||
@@ -332,4 +368,5 @@ def build_confluence_client(
|
||||
backoff_and_retry=True,
|
||||
max_backoff_retries=10,
|
||||
max_backoff_seconds=60,
|
||||
cloud=is_cloud,
|
||||
)
|
||||
|
||||
@@ -32,7 +32,11 @@ def get_user_email_from_username__server(
|
||||
response = confluence_client.get_mobile_parameters(user_name)
|
||||
email = response.get("email")
|
||||
except Exception:
|
||||
email = None
|
||||
# For now, we'll just return a string that indicates failure
|
||||
# We may want to revert to returning None in the future
|
||||
# email = None
|
||||
email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
||||
logger.warning(f"failed to get confluence email for {user_name}")
|
||||
_USER_EMAIL_CACHE[user_name] = email
|
||||
return _USER_EMAIL_CACHE[user_name]
|
||||
|
||||
@@ -173,19 +177,23 @@ def extract_text_from_confluence_html(
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if attachment["metadata"]["mediaType"] in [
|
||||
def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
|
||||
return attachment["metadata"]["mediaType"] not in [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"video/mp4",
|
||||
"video/quicktime",
|
||||
]:
|
||||
]
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if not validate_attachment_filetype(attachment):
|
||||
return None
|
||||
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
@@ -241,7 +249,7 @@ def build_confluence_document_id(
|
||||
return f"{base_url}{content_url}"
|
||||
|
||||
|
||||
def extract_referenced_attachment_names(page_text: str) -> list[str]:
|
||||
def _extract_referenced_attachment_names(page_text: str) -> list[str]:
|
||||
"""Parse a Confluence html page to generate a list of current
|
||||
attachments in use
|
||||
|
||||
|
||||
384
backend/danswer/connectors/egnyte/connector.py
Normal file
384
backend/danswer/connectors/egnyte/connector.py
Normal file
@@ -0,0 +1,384 @@
|
||||
import io
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from logging import Logger
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import IO
|
||||
|
||||
import requests
|
||||
from retry import retry
|
||||
|
||||
from danswer.configs.app_configs import EGNYTE_BASE_DOMAIN
|
||||
from danswer.configs.app_configs import EGNYTE_CLIENT_ID
|
||||
from danswer.configs.app_configs import EGNYTE_CLIENT_SECRET
|
||||
from danswer.configs.app_configs import EGNYTE_LOCALHOST_OVERRIDE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import OAuthConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import detect_encoding
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.file_processing.extract_file_text import get_file_ext
|
||||
from danswer.file_processing.extract_file_text import is_text_file_extension
|
||||
from danswer.file_processing.extract_file_text import is_valid_file_ext
|
||||
from danswer.file_processing.extract_file_text import read_text_file
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1"
|
||||
_EGNYTE_APP_BASE = "https://{domain}.egnyte.com"
|
||||
_TIMEOUT = 60
|
||||
|
||||
|
||||
def _request_with_retries(
|
||||
method: str,
|
||||
url: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
headers: dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
timeout: int = _TIMEOUT,
|
||||
stream: bool = False,
|
||||
tries: int = 8,
|
||||
delay: float = 1,
|
||||
backoff: float = 2,
|
||||
) -> requests.Response:
|
||||
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
|
||||
def _make_request() -> requests.Response:
|
||||
response = requests.request(
|
||||
method,
|
||||
url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code != 403:
|
||||
logger.exception(
|
||||
f"Failed to call Egnyte API.\n"
|
||||
f"URL: {url}\n"
|
||||
f"Headers: {headers}\n"
|
||||
f"Data: {data}\n"
|
||||
f"Params: {params}"
|
||||
)
|
||||
raise e
|
||||
return response
|
||||
|
||||
return _make_request()
|
||||
|
||||
|
||||
def _parse_last_modified(last_modified: str) -> datetime:
|
||||
return datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z").replace(
|
||||
tzinfo=timezone.utc
|
||||
)
|
||||
|
||||
|
||||
def _process_egnyte_file(
|
||||
file_metadata: dict[str, Any],
|
||||
file_content: IO,
|
||||
base_url: str,
|
||||
folder_path: str | None = None,
|
||||
) -> Document | None:
|
||||
"""Process an Egnyte file into a Document object
|
||||
|
||||
Args:
|
||||
file_data: The file data from Egnyte API
|
||||
file_content: The raw content of the file in bytes
|
||||
base_url: The base URL for the Egnyte instance
|
||||
folder_path: Optional folder path to filter results
|
||||
"""
|
||||
# Skip if file path doesn't match folder path filter
|
||||
if folder_path and not file_metadata["path"].startswith(folder_path):
|
||||
raise ValueError(
|
||||
f"File path {file_metadata['path']} does not match folder path {folder_path}"
|
||||
)
|
||||
|
||||
file_name = file_metadata["name"]
|
||||
extension = get_file_ext(file_name)
|
||||
if not is_valid_file_ext(extension):
|
||||
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
||||
return None
|
||||
|
||||
# Extract text content based on file type
|
||||
if is_text_file_extension(file_name):
|
||||
encoding = detect_encoding(file_content)
|
||||
file_content_raw, file_metadata = read_text_file(
|
||||
file_content, encoding=encoding, ignore_danswer_metadata=False
|
||||
)
|
||||
else:
|
||||
file_content_raw = extract_file_text(
|
||||
file=file_content,
|
||||
file_name=file_name,
|
||||
break_on_unprocessable=True,
|
||||
)
|
||||
|
||||
# Build the web URL for the file
|
||||
web_url = f"{base_url}/navigate/file/{file_metadata['group_id']}"
|
||||
|
||||
# Create document metadata
|
||||
metadata: dict[str, str | list[str]] = {
|
||||
"file_path": file_metadata["path"],
|
||||
"last_modified": file_metadata.get("last_modified", ""),
|
||||
}
|
||||
|
||||
# Add lock info if present
|
||||
if lock_info := file_metadata.get("lock_info"):
|
||||
metadata[
|
||||
"lock_owner"
|
||||
] = f"{lock_info.get('first_name', '')} {lock_info.get('last_name', '')}"
|
||||
|
||||
# Create the document owners
|
||||
primary_owner = None
|
||||
if uploaded_by := file_metadata.get("uploaded_by"):
|
||||
primary_owner = BasicExpertInfo(
|
||||
email=uploaded_by, # Using username as email since that's what we have
|
||||
)
|
||||
|
||||
# Create the document
|
||||
return Document(
|
||||
id=f"egnyte-{file_metadata['entry_id']}",
|
||||
sections=[Section(text=file_content_raw.strip(), link=web_url)],
|
||||
source=DocumentSource.EGNYTE,
|
||||
semantic_identifier=file_name,
|
||||
metadata=metadata,
|
||||
doc_updated_at=(
|
||||
_parse_last_modified(file_metadata["last_modified"])
|
||||
if "last_modified" in file_metadata
|
||||
else None
|
||||
),
|
||||
primary_owners=[primary_owner] if primary_owner else None,
|
||||
)
|
||||
|
||||
|
||||
class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
def __init__(
|
||||
self,
|
||||
folder_path: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.domain = "" # will always be set in `load_credentials`
|
||||
self.folder_path = folder_path or "" # Root folder if not specified
|
||||
self.batch_size = batch_size
|
||||
self.access_token: str | None = None
|
||||
|
||||
@classmethod
|
||||
def oauth_id(cls) -> DocumentSource:
|
||||
return DocumentSource.EGNYTE
|
||||
|
||||
@classmethod
|
||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
||||
if not EGNYTE_CLIENT_ID:
|
||||
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
|
||||
if not EGNYTE_BASE_DOMAIN:
|
||||
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
||||
|
||||
if EGNYTE_LOCALHOST_OVERRIDE:
|
||||
base_domain = EGNYTE_LOCALHOST_OVERRIDE
|
||||
|
||||
callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte"
|
||||
return (
|
||||
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
||||
f"?client_id={EGNYTE_CLIENT_ID}"
|
||||
f"&redirect_uri={callback_uri}"
|
||||
f"&scope=Egnyte.filesystem"
|
||||
f"&state={state}"
|
||||
f"&response_type=code"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def oauth_code_to_token(cls, code: str) -> dict[str, Any]:
|
||||
if not EGNYTE_CLIENT_ID:
|
||||
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
|
||||
if not EGNYTE_CLIENT_SECRET:
|
||||
raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set")
|
||||
if not EGNYTE_BASE_DOMAIN:
|
||||
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
||||
|
||||
# Exchange code for token
|
||||
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
||||
data = {
|
||||
"client_id": EGNYTE_CLIENT_ID,
|
||||
"client_secret": EGNYTE_CLIENT_SECRET,
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": f"{EGNYTE_LOCALHOST_OVERRIDE or ''}/connector/oauth/callback/egnyte",
|
||||
"scope": "Egnyte.filesystem",
|
||||
}
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
response = _request_with_retries(
|
||||
method="POST",
|
||||
url=url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
# try a lot faster since this is a realtime flow
|
||||
backoff=0,
|
||||
delay=0.1,
|
||||
)
|
||||
if not response.ok:
|
||||
raise RuntimeError(f"Failed to exchange code for token: {response.text}")
|
||||
|
||||
token_data = response.json()
|
||||
return {
|
||||
"domain": EGNYTE_BASE_DOMAIN,
|
||||
"access_token": token_data["access_token"],
|
||||
}
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.domain = credentials["domain"]
|
||||
self.access_token = credentials["access_token"]
|
||||
return None
|
||||
|
||||
def _get_files_list(
|
||||
self,
|
||||
path: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
if not self.access_token or not self.domain:
|
||||
raise ConnectorMissingCredentialError("Egnyte")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.access_token}",
|
||||
}
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"list_content": True,
|
||||
}
|
||||
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{path or ''}"
|
||||
response = _request_with_retries(
|
||||
method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT
|
||||
)
|
||||
if not response.ok:
|
||||
raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}")
|
||||
|
||||
data = response.json()
|
||||
all_files: list[dict[str, Any]] = []
|
||||
|
||||
# Add files from current directory
|
||||
all_files.extend(data.get("files", []))
|
||||
|
||||
# Recursively traverse folders
|
||||
for item in data.get("folders", []):
|
||||
all_files.extend(self._get_files_list(item["path"]))
|
||||
|
||||
return all_files
|
||||
|
||||
def _filter_files(
|
||||
self,
|
||||
files: list[dict[str, Any]],
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
filtered_files = []
|
||||
for file in files:
|
||||
if file["is_folder"]:
|
||||
continue
|
||||
|
||||
file_modified = _parse_last_modified(file["last_modified"])
|
||||
if start_time and file_modified < start_time:
|
||||
continue
|
||||
if end_time and file_modified > end_time:
|
||||
continue
|
||||
|
||||
filtered_files.append(file)
|
||||
|
||||
return filtered_files
|
||||
|
||||
def _process_files(
|
||||
self,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
) -> Generator[list[Document], None, None]:
|
||||
files = self._get_files_list(self.folder_path)
|
||||
files = self._filter_files(files, start_time, end_time)
|
||||
|
||||
current_batch: list[Document] = []
|
||||
for file in files:
|
||||
try:
|
||||
# Set up request with streaming enabled
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.access_token}",
|
||||
}
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{file['path']}"
|
||||
response = _request_with_retries(
|
||||
method="GET",
|
||||
url=url,
|
||||
headers=headers,
|
||||
timeout=_TIMEOUT,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
logger.error(
|
||||
f"Failed to fetch file content: {file['path']} (status code: {response.status_code})"
|
||||
)
|
||||
continue
|
||||
|
||||
# Stream the response content into a BytesIO buffer
|
||||
buffer = io.BytesIO()
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
buffer.write(chunk)
|
||||
|
||||
# Reset buffer's position to the start
|
||||
buffer.seek(0)
|
||||
|
||||
# Process the streamed file content
|
||||
doc = _process_egnyte_file(
|
||||
file_metadata=file,
|
||||
file_content=buffer,
|
||||
base_url=_EGNYTE_APP_BASE.format(domain=self.domain),
|
||||
folder_path=self.folder_path,
|
||||
)
|
||||
|
||||
if doc is not None:
|
||||
current_batch.append(doc)
|
||||
|
||||
if len(current_batch) >= self.batch_size:
|
||||
yield current_batch
|
||||
current_batch = []
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Failed to process file {file['path']}")
|
||||
continue
|
||||
|
||||
if current_batch:
|
||||
yield current_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
yield from self._process_files()
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_time = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_time = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
yield from self._process_files(start_time=start_time, end_time=end_time)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = EgnyteConnector()
|
||||
connector.load_credentials(
|
||||
{
|
||||
"domain": os.environ["EGNYTE_DOMAIN"],
|
||||
"access_token": os.environ["EGNYTE_ACCESS_TOKEN"],
|
||||
}
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
@@ -15,6 +15,7 @@ from danswer.connectors.danswer_jira.connector import JiraConnector
|
||||
from danswer.connectors.discourse.connector import DiscourseConnector
|
||||
from danswer.connectors.document360.connector import Document360Connector
|
||||
from danswer.connectors.dropbox.connector import DropboxConnector
|
||||
from danswer.connectors.egnyte.connector import EgnyteConnector
|
||||
from danswer.connectors.file.connector import LocalFileConnector
|
||||
from danswer.connectors.fireflies.connector import FirefliesConnector
|
||||
from danswer.connectors.freshdesk.connector import FreshdeskConnector
|
||||
@@ -103,6 +104,7 @@ def identify_connector_class(
|
||||
DocumentSource.XENFORO: XenforoConnector,
|
||||
DocumentSource.FRESHDESK: FreshdeskConnector,
|
||||
DocumentSource.FIREFLIES: FirefliesConnector,
|
||||
DocumentSource.EGNYTE: EgnyteConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
|
||||
|
||||
@@ -17,11 +17,11 @@ from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.file_processing.extract_file_text import check_file_ext_is_valid
|
||||
from danswer.file_processing.extract_file_text import detect_encoding
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.file_processing.extract_file_text import get_file_ext
|
||||
from danswer.file_processing.extract_file_text import is_text_file_extension
|
||||
from danswer.file_processing.extract_file_text import is_valid_file_ext
|
||||
from danswer.file_processing.extract_file_text import load_files_from_zip
|
||||
from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.file_processing.extract_file_text import read_text_file
|
||||
@@ -50,7 +50,7 @@ def _read_files_and_metadata(
|
||||
file_content, ignore_dirs=True
|
||||
):
|
||||
yield os.path.join(directory_path, file_info.filename), file, metadata
|
||||
elif check_file_ext_is_valid(extension):
|
||||
elif is_valid_file_ext(extension):
|
||||
yield file_name, file_content, metadata
|
||||
else:
|
||||
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
||||
@@ -63,7 +63,7 @@ def _process_file(
|
||||
pdf_pass: str | None = None,
|
||||
) -> list[Document]:
|
||||
extension = get_file_ext(file_name)
|
||||
if not check_file_ext_is_valid(extension):
|
||||
if not is_valid_file_ext(extension):
|
||||
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
||||
return []
|
||||
|
||||
|
||||
@@ -4,11 +4,13 @@ from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.app_configs import MAX_FILE_SIZE_BYTES
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.google_drive.doc_conversion import build_slim_document
|
||||
from danswer.connectors.google_drive.doc_conversion import (
|
||||
@@ -452,12 +454,14 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if isinstance(self.creds, ServiceAccountCredentials)
|
||||
else self._manage_oauth_retrieval
|
||||
)
|
||||
return retrieval_method(
|
||||
drive_files = retrieval_method(
|
||||
is_slim=is_slim,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
return drive_files
|
||||
|
||||
def _extract_docs_from_google_drive(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
@@ -473,6 +477,15 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
files_to_process = []
|
||||
# Gather the files into batches to be processed in parallel
|
||||
for file in self._fetch_drive_items(is_slim=False, start=start, end=end):
|
||||
if (
|
||||
file.get("size")
|
||||
and int(cast(str, file.get("size"))) > MAX_FILE_SIZE_BYTES
|
||||
):
|
||||
logger.warning(
|
||||
f"Skipping file {file.get('name', 'Unknown')} as it is too large: {file.get('size')} bytes"
|
||||
)
|
||||
continue
|
||||
|
||||
files_to_process.append(file)
|
||||
if len(files_to_process) >= LARGE_BATCH_SIZE:
|
||||
yield from _process_files_batch(
|
||||
|
||||
@@ -16,7 +16,7 @@ logger = setup_logger()
|
||||
|
||||
FILE_FIELDS = (
|
||||
"nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, "
|
||||
"shortcutDetails, owners(emailAddress))"
|
||||
"shortcutDetails, owners(emailAddress), size)"
|
||||
)
|
||||
SLIM_FILE_FIELDS = (
|
||||
"nextPageToken, files(mimeType, id, name, permissions(emailAddress, type), "
|
||||
|
||||
@@ -2,6 +2,7 @@ import abc
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import SlimDocument
|
||||
|
||||
@@ -64,6 +65,23 @@ class SlimConnector(BaseConnector):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class OAuthConnector(BaseConnector):
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_id(cls) -> DocumentSource:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_code_to_token(cls, code: str) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Event driven
|
||||
class EventConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -132,7 +132,6 @@ class LinearConnector(LoadConnector, PollConnector):
|
||||
branchName
|
||||
customerTicketCount
|
||||
description
|
||||
descriptionData
|
||||
comments {
|
||||
nodes {
|
||||
url
|
||||
@@ -215,5 +214,6 @@ class LinearConnector(LoadConnector, PollConnector):
|
||||
if __name__ == "__main__":
|
||||
connector = LinearConnector()
|
||||
connector.load_credentials({"linear_api_key": os.environ["LINEAR_API_KEY"]})
|
||||
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
|
||||
@@ -12,12 +12,15 @@ from dateutil import parser
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.models import SlimDocument
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -28,6 +31,8 @@ logger = setup_logger()
|
||||
SLAB_GRAPHQL_MAX_TRIES = 10
|
||||
SLAB_API_URL = "https://api.slab.com/v1/graphql"
|
||||
|
||||
_SLIM_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
def run_graphql_request(
|
||||
graphql_query: dict, bot_token: str, max_tries: int = SLAB_GRAPHQL_MAX_TRIES
|
||||
@@ -158,21 +163,26 @@ def get_slab_url_from_title_id(base_url: str, title: str, page_id: str) -> str:
|
||||
return urljoin(urljoin(base_url, "posts/"), url_id)
|
||||
|
||||
|
||||
class SlabConnector(LoadConnector, PollConnector):
|
||||
class SlabConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
slab_bot_token: str | None = None,
|
||||
) -> None:
|
||||
self.base_url = base_url
|
||||
self.batch_size = batch_size
|
||||
self.slab_bot_token = slab_bot_token
|
||||
self._slab_bot_token: str | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.slab_bot_token = credentials["slab_bot_token"]
|
||||
self._slab_bot_token = credentials["slab_bot_token"]
|
||||
return None
|
||||
|
||||
@property
|
||||
def slab_bot_token(self) -> str:
|
||||
if self._slab_bot_token is None:
|
||||
raise ConnectorMissingCredentialError("Slab")
|
||||
return self._slab_bot_token
|
||||
|
||||
def _iterate_posts(
|
||||
self, time_filter: Callable[[datetime], bool] | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
@@ -227,3 +237,21 @@ class SlabConnector(LoadConnector, PollConnector):
|
||||
yield from self._iterate_posts(
|
||||
time_filter=lambda t: start_time <= t <= end_time
|
||||
)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
for post_id in get_all_post_ids(self.slab_bot_token):
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=post_id,
|
||||
)
|
||||
)
|
||||
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
if slim_doc_batch:
|
||||
yield slim_doc_batch
|
||||
|
||||
@@ -171,7 +171,9 @@ def thread_to_doc(
|
||||
else first_message
|
||||
)
|
||||
|
||||
doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}"
|
||||
doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}".replace(
|
||||
"\n", " "
|
||||
)
|
||||
|
||||
return Document(
|
||||
id=f"{channel_id}__{thread[0]['ts']}",
|
||||
|
||||
@@ -33,7 +33,7 @@ def get_created_datetime(chat_message: ChatMessage) -> datetime:
|
||||
|
||||
def _extract_channel_members(channel: Channel) -> list[BasicExpertInfo]:
|
||||
channel_members_list: list[BasicExpertInfo] = []
|
||||
members = channel.members.get().execute_query()
|
||||
members = channel.members.get().execute_query_retry()
|
||||
for member in members:
|
||||
channel_members_list.append(BasicExpertInfo(display_name=member.display_name))
|
||||
return channel_members_list
|
||||
@@ -51,7 +51,7 @@ def _get_threads_from_channel(
|
||||
end = end.replace(tzinfo=timezone.utc)
|
||||
|
||||
query = channel.messages.get()
|
||||
base_messages: list[ChatMessage] = query.execute_query()
|
||||
base_messages: list[ChatMessage] = query.execute_query_retry()
|
||||
|
||||
threads: list[list[ChatMessage]] = []
|
||||
for base_message in base_messages:
|
||||
@@ -65,7 +65,7 @@ def _get_threads_from_channel(
|
||||
continue
|
||||
|
||||
reply_query = base_message.replies.get_all()
|
||||
replies = reply_query.execute_query()
|
||||
replies = reply_query.execute_query_retry()
|
||||
|
||||
# start a list containing the base message and its replies
|
||||
thread: list[ChatMessage] = [base_message]
|
||||
@@ -82,7 +82,7 @@ def _get_channels_from_teams(
|
||||
channels_list: list[Channel] = []
|
||||
for team in teams:
|
||||
query = team.channels.get()
|
||||
channels = query.execute_query()
|
||||
channels = query.execute_query_retry()
|
||||
channels_list.extend(channels)
|
||||
|
||||
return channels_list
|
||||
@@ -210,7 +210,7 @@ class TeamsConnector(LoadConnector, PollConnector):
|
||||
|
||||
teams_list: list[Team] = []
|
||||
|
||||
teams = self.graph_client.teams.get().execute_query()
|
||||
teams = self.graph_client.teams.get().execute_query_retry()
|
||||
|
||||
if len(self.requested_team_list) > 0:
|
||||
adjusted_request_strings = [
|
||||
@@ -234,14 +234,25 @@ class TeamsConnector(LoadConnector, PollConnector):
|
||||
raise ConnectorMissingCredentialError("Teams")
|
||||
|
||||
teams = self._get_all_teams()
|
||||
logger.debug(f"Found available teams: {[str(t) for t in teams]}")
|
||||
if not teams:
|
||||
msg = "No teams found."
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
channels = _get_channels_from_teams(
|
||||
teams=teams,
|
||||
)
|
||||
logger.debug(f"Found available channels: {[c.id for c in channels]}")
|
||||
if not channels:
|
||||
msg = "No channels found."
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
# goes over channels, converts them into Document objects and then yields them in batches
|
||||
doc_batch: list[Document] = []
|
||||
for channel in channels:
|
||||
logger.debug(f"Fetching threads from channel: {channel.id}")
|
||||
thread_list = _get_threads_from_channel(channel, start=start, end=end)
|
||||
for thread in thread_list:
|
||||
converted_doc = _convert_thread_to_document(channel, thread)
|
||||
@@ -259,8 +270,8 @@ class TeamsConnector(LoadConnector, PollConnector):
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_datetime = datetime.utcfromtimestamp(start)
|
||||
end_datetime = datetime.utcfromtimestamp(end)
|
||||
start_datetime = datetime.fromtimestamp(start, timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, timezone.utc)
|
||||
return self._fetch_from_teams(start=start_datetime, end=end_datetime)
|
||||
|
||||
|
||||
|
||||
@@ -102,13 +102,21 @@ def _get_tickets(
|
||||
|
||||
|
||||
def _fetch_author(client: ZendeskClient, author_id: str) -> BasicExpertInfo | None:
|
||||
author_data = client.make_request(f"users/{author_id}", {})
|
||||
user = author_data.get("user")
|
||||
return (
|
||||
BasicExpertInfo(display_name=user.get("name"), email=user.get("email"))
|
||||
if user and user.get("name") and user.get("email")
|
||||
else None
|
||||
)
|
||||
# Skip fetching if author_id is invalid
|
||||
if not author_id or author_id == "-1":
|
||||
return None
|
||||
|
||||
try:
|
||||
author_data = client.make_request(f"users/{author_id}", {})
|
||||
user = author_data.get("user")
|
||||
return (
|
||||
BasicExpertInfo(display_name=user.get("name"), email=user.get("email"))
|
||||
if user and user.get("name") and user.get("email")
|
||||
else None
|
||||
)
|
||||
except requests.exceptions.HTTPError:
|
||||
# Handle any API errors gracefully
|
||||
return None
|
||||
|
||||
|
||||
def _article_to_document(
|
||||
|
||||
@@ -8,13 +8,13 @@ from pydantic import field_validator
|
||||
|
||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.context.search.enums import LLMEvaluationType
|
||||
from danswer.context.search.enums import OptionalSearchSetting
|
||||
from danswer.context.search.enums import SearchType
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.indexing.models import BaseChunk
|
||||
from danswer.indexing.models import IndexingSetting
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.enums import SearchType
|
||||
from shared_configs.enums import RerankerProvider
|
||||
|
||||
|
||||
@@ -5,33 +5,33 @@ from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.models import SectionRelevancePiece
|
||||
from danswer.chat.prune_and_merge import _merge_sections
|
||||
from danswer.chat.prune_and_merge import ChunkRange
|
||||
from danswer.chat.prune_and_merge import merge_chunk_intervals
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
|
||||
from danswer.context.search.enums import LLMEvaluationType
|
||||
from danswer.context.search.enums import QueryFlow
|
||||
from danswer.context.search.enums import SearchType
|
||||
from danswer.context.search.models import IndexFilters
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import RerankMetricsContainer
|
||||
from danswer.context.search.models import RetrievalMetricsContainer
|
||||
from danswer.context.search.models import SearchQuery
|
||||
from danswer.context.search.models import SearchRequest
|
||||
from danswer.context.search.postprocessing.postprocessing import cleanup_chunks
|
||||
from danswer.context.search.postprocessing.postprocessing import search_postprocessing
|
||||
from danswer.context.search.preprocessing.preprocessing import retrieval_preprocessing
|
||||
from danswer.context.search.retrieval.search_runner import retrieve_chunks
|
||||
from danswer.context.search.utils import inference_section_from_chunks
|
||||
from danswer.context.search.utils import relevant_sections_to_indices
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import VespaChunkRequest
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prune_and_merge import _merge_sections
|
||||
from danswer.llm.answering.prune_and_merge import ChunkRange
|
||||
from danswer.llm.answering.prune_and_merge import merge_chunk_intervals
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchRequest
|
||||
from danswer.search.postprocessing.postprocessing import cleanup_chunks
|
||||
from danswer.search.postprocessing.postprocessing import search_postprocessing
|
||||
from danswer.search.preprocessing.preprocessing import retrieval_preprocessing
|
||||
from danswer.search.retrieval.search_runner import retrieve_chunks
|
||||
from danswer.search.utils import inference_section_from_chunks
|
||||
from danswer.search.utils import relevant_sections_to_indices
|
||||
from danswer.secondary_llm_flows.agentic_evaluation import evaluate_inference_section
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
@@ -9,19 +9,19 @@ from danswer.configs.app_configs import BLURB_SIZE
|
||||
from danswer.configs.constants import RETURN_SEPARATOR
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
||||
from danswer.context.search.enums import LLMEvaluationType
|
||||
from danswer.context.search.models import ChunkMetric
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.context.search.models import InferenceChunkUncleaned
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import MAX_METRICS_CONTENT
|
||||
from danswer.context.search.models import RerankMetricsContainer
|
||||
from danswer.context.search.models import SearchQuery
|
||||
from danswer.document_index.document_index_utils import (
|
||||
translate_boost_count_to_multiplier,
|
||||
)
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.natural_language_processing.search_nlp_models import RerankingModel
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.models import ChunkMetric
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceChunkUncleaned
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.models import MAX_METRICS_CONTENT
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
@@ -1,8 +1,8 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_acl_for_user
|
||||
from danswer.context.search.models import IndexFilters
|
||||
from danswer.db.models import User
|
||||
from danswer.search.models import IndexFilters
|
||||
|
||||
|
||||
def build_access_filters_for_user(user: User | None, session: Session) -> list[str]:
|
||||
@@ -9,21 +9,25 @@ from danswer.configs.chat_configs import HYBRID_ALPHA
|
||||
from danswer.configs.chat_configs import HYBRID_ALPHA_KEYWORD
|
||||
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
|
||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from danswer.context.search.enums import LLMEvaluationType
|
||||
from danswer.context.search.enums import RecencyBiasSetting
|
||||
from danswer.context.search.enums import SearchType
|
||||
from danswer.context.search.models import BaseFilters
|
||||
from danswer.context.search.models import IndexFilters
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import SearchQuery
|
||||
from danswer.context.search.models import SearchRequest
|
||||
from danswer.context.search.preprocessing.access_filters import (
|
||||
build_access_filters_for_user,
|
||||
)
|
||||
from danswer.context.search.retrieval.search_runner import (
|
||||
remove_stop_words_and_punctuation,
|
||||
)
|
||||
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.natural_language_processing.search_nlp_models import QueryAnalysisModel
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import BaseFilters
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import RerankingDetails
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchRequest
|
||||
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
|
||||
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
|
||||
from danswer.secondary_llm_flows.source_filter import extract_source_filter
|
||||
from danswer.secondary_llm_flows.time_filter import extract_time_filter
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -6,6 +6,16 @@ from nltk.corpus import stopwords # type:ignore
|
||||
from nltk.tokenize import word_tokenize # type:ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.context.search.models import ChunkMetric
|
||||
from danswer.context.search.models import IndexFilters
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.context.search.models import InferenceChunkUncleaned
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import MAX_METRICS_CONTENT
|
||||
from danswer.context.search.models import RetrievalMetricsContainer
|
||||
from danswer.context.search.models import SearchQuery
|
||||
from danswer.context.search.postprocessing.postprocessing import cleanup_chunks
|
||||
from danswer.context.search.utils import inference_section_from_chunks
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_multilingual_expansion
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
@@ -14,16 +24,6 @@ from danswer.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
)
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.search.models import ChunkMetric
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceChunkUncleaned
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.models import MAX_METRICS_CONTENT
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.postprocessing.postprocessing import cleanup_chunks
|
||||
from danswer.search.utils import inference_section_from_chunks
|
||||
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import cast
|
||||
|
||||
from danswer.configs.constants import KV_SEARCH_SETTINGS
|
||||
from danswer.context.search.models import SavedSearchSettings
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.search.models import SavedSearchSettings
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -2,12 +2,12 @@ from collections.abc import Sequence
|
||||
from typing import TypeVar
|
||||
|
||||
from danswer.chat.models import SectionRelevancePiece
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import SavedSearchDoc
|
||||
from danswer.context.search.models import SavedSearchDocWithContent
|
||||
from danswer.context.search.models import SearchDoc
|
||||
from danswer.db.models import SearchDoc as DBSearchDoc
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.models import SavedSearchDoc
|
||||
from danswer.search.models import SavedSearchDocWithContent
|
||||
from danswer.search.models import SearchDoc
|
||||
|
||||
|
||||
T = TypeVar(
|
||||
@@ -16,24 +16,31 @@ from slack_sdk.models.blocks import SectionBlock
|
||||
from slack_sdk.models.blocks.basic_components import MarkdownTextObject
|
||||
from slack_sdk.models.blocks.block_elements import ImageElement
|
||||
|
||||
from danswer.chat.models import DanswerQuote
|
||||
from danswer.chat.models import ChatDanswerBotResponse
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
|
||||
from danswer.context.search.models import SavedSearchDoc
|
||||
from danswer.danswerbot.slack.constants import CONTINUE_IN_WEB_UI_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.formatting import format_slack_message
|
||||
from danswer.danswerbot.slack.icons import source_to_github_img_link
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import build_continue_in_web_ui_id
|
||||
from danswer.danswerbot.slack.utils import build_feedback_id
|
||||
from danswer.danswerbot.slack.utils import remove_slack_text_interactions
|
||||
from danswer.danswerbot.slack.utils import translate_vespa_highlight_to_slack
|
||||
from danswer.search.models import SavedSearchDoc
|
||||
from danswer.db.chat import get_chat_session_by_message_id
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.utils.text_processing import decode_escapes
|
||||
from danswer.utils.text_processing import replace_whitespaces_w_space
|
||||
|
||||
_MAX_BLURB_LEN = 45
|
||||
|
||||
@@ -101,12 +108,12 @@ def _split_text(text: str, limit: int = 3000) -> list[str]:
|
||||
return chunks
|
||||
|
||||
|
||||
def clean_markdown_link_text(text: str) -> str:
|
||||
def _clean_markdown_link_text(text: str) -> str:
|
||||
# Remove any newlines within the text
|
||||
return text.replace("\n", " ").strip()
|
||||
|
||||
|
||||
def build_qa_feedback_block(
|
||||
def _build_qa_feedback_block(
|
||||
message_id: int, feedback_reminder_id: str | None = None
|
||||
) -> Block:
|
||||
return ActionsBlock(
|
||||
@@ -115,7 +122,6 @@ def build_qa_feedback_block(
|
||||
ButtonElement(
|
||||
action_id=LIKE_BLOCK_ACTION_ID,
|
||||
text="👍 Helpful",
|
||||
style="primary",
|
||||
value=feedback_reminder_id,
|
||||
),
|
||||
ButtonElement(
|
||||
@@ -155,7 +161,7 @@ def get_document_feedback_blocks() -> Block:
|
||||
)
|
||||
|
||||
|
||||
def build_doc_feedback_block(
|
||||
def _build_doc_feedback_block(
|
||||
message_id: int,
|
||||
document_id: str,
|
||||
document_rank: int,
|
||||
@@ -182,7 +188,7 @@ def get_restate_blocks(
|
||||
]
|
||||
|
||||
|
||||
def build_documents_blocks(
|
||||
def _build_documents_blocks(
|
||||
documents: list[SavedSearchDoc],
|
||||
message_id: int | None,
|
||||
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
|
||||
@@ -198,7 +204,8 @@ def build_documents_blocks(
|
||||
continue
|
||||
seen_docs_identifiers.add(d.document_id)
|
||||
|
||||
doc_sem_id = d.semantic_identifier
|
||||
# Strip newlines from the semantic identifier for Slackbot formatting
|
||||
doc_sem_id = d.semantic_identifier.replace("\n", " ")
|
||||
if d.source_type == DocumentSource.SLACK.value:
|
||||
doc_sem_id = "#" + doc_sem_id
|
||||
|
||||
@@ -223,7 +230,7 @@ def build_documents_blocks(
|
||||
|
||||
feedback: ButtonElement | dict = {}
|
||||
if message_id is not None:
|
||||
feedback = build_doc_feedback_block(
|
||||
feedback = _build_doc_feedback_block(
|
||||
message_id=message_id,
|
||||
document_id=d.document_id,
|
||||
document_rank=rank,
|
||||
@@ -241,7 +248,7 @@ def build_documents_blocks(
|
||||
return section_blocks
|
||||
|
||||
|
||||
def build_sources_blocks(
|
||||
def _build_sources_blocks(
|
||||
cited_documents: list[tuple[int, SavedSearchDoc]],
|
||||
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
|
||||
) -> list[Block]:
|
||||
@@ -286,7 +293,7 @@ def build_sources_blocks(
|
||||
+ ([days_ago_str] if days_ago_str else [])
|
||||
)
|
||||
|
||||
document_title = clean_markdown_link_text(doc_sem_id)
|
||||
document_title = _clean_markdown_link_text(doc_sem_id)
|
||||
img_link = source_to_github_img_link(d.source_type)
|
||||
|
||||
section_blocks.append(
|
||||
@@ -317,106 +324,105 @@ def build_sources_blocks(
|
||||
return section_blocks
|
||||
|
||||
|
||||
def build_quotes_block(
|
||||
quotes: list[DanswerQuote],
|
||||
def _priority_ordered_documents_blocks(
|
||||
answer: ChatDanswerBotResponse,
|
||||
) -> list[Block]:
|
||||
quote_lines: list[str] = []
|
||||
doc_to_quotes: dict[str, list[str]] = {}
|
||||
doc_to_link: dict[str, str] = {}
|
||||
doc_to_sem_id: dict[str, str] = {}
|
||||
for q in quotes:
|
||||
quote = q.quote
|
||||
doc_id = q.document_id
|
||||
doc_link = q.link
|
||||
doc_name = q.semantic_identifier
|
||||
if doc_link and doc_name and doc_id and quote:
|
||||
if doc_id not in doc_to_quotes:
|
||||
doc_to_quotes[doc_id] = [quote]
|
||||
doc_to_link[doc_id] = doc_link
|
||||
doc_to_sem_id[doc_id] = (
|
||||
doc_name
|
||||
if q.source_type != DocumentSource.SLACK.value
|
||||
else "#" + doc_name
|
||||
)
|
||||
else:
|
||||
doc_to_quotes[doc_id].append(quote)
|
||||
|
||||
for doc_id, quote_strs in doc_to_quotes.items():
|
||||
quotes_str_clean = [
|
||||
replace_whitespaces_w_space(q_str).strip() for q_str in quote_strs
|
||||
]
|
||||
longest_quotes = sorted(quotes_str_clean, key=len, reverse=True)[:5]
|
||||
single_quote_str = "\n".join([f"```{q_str}```" for q_str in longest_quotes])
|
||||
link = doc_to_link[doc_id]
|
||||
sem_id = doc_to_sem_id[doc_id]
|
||||
quote_lines.append(
|
||||
f"<{link}|{sem_id}>:\n{remove_slack_text_interactions(single_quote_str)}"
|
||||
)
|
||||
|
||||
if not doc_to_quotes:
|
||||
docs_response = answer.docs if answer.docs else None
|
||||
top_docs = docs_response.top_documents if docs_response else []
|
||||
llm_doc_inds = answer.llm_selected_doc_indices or []
|
||||
llm_docs = [top_docs[i] for i in llm_doc_inds]
|
||||
remaining_docs = [
|
||||
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
|
||||
]
|
||||
priority_ordered_docs = llm_docs + remaining_docs
|
||||
if not priority_ordered_docs:
|
||||
return []
|
||||
|
||||
return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))]
|
||||
document_blocks = _build_documents_blocks(
|
||||
documents=priority_ordered_docs,
|
||||
message_id=answer.chat_message_id,
|
||||
)
|
||||
if document_blocks:
|
||||
document_blocks = [DividerBlock()] + document_blocks
|
||||
return document_blocks
|
||||
|
||||
|
||||
def build_qa_response_blocks(
|
||||
message_id: int | None,
|
||||
answer: str | None,
|
||||
quotes: list[DanswerQuote] | None,
|
||||
source_filters: list[DocumentSource] | None,
|
||||
time_cutoff: datetime | None,
|
||||
favor_recent: bool,
|
||||
skip_quotes: bool = False,
|
||||
process_message_for_citations: bool = False,
|
||||
skip_ai_feedback: bool = False,
|
||||
feedback_reminder_id: str | None = None,
|
||||
def _build_citations_blocks(
|
||||
answer: ChatDanswerBotResponse,
|
||||
) -> list[Block]:
|
||||
docs_response = answer.docs if answer.docs else None
|
||||
top_docs = docs_response.top_documents if docs_response else []
|
||||
citations = answer.citations or []
|
||||
cited_docs = []
|
||||
for citation in citations:
|
||||
matching_doc = next(
|
||||
(d for d in top_docs if d.document_id == citation.document_id),
|
||||
None,
|
||||
)
|
||||
if matching_doc:
|
||||
cited_docs.append((citation.citation_num, matching_doc))
|
||||
|
||||
cited_docs.sort()
|
||||
citations_block = _build_sources_blocks(cited_documents=cited_docs)
|
||||
return citations_block
|
||||
|
||||
|
||||
def _build_qa_response_blocks(
|
||||
answer: ChatDanswerBotResponse,
|
||||
process_message_for_citations: bool = False,
|
||||
) -> list[Block]:
|
||||
retrieval_info = answer.docs
|
||||
if not retrieval_info:
|
||||
# This should not happen, even with no docs retrieved, there is still info returned
|
||||
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
|
||||
|
||||
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
return []
|
||||
|
||||
quotes_blocks: list[Block] = []
|
||||
|
||||
filter_block: Block | None = None
|
||||
if time_cutoff or favor_recent or source_filters:
|
||||
if (
|
||||
retrieval_info.applied_time_cutoff
|
||||
or retrieval_info.recency_bias_multiplier > 1
|
||||
or retrieval_info.applied_source_filters
|
||||
):
|
||||
filter_text = "Filters: "
|
||||
if source_filters:
|
||||
sources_str = ", ".join([s.value for s in source_filters])
|
||||
if retrieval_info.applied_source_filters:
|
||||
sources_str = ", ".join(
|
||||
[s.value for s in retrieval_info.applied_source_filters]
|
||||
)
|
||||
filter_text += f"`Sources in [{sources_str}]`"
|
||||
if time_cutoff or favor_recent:
|
||||
if (
|
||||
retrieval_info.applied_time_cutoff
|
||||
or retrieval_info.recency_bias_multiplier > 1
|
||||
):
|
||||
filter_text += " and "
|
||||
if time_cutoff is not None:
|
||||
time_str = time_cutoff.strftime("%b %d, %Y")
|
||||
if retrieval_info.applied_time_cutoff is not None:
|
||||
time_str = retrieval_info.applied_time_cutoff.strftime("%b %d, %Y")
|
||||
filter_text += f"`Docs Updated >= {time_str}` "
|
||||
if favor_recent:
|
||||
if time_cutoff is not None:
|
||||
if retrieval_info.recency_bias_multiplier > 1:
|
||||
if retrieval_info.applied_time_cutoff is not None:
|
||||
filter_text += "+ "
|
||||
filter_text += "`Prioritize Recently Updated Docs`"
|
||||
|
||||
filter_block = SectionBlock(text=f"_{filter_text}_")
|
||||
|
||||
if not answer:
|
||||
if not formatted_answer:
|
||||
answer_blocks = [
|
||||
SectionBlock(
|
||||
text="Sorry, I was unable to find an answer, but I did find some potentially relevant docs 🤓"
|
||||
)
|
||||
]
|
||||
else:
|
||||
answer_processed = decode_escapes(remove_slack_text_interactions(answer))
|
||||
answer_processed = decode_escapes(
|
||||
remove_slack_text_interactions(formatted_answer)
|
||||
)
|
||||
if process_message_for_citations:
|
||||
answer_processed = _process_citations_for_slack(answer_processed)
|
||||
answer_blocks = [
|
||||
SectionBlock(text=text) for text in _split_text(answer_processed)
|
||||
]
|
||||
if quotes:
|
||||
quotes_blocks = build_quotes_block(quotes)
|
||||
|
||||
# if no quotes OR `build_quotes_block()` did not give back any blocks
|
||||
if not quotes_blocks:
|
||||
quotes_blocks = [
|
||||
SectionBlock(
|
||||
text="*Warning*: no sources were quoted for this answer, so it may be unreliable 😔"
|
||||
)
|
||||
]
|
||||
|
||||
response_blocks: list[Block] = []
|
||||
|
||||
@@ -425,20 +431,34 @@ def build_qa_response_blocks(
|
||||
|
||||
response_blocks.extend(answer_blocks)
|
||||
|
||||
if message_id is not None and not skip_ai_feedback:
|
||||
response_blocks.append(
|
||||
build_qa_feedback_block(
|
||||
message_id=message_id, feedback_reminder_id=feedback_reminder_id
|
||||
)
|
||||
)
|
||||
|
||||
if not skip_quotes:
|
||||
response_blocks.extend(quotes_blocks)
|
||||
|
||||
return response_blocks
|
||||
|
||||
|
||||
def build_follow_up_block(message_id: int | None) -> ActionsBlock:
|
||||
def _build_continue_in_web_ui_block(
|
||||
tenant_id: str | None,
|
||||
message_id: int | None,
|
||||
) -> Block:
|
||||
if message_id is None:
|
||||
raise ValueError("No message id provided to build continue in web ui block")
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
chat_session = get_chat_session_by_message_id(
|
||||
db_session=db_session,
|
||||
message_id=message_id,
|
||||
)
|
||||
return ActionsBlock(
|
||||
block_id=build_continue_in_web_ui_id(message_id),
|
||||
elements=[
|
||||
ButtonElement(
|
||||
action_id=CONTINUE_IN_WEB_UI_ACTION_ID,
|
||||
text="Continue Chat in Danswer!",
|
||||
style="primary",
|
||||
url=f"{WEB_DOMAIN}/chat?slackChatId={chat_session.id}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _build_follow_up_block(message_id: int | None) -> ActionsBlock:
|
||||
return ActionsBlock(
|
||||
block_id=build_feedback_id(message_id) if message_id is not None else None,
|
||||
elements=[
|
||||
@@ -483,3 +503,75 @@ def build_follow_up_resolved_blocks(
|
||||
]
|
||||
)
|
||||
return [text_block, button_block]
|
||||
|
||||
|
||||
def build_slack_response_blocks(
|
||||
answer: ChatDanswerBotResponse,
|
||||
tenant_id: str | None,
|
||||
message_info: SlackMessageInfo,
|
||||
channel_conf: ChannelConfig | None,
|
||||
use_citations: bool,
|
||||
feedback_reminder_id: str | None,
|
||||
skip_ai_feedback: bool = False,
|
||||
) -> list[Block]:
|
||||
"""
|
||||
This function is a top level function that builds all the blocks for the Slack response.
|
||||
It also handles combining all the blocks together.
|
||||
"""
|
||||
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
|
||||
restate_question_block = get_restate_blocks(
|
||||
message_info.thread_messages[-1].message, message_info.is_bot_msg
|
||||
)
|
||||
|
||||
answer_blocks = _build_qa_response_blocks(
|
||||
answer=answer,
|
||||
process_message_for_citations=use_citations,
|
||||
)
|
||||
|
||||
web_follow_up_block = []
|
||||
if channel_conf and channel_conf.get("show_continue_in_web_ui"):
|
||||
web_follow_up_block.append(
|
||||
_build_continue_in_web_ui_block(
|
||||
tenant_id=tenant_id,
|
||||
message_id=answer.chat_message_id,
|
||||
)
|
||||
)
|
||||
|
||||
follow_up_block = []
|
||||
if channel_conf and channel_conf.get("follow_up_tags") is not None:
|
||||
follow_up_block.append(
|
||||
_build_follow_up_block(message_id=answer.chat_message_id)
|
||||
)
|
||||
|
||||
ai_feedback_block = []
|
||||
if answer.chat_message_id is not None and not skip_ai_feedback:
|
||||
ai_feedback_block.append(
|
||||
_build_qa_feedback_block(
|
||||
message_id=answer.chat_message_id,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
)
|
||||
|
||||
citations_blocks = []
|
||||
document_blocks = []
|
||||
if use_citations and answer.citations:
|
||||
citations_blocks = _build_citations_blocks(answer)
|
||||
else:
|
||||
document_blocks = _priority_ordered_documents_blocks(answer)
|
||||
|
||||
citations_divider = [DividerBlock()] if citations_blocks else []
|
||||
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
|
||||
|
||||
all_blocks = (
|
||||
restate_question_block
|
||||
+ answer_blocks
|
||||
+ ai_feedback_block
|
||||
+ citations_divider
|
||||
+ citations_blocks
|
||||
+ document_blocks
|
||||
+ buttons_divider
|
||||
+ web_follow_up_block
|
||||
+ follow_up_block
|
||||
)
|
||||
|
||||
return all_blocks
|
||||
|
||||
@@ -2,6 +2,7 @@ from enum import Enum
|
||||
|
||||
LIKE_BLOCK_ACTION_ID = "feedback-like"
|
||||
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
|
||||
CONTINUE_IN_WEB_UI_ACTION_ID = "continue-in-web-ui"
|
||||
FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button"
|
||||
IMMEDIATE_RESOLVED_BUTTON_ACTION_ID = "immediate-resolved-button"
|
||||
FOLLOWUP_BUTTON_ACTION_ID = "followup-button"
|
||||
|
||||
@@ -28,7 +28,7 @@ from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import build_feedback_id
|
||||
from danswer.danswerbot.slack.utils import decompose_action_id
|
||||
from danswer.danswerbot.slack.utils import fetch_group_ids_from_names
|
||||
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
|
||||
from danswer.danswerbot.slack.utils import fetch_slack_user_ids_from_emails
|
||||
from danswer.danswerbot.slack.utils import get_channel_name_from_id
|
||||
from danswer.danswerbot.slack.utils import get_feedback_visibility
|
||||
from danswer.danswerbot.slack.utils import read_slack_thread
|
||||
@@ -267,7 +267,7 @@ def handle_followup_button(
|
||||
tag_names = slack_channel_config.channel_config.get("follow_up_tags")
|
||||
remaining = None
|
||||
if tag_names:
|
||||
tag_ids, remaining = fetch_user_ids_from_emails(
|
||||
tag_ids, remaining = fetch_slack_user_ids_from_emails(
|
||||
tag_names, client.web_client
|
||||
)
|
||||
if remaining:
|
||||
|
||||
@@ -13,7 +13,7 @@ from danswer.danswerbot.slack.handlers.handle_standard_answers import (
|
||||
handle_standard_answers,
|
||||
)
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
|
||||
from danswer.danswerbot.slack.utils import fetch_slack_user_ids_from_emails
|
||||
from danswer.danswerbot.slack.utils import fetch_user_ids_from_groups
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import slack_usage_report
|
||||
@@ -184,7 +184,7 @@ def handle_message(
|
||||
send_to: list[str] | None = None
|
||||
missing_users: list[str] | None = None
|
||||
if respond_member_group_list:
|
||||
send_to, missing_ids = fetch_user_ids_from_emails(
|
||||
send_to, missing_ids = fetch_slack_user_ids_from_emails(
|
||||
respond_member_group_list, client
|
||||
)
|
||||
|
||||
|
||||
@@ -1,60 +1,43 @@
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.models.blocks import DividerBlock
|
||||
from slack_sdk.models.blocks import SectionBlock
|
||||
|
||||
from danswer.chat.chat_utils import prepare_chat_message_request
|
||||
from danswer.chat.models import ChatDanswerBotResponse
|
||||
from danswer.chat.process_message import gather_stream_for_slack
|
||||
from danswer.chat.process_message import stream_chat_message_objects
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
|
||||
from danswer.configs.constants import DEFAULT_PERSONA_ID
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
|
||||
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
|
||||
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
|
||||
from danswer.danswerbot.slack.blocks import build_documents_blocks
|
||||
from danswer.danswerbot.slack.blocks import build_follow_up_block
|
||||
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
|
||||
from danswer.danswerbot.slack.blocks import build_sources_blocks
|
||||
from danswer.danswerbot.slack.blocks import get_restate_blocks
|
||||
from danswer.danswerbot.slack.formatting import format_slack_message
|
||||
from danswer.configs.danswerbot_configs import MAX_THREAD_CONTEXT_PERCENTAGE
|
||||
from danswer.context.search.enums import OptionalSearchSetting
|
||||
from danswer.context.search.models import BaseFilters
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.danswerbot.slack.blocks import build_slack_response_blocks
|
||||
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
|
||||
from danswer.danswerbot.slack.handlers.utils import slackify_message_thread
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import SlackRateLimiter
|
||||
from danswer.danswerbot.slack.utils import update_emote_react
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.models import SlackChannelConfig
|
||||
from danswer.db.persona import fetch_persona_by_id
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_persona_by_id
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.one_shot_answer.answer_question import get_search_answer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.models import BaseFilters
|
||||
from danswer.search.models import RerankingDetails
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.utils.logger import DanswerLoggingAdapter
|
||||
|
||||
|
||||
srl = SlackRateLimiter()
|
||||
|
||||
RT = TypeVar("RT") # return type
|
||||
@@ -89,16 +72,14 @@ def handle_regular_answer(
|
||||
feedback_reminder_id: str | None,
|
||||
tenant_id: str | None,
|
||||
num_retries: int = DANSWER_BOT_NUM_RETRIES,
|
||||
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
|
||||
thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
|
||||
thread_context_percent: float = MAX_THREAD_CONTEXT_PERCENTAGE,
|
||||
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
|
||||
disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER,
|
||||
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
|
||||
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
|
||||
) -> bool:
|
||||
channel_conf = slack_channel_config.channel_config if slack_channel_config else None
|
||||
|
||||
messages = message_info.thread_messages
|
||||
|
||||
message_ts_to_respond_to = message_info.msg_to_respond
|
||||
is_bot_msg = message_info.is_bot_msg
|
||||
user = None
|
||||
@@ -108,9 +89,18 @@ def handle_regular_answer(
|
||||
user = get_user_by_email(message_info.email, db_session)
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
persona = slack_channel_config.persona if slack_channel_config else None
|
||||
prompt = None
|
||||
if persona:
|
||||
# If no persona is specified, use the default search based persona
|
||||
# This way slack flow always has a persona
|
||||
persona = slack_channel_config.persona if slack_channel_config else None
|
||||
if not persona:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
persona = get_persona_by_id(DEFAULT_PERSONA_ID, user, db_session)
|
||||
document_set_names = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
]
|
||||
prompt = persona.prompts[0] if persona.prompts else None
|
||||
else:
|
||||
document_set_names = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
]
|
||||
@@ -118,6 +108,26 @@ def handle_regular_answer(
|
||||
|
||||
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
|
||||
|
||||
# TODO: Add in support for Slack to truncate messages based on max LLM context
|
||||
# llm, _ = get_llms_for_persona(persona)
|
||||
|
||||
# llm_tokenizer = get_tokenizer(
|
||||
# model_name=llm.config.model_name,
|
||||
# provider_type=llm.config.model_provider,
|
||||
# )
|
||||
|
||||
# # In cases of threads, split the available tokens between docs and thread context
|
||||
# input_tokens = get_max_input_tokens(
|
||||
# model_name=llm.config.model_name,
|
||||
# model_provider=llm.config.model_provider,
|
||||
# )
|
||||
# max_history_tokens = int(input_tokens * thread_context_percent)
|
||||
# combined_message = combine_message_thread(
|
||||
# messages, max_tokens=max_history_tokens, llm_tokenizer=llm_tokenizer
|
||||
# )
|
||||
|
||||
combined_message = slackify_message_thread(messages)
|
||||
|
||||
bypass_acl = False
|
||||
if (
|
||||
slack_channel_config
|
||||
@@ -128,13 +138,6 @@ def handle_regular_answer(
|
||||
# with non-public document sets
|
||||
bypass_acl = True
|
||||
|
||||
# figure out if we want to use citations or quotes
|
||||
use_citations = (
|
||||
not DANSWER_BOT_USE_QUOTES
|
||||
if slack_channel_config is None
|
||||
else slack_channel_config.response_type == SlackBotResponseType.CITATIONS
|
||||
)
|
||||
|
||||
if not message_ts_to_respond_to and not is_bot_msg:
|
||||
# if the message is not "/danswer" command, then it should have a message ts to respond to
|
||||
raise RuntimeError(
|
||||
@@ -147,75 +150,23 @@ def handle_regular_answer(
|
||||
backoff=2,
|
||||
)
|
||||
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
|
||||
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | None:
|
||||
max_document_tokens: int | None = None
|
||||
max_history_tokens: int | None = None
|
||||
|
||||
def _get_slack_answer(
|
||||
new_message_request: CreateChatMessageRequest, danswer_user: User | None
|
||||
) -> ChatDanswerBotResponse:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
if len(new_message_request.messages) > 1:
|
||||
if new_message_request.persona_config:
|
||||
raise RuntimeError("Slack bot does not support persona config")
|
||||
elif new_message_request.persona_id is not None:
|
||||
persona = cast(
|
||||
Persona,
|
||||
fetch_persona_by_id(
|
||||
db_session,
|
||||
new_message_request.persona_id,
|
||||
user=None,
|
||||
get_editable=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"No persona id provided, this should never happen."
|
||||
)
|
||||
|
||||
llm, _ = get_llms_for_persona(persona)
|
||||
|
||||
# In cases of threads, split the available tokens between docs and thread context
|
||||
input_tokens = get_max_input_tokens(
|
||||
model_name=llm.config.model_name,
|
||||
model_provider=llm.config.model_provider,
|
||||
)
|
||||
max_history_tokens = int(input_tokens * thread_context_percent)
|
||||
|
||||
remaining_tokens = input_tokens - max_history_tokens
|
||||
|
||||
query_text = new_message_request.messages[0].message
|
||||
if persona:
|
||||
max_document_tokens = compute_max_document_tokens_for_persona(
|
||||
persona=persona,
|
||||
actual_user_input=query_text,
|
||||
max_llm_token_override=remaining_tokens,
|
||||
)
|
||||
else:
|
||||
max_document_tokens = (
|
||||
remaining_tokens
|
||||
- 512 # Needs to be more than any of the QA prompts
|
||||
- check_number_of_tokens(query_text)
|
||||
)
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
return None
|
||||
|
||||
# This also handles creating the query event in postgres
|
||||
answer = get_search_answer(
|
||||
query_req=new_message_request,
|
||||
user=user,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=new_message_request,
|
||||
user=danswer_user,
|
||||
db_session=db_session,
|
||||
answer_generation_timeout=answer_generation_timeout,
|
||||
enable_reflexion=reflexion,
|
||||
bypass_acl=bypass_acl,
|
||||
use_citations=use_citations,
|
||||
danswerbot_flow=True,
|
||||
)
|
||||
|
||||
if not answer.error_msg:
|
||||
return answer
|
||||
else:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
answer = gather_stream_for_slack(packets)
|
||||
|
||||
if answer.error_msg:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
|
||||
return answer
|
||||
|
||||
try:
|
||||
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
|
||||
@@ -245,26 +196,24 @@ def handle_regular_answer(
|
||||
enable_auto_detect_filters=auto_detect_filters,
|
||||
)
|
||||
|
||||
# Always apply reranking settings if it exists, this is the non-streaming flow
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
saved_search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# This includes throwing out answer via reflexion
|
||||
answer = _get_answer(
|
||||
DirectQARequest(
|
||||
messages=messages,
|
||||
multilingual_query_expansion=saved_search_settings.multilingual_expansion
|
||||
if saved_search_settings
|
||||
else None,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
persona_id=persona.id if persona is not None else 0,
|
||||
retrieval_options=retrieval_details,
|
||||
chain_of_thought=not disable_cot,
|
||||
rerank_settings=RerankingDetails.from_db_model(saved_search_settings)
|
||||
if saved_search_settings
|
||||
else None,
|
||||
answer_request = prepare_chat_message_request(
|
||||
message_text=combined_message,
|
||||
user=user,
|
||||
persona_id=persona.id,
|
||||
# This is not used in the Slack flow, only in the answer API
|
||||
persona_override_config=None,
|
||||
prompt=prompt,
|
||||
message_ts_to_respond_to=message_ts_to_respond_to,
|
||||
retrieval_details=retrieval_details,
|
||||
rerank_settings=None, # Rerank customization supported in Slack flow
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
answer = _get_slack_answer(
|
||||
new_message_request=answer_request, danswer_user=user
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unable to process message - did not successfully answer "
|
||||
@@ -365,7 +314,7 @@ def handle_regular_answer(
|
||||
top_docs = retrieval_info.top_documents
|
||||
if not top_docs and not should_respond_even_with_no_docs:
|
||||
logger.error(
|
||||
f"Unable to answer question: '{answer.rephrase}' - no documents found"
|
||||
f"Unable to answer question: '{combined_message}' - no documents found"
|
||||
)
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
@@ -386,18 +335,18 @@ def handle_regular_answer(
|
||||
)
|
||||
return True
|
||||
|
||||
only_respond_with_citations_or_quotes = (
|
||||
only_respond_if_citations = (
|
||||
channel_conf
|
||||
and "well_answered_postfilter" in channel_conf.get("answer_filters", [])
|
||||
)
|
||||
has_citations_or_quotes = bool(answer.citations or answer.quotes)
|
||||
|
||||
if (
|
||||
only_respond_with_citations_or_quotes
|
||||
and not has_citations_or_quotes
|
||||
only_respond_if_citations
|
||||
and not answer.citations
|
||||
and not message_info.bypass_filters
|
||||
):
|
||||
logger.error(
|
||||
f"Unable to find citations or quotes to answer: '{answer.rephrase}' - not answering!"
|
||||
f"Unable to find citations to answer: '{answer.answer}' - not answering!"
|
||||
)
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
@@ -411,67 +360,22 @@ def handle_regular_answer(
|
||||
)
|
||||
return True
|
||||
|
||||
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
|
||||
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
|
||||
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
|
||||
|
||||
answer_blocks = build_qa_response_blocks(
|
||||
message_id=answer.chat_message_id,
|
||||
answer=formatted_answer,
|
||||
quotes=answer.quotes.quotes if answer.quotes else None,
|
||||
source_filters=retrieval_info.applied_source_filters,
|
||||
time_cutoff=retrieval_info.applied_time_cutoff,
|
||||
favor_recent=retrieval_info.recency_bias_multiplier > 1,
|
||||
# currently Personas don't support quotes
|
||||
# if citations are enabled, also don't use quotes
|
||||
skip_quotes=persona is not None or use_citations,
|
||||
process_message_for_citations=use_citations,
|
||||
all_blocks = build_slack_response_blocks(
|
||||
tenant_id=tenant_id,
|
||||
message_info=message_info,
|
||||
answer=answer,
|
||||
channel_conf=channel_conf,
|
||||
use_citations=True, # No longer supporting quotes
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
|
||||
# Get the chunks fed to the LLM only, then fill with other docs
|
||||
llm_doc_inds = answer.llm_selected_doc_indices or []
|
||||
llm_docs = [top_docs[i] for i in llm_doc_inds]
|
||||
remaining_docs = [
|
||||
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
|
||||
]
|
||||
priority_ordered_docs = llm_docs + remaining_docs
|
||||
|
||||
document_blocks = []
|
||||
citations_block = []
|
||||
# if citations are enabled, only show cited documents
|
||||
if use_citations:
|
||||
citations = answer.citations or []
|
||||
cited_docs = []
|
||||
for citation in citations:
|
||||
matching_doc = next(
|
||||
(d for d in top_docs if d.document_id == citation.document_id),
|
||||
None,
|
||||
)
|
||||
if matching_doc:
|
||||
cited_docs.append((citation.citation_num, matching_doc))
|
||||
|
||||
cited_docs.sort()
|
||||
citations_block = build_sources_blocks(cited_documents=cited_docs)
|
||||
elif priority_ordered_docs:
|
||||
document_blocks = build_documents_blocks(
|
||||
documents=priority_ordered_docs,
|
||||
message_id=answer.chat_message_id,
|
||||
)
|
||||
document_blocks = [DividerBlock()] + document_blocks
|
||||
|
||||
all_blocks = (
|
||||
restate_question_block + answer_blocks + citations_block + document_blocks
|
||||
)
|
||||
|
||||
if channel_conf and channel_conf.get("follow_up_tags") is not None:
|
||||
all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id))
|
||||
|
||||
try:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=receiver_ids,
|
||||
receiver_ids=[message_info.sender]
|
||||
if message_info.is_bot_msg and message_info.sender
|
||||
else receiver_ids,
|
||||
text="Hello! Danswer has some results for you!",
|
||||
blocks=all_blocks,
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
|
||||
@@ -1,8 +1,33 @@
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from danswer.chat.models import ThreadMessage
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
|
||||
|
||||
def slackify_message_thread(messages: list[ThreadMessage]) -> str:
|
||||
# Note: this does not handle extremely long threads, every message will be included
|
||||
# with weaker LLMs, this could cause issues with exceeeding the token limit
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
for message in messages:
|
||||
if message.role == MessageType.USER:
|
||||
message_text = (
|
||||
f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
|
||||
)
|
||||
elif message.role == MessageType.ASSISTANT:
|
||||
message_text = f"AI said in Slack:\n{message.message}"
|
||||
else:
|
||||
message_text = (
|
||||
f"{message.role.value.upper()} said in Slack:\n{message.message}"
|
||||
)
|
||||
message_strs.append(message_text)
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def send_team_member_message(
|
||||
client: WebClient,
|
||||
channel: str,
|
||||
|
||||
@@ -19,6 +19,8 @@ from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import ThreadMessage
|
||||
from danswer.configs.app_configs import DEV_MODE
|
||||
from danswer.configs.app_configs import POD_NAME
|
||||
from danswer.configs.app_configs import POD_NAMESPACE
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
@@ -27,6 +29,7 @@ from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
|
||||
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
|
||||
from danswer.connectors.slack.utils import expert_info_from_slack_id
|
||||
from danswer.context.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.danswerbot.slack.config import get_slack_channel_config_for_bot_and_channel
|
||||
from danswer.danswerbot.slack.config import MAX_TENANTS_PER_POD
|
||||
from danswer.danswerbot.slack.config import TENANT_ACQUISITION_INTERVAL
|
||||
@@ -73,9 +76,7 @@ from danswer.db.slack_bot import fetch_slack_bots
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
@@ -250,7 +251,7 @@ class SlackbotHandler:
|
||||
nx=True,
|
||||
ex=TENANT_LOCK_EXPIRATION,
|
||||
)
|
||||
if not acquired:
|
||||
if not acquired and not DEV_MODE:
|
||||
logger.debug(f"Another pod holds the lock for tenant {tenant_id}")
|
||||
continue
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.chat.models import ThreadMessage
|
||||
|
||||
|
||||
class SlackMessageInfo(BaseModel):
|
||||
|
||||
@@ -3,14 +3,15 @@ import random
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
|
||||
from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.models.blocks import Block
|
||||
from slack_sdk.models.blocks import SectionBlock
|
||||
from slack_sdk.models.metadata import Metadata
|
||||
from slack_sdk.socket_mode import SocketModeClient
|
||||
|
||||
@@ -30,13 +31,13 @@ from danswer.configs.danswerbot_configs import (
|
||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.connectors.slack.utils import SlackTextCleaner
|
||||
from danswer.danswerbot.slack.constants import FeedbackVisibility
|
||||
from danswer.danswerbot.slack.models import ThreadMessage
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.prompts.miscellaneous_prompts import SLACK_LANGUAGE_REPHRASE_PROMPT
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
@@ -140,6 +141,40 @@ def remove_danswer_bot_tag(message_str: str, client: WebClient) -> str:
|
||||
return re.sub(rf"<@{bot_tag_id}>\s", "", message_str)
|
||||
|
||||
|
||||
def _check_for_url_in_block(block: Block) -> bool:
|
||||
"""
|
||||
Check if the block has a key that contains "url" in it
|
||||
"""
|
||||
block_dict = block.to_dict()
|
||||
|
||||
def check_dict_for_url(d: dict) -> bool:
|
||||
for key, value in d.items():
|
||||
if "url" in key.lower():
|
||||
return True
|
||||
if isinstance(value, dict):
|
||||
if check_dict_for_url(value):
|
||||
return True
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict) and check_dict_for_url(item):
|
||||
return True
|
||||
return False
|
||||
|
||||
return check_dict_for_url(block_dict)
|
||||
|
||||
|
||||
def _build_error_block(error_message: str) -> Block:
|
||||
"""
|
||||
Build an error block to display in slack so that the user can see
|
||||
the error without completely breaking
|
||||
"""
|
||||
display_text = (
|
||||
"There was an error displaying all of the Onyx answers."
|
||||
f" Please let an admin or an onyx developer know. Error: {error_message}"
|
||||
)
|
||||
return SectionBlock(text=display_text)
|
||||
|
||||
|
||||
@retry(
|
||||
tries=DANSWER_BOT_NUM_RETRIES,
|
||||
delay=0.25,
|
||||
@@ -162,24 +197,9 @@ def respond_in_thread(
|
||||
message_ids: list[str] = []
|
||||
if not receiver_ids:
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postMessage)
|
||||
response = slack_call(
|
||||
channel=channel,
|
||||
text=text,
|
||||
blocks=blocks,
|
||||
thread_ts=thread_ts,
|
||||
metadata=metadata,
|
||||
unfurl_links=unfurl,
|
||||
unfurl_media=unfurl,
|
||||
)
|
||||
if not response.get("ok"):
|
||||
raise RuntimeError(f"Failed to post message: {response}")
|
||||
message_ids.append(response["message_ts"])
|
||||
else:
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
|
||||
for receiver in receiver_ids:
|
||||
try:
|
||||
response = slack_call(
|
||||
channel=channel,
|
||||
user=receiver,
|
||||
text=text,
|
||||
blocks=blocks,
|
||||
thread_ts=thread_ts,
|
||||
@@ -187,8 +207,68 @@ def respond_in_thread(
|
||||
unfurl_links=unfurl,
|
||||
unfurl_media=unfurl,
|
||||
)
|
||||
if not response.get("ok"):
|
||||
raise RuntimeError(f"Failed to post message: {response}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to post message: {e} \n blocks: {blocks}")
|
||||
logger.warning("Trying again without blocks that have urls")
|
||||
|
||||
if not blocks:
|
||||
raise e
|
||||
|
||||
blocks_without_urls = [
|
||||
block for block in blocks if not _check_for_url_in_block(block)
|
||||
]
|
||||
blocks_without_urls.append(_build_error_block(str(e)))
|
||||
|
||||
# Try again wtihout blocks containing url
|
||||
response = slack_call(
|
||||
channel=channel,
|
||||
text=text,
|
||||
blocks=blocks_without_urls,
|
||||
thread_ts=thread_ts,
|
||||
metadata=metadata,
|
||||
unfurl_links=unfurl,
|
||||
unfurl_media=unfurl,
|
||||
)
|
||||
|
||||
message_ids.append(response["message_ts"])
|
||||
else:
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
|
||||
for receiver in receiver_ids:
|
||||
try:
|
||||
response = slack_call(
|
||||
channel=channel,
|
||||
user=receiver,
|
||||
text=text,
|
||||
blocks=blocks,
|
||||
thread_ts=thread_ts,
|
||||
metadata=metadata,
|
||||
unfurl_links=unfurl,
|
||||
unfurl_media=unfurl,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to post message: {e} \n blocks: {blocks}")
|
||||
logger.warning("Trying again without blocks that have urls")
|
||||
|
||||
if not blocks:
|
||||
raise e
|
||||
|
||||
blocks_without_urls = [
|
||||
block for block in blocks if not _check_for_url_in_block(block)
|
||||
]
|
||||
blocks_without_urls.append(_build_error_block(str(e)))
|
||||
|
||||
# Try again wtihout blocks containing url
|
||||
response = slack_call(
|
||||
channel=channel,
|
||||
user=receiver,
|
||||
text=text,
|
||||
blocks=blocks_without_urls,
|
||||
thread_ts=thread_ts,
|
||||
metadata=metadata,
|
||||
unfurl_links=unfurl,
|
||||
unfurl_media=unfurl,
|
||||
)
|
||||
|
||||
message_ids.append(response["message_ts"])
|
||||
|
||||
return message_ids
|
||||
@@ -216,6 +296,13 @@ def build_feedback_id(
|
||||
return unique_prefix + ID_SEPARATOR + feedback_id
|
||||
|
||||
|
||||
def build_continue_in_web_ui_id(
|
||||
message_id: int,
|
||||
) -> str:
|
||||
unique_prefix = str(uuid.uuid4())[:10]
|
||||
return unique_prefix + ID_SEPARATOR + str(message_id)
|
||||
|
||||
|
||||
def decompose_action_id(feedback_id: str) -> tuple[int, str | None, int | None]:
|
||||
"""Decompose into query_id, document_id, document_rank, see above function"""
|
||||
try:
|
||||
@@ -313,7 +400,7 @@ def get_channel_name_from_id(
|
||||
raise e
|
||||
|
||||
|
||||
def fetch_user_ids_from_emails(
|
||||
def fetch_slack_user_ids_from_emails(
|
||||
user_emails: list[str], client: WebClient
|
||||
) -> tuple[list[str], list[str]]:
|
||||
user_ids: list[str] = []
|
||||
@@ -522,7 +609,7 @@ class SlackRateLimiter:
|
||||
self.last_reset_time = time.time()
|
||||
|
||||
def notify(
|
||||
self, client: WebClient, channel: str, position: int, thread_ts: Optional[str]
|
||||
self, client: WebClient, channel: str, position: int, thread_ts: str | None
|
||||
) -> None:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
|
||||
@@ -2,6 +2,7 @@ import uuid
|
||||
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -45,14 +46,16 @@ def fetch_api_keys(db_session: Session) -> list[ApiKeyDescriptor]:
|
||||
]
|
||||
|
||||
|
||||
def fetch_user_for_api_key(hashed_api_key: str, db_session: Session) -> User | None:
|
||||
api_key = db_session.scalar(
|
||||
select(ApiKey).where(ApiKey.hashed_api_key == hashed_api_key)
|
||||
async def fetch_user_for_api_key(
|
||||
hashed_api_key: str, async_db_session: AsyncSession
|
||||
) -> User | None:
|
||||
"""NOTE: this is async, since it's used during auth
|
||||
(which is necessarily async due to FastAPI Users)"""
|
||||
return await async_db_session.scalar(
|
||||
select(User)
|
||||
.join(ApiKey, ApiKey.user_id == User.id)
|
||||
.where(ApiKey.hashed_api_key == hashed_api_key)
|
||||
)
|
||||
if api_key is None:
|
||||
return None
|
||||
|
||||
return db_session.scalar(select(User).where(User.id == api_key.user_id)) # type: ignore
|
||||
|
||||
|
||||
def get_api_key_fake_email(
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import func
|
||||
@@ -18,6 +19,9 @@ from danswer.auth.schemas import UserRole
|
||||
from danswer.chat.models import DocumentRelevance
|
||||
from danswer.configs.chat_configs import HARD_DELETE_CHATS
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.context.search.models import RetrievalDocs
|
||||
from danswer.context.search.models import SavedSearchDoc
|
||||
from danswer.context.search.models import SearchDoc as ServerSearchDoc
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import ChatMessage__SearchDoc
|
||||
from danswer.db.models import ChatSession
|
||||
@@ -27,13 +31,11 @@ from danswer.db.models import SearchDoc
|
||||
from danswer.db.models import SearchDoc as DBSearchDoc
|
||||
from danswer.db.models import ToolCall
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_best_persona_id_for_user
|
||||
from danswer.db.pg_file_store import delete_lobj_by_name
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.llm.override_models import LLMOverride
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SavedSearchDoc
|
||||
from danswer.search.models import SearchDoc as ServerSearchDoc
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -143,16 +145,10 @@ def get_chat_sessions_by_user(
|
||||
user_id: UUID | None,
|
||||
deleted: bool | None,
|
||||
db_session: Session,
|
||||
only_one_shot: bool = False,
|
||||
limit: int = 50,
|
||||
) -> list[ChatSession]:
|
||||
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
|
||||
|
||||
if only_one_shot:
|
||||
stmt = stmt.where(ChatSession.one_shot.is_(True))
|
||||
else:
|
||||
stmt = stmt.where(ChatSession.one_shot.is_(False))
|
||||
|
||||
stmt = stmt.order_by(desc(ChatSession.time_created))
|
||||
|
||||
if deleted is not None:
|
||||
@@ -224,12 +220,11 @@ def delete_messages_and_files_from_chat_session(
|
||||
|
||||
def create_chat_session(
|
||||
db_session: Session,
|
||||
description: str,
|
||||
description: str | None,
|
||||
user_id: UUID | None,
|
||||
persona_id: int | None, # Can be none if temporary persona is used
|
||||
llm_override: LLMOverride | None = None,
|
||||
prompt_override: PromptOverride | None = None,
|
||||
one_shot: bool = False,
|
||||
danswerbot_flow: bool = False,
|
||||
slack_thread_id: str | None = None,
|
||||
) -> ChatSession:
|
||||
@@ -239,7 +234,6 @@ def create_chat_session(
|
||||
description=description,
|
||||
llm_override=llm_override,
|
||||
prompt_override=prompt_override,
|
||||
one_shot=one_shot,
|
||||
danswerbot_flow=danswerbot_flow,
|
||||
slack_thread_id=slack_thread_id,
|
||||
)
|
||||
@@ -250,6 +244,48 @@ def create_chat_session(
|
||||
return chat_session
|
||||
|
||||
|
||||
def duplicate_chat_session_for_user_from_slack(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
chat_session_id: UUID,
|
||||
) -> ChatSession:
|
||||
"""
|
||||
This takes a chat session id for a session in Slack and:
|
||||
- Creates a new chat session in the DB
|
||||
- Tries to copy the persona from the original chat session
|
||||
(if it is available to the user clicking the button)
|
||||
- Sets the user to the given user (if provided)
|
||||
"""
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id,
|
||||
user_id=None, # Ignore user permissions for this
|
||||
db_session=db_session,
|
||||
)
|
||||
if not chat_session:
|
||||
raise HTTPException(status_code=400, detail="Invalid Chat Session ID provided")
|
||||
|
||||
# This enforces permissions and sets a default
|
||||
new_persona_id = get_best_persona_id_for_user(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona_id=chat_session.persona_id,
|
||||
)
|
||||
|
||||
return create_chat_session(
|
||||
db_session=db_session,
|
||||
user_id=user.id if user else None,
|
||||
persona_id=new_persona_id,
|
||||
# Set this to empty string so the frontend will force a rename
|
||||
description="",
|
||||
llm_override=chat_session.llm_override,
|
||||
prompt_override=chat_session.prompt_override,
|
||||
# Chat is in UI now so this is false
|
||||
danswerbot_flow=False,
|
||||
# Maybe we want this in the future to track if it was created from Slack
|
||||
slack_thread_id=None,
|
||||
)
|
||||
|
||||
|
||||
def update_chat_session(
|
||||
db_session: Session,
|
||||
user_id: UUID | None,
|
||||
@@ -336,6 +372,28 @@ def get_chat_message(
|
||||
return chat_message
|
||||
|
||||
|
||||
def get_chat_session_by_message_id(
|
||||
db_session: Session,
|
||||
message_id: int,
|
||||
) -> ChatSession:
|
||||
"""
|
||||
Should only be used for Slack
|
||||
Get the chat session associated with a specific message ID
|
||||
Note: this ignores permission checks.
|
||||
"""
|
||||
stmt = select(ChatMessage).where(ChatMessage.id == message_id)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
chat_message = result.scalar_one_or_none()
|
||||
|
||||
if chat_message is None:
|
||||
raise ValueError(
|
||||
f"Unable to find chat session associated with message ID: {message_id}"
|
||||
)
|
||||
|
||||
return chat_message.chat_session
|
||||
|
||||
|
||||
def get_chat_messages_by_sessions(
|
||||
chat_session_ids: list[UUID],
|
||||
user_id: UUID | None,
|
||||
@@ -355,6 +413,44 @@ def get_chat_messages_by_sessions(
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def add_chats_to_session_from_slack_thread(
|
||||
db_session: Session,
|
||||
slack_chat_session_id: UUID,
|
||||
new_chat_session_id: UUID,
|
||||
) -> None:
|
||||
new_root_message = get_or_create_root_message(
|
||||
chat_session_id=new_chat_session_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
for chat_message in get_chat_messages_by_sessions(
|
||||
chat_session_ids=[slack_chat_session_id],
|
||||
user_id=None, # Ignore user permissions for this
|
||||
db_session=db_session,
|
||||
skip_permission_check=True,
|
||||
):
|
||||
if chat_message.message_type == MessageType.SYSTEM:
|
||||
continue
|
||||
# Duplicate the message
|
||||
new_root_message = create_new_chat_message(
|
||||
db_session=db_session,
|
||||
chat_session_id=new_chat_session_id,
|
||||
parent_message=new_root_message,
|
||||
message=chat_message.message,
|
||||
files=chat_message.files,
|
||||
rephrased_query=chat_message.rephrased_query,
|
||||
error=chat_message.error,
|
||||
citations=chat_message.citations,
|
||||
reference_docs=chat_message.search_docs,
|
||||
tool_call=chat_message.tool_call,
|
||||
prompt_id=chat_message.prompt_id,
|
||||
token_count=chat_message.token_count,
|
||||
message_type=chat_message.message_type,
|
||||
alternate_assistant_id=chat_message.alternate_assistant_id,
|
||||
overridden_model=chat_message.overridden_model,
|
||||
)
|
||||
|
||||
|
||||
def get_search_docs_for_chat_message(
|
||||
chat_message_id: int, db_session: Session
|
||||
) -> list[SearchDoc]:
|
||||
|
||||
@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.app_configs import DEFAULT_PRUNING_FREQ
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.enums import IndexingMode
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexAttempt
|
||||
@@ -311,3 +312,25 @@ def mark_cc_pair_as_external_group_synced(db_session: Session, cc_pair_id: int)
|
||||
# If this changes, we need to update this function.
|
||||
cc_pair.last_time_external_group_sync = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_ccpair_with_indexing_trigger(
|
||||
cc_pair_id: int, indexing_mode: IndexingMode | None, db_session: Session
|
||||
) -> None:
|
||||
"""indexing_mode sets a field which will be picked up by a background task
|
||||
to trigger indexing. Set to None to disable the trigger."""
|
||||
try:
|
||||
cc_pair = db_session.execute(
|
||||
select(ConnectorCredentialPair)
|
||||
.where(ConnectorCredentialPair.id == cc_pair_id)
|
||||
.with_for_update()
|
||||
).scalar_one()
|
||||
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
|
||||
|
||||
cc_pair.indexing_trigger = indexing_mode
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
|
||||
@@ -324,8 +324,11 @@ def associate_default_cc_pair(db_session: Session) -> None:
|
||||
def _relate_groups_to_cc_pair__no_commit(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
user_group_ids: list[int],
|
||||
user_group_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
if not user_group_ids:
|
||||
return
|
||||
|
||||
for group_id in user_group_ids:
|
||||
db_session.add(
|
||||
UserGroup__ConnectorCredentialPair(
|
||||
@@ -402,12 +405,11 @@ def add_credential_to_connector(
|
||||
db_session.flush() # make sure the association has an id
|
||||
db_session.refresh(association)
|
||||
|
||||
if groups and access_type != AccessType.SYNC:
|
||||
_relate_groups_to_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=association.id,
|
||||
user_group_ids=groups,
|
||||
)
|
||||
_relate_groups_to_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=association.id,
|
||||
user_group_ids=groups,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.server.documents.models import CredentialDataUpdateRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -248,7 +247,6 @@ def create_credential(
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return credential
|
||||
|
||||
|
||||
@@ -263,7 +261,8 @@ def _cleanup_credential__user_group_relationships__no_commit(
|
||||
|
||||
def alter_credential(
|
||||
credential_id: int,
|
||||
credential_data: CredentialDataUpdateRequest,
|
||||
name: str,
|
||||
credential_json: dict[str, Any],
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> Credential | None:
|
||||
@@ -273,11 +272,13 @@ def alter_credential(
|
||||
if credential is None:
|
||||
return None
|
||||
|
||||
credential.name = credential_data.name
|
||||
credential.name = name
|
||||
|
||||
# Update only the keys present in credential_data.credential_json
|
||||
for key, value in credential_data.credential_json.items():
|
||||
credential.credential_json[key] = value
|
||||
# Assign a new dictionary to credential.credential_json
|
||||
credential.credential_json = {
|
||||
**credential.credential_json,
|
||||
**credential_json,
|
||||
}
|
||||
|
||||
credential.user_id = user.id if user is not None else None
|
||||
db_session.commit()
|
||||
@@ -310,8 +311,8 @@ def update_credential_json(
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
if credential is None:
|
||||
return None
|
||||
credential.credential_json = credential_json
|
||||
|
||||
credential.credential_json = credential_json
|
||||
db_session.commit()
|
||||
return credential
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ from danswer.configs.app_configs import POSTGRES_PORT
|
||||
from danswer.configs.app_configs import POSTGRES_USER
|
||||
from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from danswer.server.utils import BasicAuthenticationError
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -426,7 +427,9 @@ def get_session() -> Generator[Session, None, None]:
|
||||
"""Generate a database session with the appropriate tenant schema set."""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
|
||||
raise HTTPException(status_code=401, detail="User must authenticate")
|
||||
raise BasicAuthenticationError(
|
||||
detail="User must authenticate",
|
||||
)
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ class IndexingStatus(str, PyEnum):
|
||||
NOT_STARTED = "not_started"
|
||||
IN_PROGRESS = "in_progress"
|
||||
SUCCESS = "success"
|
||||
CANCELED = "canceled"
|
||||
FAILED = "failed"
|
||||
COMPLETED_WITH_ERRORS = "completed_with_errors"
|
||||
|
||||
@@ -12,11 +13,17 @@ class IndexingStatus(str, PyEnum):
|
||||
terminal_states = {
|
||||
IndexingStatus.SUCCESS,
|
||||
IndexingStatus.COMPLETED_WITH_ERRORS,
|
||||
IndexingStatus.CANCELED,
|
||||
IndexingStatus.FAILED,
|
||||
}
|
||||
return self in terminal_states
|
||||
|
||||
|
||||
class IndexingMode(str, PyEnum):
|
||||
UPDATE = "update"
|
||||
REINDEX = "reindex"
|
||||
|
||||
|
||||
# these may differ in the future, which is why we're okay with this duplication
|
||||
class DeletionStatus(str, PyEnum):
|
||||
NOT_STARTED = "not_started"
|
||||
|
||||
@@ -225,6 +225,28 @@ def mark_attempt_partially_succeeded(
|
||||
raise
|
||||
|
||||
|
||||
def mark_attempt_canceled(
|
||||
index_attempt_id: int,
|
||||
db_session: Session,
|
||||
reason: str = "Unknown",
|
||||
) -> None:
|
||||
try:
|
||||
attempt = db_session.execute(
|
||||
select(IndexAttempt)
|
||||
.where(IndexAttempt.id == index_attempt_id)
|
||||
.with_for_update()
|
||||
).scalar_one()
|
||||
|
||||
if not attempt.time_started:
|
||||
attempt.time_started = datetime.now(timezone.utc)
|
||||
attempt.status = IndexingStatus.CANCELED
|
||||
attempt.error_msg = reason
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def mark_attempt_failed(
|
||||
index_attempt_id: int,
|
||||
db_session: Session,
|
||||
@@ -500,12 +522,16 @@ def expire_index_attempts(
|
||||
search_settings_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
delete_query = (
|
||||
delete(IndexAttempt)
|
||||
not_started_query = (
|
||||
update(IndexAttempt)
|
||||
.where(IndexAttempt.search_settings_id == search_settings_id)
|
||||
.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
|
||||
.values(
|
||||
status=IndexingStatus.CANCELED,
|
||||
error_msg="Canceled, likely due to model swap",
|
||||
)
|
||||
)
|
||||
db_session.execute(delete_query)
|
||||
db_session.execute(not_started_query)
|
||||
|
||||
update_query = (
|
||||
update(IndexAttempt)
|
||||
@@ -527,9 +553,14 @@ def cancel_indexing_attempts_for_ccpair(
|
||||
include_secondary_index: bool = False,
|
||||
) -> None:
|
||||
stmt = (
|
||||
delete(IndexAttempt)
|
||||
update(IndexAttempt)
|
||||
.where(IndexAttempt.connector_credential_pair_id == cc_pair_id)
|
||||
.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
|
||||
.values(
|
||||
status=IndexingStatus.CANCELED,
|
||||
error_msg="Canceled by user",
|
||||
time_started=datetime.now(timezone.utc),
|
||||
)
|
||||
)
|
||||
|
||||
if not include_secondary_index:
|
||||
|
||||
@@ -1,202 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import InputPrompt
|
||||
from danswer.db.models import User
|
||||
from danswer.server.features.input_prompt.models import InputPromptSnapshot
|
||||
from danswer.server.manage.models import UserInfo
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def insert_input_prompt_if_not_exists(
|
||||
user: User | None,
|
||||
input_prompt_id: int | None,
|
||||
prompt: str,
|
||||
content: str,
|
||||
active: bool,
|
||||
is_public: bool,
|
||||
db_session: Session,
|
||||
commit: bool = True,
|
||||
) -> InputPrompt:
|
||||
if input_prompt_id is not None:
|
||||
input_prompt = (
|
||||
db_session.query(InputPrompt).filter_by(id=input_prompt_id).first()
|
||||
)
|
||||
else:
|
||||
query = db_session.query(InputPrompt).filter(InputPrompt.prompt == prompt)
|
||||
if user:
|
||||
query = query.filter(InputPrompt.user_id == user.id)
|
||||
else:
|
||||
query = query.filter(InputPrompt.user_id.is_(None))
|
||||
input_prompt = query.first()
|
||||
|
||||
if input_prompt is None:
|
||||
input_prompt = InputPrompt(
|
||||
id=input_prompt_id,
|
||||
prompt=prompt,
|
||||
content=content,
|
||||
active=active,
|
||||
is_public=is_public or user is None,
|
||||
user_id=user.id if user else None,
|
||||
)
|
||||
db_session.add(input_prompt)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
|
||||
return input_prompt
|
||||
|
||||
|
||||
def insert_input_prompt(
|
||||
prompt: str,
|
||||
content: str,
|
||||
is_public: bool,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> InputPrompt:
|
||||
input_prompt = InputPrompt(
|
||||
prompt=prompt,
|
||||
content=content,
|
||||
active=True,
|
||||
is_public=is_public or user is None,
|
||||
user_id=user.id if user is not None else None,
|
||||
)
|
||||
db_session.add(input_prompt)
|
||||
db_session.commit()
|
||||
|
||||
return input_prompt
|
||||
|
||||
|
||||
def update_input_prompt(
|
||||
user: User | None,
|
||||
input_prompt_id: int,
|
||||
prompt: str,
|
||||
content: str,
|
||||
active: bool,
|
||||
db_session: Session,
|
||||
) -> InputPrompt:
|
||||
input_prompt = db_session.scalar(
|
||||
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
|
||||
)
|
||||
if input_prompt is None:
|
||||
raise ValueError(f"No input prompt with id {input_prompt_id}")
|
||||
|
||||
if not validate_user_prompt_authorization(user, input_prompt):
|
||||
raise HTTPException(status_code=401, detail="You don't own this prompt")
|
||||
|
||||
input_prompt.prompt = prompt
|
||||
input_prompt.content = content
|
||||
input_prompt.active = active
|
||||
|
||||
db_session.commit()
|
||||
return input_prompt
|
||||
|
||||
|
||||
def validate_user_prompt_authorization(
|
||||
user: User | None, input_prompt: InputPrompt
|
||||
) -> bool:
|
||||
prompt = InputPromptSnapshot.from_model(input_prompt=input_prompt)
|
||||
|
||||
if prompt.user_id is not None:
|
||||
if user is None:
|
||||
return False
|
||||
|
||||
user_details = UserInfo.from_model(user)
|
||||
if str(user_details.id) != str(prompt.user_id):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def remove_public_input_prompt(input_prompt_id: int, db_session: Session) -> None:
|
||||
input_prompt = db_session.scalar(
|
||||
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
|
||||
)
|
||||
|
||||
if input_prompt is None:
|
||||
raise ValueError(f"No input prompt with id {input_prompt_id}")
|
||||
|
||||
if not input_prompt.is_public:
|
||||
raise HTTPException(status_code=400, detail="This prompt is not public")
|
||||
|
||||
db_session.delete(input_prompt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_input_prompt(
|
||||
user: User | None, input_prompt_id: int, db_session: Session
|
||||
) -> None:
|
||||
input_prompt = db_session.scalar(
|
||||
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
|
||||
)
|
||||
if input_prompt is None:
|
||||
raise ValueError(f"No input prompt with id {input_prompt_id}")
|
||||
|
||||
if input_prompt.is_public:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Cannot delete public prompts with this method"
|
||||
)
|
||||
|
||||
if not validate_user_prompt_authorization(user, input_prompt):
|
||||
raise HTTPException(status_code=401, detail="You do not own this prompt")
|
||||
|
||||
db_session.delete(input_prompt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def fetch_input_prompt_by_id(
|
||||
id: int, user_id: UUID | None, db_session: Session
|
||||
) -> InputPrompt:
|
||||
query = select(InputPrompt).where(InputPrompt.id == id)
|
||||
|
||||
if user_id:
|
||||
query = query.where(
|
||||
(InputPrompt.user_id == user_id) | (InputPrompt.user_id is None)
|
||||
)
|
||||
else:
|
||||
# If no user_id is provided, only fetch prompts without a user_id (aka public)
|
||||
query = query.where(InputPrompt.user_id == None) # noqa
|
||||
|
||||
result = db_session.scalar(query)
|
||||
|
||||
if result is None:
|
||||
raise HTTPException(422, "No input prompt found")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def fetch_public_input_prompts(
|
||||
db_session: Session,
|
||||
) -> list[InputPrompt]:
|
||||
query = select(InputPrompt).where(InputPrompt.is_public)
|
||||
return list(db_session.scalars(query).all())
|
||||
|
||||
|
||||
def fetch_input_prompts_by_user(
|
||||
db_session: Session,
|
||||
user_id: UUID | None,
|
||||
active: bool | None = None,
|
||||
include_public: bool = False,
|
||||
) -> list[InputPrompt]:
|
||||
query = select(InputPrompt)
|
||||
|
||||
if user_id is not None:
|
||||
if include_public:
|
||||
query = query.where(
|
||||
(InputPrompt.user_id == user_id) | InputPrompt.is_public
|
||||
)
|
||||
else:
|
||||
query = query.where(InputPrompt.user_id == user_id)
|
||||
|
||||
elif include_public:
|
||||
query = query.where(InputPrompt.is_public)
|
||||
|
||||
if active is not None:
|
||||
query = query.where(InputPrompt.active == active)
|
||||
|
||||
return list(db_session.scalars(query).all())
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user