Compare commits

...

85 Commits

Author SHA1 Message Date
Yuhong Sun
4293543a6a k 2024-07-20 16:48:05 -07:00
Yuhong Sun
e95bfa0e0b Suffix Test (#1880) 2024-07-20 15:54:55 -07:00
Yuhong Sun
4848b5f1de Suffix Edits (#1878) 2024-07-20 13:59:14 -07:00
Yuhong Sun
7ba5c434fa Missing Comma (#1877) 2024-07-19 22:15:45 -07:00
Yuhong Sun
59bf5ba848 File Connector Metadata (#1876) 2024-07-19 20:45:18 -07:00
Weves
f66c33380c Improve widget README 2024-07-19 20:21:07 -07:00
Weves
115650ce9f Add example widget code 2024-07-19 20:14:52 -07:00
Weves
7aa3602fca Fix black 2024-07-19 18:55:09 -07:00
Weves
864c552a17 Fix UT 2024-07-19 18:55:09 -07:00
Brent Kwok
07b2ed3d8f Fix HTTP 422 error for api_inference_sample.py (#1868) 2024-07-19 18:54:43 -07:00
Yuhong Sun
38290057f2 Search Eval (#1873) 2024-07-19 16:48:58 -07:00
Weves
2344edf158 Change default login time to 7 days 2024-07-19 13:58:50 -07:00
versecafe
86d1804eb0 Add GPT-4o-Mini & fix a missing gpt-4o 2024-07-19 12:10:27 -07:00
pablodanswer
1ebae50d0c minor udpate 2024-07-19 10:53:28 -07:00
Weves
a9fbaa396c Stop building on every PR 2024-07-19 10:21:19 -07:00
pablodanswer
27d5f69427 udpate to headers (#1864) 2024-07-19 08:38:54 -07:00
pablodanswer
5d98421ae8 show "analysis" (#1863) 2024-07-18 18:18:36 -07:00
Kevin Shi
6b561b8ca9 Add config to skip zendesk article labels 2024-07-18 18:00:51 -07:00
pablodanswer
2dc7e64dd7 fix internet search icons / text + assistants tab (#1862) 2024-07-18 16:15:19 -07:00
Yuhong Sun
5230f7e22f Enforce Disable GenAI if set (#1860) 2024-07-18 13:25:55 -07:00
hagen-danswer
a595d43ae3 Fixed deleting toolcall by message 2024-07-18 12:52:28 -07:00
Yuhong Sun
ee561f42ff Cleaner Layout (#1857) 2024-07-18 11:13:16 -07:00
Yuhong Sun
f00b3d76b3 Touchup NoOp (#1856) 2024-07-18 08:44:27 -07:00
Yuhong Sun
e4984153c0 Touchups (#1855) 2024-07-17 23:47:10 -07:00
pablodanswer
87fadb07ea COMPLETE USER EXPERIENCE OVERHAUL (#1822) 2024-07-17 19:44:21 -07:00
pablodanswer
2b07c102f9 fix discourse connector rate limiting + topic fetching (#1820) 2024-07-17 14:57:40 -07:00
hagen-danswer
e93de602c3 Use SHA instead of branch and save more data (#1850) 2024-07-17 14:56:24 -07:00
hagen-danswer
1c77395503 Fixed llm_indices from document search api (#1853) 2024-07-17 14:52:49 -07:00
Victorivus
cdf6089b3e Fix bug XML files in chat (#1804) 2024-07-17 08:09:40 -07:00
pablodanswer
d01f46af2b fix search doc bug (#1851) 2024-07-16 15:27:04 -07:00
hagen-danswer
b83f435bb0 Catch dropped eval questions and added multiprocessing (#1849) 2024-07-16 12:33:02 -07:00
hagen-danswer
25b3dacaba Seperated model caching volumes (#1845) 2024-07-15 15:32:04 -07:00
hagen-danswer
a1e638a73d Improved eval logging and stability (#1843) 2024-07-15 14:58:45 -07:00
Yuhong Sun
bd1e0c5969 Add Enum File (#1842) 2024-07-15 09:13:27 -07:00
Yuhong Sun
4d295ab97d Model Server Logging (#1839) 2024-07-15 09:00:27 -07:00
Weves
6fe3eeaa48 Fix model serer startup 2024-07-14 23:33:58 -07:00
Chris Weaver
078d5defbb Update Slack link in README.md 2024-07-14 16:50:48 -07:00
Weves
0d52e99bd4 Improve confluence rate limiting 2024-07-14 16:40:45 -07:00
hagen-danswer
1b864a00e4 Added support for multiple Eval Pipeline UIs (#1830) 2024-07-14 15:16:20 -07:00
Weves
dae4f6a0bd Fix latency caused by large numbers of tags 2024-07-14 14:21:07 -07:00
Yuhong Sun
f63d0ca3ad Title Truncation Logic (#1828) 2024-07-14 13:54:36 -07:00
Yuhong Sun
da31da33e7 Fix Title for docs without (#1827) 2024-07-14 13:51:11 -07:00
Yuhong Sun
56b175f597 Fix Sitemap Robo (#1826) 2024-07-14 13:29:26 -07:00
Zoltan Szabo
1b311d092e Try to find the sitemap for a given site (#1538) 2024-07-14 13:24:10 -07:00
Moshe Zada
6ee1292757 Fix semantic id for web pdfs (#1823) 2024-07-14 11:38:11 -07:00
Yuhong Sun
017af052be Global Tokenizer Fix (#1825) 2024-07-14 11:37:10 -07:00
pablodanswer
e7f81d1688 add third party embedding models (#1818) 2024-07-14 10:19:53 -07:00
Weves
b6bd818e60 Fix user groups page when a persona is deleted 2024-07-13 15:35:50 -07:00
hagen-danswer
36da2e4b27 Fixed slack groups (#1814)
* Simplified slackbot response groups and fixed need more help bug

* mypy fixes

* added exceptions for the couldnt find passthrough arrays
2024-07-13 22:34:35 +00:00
pablodanswer
c7af6a4601 add new standard answer test endpoint (#1789) 2024-07-12 10:06:30 -07:00
Yuhong Sun
e90c66c1b6 Include Titles in Chunks (#1817) 2024-07-12 09:42:24 -07:00
hagen-danswer
8c312482c1 fixed id retrieval from zip metadata (#1813) 2024-07-11 20:38:12 -07:00
Weves
e50820e65e Remove Internet 'Connector' that mistakenly appears on the Add Connector page 2024-07-11 18:00:59 -07:00
hagen-danswer
991ee79e47 some qol improvements for search pipeline (#1809) 2024-07-11 17:42:11 -07:00
hagen-danswer
3e645a510e Fix slack error logging (#1800) 2024-07-11 08:31:48 -07:00
Yuhong Sun
08c6e821e7 Merge Sections Logic (#1801) 2024-07-10 20:14:02 -07:00
hagen-danswer
47a550221f slackbot doesnt respond without citations/quotes (#1798)
* slackbot doesnt respond without citations/quotes

fixed logical issues

fixed dict logic

* added slackbot shim for the llm source/time feature

* mypy fixes

* slackbot doesnt respond without citations/quotes

fixed logical issues

fixed dict logic

* Update handle_regular_answer.py

* added bypass_filter check

* final fixes
2024-07-11 00:18:26 +00:00
Weves
511f619212 Add content to /document-search response 2024-07-10 15:44:58 -07:00
Varun Gaur
6c51f001dc Confluence Connector to Sync Child pages only (#1629)
---------

Co-authored-by: Varun Gaur <vgaur@roku.com>
Co-authored-by: hagen-danswer <hagen@danswer.ai>
Co-authored-by: pablodanswer <pablo@danswer.ai>
2024-07-10 14:17:03 -07:00
pablodanswer
09a11b5e1a Fix citations + unit tests (#1760) 2024-07-10 10:05:20 -07:00
pablodanswer
aa0f7abdac add basic table wrapping (#1791) 2024-07-09 19:14:41 -07:00
Yuhong Sun
7c8f8dba17 Break the Danswer LLM logging from LiteLLM Verbose (#1795) 2024-07-09 18:18:29 -07:00
Yuhong Sun
39982e5fdc Info propagating to allow Chunk Merging (#1794) 2024-07-09 18:15:07 -07:00
pablodanswer
5e0de111f9 fix wrapping in error hover connector (#1790) 2024-07-09 11:54:35 -07:00
pablodanswer
727d80f168 fix gpt-4o image issue (#1786) 2024-07-08 23:07:53 +00:00
rashad-danswer
146f85936b Internet Search Tool (#1666)
---------

Co-authored-by: Weves <chrisweaver101@gmail.com>
2024-07-06 18:01:24 -07:00
Chris Weaver
e06f8a0a4b Standard Answers (#1753)
---------

Co-authored-by: druhinsgoel <druhin@danswer.ai>
2024-07-06 16:11:11 -07:00
Yuhong Sun
f0888f2f61 Eval Script Incremental Write (#1784) 2024-07-06 15:43:40 -07:00
Yuhong Sun
d35d7ee833 Evaluation Pipeline Touchup (#1783) 2024-07-06 13:17:05 -07:00
Yuhong Sun
c5bb3fde94 Ignore Eval Files (#1782) 2024-07-06 12:15:03 -07:00
Yuhong Sun
79190030a5 New Env File for Eval (#1781) 2024-07-06 12:07:31 -07:00
Yuhong Sun
8e8f262ed3 Docker Compose Eval Pipeline Cleanup (#1780) 2024-07-06 12:04:57 -07:00
hagen-danswer
ac14369716 Added search quality testing pipeline (#1774) 2024-07-06 11:51:50 -07:00
Weves
de4d8e9a65 Fix shared chats 2024-07-04 11:41:16 -07:00
hagen-danswer
0b384c5b34 fixed salesforce url generation (#1777) 2024-07-04 10:43:21 -07:00
Weves
fa049f4f98 Add UI support for github configs 2024-07-03 17:37:59 -07:00
pablodanswer
72d6a0ef71 minor updates to assistant UI (#1771) 2024-07-03 18:28:25 +00:00
pablodanswer
ae4e643266 Update Assistants Creation UI (#1714)
* slide up "Tools"

* rework assistants page

* update layout

* reorg complete

- pending: useful header text?

* add tooltips

* alter organizational structure

* rm shadcn

* rm dependencies

* revalidate dependencies

* restore

* update component structure

* [s] format

* rm package json

* add package-lock.json [s]

* collapsible

* naming + width

* formatting

* formatting

* updated user flow

- Fix error/detail messages
- Fix tooltip delay
- Fix icons

* 1 -> 2

* naming fixes

* ran pretty

* fix build issue?

* web build issues?
2024-07-03 17:11:14 +00:00
hagen-danswer
a7da07afc0 allowed arbitrary types to handle the sqlalchemy datatype (#1758)
* allowed arbitrary types to handle the sqlalchemy datatype

* changed persona_upsert to take in ids instead of objects
2024-07-03 07:10:57 +00:00
Weves
7f1bb67e52 Pass through API base to ImageGenerationTool 2024-07-02 23:31:04 -07:00
Weves
982b1b0c49 Add litellm.set_verbose support 2024-07-02 23:22:17 -07:00
Daniel Naber
2db128fb36 Notion date filter fix (#1755)
* fix filter logic

* make comparison better readable
2024-07-02 15:39:35 -07:00
Christoph Petzold
3ebac6256f Fix "cannot access local variable" for bot direct messages (#1737)
* Update handle_message.py

* Update handle_message.py

* Update handle_message.py

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
2024-07-02 15:36:23 -07:00
Weves
1a3ec59610 Fix build caused by bad seeding config 2024-07-01 23:41:43 -07:00
hagen-danswer
581cb827bb added settings and persona seeding options (#1742)
* added settings and persona seeding options

* updated recency_bias

* changed variable type

* another fix

* Update seeding.py

* fixed mypy

* push
2024-07-01 22:22:17 +00:00
338 changed files with 24414 additions and 5910 deletions

View File

@@ -1,8 +1,6 @@
name: Build Backend Image on Merge Group
on:
pull_request:
branches: [ "main" ]
merge_group:
types: [checks_requested]

View File

@@ -1,8 +1,6 @@
name: Build Web Image on Merge Group
on:
pull_request:
branches: [ "main" ]
merge_group:
types: [checks_requested]

1
.gitignore vendored
View File

@@ -6,3 +6,4 @@
/deployment/data/nginx/app.conf
.vscode/launch.json
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml

View File

@@ -8,7 +8,7 @@ AUTH_TYPE=disabled
# Always keep these on for Dev
# Logs all model prompts to stdout
LOG_ALL_MODEL_INTERACTIONS=True
LOG_DANSWER_MODEL_INTERACTIONS=True
# More verbose logging
LOG_LEVEL=debug
@@ -25,11 +25,6 @@ OAUTH_CLIENT_SECRET=<REPLACE THIS>
REQUIRE_EMAIL_VERIFICATION=False
# Toggles on/off the EE Features
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
GEN_AI_API_KEY=<REPLACE THIS>
# If answer quality isn't important for dev, use 3.5 turbo due to it being cheaper
@@ -47,6 +42,11 @@ PYTHONPATH=./backend
PYTHONUNBUFFERED=1
# Internet Search
BING_API_KEY=<REPLACE THIS>
# Enable the full set of Danswer Enterprise Edition features
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False

View File

@@ -49,7 +49,7 @@
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_ALL_MODEL_INTERACTIONS": "True",
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
@@ -83,6 +83,7 @@
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
@@ -105,6 +106,24 @@
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
}
},
{
"name": "Pytest",
"type": "python",
"request": "launch",
"module": "pytest",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-v"
// Specify a sepcific module/test to run or provide nothing to run all tests
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
]
}
]
}
}

View File

@@ -11,7 +11,7 @@
<a href="https://docs.danswer.dev/" target="_blank">
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
</a>
<a href="https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ" target="_blank">
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ" target="_blank">
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
</a>
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">

2
backend/.gitignore vendored
View File

@@ -5,7 +5,7 @@ site_crawls/
.ipynb_checkpoints/
api_keys.py
*ipynb
.env
.env*
vespa-app.zip
dynamic_config_storage/
celerybeat-schedule*

View File

@@ -0,0 +1,32 @@
"""add search doc relevance details
Revision ID: 05c07bf07c00
Revises: b896bbd0d5a7
Create Date: 2024-07-10 17:48:15.886653
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "05c07bf07c00"
down_revision = "b896bbd0d5a7"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"search_doc",
sa.Column("is_relevant", sa.Boolean(), nullable=True),
)
op.add_column(
"search_doc",
sa.Column("relevance_explanation", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("search_doc", "relevance_explanation")
op.drop_column("search_doc", "is_relevant")

View File

@@ -13,8 +13,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "3a7802814195"
down_revision = "23957775e5f5"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@@ -0,0 +1,65 @@
"""add cloud embedding model and update embedding_model
Revision ID: 44f856ae2a4a
Revises: d716b0791ddd
Create Date: 2024-06-28 20:01:05.927647
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "44f856ae2a4a"
down_revision = "d716b0791ddd"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Create embedding_provider table
op.create_table(
"embedding_provider",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("api_key", sa.LargeBinary(), nullable=True),
sa.Column("default_model_id", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
# Add cloud_provider_id to embedding_model table
op.add_column(
"embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True)
)
# Add foreign key constraints
op.create_foreign_key(
"fk_embedding_model_cloud_provider",
"embedding_model",
"embedding_provider",
["cloud_provider_id"],
["id"],
)
op.create_foreign_key(
"fk_embedding_provider_default_model",
"embedding_provider",
"embedding_model",
["default_model_id"],
["id"],
)
def downgrade() -> None:
# Remove foreign key constraints
op.drop_constraint(
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
)
op.drop_constraint(
"fk_embedding_provider_default_model", "embedding_provider", type_="foreignkey"
)
# Remove cloud_provider_id column
op.drop_column("embedding_model", "cloud_provider_id")
# Drop embedding_provider table
op.drop_table("embedding_provider")

View File

@@ -0,0 +1,23 @@
"""added is_internet to DBDoc
Revision ID: 4505fd7302e1
Revises: c18cdf4b497e
Create Date: 2024-06-18 20:46:09.095034
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4505fd7302e1"
down_revision = "c18cdf4b497e"
def upgrade() -> None:
op.add_column("search_doc", sa.Column("is_internet", sa.Boolean(), nullable=True))
op.add_column("tool", sa.Column("display_name", sa.String(), nullable=True))
def downgrade() -> None:
op.drop_column("tool", "display_name")
op.drop_column("search_doc", "is_internet")

View File

@@ -0,0 +1,35 @@
"""added slack_auto_filter
Revision ID: 7aea705850d5
Revises: 4505fd7302e1
Create Date: 2024-07-10 11:01:23.581015
"""
from alembic import op
import sqlalchemy as sa
revision = "7aea705850d5"
down_revision = "4505fd7302e1"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"slack_bot_config",
sa.Column("enable_auto_filters", sa.Boolean(), nullable=True),
)
op.execute(
"UPDATE slack_bot_config SET enable_auto_filters = FALSE WHERE enable_auto_filters IS NULL"
)
op.alter_column(
"slack_bot_config",
"enable_auto_filters",
existing_type=sa.Boolean(),
nullable=False,
server_default=sa.false(),
)
def downgrade() -> None:
op.drop_column("slack_bot_config", "enable_auto_filters")

View File

@@ -0,0 +1,23 @@
"""backfill is_internet data to False
Revision ID: b896bbd0d5a7
Revises: 44f856ae2a4a
Create Date: 2024-07-16 15:21:05.718571
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "b896bbd0d5a7"
down_revision = "44f856ae2a4a"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute("UPDATE search_doc SET is_internet = FALSE WHERE is_internet IS NULL")
def downgrade() -> None:
pass

View File

@@ -0,0 +1,75 @@
"""Add standard_answer tables
Revision ID: c18cdf4b497e
Revises: 3a7802814195
Create Date: 2024-06-06 15:15:02.000648
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c18cdf4b497e"
down_revision = "3a7802814195"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"standard_answer",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("keyword", sa.String(), nullable=False),
sa.Column("answer", sa.String(), nullable=False),
sa.Column("active", sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("keyword"),
)
op.create_table(
"standard_answer_category",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
op.create_table(
"standard_answer__standard_answer_category",
sa.Column("standard_answer_id", sa.Integer(), nullable=False),
sa.Column("standard_answer_category_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["standard_answer_category_id"],
["standard_answer_category.id"],
),
sa.ForeignKeyConstraint(
["standard_answer_id"],
["standard_answer.id"],
),
sa.PrimaryKeyConstraint("standard_answer_id", "standard_answer_category_id"),
)
op.create_table(
"slack_bot_config__standard_answer_category",
sa.Column("slack_bot_config_id", sa.Integer(), nullable=False),
sa.Column("standard_answer_category_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["slack_bot_config_id"],
["slack_bot_config.id"],
),
sa.ForeignKeyConstraint(
["standard_answer_category_id"],
["standard_answer_category.id"],
),
sa.PrimaryKeyConstraint("slack_bot_config_id", "standard_answer_category_id"),
)
op.add_column(
"chat_session", sa.Column("slack_thread_id", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("chat_session", "slack_thread_id")
op.drop_table("slack_bot_config__standard_answer_category")
op.drop_table("standard_answer__standard_answer_category")
op.drop_table("standard_answer_category")
op.drop_table("standard_answer")

View File

@@ -0,0 +1,45 @@
"""combined slack id fields
Revision ID: d716b0791ddd
Revises: 7aea705850d5
Create Date: 2024-07-10 17:57:45.630550
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "d716b0791ddd"
down_revision = "7aea705850d5"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute(
"""
UPDATE slack_bot_config
SET channel_config = jsonb_set(
channel_config,
'{respond_member_group_list}',
coalesce(channel_config->'respond_team_member_list', '[]'::jsonb) ||
coalesce(channel_config->'respond_slack_group_list', '[]'::jsonb)
) - 'respond_team_member_list' - 'respond_slack_group_list'
"""
)
def downgrade() -> None:
op.execute(
"""
UPDATE slack_bot_config
SET channel_config = jsonb_set(
jsonb_set(
channel_config - 'respond_member_group_list',
'{respond_team_member_list}',
'[]'::jsonb
),
'{respond_slack_group_list}',
'[]'::jsonb
)
"""
)

View File

@@ -98,7 +98,6 @@ def _run_indexing(
3. Updates Postgres to record the indexed documents + the outcome of this run
"""
start_time = time.time()
db_embedding_model = index_attempt.embedding_model
index_name = db_embedding_model.index_name
@@ -116,6 +115,8 @@ def _run_indexing(
normalize=db_embedding_model.normalize,
query_prefix=db_embedding_model.query_prefix,
passage_prefix=db_embedding_model.passage_prefix,
api_key=db_embedding_model.api_key,
provider_type=db_embedding_model.provider_type,
)
indexing_pipeline = build_indexing_pipeline(
@@ -287,6 +288,7 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
db_session=db_session,
index_attempt_id=index_attempt_id,
)
if attempt is None:
raise RuntimeError(f"Unable to find IndexAttempt for ID '{index_attempt_id}'")

View File

@@ -343,13 +343,15 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
logger.info("Running a first inference to warm up embedding model")
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
if db_embedding_model.cloud_provider_id is None:
logger.info("Running a first inference to warm up embedding model")
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
client_primary: Client | SimpleJobClient
client_secondary: Client | SimpleJobClient

View File

@@ -1,5 +1,4 @@
import re
from collections.abc import Sequence
from typing import cast
from sqlalchemy.orm import Session
@@ -9,42 +8,30 @@ from danswer.chat.models import LlmDoc
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.models import ChatMessage
from danswer.llm.answering.models import PreviousMessage
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.utils.logger import setup_logger
logger = setup_logger()
def llm_doc_from_inference_section(inf_chunk: InferenceSection) -> LlmDoc:
def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc:
return LlmDoc(
document_id=inf_chunk.document_id,
document_id=inference_section.center_chunk.document_id,
# This one is using the combined content of all the chunks of the section
# In default settings, this is the same as just the content of base chunk
content=inf_chunk.combined_content,
blurb=inf_chunk.blurb,
semantic_identifier=inf_chunk.semantic_identifier,
source_type=inf_chunk.source_type,
metadata=inf_chunk.metadata,
updated_at=inf_chunk.updated_at,
link=inf_chunk.source_links[0] if inf_chunk.source_links else None,
source_links=inf_chunk.source_links,
content=inference_section.combined_content,
blurb=inference_section.center_chunk.blurb,
semantic_identifier=inference_section.center_chunk.semantic_identifier,
source_type=inference_section.center_chunk.source_type,
metadata=inference_section.center_chunk.metadata,
updated_at=inference_section.center_chunk.updated_at,
link=inference_section.center_chunk.source_links[0]
if inference_section.center_chunk.source_links
else None,
source_links=inference_section.center_chunk.source_links,
)
def map_document_id_order(
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
) -> dict[str, int]:
order_mapping = {}
current = 1 if one_indexed else 0
for chunk in chunks:
if chunk.document_id not in order_mapping:
order_mapping[chunk.document_id] = current
current += 1
return order_mapping
def create_chat_chain(
chat_session_id: int,
db_session: Session,

View File

@@ -1,5 +1,3 @@
from typing import cast
import yaml
from sqlalchemy.orm import Session
@@ -50,7 +48,7 @@ def load_personas_from_yaml(
with Session(get_sqlalchemy_engine()) as db_session:
for persona in all_personas:
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] | None = [
doc_sets: list[DocumentSetDBModel] = [
get_or_create_document_set_by_name(db_session, name)
for name in doc_set_names
]
@@ -58,22 +56,24 @@ def load_personas_from_yaml(
# Assume if user hasn't set any document sets for the persona, the user may want
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
# the document sets for the persona
if not doc_sets:
doc_sets = None
prompt_set_names = persona["prompts"]
if not prompt_set_names:
prompts: list[PromptDBModel | None] | None = None
doc_set_ids: list[int] | None = None
if doc_sets:
doc_set_ids = [doc_set.id for doc_set in doc_sets]
else:
prompts = [
doc_set_ids = None
prompt_ids: list[int] | None = None
prompt_set_names = persona["prompts"]
if prompt_set_names:
prompts: list[PromptDBModel | None] = [
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
for prompt_name in prompt_set_names
]
if any([prompt is None for prompt in prompts]):
raise ValueError("Invalid Persona configs, not all prompts exist")
if not prompts:
prompts = None
if prompts:
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
p_id = persona.get("id")
upsert_persona(
@@ -91,8 +91,8 @@ def load_personas_from_yaml(
llm_model_provider_override=None,
llm_model_version_override=None,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompts=cast(list[PromptDBModel] | None, prompts),
document_sets=doc_sets,
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
default_persona=True,
is_public=True,
db_session=db_session,

View File

@@ -42,11 +42,21 @@ class QADocsResponse(RetrievalDocs):
return initial_dict
# Second chunk of info for streaming QA
class LLMRelevanceFilterResponse(BaseModel):
relevant_chunk_indices: list[int]
class RelevanceChunk(BaseModel):
# TODO make this document level. Also slight misnomer here as this is actually
# done at the section level currently rather than the chunk
relevant: bool | None = None
content: str | None = None
class LLMRelevanceSummaryResponse(BaseModel):
relevance_summaries: dict[str, RelevanceChunk]
class DanswerAnswerPiece(BaseModel):
# A small piece of a complete answer. Used for streaming back answers.
answer_piece: str | None # if None, specifies the end of an Answer

View File

@@ -10,14 +10,15 @@ from danswer.chat.models import CitationInfo
from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import ImageGenerationDisplay
from danswer.chat.models import LlmDoc
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.chat import attach_files_to_chat_message
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
@@ -49,9 +50,13 @@ from danswer.llm.answering.models import PromptConfig
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.search.enums import OptionalSearchSetting
from danswer.search.retrieval.search_runner import inference_documents_from_ids
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import InferenceSection
from danswer.search.retrieval.search_runner import inference_sections_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
from danswer.search.utils import drop_llm_indices
@@ -66,6 +71,14 @@ from danswer.tools.force import ForceUseTool
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
from danswer.tools.images.image_generation_tool import ImageGenerationTool
from danswer.tools.internet_search.internet_search_tool import (
INTERNET_SEARCH_RESPONSE_ID,
)
from danswer.tools.internet_search.internet_search_tool import (
internet_search_response_to_search_docs,
)
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
@@ -143,6 +156,37 @@ def _handle_search_tool_response_summary(
)
def _handle_internet_search_tool_response_summary(
packet: ToolResponse,
db_session: Session,
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
internet_search_response = cast(InternetSearchResponse, packet.response)
server_search_docs = internet_search_response_to_search_docs(
internet_search_response
)
reference_db_search_docs = [
create_db_search_doc(server_search_doc=doc, db_session=db_session)
for doc in server_search_docs
]
response_docs = [
translate_db_search_doc_to_server_search_doc(db_search_doc)
for db_search_doc in reference_db_search_docs
]
return (
QADocsResponse(
rephrased_query=internet_search_response.revised_query,
top_documents=response_docs,
predicted_flow=QueryFlow.QUESTION_ANSWER,
predicted_search=SearchType.HYBRID,
applied_source_filters=[],
applied_time_cutoff=None,
recency_bias_multiplier=1.0,
),
reference_db_search_docs,
)
def _check_should_force_search(
new_msg_req: CreateChatMessageRequest,
) -> ForceUseTool | None:
@@ -170,7 +214,7 @@ def _check_should_force_search(
args = {"query": new_msg_req.message}
return ForceUseTool(
tool_name=SearchTool.NAME,
tool_name=SearchTool._NAME,
args=args,
)
return None
@@ -338,7 +382,7 @@ def stream_chat_message_objects(
)
selected_db_search_docs = None
selected_llm_docs: list[LlmDoc] | None = None
selected_sections: list[InferenceSection] | None = None
if reference_doc_ids:
identifier_tuples = get_doc_query_identifiers_from_model(
search_doc_ids=reference_doc_ids,
@@ -348,8 +392,8 @@ def stream_chat_message_objects(
)
# Generates full documents currently
# May extend to include chunk ranges
selected_llm_docs = inference_documents_from_ids(
# May extend to use sections instead in the future
selected_sections = inference_sections_from_ids(
doc_identifiers=identifier_tuples,
document_index=document_index,
)
@@ -428,20 +472,20 @@ def stream_chat_message_objects(
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
selected_docs=selected_llm_docs,
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
)
tool_dict[db_tool_model.id] = [search_tool]
elif tool_cls.__name__ == ImageGenerationTool.__name__:
dalle_key = None
img_generation_llm_config: LLMConfig | None = None
if (
llm
and llm.config.api_key
and llm.config.model_provider == "openai"
):
dalle_key = llm.config.api_key
img_generation_llm_config = llm.config
else:
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
@@ -458,13 +502,31 @@ def stream_chat_message_objects(
raise ValueError(
"Image generation tool requires an OpenAI API key"
)
dalle_key = openai_provider.api_key
img_generation_llm_config = LLMConfig(
model_provider=openai_provider.provider,
model_name=openai_provider.default_model_name,
temperature=GEN_AI_TEMPERATURE,
api_key=openai_provider.api_key,
api_base=openai_provider.api_base,
api_version=openai_provider.api_version,
)
tool_dict[db_tool_model.id] = [
ImageGenerationTool(
api_key=dalle_key,
api_key=cast(str, img_generation_llm_config.api_key),
api_base=img_generation_llm_config.api_base,
api_version=img_generation_llm_config.api_version,
additional_headers=litellm_additional_headers,
)
]
elif tool_cls.__name__ == InternetSearchTool.__name__:
bing_api_key = BING_API_KEY
if not bing_api_key:
raise ValueError(
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
)
tool_dict[db_tool_model.id] = [
InternetSearchTool(api_key=bing_api_key)
]
continue
@@ -571,6 +633,15 @@ def stream_chat_message_objects(
yield ImageGenerationDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
(
qa_docs_response,
reference_db_search_docs,
) = _handle_internet_search_tool_response_summary(
packet=packet,
db_session=db_session,
)
yield qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
yield CustomToolResponse(
@@ -612,7 +683,7 @@ def stream_chat_message_objects(
tool_name_to_tool_id: dict[str, int] = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name()] = tool_id
tool_name_to_tool_id[tool.name] = tool_id
gen_ai_response_message = partial_response(
message=answer.llm_answer,

View File

@@ -223,6 +223,11 @@ MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE = int(
os.environ.get("MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE", 0)
)
# comma delimited list of zendesk article labels to skip indexing for
ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS = os.environ.get(
"ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS", ""
).split(",")
#####
# Indexing Configs
@@ -243,13 +248,15 @@ DISABLE_INDEX_UPDATE_ON_SWAP = (
# fairly large amount of memory in order to increase substantially, since
# each worker loads the embedding models into memory.
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
CHUNK_OVERLAP = 0
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
# Finer grained chunking for more detail retention
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
MINI_CHUNK_SIZE = 150
# Include the document level metadata in each chunk. If the metadata is too long, then it is thrown out
# We don't want the metadata to overwhelm the actual contents of the chunk
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"
# Timeout to wait for job's last update before killing it, in hours
CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3))
@@ -266,10 +273,14 @@ JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
CURRENT_PROCESS_IS_AN_INDEXING_JOB = (
os.environ.get("CURRENT_PROCESS_IS_AN_INDEXING_JOB", "").lower() == "true"
)
# Logs every model prompt and output, mostly used for development or exploration purposes
# Sets LiteLLM to verbose logging
LOG_ALL_MODEL_INTERACTIONS = (
os.environ.get("LOG_ALL_MODEL_INTERACTIONS", "").lower() == "true"
)
# Logs Danswer only model interactions like prompts, responses, messages etc.
LOG_DANSWER_MODEL_INTERACTIONS = (
os.environ.get("LOG_DANSWER_MODEL_INTERACTIONS", "").lower() == "true"
)
# If set to `true` will enable additional logs about Vespa query performance
# (time spent on finding the right docs + time spent fetching summaries from disk)
LOG_VESPA_TIMING_INFORMATION = (

View File

@@ -5,7 +5,10 @@ PROMPTS_YAML = "./danswer/chat/prompts.yaml"
PERSONAS_YAML = "./danswer/chat/personas.yaml"
NUM_RETURNED_HITS = 50
NUM_RERANKED_RESULTS = 15
# Used for LLM filtering and reranking
# We want this to be approximately the number of results we want to show on the first page
# It cannot be too large due to cost and latency implications
NUM_RERANKED_RESULTS = 20
# May be less depending on model
MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
@@ -25,9 +28,10 @@ BASE_RECENCY_DECAY = 0.5
FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
# Currently this next one is not configurable via env
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
DISABLE_LLM_FILTER_EXTRACTION = (
os.environ.get("DISABLE_LLM_FILTER_EXTRACTION", "").lower() == "true"
)
# For the highest matching base size chunk, how many chunks above and below do we pull in by default
# Note this is not in any of the deployment configs yet
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 0)
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 0)
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
# in relation to the user query
DISABLE_LLM_CHUNK_FILTER = (
@@ -43,8 +47,6 @@ DISABLE_LLM_QUERY_REPHRASE = (
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
# Include additional document/chunk metadata in prompt to GenerativeAI
INCLUDE_METADATA = False
# Keyword Search Drop Stopwords
# If user has changed the default model, would most likely be to use a multilingual
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords
@@ -73,8 +75,22 @@ LANGUAGE_CHAT_NAMING_HINT = (
or "The name of the conversation must be in the same language as the user query."
)
# Agentic search takes significantly more tokens and therefore has much higher cost.
# This configuration allows users to get a search-only experience with instant results
# and no involvement from the LLM.
# Additionally, some LLM providers have strict rate limits which may prohibit
# sending many API requests at once (as is done in agentic search).
DISABLE_AGENTIC_SEARCH = (
os.environ.get("DISABLE_AGENTIC_SEARCH") or "false"
).lower() == "true"
# Stops streaming answers back to the UI if this pattern is seen:
STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
# The backend logic for this being True isn't fully supported yet
HARD_DELETE_CHATS = False
# Internet Search
BING_API_KEY = os.environ.get("BING_API_KEY") or None

View File

@@ -19,6 +19,7 @@ DOCUMENT_SETS = "document_sets"
TIME_FILTER = "time_filter"
METADATA = "metadata"
METADATA_LIST = "metadata_list"
METADATA_SUFFIX = "metadata_suffix"
MATCH_HIGHLIGHTS = "match_highlights"
# stored in the `metadata` of a chunk. Used to signify that this chunk should
# not be used for QA. For example, Google Drive file types which can't be parsed
@@ -43,7 +44,8 @@ QUERY_EVENT_ID = "query_event_id"
LLM_CHUNKS = "llm_chunks"
# For chunking/processing chunks
TITLE_SEPARATOR = "\n\r\n"
MAX_CHUNK_TITLE_LEN = 1000
RETURN_SEPARATOR = "\n\r\n"
SECTION_SEPARATOR = "\n\n"
# For combining attributes, doesn't have to be unique/perfect to work
INDEX_SEPARATOR = "==="
@@ -104,6 +106,7 @@ class DocumentSource(str, Enum):
R2 = "r2"
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
OCI_STORAGE = "oci_storage"
NOT_APPLICABLE = "not_applicable"
class BlobType(str, Enum):
@@ -112,6 +115,9 @@ class BlobType(str, Enum):
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
OCI_STORAGE = "oci_storage"
# Special case, for internet search
NOT_APPLICABLE = "not_applicable"
class DocumentIndexType(str, Enum):
COMBINED = "combined" # Vespa

View File

@@ -47,10 +47,6 @@ DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
DANSWER_BOT_RESPOND_EVERY_CHANNEL = (
os.environ.get("DANSWER_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true"
)
# Auto detect query options like time cutoff or heavily favor recently updated docs
DISABLE_DANSWER_BOT_FILTER_DETECT = (
os.environ.get("DISABLE_DANSWER_BOT_FILTER_DETECT", "").lower() == "true"
)
# Add a second LLM call post Answer to verify if the Answer is valid
# Throws out answers that don't directly or fully answer the user query
# This is the default for all DanswerBot channels unless the channel is configured individually

View File

@@ -39,8 +39,8 @@ ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ")
# Purely an optimization, memory limitation consideration
BATCH_SIZE_ENCODE_CHUNKS = 8
# For score display purposes, only way is to know the expected ranges
CROSS_ENCODER_RANGE_MAX = 12
CROSS_ENCODER_RANGE_MIN = -12
CROSS_ENCODER_RANGE_MAX = 1
CROSS_ENCODER_RANGE_MIN = 0
# Unused currently, can't be used with the current default encoder model due to its output range
SEARCH_DISTANCE_CUTOFF = 0

View File

@@ -37,16 +37,18 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
# Potential Improvements
# 1. If wiki page instead of space, do a search of all the children of the page instead of index all in the space
# 2. Include attachments, etc
# 3. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
# 1. Include attachments, etc
# 2. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str]:
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]:
"""Sample
https://danswer.atlassian.net/wiki/spaces/1234abcd/overview
URL w/ page: https://danswer.atlassian.net/wiki/spaces/1234abcd/pages/5678efgh/overview
URL w/o page: https://danswer.atlassian.net/wiki/spaces/ASAM/overview
wiki_base is https://danswer.atlassian.net/wiki
space is 1234abcd
page_id is 5678efgh
"""
parsed_url = urlparse(wiki_url)
wiki_base = (
@@ -55,18 +57,25 @@ def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str]:
+ parsed_url.netloc
+ parsed_url.path.split("/spaces")[0]
)
space = parsed_url.path.split("/")[3]
return wiki_base, space
path_parts = parsed_url.path.split("/")
space = path_parts[3]
page_id = path_parts[5] if len(path_parts) > 5 else ""
return wiki_base, space, page_id
def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, str]:
def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, str, str]:
"""Sample
https://danswer.ai/confluence/display/1234abcd/overview
URL w/ page https://danswer.ai/confluence/display/1234abcd/pages/5678efgh/overview
URL w/o page https://danswer.ai/confluence/display/1234abcd/overview
wiki_base is https://danswer.ai/confluence
space is 1234abcd
page_id is 5678efgh
"""
# /display/ is always right before the space and at the end of the base url
# /display/ is always right before the space and at the end of the base print()
DISPLAY = "/display/"
PAGE = "/pages/"
parsed_url = urlparse(wiki_url)
wiki_base = (
@@ -76,10 +85,13 @@ def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, st
+ parsed_url.path.split(DISPLAY)[0]
)
space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0]
return wiki_base, space
page_id = ""
if (content := parsed_url.path.split(PAGE)) and len(content) > 1:
page_id = content[1]
return wiki_base, space, page_id
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, bool]:
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]:
is_confluence_cloud = (
".atlassian.net/wiki/spaces/" in wiki_url
or ".jira.com/wiki/spaces/" in wiki_url
@@ -87,15 +99,19 @@ def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, bool]:
try:
if is_confluence_cloud:
wiki_base, space = _extract_confluence_keys_from_cloud_url(wiki_url)
wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url(
wiki_url
)
else:
wiki_base, space = _extract_confluence_keys_from_datacenter_url(wiki_url)
wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url(
wiki_url
)
except Exception as e:
error_msg = f"Not a valid Confluence Wiki Link, unable to extract wiki base and space names. Exception: {e}"
error_msg = f"Not a valid Confluence Wiki Link, unable to extract wiki base, space, and page id. Exception: {e}"
logger.error(error_msg)
raise ValueError(error_msg)
return wiki_base, space, is_confluence_cloud
return wiki_base, space, page_id, is_confluence_cloud
@lru_cache()
@@ -196,10 +212,135 @@ def _comment_dfs(
return comments_str
class RecursiveIndexer:
def __init__(
self,
batch_size: int,
confluence_client: Confluence,
index_origin: bool,
origin_page_id: str,
) -> None:
self.batch_size = 1
# batch_size
self.confluence_client = confluence_client
self.index_origin = index_origin
self.origin_page_id = origin_page_id
self.pages = self.recurse_children_pages(0, self.origin_page_id)
def get_pages(self, ind: int, size: int) -> list[dict]:
if ind * size > len(self.pages):
return []
return self.pages[ind * size : (ind + 1) * size]
def _fetch_origin_page(
self,
) -> dict[str, Any]:
get_page_by_id = make_confluence_call_handle_rate_limit(
self.confluence_client.get_page_by_id
)
try:
origin_page = get_page_by_id(
self.origin_page_id, expand="body.storage.value,version"
)
return origin_page
except Exception as e:
logger.warning(
f"Appending orgin page with id {self.origin_page_id} failed: {e}"
)
return {}
def recurse_children_pages(
self,
start_ind: int,
page_id: str,
) -> list[dict[str, Any]]:
pages: list[dict[str, Any]] = []
current_level_pages: list[dict[str, Any]] = []
next_level_pages: list[dict[str, Any]] = []
# Initial fetch of first level children
index = start_ind
while batch := self._fetch_single_depth_child_pages(
index, self.batch_size, page_id
):
current_level_pages.extend(batch)
index += len(batch)
pages.extend(current_level_pages)
# Recursively index children and children's children, etc.
while current_level_pages:
for child in current_level_pages:
child_index = 0
while child_batch := self._fetch_single_depth_child_pages(
child_index, self.batch_size, child["id"]
):
next_level_pages.extend(child_batch)
child_index += len(child_batch)
pages.extend(next_level_pages)
current_level_pages = next_level_pages
next_level_pages = []
if self.index_origin:
try:
origin_page = self._fetch_origin_page()
pages.append(origin_page)
except Exception as e:
logger.warning(f"Appending origin page with id {page_id} failed: {e}")
return pages
def _fetch_single_depth_child_pages(
self, start_ind: int, batch_size: int, page_id: str
) -> list[dict[str, Any]]:
child_pages: list[dict[str, Any]] = []
get_page_child_by_type = make_confluence_call_handle_rate_limit(
self.confluence_client.get_page_child_by_type
)
try:
child_page = get_page_child_by_type(
page_id,
type="page",
start=start_ind,
limit=batch_size,
expand="body.storage.value,version",
)
child_pages.extend(child_page)
return child_pages
except Exception:
logger.warning(
f"Batch failed with page {page_id} at offset {start_ind} "
f"with size {batch_size}, processing pages individually..."
)
for i in range(batch_size):
ind = start_ind + i
try:
child_page = get_page_child_by_type(
page_id,
type="page",
start=ind,
limit=1,
expand="body.storage.value,version",
)
child_pages.extend(child_page)
except Exception as e:
logger.warning(f"Page {page_id} at offset {ind} failed: {e}")
raise e
return child_pages
class ConfluenceConnector(LoadConnector, PollConnector):
def __init__(
self,
wiki_page_url: str,
index_origin: bool = True,
batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
# if a page has one of the labels specified in this list, we will just
@@ -210,11 +351,27 @@ class ConfluenceConnector(LoadConnector, PollConnector):
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
self.labels_to_skip = set(labels_to_skip)
self.wiki_base, self.space, self.is_cloud = extract_confluence_keys_from_url(
wiki_page_url
)
self.recursive_indexer: RecursiveIndexer | None = None
self.index_origin = index_origin
(
self.wiki_base,
self.space,
self.page_id,
self.is_cloud,
) = extract_confluence_keys_from_url(wiki_page_url)
self.space_level_scan = False
self.confluence_client: Confluence | None = None
if self.page_id is None or self.page_id == "":
self.space_level_scan = True
logger.info(
f"wiki_base: {self.wiki_base}, space: {self.space}, page_id: {self.page_id},"
+ f" space_level_scan: {self.space_level_scan}, origin: {self.index_origin}"
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
username = credentials["confluence_username"]
access_token = credentials["confluence_access_token"]
@@ -232,8 +389,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
self,
confluence_client: Confluence,
start_ind: int,
) -> Collection[dict[str, Any]]:
def _fetch(start_ind: int, batch_size: int) -> Collection[dict[str, Any]]:
) -> list[dict[str, Any]]:
def _fetch_space(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
get_all_pages_from_space = make_confluence_call_handle_rate_limit(
confluence_client.get_all_pages_from_space
)
@@ -242,9 +399,11 @@ class ConfluenceConnector(LoadConnector, PollConnector):
self.space,
start=start_ind,
limit=batch_size,
status="current"
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
else None,
status=(
"current"
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
else None
),
expand="body.storage.value,version",
)
except Exception:
@@ -263,9 +422,11 @@ class ConfluenceConnector(LoadConnector, PollConnector):
self.space,
start=start_ind + i,
limit=1,
status="current"
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
else None,
status=(
"current"
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
else None
),
expand="body.storage.value,version",
)
)
@@ -286,17 +447,41 @@ class ConfluenceConnector(LoadConnector, PollConnector):
return view_pages
def _fetch_page(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
if self.recursive_indexer is None:
self.recursive_indexer = RecursiveIndexer(
origin_page_id=self.page_id,
batch_size=self.batch_size,
confluence_client=self.confluence_client,
index_origin=self.index_origin,
)
return self.recursive_indexer.get_pages(start_ind, batch_size)
pages: list[dict[str, Any]] = []
try:
return _fetch(start_ind, self.batch_size)
pages = (
_fetch_space(start_ind, self.batch_size)
if self.space_level_scan
else _fetch_page(start_ind, self.batch_size)
)
return pages
except Exception as e:
if not self.continue_on_failure:
raise e
# error checking phase, only reachable if `self.continue_on_failure=True`
pages: list[dict[str, Any]] = []
for i in range(self.batch_size):
try:
pages.extend(_fetch(start_ind + i, 1))
pages = (
_fetch_space(start_ind, self.batch_size)
if self.space_level_scan
else _fetch_page(start_ind, self.batch_size)
)
return pages
except Exception:
logger.exception(
"Ran into exception when fetching pages from Confluence"
@@ -308,6 +493,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
get_page_child_by_type = make_confluence_call_handle_rate_limit(
confluence_client.get_page_child_by_type
)
try:
comment_pages = cast(
Collection[dict[str, Any]],
@@ -356,7 +542,14 @@ class ConfluenceConnector(LoadConnector, PollConnector):
page_id, start=0, limit=500
)
for attachment in attachments_container["results"]:
if attachment["metadata"]["mediaType"] in ["image/jpeg", "image/png"]:
if attachment["metadata"]["mediaType"] in [
"image/jpeg",
"image/png",
"image/gif",
"image/svg+xml",
"video/mp4",
"video/quicktime",
]:
continue
if attachment["title"] not in files_in_used:
@@ -367,9 +560,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if response.status_code == 200:
extract = extract_file_text(
attachment["title"],
io.BytesIO(response.content),
break_on_unprocessable=False,
attachment["title"], io.BytesIO(response.content), False
)
files_attachment_content.append(extract)
@@ -389,8 +580,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
batch = self._fetch_pages(self.confluence_client, start_ind)
for page in batch:
last_modified_str = page["version"]["when"]
author = cast(str | None, page["version"].get("by", {}).get("email"))
@@ -405,6 +596,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if time_filter is None or time_filter(last_modified):
page_id = page["id"]
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
page_labels = self._fetch_labels(self.confluence_client, page_id)
@@ -416,6 +608,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
f"Page with ID '{page_id}' has a label which has been "
f"designated as disallowed: {label_intersection}. Skipping."
)
continue
page_html = (
@@ -436,7 +629,6 @@ class ConfluenceConnector(LoadConnector, PollConnector):
page_text += attachment_text
comments_text = self._fetch_comments(self.confluence_client, page_id)
page_text += comments_text
doc_metadata: dict[str, str | list[str]] = {
"Wiki Space Name": self.space
}
@@ -450,9 +642,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
source=DocumentSource.CONFLUENCE,
semantic_identifier=page["title"],
doc_updated_at=last_modified,
primary_owners=[BasicExpertInfo(email=author)]
if author
else None,
primary_owners=(
[BasicExpertInfo(email=author)] if author else None
),
metadata=doc_metadata,
)
)

View File

@@ -1,10 +1,14 @@
import time
from collections.abc import Callable
from typing import Any
from typing import cast
from typing import TypeVar
from requests import HTTPError
from retry import retry
from danswer.utils.logger import setup_logger
logger = setup_logger()
F = TypeVar("F", bound=Callable[..., Any])
@@ -18,23 +22,38 @@ class ConfluenceRateLimitError(Exception):
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
@retry(
exceptions=ConfluenceRateLimitError,
tries=10,
delay=1,
max_delay=600, # 10 minutes
backoff=2,
jitter=1,
)
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
try:
return confluence_call(*args, **kwargs)
except HTTPError as e:
if (
e.response.status_code == 429
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
):
raise ConfluenceRateLimitError()
raise
starting_delay = 5
backoff = 2
max_delay = 600
for attempt in range(10):
try:
return confluence_call(*args, **kwargs)
except HTTPError as e:
if (
e.response.status_code == 429
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
):
retry_after = None
try:
retry_after = int(e.response.headers.get("Retry-After"))
except (ValueError, TypeError):
pass
if retry_after:
logger.warning(
f"Rate limit hit. Retrying after {retry_after} seconds..."
)
time.sleep(retry_after)
else:
logger.warning(
"Rate limit hit. Retrying with exponential backoff..."
)
delay = min(starting_delay * (backoff**attempt), max_delay)
time.sleep(delay)
else:
# re-raise, let caller handle
raise
return cast(F, wrapped_call)

View File

@@ -6,6 +6,7 @@ from typing import TypeVar
from dateutil.parser import parse
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.connectors.models import BasicExpertInfo
from danswer.utils.text_processing import is_valid_email
@@ -57,3 +58,7 @@ def process_in_batches(
) -> Iterator[list[U]]:
for i in range(0, len(objects), batch_size):
yield [process_function(obj) for obj in objects[i : i + batch_size]]
def get_metadata_keys_to_ignore() -> list[str]:
return [IGNORE_FOR_QA]

View File

@@ -11,6 +11,9 @@ from requests import Response
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import PollConnector
@@ -58,67 +61,36 @@ class DiscourseConnector(PollConnector):
self.category_id_map: dict[int, str] = {}
self.batch_size = batch_size
self.permissions: DiscoursePerms | None = None
self.active_categories: set | None = None
@rate_limit_builder(max_calls=100, period=60)
def _make_request(self, endpoint: str, params: dict | None = None) -> Response:
if not self.permissions:
raise ConnectorMissingCredentialError("Discourse")
return discourse_request(endpoint, self.permissions, params)
def _get_categories_map(
self,
) -> None:
assert self.permissions is not None
categories_endpoint = urllib.parse.urljoin(self.base_url, "categories.json")
response = discourse_request(
response = self._make_request(
endpoint=categories_endpoint,
perms=self.permissions,
params={"include_subcategories": True},
)
categories = response.json()["category_list"]["categories"]
self.category_id_map = {
category["id"]: category["name"]
for category in categories
if not self.categories or category["name"].lower() in self.categories
cat["id"]: cat["name"]
for cat in categories
if not self.categories or cat["name"].lower() in self.categories
}
def _get_latest_topics(
self, start: datetime | None, end: datetime | None
) -> list[int]:
assert self.permissions is not None
topic_ids = []
valid_categories = set(self.category_id_map.keys())
latest_endpoint = urllib.parse.urljoin(self.base_url, "latest.json")
response = discourse_request(endpoint=latest_endpoint, perms=self.permissions)
topics = response.json()["topic_list"]["topics"]
for topic in topics:
last_time = topic.get("last_posted_at")
if not last_time:
continue
last_time_dt = time_str_to_utc(last_time)
if start and start > last_time_dt:
continue
if end and end < last_time_dt:
continue
if (
self.categories
and valid_categories
and topic.get("category_id") not in valid_categories
):
continue
topic_ids.append(topic["id"])
return topic_ids
self.active_categories = set(self.category_id_map)
def _get_doc_from_topic(self, topic_id: int) -> Document:
assert self.permissions is not None
topic_endpoint = urllib.parse.urljoin(self.base_url, f"t/{topic_id}.json")
response = discourse_request(
endpoint=topic_endpoint,
perms=self.permissions,
)
response = self._make_request(endpoint=topic_endpoint)
topic = response.json()
topic_url = urllib.parse.urljoin(self.base_url, f"t/{topic['slug']}")
@@ -167,26 +139,78 @@ class DiscourseConnector(PollConnector):
)
return doc
def _get_latest_topics(
self, start: datetime | None, end: datetime | None, page: int
) -> list[int]:
assert self.permissions is not None
topic_ids = []
if not self.categories:
latest_endpoint = urllib.parse.urljoin(
self.base_url, f"latest.json?page={page}"
)
response = self._make_request(endpoint=latest_endpoint)
topics = response.json()["topic_list"]["topics"]
else:
topics = []
empty_categories = []
for category_id in self.category_id_map.keys():
category_endpoint = urllib.parse.urljoin(
self.base_url, f"c/{category_id}.json?page={page}&sys=latest"
)
response = self._make_request(endpoint=category_endpoint)
new_topics = response.json()["topic_list"]["topics"]
if len(new_topics) == 0:
empty_categories.append(category_id)
topics.extend(new_topics)
for empty_category in empty_categories:
self.category_id_map.pop(empty_category)
for topic in topics:
last_time = topic.get("last_posted_at")
if not last_time:
continue
last_time_dt = time_str_to_utc(last_time)
if (start and start > last_time_dt) or (end and end < last_time_dt):
continue
topic_ids.append(topic["id"])
if len(topic_ids) >= self.batch_size:
break
return topic_ids
def _yield_discourse_documents(
self, topic_ids: list[int]
self,
start: datetime,
end: datetime,
) -> GenerateDocumentsOutput:
doc_batch: list[Document] = []
for topic_id in topic_ids:
doc_batch.append(self._get_doc_from_topic(topic_id))
page = 1
while topic_ids := self._get_latest_topics(start, end, page):
doc_batch: list[Document] = []
for topic_id in topic_ids:
doc_batch.append(self._get_doc_from_topic(topic_id))
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if len(doc_batch) >= self.batch_size:
if doc_batch:
yield doc_batch
doc_batch = []
page += 1
if doc_batch:
yield doc_batch
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
def load_credentials(
self,
credentials: dict[str, Any],
) -> dict[str, Any] | None:
self.permissions = DiscoursePerms(
api_key=credentials["discourse_api_key"],
api_username=credentials["discourse_api_username"],
)
return None
def poll_source(
@@ -194,16 +218,13 @@ class DiscourseConnector(PollConnector):
) -> GenerateDocumentsOutput:
if self.permissions is None:
raise ConnectorMissingCredentialError("Discourse")
start_datetime = datetime.utcfromtimestamp(start).replace(tzinfo=timezone.utc)
end_datetime = datetime.utcfromtimestamp(end).replace(tzinfo=timezone.utc)
self._get_categories_map()
latest_topic_ids = self._get_latest_topics(
start=start_datetime, end=end_datetime
)
yield from self._yield_discourse_documents(latest_topic_ids)
yield from self._yield_discourse_documents(start_datetime, end_datetime)
if __name__ == "__main__":
@@ -219,7 +240,5 @@ if __name__ == "__main__":
current = time.time()
one_year_ago = current - 24 * 60 * 60 * 360
latest_docs = connector.poll_source(one_year_ago, current)
print(next(latest_docs))

View File

@@ -85,6 +85,11 @@ def _process_file(
all_metadata = {**metadata, **file_metadata} if metadata else file_metadata
# add a prefix to avoid conflicts with other connectors
doc_id = f"FILE_CONNECTOR__{file_name}"
if metadata:
doc_id = metadata.get("document_id") or doc_id
# If this is set, we will show this in the UI as the "name" of the file
file_display_name = all_metadata.get("file_display_name") or os.path.basename(
file_name
@@ -106,6 +111,7 @@ def _process_file(
for k, v in all_metadata.items()
if k
not in [
"document_id",
"time_updated",
"doc_updated_at",
"link",
@@ -132,7 +138,7 @@ def _process_file(
return [
Document(
id=f"FILE_CONNECTOR__{file_name}", # add a prefix to avoid conflicts with other connectors
id=doc_id,
sections=[
Section(link=all_metadata.get("link"), text=file_content_raw.strip())
],

View File

@@ -6,6 +6,7 @@ from pydantic import BaseModel
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import INDEX_SEPARATOR
from danswer.configs.constants import RETURN_SEPARATOR
from danswer.utils.text_processing import make_url_compatible
@@ -117,7 +118,12 @@ class DocumentBase(BaseModel):
# If title is explicitly empty, return a None here for embedding purposes
if self.title == "":
return None
return self.semantic_identifier if self.title is None else self.title
replace_chars = set(RETURN_SEPARATOR)
title = self.semantic_identifier if self.title is None else self.title
for char in replace_chars:
title = title.replace(char, " ")
title = title.strip()
return title
def get_metadata_str_attributes(self) -> list[str] | None:
if not self.metadata:

View File

@@ -368,7 +368,7 @@ class NotionConnector(LoadConnector, PollConnector):
compare_time = time.mktime(
time.strptime(page[filter_field], "%Y-%m-%dT%H:%M:%S.000Z")
)
if compare_time <= end or compare_time > start:
if compare_time > start and compare_time <= end:
filtered_pages += [NotionPage(**page)]
return filtered_pages

View File

@@ -79,8 +79,9 @@ class SalesforceConnector(LoadConnector, PollConnector, IdConnector):
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
extracted_id = f"{ID_PREFIX}{object_dict['Id']}"
extracted_link = f"https://{self.sf_client.sf_instance}/{extracted_id}"
salesforce_id = object_dict["Id"]
danswer_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
extracted_link = f"https://{self.sf_client.sf_instance}/{salesforce_id}"
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
extracted_object_text = extract_dict_text(object_dict)
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
@@ -91,7 +92,7 @@ class SalesforceConnector(LoadConnector, PollConnector, IdConnector):
]
doc = Document(
id=extracted_id,
id=danswer_salesforce_id,
sections=[Section(link=extracted_link, text=extracted_object_text)],
source=DocumentSource.SALESFORCE,
semantic_identifier=extracted_semantic_identifier,

View File

@@ -29,6 +29,7 @@ from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import pdf_to_text
from danswer.file_processing.html_utils import web_html_cleanup
from danswer.utils.logger import setup_logger
from danswer.utils.sitemap import list_pages_for_site
logger = setup_logger()
@@ -145,16 +146,21 @@ def extract_urls_from_sitemap(sitemap_url: str) -> list[str]:
response.raise_for_status()
soup = BeautifulSoup(response.content, "html.parser")
result = [
urls = [
_ensure_absolute_url(sitemap_url, loc_tag.text)
for loc_tag in soup.find_all("loc")
]
if not result:
if len(urls) == 0 and len(soup.find_all("urlset")) == 0:
# the given url doesn't look like a sitemap, let's try to find one
urls = list_pages_for_site(sitemap_url)
if len(urls) == 0:
raise ValueError(
f"No URLs found in sitemap {sitemap_url}. Try using the 'single' or 'recursive' scraping options instead."
)
return result
return urls
def _ensure_absolute_url(source_url: str, maybe_relative_url: str) -> str:
@@ -264,7 +270,7 @@ class WebConnector(LoadConnector):
id=current_url,
sections=[Section(link=current_url, text=page_text)],
source=DocumentSource.WEB,
semantic_identifier=current_url.split(".")[-1],
semantic_identifier=current_url.split("/")[-1],
metadata={},
)
)

View File

@@ -4,6 +4,7 @@ from zenpy import Zenpy # type: ignore
from zenpy.lib.api_objects.help_centre_objects import Article # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
time_str_to_utc,
@@ -81,7 +82,14 @@ class ZendeskConnector(LoadConnector, PollConnector):
)
doc_batch = []
for article in articles:
if article.body is None or article.draft:
if (
article.body is None
or article.draft
or any(
label in ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
for label in article.label_names
)
):
continue
doc_batch.append(_article_to_document(article))

View File

@@ -25,6 +25,7 @@ from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
from danswer.danswerbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.icons import source_to_github_img_link
@@ -353,6 +354,22 @@ def build_quotes_block(
return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))]
def build_standard_answer_blocks(
answer_message: str,
) -> list[Block]:
generate_button_block = ButtonElement(
action_id=GENERATE_ANSWER_BUTTON_ACTION_ID,
text="Generate Full Answer",
)
answer_block = SectionBlock(text=answer_message)
return [
answer_block,
ActionsBlock(
elements=[generate_button_block],
),
]
def build_qa_response_blocks(
message_id: int | None,
answer: str | None,
@@ -457,7 +474,7 @@ def build_follow_up_resolved_blocks(
if tag_str:
tag_str += " "
group_str = " ".join([f"<!subteam^{group}>" for group in group_ids])
group_str = " ".join([f"<!subteam^{group_id}|>" for group_id in group_ids])
if group_str:
group_str += " "

View File

@@ -8,6 +8,7 @@ FOLLOWUP_BUTTON_ACTION_ID = "followup-button"
FOLLOWUP_BUTTON_RESOLVED_ACTION_ID = "followup-resolved-button"
SLACK_CHANNEL_ID = "channel_id"
VIEW_DOC_FEEDBACK_ID = "view-doc-feedback"
GENERATE_ANSWER_BUTTON_ACTION_ID = "generate-answer-button"
class FeedbackVisibility(str, Enum):

View File

@@ -1,3 +1,4 @@
import logging
from typing import Any
from typing import cast
@@ -8,6 +9,7 @@ from slack_sdk.socket_mode import SocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest
from sqlalchemy.orm import Session
from danswer.configs.constants import MessageType
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
from danswer.connectors.slack.utils import make_slack_api_rate_limited
@@ -21,12 +23,17 @@ from danswer.danswerbot.slack.constants import VIEW_DOC_FEEDBACK_ID
from danswer.danswerbot.slack.handlers.handle_message import (
remove_scheduled_feedback_reminder,
)
from danswer.danswerbot.slack.handlers.handle_regular_answer import (
handle_regular_answer,
)
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import build_feedback_id
from danswer.danswerbot.slack.utils import decompose_action_id
from danswer.danswerbot.slack.utils import fetch_groupids_from_names
from danswer.danswerbot.slack.utils import fetch_userids_from_emails
from danswer.danswerbot.slack.utils import fetch_group_ids_from_names
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
from danswer.danswerbot.slack.utils import get_channel_name_from_id
from danswer.danswerbot.slack.utils import get_feedback_visibility
from danswer.danswerbot.slack.utils import read_slack_thread
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine
@@ -36,7 +43,7 @@ from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.utils.logger import setup_logger
logger_base = setup_logger()
logger = setup_logger()
def handle_doc_feedback_button(
@@ -44,7 +51,7 @@ def handle_doc_feedback_button(
client: SocketModeClient,
) -> None:
if not (actions := req.payload.get("actions")):
logger_base.error("Missing actions. Unable to build the source feedback view")
logger.error("Missing actions. Unable to build the source feedback view")
return
# Extracts the feedback_id coming from the 'source feedback' button
@@ -72,6 +79,66 @@ def handle_doc_feedback_button(
)
def handle_generate_answer_button(
req: SocketModeRequest,
client: SocketModeClient,
) -> None:
channel_id = req.payload["channel"]["id"]
channel_name = req.payload["channel"]["name"]
message_ts = req.payload["message"]["ts"]
thread_ts = req.payload["container"]["thread_ts"]
user_id = req.payload["user"]["id"]
if not thread_ts:
raise ValueError("Missing thread_ts in the payload")
thread_messages = read_slack_thread(
channel=channel_id, thread=thread_ts, client=client.web_client
)
# remove all assistant messages till we get to the last user message
# we want the new answer to be generated off of the last "question" in
# the thread
for i in range(len(thread_messages) - 1, -1, -1):
if thread_messages[i].role == MessageType.USER:
break
if thread_messages[i].role == MessageType.ASSISTANT:
thread_messages.pop(i)
# tell the user that we're working on it
# Send an ephemeral message to the user that we're generating the answer
respond_in_thread(
client=client.web_client,
channel=channel_id,
receiver_ids=[user_id],
text="I'm working on generating a full answer for you. This may take a moment...",
thread_ts=thread_ts,
)
with Session(get_sqlalchemy_engine()) as db_session:
slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session
)
handle_regular_answer(
message_info=SlackMessageInfo(
thread_messages=thread_messages,
channel_to_respond=channel_id,
msg_to_respond=cast(str, message_ts or thread_ts),
thread_to_respond=cast(str, thread_ts or message_ts),
sender=user_id or None,
bypass_filters=True,
is_bot_msg=False,
is_bot_dm=False,
),
slack_bot_config=slack_bot_config,
receiver_ids=None,
client=client.web_client,
channel=channel_id,
logger=cast(logging.Logger, logger),
feedback_reminder_id=None,
)
def handle_slack_feedback(
feedback_id: str,
feedback_type: str,
@@ -129,7 +196,7 @@ def handle_slack_feedback(
feedback=feedback,
)
else:
logger_base.error(f"Feedback type '{feedback_type}' not supported")
logger.error(f"Feedback type '{feedback_type}' not supported")
if get_feedback_visibility() == FeedbackVisibility.PRIVATE or feedback_type not in [
LIKE_BLOCK_ACTION_ID,
@@ -193,11 +260,11 @@ def handle_followup_button(
tag_names = slack_bot_config.channel_config.get("follow_up_tags")
remaining = None
if tag_names:
tag_ids, remaining = fetch_userids_from_emails(
tag_ids, remaining = fetch_user_ids_from_emails(
tag_names, client.web_client
)
if remaining:
group_ids, _ = fetch_groupids_from_names(remaining, client.web_client)
group_ids, _ = fetch_group_ids_from_names(remaining, client.web_client)
blocks = build_follow_up_resolved_blocks(tag_ids=tag_ids, group_ids=group_ids)
@@ -272,7 +339,7 @@ def handle_followup_resolved_button(
)
if not response.get("ok"):
logger_base.error("Unable to delete message for resolved")
logger.error("Unable to delete message for resolved")
if immediate:
msg_text = f"{clicker_name} has marked this question as resolved!"

View File

@@ -1,92 +1,34 @@
import datetime
import functools
import logging
from collections.abc import Callable
from typing import Any
from typing import cast
from typing import Optional
from typing import TypeVar
from retry import retry
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.models.blocks import DividerBlock
from slack_sdk.models.blocks import SectionBlock
from sqlalchemy.orm import Session
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
from danswer.configs.danswerbot_configs import DANSWER_BOT_FEEDBACK_REMINDER
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.configs.danswerbot_configs import DISABLE_DANSWER_BOT_FILTER_DETECT
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
from danswer.danswerbot.slack.blocks import build_documents_blocks
from danswer.danswerbot.slack.blocks import build_follow_up_block
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
from danswer.danswerbot.slack.blocks import build_sources_blocks
from danswer.danswerbot.slack.blocks import get_feedback_reminder_blocks
from danswer.danswerbot.slack.blocks import get_restate_blocks
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
from danswer.danswerbot.slack.handlers.handle_regular_answer import (
handle_regular_answer,
)
from danswer.danswerbot.slack.handlers.handle_standard_answers import (
handle_standard_answers,
)
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import ChannelIdAdapter
from danswer.danswerbot.slack.utils import fetch_userids_from_emails
from danswer.danswerbot.slack.utils import fetch_userids_from_groups
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
from danswer.danswerbot.slack.utils import fetch_user_ids_from_groups
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import slack_usage_report
from danswer.danswerbot.slack.utils import SlackRateLimiter
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import Persona
from danswer.db.models import SlackBotConfig
from danswer.db.models import SlackBotResponseType
from danswer.db.persona import fetch_persona_by_id
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_max_input_tokens
from danswer.one_shot_answer.answer_question import get_search_answer
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.search.models import BaseFilters
from danswer.search.models import OptionalSearchSetting
from danswer.search.models import RetrievalDetails
from danswer.utils.logger import setup_logger
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
logger_base = setup_logger()
srl = SlackRateLimiter()
RT = TypeVar("RT") # return type
def rate_limits(
client: WebClient, channel: str, thread_ts: Optional[str]
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
def decorator(func: Callable[..., RT]) -> Callable[..., RT]:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> RT:
if not srl.is_available():
func_randid, position = srl.init_waiter()
srl.notify(client, channel, position, thread_ts)
while not srl.is_available():
srl.waiter(func_randid)
srl.acquire_slot()
return func(*args, **kwargs)
return wrapper
return decorator
def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None:
if details.is_bot_msg and details.sender:
@@ -174,17 +116,9 @@ def remove_scheduled_feedback_reminder(
def handle_message(
message_info: SlackMessageInfo,
channel_config: SlackBotConfig | None,
slack_bot_config: SlackBotConfig | None,
client: WebClient,
feedback_reminder_id: str | None,
num_retries: int = DANSWER_BOT_NUM_RETRIES,
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER,
disable_auto_detect_filters: bool = DISABLE_DANSWER_BOT_FILTER_DETECT,
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
) -> bool:
"""Potentially respond to the user message depending on filters and if an answer was generated
@@ -201,14 +135,22 @@ def handle_message(
)
messages = message_info.thread_messages
message_ts_to_respond_to = message_info.msg_to_respond
sender_id = message_info.sender
bypass_filters = message_info.bypass_filters
is_bot_msg = message_info.is_bot_msg
is_bot_dm = message_info.is_bot_dm
action = "slack_message"
if is_bot_msg:
action = "slack_slash_message"
elif bypass_filters:
action = "slack_tag_message"
elif is_bot_dm:
action = "slack_dm_message"
slack_usage_report(action=action, sender_id=sender_id, client=client)
document_set_names: list[str] | None = None
persona = channel_config.persona if channel_config else None
persona = slack_bot_config.persona if slack_bot_config else None
prompt = None
if persona:
document_set_names = [
@@ -216,36 +158,13 @@ def handle_message(
]
prompt = persona.prompts[0] if persona.prompts else None
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
# figure out if we want to use citations or quotes
use_citations = (
not DANSWER_BOT_USE_QUOTES
if channel_config is None
else channel_config.response_type == SlackBotResponseType.CITATIONS
)
# List of user id to send message to, if None, send to everyone in channel
send_to: list[str] | None = None
respond_tag_only = False
respond_team_member_list = None
bypass_acl = False
if (
channel_config
and channel_config.persona
and channel_config.persona.document_sets
):
# For Slack channels, use the full document set, admin will be warned when configuring it
# with non-public document sets
bypass_acl = True
respond_member_group_list = None
channel_conf = None
if channel_config and channel_config.channel_config:
channel_conf = channel_config.channel_config
if slack_bot_config and slack_bot_config.channel_config:
channel_conf = slack_bot_config.channel_config
if not bypass_filters and "answer_filters" in channel_conf:
reflexion = "well_answered_postfilter" in channel_conf["answer_filters"]
if (
"questionmark_prefilter" in channel_conf["answer_filters"]
and "?" not in messages[-1].message
@@ -262,8 +181,7 @@ def handle_message(
)
respond_tag_only = channel_conf.get("respond_tag_only") or False
respond_team_member_list = channel_conf.get("respond_team_member_list") or None
respond_slack_group_list = channel_conf.get("respond_slack_group_list") or None
respond_member_group_list = channel_conf.get("respond_member_group_list", None)
if respond_tag_only and not bypass_filters:
logger.info(
@@ -272,17 +190,23 @@ def handle_message(
)
return False
if respond_team_member_list:
send_to, _ = fetch_userids_from_emails(respond_team_member_list, client)
if respond_slack_group_list:
user_ids, _ = fetch_userids_from_groups(respond_slack_group_list, client)
send_to = (send_to + user_ids) if send_to else user_ids
if send_to:
send_to = list(set(send_to)) # remove duplicates
# List of user id to send message to, if None, send to everyone in channel
send_to: list[str] | None = None
missing_users: list[str] | None = None
if respond_member_group_list:
send_to, missing_ids = fetch_user_ids_from_emails(
respond_member_group_list, client
)
user_ids, missing_users = fetch_user_ids_from_groups(missing_ids, client)
send_to = list(set(send_to + user_ids)) if send_to else user_ids
if missing_users:
logger.warning(f"Failed to find these users/groups: {missing_users}")
# If configured to respond to team members only, then cannot be used with a /DanswerBot command
# which would just respond to the sender
if (respond_team_member_list or respond_slack_group_list) and is_bot_msg:
if send_to and is_bot_msg:
if sender_id:
respond_in_thread(
client=client,
@@ -297,324 +221,28 @@ def handle_message(
except SlackApiError as e:
logger.error(f"Was not able to react to user message due to: {e}")
@retry(
tries=num_retries,
delay=0.25,
backoff=2,
logger=logger,
)
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | None:
action = "slack_message"
if is_bot_msg:
action = "slack_slash_message"
elif bypass_filters:
action = "slack_tag_message"
elif is_bot_dm:
action = "slack_dm_message"
slack_usage_report(action=action, sender_id=sender_id, client=client)
max_document_tokens: int | None = None
max_history_tokens: int | None = None
with Session(get_sqlalchemy_engine()) as db_session:
if len(new_message_request.messages) > 1:
persona = cast(
Persona,
fetch_persona_by_id(db_session, new_message_request.persona_id),
)
llm, _ = get_llms_for_persona(persona)
# In cases of threads, split the available tokens between docs and thread context
input_tokens = get_max_input_tokens(
model_name=llm.config.model_name,
model_provider=llm.config.model_provider,
)
max_history_tokens = int(input_tokens * thread_context_percent)
remaining_tokens = input_tokens - max_history_tokens
query_text = new_message_request.messages[0].message
if persona:
max_document_tokens = compute_max_document_tokens_for_persona(
persona=persona,
actual_user_input=query_text,
max_llm_token_override=remaining_tokens,
)
else:
max_document_tokens = (
remaining_tokens
- 512 # Needs to be more than any of the QA prompts
- check_number_of_tokens(query_text)
)
if DISABLE_GENERATIVE_AI:
return None
# This also handles creating the query event in postgres
answer = get_search_answer(
query_req=new_message_request,
user=None,
max_document_tokens=max_document_tokens,
max_history_tokens=max_history_tokens,
db_session=db_session,
answer_generation_timeout=answer_generation_timeout,
enable_reflexion=reflexion,
bypass_acl=bypass_acl,
use_citations=use_citations,
danswerbot_flow=True,
)
if not answer.error_msg:
return answer
else:
raise RuntimeError(answer.error_msg)
try:
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
# it allows the slack flow to extract out filters from the user query
filters = BaseFilters(
source_type=None,
document_set=document_set_names,
time_cutoff=None,
with Session(get_sqlalchemy_engine()) as db_session:
# first check if we need to respond with a standard answer
used_standard_answer = handle_standard_answers(
message_info=message_info,
receiver_ids=send_to,
slack_bot_config=slack_bot_config,
prompt=prompt,
logger=logger,
client=client,
db_session=db_session,
)
# Default True because no other ways to apply filters in Slack (no nice UI)
auto_detect_filters = (
persona.llm_filter_extraction if persona is not None else True
)
if disable_auto_detect_filters:
auto_detect_filters = False
retrieval_details = RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=False,
filters=filters,
enable_auto_detect_filters=auto_detect_filters,
)
# This includes throwing out answer via reflexion
answer = _get_answer(
DirectQARequest(
messages=messages,
prompt_id=prompt.id if prompt else None,
persona_id=persona.id if persona is not None else 0,
retrieval_options=retrieval_details,
chain_of_thought=not disable_cot,
skip_rerank=not ENABLE_RERANKING_ASYNC_FLOW,
)
)
except Exception as e:
logger.exception(
f"Unable to process message - did not successfully answer "
f"in {num_retries} attempts"
)
# Optionally, respond in thread with the error message, Used primarily
# for debugging purposes
if should_respond_with_error_msgs:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=None,
text=f"Encountered exception when trying to answer: \n\n```{e}```",
thread_ts=message_ts_to_respond_to,
)
# In case of failures, don't keep the reaction there permanently
try:
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,
client=client,
)
except SlackApiError as e:
logger.error(f"Failed to remove Reaction due to: {e}")
return True
# Edge case handling, for tracking down the Slack usage issue
if answer is None:
assert DISABLE_GENERATIVE_AI is True
try:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=send_to,
text="Hello! Danswer has some results for you!",
blocks=[
SectionBlock(
text="Danswer is down for maintenance.\nWe're working hard on recharging the AI!"
)
],
thread_ts=message_ts_to_respond_to,
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
unfurl=False,
)
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
if respond_team_member_list or respond_slack_group_list:
respond_in_thread(
client=client,
channel=channel,
text=(
"👋 Hi, we've just gathered and forwarded the relevant "
+ "information to the team. They'll get back to you shortly!"
),
thread_ts=message_ts_to_respond_to,
)
if used_standard_answer:
return False
except Exception:
logger.exception(
f"Unable to process message - could not respond in slack in {num_retries} attempts"
)
return True
# Got an answer at this point, can remove reaction and give results
try:
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,
client=client,
)
except SlackApiError as e:
logger.error(f"Failed to remove Reaction due to: {e}")
if answer.answer_valid is False:
logger.info(
"Answer was evaluated to be invalid, throwing it away without responding."
)
update_emote_react(
emoji=DANSWER_FOLLOWUP_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=False,
client=client,
)
if answer.answer:
logger.debug(answer.answer)
return True
retrieval_info = answer.docs
if not retrieval_info:
# This should not happen, even with no docs retrieved, there is still info returned
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
top_docs = retrieval_info.top_documents
if not top_docs and not should_respond_even_with_no_docs:
logger.error(
f"Unable to answer question: '{answer.rephrase}' - no documents found"
)
# Optionally, respond in thread with the error message
# Used primarily for debugging purposes
if should_respond_with_error_msgs:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=None,
text="Found no documents when trying to answer. Did you index any documents?",
thread_ts=message_ts_to_respond_to,
)
return True
if not answer.answer and disable_docs_only_answer:
logger.info(
"Unable to find answer - not responding since the "
"`DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER` env variable is set"
)
return True
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
answer_blocks = build_qa_response_blocks(
message_id=answer.chat_message_id,
answer=answer.answer,
quotes=answer.quotes.quotes if answer.quotes else None,
source_filters=retrieval_info.applied_source_filters,
time_cutoff=retrieval_info.applied_time_cutoff,
favor_recent=retrieval_info.recency_bias_multiplier > 1,
# currently Personas don't support quotes
# if citations are enabled, also don't use quotes
skip_quotes=persona is not None or use_citations,
process_message_for_citations=use_citations,
feedback_reminder_id=feedback_reminder_id,
)
# Get the chunks fed to the LLM only, then fill with other docs
llm_doc_inds = answer.llm_chunks_indices or []
llm_docs = [top_docs[i] for i in llm_doc_inds]
remaining_docs = [
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
]
priority_ordered_docs = llm_docs + remaining_docs
document_blocks = []
citations_block = []
# if citations are enabled, only show cited documents
if use_citations:
citations = answer.citations or []
cited_docs = []
for citation in citations:
matching_doc = next(
(d for d in top_docs if d.document_id == citation.document_id),
None,
)
if matching_doc:
cited_docs.append((citation.citation_num, matching_doc))
cited_docs.sort()
citations_block = build_sources_blocks(cited_documents=cited_docs)
elif priority_ordered_docs:
document_blocks = build_documents_blocks(
documents=priority_ordered_docs,
message_id=answer.chat_message_id,
)
document_blocks = [DividerBlock()] + document_blocks
all_blocks = (
restate_question_block + answer_blocks + citations_block + document_blocks
)
if channel_conf and channel_conf.get("follow_up_tags") is not None:
all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id))
try:
respond_in_thread(
# if no standard answer applies, try a regular answer
issue_with_regular_answer = handle_regular_answer(
message_info=message_info,
slack_bot_config=slack_bot_config,
receiver_ids=send_to,
client=client,
channel=channel,
receiver_ids=send_to,
text="Hello! Danswer has some results for you!",
blocks=all_blocks,
thread_ts=message_ts_to_respond_to,
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
unfurl=False,
logger=logger,
feedback_reminder_id=feedback_reminder_id,
)
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
if respond_team_member_list or respond_slack_group_list:
respond_in_thread(
client=client,
channel=channel,
text=(
"👋 Hi, we've just gathered and forwarded the relevant "
+ "information to the team. They'll get back to you shortly!"
),
thread_ts=message_ts_to_respond_to,
)
return False
except Exception:
logger.exception(
f"Unable to process message - could not respond in slack in {num_retries} attempts"
)
return True
return issue_with_regular_answer

View File

@@ -0,0 +1,465 @@
import functools
import logging
from collections.abc import Callable
from typing import Any
from typing import cast
from typing import Optional
from typing import TypeVar
from retry import retry
from slack_sdk import WebClient
from slack_sdk.models.blocks import DividerBlock
from slack_sdk.models.blocks import SectionBlock
from sqlalchemy.orm import Session
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
from danswer.danswerbot.slack.blocks import build_documents_blocks
from danswer.danswerbot.slack.blocks import build_follow_up_block
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
from danswer.danswerbot.slack.blocks import build_sources_blocks
from danswer.danswerbot.slack.blocks import get_restate_blocks
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import SlackRateLimiter
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import Persona
from danswer.db.models import SlackBotConfig
from danswer.db.models import SlackBotResponseType
from danswer.db.persona import fetch_persona_by_id
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_max_input_tokens
from danswer.one_shot_answer.answer_question import get_search_answer
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.search.enums import OptionalSearchSetting
from danswer.search.models import BaseFilters
from danswer.search.models import RetrievalDetails
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
srl = SlackRateLimiter()
RT = TypeVar("RT") # return type
def rate_limits(
client: WebClient, channel: str, thread_ts: Optional[str]
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
def decorator(func: Callable[..., RT]) -> Callable[..., RT]:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> RT:
if not srl.is_available():
func_randid, position = srl.init_waiter()
srl.notify(client, channel, position, thread_ts)
while not srl.is_available():
srl.waiter(func_randid)
srl.acquire_slot()
return func(*args, **kwargs)
return wrapper
return decorator
def handle_regular_answer(
message_info: SlackMessageInfo,
slack_bot_config: SlackBotConfig | None,
receiver_ids: list[str] | None,
client: WebClient,
channel: str,
logger: logging.Logger,
feedback_reminder_id: str | None,
num_retries: int = DANSWER_BOT_NUM_RETRIES,
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER,
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
) -> bool:
channel_conf = slack_bot_config.channel_config if slack_bot_config else None
messages = message_info.thread_messages
message_ts_to_respond_to = message_info.msg_to_respond
is_bot_msg = message_info.is_bot_msg
document_set_names: list[str] | None = None
persona = slack_bot_config.persona if slack_bot_config else None
prompt = None
if persona:
document_set_names = [
document_set.name for document_set in persona.document_sets
]
prompt = persona.prompts[0] if persona.prompts else None
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
bypass_acl = False
if (
slack_bot_config
and slack_bot_config.persona
and slack_bot_config.persona.document_sets
):
# For Slack channels, use the full document set, admin will be warned when configuring it
# with non-public document sets
bypass_acl = True
# figure out if we want to use citations or quotes
use_citations = (
not DANSWER_BOT_USE_QUOTES
if slack_bot_config is None
else slack_bot_config.response_type == SlackBotResponseType.CITATIONS
)
if not message_ts_to_respond_to:
raise RuntimeError(
"No message timestamp to respond to in `handle_message`. This should never happen."
)
@retry(
tries=num_retries,
delay=0.25,
backoff=2,
logger=logger,
)
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | None:
max_document_tokens: int | None = None
max_history_tokens: int | None = None
with Session(get_sqlalchemy_engine()) as db_session:
if len(new_message_request.messages) > 1:
persona = cast(
Persona,
fetch_persona_by_id(db_session, new_message_request.persona_id),
)
llm, _ = get_llms_for_persona(persona)
# In cases of threads, split the available tokens between docs and thread context
input_tokens = get_max_input_tokens(
model_name=llm.config.model_name,
model_provider=llm.config.model_provider,
)
max_history_tokens = int(input_tokens * thread_context_percent)
remaining_tokens = input_tokens - max_history_tokens
query_text = new_message_request.messages[0].message
if persona:
max_document_tokens = compute_max_document_tokens_for_persona(
persona=persona,
actual_user_input=query_text,
max_llm_token_override=remaining_tokens,
)
else:
max_document_tokens = (
remaining_tokens
- 512 # Needs to be more than any of the QA prompts
- check_number_of_tokens(query_text)
)
if DISABLE_GENERATIVE_AI:
return None
# This also handles creating the query event in postgres
answer = get_search_answer(
query_req=new_message_request,
user=None,
max_document_tokens=max_document_tokens,
max_history_tokens=max_history_tokens,
db_session=db_session,
answer_generation_timeout=answer_generation_timeout,
enable_reflexion=reflexion,
bypass_acl=bypass_acl,
use_citations=use_citations,
danswerbot_flow=True,
)
if not answer.error_msg:
return answer
else:
raise RuntimeError(answer.error_msg)
try:
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
# it allows the slack flow to extract out filters from the user query
filters = BaseFilters(
source_type=None,
document_set=document_set_names,
time_cutoff=None,
)
# Default True because no other ways to apply filters in Slack (no nice UI)
# Commenting this out because this is only available to the slackbot for now
# later we plan to implement this at the persona level where this will get
# commented back in
# auto_detect_filters = (
# persona.llm_filter_extraction if persona is not None else True
# )
auto_detect_filters = (
slack_bot_config.enable_auto_filters
if slack_bot_config is not None
else False
)
retrieval_details = RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=False,
filters=filters,
enable_auto_detect_filters=auto_detect_filters,
)
# This includes throwing out answer via reflexion
answer = _get_answer(
DirectQARequest(
messages=messages,
prompt_id=prompt.id if prompt else None,
persona_id=persona.id if persona is not None else 0,
retrieval_options=retrieval_details,
chain_of_thought=not disable_cot,
skip_rerank=not ENABLE_RERANKING_ASYNC_FLOW,
)
)
except Exception as e:
logger.exception(
f"Unable to process message - did not successfully answer "
f"in {num_retries} attempts"
)
# Optionally, respond in thread with the error message, Used primarily
# for debugging purposes
if should_respond_with_error_msgs:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=None,
text=f"Encountered exception when trying to answer: \n\n```{e}```",
thread_ts=message_ts_to_respond_to,
)
# In case of failures, don't keep the reaction there permanently
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,
client=client,
)
return True
# Edge case handling, for tracking down the Slack usage issue
if answer is None:
assert DISABLE_GENERATIVE_AI is True
try:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=receiver_ids,
text="Hello! Danswer has some results for you!",
blocks=[
SectionBlock(
text="Danswer is down for maintenance.\nWe're working hard on recharging the AI!"
)
],
thread_ts=message_ts_to_respond_to,
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
unfurl=False,
)
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
if receiver_ids:
respond_in_thread(
client=client,
channel=channel,
text=(
"👋 Hi, we've just gathered and forwarded the relevant "
+ "information to the team. They'll get back to you shortly!"
),
thread_ts=message_ts_to_respond_to,
)
return False
except Exception:
logger.exception(
f"Unable to process message - could not respond in slack in {num_retries} attempts"
)
return True
# Got an answer at this point, can remove reaction and give results
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,
client=client,
)
if answer.answer_valid is False:
logger.info(
"Answer was evaluated to be invalid, throwing it away without responding."
)
update_emote_react(
emoji=DANSWER_FOLLOWUP_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=False,
client=client,
)
if answer.answer:
logger.debug(answer.answer)
return True
retrieval_info = answer.docs
if not retrieval_info:
# This should not happen, even with no docs retrieved, there is still info returned
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
top_docs = retrieval_info.top_documents
if not top_docs and not should_respond_even_with_no_docs:
logger.error(
f"Unable to answer question: '{answer.rephrase}' - no documents found"
)
# Optionally, respond in thread with the error message
# Used primarily for debugging purposes
if should_respond_with_error_msgs:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=None,
text="Found no documents when trying to answer. Did you index any documents?",
thread_ts=message_ts_to_respond_to,
)
return True
if not answer.answer and disable_docs_only_answer:
logger.info(
"Unable to find answer - not responding since the "
"`DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER` env variable is set"
)
return True
only_respond_with_citations_or_quotes = (
channel_conf
and "well_answered_postfilter" in channel_conf.get("answer_filters", [])
)
has_citations_or_quotes = bool(answer.citations or answer.quotes)
if (
only_respond_with_citations_or_quotes
and not has_citations_or_quotes
and not message_info.bypass_filters
):
logger.error(
f"Unable to find citations or quotes to answer: '{answer.rephrase}' - not answering!"
)
# Optionally, respond in thread with the error message
# Used primarily for debugging purposes
if should_respond_with_error_msgs:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=None,
text="Found no citations or quotes when trying to answer.",
thread_ts=message_ts_to_respond_to,
)
return True
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
answer_blocks = build_qa_response_blocks(
message_id=answer.chat_message_id,
answer=answer.answer,
quotes=answer.quotes.quotes if answer.quotes else None,
source_filters=retrieval_info.applied_source_filters,
time_cutoff=retrieval_info.applied_time_cutoff,
favor_recent=retrieval_info.recency_bias_multiplier > 1,
# currently Personas don't support quotes
# if citations are enabled, also don't use quotes
skip_quotes=persona is not None or use_citations,
process_message_for_citations=use_citations,
feedback_reminder_id=feedback_reminder_id,
)
# Get the chunks fed to the LLM only, then fill with other docs
llm_doc_inds = answer.llm_chunks_indices or []
llm_docs = [top_docs[i] for i in llm_doc_inds]
remaining_docs = [
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
]
priority_ordered_docs = llm_docs + remaining_docs
document_blocks = []
citations_block = []
# if citations are enabled, only show cited documents
if use_citations:
citations = answer.citations or []
cited_docs = []
for citation in citations:
matching_doc = next(
(d for d in top_docs if d.document_id == citation.document_id),
None,
)
if matching_doc:
cited_docs.append((citation.citation_num, matching_doc))
cited_docs.sort()
citations_block = build_sources_blocks(cited_documents=cited_docs)
elif priority_ordered_docs:
document_blocks = build_documents_blocks(
documents=priority_ordered_docs,
message_id=answer.chat_message_id,
)
document_blocks = [DividerBlock()] + document_blocks
all_blocks = (
restate_question_block + answer_blocks + citations_block + document_blocks
)
if channel_conf and channel_conf.get("follow_up_tags") is not None:
all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id))
try:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=receiver_ids,
text="Hello! Danswer has some results for you!",
blocks=all_blocks,
thread_ts=message_ts_to_respond_to,
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
unfurl=False,
)
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
if receiver_ids:
send_team_member_message(
client=client,
channel=channel,
thread_ts=message_ts_to_respond_to,
)
return False
except Exception:
logger.exception(
f"Unable to process message - could not respond in slack in {num_retries} attempts"
)
return True

View File

@@ -0,0 +1,216 @@
import logging
from slack_sdk import WebClient
from sqlalchemy.orm import Session
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.danswerbot.slack.blocks import build_standard_answer_blocks
from danswer.danswerbot.slack.blocks import get_restate_blocks
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.chat import create_chat_session
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_chat_messages_by_sessions
from danswer.db.chat import get_chat_sessions_by_slack_thread_id
from danswer.db.chat import get_or_create_root_message
from danswer.db.models import Prompt
from danswer.db.models import SlackBotConfig
from danswer.db.standard_answer import fetch_standard_answer_categories_by_names
from danswer.db.standard_answer import find_matching_standard_answers
from danswer.server.manage.models import StandardAnswer
from danswer.utils.logger import setup_logger
logger = setup_logger()
def oneoff_standard_answers(
message: str,
slack_bot_categories: list[str],
db_session: Session,
) -> list[StandardAnswer]:
"""
Respond to the user message if it matches any configured standard answers.
Returns a list of matching StandardAnswers if found, otherwise None.
"""
configured_standard_answers = {
standard_answer
for category in fetch_standard_answer_categories_by_names(
slack_bot_categories, db_session=db_session
)
for standard_answer in category.standard_answers
}
matching_standard_answers = find_matching_standard_answers(
query=message,
id_in=[answer.id for answer in configured_standard_answers],
db_session=db_session,
)
server_standard_answers = [
StandardAnswer.from_model(db_answer) for db_answer in matching_standard_answers
]
return server_standard_answers
def handle_standard_answers(
message_info: SlackMessageInfo,
receiver_ids: list[str] | None,
slack_bot_config: SlackBotConfig | None,
prompt: Prompt | None,
logger: logging.Logger,
client: WebClient,
db_session: Session,
) -> bool:
"""
Potentially respond to the user message depending on whether the user's message matches
any of the configured standard answers and also whether those answers have already been
provided in the current thread.
Returns True if standard answers are found to match the user's message and therefore,
we still need to respond to the users.
"""
# if no channel config, then no standard answers are configured
if not slack_bot_config:
return False
slack_thread_id = message_info.thread_to_respond
configured_standard_answer_categories = (
slack_bot_config.standard_answer_categories if slack_bot_config else []
)
configured_standard_answers = set(
[
standard_answer
for standard_answer_category in configured_standard_answer_categories
for standard_answer in standard_answer_category.standard_answers
]
)
query_msg = message_info.thread_messages[-1]
if slack_thread_id is None:
used_standard_answer_ids = set([])
else:
chat_sessions = get_chat_sessions_by_slack_thread_id(
slack_thread_id=slack_thread_id,
user_id=None,
db_session=db_session,
)
chat_messages = get_chat_messages_by_sessions(
chat_session_ids=[chat_session.id for chat_session in chat_sessions],
user_id=None,
db_session=db_session,
skip_permission_check=True,
)
used_standard_answer_ids = set(
[
standard_answer.id
for chat_message in chat_messages
for standard_answer in chat_message.standard_answers
]
)
usable_standard_answers = configured_standard_answers.difference(
used_standard_answer_ids
)
if usable_standard_answers:
matching_standard_answers = find_matching_standard_answers(
query=query_msg.message,
id_in=[standard_answer.id for standard_answer in usable_standard_answers],
db_session=db_session,
)
else:
matching_standard_answers = []
if matching_standard_answers:
chat_session = create_chat_session(
db_session=db_session,
description="",
user_id=None,
persona_id=slack_bot_config.persona.id if slack_bot_config.persona else 0,
danswerbot_flow=True,
slack_thread_id=slack_thread_id,
one_shot=True,
)
root_message = get_or_create_root_message(
chat_session_id=chat_session.id, db_session=db_session
)
new_user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=root_message,
prompt_id=prompt.id if prompt else None,
message=query_msg.message,
token_count=0,
message_type=MessageType.USER,
db_session=db_session,
commit=True,
)
formatted_answers = []
for standard_answer in matching_standard_answers:
block_quotified_answer = ">" + standard_answer.answer.replace("\n", "\n> ")
formatted_answer = (
f'Since you mentioned _"{standard_answer.keyword}"_, '
f"I thought this might be useful: \n\n{block_quotified_answer}"
)
formatted_answers.append(formatted_answer)
answer_message = "\n\n".join(formatted_answers)
_ = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=new_user_message,
prompt_id=prompt.id if prompt else None,
message=answer_message,
token_count=0,
message_type=MessageType.ASSISTANT,
error=None,
db_session=db_session,
commit=True,
)
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,
client=client,
)
restate_question_blocks = get_restate_blocks(
msg=query_msg.message,
is_bot_msg=message_info.is_bot_msg,
)
answer_blocks = build_standard_answer_blocks(
answer_message=answer_message,
)
all_blocks = restate_question_blocks + answer_blocks
try:
respond_in_thread(
client=client,
channel=message_info.channel_to_respond,
receiver_ids=receiver_ids,
text="Hello! Danswer has some results for you!",
blocks=all_blocks,
thread_ts=message_info.msg_to_respond,
unfurl=False,
)
if receiver_ids and slack_thread_id:
send_team_member_message(
client=client,
channel=message_info.channel_to_respond,
thread_ts=slack_thread_id,
)
return True
except Exception as e:
logger.exception(f"Unable to send standard answer message: {e}")
return False
else:
return False

View File

@@ -0,0 +1,19 @@
from slack_sdk import WebClient
from danswer.danswerbot.slack.utils import respond_in_thread
def send_team_member_message(
client: WebClient,
channel: str,
thread_ts: str,
) -> None:
respond_in_thread(
client=client,
channel=channel,
text=(
"👋 Hi, we've just gathered and forwarded the relevant "
+ "information to the team. They'll get back to you shortly!"
),
thread_ts=thread_ts,
)

View File

@@ -18,6 +18,7 @@ from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
from danswer.danswerbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
@@ -27,6 +28,9 @@ from danswer.danswerbot.slack.handlers.handle_buttons import handle_followup_but
from danswer.danswerbot.slack.handlers.handle_buttons import (
handle_followup_resolved_button,
)
from danswer.danswerbot.slack.handlers.handle_buttons import (
handle_generate_answer_button,
)
from danswer.danswerbot.slack.handlers.handle_buttons import handle_slack_feedback
from danswer.danswerbot.slack.handlers.handle_message import handle_message
from danswer.danswerbot.slack.handlers.handle_message import (
@@ -266,6 +270,7 @@ def build_request_details(
thread_messages=thread_messages,
channel_to_respond=channel,
msg_to_respond=cast(str, message_ts or thread_ts),
thread_to_respond=cast(str, thread_ts or message_ts),
sender=event.get("user") or None,
bypass_filters=tagged,
is_bot_msg=False,
@@ -283,6 +288,7 @@ def build_request_details(
thread_messages=[single_msg],
channel_to_respond=channel,
msg_to_respond=None,
thread_to_respond=None,
sender=sender,
bypass_filters=True,
is_bot_msg=True,
@@ -352,7 +358,7 @@ def process_message(
failed = handle_message(
message_info=details,
channel_config=slack_bot_config,
slack_bot_config=slack_bot_config,
client=client.web_client,
feedback_reminder_id=feedback_reminder_id,
)
@@ -390,6 +396,8 @@ def action_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
return handle_followup_resolved_button(req, client, immediate=True)
elif action["action_id"] == FOLLOWUP_BUTTON_RESOLVED_ACTION_ID:
return handle_followup_resolved_button(req, client, immediate=False)
elif action["action_id"] == GENERATE_ANSWER_BUTTON_ACTION_ID:
return handle_generate_answer_button(req, client)
def view_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
@@ -461,13 +469,13 @@ if __name__ == "__main__":
# or the tokens have updated (set up for the first time)
with Session(get_sqlalchemy_engine()) as db_session:
embedding_model = get_current_db_embedding_model(db_session)
warm_up_encoders(
model_name=embedding_model.model_name,
normalize=embedding_model.normalize,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
if embedding_model.cloud_provider_id is None:
warm_up_encoders(
model_name=embedding_model.model_name,
normalize=embedding_model.normalize,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
slack_bot_tokens = latest_slack_bot_tokens
# potentially may cause a message to be dropped, but it is complicated

View File

@@ -7,6 +7,7 @@ class SlackMessageInfo(BaseModel):
thread_messages: list[ThreadMessage]
channel_to_respond: str
msg_to_respond: str | None
thread_to_respond: str | None
sender: str | None
bypass_filters: bool # User has tagged @DanswerBot
is_bot_msg: bool # User is using /DanswerBot

View File

@@ -77,17 +77,25 @@ def update_emote_react(
remove: bool,
client: WebClient,
) -> None:
if not message_ts:
logger.error(f"Tried to remove a react in {channel} but no message specified")
return
try:
if not message_ts:
logger.error(
f"Tried to remove a react in {channel} but no message specified"
)
return
func = client.reactions_remove if remove else client.reactions_add
slack_call = make_slack_api_rate_limited(func) # type: ignore
slack_call(
name=emoji,
channel=channel,
timestamp=message_ts,
)
func = client.reactions_remove if remove else client.reactions_add
slack_call = make_slack_api_rate_limited(func) # type: ignore
slack_call(
name=emoji,
channel=channel,
timestamp=message_ts,
)
except SlackApiError as e:
if remove:
logger.error(f"Failed to remove Reaction due to: {e}")
else:
logger.error(f"Was not able to react to user message due to: {e}")
def get_danswer_bot_app_id(web_client: WebClient) -> Any:
@@ -136,16 +144,13 @@ def respond_in_thread(
receiver_ids: list[str] | None = None,
metadata: Metadata | None = None,
unfurl: bool = True,
) -> None:
) -> list[str]:
if not text and not blocks:
raise ValueError("One of `text` or `blocks` must be provided")
message_ids: list[str] = []
if not receiver_ids:
slack_call = make_slack_api_rate_limited(client.chat_postMessage)
else:
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
if not receiver_ids:
response = slack_call(
channel=channel,
text=text,
@@ -157,7 +162,9 @@ def respond_in_thread(
)
if not response.get("ok"):
raise RuntimeError(f"Failed to post message: {response}")
message_ids.append(response["message_ts"])
else:
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
for receiver in receiver_ids:
response = slack_call(
channel=channel,
@@ -171,6 +178,9 @@ def respond_in_thread(
)
if not response.get("ok"):
raise RuntimeError(f"Failed to post message: {response}")
message_ids.append(response["message_ts"])
return message_ids
def build_feedback_id(
@@ -292,7 +302,7 @@ def get_channel_name_from_id(
raise e
def fetch_userids_from_emails(
def fetch_user_ids_from_emails(
user_emails: list[str], client: WebClient
) -> tuple[list[str], list[str]]:
user_ids: list[str] = []
@@ -308,57 +318,72 @@ def fetch_userids_from_emails(
return user_ids, failed_to_find
def fetch_userids_from_groups(
group_names: list[str], client: WebClient
def fetch_user_ids_from_groups(
given_names: list[str], client: WebClient
) -> tuple[list[str], list[str]]:
user_ids: list[str] = []
failed_to_find: list[str] = []
for group_name in group_names:
try:
# First, find the group ID from the group name
response = client.usergroups_list()
groups = {group["name"]: group["id"] for group in response["usergroups"]}
group_id = groups.get(group_name)
try:
response = client.usergroups_list()
if not isinstance(response.data, dict):
logger.error("Error fetching user groups")
return user_ids, given_names
if group_id:
# Fetch user IDs for the group
all_group_data = response.data.get("usergroups", [])
name_id_map = {d["name"]: d["id"] for d in all_group_data}
handle_id_map = {d["handle"]: d["id"] for d in all_group_data}
for given_name in given_names:
group_id = name_id_map.get(given_name) or handle_id_map.get(
given_name.lstrip("@")
)
if not group_id:
failed_to_find.append(given_name)
continue
try:
response = client.usergroups_users_list(usergroup=group_id)
user_ids.extend(response["users"])
else:
failed_to_find.append(group_name)
except Exception as e:
logger.error(f"Error fetching user IDs for group {group_name}: {str(e)}")
failed_to_find.append(group_name)
if isinstance(response.data, dict):
user_ids.extend(response.data.get("users", []))
else:
failed_to_find.append(given_name)
except Exception as e:
logger.error(f"Error fetching user group ids: {str(e)}")
failed_to_find.append(given_name)
except Exception as e:
logger.error(f"Error fetching user groups: {str(e)}")
failed_to_find = given_names
return user_ids, failed_to_find
def fetch_groupids_from_names(
names: list[str], client: WebClient
def fetch_group_ids_from_names(
given_names: list[str], client: WebClient
) -> tuple[list[str], list[str]]:
group_ids: set[str] = set()
group_data: list[str] = []
failed_to_find: list[str] = []
try:
response = client.usergroups_list()
if response.get("ok") and "usergroups" in response.data:
all_groups_dicts = response.data["usergroups"] # type: ignore
name_id_map = {d["name"]: d["id"] for d in all_groups_dicts}
handle_id_map = {d["handle"]: d["id"] for d in all_groups_dicts}
for group in names:
if group in name_id_map:
group_ids.add(name_id_map[group])
elif group in handle_id_map:
group_ids.add(handle_id_map[group])
else:
failed_to_find.append(group)
else:
# Most likely a Slack App scope issue
if not isinstance(response.data, dict):
logger.error("Error fetching user groups")
return group_data, given_names
all_group_data = response.data.get("usergroups", [])
name_id_map = {d["name"]: d["id"] for d in all_group_data}
handle_id_map = {d["handle"]: d["id"] for d in all_group_data}
for given_name in given_names:
id = handle_id_map.get(given_name.lstrip("@"))
id = id or name_id_map.get(given_name)
if id:
group_data.append(id)
else:
failed_to_find.append(given_name)
except Exception as e:
failed_to_find = given_names
logger.error(f"Error fetching user groups: {str(e)}")
return list(group_ids), failed_to_find
return group_data, failed_to_find
def fetch_user_semantic_id_from_id(

View File

@@ -1,16 +1,22 @@
from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from uuid import UUID
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import nullsfirst
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserRole
from danswer.chat.models import LLMRelevanceSummaryResponse
from danswer.configs.chat_configs import HARD_DELETE_CHATS
from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage
@@ -33,6 +39,7 @@ from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -67,17 +74,59 @@ def get_chat_session_by_id(
return chat_session
def get_chat_sessions_by_slack_thread_id(
slack_thread_id: str,
user_id: UUID | None,
db_session: Session,
) -> Sequence[ChatSession]:
stmt = select(ChatSession).where(ChatSession.slack_thread_id == slack_thread_id)
if user_id is not None:
stmt = stmt.where(
or_(ChatSession.user_id == user_id, ChatSession.user_id.is_(None))
)
return db_session.scalars(stmt).all()
def get_first_messages_for_chat_sessions(
chat_session_ids: list[int], db_session: Session
) -> dict[int, str]:
subquery = (
select(ChatMessage.chat_session_id, func.min(ChatMessage.id).label("min_id"))
.where(
and_(
ChatMessage.chat_session_id.in_(chat_session_ids),
ChatMessage.message_type == MessageType.USER, # Select USER messages
)
)
.group_by(ChatMessage.chat_session_id)
.subquery()
)
query = select(ChatMessage.chat_session_id, ChatMessage.message).join(
subquery,
(ChatMessage.chat_session_id == subquery.c.chat_session_id)
& (ChatMessage.id == subquery.c.min_id),
)
first_messages = db_session.execute(query).all()
return dict([(row.chat_session_id, row.message) for row in first_messages])
def get_chat_sessions_by_user(
user_id: UUID | None,
deleted: bool | None,
db_session: Session,
include_one_shot: bool = False,
only_one_shot: bool = False,
) -> list[ChatSession]:
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
if not include_one_shot:
if only_one_shot:
stmt = stmt.where(ChatSession.one_shot.is_(True))
else:
stmt = stmt.where(ChatSession.one_shot.is_(False))
stmt = stmt.order_by(desc(ChatSession.time_created))
if deleted is not None:
stmt = stmt.where(ChatSession.deleted == deleted)
@@ -97,6 +146,12 @@ def delete_search_doc_message_relationship(
db_session.commit()
def delete_tool_call_for_message_id(message_id: int, db_session: Session) -> None:
stmt = delete(ToolCall).where(ToolCall.message_id == message_id)
db_session.execute(stmt)
db_session.commit()
def delete_orphaned_search_docs(db_session: Session) -> None:
orphaned_docs = (
db_session.query(SearchDoc)
@@ -120,6 +175,7 @@ def delete_messages_and_files_from_chat_session(
).fetchall()
for id, files in messages_with_files:
delete_tool_call_for_message_id(message_id=id, db_session=db_session)
delete_search_doc_message_relationship(message_id=id, db_session=db_session)
for file_info in files or {}:
lobj_name = file_info.get("id")
@@ -139,11 +195,12 @@ def create_chat_session(
db_session: Session,
description: str,
user_id: UUID | None,
persona_id: int | None = None,
persona_id: int,
llm_override: LLMOverride | None = None,
prompt_override: PromptOverride | None = None,
one_shot: bool = False,
danswerbot_flow: bool = False,
slack_thread_id: str | None = None,
) -> ChatSession:
chat_session = ChatSession(
user_id=user_id,
@@ -153,6 +210,7 @@ def create_chat_session(
prompt_override=prompt_override,
one_shot=one_shot,
danswerbot_flow=danswerbot_flow,
slack_thread_id=slack_thread_id,
)
db_session.add(chat_session)
@@ -240,6 +298,39 @@ def get_chat_message(
return chat_message
def get_chat_messages_by_sessions(
chat_session_ids: list[int],
user_id: UUID | None,
db_session: Session,
skip_permission_check: bool = False,
) -> Sequence[ChatMessage]:
if not skip_permission_check:
for chat_session_id in chat_session_ids:
get_chat_session_by_id(
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
)
stmt = (
select(ChatMessage)
.where(ChatMessage.chat_session_id.in_(chat_session_ids))
.order_by(nullsfirst(ChatMessage.parent_message))
)
return db_session.execute(stmt).scalars().all()
def get_search_docs_for_chat_message(
chat_message_id: int, db_session: Session
) -> list[SearchDoc]:
stmt = (
select(SearchDoc)
.join(
ChatMessage__SearchDoc, ChatMessage__SearchDoc.search_doc_id == SearchDoc.id
)
.where(ChatMessage__SearchDoc.chat_message_id == chat_message_id)
)
return list(db_session.scalars(stmt).all())
def get_chat_messages_by_session(
chat_session_id: int,
user_id: UUID | None,
@@ -260,8 +351,6 @@ def get_chat_messages_by_session(
if prefetch_tool_calls:
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
if prefetch_tool_calls:
result = db_session.scalars(stmt).unique().all()
else:
result = db_session.scalars(stmt).all()
@@ -449,6 +538,27 @@ def get_doc_query_identifiers_from_model(
return doc_query_identifiers
def update_search_docs_table_with_relevance(
db_session: Session,
reference_db_search_docs: list[SearchDoc],
relevance_summary: LLMRelevanceSummaryResponse,
) -> None:
for search_doc in reference_db_search_docs:
relevance_data = relevance_summary.relevance_summaries.get(
f"{search_doc.document_id}-{search_doc.chunk_ind}"
)
if relevance_data is not None:
db_session.execute(
update(SearchDoc)
.where(SearchDoc.id == search_doc.id)
.values(
is_relevant=relevance_data.relevant,
relevance_explanation=relevance_data.content,
)
)
db_session.commit()
def create_db_search_doc(
server_search_doc: ServerSearchDoc,
db_session: Session,
@@ -463,17 +573,19 @@ def create_db_search_doc(
boost=server_search_doc.boost,
hidden=server_search_doc.hidden,
doc_metadata=server_search_doc.metadata,
is_relevant=server_search_doc.is_relevant,
relevance_explanation=server_search_doc.relevance_explanation,
# For docs further down that aren't reranked, we can't use the retrieval score
score=server_search_doc.score or 0.0,
match_highlights=server_search_doc.match_highlights,
updated_at=server_search_doc.updated_at,
primary_owners=server_search_doc.primary_owners,
secondary_owners=server_search_doc.secondary_owners,
is_internet=server_search_doc.is_internet,
)
db_session.add(db_search_doc)
db_session.commit()
return db_search_doc
@@ -502,11 +614,14 @@ def translate_db_search_doc_to_server_search_doc(
match_highlights=(
db_search_doc.match_highlights if not remove_doc_content else []
),
relevance_explanation=db_search_doc.relevance_explanation,
is_relevant=db_search_doc.is_relevant,
updated_at=db_search_doc.updated_at if not remove_doc_content else None,
primary_owners=db_search_doc.primary_owners if not remove_doc_content else [],
secondary_owners=(
db_search_doc.secondary_owners if not remove_doc_content else []
),
is_internet=db_search_doc.is_internet,
)
@@ -524,9 +639,11 @@ def get_retrieval_docs_from_chat_message(
def translate_db_message_to_chat_message_detail(
chat_message: ChatMessage, remove_doc_content: bool = False
chat_message: ChatMessage,
remove_doc_content: bool = False,
) -> ChatMessageDetail:
chat_msg_detail = ChatMessageDetail(
chat_session_id=chat_message.chat_session_id,
message_id=chat_message.id,
parent_message=chat_message.parent_message,
latest_child_message=chat_message.latest_child_message,

View File

@@ -152,7 +152,7 @@ def add_credential_to_connector(
credential_id: int,
cc_pair_name: str | None,
is_public: bool,
user: User,
user: User | None,
db_session: Session,
) -> StatusResponse[int]:
connector = fetch_connector_by_id(connector_id, db_session)
@@ -201,7 +201,7 @@ def add_credential_to_connector(
def remove_credential_from_connector(
connector_id: int,
credential_id: int,
user: User,
user: User | None,
db_session: Session,
) -> StatusResponse[int]:
connector = fetch_connector_by_id(connector_id, db_session)

View File

@@ -12,6 +12,7 @@ from danswer.connectors.gmail.constants import (
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.db.models import User
from danswer.server.documents.models import CredentialBase
@@ -142,6 +143,18 @@ def delete_credential(
f"Credential by provided id {credential_id} does not exist or does not belong to user"
)
associated_connectors = (
db_session.query(ConnectorCredentialPair)
.filter(ConnectorCredentialPair.credential_id == credential_id)
.all()
)
if associated_connectors:
raise ValueError(
f"Cannot delete credential {credential_id} as it is still associated with {len(associated_connectors)} connector(s). "
"Please delete all associated connectors first."
)
db_session.delete(credential)
db_session.commit()

View File

@@ -10,10 +10,15 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexModelStatus
from danswer.indexing.models import EmbeddingModelDetail
from danswer.search.search_nlp_models import clean_model_name
from danswer.server.manage.embedding.models import (
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
)
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -31,6 +36,7 @@ def create_embedding_model(
query_prefix=model_details.query_prefix,
passage_prefix=model_details.passage_prefix,
status=status,
cloud_provider_id=model_details.cloud_provider_id,
# Every single embedding model except the initial one from migrations has this name
# The initial one from migration is called "danswer_chunk"
index_name=f"danswer_chunk_{clean_model_name(model_details.model_name)}",
@@ -42,6 +48,42 @@ def create_embedding_model(
return embedding_model
def get_model_id_from_name(
db_session: Session, embedding_provider_name: str
) -> int | None:
query = select(CloudEmbeddingProvider).where(
CloudEmbeddingProvider.name == embedding_provider_name
)
provider = db_session.execute(query).scalars().first()
return provider.id if provider else None
def get_current_db_embedding_provider(
db_session: Session,
) -> ServerCloudEmbeddingProvider | None:
current_embedding_model = EmbeddingModelDetail.from_model(
get_current_db_embedding_model(db_session=db_session)
)
if (
current_embedding_model is None
or current_embedding_model.cloud_provider_id is None
):
return None
embedding_provider = fetch_embedding_provider(
db_session=db_session, provider_id=current_embedding_model.cloud_provider_id
)
if embedding_provider is None:
raise RuntimeError("No embedding provider exists for this model.")
current_embedding_provider = ServerCloudEmbeddingProvider.from_request(
cloud_provider_model=embedding_provider
)
return current_embedding_provider
def get_current_db_embedding_model(db_session: Session) -> EmbeddingModel:
query = (
select(EmbeddingModel)

View File

@@ -2,11 +2,34 @@ from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from danswer.db.models import LLMProvider as LLMProviderModel
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
def upsert_cloud_embedding_provider(
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
) -> CloudEmbeddingProvider:
existing_provider = (
db_session.query(CloudEmbeddingProviderModel)
.filter_by(name=provider.name)
.first()
)
if existing_provider:
for key, value in provider.dict().items():
setattr(existing_provider, key, value)
else:
new_provider = CloudEmbeddingProviderModel(**provider.dict())
db_session.add(new_provider)
existing_provider = new_provider
db_session.commit()
db_session.refresh(existing_provider)
return CloudEmbeddingProvider.from_request(existing_provider)
def upsert_llm_provider(
db_session: Session, llm_provider: LLMProviderUpsertRequest
) -> FullLLMProvider:
@@ -26,7 +49,6 @@ def upsert_llm_provider(
existing_llm_provider.model_names = llm_provider.model_names
db_session.commit()
return FullLLMProvider.from_model(existing_llm_provider)
# if it does not exist, create a new entry
llm_provider_model = LLMProviderModel(
name=llm_provider.name,
@@ -46,10 +68,26 @@ def upsert_llm_provider(
return FullLLMProvider.from_model(llm_provider_model)
def fetch_existing_embedding_providers(
db_session: Session,
) -> list[CloudEmbeddingProviderModel]:
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
return list(db_session.scalars(select(LLMProviderModel)).all())
def fetch_embedding_provider(
db_session: Session, provider_id: int
) -> CloudEmbeddingProviderModel | None:
return db_session.scalar(
select(CloudEmbeddingProviderModel).where(
CloudEmbeddingProviderModel.id == provider_id
)
)
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(
@@ -70,6 +108,16 @@ def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider |
return FullLLMProvider.from_model(provider_model)
def remove_embedding_provider(
db_session: Session, embedding_provider_name: str
) -> None:
db_session.execute(
delete(CloudEmbeddingProviderModel).where(
CloudEmbeddingProviderModel.name == embedding_provider_name
)
)
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
db_session.execute(
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)

View File

@@ -130,6 +130,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
chat_folders: Mapped[list["ChatFolder"]] = relationship(
"ChatFolder", back_populates="user"
)
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
# Personas owned by this user
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
@@ -246,6 +247,39 @@ class Persona__Tool(Base):
tool_id: Mapped[int] = mapped_column(ForeignKey("tool.id"), primary_key=True)
class StandardAnswer__StandardAnswerCategory(Base):
__tablename__ = "standard_answer__standard_answer_category"
standard_answer_id: Mapped[int] = mapped_column(
ForeignKey("standard_answer.id"), primary_key=True
)
standard_answer_category_id: Mapped[int] = mapped_column(
ForeignKey("standard_answer_category.id"), primary_key=True
)
class SlackBotConfig__StandardAnswerCategory(Base):
__tablename__ = "slack_bot_config__standard_answer_category"
slack_bot_config_id: Mapped[int] = mapped_column(
ForeignKey("slack_bot_config.id"), primary_key=True
)
standard_answer_category_id: Mapped[int] = mapped_column(
ForeignKey("standard_answer_category.id"), primary_key=True
)
class ChatMessage__StandardAnswer(Base):
__tablename__ = "chat_message__standard_answer"
chat_message_id: Mapped[int] = mapped_column(
ForeignKey("chat_message.id"), primary_key=True
)
standard_answer_id: Mapped[int] = mapped_column(
ForeignKey("standard_answer.id"), primary_key=True
)
"""
Documents/Indexing Tables
"""
@@ -436,7 +470,7 @@ class Credential(Base):
class EmbeddingModel(Base):
__tablename__ = "embedding_model"
# ID is used also to indicate the order that the models are configured by the admin
id: Mapped[int] = mapped_column(primary_key=True)
model_name: Mapped[str] = mapped_column(String)
model_dim: Mapped[int] = mapped_column(Integer)
@@ -448,6 +482,16 @@ class EmbeddingModel(Base):
)
index_name: Mapped[str] = mapped_column(String)
# New field for cloud provider relationship
cloud_provider_id: Mapped[int | None] = mapped_column(
ForeignKey("embedding_provider.id")
)
cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship(
"CloudEmbeddingProvider",
back_populates="embedding_models",
foreign_keys=[cloud_provider_id],
)
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
"IndexAttempt", back_populates="embedding_model"
)
@@ -467,6 +511,18 @@ class EmbeddingModel(Base):
),
)
def __repr__(self) -> str:
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
cloud_provider='{self.cloud_provider.name if self.cloud_provider else 'None'}')>"
@property
def api_key(self) -> str | None:
return self.cloud_provider.api_key if self.cloud_provider else None
@property
def provider_type(self) -> str | None:
return self.cloud_provider.name if self.cloud_provider else None
class IndexAttempt(Base):
"""
@@ -486,6 +542,7 @@ class IndexAttempt(Base):
ForeignKey("credential.id"),
nullable=True,
)
# Some index attempts that run from beginning will still have this as False
# This is only for attempts that are explicitly marked as from the start via
# the run once API
@@ -612,6 +669,10 @@ class SearchDoc(Base):
secondary_owners: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
is_internet: Mapped[bool] = mapped_column(Boolean, default=False, nullable=True)
is_relevant: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
relevance_explanation: Mapped[str | None] = mapped_column(String, nullable=True)
chat_messages = relationship(
"ChatMessage",
@@ -663,6 +724,10 @@ class ChatSession(Base):
current_alternate_model: Mapped[str | None] = mapped_column(String, default=None)
slack_thread_id: Mapped[str | None] = mapped_column(
String, nullable=True, default=None
)
# the latest "overrides" specified by the user. These take precedence over
# the attached persona. However, overrides specified directly in the
# `send-message` call will take precedence over these.
@@ -760,6 +825,11 @@ class ChatMessage(Base):
"ToolCall",
back_populates="message",
)
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
"StandardAnswer",
secondary=ChatMessage__StandardAnswer.__table__,
back_populates="chat_messages",
)
class ChatFolder(Base):
@@ -836,11 +906,6 @@ class ChatMessageFeedback(Base):
)
"""
Structures, Organizational, Configurations Tables
"""
class LLMProvider(Base):
__tablename__ = "llm_provider"
@@ -869,6 +934,29 @@ class LLMProvider(Base):
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
class CloudEmbeddingProvider(Base):
__tablename__ = "embedding_provider"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
api_key: Mapped[str | None] = mapped_column(EncryptedString())
default_model_id: Mapped[int | None] = mapped_column(
Integer, ForeignKey("embedding_model.id"), nullable=True
)
embedding_models: Mapped[list["EmbeddingModel"]] = relationship(
"EmbeddingModel",
back_populates="cloud_provider",
foreign_keys="EmbeddingModel.cloud_provider_id",
)
default_model: Mapped["EmbeddingModel"] = relationship(
"EmbeddingModel", foreign_keys=[default_model_id]
)
def __repr__(self) -> str:
return f"<EmbeddingProvider(name='{self.name}')>"
class DocumentSet(Base):
__tablename__ = "document_set"
@@ -948,6 +1036,7 @@ class Tool(Base):
# ID of the tool in the codebase, only applies for in-code tools.
# tools defined via the UI will have this as None
in_code_tool_id: Mapped[str | None] = mapped_column(String, nullable=True)
display_name: Mapped[str] = mapped_column(String, nullable=True)
# OpenAPI scheme for the tool. Only applies to tools defined via the UI.
openapi_schema: Mapped[dict[str, Any] | None] = mapped_column(
@@ -1077,14 +1166,60 @@ class ChannelConfig(TypedDict):
channel_names: list[str]
respond_tag_only: NotRequired[bool] # defaults to False
respond_to_bots: NotRequired[bool] # defaults to False
respond_team_member_list: NotRequired[list[str]]
respond_slack_group_list: NotRequired[list[str]]
respond_member_group_list: NotRequired[list[str]]
answer_filters: NotRequired[list[AllowedAnswerFilters]]
# If None then no follow up
# If empty list, follow up with no tags
follow_up_tags: NotRequired[list[str]]
class StandardAnswerCategory(Base):
__tablename__ = "standard_answer_category"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
"StandardAnswer",
secondary=StandardAnswer__StandardAnswerCategory.__table__,
back_populates="categories",
)
slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship(
"SlackBotConfig",
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
back_populates="standard_answer_categories",
)
class StandardAnswer(Base):
__tablename__ = "standard_answer"
id: Mapped[int] = mapped_column(primary_key=True)
keyword: Mapped[str] = mapped_column(String)
answer: Mapped[str] = mapped_column(String)
active: Mapped[bool] = mapped_column(Boolean)
__table_args__ = (
Index(
"unique_keyword_active",
keyword,
active,
unique=True,
postgresql_where=(active == True), # noqa: E712
),
)
categories: Mapped[list[StandardAnswerCategory]] = relationship(
"StandardAnswerCategory",
secondary=StandardAnswer__StandardAnswerCategory.__table__,
back_populates="standard_answers",
)
chat_messages: Mapped[list[ChatMessage]] = relationship(
"ChatMessage",
secondary=ChatMessage__StandardAnswer.__table__,
back_populates="standard_answers",
)
class SlackBotResponseType(str, PyEnum):
QUOTES = "quotes"
CITATIONS = "citations"
@@ -1105,7 +1240,16 @@ class SlackBotConfig(Base):
Enum(SlackBotResponseType, native_enum=False), nullable=False
)
enable_auto_filters: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
persona: Mapped[Persona | None] = relationship("Persona")
standard_answer_categories: Mapped[list[StandardAnswerCategory]] = relationship(
"StandardAnswerCategory",
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
back_populates="slack_bot_configs",
)
class TaskQueueState(Base):

View File

@@ -12,8 +12,8 @@ from sqlalchemy import update
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserRole
from danswer.configs.chat_configs import BING_API_KEY
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
from danswer.db.document_set import get_document_sets_by_ids
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import DocumentSet
from danswer.db.models import Persona
@@ -62,19 +62,6 @@ def create_update_persona(
) -> PersonaSnapshot:
"""Higher level function than upsert_persona, although either is valid to use."""
# Permission to actually use these is checked later
document_sets = list(
get_document_sets_by_ids(
document_set_ids=create_persona_request.document_set_ids,
db_session=db_session,
)
)
prompts = list(
get_prompts_by_ids(
prompt_ids=create_persona_request.prompt_ids,
db_session=db_session,
)
)
try:
persona = upsert_persona(
persona_id=persona_id,
@@ -85,9 +72,9 @@ def create_update_persona(
llm_relevance_filter=create_persona_request.llm_relevance_filter,
llm_filter_extraction=create_persona_request.llm_filter_extraction,
recency_bias=create_persona_request.recency_bias,
prompts=prompts,
prompt_ids=create_persona_request.prompt_ids,
tool_ids=create_persona_request.tool_ids,
document_sets=document_sets,
document_set_ids=create_persona_request.document_set_ids,
llm_model_provider_override=create_persona_request.llm_model_provider_override,
llm_model_version_override=create_persona_request.llm_model_version_override,
starter_messages=create_persona_request.starter_messages,
@@ -330,13 +317,13 @@ def upsert_persona(
llm_relevance_filter: bool,
llm_filter_extraction: bool,
recency_bias: RecencyBiasSetting,
prompts: list[Prompt] | None,
document_sets: list[DocumentSet] | None,
llm_model_provider_override: str | None,
llm_model_version_override: str | None,
starter_messages: list[StarterMessage] | None,
is_public: bool,
db_session: Session,
prompt_ids: list[int] | None = None,
document_set_ids: list[int] | None = None,
tool_ids: list[int] | None = None,
persona_id: int | None = None,
default_persona: bool = False,
@@ -356,6 +343,28 @@ def upsert_persona(
if not tools and tool_ids:
raise ValueError("Tools not found")
# Fetch and attach document_sets by IDs
document_sets = None
if document_set_ids is not None:
document_sets = (
db_session.query(DocumentSet)
.filter(DocumentSet.id.in_(document_set_ids))
.all()
)
if not document_sets and document_set_ids:
raise ValueError("document_sets not found")
# Fetch and attach prompts by IDs
prompts = None
if prompt_ids is not None:
prompts = db_session.query(Prompt).filter(Prompt.id.in_(prompt_ids)).all()
if not prompts and prompt_ids:
raise ValueError("prompts not found")
# ensure all specified tools are valid
if tools:
validate_persona_tools(tools)
if persona:
if not default_persona and persona.default_persona:
raise ValueError("Cannot update default persona with non-default.")
@@ -383,10 +392,10 @@ def upsert_persona(
if prompts is not None:
persona.prompts.clear()
persona.prompts = prompts
persona.prompts = prompts or []
if tools is not None:
persona.tools = tools
persona.tools = tools or []
else:
persona = Persona(
@@ -453,6 +462,14 @@ def update_persona_visibility(
db_session.commit()
def validate_persona_tools(tools: list[Tool]) -> None:
for tool in tools:
if tool.name == "InternetSearchTool" and not BING_API_KEY:
raise ValueError(
"Bing API key not found, please contact your Danswer admin to get it added!"
)
def check_user_can_edit_persona(user: User | None, persona: Persona) -> None:
# if user is None, assume that no-auth is turned on
if user is None:

View File

@@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
from danswer.db.document_set import get_document_sets_by_ids
from danswer.db.models import ChannelConfig
from danswer.db.models import Persona
from danswer.db.models import Persona__DocumentSet
@@ -15,6 +14,7 @@ from danswer.db.models import User
from danswer.db.persona import get_default_prompt
from danswer.db.persona import mark_persona_as_deleted
from danswer.db.persona import upsert_persona
from danswer.db.standard_answer import fetch_standard_answer_categories_by_ids
from danswer.search.enums import RecencyBiasSetting
@@ -42,12 +42,6 @@ def create_slack_bot_persona(
num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
) -> Persona:
"""NOTE: does not commit changes"""
document_sets = list(
get_document_sets_by_ids(
document_set_ids=document_set_ids,
db_session=db_session,
)
)
# create/update persona associated with the slack bot
persona_name = _build_persona_name(channel_names)
@@ -59,10 +53,10 @@ def create_slack_bot_persona(
description="",
num_chunks=num_chunks,
llm_relevance_filter=True,
llm_filter_extraction=True,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.AUTO,
prompts=[default_prompt],
document_sets=document_sets,
prompt_ids=[default_prompt.id],
document_set_ids=document_set_ids,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
@@ -79,12 +73,25 @@ def insert_slack_bot_config(
persona_id: int | None,
channel_config: ChannelConfig,
response_type: SlackBotResponseType,
standard_answer_category_ids: list[int],
enable_auto_filters: bool,
db_session: Session,
) -> SlackBotConfig:
existing_standard_answer_categories = fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=standard_answer_category_ids,
db_session=db_session,
)
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
raise ValueError(
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
)
slack_bot_config = SlackBotConfig(
persona_id=persona_id,
channel_config=channel_config,
response_type=response_type,
standard_answer_categories=existing_standard_answer_categories,
enable_auto_filters=enable_auto_filters,
)
db_session.add(slack_bot_config)
db_session.commit()
@@ -97,6 +104,8 @@ def update_slack_bot_config(
persona_id: int | None,
channel_config: ChannelConfig,
response_type: SlackBotResponseType,
standard_answer_category_ids: list[int],
enable_auto_filters: bool,
db_session: Session,
) -> SlackBotConfig:
slack_bot_config = db_session.scalar(
@@ -106,6 +115,16 @@ def update_slack_bot_config(
raise ValueError(
f"Unable to find slack bot config with ID {slack_bot_config_id}"
)
existing_standard_answer_categories = fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=standard_answer_category_ids,
db_session=db_session,
)
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
raise ValueError(
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
)
# get the existing persona id before updating the object
existing_persona_id = slack_bot_config.persona_id
@@ -115,6 +134,10 @@ def update_slack_bot_config(
slack_bot_config.persona_id = persona_id
slack_bot_config.channel_config = channel_config
slack_bot_config.response_type = response_type
slack_bot_config.standard_answer_categories = list(
existing_standard_answer_categories
)
slack_bot_config.enable_auto_filters = enable_auto_filters
# if the persona has changed, then clean up the old persona
if persona_id != existing_persona_id and existing_persona_id:

View File

@@ -0,0 +1,239 @@
import string
from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import StandardAnswer
from danswer.db.models import StandardAnswerCategory
from danswer.utils.logger import setup_logger
logger = setup_logger()
def check_category_validity(category_name: str) -> bool:
"""If a category name is too long, it should not be used (it will cause an error in Postgres
as the unique constraint can only apply to entries that are less than 2704 bytes).
Additionally, extremely long categories are not really usable / useful."""
if len(category_name) > 255:
logger.error(
f"Category with name '{category_name}' is too long, cannot be used"
)
return False
return True
def insert_standard_answer_category(
category_name: str, db_session: Session
) -> StandardAnswerCategory:
if not check_category_validity(category_name):
raise ValueError(f"Invalid category name: {category_name}")
standard_answer_category = StandardAnswerCategory(name=category_name)
db_session.add(standard_answer_category)
db_session.commit()
return standard_answer_category
def insert_standard_answer(
keyword: str,
answer: str,
category_ids: list[int],
db_session: Session,
) -> StandardAnswer:
existing_categories = fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=category_ids,
db_session=db_session,
)
if len(existing_categories) != len(category_ids):
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
standard_answer = StandardAnswer(
keyword=keyword,
answer=answer,
categories=existing_categories,
active=True,
)
db_session.add(standard_answer)
db_session.commit()
return standard_answer
def update_standard_answer(
standard_answer_id: int,
keyword: str,
answer: str,
category_ids: list[int],
db_session: Session,
) -> StandardAnswer:
standard_answer = db_session.scalar(
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
)
if standard_answer is None:
raise ValueError(f"No standard answer with id {standard_answer_id}")
existing_categories = fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=category_ids,
db_session=db_session,
)
if len(existing_categories) != len(category_ids):
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
standard_answer.keyword = keyword
standard_answer.answer = answer
standard_answer.categories = list(existing_categories)
db_session.commit()
return standard_answer
def remove_standard_answer(
standard_answer_id: int,
db_session: Session,
) -> None:
standard_answer = db_session.scalar(
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
)
if standard_answer is None:
raise ValueError(f"No standard answer with id {standard_answer_id}")
standard_answer.active = False
db_session.commit()
def update_standard_answer_category(
standard_answer_category_id: int,
category_name: str,
db_session: Session,
) -> StandardAnswerCategory:
standard_answer_category = db_session.scalar(
select(StandardAnswerCategory).where(
StandardAnswerCategory.id == standard_answer_category_id
)
)
if standard_answer_category is None:
raise ValueError(
f"No standard answer category with id {standard_answer_category_id}"
)
if not check_category_validity(category_name):
raise ValueError(f"Invalid category name: {category_name}")
standard_answer_category.name = category_name
db_session.commit()
return standard_answer_category
def fetch_standard_answer_category(
standard_answer_category_id: int,
db_session: Session,
) -> StandardAnswerCategory | None:
return db_session.scalar(
select(StandardAnswerCategory).where(
StandardAnswerCategory.id == standard_answer_category_id
)
)
def fetch_standard_answer_categories_by_names(
standard_answer_category_names: list[str],
db_session: Session,
) -> Sequence[StandardAnswerCategory]:
return db_session.scalars(
select(StandardAnswerCategory).where(
StandardAnswerCategory.name.in_(standard_answer_category_names)
)
).all()
def fetch_standard_answer_categories_by_ids(
standard_answer_category_ids: list[int],
db_session: Session,
) -> Sequence[StandardAnswerCategory]:
return db_session.scalars(
select(StandardAnswerCategory).where(
StandardAnswerCategory.id.in_(standard_answer_category_ids)
)
).all()
def fetch_standard_answer_categories(
db_session: Session,
) -> Sequence[StandardAnswerCategory]:
return db_session.scalars(select(StandardAnswerCategory)).all()
def fetch_standard_answer(
standard_answer_id: int,
db_session: Session,
) -> StandardAnswer | None:
return db_session.scalar(
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
)
def find_matching_standard_answers(
id_in: list[int],
query: str,
db_session: Session,
) -> list[StandardAnswer]:
stmt = (
select(StandardAnswer)
.where(StandardAnswer.active.is_(True))
.where(StandardAnswer.id.in_(id_in))
)
possible_standard_answers = db_session.scalars(stmt).all()
matching_standard_answers: list[StandardAnswer] = []
for standard_answer in possible_standard_answers:
# Remove punctuation and split the keyword into individual words
keyword_words = "".join(
char
for char in standard_answer.keyword.lower()
if char not in string.punctuation
).split()
# Remove punctuation and split the query into individual words
query_words = "".join(
char for char in query.lower() if char not in string.punctuation
).split()
# Check if all of the keyword words are in the query words
if all(word in query_words for word in keyword_words):
matching_standard_answers.append(standard_answer)
return matching_standard_answers
def fetch_standard_answers(db_session: Session) -> Sequence[StandardAnswer]:
return db_session.scalars(
select(StandardAnswer).where(StandardAnswer.active.is_(True))
).all()
def create_initial_default_standard_answer_category(db_session: Session) -> None:
default_category_id = 0
default_category_name = "General"
default_category = fetch_standard_answer_category(
standard_answer_category_id=default_category_id,
db_session=db_session,
)
if default_category is not None:
if default_category.name != default_category_name:
raise ValueError(
"DB is not in a valid initial state. "
"Default standard answer category does not have expected name."
)
return
standard_answer_category = StandardAnswerCategory(
id=default_category_id,
name=default_category_name,
)
db_session.add(standard_answer_category)
db_session.commit()

View File

@@ -1,5 +1,6 @@
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -107,18 +108,28 @@ def create_or_add_document_tag_list(
def get_tags_by_value_prefix_for_source_types(
tag_key_prefix: str | None,
tag_value_prefix: str | None,
sources: list[DocumentSource] | None,
limit: int | None,
db_session: Session,
) -> list[Tag]:
query = select(Tag)
if tag_value_prefix:
query = query.where(Tag.tag_value.startswith(tag_value_prefix))
if tag_key_prefix or tag_value_prefix:
conditions = []
if tag_key_prefix:
conditions.append(Tag.tag_key.ilike(f"{tag_key_prefix}%"))
if tag_value_prefix:
conditions.append(Tag.tag_value.ilike(f"{tag_value_prefix}%"))
query = query.where(or_(*conditions))
if sources:
query = query.where(Tag.source.in_(sources))
if limit:
query = query.limit(limit)
result = db_session.execute(query)
tags = result.scalars().all()

View File

@@ -6,7 +6,7 @@ from typing import Any
from danswer.access.models import DocumentAccess
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceChunkUncleaned
@dataclass(frozen=True)
@@ -186,7 +186,7 @@ class IdRetrievalCapable(abc.ABC):
min_chunk_ind: int | None,
max_chunk_ind: int | None,
user_access_control_list: list[str] | None = None,
) -> list[InferenceChunk]:
) -> list[InferenceChunkUncleaned]:
"""
Fetch chunk(s) based on document id
@@ -222,7 +222,7 @@ class KeywordCapable(abc.ABC):
time_decay_multiplier: float,
num_to_retrieve: int,
offset: int = 0,
) -> list[InferenceChunk]:
) -> list[InferenceChunkUncleaned]:
"""
Run keyword search and return a list of chunks. Inference chunks are chunks with all of the
information required for query time purposes. For example, some details of the document
@@ -262,7 +262,7 @@ class VectorCapable(abc.ABC):
time_decay_multiplier: float,
num_to_retrieve: int,
offset: int = 0,
) -> list[InferenceChunk]:
) -> list[InferenceChunkUncleaned]:
"""
Run vector/semantic search and return a list of inference chunks.
@@ -298,7 +298,7 @@ class HybridCapable(abc.ABC):
num_to_retrieve: int,
offset: int = 0,
hybrid_alpha: float | None = None,
) -> list[InferenceChunk]:
) -> list[InferenceChunkUncleaned]:
"""
Run hybrid search and return a list of inference chunks.
@@ -348,7 +348,7 @@ class AdminCapable(abc.ABC):
filters: IndexFilters,
num_to_retrieve: int,
offset: int = 0,
) -> list[InferenceChunk]:
) -> list[InferenceChunkUncleaned]:
"""
Run the special search for the admin document explorer page

View File

@@ -91,6 +91,9 @@ schema DANSWER_CHUNK_NAME {
field metadata type string {
indexing: summary | attribute
}
field metadata_suffix type string {
indexing: summary | attribute
}
field doc_updated_at type int {
indexing: summary | attribute
}
@@ -150,43 +153,41 @@ schema DANSWER_CHUNK_NAME {
query(query_embedding) tensor<float>(x[VARIABLE_DIM])
}
# This must be separate function for normalize_linear to work
function vector_score() {
function title_vector_score() {
expression {
# If no title, the full vector score comes from the content embedding
(query(title_content_ratio) * if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))) +
((1 - query(title_content_ratio)) * closeness(field, embeddings))
}
}
# This must be separate function for normalize_linear to work
function keyword_score() {
expression {
(query(title_content_ratio) * bm25(title)) +
((1 - query(title_content_ratio)) * bm25(content))
#query(title_content_ratio) * if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))
if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))
}
}
first-phase {
expression: vector_score
expression: closeness(field, embeddings)
}
# Weighted average between Vector Search and BM-25
# Each is a weighted average between the Title and Content fields
# Finally each doc is boosted by it's user feedback based boost and recency
# If any embedding or index field is missing, it just receives a score of 0
# Assumptions:
# - For a given query + corpus, the BM-25 scores will be relatively similar in distribution
# therefore not normalizing before combining.
# - For documents without title, it gets a score of 0 for that and this is ok as documents
# without any title match should be penalized.
global-phase {
expression {
(
# Weighted Vector Similarity Score
(query(alpha) * normalize_linear(vector_score)) +
(
query(alpha) * (
(query(title_content_ratio) * normalize_linear(title_vector_score))
+
((1 - query(title_content_ratio)) * normalize_linear(closeness(field, embeddings)))
)
)
+
# Weighted Keyword Similarity Score
((1 - query(alpha)) * normalize_linear(keyword_score))
(
(1 - query(alpha)) * (
(query(title_content_ratio) * normalize_linear(bm25(title)))
+
((1 - query(title_content_ratio)) * normalize_linear(bm25(content)))
)
)
)
# Boost based on user feedback
* document_boost
@@ -201,8 +202,6 @@ schema DANSWER_CHUNK_NAME {
bm25(content)
closeness(field, title_embedding)
closeness(field, embeddings)
keyword_score
vector_score
document_boost
recency_bias
closest(embeddings)

View File

@@ -41,6 +41,7 @@ from danswer.configs.constants import HIDDEN
from danswer.configs.constants import INDEX_SEPARATOR
from danswer.configs.constants import METADATA
from danswer.configs.constants import METADATA_LIST
from danswer.configs.constants import METADATA_SUFFIX
from danswer.configs.constants import PRIMARY_OWNERS
from danswer.configs.constants import RECENCY_BIAS
from danswer.configs.constants import SECONDARY_OWNERS
@@ -51,7 +52,6 @@ from danswer.configs.constants import SOURCE_LINKS
from danswer.configs.constants import SOURCE_TYPE
from danswer.configs.constants import TITLE
from danswer.configs.constants import TITLE_EMBEDDING
from danswer.configs.constants import TITLE_SEPARATOR
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
get_experts_stores_representations,
@@ -64,7 +64,7 @@ from danswer.document_index.vespa.utils import remove_invalid_unicode_chars
from danswer.document_index.vespa.utils import replace_invalid_doc_id_characters
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceChunkUncleaned
from danswer.search.retrieval.search_runner import query_processing
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
from danswer.utils.batching import batch_generator
@@ -119,6 +119,7 @@ def _does_document_exist(
chunk. This checks for whether the chunk exists already in the index"""
doc_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}"
doc_fetch_response = http_client.get(doc_url)
if doc_fetch_response.status_code == 404:
return False
@@ -346,8 +347,10 @@ def _index_vespa_chunk(
TITLE: remove_invalid_unicode_chars(title) if title else None,
SKIP_TITLE_EMBEDDING: not title,
CONTENT: remove_invalid_unicode_chars(chunk.content),
# This duplication of `content` is needed for keyword highlighting :(
CONTENT_SUMMARY: remove_invalid_unicode_chars(chunk.content),
# This duplication of `content` is needed for keyword highlighting
# Note that it's not exactly the same as the actual content
# which contains the title prefix and metadata suffix
CONTENT_SUMMARY: remove_invalid_unicode_chars(chunk.content_summary),
SOURCE_TYPE: str(document.source.value),
SOURCE_LINKS: json.dumps(chunk.source_links),
SEMANTIC_IDENTIFIER: remove_invalid_unicode_chars(document.semantic_identifier),
@@ -355,6 +358,7 @@ def _index_vespa_chunk(
METADATA: json.dumps(document.metadata),
# Save as a list for efficient extraction as an Attribute
METADATA_LIST: chunk.source_document.get_metadata_str_attributes(),
METADATA_SUFFIX: chunk.metadata_suffix,
EMBEDDINGS: embeddings_name_vector_map,
TITLE_EMBEDDING: chunk.title_embedding,
BOOST: chunk.boost,
@@ -559,7 +563,9 @@ def _process_dynamic_summary(
return processed_summary
def _vespa_hit_to_inference_chunk(hit: dict[str, Any]) -> InferenceChunk:
def _vespa_hit_to_inference_chunk(
hit: dict[str, Any], null_score: bool = False
) -> InferenceChunkUncleaned:
fields = cast(dict[str, Any], hit["fields"])
# parse fields that are stored as strings, but are really json / datetime
@@ -582,19 +588,6 @@ def _vespa_hit_to_inference_chunk(hit: dict[str, Any]) -> InferenceChunk:
f"Chunk with blurb: {fields.get(BLURB, 'Unknown')[:50]}... has no Semantic Identifier"
)
# Remove the title from the first chunk as every chunk already included
# its semantic identifier for LLM
content = fields[CONTENT]
if fields[CHUNK_ID] == 0:
parts = content.split(TITLE_SEPARATOR, maxsplit=1)
content = parts[1] if len(parts) > 1 and "\n" not in parts[0] else content
# User ran into this, not sure why this could happen, error checking here
blurb = fields.get(BLURB)
if not blurb:
logger.error(f"Chunk with id {fields.get(semantic_identifier)} ")
blurb = ""
source_links = fields.get(SOURCE_LINKS, {})
source_links_dict_unprocessed = (
json.loads(source_links) if isinstance(source_links, str) else source_links
@@ -604,29 +597,33 @@ def _vespa_hit_to_inference_chunk(hit: dict[str, Any]) -> InferenceChunk:
for k, v in cast(dict[str, str], source_links_dict_unprocessed).items()
}
return InferenceChunk(
return InferenceChunkUncleaned(
chunk_id=fields[CHUNK_ID],
blurb=blurb,
content=content,
blurb=fields.get(BLURB, ""), # Unused
content=fields[CONTENT], # Includes extra title prefix and metadata suffix
source_links=source_links_dict,
section_continuation=fields[SECTION_CONTINUATION],
document_id=fields[DOCUMENT_ID],
source_type=fields[SOURCE_TYPE],
title=fields.get(TITLE),
semantic_identifier=fields[SEMANTIC_IDENTIFIER],
boost=fields.get(BOOST, 1),
recency_bias=fields.get("matchfeatures", {}).get(RECENCY_BIAS, 1.0),
score=hit.get("relevance", 0),
score=None if null_score else hit.get("relevance", 0),
hidden=fields.get(HIDDEN, False),
primary_owners=fields.get(PRIMARY_OWNERS),
secondary_owners=fields.get(SECONDARY_OWNERS),
metadata=metadata,
metadata_suffix=fields.get(METADATA_SUFFIX),
match_highlights=match_highlights,
updated_at=updated_at,
)
@retry(tries=3, delay=1, backoff=2)
def _query_vespa(query_params: Mapping[str, str | int | float]) -> list[InferenceChunk]:
def _query_vespa(
query_params: Mapping[str, str | int | float]
) -> list[InferenceChunkUncleaned]:
if "query" in query_params and not cast(str, query_params["query"]).strip():
raise ValueError("No/empty query received")
@@ -681,16 +678,6 @@ def _query_vespa(query_params: Mapping[str, str | int | float]) -> list[Inferenc
return inference_chunks
@retry(tries=3, delay=1, backoff=2)
def _inference_chunk_by_vespa_id(vespa_id: str, index_name: str) -> InferenceChunk:
res = requests.get(
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_id}"
)
res.raise_for_status()
return _vespa_hit_to_inference_chunk(res.json())
def in_memory_zip_from_file_bytes(file_contents: dict[str, bytes]) -> BinaryIO:
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zipf:
@@ -735,6 +722,7 @@ class VespaIndex(DocumentIndex):
f"{SOURCE_TYPE}, "
f"{SOURCE_LINKS}, "
f"{SEMANTIC_IDENTIFIER}, "
f"{TITLE}, "
f"{SECTION_CONTINUATION}, "
f"{BOOST}, "
f"{HIDDEN}, "
@@ -742,6 +730,7 @@ class VespaIndex(DocumentIndex):
f"{PRIMARY_OWNERS}, "
f"{SECONDARY_OWNERS}, "
f"{METADATA}, "
f"{METADATA_SUFFIX}, "
f"{CONTENT_SUMMARY} "
f"from {{index_name}} where "
)
@@ -977,7 +966,7 @@ class VespaIndex(DocumentIndex):
min_chunk_ind: int | None,
max_chunk_ind: int | None,
user_access_control_list: list[str] | None = None,
) -> list[InferenceChunk]:
) -> list[InferenceChunkUncleaned]:
document_id = replace_invalid_doc_id_characters(document_id)
vespa_chunks = _get_vespa_chunks_by_document_id(
@@ -992,7 +981,8 @@ class VespaIndex(DocumentIndex):
return []
inference_chunks = [
_vespa_hit_to_inference_chunk(chunk) for chunk in vespa_chunks
_vespa_hit_to_inference_chunk(chunk, null_score=True)
for chunk in vespa_chunks
]
inference_chunks.sort(key=lambda chunk: chunk.chunk_id)
return inference_chunks
@@ -1005,7 +995,7 @@ class VespaIndex(DocumentIndex):
num_to_retrieve: int = NUM_RETURNED_HITS,
offset: int = 0,
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
) -> list[InferenceChunk]:
) -> list[InferenceChunkUncleaned]:
# IMPORTANT: THIS FUNCTION IS NOT UP TO DATE, DOES NOT WORK CORRECTLY
vespa_where_clauses = _build_vespa_filters(filters)
yql = (
@@ -1042,7 +1032,7 @@ class VespaIndex(DocumentIndex):
offset: int = 0,
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
) -> list[InferenceChunk]:
) -> list[InferenceChunkUncleaned]:
# IMPORTANT: THIS FUNCTION IS NOT UP TO DATE, DOES NOT WORK CORRECTLY
vespa_where_clauses = _build_vespa_filters(filters)
yql = (
@@ -1086,7 +1076,7 @@ class VespaIndex(DocumentIndex):
title_content_ratio: float | None = TITLE_CONTENT_RATIO,
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
) -> list[InferenceChunk]:
) -> list[InferenceChunkUncleaned]:
vespa_where_clauses = _build_vespa_filters(filters)
# Needs to be at least as much as the value set in Vespa schema config
target_hits = max(10 * num_to_retrieve, 1000)
@@ -1130,7 +1120,7 @@ class VespaIndex(DocumentIndex):
filters: IndexFilters,
num_to_retrieve: int = NUM_RETURNED_HITS,
offset: int = 0,
) -> list[InferenceChunk]:
) -> list[InferenceChunkUncleaned]:
vespa_where_clauses = _build_vespa_filters(filters, include_hidden=True)
yql = (
VespaIndex.yql_base.format(index_name=self.index_name)

View File

@@ -3,12 +3,16 @@ from collections.abc import Callable
from typing import TYPE_CHECKING
from danswer.configs.app_configs import BLURB_SIZE
from danswer.configs.app_configs import CHUNK_OVERLAP
from danswer.configs.app_configs import MINI_CHUNK_SIZE
from danswer.configs.app_configs import SKIP_METADATA_IN_CHUNK
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MAX_CHUNK_TITLE_LEN
from danswer.configs.constants import RETURN_SEPARATOR
from danswer.configs.constants import SECTION_SEPARATOR
from danswer.configs.constants import TITLE_SEPARATOR
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
get_metadata_keys_to_ignore,
)
from danswer.connectors.models import Document
from danswer.indexing.models import DocAwareChunk
from danswer.search.search_nlp_models import get_default_tokenizer
@@ -19,6 +23,14 @@ if TYPE_CHECKING:
from transformers import AutoTokenizer # type:ignore
# Not supporting overlaps, we need a clean combination of chunks and it is unclear if overlaps
# actually help quality at all
CHUNK_OVERLAP = 0
# Fairly arbitrary numbers but the general concept is we don't want the title/metadata to
# overwhelm the actual contents of the chunk
MAX_METADATA_PERCENTAGE = 0.25
CHUNK_MIN_CONTENT = 256
logger = setup_logger()
ChunkFunc = Callable[[Document], list[DocAwareChunk]]
@@ -44,6 +56,8 @@ def chunk_large_section(
chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
chunk_overlap: int = CHUNK_OVERLAP,
blurb_size: int = BLURB_SIZE,
title_prefix: str = "",
metadata_suffix: str = "",
) -> list[DocAwareChunk]:
from llama_index.text_splitter import SentenceSplitter
@@ -60,30 +74,69 @@ def chunk_large_section(
source_document=document,
chunk_id=start_chunk_id + chunk_ind,
blurb=blurb,
content=chunk_str,
content=f"{title_prefix}{chunk_str}{metadata_suffix}",
content_summary=chunk_str,
source_links={0: section_link_text},
section_continuation=(chunk_ind != 0),
metadata_suffix=metadata_suffix,
)
for chunk_ind, chunk_str in enumerate(split_texts)
]
return chunks
def _get_metadata_suffix_for_document_index(
metadata: dict[str, str | list[str]]
) -> str:
if not metadata:
return ""
metadata_str = "Metadata:\n"
for key, value in metadata.items():
if key in get_metadata_keys_to_ignore():
continue
value_str = ", ".join(value) if isinstance(value, list) else value
metadata_str += f"\t{key} - {value_str}\n"
return metadata_str.strip()
def chunk_document(
document: Document,
chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
subsection_overlap: int = CHUNK_OVERLAP,
blurb_size: int = BLURB_SIZE,
include_metadata: bool = not SKIP_METADATA_IN_CHUNK,
) -> list[DocAwareChunk]:
title = document.get_title_for_document_index()
title_prefix = title.replace("\n", " ") + TITLE_SEPARATOR if title else ""
tokenizer = get_default_tokenizer()
title = document.get_title_for_document_index()
title_prefix = f"{title[:MAX_CHUNK_TITLE_LEN]}{RETURN_SEPARATOR}" if title else ""
title_tokens = len(tokenizer.tokenize(title_prefix))
metadata_suffix = ""
metadata_tokens = 0
if include_metadata:
metadata = _get_metadata_suffix_for_document_index(document.metadata)
metadata_suffix = RETURN_SEPARATOR + metadata if metadata else ""
metadata_tokens = len(tokenizer.tokenize(metadata_suffix))
if metadata_tokens >= chunk_tok_size * MAX_METADATA_PERCENTAGE:
metadata_suffix = ""
metadata_tokens = 0
content_token_limit = chunk_tok_size - title_tokens - metadata_tokens
# If there is not enough context remaining then just index the chunk with no prefix/suffix
if content_token_limit <= CHUNK_MIN_CONTENT:
content_token_limit = chunk_tok_size
title_prefix = ""
metadata_suffix = ""
chunks: list[DocAwareChunk] = []
link_offsets: dict[int, str] = {}
chunk_text = ""
for ind, section in enumerate(document.sections):
section_text = title_prefix + section.text if ind == 0 else section.text
for section in document.sections:
section_text = section.text
section_link_text = section.link or ""
section_tok_length = len(tokenizer.tokenize(section_text))
@@ -92,16 +145,18 @@ def chunk_document(
# Large sections are considered self-contained/unique therefore they start a new chunk and are not concatenated
# at the end by other sections
if section_tok_length > chunk_tok_size:
if section_tok_length > content_token_limit:
if chunk_text:
chunks.append(
DocAwareChunk(
source_document=document,
chunk_id=len(chunks),
blurb=extract_blurb(chunk_text, blurb_size),
content=chunk_text,
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
content_summary=chunk_text,
source_links=link_offsets,
section_continuation=False,
metadata_suffix=metadata_suffix,
)
)
link_offsets = {}
@@ -113,9 +168,11 @@ def chunk_document(
document=document,
start_chunk_id=len(chunks),
tokenizer=tokenizer,
chunk_size=chunk_tok_size,
chunk_size=content_token_limit,
chunk_overlap=subsection_overlap,
blurb_size=blurb_size,
title_prefix=title_prefix,
metadata_suffix=metadata_suffix,
)
chunks.extend(large_section_chunks)
continue
@@ -125,7 +182,7 @@ def chunk_document(
current_tok_length
+ len(tokenizer.tokenize(SECTION_SEPARATOR))
+ section_tok_length
<= chunk_tok_size
<= content_token_limit
):
chunk_text += (
SECTION_SEPARATOR + section_text if chunk_text else section_text
@@ -137,9 +194,11 @@ def chunk_document(
source_document=document,
chunk_id=len(chunks),
blurb=extract_blurb(chunk_text, blurb_size),
content=chunk_text,
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
content_summary=chunk_text,
source_links=link_offsets,
section_continuation=False,
metadata_suffix=metadata_suffix,
)
)
link_offsets = {0: section_link_text}
@@ -153,9 +212,11 @@ def chunk_document(
source_document=document,
chunk_id=len(chunks),
blurb=extract_blurb(chunk_text, blurb_size),
content=chunk_text,
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
content_summary=chunk_text,
source_links=link_offsets,
section_continuation=False,
metadata_suffix=metadata_suffix,
)
)
return chunks
@@ -164,6 +225,9 @@ def chunk_document(
def split_chunk_text_into_mini_chunks(
chunk_text: str, mini_chunk_size: int = MINI_CHUNK_SIZE
) -> list[str]:
"""The minichunks won't all have the title prefix or metadata suffix
It could be a significant percentage of every minichunk so better to not include it
"""
from llama_index.text_splitter import SentenceSplitter
token_count_func = get_default_tokenizer().tokenize

View File

@@ -14,12 +14,12 @@ from danswer.indexing.chunker import split_chunk_text_into_mini_chunks
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
from danswer.search.enums import EmbedTextType
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.utils.batching import batch_list
from danswer.utils.logger import setup_logger
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
from shared_configs.enums import EmbedTextType
logger = setup_logger()
@@ -50,6 +50,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
normalize: bool,
query_prefix: str | None,
passage_prefix: str | None,
api_key: str | None = None,
provider_type: str | None = None,
):
super().__init__(model_name, normalize, query_prefix, passage_prefix)
self.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE # Currently not customizable
@@ -59,6 +61,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
query_prefix=query_prefix,
passage_prefix=passage_prefix,
normalize=normalize,
api_key=api_key,
provider_type=provider_type,
# The below are globally set, this flow always uses the indexing one
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
@@ -81,7 +85,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
for chunk_ind, chunk in enumerate(chunks):
chunk_texts.append(chunk.content)
mini_chunk_texts = (
split_chunk_text_into_mini_chunks(chunk.content)
split_chunk_text_into_mini_chunks(chunk.content_summary)
if enable_mini_chunk
else []
)

View File

@@ -36,6 +36,16 @@ class DocAwareChunk(BaseChunk):
# During inference we only have access to the document id and do not reconstruct the Document
source_document: Document
# The Vespa documents require a separate highlight field. Since it is stored as a duplicate anyway,
# it's easier to just store a not prefixed/suffixed string for the highlighting
# Also during the chunking, this non-prefixed/suffixed string is used for mini-chunks
content_summary: str
# During indexing we also (optionally) build a metadata string from the metadata dict
# This is also indexed so that we can strip it out after indexing, this way it supports
# multiple iterations of metadata representation for backwards compatibility
metadata_suffix: str
def to_short_descriptor(self) -> str:
"""Used when logging the identity of a chunk"""
return (
@@ -87,13 +97,19 @@ class EmbeddingModelDetail(BaseModel):
normalize: bool
query_prefix: str | None
passage_prefix: str | None
cloud_provider_id: int | None = None
cloud_provider_name: str | None = None
@classmethod
def from_model(cls, embedding_model: "EmbeddingModel") -> "EmbeddingModelDetail":
def from_model(
cls,
embedding_model: "EmbeddingModel",
) -> "EmbeddingModelDetail":
return cls(
model_name=embedding_model.model_name,
model_dim=embedding_model.model_dim,
normalize=embedding_model.normalize,
query_prefix=embedding_model.query_prefix,
passage_prefix=embedding_model.passage_prefix,
cloud_provider_id=embedding_model.cloud_provider_id,
)

View File

@@ -31,6 +31,8 @@ from danswer.llm.answering.stream_processing.citation_processing import (
from danswer.llm.answering.stream_processing.quotes_processing import (
build_quotes_processor,
)
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.llm.answering.stream_processing.utils import map_document_id_order
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import message_generator_to_string_generator
@@ -43,6 +45,7 @@ from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
from danswer.tools.images.image_generation_tool import ImageGenerationTool
from danswer.tools.images.prompt import build_image_generation_user_prompt
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.message import build_tool_message
from danswer.tools.message import ToolCallSummary
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS
@@ -58,17 +61,22 @@ from danswer.tools.tool_runner import (
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.tools.tool_runner import ToolRunner
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _get_answer_stream_processor(
context_docs: list[LlmDoc],
search_order_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
answer_style_configs: AnswerStyleConfig,
) -> StreamProcessor:
if answer_style_configs.citation_config:
return build_citation_processor(
context_docs=context_docs, search_order_docs=search_order_docs
context_docs=context_docs, doc_id_to_rank_map=doc_id_to_rank_map
)
if answer_style_configs.quotes_config:
return build_quotes_processor(
@@ -81,6 +89,9 @@ def _get_answer_stream_processor(
AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse]
logger = setup_logger()
class Answer:
def __init__(
self,
@@ -104,6 +115,7 @@ class Answer:
skip_explicit_tool_calling: bool = False,
# Returns the full document sections text from the search tool
return_contexts: bool = False,
skip_gen_ai_answer_generation: bool = False,
) -> None:
if single_message_history and message_history:
raise ValueError(
@@ -132,11 +144,12 @@ class Answer:
self._final_prompt: list[BaseMessage] | None = None
self._streamed_output: list[str] | None = None
self._processed_stream: list[
AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff
] | None = None
self._processed_stream: (
list[AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff] | None
) = None
self._return_contexts = return_contexts
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
def _update_prompt_builder_for_search_tool(
self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc]
@@ -228,7 +241,7 @@ class Answer:
tool_call_requests = tool_call_chunk.tool_calls
for tool_call_request in tool_call_requests:
tool = [
tool for tool in self.tools if tool.name() == tool_call_request["name"]
tool for tool in self.tools if tool.name == tool_call_request["name"]
][0]
tool_args = (
self.force_use_tool.args
@@ -247,15 +260,14 @@ class Answer:
),
)
if tool.name() == SearchTool.NAME:
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
self._update_prompt_builder_for_search_tool(prompt_builder, [])
elif tool.name() == ImageGenerationTool.NAME:
elif tool.name == ImageGenerationTool._NAME:
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question,
)
)
yield tool_runner.tool_final_result()
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
@@ -281,7 +293,7 @@ class Answer:
[
tool
for tool in self.tools
if tool.name() == self.force_use_tool.tool_name
if tool.name == self.force_use_tool.tool_name
]
),
None,
@@ -301,21 +313,39 @@ class Answer:
)
if tool_args is None:
raise RuntimeError(f"Tool '{tool.name()}' did not return args")
raise RuntimeError(f"Tool '{tool.name}' did not return args")
chosen_tool_and_args = (tool, tool_args)
else:
all_tool_args = check_which_tools_should_run_for_non_tool_calling_llm(
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=self.tools,
query=self.question,
history=self.message_history,
llm=self.llm,
)
for ind, args in enumerate(all_tool_args):
if args is not None:
chosen_tool_and_args = (self.tools[ind], args)
# for now, just pick the first tool selected
break
available_tools_and_args = [
(self.tools[ind], args)
for ind, args in enumerate(tool_options)
if args is not None
]
logger.info(
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
)
chosen_tool_and_args = (
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=self.message_history,
query=self.question,
llm=self.llm,
)
if available_tools_and_args
else None
)
logger.info(f"Chosen tool: {chosen_tool_and_args}")
if not chosen_tool_and_args:
prompt_builder.update_system_prompt(
@@ -336,7 +366,7 @@ class Answer:
tool_runner = ToolRunner(tool, tool_args)
yield tool_runner.kickoff()
if tool.name() == SearchTool.NAME:
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
final_context_documents = None
for response in tool_runner.tool_responses():
if response.id == FINAL_CONTEXT_DOCUMENTS:
@@ -344,12 +374,14 @@ class Answer:
yield response
if final_context_documents is None:
raise RuntimeError("SearchTool did not return final context documents")
raise RuntimeError(
f"{tool.name} did not return final context documents"
)
self._update_prompt_builder_for_search_tool(
prompt_builder, final_context_documents
)
elif tool.name() == ImageGenerationTool.NAME:
elif tool.name == ImageGenerationTool._NAME:
img_urls = []
for response in tool_runner.tool_responses():
if response.id == IMAGE_GENERATION_RESPONSE_ID:
@@ -371,13 +403,14 @@ class Answer:
HumanMessage(
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
self.question,
tool.name(),
tool.name,
*tool_runner.tool_responses(),
)
)
)
final = tool_runner.tool_final_result()
yield tool_runner.tool_final_result()
yield final
prompt = prompt_builder.build()
yield from message_generator_to_string_generator(self.llm.stream(prompt=prompt))
@@ -417,6 +450,10 @@ class Answer:
yield message
elif isinstance(message, ToolResponse):
if message.id == SEARCH_RESPONSE_SUMMARY_ID:
# We don't need to run section merging in this flow, this variable is only used
# below to specify the ordering of the documents for the purpose of matching
# citations to the right search documents. The deduplication logic is more lightweight
# there and we don't need to do it twice
search_results = [
llm_doc_from_inference_section(section)
for section in cast(
@@ -436,20 +473,23 @@ class Answer:
# assumes all tool responses will come first, then the final answer
break
process_answer_stream_fn = _get_answer_stream_processor(
context_docs=final_context_docs or [],
# if doc selection is enabled, then search_results will be None,
# so we need to use the final_context_docs
search_order_docs=search_results or final_context_docs or [],
answer_style_configs=self.answer_style_config,
)
if not self.skip_gen_ai_answer_generation:
process_answer_stream_fn = _get_answer_stream_processor(
context_docs=final_context_docs or [],
# if doc selection is enabled, then search_results will be None,
# so we need to use the final_context_docs
doc_id_to_rank_map=map_document_id_order(
search_results or final_context_docs or []
),
answer_style_configs=self.answer_style_config,
)
def _stream() -> Iterator[str]:
if message:
yield cast(str, message)
yield from cast(Iterator[str], stream)
def _stream() -> Iterator[str]:
if message:
yield cast(str, message)
yield from cast(Iterator[str], stream)
yield from process_answer_stream_fn(_stream())
yield from process_answer_stream_fn(_stream())
processed_stream = []
for processed_packet in _process_stream(output_generator):

View File

@@ -1,230 +0,0 @@
import json
from copy import deepcopy
from typing import TypeVar
from danswer.chat.models import (
LlmDoc,
)
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import tokenizer_trim_content
from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.search.models import InferenceChunk
from danswer.tools.search.search_utils import llm_doc_to_dict
from danswer.utils.logger import setup_logger
logger = setup_logger()
T = TypeVar("T", bound=LlmDoc | InferenceChunk)
_METADATA_TOKEN_ESTIMATE = 75
class PruningError(Exception):
pass
def _compute_limit(
prompt_config: PromptConfig,
llm_config: LLMConfig,
question: str,
max_chunks: int | None,
max_window_percentage: float | None,
max_tokens: int | None,
tool_token_count: int,
) -> int:
llm_max_document_tokens = compute_max_document_tokens(
prompt_config=prompt_config,
llm_config=llm_config,
tool_token_count=tool_token_count,
actual_user_input=question,
)
window_percentage_based_limit = (
max_window_percentage * llm_max_document_tokens
if max_window_percentage
else None
)
chunk_count_based_limit = (
max_chunks * DOC_EMBEDDING_CONTEXT_SIZE if max_chunks else None
)
limit_options = [
lim
for lim in [
window_percentage_based_limit,
chunk_count_based_limit,
max_tokens,
llm_max_document_tokens,
]
if lim
]
return int(min(limit_options))
def reorder_docs(
docs: list[T],
doc_relevance_list: list[bool] | None,
) -> list[T]:
if doc_relevance_list is None:
return docs
reordered_docs: list[T] = []
if doc_relevance_list is not None:
for selection_target in [True, False]:
for doc, is_relevant in zip(docs, doc_relevance_list):
if is_relevant == selection_target:
reordered_docs.append(doc)
return reordered_docs
def _remove_docs_to_ignore(docs: list[LlmDoc]) -> list[LlmDoc]:
return [doc for doc in docs if not doc.metadata.get(IGNORE_FOR_QA)]
def _apply_pruning(
docs: list[LlmDoc],
doc_relevance_list: list[bool] | None,
token_limit: int,
is_manually_selected_docs: bool,
use_sections: bool,
using_tool_message: bool,
) -> list[LlmDoc]:
llm_tokenizer = get_default_llm_tokenizer()
docs = deepcopy(docs) # don't modify in place
# re-order docs with all the "relevant" docs at the front
docs = reorder_docs(docs=docs, doc_relevance_list=doc_relevance_list)
# remove docs that are explicitly marked as not for QA
docs = _remove_docs_to_ignore(docs=docs)
tokens_per_doc: list[int] = []
final_doc_ind = None
total_tokens = 0
for ind, llm_doc in enumerate(docs):
doc_str = (
json.dumps(llm_doc_to_dict(llm_doc, ind))
if using_tool_message
else build_doc_context_str(
semantic_identifier=llm_doc.semantic_identifier,
source_type=llm_doc.source_type,
content=llm_doc.content,
metadata_dict=llm_doc.metadata,
updated_at=llm_doc.updated_at,
ind=ind,
)
)
doc_tokens = len(llm_tokenizer.encode(doc_str))
# if chunks, truncate chunks that are way too long
# this can happen if the embedding model tokenizer is different
# than the LLM tokenizer
if (
not is_manually_selected_docs
and not use_sections
and doc_tokens > DOC_EMBEDDING_CONTEXT_SIZE + _METADATA_TOKEN_ESTIMATE
):
logger.warning(
"Found more tokens in chunk than expected, "
"likely mismatch between embedding and LLM tokenizers. Trimming content..."
)
llm_doc.content = tokenizer_trim_content(
content=llm_doc.content,
desired_length=DOC_EMBEDDING_CONTEXT_SIZE,
tokenizer=llm_tokenizer,
)
doc_tokens = DOC_EMBEDDING_CONTEXT_SIZE
tokens_per_doc.append(doc_tokens)
total_tokens += doc_tokens
if total_tokens > token_limit:
final_doc_ind = ind
break
if final_doc_ind is not None:
if is_manually_selected_docs or use_sections:
# for document selection, only allow the final document to get truncated
# if more than that, then the user message is too long
if final_doc_ind != len(docs) - 1:
if use_sections:
# Truncate the rest of the list since we're over the token limit
# for the last one, trim it. In this case, the Sections can be rather long
# so better to trim the back than throw away the whole thing.
docs = docs[: final_doc_ind + 1]
else:
raise PruningError(
"LLM context window exceeded. Please de-select some documents or shorten your query."
)
amount_to_truncate = total_tokens - token_limit
# NOTE: need to recalculate the length here, since the previous calculation included
# overhead from JSON-fying the doc / the metadata
final_doc_content_length = len(
llm_tokenizer.encode(docs[final_doc_ind].content)
) - (amount_to_truncate)
# this could occur if we only have space for the title / metadata
# not ideal, but it's the most reasonable thing to do
# NOTE: the frontend prevents documents from being selected if
# less than 75 tokens are available to try and avoid this situation
# from occurring in the first place
if final_doc_content_length <= 0:
logger.error(
f"Final doc ({docs[final_doc_ind].semantic_identifier}) content "
"length is less than 0. Removing this doc from the final prompt."
)
docs.pop()
else:
docs[final_doc_ind].content = tokenizer_trim_content(
content=docs[final_doc_ind].content,
desired_length=final_doc_content_length,
tokenizer=llm_tokenizer,
)
else:
# For regular search, don't truncate the final document unless it's the only one
# If it's not the only one, we can throw it away, if it's the only one, we have to truncate
if final_doc_ind != 0:
docs = docs[:final_doc_ind]
else:
docs[0].content = tokenizer_trim_content(
content=docs[0].content,
desired_length=token_limit - _METADATA_TOKEN_ESTIMATE,
tokenizer=llm_tokenizer,
)
docs = [docs[0]]
return docs
def prune_documents(
docs: list[LlmDoc],
doc_relevance_list: list[bool] | None,
prompt_config: PromptConfig,
llm_config: LLMConfig,
question: str,
document_pruning_config: DocumentPruningConfig,
) -> list[LlmDoc]:
if doc_relevance_list is not None:
assert len(docs) == len(doc_relevance_list)
doc_token_limit = _compute_limit(
prompt_config=prompt_config,
llm_config=llm_config,
question=question,
max_chunks=document_pruning_config.max_chunks,
max_window_percentage=document_pruning_config.max_window_percentage,
max_tokens=document_pruning_config.max_tokens,
tool_token_count=document_pruning_config.tool_num_tokens,
)
return _apply_pruning(
docs=docs,
doc_relevance_list=doc_relevance_list,
token_limit=doc_token_limit,
is_manually_selected_docs=document_pruning_config.is_manually_selected_docs,
use_sections=document_pruning_config.use_sections,
using_tool_message=document_pruning_config.using_tool_message,
)

View File

@@ -70,9 +70,11 @@ class DocumentPruningConfig(BaseModel):
# e.g. we don't want to truncate each document to be no more
# than one chunk long
is_manually_selected_docs: bool = False
# If user specifies to include additional context chunks for each match, then different pruning
# If user specifies to include additional context Chunks for each match, then different pruning
# is used. As many Sections as possible are included, and the last Section is truncated
use_sections: bool = False
# If this is false, all of the Sections are truncated if they are longer than the expected Chunk size.
# Sections are often expected to be longer than the maximum Chunk size but Chunks should not be.
use_sections: bool = True
# If using tools, then we need to consider the tool length
tool_num_tokens: int = 0
# If using a tool message to represent the docs, then we have to JSON serialize

View File

@@ -0,0 +1,360 @@
import json
from collections import defaultdict
from copy import deepcopy
from typing import TypeVar
from pydantic import BaseModel
from danswer.chat.models import (
LlmDoc,
)
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import tokenizer_trim_content
from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.tools.search.search_utils import section_to_dict
from danswer.utils.logger import setup_logger
logger = setup_logger()
T = TypeVar("T", bound=LlmDoc | InferenceChunk | InferenceSection)
_METADATA_TOKEN_ESTIMATE = 75
class PruningError(Exception):
pass
class ChunkRange(BaseModel):
chunks: list[InferenceChunk]
start: int
end: int
def merge_chunk_intervals(chunk_ranges: list[ChunkRange]) -> list[ChunkRange]:
"""
This acts on a single document to merge the overlapping ranges of chunks
Algo explained here for easy understanding: https://leetcode.com/problems/merge-intervals
NOTE: this is used to merge chunk ranges for retrieving the right chunk_ids against the
document index, this does not merge the actual contents so it should not be used to actually
merge chunks post retrieval.
"""
sorted_ranges = sorted(chunk_ranges, key=lambda x: x.start)
combined_ranges: list[ChunkRange] = []
for new_chunk_range in sorted_ranges:
if not combined_ranges or combined_ranges[-1].end < new_chunk_range.start - 1:
combined_ranges.append(new_chunk_range)
else:
current_range = combined_ranges[-1]
current_range.end = max(current_range.end, new_chunk_range.end)
current_range.chunks.extend(new_chunk_range.chunks)
return combined_ranges
def _compute_limit(
prompt_config: PromptConfig,
llm_config: LLMConfig,
question: str,
max_chunks: int | None,
max_window_percentage: float | None,
max_tokens: int | None,
tool_token_count: int,
) -> int:
llm_max_document_tokens = compute_max_document_tokens(
prompt_config=prompt_config,
llm_config=llm_config,
tool_token_count=tool_token_count,
actual_user_input=question,
)
window_percentage_based_limit = (
max_window_percentage * llm_max_document_tokens
if max_window_percentage
else None
)
chunk_count_based_limit = (
max_chunks * DOC_EMBEDDING_CONTEXT_SIZE if max_chunks else None
)
limit_options = [
lim
for lim in [
window_percentage_based_limit,
chunk_count_based_limit,
max_tokens,
llm_max_document_tokens,
]
if lim
]
return int(min(limit_options))
def reorder_sections(
sections: list[InferenceSection],
section_relevance_list: list[bool] | None,
) -> list[InferenceSection]:
if section_relevance_list is None:
return sections
reordered_sections: list[InferenceSection] = []
if section_relevance_list is not None:
for selection_target in [True, False]:
for section, is_relevant in zip(sections, section_relevance_list):
if is_relevant == selection_target:
reordered_sections.append(section)
return reordered_sections
def _remove_sections_to_ignore(
sections: list[InferenceSection],
) -> list[InferenceSection]:
return [
section
for section in sections
if not section.center_chunk.metadata.get(IGNORE_FOR_QA)
]
def _apply_pruning(
sections: list[InferenceSection],
section_relevance_list: list[bool] | None,
token_limit: int,
is_manually_selected_docs: bool,
use_sections: bool,
using_tool_message: bool,
) -> list[InferenceSection]:
llm_tokenizer = get_default_llm_tokenizer()
sections = deepcopy(sections) # don't modify in place
# re-order docs with all the "relevant" docs at the front
sections = reorder_sections(
sections=sections, section_relevance_list=section_relevance_list
)
# remove docs that are explicitly marked as not for QA
sections = _remove_sections_to_ignore(sections=sections)
final_section_ind = None
total_tokens = 0
for ind, section in enumerate(sections):
section_str = (
# If using tool message, it will be a bit of an overestimate as the extra json text around the section
# will be counted towards the token count. However, once the Sections are merged, the extra json parts
# that overlap will not be counted multiple times like it is in the pruning step.
json.dumps(section_to_dict(section, ind))
if using_tool_message
else build_doc_context_str(
semantic_identifier=section.center_chunk.semantic_identifier,
source_type=section.center_chunk.source_type,
content=section.combined_content,
metadata_dict=section.center_chunk.metadata,
updated_at=section.center_chunk.updated_at,
ind=ind,
)
)
section_tokens = len(llm_tokenizer.encode(section_str))
# if not using sections (specifically, using Sections where each section maps exactly to the one center chunk),
# truncate chunks that are way too long. This can happen if the embedding model tokenizer is different
# than the LLM tokenizer
if (
not is_manually_selected_docs
and not use_sections
and section_tokens > DOC_EMBEDDING_CONTEXT_SIZE + _METADATA_TOKEN_ESTIMATE
):
logger.warning(
"Found more tokens in Section than expected, "
"likely mismatch between embedding and LLM tokenizers. Trimming content..."
)
section.combined_content = tokenizer_trim_content(
content=section.combined_content,
desired_length=DOC_EMBEDDING_CONTEXT_SIZE,
tokenizer=llm_tokenizer,
)
section_tokens = DOC_EMBEDDING_CONTEXT_SIZE
total_tokens += section_tokens
if total_tokens > token_limit:
final_section_ind = ind
break
if final_section_ind is not None:
if is_manually_selected_docs or use_sections:
if final_section_ind != len(sections) - 1:
# If using Sections, then the final section could be more than we need, in this case we are willing to
# truncate the final section to fit the specified context window
sections = sections[: final_section_ind + 1]
if is_manually_selected_docs:
# For document selection flow, only allow the final document/section to get truncated
# if more than that needs to be throw away then some documents are completely thrown away in which
# case this should be reported to the user as an error
raise PruningError(
"LLM context window exceeded. Please de-select some documents or shorten your query."
)
amount_to_truncate = total_tokens - token_limit
# NOTE: need to recalculate the length here, since the previous calculation included
# overhead from JSON-fying the doc / the metadata
final_doc_content_length = len(
llm_tokenizer.encode(sections[final_section_ind].combined_content)
) - (amount_to_truncate)
# this could occur if we only have space for the title / metadata
# not ideal, but it's the most reasonable thing to do
# NOTE: the frontend prevents documents from being selected if
# less than 75 tokens are available to try and avoid this situation
# from occurring in the first place
if final_doc_content_length <= 0:
logger.error(
f"Final section ({sections[final_section_ind].center_chunk.semantic_identifier}) content "
"length is less than 0. Removing this section from the final prompt."
)
sections.pop()
else:
sections[final_section_ind].combined_content = tokenizer_trim_content(
content=sections[final_section_ind].combined_content,
desired_length=final_doc_content_length,
tokenizer=llm_tokenizer,
)
else:
# For search on chunk level (Section is just a chunk), don't truncate the final Chunk/Section unless it's the only one
# If it's not the only one, we can throw it away, if it's the only one, we have to truncate
if final_section_ind != 0:
sections = sections[:final_section_ind]
else:
sections[0].combined_content = tokenizer_trim_content(
content=sections[0].combined_content,
desired_length=token_limit - _METADATA_TOKEN_ESTIMATE,
tokenizer=llm_tokenizer,
)
sections = [sections[0]]
return sections
def prune_sections(
sections: list[InferenceSection],
section_relevance_list: list[bool] | None,
prompt_config: PromptConfig,
llm_config: LLMConfig,
question: str,
document_pruning_config: DocumentPruningConfig,
) -> list[InferenceSection]:
# Assumes the sections are score ordered with highest first
if section_relevance_list is not None:
assert len(sections) == len(section_relevance_list)
token_limit = _compute_limit(
prompt_config=prompt_config,
llm_config=llm_config,
question=question,
max_chunks=document_pruning_config.max_chunks,
max_window_percentage=document_pruning_config.max_window_percentage,
max_tokens=document_pruning_config.max_tokens,
tool_token_count=document_pruning_config.tool_num_tokens,
)
return _apply_pruning(
sections=sections,
section_relevance_list=section_relevance_list,
token_limit=token_limit,
is_manually_selected_docs=document_pruning_config.is_manually_selected_docs,
use_sections=document_pruning_config.use_sections, # Now default True
using_tool_message=document_pruning_config.using_tool_message,
)
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
# Assuming there are no duplicates by this point
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)
center_chunk = max(
chunks, key=lambda x: x.score if x.score is not None else float("-inf")
)
merged_content = []
for i, chunk in enumerate(sorted_chunks):
if i > 0:
prev_chunk_id = sorted_chunks[i - 1].chunk_id
if chunk.chunk_id == prev_chunk_id + 1:
merged_content.append("\n")
else:
merged_content.append("\n\n...\n\n")
merged_content.append(chunk.content)
combined_content = "".join(merged_content)
return InferenceSection(
center_chunk=center_chunk,
chunks=sorted_chunks,
combined_content=combined_content,
)
def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
docs_map: dict[str, dict[int, InferenceChunk]] = defaultdict(dict)
doc_order: dict[str, int] = {}
for index, section in enumerate(sections):
if section.center_chunk.document_id not in doc_order:
doc_order[section.center_chunk.document_id] = index
for chunk in [section.center_chunk] + section.chunks:
chunks_map = docs_map[section.center_chunk.document_id]
existing_chunk = chunks_map.get(chunk.chunk_id)
if (
existing_chunk is None
or existing_chunk.score is None
or chunk.score is not None
and chunk.score > existing_chunk.score
):
chunks_map[chunk.chunk_id] = chunk
new_sections = []
for section_chunks in docs_map.values():
new_sections.append(_merge_doc_chunks(chunks=list(section_chunks.values())))
# Sort by highest score, then by original document order
# It is now 1 large section per doc, the center chunk being the one with the highest score
new_sections.sort(
key=lambda x: (
x.center_chunk.score if x.center_chunk.score is not None else 0,
-1 * doc_order[x.center_chunk.document_id],
),
reverse=True,
)
return new_sections
def prune_and_merge_sections(
sections: list[InferenceSection],
section_relevance_list: list[bool] | None,
prompt_config: PromptConfig,
llm_config: LLMConfig,
question: str,
document_pruning_config: DocumentPruningConfig,
) -> list[InferenceSection]:
# Assumes the sections are score ordered with highest first
remaining_sections = prune_sections(
sections=sections,
section_relevance_list=section_relevance_list,
prompt_config=prompt_config,
llm_config=llm_config,
question=question,
document_pruning_config=document_pruning_config,
)
merged_sections = _merge_sections(sections=remaining_sections)
return merged_sections

View File

@@ -7,7 +7,7 @@ from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import STOP_STREAM_PAT
from danswer.llm.answering.models import StreamProcessor
from danswer.llm.answering.stream_processing.utils import map_document_id_order
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.prompts.constants import TRIPLE_BACKTICK
from danswer.utils.logger import setup_logger
@@ -23,104 +23,167 @@ def in_code_block(llm_text: str) -> bool:
def extract_citations_from_stream(
tokens: Iterator[str],
context_docs: list[LlmDoc],
doc_id_to_rank_map: dict[str, int],
doc_id_to_rank_map: DocumentIdOrderMapping,
stop_stream: str | None = STOP_STREAM_PAT,
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
"""
Key aspects:
1. Stream Processing:
- Processes tokens one by one, allowing for real-time handling of large texts.
2. Citation Detection:
- Uses regex to find citations in the format [number].
- Example: [1], [2], etc.
3. Citation Mapping:
- Maps detected citation numbers to actual document ranks using doc_id_to_rank_map.
- Example: [1] might become [3] if doc_id_to_rank_map maps it to 3.
4. Citation Formatting:
- Replaces citations with properly formatted versions.
- Adds links if available: [[1]](https://example.com)
- Handles cases where links are not available: [[1]]()
5. Duplicate Handling:
- Skips consecutive citations of the same document to avoid redundancy.
6. Output Generation:
- Yields DanswerAnswerPiece objects for regular text.
- Yields CitationInfo objects for each unique citation encountered.
7. Context Awareness:
- Uses context_docs to access document information for citations.
This function effectively processes a stream of text, identifies and reformats citations,
and provides both the processed text and citation information as output.
"""
order_mapping = doc_id_to_rank_map.order_mapping
llm_out = ""
max_citation_num = len(context_docs)
citation_order = []
curr_segment = ""
prepend_bracket = False
cited_inds = set()
hold = ""
raw_out = ""
current_citations: list[int] = []
past_cite_count = 0
for raw_token in tokens:
raw_out += raw_token
if stop_stream:
next_hold = hold + raw_token
if stop_stream in next_hold:
break
if next_hold == stop_stream[: len(next_hold)]:
hold = next_hold
continue
token = next_hold
hold = ""
else:
token = raw_token
# Special case of [1][ where ][ is a single token
# This is where the model attempts to do consecutive citations like [1][2]
if prepend_bracket:
curr_segment += "[" + curr_segment
prepend_bracket = False
curr_segment += token
llm_out += token
citation_pattern = r"\[(\d+)\]"
citations_found = list(re.finditer(citation_pattern, curr_segment))
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
citation_found = re.search(citation_pattern, curr_segment)
# `past_cite_count`: number of characters since past citation
# 5 to ensure a citation hasn't occured
if len(citations_found) == 0 and len(llm_out) - past_cite_count > 5:
current_citations = []
if citation_found and not in_code_block(llm_out):
numerical_value = int(citation_found.group(1))
if 1 <= numerical_value <= max_citation_num:
context_llm_doc = context_docs[
numerical_value - 1
] # remove 1 index offset
if citations_found and not in_code_block(llm_out):
last_citation_end = 0
length_to_add = 0
while len(citations_found) > 0:
citation = citations_found.pop(0)
numerical_value = int(citation.group(1))
link = context_llm_doc.link
target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id]
if 1 <= numerical_value <= max_citation_num:
context_llm_doc = context_docs[numerical_value - 1]
real_citation_num = order_mapping[context_llm_doc.document_id]
# Use the citation number for the document's rank in
# the search (or selected docs) results
curr_segment = re.sub(
rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment
)
if real_citation_num not in citation_order:
citation_order.append(real_citation_num)
if target_citation_num not in cited_inds:
cited_inds.add(target_citation_num)
yield CitationInfo(
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
target_citation_num = citation_order.index(real_citation_num) + 1
# Skip consecutive citations of the same work
if target_citation_num in current_citations:
start, end = citation.span()
real_start = length_to_add + start
diff = end - start
curr_segment = (
curr_segment[: length_to_add + start]
+ curr_segment[real_start + diff :]
)
length_to_add -= diff
continue
link = context_llm_doc.link
# Replace the citation in the current segment
start, end = citation.span()
curr_segment = (
curr_segment[: start + length_to_add]
+ f"[{target_citation_num}]"
+ curr_segment[end + length_to_add :]
)
if link:
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1)
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
past_cite_count = len(llm_out)
current_citations.append(target_citation_num)
# In case there's another open bracket like [1][, don't want to match this
possible_citation_found = None
if target_citation_num not in cited_inds:
cited_inds.add(target_citation_num)
yield CitationInfo(
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
)
# if we see "[", but haven't seen the right side, hold back - this may be a
# citation that needs to be replaced with a link
if link:
prev_length = len(curr_segment)
curr_segment = (
curr_segment[: start + length_to_add]
+ f"[[{target_citation_num}]]({link})"
+ curr_segment[end + length_to_add :]
)
length_to_add += len(curr_segment) - prev_length
else:
prev_length = len(curr_segment)
curr_segment = (
curr_segment[: start + length_to_add]
+ f"[[{target_citation_num}]]()"
+ curr_segment[end + length_to_add :]
)
length_to_add += len(curr_segment) - prev_length
last_citation_end = end + length_to_add
if last_citation_end > 0:
yield DanswerAnswerPiece(answer_piece=curr_segment[:last_citation_end])
curr_segment = curr_segment[last_citation_end:]
if possible_citation_found:
continue
# Special case with back to back citations [1][2]
if curr_segment and curr_segment[-1] == "[":
curr_segment = curr_segment[:-1]
prepend_bracket = True
yield DanswerAnswerPiece(answer_piece=curr_segment)
curr_segment = ""
if curr_segment:
if prepend_bracket:
yield DanswerAnswerPiece(answer_piece="[" + curr_segment)
else:
yield DanswerAnswerPiece(answer_piece=curr_segment)
yield DanswerAnswerPiece(answer_piece=curr_segment)
def build_citation_processor(
context_docs: list[LlmDoc], search_order_docs: list[LlmDoc]
context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
) -> StreamProcessor:
def stream_processor(tokens: Iterator[str]) -> AnswerQuestionStreamReturn:
yield from extract_citations_from_stream(
tokens=tokens,
context_docs=context_docs,
doc_id_to_rank_map=map_document_id_order(search_order_docs),
doc_id_to_rank_map=doc_id_to_rank_map,
)
return stream_processor

View File

@@ -1,12 +1,18 @@
from collections.abc import Sequence
from pydantic import BaseModel
from danswer.chat.models import LlmDoc
from danswer.search.models import InferenceChunk
class DocumentIdOrderMapping(BaseModel):
order_mapping: dict[str, int]
def map_document_id_order(
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
) -> dict[str, int]:
) -> DocumentIdOrderMapping:
order_mapping = {}
current = 1 if one_indexed else 0
for chunk in chunks:
@@ -14,4 +20,4 @@ def map_document_id_order(
order_mapping[chunk.document_id] = current
current += 1
return order_mapping
return DocumentIdOrderMapping(order_mapping=order_mapping)

View File

@@ -23,6 +23,7 @@ from langchain_core.messages.tool import ToolCallChunk
from langchain_core.messages.tool import ToolMessage
from danswer.configs.app_configs import LOG_ALL_MODEL_INTERACTIONS
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_API_VERSION
@@ -42,6 +43,8 @@ logger = setup_logger()
litellm.drop_params = True
litellm.telemetry = False
litellm.set_verbose = LOG_ALL_MODEL_INTERACTIONS
def _base_msg_to_role(msg: BaseMessage) -> str:
if isinstance(msg, HumanMessage) or isinstance(msg, HumanMessageChunk):
@@ -229,32 +232,6 @@ class DefaultMultiLLM(LLM):
self._model_kwargs = model_kwargs
@staticmethod
def _log_prompt(prompt: LanguageModelInput) -> None:
if isinstance(prompt, list):
for ind, msg in enumerate(prompt):
if isinstance(msg, AIMessageChunk):
if msg.content:
log_msg = msg.content
elif msg.tool_call_chunks:
log_msg = "Tool Calls: " + str(
[
{
key: value
for key, value in tool_call.items()
if key != "index"
}
for tool_call in msg.tool_call_chunks
]
)
else:
log_msg = ""
logger.debug(f"Message {ind}:\n{log_msg}")
else:
logger.debug(f"Message {ind}:\n{msg.content}")
if isinstance(prompt, str):
logger.debug(f"Prompt:\n{prompt}")
def log_model_configs(self) -> None:
logger.info(f"Config: {self.config}")
@@ -304,17 +281,18 @@ class DefaultMultiLLM(LLM):
model_name=self._model_version,
temperature=self._temperature,
api_key=self._api_key,
api_base=self._api_base,
api_version=self._api_version,
)
def invoke(
def _invoke_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
) -> BaseMessage:
if LOG_ALL_MODEL_INTERACTIONS:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
self._log_prompt(prompt)
response = cast(
litellm.ModelResponse, self._completion(prompt, tools, tool_choice, False)
@@ -323,15 +301,14 @@ class DefaultMultiLLM(LLM):
response.choices[0].message
)
def stream(
def _stream_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
) -> Iterator[BaseMessage]:
if LOG_ALL_MODEL_INTERACTIONS:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
self._log_prompt(prompt)
if DISABLE_LITELLM_STREAMING:
yield self.invoke(prompt)
@@ -357,7 +334,7 @@ class DefaultMultiLLM(LLM):
"The AI model failed partway through generation, please try again."
)
if LOG_ALL_MODEL_INTERACTIONS and output:
if LOG_DANSWER_MODEL_INTERACTIONS and output:
content = output.content or ""
if isinstance(output, AIMessage):
if content:

View File

@@ -76,7 +76,7 @@ class CustomModelServer(LLM):
def log_model_configs(self) -> None:
logger.debug(f"Custom model at: {self._endpoint}")
def invoke(
def _invoke_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
@@ -84,7 +84,7 @@ class CustomModelServer(LLM):
) -> BaseMessage:
return self._execute(prompt)
def stream(
def _stream_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,

View File

@@ -3,9 +3,12 @@ from collections.abc import Iterator
from typing import Literal
from langchain.schema.language_model import LanguageModelInput
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from pydantic import BaseModel
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from danswer.utils.logger import setup_logger
@@ -19,6 +22,34 @@ class LLMConfig(BaseModel):
model_name: str
temperature: float
api_key: str | None
api_base: str | None
api_version: str | None
def log_prompt(prompt: LanguageModelInput) -> None:
if isinstance(prompt, list):
for ind, msg in enumerate(prompt):
if isinstance(msg, AIMessageChunk):
if msg.content:
log_msg = msg.content
elif msg.tool_call_chunks:
log_msg = "Tool Calls: " + str(
[
{
key: value
for key, value in tool_call.items()
if key != "index"
}
for tool_call in msg.tool_call_chunks
]
)
else:
log_msg = ""
logger.debug(f"Message {ind}:\n{log_msg}")
else:
logger.debug(f"Message {ind}:\n{msg.content}")
if isinstance(prompt, str):
logger.debug(f"Prompt:\n{prompt}")
class LLM(abc.ABC):
@@ -43,20 +74,48 @@ class LLM(abc.ABC):
def log_model_configs(self) -> None:
raise NotImplementedError
@abc.abstractmethod
def _precall(self, prompt: LanguageModelInput) -> None:
if DISABLE_GENERATIVE_AI:
raise Exception("Generative AI is disabled")
if LOG_DANSWER_MODEL_INTERACTIONS:
log_prompt(prompt)
def invoke(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
) -> BaseMessage:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._invoke_implementation(prompt, tools, tool_choice)
@abc.abstractmethod
def _invoke_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
) -> BaseMessage:
raise NotImplementedError
@abc.abstractmethod
def stream(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
) -> Iterator[BaseMessage]:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._stream_implementation(prompt, tools, tool_choice)
@abc.abstractmethod
def _stream_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
) -> Iterator[BaseMessage]:
raise NotImplementedError

View File

@@ -26,6 +26,7 @@ OPENAI_PROVIDER_NAME = "openai"
OPEN_AI_MODEL_NAMES = [
"gpt-4",
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
"gpt-4-turbo-preview",
"gpt-4-1106-preview",

View File

@@ -46,6 +46,7 @@ from danswer.db.engine import warm_up_connections
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import expire_index_attempts
from danswer.db.persona import delete_old_default_personas
from danswer.db.standard_answer import create_initial_default_standard_answer_category
from danswer.db.swap_index import check_index_swap
from danswer.document_index.factory import get_default_document_index
from danswer.llm.llm_initialization import load_llm_providers
@@ -66,11 +67,14 @@ from danswer.server.features.tool.api import admin_router as admin_tool_router
from danswer.server.features.tool.api import router as tool_router
from danswer.server.gpts.api import router as gpts_router
from danswer.server.manage.administrative import router as admin_router
from danswer.server.manage.embedding.api import admin_router as embedding_admin_router
from danswer.server.manage.embedding.api import basic_router as embedding_router
from danswer.server.manage.get_state import router as state_router
from danswer.server.manage.llm.api import admin_router as llm_admin_router
from danswer.server.manage.llm.api import basic_router as llm_router
from danswer.server.manage.secondary_index import router as secondary_index_router
from danswer.server.manage.slack_bot import router as slack_bot_management_router
from danswer.server.manage.standard_answer import router as standard_answer_router
from danswer.server.manage.users import router as user_router
from danswer.server.middleware.latency_logging import add_latency_logging_middleware
from danswer.server.query_and_chat.chat_backend import router as chat_router
@@ -207,6 +211,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.info("Verifying default standard answer category exists.")
create_initial_default_standard_answer_category(db_session)
logger.info("Loading LLM providers from env variables")
load_llm_providers(db_session)
@@ -242,12 +249,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
time.sleep(wait_time)
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
if db_embedding_model.cloud_provider_id is None:
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
yield
@@ -273,6 +281,7 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(
application, slack_bot_management_router
)
include_router_with_global_prefix_prepended(application, standard_answer_router)
include_router_with_global_prefix_prepended(application, persona_router)
include_router_with_global_prefix_prepended(application, admin_persona_router)
include_router_with_global_prefix_prepended(application, prompt_router)
@@ -285,6 +294,8 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, settings_admin_router)
include_router_with_global_prefix_prepended(application, llm_admin_router)
include_router_with_global_prefix_prepended(application, llm_router)
include_router_with_global_prefix_prepended(application, embedding_admin_router)
include_router_with_global_prefix_prepended(application, embedding_router)
include_router_with_global_prefix_prepended(
application, token_rate_limit_settings_router
)

View File

@@ -10,6 +10,7 @@ from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerContexts
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import LLMRelevanceSummaryResponse
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
@@ -21,6 +22,7 @@ from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_or_create_root_message
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.chat import update_search_docs_table_with_relevance
from danswer.db.engine import get_session_context_manager
from danswer.db.models import User
from danswer.db.persona import get_prompt_by_id
@@ -48,6 +50,7 @@ from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.utils import get_json_line
from danswer.tools.force import ForceUseTool
from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID
from danswer.tools.search.search_tool import SEARCH_EVALUATION_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
@@ -57,6 +60,7 @@ from danswer.tools.tool_runner import ToolCallKickoff
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
logger = setup_logger()
AnswerObjectIterator = Iterator[
@@ -70,6 +74,7 @@ AnswerObjectIterator = Iterator[
| ChatMessageDetail
| CitationInfo
| ToolCallKickoff
| LLMRelevanceSummaryResponse
]
@@ -88,8 +93,9 @@ def stream_answer_objects(
bypass_acl: bool = False,
use_citations: bool = False,
danswerbot_flow: bool = False,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
retrieval_metrics_callback: (
Callable[[RetrievalMetricsContainer], None] | None
) = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> AnswerObjectIterator:
"""Streams in order:
@@ -127,6 +133,7 @@ def stream_answer_objects(
user_query=query_msg.message,
history_str=history_str,
)
# Given back ahead of the documents for latency reasons
# In chat flow it's given back along with the documents
yield QueryRephrase(rephrased_query=rephrased_query)
@@ -168,6 +175,7 @@ def stream_answer_objects(
max_tokens=max_document_tokens,
use_sections=query_req.chunks_above > 0 or query_req.chunks_below > 0,
)
search_tool = SearchTool(
db_session=db_session,
user=user,
@@ -177,7 +185,11 @@ def stream_answer_objects(
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
chunks_above=query_req.chunks_above,
chunks_below=query_req.chunks_below,
full_doc=query_req.full_doc,
bypass_acl=bypass_acl,
llm_doc_eval=query_req.llm_doc_eval,
)
answer_config = AnswerStyleConfig(
@@ -185,6 +197,7 @@ def stream_answer_objects(
quotes_config=QuotesConfig() if not use_citations else None,
document_pruning_config=document_pruning_config,
)
answer = Answer(
question=query_msg.message,
answer_style_config=answer_config,
@@ -193,19 +206,23 @@ def stream_answer_objects(
single_message_history=history_str,
tools=[search_tool],
force_use_tool=ForceUseTool(
tool_name=search_tool.name(),
tool_name=search_tool.name,
args={"query": rephrased_query},
),
# for now, don't use tool calling for this flow, as we haven't
# tested quotes with tool calling too much yet
skip_explicit_tool_calling=True,
return_contexts=query_req.return_contexts,
skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation,
)
# won't be any ImageGenerationDisplay responses since that tool is never passed in
dropped_inds: list[int] = []
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
# for one-shot flow, don't currently do anything with these
if isinstance(packet, ToolResponse):
# (likely fine that it comes after the initial creation of the search docs)
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
search_response_summary = cast(SearchResponseSummary, packet.response)
@@ -238,6 +255,7 @@ def stream_answer_objects(
recency_bias_multiplier=search_response_summary.recency_bias_multiplier,
)
yield initial_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
chunk_indices = packet.response
@@ -249,8 +267,21 @@ def stream_answer_objects(
)
yield LLMRelevanceFilterResponse(relevant_chunk_indices=packet.response)
elif packet.id == SEARCH_DOC_CONTENT_ID:
yield packet.response
elif packet.id == SEARCH_EVALUATION_ID:
evaluation_response = LLMRelevanceSummaryResponse(
relevance_summaries=packet.response
)
if reference_db_search_docs is not None:
update_search_docs_table_with_relevance(
db_session=db_session,
reference_db_search_docs=reference_db_search_docs,
relevance_summary=evaluation_response,
)
yield evaluation_response
else:
yield packet
@@ -271,7 +302,6 @@ def stream_answer_objects(
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
yield msg_detail_response
@@ -305,8 +335,9 @@ def get_search_answer(
bypass_acl: bool = False,
use_citations: bool = False,
danswerbot_flow: bool = False,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
retrieval_metrics_callback: (
Callable[[RetrievalMetricsContainer], None] | None
) = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> OneShotQAResponse:
"""Collects the streamed one shot answer responses into a single object"""

View File

@@ -27,12 +27,19 @@ class DirectQARequest(ChunkContext):
messages: list[ThreadMessage]
prompt_id: int | None
persona_id: int
agentic: bool | None = None
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
# This is to forcibly skip (or run) the step, if None it uses the system defaults
skip_rerank: bool | None = None
skip_llm_chunk_filter: bool | None = None
chain_of_thought: bool = False
return_contexts: bool = False
# This is to toggle agentic evaluation:
# 1. Evaluates whether each response is relevant or not
# 2. Provides a summary of the document's relevance in the resulsts
llm_doc_eval: bool = False
# If True, skips generative an AI response to the search query
skip_gen_ai_answer_generation: bool = False
@root_validator
def check_chain_of_thought_and_prompt_id(

View File

@@ -0,0 +1,43 @@
AGENTIC_SEARCH_SYSTEM_PROMPT = """
You are an expert at evaluating the relevance of a document to a search query.
Provided a document and a search query, you determine if the document is relevant to the user query.
You ALWAYS output the 3 sections described below and every section always begins with the same header line.
The "Chain of Thought" is to help you understand the document and query and their relevance to one another.
The "Useful Analysis" is shown to the user to help them understand why the document is or is not useful for them.
The "Final Relevance Determination" is always a single True or False.
You always output your response following these 3 sections:
1. Chain of Thought:
Provide a chain of thought analysis considering:
- The main purpose and content of the document
- What the user is searching for
- How the document relates to the query
- Potential uses of the document for the given query
Be thorough, but avoid unnecessary repetition. Think step by step.
2. Useful Analysis:
Summarize the contents of the document as it relates to the user query.
BE ABSOLUTELY AS CONCISE AS POSSIBLE.
If the document is not useful, briefly mention the what the document is about.
Do NOT say whether this document is useful or not useful, ONLY provide the summary.
If referring to the document, prefer using "this" document over "the" document.
3. Final Relevance Determination:
True or False
"""
AGENTIC_SEARCH_USER_PROMPT = """
Document:
```
{content}
```
Query:
{query}
Be sure to run through the 3 steps of evaluation:
1. Chain of Thought
2. Useful Analysis
3. Final Relevance Determination
""".strip()

View File

@@ -144,6 +144,23 @@ Follow Up Input: {{question}}
Standalone question (Respond with only the short combined query):
""".strip()
INTERNET_SEARCH_QUERY_REPHRASE = f"""
Given the following conversation and a follow up input, rephrase the follow up into a SHORT, \
standalone query suitable for an internet search engine.
IMPORTANT: If a specific query might limit results, keep it broad. \
If a broad query might yield too many results, make it detailed.
If there is a clear change in topic, ensure the query reflects the new topic accurately.
Strip out any information that is not relevant for the internet search.
{GENERAL_SEP_PAT}
Chat History:
{{chat_history}}
{GENERAL_SEP_PAT}
Follow Up Input: {{question}}
Internet Search Query (Respond with a detailed and specific query):
""".strip()
# The below prompts are retired
NO_SEARCH = "No Search"

View File

@@ -4,7 +4,7 @@
USEFUL_PAT = "Yes useful"
NONUSEFUL_PAT = "Not useful"
CHUNK_FILTER_PROMPT = f"""
SECTION_FILTER_PROMPT = f"""
Determine if the reference section is USEFUL for answering the user query.
It is NOT enough for the section to be related to the query, \
it must contain information that is USEFUL for answering the query.
@@ -27,4 +27,4 @@ Respond with EXACTLY AND ONLY: "{USEFUL_PAT}" or "{NONUSEFUL_PAT}"
# Use the following for easy viewing of prompts
if __name__ == "__main__":
print(CHUNK_FILTER_PROMPT)
print(SECTION_FILTER_PROMPT)

View File

@@ -28,8 +28,3 @@ class SearchType(str, Enum):
class QueryFlow(str, Enum):
SEARCH = "search"
QUESTION_ANSWER = "question-answer"
class EmbedTextType(str, Enum):
QUERY = "query"
PASSAGE = "passage"

View File

@@ -4,6 +4,8 @@ from typing import Any
from pydantic import BaseModel
from pydantic import validator
from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.chat_configs import HYBRID_ALPHA
from danswer.configs.chat_configs import NUM_RERANKED_RESULTS
@@ -47,8 +49,8 @@ class ChunkMetric(BaseModel):
class ChunkContext(BaseModel):
# Additional surrounding context options, if full doc, then chunks are deduped
# If surrounding context overlap, it is combined into one
chunks_above: int = 0
chunks_below: int = 0
chunks_above: int = CONTEXT_CHUNKS_ABOVE
chunks_below: int = CONTEXT_CHUNKS_BELOW
full_doc: bool = False
@validator("chunks_above", "chunks_below", pre=True, each_item=False)
@@ -94,7 +96,7 @@ class SearchQuery(ChunkContext):
# Only used if not skip_rerank
num_rerank: int | None = NUM_RERANKED_RESULTS
# Only used if not skip_llm_chunk_filter
max_llm_filter_chunks: int = NUM_RERANKED_RESULTS
max_llm_filter_sections: int = NUM_RERANKED_RESULTS
class Config:
frozen = True
@@ -128,11 +130,14 @@ class InferenceChunk(BaseChunk):
recency_bias: float
score: float | None
hidden: bool
is_relevant: bool | None = None
relevance_explanation: str | None = None
metadata: dict[str, str | list[str]]
# Matched sections in the chunk. Uses Vespa syntax e.g. <hi>TEXT</hi>
# to specify that a set of words should be highlighted. For example:
# ["<hi>the</hi> <hi>answer</hi> is 42", "he couldn't find an <hi>answer</hi>"]
match_highlights: list[str]
# when the doc was last updated
updated_at: datetime | None
primary_owners: list[str] | None = None
@@ -162,20 +167,54 @@ class InferenceChunk(BaseChunk):
def __hash__(self) -> int:
return hash((self.document_id, self.chunk_id))
def __lt__(self, other: Any) -> bool:
if not isinstance(other, InferenceChunk):
return NotImplemented
if self.score is None:
if other.score is None:
return self.chunk_id > other.chunk_id
return True
if other.score is None:
return False
if self.score == other.score:
return self.chunk_id > other.chunk_id
return self.score < other.score
class InferenceSection(InferenceChunk):
"""Section is a combination of chunks. A section could be a single chunk, several consecutive
chunks or the entire document"""
def __gt__(self, other: Any) -> bool:
if not isinstance(other, InferenceChunk):
return NotImplemented
if self.score is None:
return False
if other.score is None:
return True
if self.score == other.score:
return self.chunk_id < other.chunk_id
return self.score > other.score
class InferenceChunkUncleaned(InferenceChunk):
title: str | None # Separate from Semantic Identifier though often same
metadata_suffix: str | None
def to_inference_chunk(self) -> InferenceChunk:
# Create a dict of all fields except 'title' and 'metadata_suffix'
# Assumes the cleaning has already been applied and just needs to translate to the right type
inference_chunk_data = {
k: v
for k, v in self.dict().items()
if k not in ["title", "metadata_suffix"]
}
return InferenceChunk(**inference_chunk_data)
class InferenceSection(BaseModel):
"""Section list of chunks with a combined content. A section could be a single chunk, several
chunks from the same document or the entire document."""
center_chunk: InferenceChunk
chunks: list[InferenceChunk]
combined_content: str
@classmethod
def from_chunk(
cls, inf_chunk: InferenceChunk, content: str | None = None
) -> "InferenceSection":
inf_chunk_data = inf_chunk.dict()
return cls(**inf_chunk_data, combined_content=content or inf_chunk.content)
class SearchDoc(BaseModel):
document_id: str
@@ -191,6 +230,8 @@ class SearchDoc(BaseModel):
hidden: bool
metadata: dict[str, str | list[str]]
score: float | None
is_relevant: bool | None = None
relevance_explanation: str | None = None
# Matched sections in the doc. Uses Vespa syntax e.g. <hi>TEXT</hi>
# to specify that a set of words should be highlighted. For example:
# ["<hi>the</hi> <hi>answer</hi> is 42", "the answer is <hi>42</hi>""]
@@ -199,6 +240,7 @@ class SearchDoc(BaseModel):
updated_at: datetime | None
primary_owners: list[str] | None
secondary_owners: list[str] | None
is_internet: bool = False
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().dict(*args, **kwargs) # type: ignore
@@ -229,6 +271,13 @@ class SavedSearchDoc(SearchDoc):
return self.score < other.score
class SavedSearchDocWithContent(SavedSearchDoc):
"""Used for endpoints that need to return the actual contents of the retrieved
section in addition to the match_highlights."""
content: str
class RetrievalDocs(BaseModel):
top_documents: list[SavedSearchDoc]

View File

@@ -1,15 +1,20 @@
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from typing import cast
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.chat.models import RelevanceChunk
from danswer.configs.chat_configs import DISABLE_AGENTIC_SEARCH
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prune_and_merge import ChunkRange
from danswer.llm.answering.prune_and_merge import merge_chunk_intervals
from danswer.llm.interfaces import LLM
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
@@ -23,31 +28,14 @@ from danswer.search.models import SearchRequest
from danswer.search.postprocessing.postprocessing import search_postprocessing
from danswer.search.preprocessing.preprocessing import retrieval_preprocessing
from danswer.search.retrieval.search_runner import retrieve_chunks
from danswer.search.utils import inference_section_from_chunks
from danswer.secondary_llm_flows.agentic_evaluation import evaluate_inference_section
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
class ChunkRange(BaseModel):
chunk: InferenceChunk
start: int
end: int
combined_content: str | None = None
def merge_chunk_intervals(chunk_ranges: list[ChunkRange]) -> list[ChunkRange]:
"""This acts on a single document to merge the overlapping ranges of sections
Algo explained here for easy understanding: https://leetcode.com/problems/merge-intervals
"""
sorted_ranges = sorted(chunk_ranges, key=lambda x: x.start)
ans: list[ChunkRange] = []
for chunk_range in sorted_ranges:
if not ans or ans[-1].end < chunk_range.start:
ans.append(chunk_range)
else:
ans[-1].end = max(ans[-1].end, chunk_range.end)
return ans
logger = setup_logger()
class SearchPipeline:
@@ -59,9 +47,12 @@ class SearchPipeline:
fast_llm: LLM,
db_session: Session,
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
retrieval_metrics_callback: (
Callable[[RetrievalMetricsContainer], None] | None
) = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
prompt_config: PromptConfig | None = None,
pruning_config: DocumentPruningConfig | None = None,
):
self.search_request = search_request
self.user = user
@@ -77,61 +68,116 @@ class SearchPipeline:
primary_index_name=self.embedding_model.index_name,
secondary_index_name=None,
)
self.prompt_config: PromptConfig | None = prompt_config
self.pruning_config: DocumentPruningConfig | None = pruning_config
# Preprocessing steps generate this
self._search_query: SearchQuery | None = None
self._predicted_search_type: SearchType | None = None
self._predicted_flow: QueryFlow | None = None
# Initial document index retrieval chunks
self._retrieved_chunks: list[InferenceChunk] | None = None
# Another call made to the document index to get surrounding sections
self._retrieved_sections: list[InferenceSection] | None = None
self._reranked_chunks: list[InferenceChunk] | None = None
# Reranking and LLM section selection can be run together
# If only LLM selection is on, the reranked chunks are yielded immediatly
self._reranked_sections: list[InferenceSection] | None = None
self._relevant_chunk_indices: list[int] | None = None
self._relevant_section_indices: list[int] | None = None
# If chunks have been merged, the LLM filter flow no longer applies
# as the indices no longer match. Can be implemented later as needed
self.ran_merge_chunk = False
# Generates reranked chunks and LLM selections
self._postprocessing_generator: (
Iterator[list[InferenceSection] | list[int]] | None
) = None
# generator state
self._postprocessing_generator: Generator[
list[InferenceChunk] | list[str], None, None
] | None = None
"""Pre-processing"""
def _combine_chunks(self, post_rerank: bool) -> list[InferenceSection]:
if not post_rerank and self._retrieved_sections:
def _run_preprocessing(self) -> None:
(
final_search_query,
predicted_search_type,
predicted_flow,
) = retrieval_preprocessing(
search_request=self.search_request,
user=self.user,
llm=self.llm,
db_session=self.db_session,
bypass_acl=self.bypass_acl,
)
self._search_query = final_search_query
self._predicted_search_type = predicted_search_type
self._predicted_flow = predicted_flow
@property
def search_query(self) -> SearchQuery:
if self._search_query is not None:
return self._search_query
self._run_preprocessing()
return cast(SearchQuery, self._search_query)
@property
def predicted_search_type(self) -> SearchType:
if self._predicted_search_type is not None:
return self._predicted_search_type
self._run_preprocessing()
return cast(SearchType, self._predicted_search_type)
@property
def predicted_flow(self) -> QueryFlow:
if self._predicted_flow is not None:
return self._predicted_flow
self._run_preprocessing()
return cast(QueryFlow, self._predicted_flow)
"""Retrieval and Postprocessing"""
def _get_chunks(self) -> list[InferenceChunk]:
"""TODO as a future extension:
If large chunks (above 512 tokens) are used which cannot be directly fed to the LLM,
This step should run the two retrievals to get all of the base size chunks
"""
if self._retrieved_chunks is not None:
return self._retrieved_chunks
self._retrieved_chunks = retrieve_chunks(
query=self.search_query,
document_index=self.document_index,
db_session=self.db_session,
hybrid_alpha=self.search_request.hybrid_alpha,
multilingual_expansion_str=MULTILINGUAL_QUERY_EXPANSION,
retrieval_metrics_callback=self.retrieval_metrics_callback,
)
return cast(list[InferenceChunk], self._retrieved_chunks)
def _get_sections(self) -> list[InferenceSection]:
"""Returns an expanded section from each of the chunks.
If whole docs (instead of above/below context) is specified then it will give back all of the whole docs
that have a corresponding chunk.
This step should be fast for any document index implementation.
"""
if self._retrieved_sections is not None:
return self._retrieved_sections
if post_rerank and self._reranked_sections:
return self._reranked_sections
if not post_rerank:
chunks = self.retrieved_chunks
else:
chunks = self.reranked_chunks
retrieved_chunks = self._get_chunks()
if self._search_query is None:
# Should never happen
raise RuntimeError("Failed in Query Preprocessing")
above = self.search_query.chunks_above
below = self.search_query.chunks_below
functions_with_args: list[tuple[Callable, tuple]] = []
final_inference_sections = []
# Nothing to combine, just return the chunks
if (
not self._search_query.chunks_above
and not self._search_query.chunks_below
and not self._search_query.full_doc
):
return [InferenceSection.from_chunk(chunk) for chunk in chunks]
# If chunk merges have been run, LLM reranking loses meaning
# Needs reimplementation, out of scope for now
self.ran_merge_chunk = True
expanded_inference_sections = []
# Full doc setting takes priority
if self._search_query.full_doc:
if self.search_query.full_doc:
seen_document_ids = set()
unique_chunks = []
for chunk in chunks:
# This preserves the ordering since the chunks are retrieved in score order
for chunk in retrieved_chunks:
if chunk.document_id not in seen_document_ids:
seen_document_ids.add(chunk.document_id)
unique_chunks.append(chunk)
@@ -156,43 +202,54 @@ class SearchPipeline:
for ind, chunk in enumerate(unique_chunks):
inf_chunks = list_inference_chunks[ind]
combined_content = "\n".join([chunk.content for chunk in inf_chunks])
final_inference_sections.append(
InferenceSection.from_chunk(chunk, content=combined_content)
inference_section = inference_section_from_chunks(
center_chunk=chunk,
chunks=inf_chunks,
)
return final_inference_sections
if inference_section is not None:
expanded_inference_sections.append(inference_section)
else:
logger.warning("Skipped creation of section, no chunks found")
self._retrieved_sections = expanded_inference_sections
return expanded_inference_sections
# General flow:
# - Combine chunks into lists by document_id
# - For each document, run merge-intervals to get combined ranges
# - This allows for less queries to the document index
# - Fetch all of the new chunks with contents for the combined ranges
# - Map it back to the combined ranges (which each know their "center" chunk)
# - Reiterate the chunks again and map to the results above based on the chunk.
# This maintains the original chunks ordering. Note, we cannot simply sort by score here
# as reranking flow may wipe the scores for a lot of the chunks.
doc_chunk_ranges_map = defaultdict(list)
for chunk in chunks:
for chunk in retrieved_chunks:
# The list of ranges for each document is ordered by score
doc_chunk_ranges_map[chunk.document_id].append(
ChunkRange(
chunk=chunk,
start=max(0, chunk.chunk_id - self._search_query.chunks_above),
chunks=[chunk],
start=max(0, chunk.chunk_id - above),
# No max known ahead of time, filter will handle this anyway
end=chunk.chunk_id + self._search_query.chunks_below,
end=chunk.chunk_id + below,
)
)
# List of ranges, outside list represents documents, inner list represents ranges
merged_ranges = [
merge_chunk_intervals(ranges) for ranges in doc_chunk_ranges_map.values()
]
reverse_map = {r.chunk: r for doc_ranges in merged_ranges for r in doc_ranges}
flat_ranges = [r for ranges in merged_ranges for r in ranges]
for chunk_range in reverse_map.values():
for chunk_range in flat_ranges:
functions_with_args.append(
(
# If Large Chunks are introduced, additional filters need to be added here
self.document_index.id_based_retrieval,
(
chunk_range.chunk.document_id,
# Only need the document_id here, just use any chunk in the range is fine
chunk_range.chunks[0].document_id,
chunk_range.start,
chunk_range.end,
# There is no chunk level permissioning, this expansion around chunks
@@ -206,152 +263,107 @@ class SearchPipeline:
list_inference_chunks = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=False
)
flattened_inference_chunks = [
chunk for sublist in list_inference_chunks for chunk in sublist
]
for ind, chunk_range in enumerate(reverse_map.values()):
inf_chunks = list_inference_chunks[ind]
combined_content = "\n".join([chunk.content for chunk in inf_chunks])
chunk_range.combined_content = combined_content
doc_chunk_ind_to_chunk = {
(chunk.document_id, chunk.chunk_id): chunk
for chunk in flattened_inference_chunks
}
for chunk in chunks:
if chunk not in reverse_map:
continue
chunk_range = reverse_map[chunk]
final_inference_sections.append(
InferenceSection.from_chunk(
chunk_range.chunk, content=chunk_range.combined_content
)
# Build the surroundings for all of the initial retrieved chunks
for chunk in retrieved_chunks:
start_ind = max(0, chunk.chunk_id - above)
end_ind = chunk.chunk_id + below
# Since the index of the max_chunk is unknown, just allow it to be None and filter after
surrounding_chunks_or_none = [
doc_chunk_ind_to_chunk.get((chunk.document_id, chunk_ind))
for chunk_ind in range(start_ind, end_ind + 1) # end_ind is inclusive
]
# The None will apply to the would be "chunks" that are larger than the index of the last chunk
# of the document
surrounding_chunks = [
chunk for chunk in surrounding_chunks_or_none if chunk is not None
]
inference_section = inference_section_from_chunks(
center_chunk=chunk,
chunks=surrounding_chunks,
)
if inference_section is not None:
expanded_inference_sections.append(inference_section)
else:
logger.warning("Skipped creation of section, no chunks found")
return final_inference_sections
"""Pre-processing"""
def _run_preprocessing(self) -> None:
(
final_search_query,
predicted_search_type,
predicted_flow,
) = retrieval_preprocessing(
search_request=self.search_request,
user=self.user,
llm=self.llm,
db_session=self.db_session,
bypass_acl=self.bypass_acl,
)
self._predicted_search_type = predicted_search_type
self._predicted_flow = predicted_flow
self._search_query = final_search_query
@property
def search_query(self) -> SearchQuery:
if self._search_query is not None:
return self._search_query
self._run_preprocessing()
return cast(SearchQuery, self._search_query)
@property
def predicted_search_type(self) -> SearchType:
if self._predicted_search_type is not None:
return self._predicted_search_type
self._run_preprocessing()
return cast(SearchType, self._predicted_search_type)
@property
def predicted_flow(self) -> QueryFlow:
if self._predicted_flow is not None:
return self._predicted_flow
self._run_preprocessing()
return cast(QueryFlow, self._predicted_flow)
"""Retrieval"""
@property
def retrieved_chunks(self) -> list[InferenceChunk]:
if self._retrieved_chunks is not None:
return self._retrieved_chunks
self._retrieved_chunks = retrieve_chunks(
query=self.search_query,
document_index=self.document_index,
db_session=self.db_session,
hybrid_alpha=self.search_request.hybrid_alpha,
multilingual_expansion_str=MULTILINGUAL_QUERY_EXPANSION,
retrieval_metrics_callback=self.retrieval_metrics_callback,
)
return cast(list[InferenceChunk], self._retrieved_chunks)
@property
def retrieved_sections(self) -> list[InferenceSection]:
# Calls retrieved_chunks inside
self._retrieved_sections = self._combine_chunks(post_rerank=False)
return self._retrieved_sections
"""Post-Processing"""
@property
def reranked_chunks(self) -> list[InferenceChunk]:
if self._reranked_chunks is not None:
return self._reranked_chunks
self._postprocessing_generator = search_postprocessing(
search_query=self.search_query,
retrieved_chunks=self.retrieved_chunks,
llm=self.fast_llm, # use fast_llm for relevance, since it is a relatively easier task
rerank_metrics_callback=self.rerank_metrics_callback,
)
self._reranked_chunks = cast(
list[InferenceChunk], next(self._postprocessing_generator)
)
return self._reranked_chunks
self._retrieved_sections = expanded_inference_sections
return expanded_inference_sections
@property
def reranked_sections(self) -> list[InferenceSection]:
# Calls reranked_chunks inside
self._reranked_sections = self._combine_chunks(post_rerank=True)
"""Reranking is always done at the chunk level since section merging could create arbitrarily
long sections which could be:
1. Longer than the maximum context limit of even large rerankers
2. Slow to calculate due to the quadratic scaling laws of Transformers
See implementation in search_postprocessing for details
"""
if self._reranked_sections is not None:
return self._reranked_sections
self._postprocessing_generator = search_postprocessing(
search_query=self.search_query,
retrieved_sections=self._get_sections(),
llm=self.fast_llm,
rerank_metrics_callback=self.rerank_metrics_callback,
)
self._reranked_sections = cast(
list[InferenceSection], next(self._postprocessing_generator)
)
return self._reranked_sections
@property
def relevant_chunk_indices(self) -> list[int]:
# If chunks have been merged, then we cannot simply rely on the leading chunk
# relevance, there is no way to get the full relevance of the Section now
# without running a more token heavy pass. This can be an option but not
# implementing now.
if self.ran_merge_chunk:
return []
def relevant_section_indices(self) -> list[int]:
if self._relevant_section_indices is not None:
return self._relevant_section_indices
if self._relevant_chunk_indices is not None:
return self._relevant_chunk_indices
# run first step of postprocessing generator if not already done
reranked_docs = self.reranked_chunks
relevant_chunk_ids = next(
cast(Generator[list[str], None, None], self._postprocessing_generator)
self._relevant_section_indices = next(
cast(Iterator[list[int]], self._postprocessing_generator)
)
self._relevant_chunk_indices = [
ind
for ind, chunk in enumerate(reranked_docs)
if chunk.unique_id in relevant_chunk_ids
]
return self._relevant_chunk_indices
return self._relevant_section_indices
@property
def chunk_relevance_list(self) -> list[bool]:
return [
True if ind in self.relevant_chunk_indices else False
for ind in range(len(self.reranked_chunks))
def relevance_summaries(self) -> dict[str, RelevanceChunk]:
if DISABLE_AGENTIC_SEARCH:
raise ValueError(
"Agentic saerch operation called while DISABLE_AGENTIC_SEARCH is toggled"
)
if len(self.reranked_sections) == 0:
logger.warning(
"No sections found in agentic search evalution. Returning empty dict."
)
return {}
sections = self.reranked_sections
functions = [
FunctionCall(
evaluate_inference_section, (section, self.search_query.query, self.llm)
)
for section in sections
]
results = run_functions_in_parallel(function_calls=functions)
return {
next(iter(value)): value[next(iter(value))] for value in results.values()
}
@property
def section_relevance_list(self) -> list[bool]:
if self.ran_merge_chunk:
return [False] * len(self.reranked_sections)
return [
True if ind in self.relevant_chunk_indices else False
for ind in range(len(self.reranked_chunks))
True if ind in self.relevant_section_indices else False
for ind in range(len(self.reranked_sections))
]

View File

@@ -1,9 +1,11 @@
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from typing import cast
import numpy
from danswer.configs.constants import MAX_CHUNK_TITLE_LEN
from danswer.configs.constants import RETURN_SEPARATOR
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
from danswer.document_index.document_index_utils import (
@@ -12,12 +14,14 @@ from danswer.document_index.document_index_utils import (
from danswer.llm.interfaces import LLM
from danswer.search.models import ChunkMetric
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceChunkUncleaned
from danswer.search.models import InferenceSection
from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
@@ -27,9 +31,12 @@ from danswer.utils.timing import log_function_time
logger = setup_logger()
def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None:
def _log_top_section_links(search_flow: str, sections: list[InferenceSection]) -> None:
top_links = [
c.source_links[0] if c.source_links is not None else "No Link" for c in chunks
section.center_chunk.source_links[0]
if section.center_chunk.source_links is not None
else "No Link"
for section in sections
]
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
@@ -43,6 +50,33 @@ def should_apply_llm_based_relevance_filter(query: SearchQuery) -> bool:
return not query.skip_llm_chunk_filter
def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk]:
def _remove_title(chunk: InferenceChunkUncleaned) -> str:
if not chunk.title or not chunk.content:
return chunk.content
if chunk.content.startswith(chunk.title):
return chunk.content[len(chunk.title) :].lstrip()
if chunk.content.startswith(chunk.title[:MAX_CHUNK_TITLE_LEN]):
return chunk.content[MAX_CHUNK_TITLE_LEN:].lstrip()
return chunk.content
def _remove_metadata_suffix(chunk: InferenceChunkUncleaned) -> str:
if not chunk.metadata_suffix:
return chunk.content
return chunk.content.removesuffix(chunk.metadata_suffix).rstrip(
RETURN_SEPARATOR
)
for chunk in chunks:
chunk.content = _remove_title(chunk)
chunk.content = _remove_metadata_suffix(chunk)
return [chunk.to_inference_chunk() for chunk in chunks]
@log_function_time(print_only=True)
def semantic_reranking(
query: str,
@@ -113,84 +147,113 @@ def semantic_reranking(
return list(ranked_chunks), list(ranked_indices)
def rerank_chunks(
def rerank_sections(
query: SearchQuery,
chunks_to_rerank: list[InferenceChunk],
sections_to_rerank: list[InferenceSection],
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> list[InferenceChunk]:
) -> list[InferenceSection]:
"""Chunks are reranked rather than the containing sections, this is because of speed
implications, if reranking models have lower latency for long inputs in the future
we may rerank on the combined context of the section instead
Making the assumption here that often times we want larger Sections to provide context
for the LLM to determine if a section is useful but for reranking, we don't need to be
as stringent. If the Section is relevant, we assume that the chunk rerank score will
also be high.
"""
chunks_to_rerank = [section.center_chunk for section in sections_to_rerank]
ranked_chunks, _ = semantic_reranking(
query=query.query,
chunks=chunks_to_rerank[: query.num_rerank],
rerank_metrics_callback=rerank_metrics_callback,
)
lower_chunks = chunks_to_rerank[query.num_rerank :]
# Scores from rerank cannot be meaningfully combined with scores without rerank
# However the ordering is still important
for lower_chunk in lower_chunks:
lower_chunk.score = None
ranked_chunks.extend(lower_chunks)
return ranked_chunks
chunk_id_to_section = {
section.center_chunk.unique_id: section for section in sections_to_rerank
}
ordered_sections = [chunk_id_to_section[chunk.unique_id] for chunk in ranked_chunks]
return ordered_sections
@log_function_time(print_only=True)
def filter_chunks(
def filter_sections(
query: SearchQuery,
chunks_to_filter: list[InferenceChunk],
sections_to_filter: list[InferenceSection],
llm: LLM,
) -> list[str]:
"""Filters chunks based on whether the LLM thought they were relevant to the query.
# For cost saving, we may turn this on
use_chunk: bool = False,
) -> list[InferenceSection]:
"""Filters sections based on whether the LLM thought they were relevant to the query.
This applies on the section which has more context than the chunk. Hopefully this yields more accurate LLM evaluations.
Returns a list of the unique chunk IDs that were marked as relevant"""
chunks_to_filter = chunks_to_filter[: query.max_llm_filter_chunks]
llm_chunk_selection = llm_batch_eval_chunks(
Returns a list of the unique chunk IDs that were marked as relevant
"""
sections_to_filter = sections_to_filter[: query.max_llm_filter_sections]
contents = [
section.center_chunk.content if use_chunk else section.combined_content
for section in sections_to_filter
]
llm_chunk_selection = llm_batch_eval_sections(
query=query.query,
chunk_contents=[chunk.content for chunk in chunks_to_filter],
section_contents=contents,
llm=llm,
)
return [
chunk.unique_id
for ind, chunk in enumerate(chunks_to_filter)
section
for ind, section in enumerate(sections_to_filter)
if llm_chunk_selection[ind]
]
def search_postprocessing(
search_query: SearchQuery,
retrieved_chunks: list[InferenceChunk],
retrieved_sections: list[InferenceSection],
llm: LLM,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> Generator[list[InferenceChunk] | list[str], None, None]:
) -> Iterator[list[InferenceSection] | list[int]]:
post_processing_tasks: list[FunctionCall] = []
rerank_task_id = None
chunks_yielded = False
sections_yielded = False
if should_rerank(search_query):
post_processing_tasks.append(
FunctionCall(
rerank_chunks,
rerank_sections,
(
search_query,
retrieved_chunks,
retrieved_sections,
rerank_metrics_callback,
),
)
)
rerank_task_id = post_processing_tasks[-1].result_id
else:
final_chunks = retrieved_chunks
# NOTE: if we don't rerank, we can return the chunks immediately
# since we know this is the final order
_log_top_chunk_links(search_query.search_type.value, final_chunks)
yield final_chunks
chunks_yielded = True
# since we know this is the final order.
# This way the user experience isn't delayed by the LLM step
_log_top_section_links(search_query.search_type.value, retrieved_sections)
yield retrieved_sections
sections_yielded = True
llm_filter_task_id = None
if should_apply_llm_based_relevance_filter(search_query):
post_processing_tasks.append(
FunctionCall(
filter_chunks,
filter_sections,
(
search_query,
retrieved_chunks[: search_query.max_llm_filter_chunks],
retrieved_sections[: search_query.max_llm_filter_sections],
llm,
),
)
@@ -202,30 +265,30 @@ def search_postprocessing(
if post_processing_tasks
else {}
)
reranked_chunks = cast(
list[InferenceChunk] | None,
reranked_sections = cast(
list[InferenceSection] | None,
post_processing_results.get(str(rerank_task_id)) if rerank_task_id else None,
)
if reranked_chunks:
if chunks_yielded:
if reranked_sections:
if sections_yielded:
logger.error(
"Trying to yield re-ranked chunks, but chunks were already yielded. This should never happen."
"Trying to yield re-ranked sections, but sections were already yielded. This should never happen."
)
else:
_log_top_chunk_links(search_query.search_type.value, reranked_chunks)
yield reranked_chunks
_log_top_section_links(search_query.search_type.value, reranked_sections)
yield reranked_sections
llm_chunk_selection = cast(
list[str] | None,
post_processing_results.get(str(llm_filter_task_id))
if llm_filter_task_id
else None,
)
if llm_chunk_selection is not None:
yield [
chunk.unique_id
for chunk in reranked_chunks or retrieved_chunks
if chunk.unique_id in llm_chunk_selection
llm_selected_section_ids = (
[
section.center_chunk.unique_id
for section in post_processing_results.get(str(llm_filter_task_id), [])
]
else:
yield cast(list[str], [])
if llm_filter_task_id
else []
)
yield [
index
for index, section in enumerate(reranked_sections or retrieved_sections)
if section.center_chunk.unique_id in llm_selected_section_ids
]

View File

@@ -15,7 +15,7 @@ if TYPE_CHECKING:
def count_unk_tokens(text: str, tokenizer: "AutoTokenizer") -> int:
"""Unclear if the wordpiece tokenizer used is actually tokenizing anything as the [UNK] token
"""Unclear if the wordpiece/sentencepiece tokenizer used is actually tokenizing anything as the [UNK] token
It splits up even foreign characters and unicode emojis without using UNK"""
tokenized_text = tokenizer.tokenize(text)
num_unk_tokens = len(
@@ -73,6 +73,7 @@ def recommend_search_flow(
non_stopword_percent = len(non_stopwords) / len(words)
# UNK tokens -> suggest Keyword (still may be valid QA)
# TODO do a better job with the classifier model and retire the heuristics
if count_unk_tokens(query, get_default_tokenizer(model_name=model_name)) > 0:
if not keyword:
heuristic_search_type = SearchType.KEYWORD

View File

@@ -2,7 +2,6 @@ from sqlalchemy.orm import Session
from danswer.configs.chat_configs import BASE_RECENCY_DECAY
from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION
from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.db.models import User
@@ -36,8 +35,6 @@ def retrieval_preprocessing(
db_session: Session,
bypass_acl: bool = False,
include_query_intent: bool = True,
enable_auto_detect_filters: bool = False,
disable_llm_filter_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION,
disable_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
base_recency_decay: float = BASE_RECENCY_DECAY,
favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER,
@@ -63,10 +60,7 @@ def retrieval_preprocessing(
auto_detect_time_filter = True
auto_detect_source_filter = True
if disable_llm_filter_extraction:
auto_detect_time_filter = False
auto_detect_source_filter = False
elif enable_auto_detect_filters is False:
if not search_request.enable_auto_detect_filters:
logger.debug("Retrieval details disables auto detect filters")
auto_detect_time_filter = False
auto_detect_source_filter = False

View File

@@ -7,26 +7,28 @@ from nltk.stem import WordNetLemmatizer # type:ignore
from nltk.tokenize import word_tokenize # type:ignore
from sqlalchemy.orm import Session
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import HYBRID_ALPHA
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.document_index.interfaces import DocumentIndex
from danswer.search.enums import EmbedTextType
from danswer.search.models import ChunkMetric
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.search.postprocessing.postprocessing import cleanup_chunks
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.search.utils import inference_section_from_chunks
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from danswer.utils.timing import log_function_time
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbedTextType
logger = setup_logger()
@@ -129,6 +131,8 @@ def doc_index_retrieval(
query_prefix=db_embedding_model.query_prefix,
passage_prefix=db_embedding_model.passage_prefix,
normalize=db_embedding_model.normalize,
api_key=db_embedding_model.api_key,
provider_type=db_embedding_model.provider_type,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
@@ -159,7 +163,7 @@ def doc_index_retrieval(
else:
raise RuntimeError("Invalid Search Flow")
return top_chunks
return cleanup_chunks(top_chunks)
def _simplify_text(text: str) -> str:
@@ -240,30 +244,10 @@ def retrieve_chunks(
return top_chunks
def combine_inference_chunks(inf_chunks: list[InferenceChunk]) -> LlmDoc:
if not inf_chunks:
raise ValueError("Cannot combine empty list of chunks")
# Use the first link of the document
first_chunk = inf_chunks[0]
chunk_texts = [chunk.content for chunk in inf_chunks]
return LlmDoc(
document_id=first_chunk.document_id,
content="\n".join(chunk_texts),
blurb=first_chunk.blurb,
semantic_identifier=first_chunk.semantic_identifier,
source_type=first_chunk.source_type,
metadata=first_chunk.metadata,
updated_at=first_chunk.updated_at,
link=first_chunk.source_links[0] if first_chunk.source_links else None,
source_links=first_chunk.source_links,
)
def inference_documents_from_ids(
def inference_sections_from_ids(
doc_identifiers: list[tuple[str, int]],
document_index: DocumentIndex,
) -> list[LlmDoc]:
) -> list[InferenceSection]:
# Currently only fetches whole docs
doc_ids_set = set(doc_id for doc_id, chunk_id in doc_identifiers)
@@ -282,4 +266,17 @@ def inference_documents_from_ids(
# Any failures to retrieve would give a None, drop the Nones and empty lists
inference_chunks_sets = [res for res in parallel_results if res]
return [combine_inference_chunks(chunk_set) for chunk_set in inference_chunks_sets]
return [
inference_section
for inference_section in [
inference_section_from_chunks(
# The scores will always be 0 because the fetching by id gives back
# no search scores. This is not needed though if the user is explicitly
# selecting a document.
center_chunk=chunk_set[0],
chunks=chunk_set,
)
for chunk_set in inference_chunks_sets
]
if inference_section is not None
]

View File

@@ -9,10 +9,10 @@ from transformers import logging as transformer_logging # type:ignore
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.search.enums import EmbedTextType
from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import IntentRequest
@@ -40,25 +40,22 @@ def clean_model_name(model_str: str) -> str:
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
# NOTE: If None is used, it may not be using the "correct" tokenizer, for cases
# where this is more important, be sure to refresh with the actual model name
def get_default_tokenizer(model_name: str | None = None) -> "AutoTokenizer":
# NOTE: If no model_name is specified, it may not be using the "correct" tokenizer
# for cases where this is more important, be sure to refresh with the actual model name
# One case where it is not particularly important is in the document chunking flow,
# they're basically all using the sentencepiece tokenizer and whether it's cased or
# uncased does not really matter, they'll all generally end up with the same chunk lengths.
def get_default_tokenizer(model_name: str = DOCUMENT_ENCODER_MODEL) -> "AutoTokenizer":
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
from transformers import AutoTokenizer # type: ignore
global _TOKENIZER
if _TOKENIZER[0] is None or (
_TOKENIZER[1] is not None and _TOKENIZER[1] != model_name
):
if _TOKENIZER[0] is None or _TOKENIZER[1] != model_name:
if _TOKENIZER[0] is not None:
del _TOKENIZER
gc.collect()
if model_name is None:
# This could be inaccurate
model_name = DOCUMENT_ENCODER_MODEL
_TOKENIZER = (AutoTokenizer.from_pretrained(model_name), model_name)
if hasattr(_TOKENIZER[0], "is_fast") and _TOKENIZER[0].is_fast:
@@ -84,20 +81,24 @@ def build_model_server_url(
class EmbeddingModel:
def __init__(
self,
model_name: str,
query_prefix: str | None,
passage_prefix: str | None,
normalize: bool,
server_host: str, # Changes depending on indexing or inference
server_port: int,
model_name: str | None,
normalize: bool,
query_prefix: str | None,
passage_prefix: str | None,
api_key: str | None,
provider_type: str | None,
# The following are globals are currently not configurable
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
) -> None:
self.model_name = model_name
self.api_key = api_key
self.provider_type = provider_type
self.max_seq_length = max_seq_length
self.query_prefix = query_prefix
self.passage_prefix = passage_prefix
self.normalize = normalize
self.model_name = model_name
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
@@ -111,10 +112,13 @@ class EmbeddingModel:
prefixed_texts = texts
embed_request = EmbedRequest(
texts=prefixed_texts,
model_name=self.model_name,
texts=prefixed_texts,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
)
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
@@ -177,6 +181,7 @@ def warm_up_encoders(
"https://docs.danswer.dev/quickstart"
)
# May not be the exact same tokenizer used for the indexing flow
get_default_tokenizer(model_name=model_name)(warm_up_str)
embed_model = EmbeddingModel(
@@ -187,6 +192,8 @@ def warm_up_encoders(
passage_prefix=None,
server_host=model_server_host,
server_port=model_server_port,
api_key=None,
provider_type=None,
)
# First time downloading the models it may take even longer, but just in case,

View File

@@ -5,10 +5,18 @@ from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.search.models import SavedSearchDoc
from danswer.search.models import SavedSearchDocWithContent
from danswer.search.models import SearchDoc
T = TypeVar("T", InferenceSection, InferenceChunk, SearchDoc)
T = TypeVar(
"T",
InferenceSection,
InferenceChunk,
SearchDoc,
SavedSearchDoc,
SavedSearchDocWithContent,
)
def dedupe_documents(items: list[T]) -> tuple[list[T], list[int]]:
@@ -16,8 +24,13 @@ def dedupe_documents(items: list[T]) -> tuple[list[T], list[int]]:
deduped_items = []
dropped_indices = []
for index, item in enumerate(items):
if item.document_id not in seen_ids:
seen_ids.add(item.document_id)
if isinstance(item, InferenceSection):
document_id = item.center_chunk.document_id
else:
document_id = item.document_id
if document_id not in seen_ids:
seen_ids.add(document_id)
deduped_items.append(item)
else:
dropped_indices.append(index)
@@ -37,30 +50,51 @@ def drop_llm_indices(
return [i for i, val in enumerate(llm_bools) if val]
def chunks_or_sections_to_search_docs(
chunks: Sequence[InferenceChunk | InferenceSection] | None,
) -> list[SearchDoc]:
search_docs = (
[
SearchDoc(
document_id=chunk.document_id,
chunk_ind=chunk.chunk_id,
semantic_identifier=chunk.semantic_identifier or "Unknown",
link=chunk.source_links.get(0) if chunk.source_links else None,
blurb=chunk.blurb,
source_type=chunk.source_type,
boost=chunk.boost,
hidden=chunk.hidden,
metadata=chunk.metadata,
score=chunk.score,
match_highlights=chunk.match_highlights,
updated_at=chunk.updated_at,
primary_owners=chunk.primary_owners,
secondary_owners=chunk.secondary_owners,
)
for chunk in chunks
]
if chunks
else []
def inference_section_from_chunks(
center_chunk: InferenceChunk,
chunks: list[InferenceChunk],
) -> InferenceSection | None:
if not chunks:
return None
combined_content = "\n".join([chunk.content for chunk in chunks])
return InferenceSection(
center_chunk=center_chunk,
chunks=chunks,
combined_content=combined_content,
)
def chunks_or_sections_to_search_docs(
items: Sequence[InferenceChunk | InferenceSection] | None,
) -> list[SearchDoc]:
if not items:
return []
search_docs = [
SearchDoc(
document_id=(
chunk := item.center_chunk
if isinstance(item, InferenceSection)
else item
).document_id,
chunk_ind=chunk.chunk_id,
semantic_identifier=chunk.semantic_identifier or "Unknown",
link=chunk.source_links[0] if chunk.source_links else None,
blurb=chunk.blurb,
source_type=chunk.source_type,
boost=chunk.boost,
hidden=chunk.hidden,
metadata=chunk.metadata,
score=chunk.score,
match_highlights=chunk.match_highlights,
updated_at=chunk.updated_at,
primary_owners=chunk.primary_owners,
secondary_owners=chunk.secondary_owners,
is_internet=False,
)
for item in items
]
return search_docs

View File

@@ -0,0 +1,70 @@
import re
from danswer.chat.models import RelevanceChunk
from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
from danswer.prompts.agentic_evaluation import AGENTIC_SEARCH_SYSTEM_PROMPT
from danswer.prompts.agentic_evaluation import AGENTIC_SEARCH_USER_PROMPT
from danswer.search.models import InferenceSection
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _get_agent_eval_messages(
title: str, content: str, query: str
) -> list[dict[str, str]]:
messages = [
{
"role": "system",
"content": AGENTIC_SEARCH_SYSTEM_PROMPT,
},
{
"role": "user",
"content": AGENTIC_SEARCH_USER_PROMPT.format(
title=title, content=content, query=query
),
},
]
return messages
def evaluate_inference_section(
document: InferenceSection, query: str, llm: LLM
) -> dict[str, RelevanceChunk]:
results = {}
document_id = document.center_chunk.document_id
semantic_id = document.center_chunk.semantic_identifier
contents = document.combined_content
chunk_id = document.center_chunk.chunk_id
messages = _get_agent_eval_messages(
title=semantic_id, content=contents, query=query
)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = message_to_string(llm.invoke(filled_llm_prompt))
# Search for the "Useful Analysis" section in the model output
# This regex looks for "2. Useful Analysis" (case-insensitive) followed by an optional colon,
# then any text up to "3. Final Relevance"
# The (?i) flag makes it case-insensitive, and re.DOTALL allows the dot to match newlines
# If no match is found, the entire model output is used as the analysis
analysis_match = re.search(
r"(?i)2\.\s*useful analysis:?\s*(.+?)\n\n3\.\s*final relevance",
model_output,
re.DOTALL,
)
analysis = analysis_match.group(1).strip() if analysis_match else model_output
# Get the last non-empty line
last_line = next(
(line for line in reversed(model_output.split("\n")) if line.strip()), ""
)
relevant = last_line.strip().lower().startswith("true")
results[f"{document_id}-{chunk_id}"] = RelevanceChunk(
relevant=relevant, content=analysis
)
return results

View File

@@ -3,21 +3,21 @@ from collections.abc import Callable
from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
from danswer.prompts.llm_chunk_filter import CHUNK_FILTER_PROMPT
from danswer.prompts.llm_chunk_filter import NONUSEFUL_PAT
from danswer.prompts.llm_chunk_filter import SECTION_FILTER_PROMPT
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
def llm_eval_chunk(query: str, chunk_content: str, llm: LLM) -> bool:
def llm_eval_section(query: str, section_content: str, llm: LLM) -> bool:
def _get_usefulness_messages() -> list[dict[str, str]]:
messages = [
{
"role": "user",
"content": CHUNK_FILTER_PROMPT.format(
chunk_text=chunk_content, user_query=query
"content": SECTION_FILTER_PROMPT.format(
chunk_text=section_content, user_query=query
),
},
]
@@ -42,13 +42,13 @@ def llm_eval_chunk(query: str, chunk_content: str, llm: LLM) -> bool:
return _extract_usefulness(model_output)
def llm_batch_eval_chunks(
query: str, chunk_contents: list[str], llm: LLM, use_threads: bool = True
def llm_batch_eval_sections(
query: str, section_contents: list[str], llm: LLM, use_threads: bool = True
) -> list[bool]:
if use_threads:
functions_with_args: list[tuple[Callable, tuple]] = [
(llm_eval_chunk, (query, chunk_content, llm))
for chunk_content in chunk_contents
(llm_eval_section, (query, section_content, llm))
for section_content in section_contents
]
logger.debug(
@@ -58,11 +58,11 @@ def llm_batch_eval_chunks(
functions_with_args, allow_failures=True
)
# In case of failure/timeout, don't throw out the chunk
# In case of failure/timeout, don't throw out the section
return [True if item is None else item for item in parallel_results]
else:
return [
llm_eval_chunk(query, chunk_content, llm)
for chunk_content in chunk_contents
llm_eval_section(query, section_content, llm)
for section_content in section_contents
]

View File

@@ -74,11 +74,12 @@ def multilingual_query_expansion(
def get_contextual_rephrase_messages(
question: str,
history_str: str,
prompt_template: str = HISTORY_QUERY_REPHRASE,
) -> list[dict[str, str]]:
messages = [
{
"role": "user",
"content": HISTORY_QUERY_REPHRASE.format(
"content": prompt_template.format(
question=question, chat_history=history_str
),
},
@@ -94,6 +95,7 @@ def history_based_query_rephrase(
size_heuristic: int = 200,
punctuation_heuristic: int = 10,
skip_first_rephrase: bool = False,
prompt_template: str = HISTORY_QUERY_REPHRASE,
) -> str:
# Globally disabled, just use the exact user query
if DISABLE_LLM_QUERY_REPHRASE:
@@ -119,7 +121,7 @@ def history_based_query_rephrase(
)
prompt_msgs = get_contextual_rephrase_messages(
question=query, history_str=history_str
question=query, history_str=history_str, prompt_template=prompt_template
)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)

View File

@@ -77,7 +77,7 @@ def associate_credential_to_connector(
connector_id: int,
credential_id: int,
metadata: ConnectorCredentialPairMetadata,
user: User = Depends(current_user),
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StatusResponse[int]:
try:
@@ -97,7 +97,7 @@ def associate_credential_to_connector(
def dissociate_credential_from_connector(
connector_id: int,
credential_id: int,
user: User = Depends(current_user),
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StatusResponse[int]:
return remove_credential_from_connector(

View File

@@ -10,6 +10,7 @@ class ToolSnapshot(BaseModel):
name: str
description: str
definition: dict[str, Any] | None
display_name: str
in_code_tool_id: str | None
@classmethod
@@ -19,5 +20,6 @@ class ToolSnapshot(BaseModel):
name=tool.name,
description=tool.description,
definition=tool.openapi_schema,
display_name=tool.display_name or tool.name,
in_code_tool_id=tool.in_code_tool_id,
)

View File

@@ -68,7 +68,7 @@ def gpt_search(
db_session: Session = Depends(get_session),
) -> GptSearchResponse:
llm, fast_llm = get_default_llms()
top_chunks = SearchPipeline(
top_sections = SearchPipeline(
search_request=SearchRequest(
query=search_request.query,
),
@@ -76,20 +76,22 @@ def gpt_search(
llm=llm,
fast_llm=fast_llm,
db_session=db_session,
).reranked_chunks
).reranked_sections
return GptSearchResponse(
matching_document_chunks=[
GptDocChunk(
title=chunk.semantic_identifier,
content=chunk.content,
source_type=chunk.source_type,
link=chunk.source_links.get(0, "") if chunk.source_links else "",
metadata=chunk.metadata,
document_age=time_ago(chunk.updated_at)
if chunk.updated_at
title=section.center_chunk.semantic_identifier,
content=section.center_chunk.content,
source_type=section.center_chunk.source_type,
link=section.center_chunk.source_links.get(0, "")
if section.center_chunk.source_links
else "",
metadata=section.center_chunk.metadata,
document_age=time_ago(section.center_chunk.updated_at)
if section.center_chunk.updated_at
else "Unknown",
)
for chunk in top_chunks
for section in top_sections
],
)

View File

@@ -0,0 +1,93 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.db.embedding_model import get_current_db_embedding_provider
from danswer.db.engine import get_session
from danswer.db.llm import fetch_existing_embedding_providers
from danswer.db.llm import remove_embedding_provider
from danswer.db.llm import upsert_cloud_embedding_provider
from danswer.db.models import User
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.embedding.models import TestEmbeddingRequest
from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbedTextType
logger = setup_logger()
admin_router = APIRouter(prefix="/admin/embedding")
basic_router = APIRouter(prefix="/embedding")
@admin_router.post("/test-embedding")
def test_embedding_configuration(
test_llm_request: TestEmbeddingRequest,
_: User | None = Depends(current_admin_user),
) -> None:
try:
test_model = EmbeddingModel(
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
api_key=test_llm_request.api_key,
provider_type=test_llm_request.provider,
normalize=False,
query_prefix=None,
passage_prefix=None,
model_name=None,
)
test_model.encode(["Test String"], text_type=EmbedTextType.QUERY)
except ValueError as e:
error_msg = f"Not a valid embedding model. Exception thrown: {e}"
logger.error(error_msg)
raise ValueError(error_msg)
except Exception as e:
error_msg = "An error occurred while testing your embedding model. Please check your configuration."
logger.error(f"{error_msg} Error message: {e}", exc_info=True)
raise HTTPException(status_code=400, detail=error_msg)
@admin_router.get("/embedding-provider")
def list_embedding_providers(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[CloudEmbeddingProvider]:
return [
CloudEmbeddingProvider.from_request(embedding_provider_model)
for embedding_provider_model in fetch_existing_embedding_providers(db_session)
]
@admin_router.delete("/embedding-provider/{embedding_provider_name}")
def delete_embedding_provider(
embedding_provider_name: str,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
embedding_provider = get_current_db_embedding_provider(db_session=db_session)
if (
embedding_provider is not None
and embedding_provider_name == embedding_provider.name
):
raise HTTPException(
status_code=400, detail="You can't delete a currently active model"
)
remove_embedding_provider(db_session, embedding_provider_name)
@admin_router.put("/embedding-provider")
def put_cloud_embedding_provider(
provider: CloudEmbeddingProviderCreationRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> CloudEmbeddingProvider:
return upsert_cloud_embedding_provider(db_session, provider)

View File

@@ -0,0 +1,35 @@
from typing import TYPE_CHECKING
from pydantic import BaseModel
if TYPE_CHECKING:
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
class TestEmbeddingRequest(BaseModel):
provider: str
api_key: str | None = None
class CloudEmbeddingProvider(BaseModel):
name: str
api_key: str | None = None
default_model_id: int | None = None
id: int
@classmethod
def from_request(
cls, cloud_provider_model: "CloudEmbeddingProviderModel"
) -> "CloudEmbeddingProvider":
return cls(
id=cloud_provider_model.id,
name=cloud_provider_model.name,
api_key=cloud_provider_model.api_key,
default_model_id=cloud_provider_model.default_model_id,
)
class CloudEmbeddingProviderCreationRequest(BaseModel):
name: str
api_key: str | None = None
default_model_id: int | None = None

View File

@@ -4,6 +4,7 @@ from pydantic import BaseModel
from danswer.llm.llm_provider_options import fetch_models_for_provider
if TYPE_CHECKING:
from danswer.db.models import LLMProvider as LLMProviderModel

View File

@@ -12,6 +12,8 @@ from danswer.db.models import AllowedAnswerFilters
from danswer.db.models import ChannelConfig
from danswer.db.models import SlackBotConfig as SlackBotConfigModel
from danswer.db.models import SlackBotResponseType
from danswer.db.models import StandardAnswer as StandardAnswerModel
from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
from danswer.indexing.models import EmbeddingModelDetail
from danswer.server.features.persona.models import PersonaSnapshot
from danswer.server.models import FullUserSnapshot
@@ -84,6 +86,57 @@ class HiddenUpdateRequest(BaseModel):
hidden: bool
class StandardAnswerCategoryCreationRequest(BaseModel):
name: str
class StandardAnswerCategory(BaseModel):
id: int
name: str
@classmethod
def from_model(
cls, standard_answer_category: StandardAnswerCategoryModel
) -> "StandardAnswerCategory":
return cls(
id=standard_answer_category.id,
name=standard_answer_category.name,
)
class StandardAnswer(BaseModel):
id: int
keyword: str
answer: str
categories: list[StandardAnswerCategory]
@classmethod
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
return cls(
id=standard_answer_model.id,
keyword=standard_answer_model.keyword,
answer=standard_answer_model.answer,
categories=[
StandardAnswerCategory.from_model(standard_answer_category_model)
for standard_answer_category_model in standard_answer_model.categories
],
)
class StandardAnswerCreationRequest(BaseModel):
keyword: str
answer: str
categories: list[int]
@validator("categories", pre=True)
def validate_categories(cls, value: list[int]) -> list[int]:
if len(value) < 1:
raise ValueError(
"At least one category must be attached to a standard answer"
)
return value
class SlackBotTokens(BaseModel):
bot_token: str
app_token: str
@@ -102,13 +155,14 @@ class SlackBotConfigCreationRequest(BaseModel):
channel_names: list[str]
respond_tag_only: bool = False
respond_to_bots: bool = False
enable_auto_filters: bool = False
# If no team members, assume respond in the channel to everyone
respond_team_member_list: list[str] = []
respond_slack_group_list: list[str] = []
respond_member_group_list: list[str] = []
answer_filters: list[AllowedAnswerFilters] = []
# list of user emails
follow_up_tags: list[str] | None = None
response_type: SlackBotResponseType
standard_answer_categories: list[int] = []
@validator("answer_filters", pre=True)
def validate_filters(cls, value: list[str]) -> list[str]:
@@ -133,6 +187,8 @@ class SlackBotConfig(BaseModel):
persona: PersonaSnapshot | None
channel_config: ChannelConfig
response_type: SlackBotResponseType
standard_answer_categories: list[StandardAnswerCategory]
enable_auto_filters: bool
@classmethod
def from_model(
@@ -149,6 +205,11 @@ class SlackBotConfig(BaseModel):
),
channel_config=slack_bot_config_model.channel_config,
response_type=slack_bot_config_model.response_type,
standard_answer_categories=[
StandardAnswerCategory.from_model(standard_answer_category_model)
for standard_answer_category_model in slack_bot_config_model.standard_answer_categories
],
enable_auto_filters=slack_bot_config_model.enable_auto_filters,
)

View File

@@ -11,6 +11,7 @@ from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import resync_cc_pair
from danswer.db.embedding_model import create_embedding_model
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_model_id_from_name
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.embedding_model import update_embedding_model_status
from danswer.db.engine import get_session
@@ -38,6 +39,19 @@ def set_new_embedding_model(
"""
current_model = get_current_db_embedding_model(db_session)
if embed_model_details.cloud_provider_name is not None:
cloud_id = get_model_id_from_name(
db_session, embed_model_details.cloud_provider_name
)
if cloud_id is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No ID exists for given provider name",
)
embed_model_details.cloud_provider_id = cloud_id
if embed_model_details.model_name == current_model.model_name:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,

View File

@@ -34,11 +34,8 @@ def _form_channel_config(
) -> ChannelConfig:
raw_channel_names = slack_bot_config_creation_request.channel_names
respond_tag_only = slack_bot_config_creation_request.respond_tag_only
respond_team_member_list = (
slack_bot_config_creation_request.respond_team_member_list
)
respond_slack_group_list = (
slack_bot_config_creation_request.respond_slack_group_list
respond_member_group_list = (
slack_bot_config_creation_request.respond_member_group_list
)
answer_filters = slack_bot_config_creation_request.answer_filters
follow_up_tags = slack_bot_config_creation_request.follow_up_tags
@@ -61,7 +58,7 @@ def _form_channel_config(
detail=str(e),
)
if respond_tag_only and (respond_team_member_list or respond_slack_group_list):
if respond_tag_only and respond_member_group_list:
raise ValueError(
"Cannot set DanswerBot to only respond to tags only and "
"also respond to a predetermined set of users."
@@ -72,10 +69,8 @@ def _form_channel_config(
}
if respond_tag_only is not None:
channel_config["respond_tag_only"] = respond_tag_only
if respond_team_member_list:
channel_config["respond_team_member_list"] = respond_team_member_list
if respond_slack_group_list:
channel_config["respond_slack_group_list"] = respond_slack_group_list
if respond_member_group_list:
channel_config["respond_member_group_list"] = respond_member_group_list
if answer_filters:
channel_config["answer_filters"] = answer_filters
if follow_up_tags is not None:
@@ -113,7 +108,9 @@ def create_slack_bot_config(
persona_id=persona_id,
channel_config=channel_config,
response_type=slack_bot_config_creation_request.response_type,
standard_answer_category_ids=slack_bot_config_creation_request.standard_answer_categories,
db_session=db_session,
enable_auto_filters=slack_bot_config_creation_request.enable_auto_filters,
)
return SlackBotConfig.from_model(slack_bot_config_model)
@@ -171,7 +168,9 @@ def patch_slack_bot_config(
persona_id=persona_id,
channel_config=channel_config,
response_type=slack_bot_config_creation_request.response_type,
standard_answer_category_ids=slack_bot_config_creation_request.standard_answer_categories,
db_session=db_session,
enable_auto_filters=slack_bot_config_creation_request.enable_auto_filters,
)
return SlackBotConfig.from_model(slack_bot_config_model)

View File

@@ -0,0 +1,139 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.standard_answer import fetch_standard_answer
from danswer.db.standard_answer import fetch_standard_answer_categories
from danswer.db.standard_answer import fetch_standard_answer_category
from danswer.db.standard_answer import fetch_standard_answers
from danswer.db.standard_answer import insert_standard_answer
from danswer.db.standard_answer import insert_standard_answer_category
from danswer.db.standard_answer import remove_standard_answer
from danswer.db.standard_answer import update_standard_answer
from danswer.db.standard_answer import update_standard_answer_category
from danswer.server.manage.models import StandardAnswer
from danswer.server.manage.models import StandardAnswerCategory
from danswer.server.manage.models import StandardAnswerCategoryCreationRequest
from danswer.server.manage.models import StandardAnswerCreationRequest
router = APIRouter(prefix="/manage")
@router.post("/admin/standard-answer")
def create_standard_answer(
standard_answer_creation_request: StandardAnswerCreationRequest,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
) -> StandardAnswer:
standard_answer_model = insert_standard_answer(
keyword=standard_answer_creation_request.keyword,
answer=standard_answer_creation_request.answer,
category_ids=standard_answer_creation_request.categories,
db_session=db_session,
)
return StandardAnswer.from_model(standard_answer_model)
@router.get("/admin/standard-answer")
def list_standard_answers(
db_session: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
) -> list[StandardAnswer]:
standard_answer_models = fetch_standard_answers(db_session=db_session)
return [
StandardAnswer.from_model(standard_answer_model)
for standard_answer_model in standard_answer_models
]
@router.patch("/admin/standard-answer/{standard_answer_id}")
def patch_standard_answer(
standard_answer_id: int,
standard_answer_creation_request: StandardAnswerCreationRequest,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
) -> StandardAnswer:
existing_standard_answer = fetch_standard_answer(
standard_answer_id=standard_answer_id,
db_session=db_session,
)
if existing_standard_answer is None:
raise HTTPException(status_code=404, detail="Standard answer not found")
standard_answer_model = update_standard_answer(
standard_answer_id=standard_answer_id,
keyword=standard_answer_creation_request.keyword,
answer=standard_answer_creation_request.answer,
category_ids=standard_answer_creation_request.categories,
db_session=db_session,
)
return StandardAnswer.from_model(standard_answer_model)
@router.delete("/admin/standard-answer/{standard_answer_id}")
def delete_standard_answer(
standard_answer_id: int,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
) -> None:
return remove_standard_answer(
standard_answer_id=standard_answer_id,
db_session=db_session,
)
@router.post("/admin/standard-answer/category")
def create_standard_answer_category(
standard_answer_category_creation_request: StandardAnswerCategoryCreationRequest,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
) -> StandardAnswerCategory:
standard_answer_category_model = insert_standard_answer_category(
category_name=standard_answer_category_creation_request.name,
db_session=db_session,
)
return StandardAnswerCategory.from_model(standard_answer_category_model)
@router.get("/admin/standard-answer/category")
def list_standard_answer_categories(
db_session: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
) -> list[StandardAnswerCategory]:
standard_answer_category_models = fetch_standard_answer_categories(
db_session=db_session
)
return [
StandardAnswerCategory.from_model(standard_answer_category_model)
for standard_answer_category_model in standard_answer_category_models
]
@router.patch("/admin/standard-answer/category/{standard_answer_category_id}")
def patch_standard_answer_category(
standard_answer_category_id: int,
standard_answer_category_creation_request: StandardAnswerCategoryCreationRequest,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
) -> StandardAnswerCategory:
existing_standard_answer_category = fetch_standard_answer_category(
standard_answer_category_id=standard_answer_category_id,
db_session=db_session,
)
if existing_standard_answer_category is None:
raise HTTPException(
status_code=404, detail="Standard answer category not found"
)
standard_answer_category_model = update_standard_answer_category(
standard_answer_category_id=standard_answer_category_id,
category_name=standard_answer_category_creation_request.name,
db_session=db_session,
)
return StandardAnswerCategory.from_model(standard_answer_category_model)

Some files were not shown because too many files have changed in this diff Show More