mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-21 09:45:46 +00:00
Compare commits
85 Commits
v0.3.92
...
eval/split
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4293543a6a | ||
|
|
e95bfa0e0b | ||
|
|
4848b5f1de | ||
|
|
7ba5c434fa | ||
|
|
59bf5ba848 | ||
|
|
f66c33380c | ||
|
|
115650ce9f | ||
|
|
7aa3602fca | ||
|
|
864c552a17 | ||
|
|
07b2ed3d8f | ||
|
|
38290057f2 | ||
|
|
2344edf158 | ||
|
|
86d1804eb0 | ||
|
|
1ebae50d0c | ||
|
|
a9fbaa396c | ||
|
|
27d5f69427 | ||
|
|
5d98421ae8 | ||
|
|
6b561b8ca9 | ||
|
|
2dc7e64dd7 | ||
|
|
5230f7e22f | ||
|
|
a595d43ae3 | ||
|
|
ee561f42ff | ||
|
|
f00b3d76b3 | ||
|
|
e4984153c0 | ||
|
|
87fadb07ea | ||
|
|
2b07c102f9 | ||
|
|
e93de602c3 | ||
|
|
1c77395503 | ||
|
|
cdf6089b3e | ||
|
|
d01f46af2b | ||
|
|
b83f435bb0 | ||
|
|
25b3dacaba | ||
|
|
a1e638a73d | ||
|
|
bd1e0c5969 | ||
|
|
4d295ab97d | ||
|
|
6fe3eeaa48 | ||
|
|
078d5defbb | ||
|
|
0d52e99bd4 | ||
|
|
1b864a00e4 | ||
|
|
dae4f6a0bd | ||
|
|
f63d0ca3ad | ||
|
|
da31da33e7 | ||
|
|
56b175f597 | ||
|
|
1b311d092e | ||
|
|
6ee1292757 | ||
|
|
017af052be | ||
|
|
e7f81d1688 | ||
|
|
b6bd818e60 | ||
|
|
36da2e4b27 | ||
|
|
c7af6a4601 | ||
|
|
e90c66c1b6 | ||
|
|
8c312482c1 | ||
|
|
e50820e65e | ||
|
|
991ee79e47 | ||
|
|
3e645a510e | ||
|
|
08c6e821e7 | ||
|
|
47a550221f | ||
|
|
511f619212 | ||
|
|
6c51f001dc | ||
|
|
09a11b5e1a | ||
|
|
aa0f7abdac | ||
|
|
7c8f8dba17 | ||
|
|
39982e5fdc | ||
|
|
5e0de111f9 | ||
|
|
727d80f168 | ||
|
|
146f85936b | ||
|
|
e06f8a0a4b | ||
|
|
f0888f2f61 | ||
|
|
d35d7ee833 | ||
|
|
c5bb3fde94 | ||
|
|
79190030a5 | ||
|
|
8e8f262ed3 | ||
|
|
ac14369716 | ||
|
|
de4d8e9a65 | ||
|
|
0b384c5b34 | ||
|
|
fa049f4f98 | ||
|
|
72d6a0ef71 | ||
|
|
ae4e643266 | ||
|
|
a7da07afc0 | ||
|
|
7f1bb67e52 | ||
|
|
982b1b0c49 | ||
|
|
2db128fb36 | ||
|
|
3ebac6256f | ||
|
|
1a3ec59610 | ||
|
|
581cb827bb |
@@ -1,8 +1,6 @@
|
||||
name: Build Backend Image on Merge Group
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
merge_group:
|
||||
types: [checks_requested]
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
name: Build Web Image on Merge Group
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
merge_group:
|
||||
types: [checks_requested]
|
||||
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,3 +6,4 @@
|
||||
/deployment/data/nginx/app.conf
|
||||
.vscode/launch.json
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
|
||||
14
.vscode/env_template.txt
vendored
14
.vscode/env_template.txt
vendored
@@ -8,7 +8,7 @@ AUTH_TYPE=disabled
|
||||
|
||||
# Always keep these on for Dev
|
||||
# Logs all model prompts to stdout
|
||||
LOG_ALL_MODEL_INTERACTIONS=True
|
||||
LOG_DANSWER_MODEL_INTERACTIONS=True
|
||||
# More verbose logging
|
||||
LOG_LEVEL=debug
|
||||
|
||||
@@ -25,11 +25,6 @@ OAUTH_CLIENT_SECRET=<REPLACE THIS>
|
||||
REQUIRE_EMAIL_VERIFICATION=False
|
||||
|
||||
|
||||
# Toggles on/off the EE Features
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
|
||||
|
||||
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
|
||||
GEN_AI_API_KEY=<REPLACE THIS>
|
||||
# If answer quality isn't important for dev, use 3.5 turbo due to it being cheaper
|
||||
@@ -47,6 +42,11 @@ PYTHONPATH=./backend
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
|
||||
# Internet Search
|
||||
BING_API_KEY=<REPLACE THIS>
|
||||
|
||||
|
||||
# Enable the full set of Danswer Enterprise Edition features
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
|
||||
|
||||
23
.vscode/launch.template.jsonc
vendored
23
.vscode/launch.template.jsonc
vendored
@@ -49,7 +49,7 @@
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"LOG_ALL_MODEL_INTERACTIONS": "True",
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
@@ -83,6 +83,7 @@
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
@@ -105,6 +106,24 @@
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-v"
|
||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
<a href="https://docs.danswer.dev/" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ" target="_blank">
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ" target="_blank">
|
||||
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
|
||||
</a>
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
|
||||
2
backend/.gitignore
vendored
2
backend/.gitignore
vendored
@@ -5,7 +5,7 @@ site_crawls/
|
||||
.ipynb_checkpoints/
|
||||
api_keys.py
|
||||
*ipynb
|
||||
.env
|
||||
.env*
|
||||
vespa-app.zip
|
||||
dynamic_config_storage/
|
||||
celerybeat-schedule*
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
"""add search doc relevance details
|
||||
|
||||
Revision ID: 05c07bf07c00
|
||||
Revises: b896bbd0d5a7
|
||||
Create Date: 2024-07-10 17:48:15.886653
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "05c07bf07c00"
|
||||
down_revision = "b896bbd0d5a7"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"search_doc",
|
||||
sa.Column("is_relevant", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"search_doc",
|
||||
sa.Column("relevance_explanation", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("search_doc", "relevance_explanation")
|
||||
op.drop_column("search_doc", "is_relevant")
|
||||
@@ -13,8 +13,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3a7802814195"
|
||||
down_revision = "23957775e5f5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
"""add cloud embedding model and update embedding_model
|
||||
|
||||
Revision ID: 44f856ae2a4a
|
||||
Revises: d716b0791ddd
|
||||
Create Date: 2024-06-28 20:01:05.927647
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "44f856ae2a4a"
|
||||
down_revision = "d716b0791ddd"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create embedding_provider table
|
||||
op.create_table(
|
||||
"embedding_provider",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("api_key", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("default_model_id", sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
|
||||
# Add cloud_provider_id to embedding_model table
|
||||
op.add_column(
|
||||
"embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Add foreign key constraints
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_model_cloud_provider",
|
||||
"embedding_model",
|
||||
"embedding_provider",
|
||||
["cloud_provider_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_provider_default_model",
|
||||
"embedding_provider",
|
||||
"embedding_model",
|
||||
["default_model_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove foreign key constraints
|
||||
op.drop_constraint(
|
||||
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
|
||||
)
|
||||
op.drop_constraint(
|
||||
"fk_embedding_provider_default_model", "embedding_provider", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Remove cloud_provider_id column
|
||||
op.drop_column("embedding_model", "cloud_provider_id")
|
||||
|
||||
# Drop embedding_provider table
|
||||
op.drop_table("embedding_provider")
|
||||
@@ -0,0 +1,23 @@
|
||||
"""added is_internet to DBDoc
|
||||
|
||||
Revision ID: 4505fd7302e1
|
||||
Revises: c18cdf4b497e
|
||||
Create Date: 2024-06-18 20:46:09.095034
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4505fd7302e1"
|
||||
down_revision = "c18cdf4b497e"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("search_doc", sa.Column("is_internet", sa.Boolean(), nullable=True))
|
||||
op.add_column("tool", sa.Column("display_name", sa.String(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("tool", "display_name")
|
||||
op.drop_column("search_doc", "is_internet")
|
||||
@@ -0,0 +1,35 @@
|
||||
"""added slack_auto_filter
|
||||
|
||||
Revision ID: 7aea705850d5
|
||||
Revises: 4505fd7302e1
|
||||
Create Date: 2024-07-10 11:01:23.581015
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "7aea705850d5"
|
||||
down_revision = "4505fd7302e1"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"slack_bot_config",
|
||||
sa.Column("enable_auto_filters", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE slack_bot_config SET enable_auto_filters = FALSE WHERE enable_auto_filters IS NULL"
|
||||
)
|
||||
op.alter_column(
|
||||
"slack_bot_config",
|
||||
"enable_auto_filters",
|
||||
existing_type=sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("slack_bot_config", "enable_auto_filters")
|
||||
@@ -0,0 +1,23 @@
|
||||
"""backfill is_internet data to False
|
||||
|
||||
Revision ID: b896bbd0d5a7
|
||||
Revises: 44f856ae2a4a
|
||||
Create Date: 2024-07-16 15:21:05.718571
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b896bbd0d5a7"
|
||||
down_revision = "44f856ae2a4a"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("UPDATE search_doc SET is_internet = FALSE WHERE is_internet IS NULL")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -0,0 +1,75 @@
|
||||
"""Add standard_answer tables
|
||||
|
||||
Revision ID: c18cdf4b497e
|
||||
Revises: 3a7802814195
|
||||
Create Date: 2024-06-06 15:15:02.000648
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c18cdf4b497e"
|
||||
down_revision = "3a7802814195"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"standard_answer",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("keyword", sa.String(), nullable=False),
|
||||
sa.Column("answer", sa.String(), nullable=False),
|
||||
sa.Column("active", sa.Boolean(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("keyword"),
|
||||
)
|
||||
op.create_table(
|
||||
"standard_answer_category",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
op.create_table(
|
||||
"standard_answer__standard_answer_category",
|
||||
sa.Column("standard_answer_id", sa.Integer(), nullable=False),
|
||||
sa.Column("standard_answer_category_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["standard_answer_category_id"],
|
||||
["standard_answer_category.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["standard_answer_id"],
|
||||
["standard_answer.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("standard_answer_id", "standard_answer_category_id"),
|
||||
)
|
||||
op.create_table(
|
||||
"slack_bot_config__standard_answer_category",
|
||||
sa.Column("slack_bot_config_id", sa.Integer(), nullable=False),
|
||||
sa.Column("standard_answer_category_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["slack_bot_config_id"],
|
||||
["slack_bot_config.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["standard_answer_category_id"],
|
||||
["standard_answer_category.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("slack_bot_config_id", "standard_answer_category_id"),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"chat_session", sa.Column("slack_thread_id", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_session", "slack_thread_id")
|
||||
|
||||
op.drop_table("slack_bot_config__standard_answer_category")
|
||||
op.drop_table("standard_answer__standard_answer_category")
|
||||
op.drop_table("standard_answer_category")
|
||||
op.drop_table("standard_answer")
|
||||
@@ -0,0 +1,45 @@
|
||||
"""combined slack id fields
|
||||
|
||||
Revision ID: d716b0791ddd
|
||||
Revises: 7aea705850d5
|
||||
Create Date: 2024-07-10 17:57:45.630550
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d716b0791ddd"
|
||||
down_revision = "7aea705850d5"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE slack_bot_config
|
||||
SET channel_config = jsonb_set(
|
||||
channel_config,
|
||||
'{respond_member_group_list}',
|
||||
coalesce(channel_config->'respond_team_member_list', '[]'::jsonb) ||
|
||||
coalesce(channel_config->'respond_slack_group_list', '[]'::jsonb)
|
||||
) - 'respond_team_member_list' - 'respond_slack_group_list'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE slack_bot_config
|
||||
SET channel_config = jsonb_set(
|
||||
jsonb_set(
|
||||
channel_config - 'respond_member_group_list',
|
||||
'{respond_team_member_list}',
|
||||
'[]'::jsonb
|
||||
),
|
||||
'{respond_slack_group_list}',
|
||||
'[]'::jsonb
|
||||
)
|
||||
"""
|
||||
)
|
||||
@@ -98,7 +98,6 @@ def _run_indexing(
|
||||
3. Updates Postgres to record the indexed documents + the outcome of this run
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
db_embedding_model = index_attempt.embedding_model
|
||||
index_name = db_embedding_model.index_name
|
||||
|
||||
@@ -116,6 +115,8 @@ def _run_indexing(
|
||||
normalize=db_embedding_model.normalize,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
api_key=db_embedding_model.api_key,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
@@ -287,6 +288,7 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
)
|
||||
|
||||
if attempt is None:
|
||||
raise RuntimeError(f"Unable to find IndexAttempt for ID '{index_attempt_id}'")
|
||||
|
||||
|
||||
@@ -343,13 +343,15 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
logger.info("Running a first inference to warm up embedding model")
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
if db_embedding_model.cloud_provider_id is None:
|
||||
logger.info("Running a first inference to warm up embedding model")
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
client_primary: Client | SimpleJobClient
|
||||
client_secondary: Client | SimpleJobClient
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -9,42 +8,30 @@ from danswer.chat.models import LlmDoc
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def llm_doc_from_inference_section(inf_chunk: InferenceSection) -> LlmDoc:
|
||||
def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc:
|
||||
return LlmDoc(
|
||||
document_id=inf_chunk.document_id,
|
||||
document_id=inference_section.center_chunk.document_id,
|
||||
# This one is using the combined content of all the chunks of the section
|
||||
# In default settings, this is the same as just the content of base chunk
|
||||
content=inf_chunk.combined_content,
|
||||
blurb=inf_chunk.blurb,
|
||||
semantic_identifier=inf_chunk.semantic_identifier,
|
||||
source_type=inf_chunk.source_type,
|
||||
metadata=inf_chunk.metadata,
|
||||
updated_at=inf_chunk.updated_at,
|
||||
link=inf_chunk.source_links[0] if inf_chunk.source_links else None,
|
||||
source_links=inf_chunk.source_links,
|
||||
content=inference_section.combined_content,
|
||||
blurb=inference_section.center_chunk.blurb,
|
||||
semantic_identifier=inference_section.center_chunk.semantic_identifier,
|
||||
source_type=inference_section.center_chunk.source_type,
|
||||
metadata=inference_section.center_chunk.metadata,
|
||||
updated_at=inference_section.center_chunk.updated_at,
|
||||
link=inference_section.center_chunk.source_links[0]
|
||||
if inference_section.center_chunk.source_links
|
||||
else None,
|
||||
source_links=inference_section.center_chunk.source_links,
|
||||
)
|
||||
|
||||
|
||||
def map_document_id_order(
|
||||
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
|
||||
) -> dict[str, int]:
|
||||
order_mapping = {}
|
||||
current = 1 if one_indexed else 0
|
||||
for chunk in chunks:
|
||||
if chunk.document_id not in order_mapping:
|
||||
order_mapping[chunk.document_id] = current
|
||||
current += 1
|
||||
|
||||
return order_mapping
|
||||
|
||||
|
||||
def create_chat_chain(
|
||||
chat_session_id: int,
|
||||
db_session: Session,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import cast
|
||||
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -50,7 +48,7 @@ def load_personas_from_yaml(
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for persona in all_personas:
|
||||
doc_set_names = persona["document_sets"]
|
||||
doc_sets: list[DocumentSetDBModel] | None = [
|
||||
doc_sets: list[DocumentSetDBModel] = [
|
||||
get_or_create_document_set_by_name(db_session, name)
|
||||
for name in doc_set_names
|
||||
]
|
||||
@@ -58,22 +56,24 @@ def load_personas_from_yaml(
|
||||
# Assume if user hasn't set any document sets for the persona, the user may want
|
||||
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
|
||||
# the document sets for the persona
|
||||
if not doc_sets:
|
||||
doc_sets = None
|
||||
|
||||
prompt_set_names = persona["prompts"]
|
||||
if not prompt_set_names:
|
||||
prompts: list[PromptDBModel | None] | None = None
|
||||
doc_set_ids: list[int] | None = None
|
||||
if doc_sets:
|
||||
doc_set_ids = [doc_set.id for doc_set in doc_sets]
|
||||
else:
|
||||
prompts = [
|
||||
doc_set_ids = None
|
||||
|
||||
prompt_ids: list[int] | None = None
|
||||
prompt_set_names = persona["prompts"]
|
||||
if prompt_set_names:
|
||||
prompts: list[PromptDBModel | None] = [
|
||||
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
|
||||
for prompt_name in prompt_set_names
|
||||
]
|
||||
if any([prompt is None for prompt in prompts]):
|
||||
raise ValueError("Invalid Persona configs, not all prompts exist")
|
||||
|
||||
if not prompts:
|
||||
prompts = None
|
||||
if prompts:
|
||||
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
|
||||
|
||||
p_id = persona.get("id")
|
||||
upsert_persona(
|
||||
@@ -91,8 +91,8 @@ def load_personas_from_yaml(
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
|
||||
prompts=cast(list[PromptDBModel] | None, prompts),
|
||||
document_sets=doc_sets,
|
||||
prompt_ids=prompt_ids,
|
||||
document_set_ids=doc_set_ids,
|
||||
default_persona=True,
|
||||
is_public=True,
|
||||
db_session=db_session,
|
||||
|
||||
@@ -42,11 +42,21 @@ class QADocsResponse(RetrievalDocs):
|
||||
return initial_dict
|
||||
|
||||
|
||||
# Second chunk of info for streaming QA
|
||||
class LLMRelevanceFilterResponse(BaseModel):
|
||||
relevant_chunk_indices: list[int]
|
||||
|
||||
|
||||
class RelevanceChunk(BaseModel):
|
||||
# TODO make this document level. Also slight misnomer here as this is actually
|
||||
# done at the section level currently rather than the chunk
|
||||
relevant: bool | None = None
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class LLMRelevanceSummaryResponse(BaseModel):
|
||||
relevance_summaries: dict[str, RelevanceChunk]
|
||||
|
||||
|
||||
class DanswerAnswerPiece(BaseModel):
|
||||
# A small piece of a complete answer. Used for streaming back answers.
|
||||
answer_piece: str | None # if None, specifies the end of an Answer
|
||||
|
||||
@@ -10,14 +10,15 @@ from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import CustomToolResponse
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import ImageGenerationDisplay
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.db.chat import attach_files_to_chat_message
|
||||
from danswer.db.chat import create_db_search_doc
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
@@ -49,9 +50,13 @@ from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.retrieval.search_runner import inference_documents_from_ids
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.retrieval.search_runner import inference_sections_from_ids
|
||||
from danswer.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.search.utils import dedupe_documents
|
||||
from danswer.search.utils import drop_llm_indices
|
||||
@@ -66,6 +71,14 @@ from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
from danswer.tools.internet_search.internet_search_tool import (
|
||||
INTERNET_SEARCH_RESPONSE_ID,
|
||||
)
|
||||
from danswer.tools.internet_search.internet_search_tool import (
|
||||
internet_search_response_to_search_docs,
|
||||
)
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
@@ -143,6 +156,37 @@ def _handle_search_tool_response_summary(
|
||||
)
|
||||
|
||||
|
||||
def _handle_internet_search_tool_response_summary(
|
||||
packet: ToolResponse,
|
||||
db_session: Session,
|
||||
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
|
||||
internet_search_response = cast(InternetSearchResponse, packet.response)
|
||||
server_search_docs = internet_search_response_to_search_docs(
|
||||
internet_search_response
|
||||
)
|
||||
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||
for doc in server_search_docs
|
||||
]
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
return (
|
||||
QADocsResponse(
|
||||
rephrased_query=internet_search_response.revised_query,
|
||||
top_documents=response_docs,
|
||||
predicted_flow=QueryFlow.QUESTION_ANSWER,
|
||||
predicted_search=SearchType.HYBRID,
|
||||
applied_source_filters=[],
|
||||
applied_time_cutoff=None,
|
||||
recency_bias_multiplier=1.0,
|
||||
),
|
||||
reference_db_search_docs,
|
||||
)
|
||||
|
||||
|
||||
def _check_should_force_search(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
) -> ForceUseTool | None:
|
||||
@@ -170,7 +214,7 @@ def _check_should_force_search(
|
||||
args = {"query": new_msg_req.message}
|
||||
|
||||
return ForceUseTool(
|
||||
tool_name=SearchTool.NAME,
|
||||
tool_name=SearchTool._NAME,
|
||||
args=args,
|
||||
)
|
||||
return None
|
||||
@@ -338,7 +382,7 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
selected_db_search_docs = None
|
||||
selected_llm_docs: list[LlmDoc] | None = None
|
||||
selected_sections: list[InferenceSection] | None = None
|
||||
if reference_doc_ids:
|
||||
identifier_tuples = get_doc_query_identifiers_from_model(
|
||||
search_doc_ids=reference_doc_ids,
|
||||
@@ -348,8 +392,8 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
# Generates full documents currently
|
||||
# May extend to include chunk ranges
|
||||
selected_llm_docs = inference_documents_from_ids(
|
||||
# May extend to use sections instead in the future
|
||||
selected_sections = inference_sections_from_ids(
|
||||
doc_identifiers=identifier_tuples,
|
||||
document_index=document_index,
|
||||
)
|
||||
@@ -428,20 +472,20 @@ def stream_chat_message_objects(
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
selected_docs=selected_llm_docs,
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
dalle_key = None
|
||||
img_generation_llm_config: LLMConfig | None = None
|
||||
if (
|
||||
llm
|
||||
and llm.config.api_key
|
||||
and llm.config.model_provider == "openai"
|
||||
):
|
||||
dalle_key = llm.config.api_key
|
||||
img_generation_llm_config = llm.config
|
||||
else:
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
openai_provider = next(
|
||||
@@ -458,13 +502,31 @@ def stream_chat_message_objects(
|
||||
raise ValueError(
|
||||
"Image generation tool requires an OpenAI API key"
|
||||
)
|
||||
dalle_key = openai_provider.api_key
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider=openai_provider.provider,
|
||||
model_name=openai_provider.default_model_name,
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=openai_provider.api_key,
|
||||
api_base=openai_provider.api_base,
|
||||
api_version=openai_provider.api_version,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
ImageGenerationTool(
|
||||
api_key=dalle_key,
|
||||
api_key=cast(str, img_generation_llm_config.api_key),
|
||||
api_base=img_generation_llm_config.api_base,
|
||||
api_version=img_generation_llm_config.api_version,
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
]
|
||||
elif tool_cls.__name__ == InternetSearchTool.__name__:
|
||||
bing_api_key = BING_API_KEY
|
||||
if not bing_api_key:
|
||||
raise ValueError(
|
||||
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
InternetSearchTool(api_key=bing_api_key)
|
||||
]
|
||||
|
||||
continue
|
||||
|
||||
@@ -571,6 +633,15 @@ def stream_chat_message_objects(
|
||||
yield ImageGenerationDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
yield CustomToolResponse(
|
||||
@@ -612,7 +683,7 @@ def stream_chat_message_objects(
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name()] = tool_id
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
message=answer.llm_answer,
|
||||
|
||||
@@ -223,6 +223,11 @@ MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE = int(
|
||||
os.environ.get("MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE", 0)
|
||||
)
|
||||
|
||||
# comma delimited list of zendesk article labels to skip indexing for
|
||||
ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS = os.environ.get(
|
||||
"ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS", ""
|
||||
).split(",")
|
||||
|
||||
|
||||
#####
|
||||
# Indexing Configs
|
||||
@@ -243,13 +248,15 @@ DISABLE_INDEX_UPDATE_ON_SWAP = (
|
||||
# fairly large amount of memory in order to increase substantially, since
|
||||
# each worker loads the embedding models into memory.
|
||||
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
|
||||
CHUNK_OVERLAP = 0
|
||||
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
|
||||
ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
|
||||
# Finer grained chunking for more detail retention
|
||||
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
|
||||
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
|
||||
MINI_CHUNK_SIZE = 150
|
||||
# Include the document level metadata in each chunk. If the metadata is too long, then it is thrown out
|
||||
# We don't want the metadata to overwhelm the actual contents of the chunk
|
||||
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"
|
||||
# Timeout to wait for job's last update before killing it, in hours
|
||||
CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3))
|
||||
|
||||
@@ -266,10 +273,14 @@ JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
|
||||
CURRENT_PROCESS_IS_AN_INDEXING_JOB = (
|
||||
os.environ.get("CURRENT_PROCESS_IS_AN_INDEXING_JOB", "").lower() == "true"
|
||||
)
|
||||
# Logs every model prompt and output, mostly used for development or exploration purposes
|
||||
# Sets LiteLLM to verbose logging
|
||||
LOG_ALL_MODEL_INTERACTIONS = (
|
||||
os.environ.get("LOG_ALL_MODEL_INTERACTIONS", "").lower() == "true"
|
||||
)
|
||||
# Logs Danswer only model interactions like prompts, responses, messages etc.
|
||||
LOG_DANSWER_MODEL_INTERACTIONS = (
|
||||
os.environ.get("LOG_DANSWER_MODEL_INTERACTIONS", "").lower() == "true"
|
||||
)
|
||||
# If set to `true` will enable additional logs about Vespa query performance
|
||||
# (time spent on finding the right docs + time spent fetching summaries from disk)
|
||||
LOG_VESPA_TIMING_INFORMATION = (
|
||||
|
||||
@@ -5,7 +5,10 @@ PROMPTS_YAML = "./danswer/chat/prompts.yaml"
|
||||
PERSONAS_YAML = "./danswer/chat/personas.yaml"
|
||||
|
||||
NUM_RETURNED_HITS = 50
|
||||
NUM_RERANKED_RESULTS = 15
|
||||
# Used for LLM filtering and reranking
|
||||
# We want this to be approximately the number of results we want to show on the first page
|
||||
# It cannot be too large due to cost and latency implications
|
||||
NUM_RERANKED_RESULTS = 20
|
||||
|
||||
# May be less depending on model
|
||||
MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
|
||||
@@ -25,9 +28,10 @@ BASE_RECENCY_DECAY = 0.5
|
||||
FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
|
||||
# Currently this next one is not configurable via env
|
||||
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
|
||||
DISABLE_LLM_FILTER_EXTRACTION = (
|
||||
os.environ.get("DISABLE_LLM_FILTER_EXTRACTION", "").lower() == "true"
|
||||
)
|
||||
# For the highest matching base size chunk, how many chunks above and below do we pull in by default
|
||||
# Note this is not in any of the deployment configs yet
|
||||
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 0)
|
||||
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 0)
|
||||
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
|
||||
# in relation to the user query
|
||||
DISABLE_LLM_CHUNK_FILTER = (
|
||||
@@ -43,8 +47,6 @@ DISABLE_LLM_QUERY_REPHRASE = (
|
||||
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
|
||||
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
|
||||
# Include additional document/chunk metadata in prompt to GenerativeAI
|
||||
INCLUDE_METADATA = False
|
||||
# Keyword Search Drop Stopwords
|
||||
# If user has changed the default model, would most likely be to use a multilingual
|
||||
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords
|
||||
@@ -73,8 +75,22 @@ LANGUAGE_CHAT_NAMING_HINT = (
|
||||
or "The name of the conversation must be in the same language as the user query."
|
||||
)
|
||||
|
||||
|
||||
# Agentic search takes significantly more tokens and therefore has much higher cost.
|
||||
# This configuration allows users to get a search-only experience with instant results
|
||||
# and no involvement from the LLM.
|
||||
# Additionally, some LLM providers have strict rate limits which may prohibit
|
||||
# sending many API requests at once (as is done in agentic search).
|
||||
DISABLE_AGENTIC_SEARCH = (
|
||||
os.environ.get("DISABLE_AGENTIC_SEARCH") or "false"
|
||||
).lower() == "true"
|
||||
|
||||
|
||||
# Stops streaming answers back to the UI if this pattern is seen:
|
||||
STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
|
||||
|
||||
# The backend logic for this being True isn't fully supported yet
|
||||
HARD_DELETE_CHATS = False
|
||||
|
||||
# Internet Search
|
||||
BING_API_KEY = os.environ.get("BING_API_KEY") or None
|
||||
|
||||
@@ -19,6 +19,7 @@ DOCUMENT_SETS = "document_sets"
|
||||
TIME_FILTER = "time_filter"
|
||||
METADATA = "metadata"
|
||||
METADATA_LIST = "metadata_list"
|
||||
METADATA_SUFFIX = "metadata_suffix"
|
||||
MATCH_HIGHLIGHTS = "match_highlights"
|
||||
# stored in the `metadata` of a chunk. Used to signify that this chunk should
|
||||
# not be used for QA. For example, Google Drive file types which can't be parsed
|
||||
@@ -43,7 +44,8 @@ QUERY_EVENT_ID = "query_event_id"
|
||||
LLM_CHUNKS = "llm_chunks"
|
||||
|
||||
# For chunking/processing chunks
|
||||
TITLE_SEPARATOR = "\n\r\n"
|
||||
MAX_CHUNK_TITLE_LEN = 1000
|
||||
RETURN_SEPARATOR = "\n\r\n"
|
||||
SECTION_SEPARATOR = "\n\n"
|
||||
# For combining attributes, doesn't have to be unique/perfect to work
|
||||
INDEX_SEPARATOR = "==="
|
||||
@@ -104,6 +106,7 @@ class DocumentSource(str, Enum):
|
||||
R2 = "r2"
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
OCI_STORAGE = "oci_storage"
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
@@ -112,6 +115,9 @@ class BlobType(str, Enum):
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
OCI_STORAGE = "oci_storage"
|
||||
|
||||
# Special case, for internet search
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
|
||||
|
||||
class DocumentIndexType(str, Enum):
|
||||
COMBINED = "combined" # Vespa
|
||||
|
||||
@@ -47,10 +47,6 @@ DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
|
||||
DANSWER_BOT_RESPOND_EVERY_CHANNEL = (
|
||||
os.environ.get("DANSWER_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true"
|
||||
)
|
||||
# Auto detect query options like time cutoff or heavily favor recently updated docs
|
||||
DISABLE_DANSWER_BOT_FILTER_DETECT = (
|
||||
os.environ.get("DISABLE_DANSWER_BOT_FILTER_DETECT", "").lower() == "true"
|
||||
)
|
||||
# Add a second LLM call post Answer to verify if the Answer is valid
|
||||
# Throws out answers that don't directly or fully answer the user query
|
||||
# This is the default for all DanswerBot channels unless the channel is configured individually
|
||||
|
||||
@@ -39,8 +39,8 @@ ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ")
|
||||
# Purely an optimization, memory limitation consideration
|
||||
BATCH_SIZE_ENCODE_CHUNKS = 8
|
||||
# For score display purposes, only way is to know the expected ranges
|
||||
CROSS_ENCODER_RANGE_MAX = 12
|
||||
CROSS_ENCODER_RANGE_MIN = -12
|
||||
CROSS_ENCODER_RANGE_MAX = 1
|
||||
CROSS_ENCODER_RANGE_MIN = 0
|
||||
|
||||
# Unused currently, can't be used with the current default encoder model due to its output range
|
||||
SEARCH_DISTANCE_CUTOFF = 0
|
||||
|
||||
@@ -37,16 +37,18 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
# Potential Improvements
|
||||
# 1. If wiki page instead of space, do a search of all the children of the page instead of index all in the space
|
||||
# 2. Include attachments, etc
|
||||
# 3. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
|
||||
# 1. Include attachments, etc
|
||||
# 2. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
|
||||
|
||||
|
||||
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str]:
|
||||
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]:
|
||||
"""Sample
|
||||
https://danswer.atlassian.net/wiki/spaces/1234abcd/overview
|
||||
URL w/ page: https://danswer.atlassian.net/wiki/spaces/1234abcd/pages/5678efgh/overview
|
||||
URL w/o page: https://danswer.atlassian.net/wiki/spaces/ASAM/overview
|
||||
|
||||
wiki_base is https://danswer.atlassian.net/wiki
|
||||
space is 1234abcd
|
||||
page_id is 5678efgh
|
||||
"""
|
||||
parsed_url = urlparse(wiki_url)
|
||||
wiki_base = (
|
||||
@@ -55,18 +57,25 @@ def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str]:
|
||||
+ parsed_url.netloc
|
||||
+ parsed_url.path.split("/spaces")[0]
|
||||
)
|
||||
space = parsed_url.path.split("/")[3]
|
||||
return wiki_base, space
|
||||
|
||||
path_parts = parsed_url.path.split("/")
|
||||
space = path_parts[3]
|
||||
|
||||
page_id = path_parts[5] if len(path_parts) > 5 else ""
|
||||
return wiki_base, space, page_id
|
||||
|
||||
|
||||
def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, str]:
|
||||
def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, str, str]:
|
||||
"""Sample
|
||||
https://danswer.ai/confluence/display/1234abcd/overview
|
||||
URL w/ page https://danswer.ai/confluence/display/1234abcd/pages/5678efgh/overview
|
||||
URL w/o page https://danswer.ai/confluence/display/1234abcd/overview
|
||||
wiki_base is https://danswer.ai/confluence
|
||||
space is 1234abcd
|
||||
page_id is 5678efgh
|
||||
"""
|
||||
# /display/ is always right before the space and at the end of the base url
|
||||
# /display/ is always right before the space and at the end of the base print()
|
||||
DISPLAY = "/display/"
|
||||
PAGE = "/pages/"
|
||||
|
||||
parsed_url = urlparse(wiki_url)
|
||||
wiki_base = (
|
||||
@@ -76,10 +85,13 @@ def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, st
|
||||
+ parsed_url.path.split(DISPLAY)[0]
|
||||
)
|
||||
space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0]
|
||||
return wiki_base, space
|
||||
page_id = ""
|
||||
if (content := parsed_url.path.split(PAGE)) and len(content) > 1:
|
||||
page_id = content[1]
|
||||
return wiki_base, space, page_id
|
||||
|
||||
|
||||
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, bool]:
|
||||
def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]:
|
||||
is_confluence_cloud = (
|
||||
".atlassian.net/wiki/spaces/" in wiki_url
|
||||
or ".jira.com/wiki/spaces/" in wiki_url
|
||||
@@ -87,15 +99,19 @@ def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, bool]:
|
||||
|
||||
try:
|
||||
if is_confluence_cloud:
|
||||
wiki_base, space = _extract_confluence_keys_from_cloud_url(wiki_url)
|
||||
wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url(
|
||||
wiki_url
|
||||
)
|
||||
else:
|
||||
wiki_base, space = _extract_confluence_keys_from_datacenter_url(wiki_url)
|
||||
wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url(
|
||||
wiki_url
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Not a valid Confluence Wiki Link, unable to extract wiki base and space names. Exception: {e}"
|
||||
error_msg = f"Not a valid Confluence Wiki Link, unable to extract wiki base, space, and page id. Exception: {e}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
return wiki_base, space, is_confluence_cloud
|
||||
return wiki_base, space, page_id, is_confluence_cloud
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@@ -196,10 +212,135 @@ def _comment_dfs(
|
||||
return comments_str
|
||||
|
||||
|
||||
class RecursiveIndexer:
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
confluence_client: Confluence,
|
||||
index_origin: bool,
|
||||
origin_page_id: str,
|
||||
) -> None:
|
||||
self.batch_size = 1
|
||||
# batch_size
|
||||
self.confluence_client = confluence_client
|
||||
self.index_origin = index_origin
|
||||
self.origin_page_id = origin_page_id
|
||||
self.pages = self.recurse_children_pages(0, self.origin_page_id)
|
||||
|
||||
def get_pages(self, ind: int, size: int) -> list[dict]:
|
||||
if ind * size > len(self.pages):
|
||||
return []
|
||||
return self.pages[ind * size : (ind + 1) * size]
|
||||
|
||||
def _fetch_origin_page(
|
||||
self,
|
||||
) -> dict[str, Any]:
|
||||
get_page_by_id = make_confluence_call_handle_rate_limit(
|
||||
self.confluence_client.get_page_by_id
|
||||
)
|
||||
try:
|
||||
origin_page = get_page_by_id(
|
||||
self.origin_page_id, expand="body.storage.value,version"
|
||||
)
|
||||
return origin_page
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Appending orgin page with id {self.origin_page_id} failed: {e}"
|
||||
)
|
||||
return {}
|
||||
|
||||
def recurse_children_pages(
|
||||
self,
|
||||
start_ind: int,
|
||||
page_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
pages: list[dict[str, Any]] = []
|
||||
current_level_pages: list[dict[str, Any]] = []
|
||||
next_level_pages: list[dict[str, Any]] = []
|
||||
|
||||
# Initial fetch of first level children
|
||||
index = start_ind
|
||||
while batch := self._fetch_single_depth_child_pages(
|
||||
index, self.batch_size, page_id
|
||||
):
|
||||
current_level_pages.extend(batch)
|
||||
index += len(batch)
|
||||
|
||||
pages.extend(current_level_pages)
|
||||
|
||||
# Recursively index children and children's children, etc.
|
||||
while current_level_pages:
|
||||
for child in current_level_pages:
|
||||
child_index = 0
|
||||
while child_batch := self._fetch_single_depth_child_pages(
|
||||
child_index, self.batch_size, child["id"]
|
||||
):
|
||||
next_level_pages.extend(child_batch)
|
||||
child_index += len(child_batch)
|
||||
|
||||
pages.extend(next_level_pages)
|
||||
current_level_pages = next_level_pages
|
||||
next_level_pages = []
|
||||
|
||||
if self.index_origin:
|
||||
try:
|
||||
origin_page = self._fetch_origin_page()
|
||||
pages.append(origin_page)
|
||||
except Exception as e:
|
||||
logger.warning(f"Appending origin page with id {page_id} failed: {e}")
|
||||
|
||||
return pages
|
||||
|
||||
def _fetch_single_depth_child_pages(
|
||||
self, start_ind: int, batch_size: int, page_id: str
|
||||
) -> list[dict[str, Any]]:
|
||||
child_pages: list[dict[str, Any]] = []
|
||||
|
||||
get_page_child_by_type = make_confluence_call_handle_rate_limit(
|
||||
self.confluence_client.get_page_child_by_type
|
||||
)
|
||||
|
||||
try:
|
||||
child_page = get_page_child_by_type(
|
||||
page_id,
|
||||
type="page",
|
||||
start=start_ind,
|
||||
limit=batch_size,
|
||||
expand="body.storage.value,version",
|
||||
)
|
||||
|
||||
child_pages.extend(child_page)
|
||||
return child_pages
|
||||
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Batch failed with page {page_id} at offset {start_ind} "
|
||||
f"with size {batch_size}, processing pages individually..."
|
||||
)
|
||||
|
||||
for i in range(batch_size):
|
||||
ind = start_ind + i
|
||||
try:
|
||||
child_page = get_page_child_by_type(
|
||||
page_id,
|
||||
type="page",
|
||||
start=ind,
|
||||
limit=1,
|
||||
expand="body.storage.value,version",
|
||||
)
|
||||
child_pages.extend(child_page)
|
||||
except Exception as e:
|
||||
logger.warning(f"Page {page_id} at offset {ind} failed: {e}")
|
||||
raise e
|
||||
|
||||
return child_pages
|
||||
|
||||
|
||||
class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
wiki_page_url: str,
|
||||
index_origin: bool = True,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
# if a page has one of the labels specified in this list, we will just
|
||||
@@ -210,11 +351,27 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
self.wiki_base, self.space, self.is_cloud = extract_confluence_keys_from_url(
|
||||
wiki_page_url
|
||||
)
|
||||
self.recursive_indexer: RecursiveIndexer | None = None
|
||||
self.index_origin = index_origin
|
||||
(
|
||||
self.wiki_base,
|
||||
self.space,
|
||||
self.page_id,
|
||||
self.is_cloud,
|
||||
) = extract_confluence_keys_from_url(wiki_page_url)
|
||||
|
||||
self.space_level_scan = False
|
||||
|
||||
self.confluence_client: Confluence | None = None
|
||||
|
||||
if self.page_id is None or self.page_id == "":
|
||||
self.space_level_scan = True
|
||||
|
||||
logger.info(
|
||||
f"wiki_base: {self.wiki_base}, space: {self.space}, page_id: {self.page_id},"
|
||||
+ f" space_level_scan: {self.space_level_scan}, origin: {self.index_origin}"
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
username = credentials["confluence_username"]
|
||||
access_token = credentials["confluence_access_token"]
|
||||
@@ -232,8 +389,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
self,
|
||||
confluence_client: Confluence,
|
||||
start_ind: int,
|
||||
) -> Collection[dict[str, Any]]:
|
||||
def _fetch(start_ind: int, batch_size: int) -> Collection[dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
def _fetch_space(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
|
||||
get_all_pages_from_space = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_all_pages_from_space
|
||||
)
|
||||
@@ -242,9 +399,11 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
self.space,
|
||||
start=start_ind,
|
||||
limit=batch_size,
|
||||
status="current"
|
||||
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
|
||||
else None,
|
||||
status=(
|
||||
"current"
|
||||
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
|
||||
else None
|
||||
),
|
||||
expand="body.storage.value,version",
|
||||
)
|
||||
except Exception:
|
||||
@@ -263,9 +422,11 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
self.space,
|
||||
start=start_ind + i,
|
||||
limit=1,
|
||||
status="current"
|
||||
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
|
||||
else None,
|
||||
status=(
|
||||
"current"
|
||||
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
|
||||
else None
|
||||
),
|
||||
expand="body.storage.value,version",
|
||||
)
|
||||
)
|
||||
@@ -286,17 +447,41 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
return view_pages
|
||||
|
||||
def _fetch_page(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
|
||||
if self.recursive_indexer is None:
|
||||
self.recursive_indexer = RecursiveIndexer(
|
||||
origin_page_id=self.page_id,
|
||||
batch_size=self.batch_size,
|
||||
confluence_client=self.confluence_client,
|
||||
index_origin=self.index_origin,
|
||||
)
|
||||
|
||||
return self.recursive_indexer.get_pages(start_ind, batch_size)
|
||||
|
||||
pages: list[dict[str, Any]] = []
|
||||
|
||||
try:
|
||||
return _fetch(start_ind, self.batch_size)
|
||||
pages = (
|
||||
_fetch_space(start_ind, self.batch_size)
|
||||
if self.space_level_scan
|
||||
else _fetch_page(start_ind, self.batch_size)
|
||||
)
|
||||
return pages
|
||||
|
||||
except Exception as e:
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
|
||||
# error checking phase, only reachable if `self.continue_on_failure=True`
|
||||
pages: list[dict[str, Any]] = []
|
||||
for i in range(self.batch_size):
|
||||
try:
|
||||
pages.extend(_fetch(start_ind + i, 1))
|
||||
pages = (
|
||||
_fetch_space(start_ind, self.batch_size)
|
||||
if self.space_level_scan
|
||||
else _fetch_page(start_ind, self.batch_size)
|
||||
)
|
||||
return pages
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Ran into exception when fetching pages from Confluence"
|
||||
@@ -308,6 +493,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
get_page_child_by_type = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_page_child_by_type
|
||||
)
|
||||
|
||||
try:
|
||||
comment_pages = cast(
|
||||
Collection[dict[str, Any]],
|
||||
@@ -356,7 +542,14 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
page_id, start=0, limit=500
|
||||
)
|
||||
for attachment in attachments_container["results"]:
|
||||
if attachment["metadata"]["mediaType"] in ["image/jpeg", "image/png"]:
|
||||
if attachment["metadata"]["mediaType"] in [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"video/mp4",
|
||||
"video/quicktime",
|
||||
]:
|
||||
continue
|
||||
|
||||
if attachment["title"] not in files_in_used:
|
||||
@@ -367,9 +560,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
if response.status_code == 200:
|
||||
extract = extract_file_text(
|
||||
attachment["title"],
|
||||
io.BytesIO(response.content),
|
||||
break_on_unprocessable=False,
|
||||
attachment["title"], io.BytesIO(response.content), False
|
||||
)
|
||||
files_attachment_content.append(extract)
|
||||
|
||||
@@ -389,8 +580,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
batch = self._fetch_pages(self.confluence_client, start_ind)
|
||||
|
||||
for page in batch:
|
||||
last_modified_str = page["version"]["when"]
|
||||
author = cast(str | None, page["version"].get("by", {}).get("email"))
|
||||
@@ -405,6 +596,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
if time_filter is None or time_filter(last_modified):
|
||||
page_id = page["id"]
|
||||
|
||||
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
|
||||
page_labels = self._fetch_labels(self.confluence_client, page_id)
|
||||
|
||||
@@ -416,6 +608,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
f"Page with ID '{page_id}' has a label which has been "
|
||||
f"designated as disallowed: {label_intersection}. Skipping."
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
page_html = (
|
||||
@@ -436,7 +629,6 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
page_text += attachment_text
|
||||
comments_text = self._fetch_comments(self.confluence_client, page_id)
|
||||
page_text += comments_text
|
||||
|
||||
doc_metadata: dict[str, str | list[str]] = {
|
||||
"Wiki Space Name": self.space
|
||||
}
|
||||
@@ -450,9 +642,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=page["title"],
|
||||
doc_updated_at=last_modified,
|
||||
primary_owners=[BasicExpertInfo(email=author)]
|
||||
if author
|
||||
else None,
|
||||
primary_owners=(
|
||||
[BasicExpertInfo(email=author)] if author else None
|
||||
),
|
||||
metadata=doc_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
from requests import HTTPError
|
||||
from retry import retry
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
@@ -18,23 +22,38 @@ class ConfluenceRateLimitError(Exception):
|
||||
|
||||
|
||||
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
|
||||
@retry(
|
||||
exceptions=ConfluenceRateLimitError,
|
||||
tries=10,
|
||||
delay=1,
|
||||
max_delay=600, # 10 minutes
|
||||
backoff=2,
|
||||
jitter=1,
|
||||
)
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
try:
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
if (
|
||||
e.response.status_code == 429
|
||||
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
|
||||
):
|
||||
raise ConfluenceRateLimitError()
|
||||
raise
|
||||
starting_delay = 5
|
||||
backoff = 2
|
||||
max_delay = 600
|
||||
|
||||
for attempt in range(10):
|
||||
try:
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
if (
|
||||
e.response.status_code == 429
|
||||
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
|
||||
):
|
||||
retry_after = None
|
||||
try:
|
||||
retry_after = int(e.response.headers.get("Retry-After"))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
if retry_after:
|
||||
logger.warning(
|
||||
f"Rate limit hit. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
time.sleep(retry_after)
|
||||
else:
|
||||
logger.warning(
|
||||
"Rate limit hit. Retrying with exponential backoff..."
|
||||
)
|
||||
delay = min(starting_delay * (backoff**attempt), max_delay)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
# re-raise, let caller handle
|
||||
raise
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import TypeVar
|
||||
|
||||
from dateutil.parser import parse
|
||||
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.utils.text_processing import is_valid_email
|
||||
|
||||
@@ -57,3 +58,7 @@ def process_in_batches(
|
||||
) -> Iterator[list[U]]:
|
||||
for i in range(0, len(objects), batch_size):
|
||||
yield [process_function(obj) for obj in objects[i : i + batch_size]]
|
||||
|
||||
|
||||
def get_metadata_keys_to_ignore() -> list[str]:
|
||||
return [IGNORE_FOR_QA]
|
||||
|
||||
@@ -11,6 +11,9 @@ from requests import Response
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
@@ -58,67 +61,36 @@ class DiscourseConnector(PollConnector):
|
||||
self.category_id_map: dict[int, str] = {}
|
||||
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.permissions: DiscoursePerms | None = None
|
||||
self.active_categories: set | None = None
|
||||
|
||||
@rate_limit_builder(max_calls=100, period=60)
|
||||
def _make_request(self, endpoint: str, params: dict | None = None) -> Response:
|
||||
if not self.permissions:
|
||||
raise ConnectorMissingCredentialError("Discourse")
|
||||
return discourse_request(endpoint, self.permissions, params)
|
||||
|
||||
def _get_categories_map(
|
||||
self,
|
||||
) -> None:
|
||||
assert self.permissions is not None
|
||||
categories_endpoint = urllib.parse.urljoin(self.base_url, "categories.json")
|
||||
response = discourse_request(
|
||||
response = self._make_request(
|
||||
endpoint=categories_endpoint,
|
||||
perms=self.permissions,
|
||||
params={"include_subcategories": True},
|
||||
)
|
||||
categories = response.json()["category_list"]["categories"]
|
||||
|
||||
self.category_id_map = {
|
||||
category["id"]: category["name"]
|
||||
for category in categories
|
||||
if not self.categories or category["name"].lower() in self.categories
|
||||
cat["id"]: cat["name"]
|
||||
for cat in categories
|
||||
if not self.categories or cat["name"].lower() in self.categories
|
||||
}
|
||||
|
||||
def _get_latest_topics(
|
||||
self, start: datetime | None, end: datetime | None
|
||||
) -> list[int]:
|
||||
assert self.permissions is not None
|
||||
topic_ids = []
|
||||
|
||||
valid_categories = set(self.category_id_map.keys())
|
||||
|
||||
latest_endpoint = urllib.parse.urljoin(self.base_url, "latest.json")
|
||||
response = discourse_request(endpoint=latest_endpoint, perms=self.permissions)
|
||||
topics = response.json()["topic_list"]["topics"]
|
||||
for topic in topics:
|
||||
last_time = topic.get("last_posted_at")
|
||||
if not last_time:
|
||||
continue
|
||||
last_time_dt = time_str_to_utc(last_time)
|
||||
|
||||
if start and start > last_time_dt:
|
||||
continue
|
||||
if end and end < last_time_dt:
|
||||
continue
|
||||
|
||||
if (
|
||||
self.categories
|
||||
and valid_categories
|
||||
and topic.get("category_id") not in valid_categories
|
||||
):
|
||||
continue
|
||||
|
||||
topic_ids.append(topic["id"])
|
||||
|
||||
return topic_ids
|
||||
self.active_categories = set(self.category_id_map)
|
||||
|
||||
def _get_doc_from_topic(self, topic_id: int) -> Document:
|
||||
assert self.permissions is not None
|
||||
topic_endpoint = urllib.parse.urljoin(self.base_url, f"t/{topic_id}.json")
|
||||
response = discourse_request(
|
||||
endpoint=topic_endpoint,
|
||||
perms=self.permissions,
|
||||
)
|
||||
response = self._make_request(endpoint=topic_endpoint)
|
||||
topic = response.json()
|
||||
|
||||
topic_url = urllib.parse.urljoin(self.base_url, f"t/{topic['slug']}")
|
||||
@@ -167,26 +139,78 @@ class DiscourseConnector(PollConnector):
|
||||
)
|
||||
return doc
|
||||
|
||||
def _get_latest_topics(
|
||||
self, start: datetime | None, end: datetime | None, page: int
|
||||
) -> list[int]:
|
||||
assert self.permissions is not None
|
||||
topic_ids = []
|
||||
|
||||
if not self.categories:
|
||||
latest_endpoint = urllib.parse.urljoin(
|
||||
self.base_url, f"latest.json?page={page}"
|
||||
)
|
||||
response = self._make_request(endpoint=latest_endpoint)
|
||||
topics = response.json()["topic_list"]["topics"]
|
||||
|
||||
else:
|
||||
topics = []
|
||||
empty_categories = []
|
||||
|
||||
for category_id in self.category_id_map.keys():
|
||||
category_endpoint = urllib.parse.urljoin(
|
||||
self.base_url, f"c/{category_id}.json?page={page}&sys=latest"
|
||||
)
|
||||
response = self._make_request(endpoint=category_endpoint)
|
||||
new_topics = response.json()["topic_list"]["topics"]
|
||||
|
||||
if len(new_topics) == 0:
|
||||
empty_categories.append(category_id)
|
||||
topics.extend(new_topics)
|
||||
|
||||
for empty_category in empty_categories:
|
||||
self.category_id_map.pop(empty_category)
|
||||
|
||||
for topic in topics:
|
||||
last_time = topic.get("last_posted_at")
|
||||
if not last_time:
|
||||
continue
|
||||
|
||||
last_time_dt = time_str_to_utc(last_time)
|
||||
if (start and start > last_time_dt) or (end and end < last_time_dt):
|
||||
continue
|
||||
|
||||
topic_ids.append(topic["id"])
|
||||
if len(topic_ids) >= self.batch_size:
|
||||
break
|
||||
|
||||
return topic_ids
|
||||
|
||||
def _yield_discourse_documents(
|
||||
self, topic_ids: list[int]
|
||||
self,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch: list[Document] = []
|
||||
for topic_id in topic_ids:
|
||||
doc_batch.append(self._get_doc_from_topic(topic_id))
|
||||
page = 1
|
||||
while topic_ids := self._get_latest_topics(start, end, page):
|
||||
doc_batch: list[Document] = []
|
||||
for topic_id in topic_ids:
|
||||
doc_batch.append(self._get_doc_from_topic(topic_id))
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
page += 1
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
def load_credentials(
|
||||
self,
|
||||
credentials: dict[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
self.permissions = DiscoursePerms(
|
||||
api_key=credentials["discourse_api_key"],
|
||||
api_username=credentials["discourse_api_username"],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def poll_source(
|
||||
@@ -194,16 +218,13 @@ class DiscourseConnector(PollConnector):
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.permissions is None:
|
||||
raise ConnectorMissingCredentialError("Discourse")
|
||||
|
||||
start_datetime = datetime.utcfromtimestamp(start).replace(tzinfo=timezone.utc)
|
||||
end_datetime = datetime.utcfromtimestamp(end).replace(tzinfo=timezone.utc)
|
||||
|
||||
self._get_categories_map()
|
||||
|
||||
latest_topic_ids = self._get_latest_topics(
|
||||
start=start_datetime, end=end_datetime
|
||||
)
|
||||
|
||||
yield from self._yield_discourse_documents(latest_topic_ids)
|
||||
yield from self._yield_discourse_documents(start_datetime, end_datetime)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -219,7 +240,5 @@ if __name__ == "__main__":
|
||||
|
||||
current = time.time()
|
||||
one_year_ago = current - 24 * 60 * 60 * 360
|
||||
|
||||
latest_docs = connector.poll_source(one_year_ago, current)
|
||||
|
||||
print(next(latest_docs))
|
||||
|
||||
@@ -85,6 +85,11 @@ def _process_file(
|
||||
|
||||
all_metadata = {**metadata, **file_metadata} if metadata else file_metadata
|
||||
|
||||
# add a prefix to avoid conflicts with other connectors
|
||||
doc_id = f"FILE_CONNECTOR__{file_name}"
|
||||
if metadata:
|
||||
doc_id = metadata.get("document_id") or doc_id
|
||||
|
||||
# If this is set, we will show this in the UI as the "name" of the file
|
||||
file_display_name = all_metadata.get("file_display_name") or os.path.basename(
|
||||
file_name
|
||||
@@ -106,6 +111,7 @@ def _process_file(
|
||||
for k, v in all_metadata.items()
|
||||
if k
|
||||
not in [
|
||||
"document_id",
|
||||
"time_updated",
|
||||
"doc_updated_at",
|
||||
"link",
|
||||
@@ -132,7 +138,7 @@ def _process_file(
|
||||
|
||||
return [
|
||||
Document(
|
||||
id=f"FILE_CONNECTOR__{file_name}", # add a prefix to avoid conflicts with other connectors
|
||||
id=doc_id,
|
||||
sections=[
|
||||
Section(link=all_metadata.get("link"), text=file_content_raw.strip())
|
||||
],
|
||||
|
||||
@@ -6,6 +6,7 @@ from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import INDEX_SEPARATOR
|
||||
from danswer.configs.constants import RETURN_SEPARATOR
|
||||
from danswer.utils.text_processing import make_url_compatible
|
||||
|
||||
|
||||
@@ -117,7 +118,12 @@ class DocumentBase(BaseModel):
|
||||
# If title is explicitly empty, return a None here for embedding purposes
|
||||
if self.title == "":
|
||||
return None
|
||||
return self.semantic_identifier if self.title is None else self.title
|
||||
replace_chars = set(RETURN_SEPARATOR)
|
||||
title = self.semantic_identifier if self.title is None else self.title
|
||||
for char in replace_chars:
|
||||
title = title.replace(char, " ")
|
||||
title = title.strip()
|
||||
return title
|
||||
|
||||
def get_metadata_str_attributes(self) -> list[str] | None:
|
||||
if not self.metadata:
|
||||
|
||||
@@ -368,7 +368,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
compare_time = time.mktime(
|
||||
time.strptime(page[filter_field], "%Y-%m-%dT%H:%M:%S.000Z")
|
||||
)
|
||||
if compare_time <= end or compare_time > start:
|
||||
if compare_time > start and compare_time <= end:
|
||||
filtered_pages += [NotionPage(**page)]
|
||||
return filtered_pages
|
||||
|
||||
|
||||
@@ -79,8 +79,9 @@ class SalesforceConnector(LoadConnector, PollConnector, IdConnector):
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
extracted_id = f"{ID_PREFIX}{object_dict['Id']}"
|
||||
extracted_link = f"https://{self.sf_client.sf_instance}/{extracted_id}"
|
||||
salesforce_id = object_dict["Id"]
|
||||
danswer_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
|
||||
extracted_link = f"https://{self.sf_client.sf_instance}/{salesforce_id}"
|
||||
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
|
||||
extracted_object_text = extract_dict_text(object_dict)
|
||||
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
|
||||
@@ -91,7 +92,7 @@ class SalesforceConnector(LoadConnector, PollConnector, IdConnector):
|
||||
]
|
||||
|
||||
doc = Document(
|
||||
id=extracted_id,
|
||||
id=danswer_salesforce_id,
|
||||
sections=[Section(link=extracted_link, text=extracted_object_text)],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier=extracted_semantic_identifier,
|
||||
|
||||
@@ -29,6 +29,7 @@ from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import pdf_to_text
|
||||
from danswer.file_processing.html_utils import web_html_cleanup
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.sitemap import list_pages_for_site
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -145,16 +146,21 @@ def extract_urls_from_sitemap(sitemap_url: str) -> list[str]:
|
||||
response.raise_for_status()
|
||||
|
||||
soup = BeautifulSoup(response.content, "html.parser")
|
||||
result = [
|
||||
urls = [
|
||||
_ensure_absolute_url(sitemap_url, loc_tag.text)
|
||||
for loc_tag in soup.find_all("loc")
|
||||
]
|
||||
if not result:
|
||||
|
||||
if len(urls) == 0 and len(soup.find_all("urlset")) == 0:
|
||||
# the given url doesn't look like a sitemap, let's try to find one
|
||||
urls = list_pages_for_site(sitemap_url)
|
||||
|
||||
if len(urls) == 0:
|
||||
raise ValueError(
|
||||
f"No URLs found in sitemap {sitemap_url}. Try using the 'single' or 'recursive' scraping options instead."
|
||||
)
|
||||
|
||||
return result
|
||||
return urls
|
||||
|
||||
|
||||
def _ensure_absolute_url(source_url: str, maybe_relative_url: str) -> str:
|
||||
@@ -264,7 +270,7 @@ class WebConnector(LoadConnector):
|
||||
id=current_url,
|
||||
sections=[Section(link=current_url, text=page_text)],
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=current_url.split(".")[-1],
|
||||
semantic_identifier=current_url.split("/")[-1],
|
||||
metadata={},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from zenpy import Zenpy # type: ignore
|
||||
from zenpy.lib.api_objects.help_centre_objects import Article # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.app_configs import ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
time_str_to_utc,
|
||||
@@ -81,7 +82,14 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
doc_batch = []
|
||||
for article in articles:
|
||||
if article.body is None or article.draft:
|
||||
if (
|
||||
article.body is None
|
||||
or article.draft
|
||||
or any(
|
||||
label in ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
|
||||
for label in article.label_names
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
doc_batch.append(_article_to_document(article))
|
||||
|
||||
@@ -25,6 +25,7 @@ from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.icons import source_to_github_img_link
|
||||
@@ -353,6 +354,22 @@ def build_quotes_block(
|
||||
return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))]
|
||||
|
||||
|
||||
def build_standard_answer_blocks(
|
||||
answer_message: str,
|
||||
) -> list[Block]:
|
||||
generate_button_block = ButtonElement(
|
||||
action_id=GENERATE_ANSWER_BUTTON_ACTION_ID,
|
||||
text="Generate Full Answer",
|
||||
)
|
||||
answer_block = SectionBlock(text=answer_message)
|
||||
return [
|
||||
answer_block,
|
||||
ActionsBlock(
|
||||
elements=[generate_button_block],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def build_qa_response_blocks(
|
||||
message_id: int | None,
|
||||
answer: str | None,
|
||||
@@ -457,7 +474,7 @@ def build_follow_up_resolved_blocks(
|
||||
if tag_str:
|
||||
tag_str += " "
|
||||
|
||||
group_str = " ".join([f"<!subteam^{group}>" for group in group_ids])
|
||||
group_str = " ".join([f"<!subteam^{group_id}|>" for group_id in group_ids])
|
||||
if group_str:
|
||||
group_str += " "
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ FOLLOWUP_BUTTON_ACTION_ID = "followup-button"
|
||||
FOLLOWUP_BUTTON_RESOLVED_ACTION_ID = "followup-resolved-button"
|
||||
SLACK_CHANNEL_ID = "channel_id"
|
||||
VIEW_DOC_FEEDBACK_ID = "view-doc-feedback"
|
||||
GENERATE_ANSWER_BUTTON_ACTION_ID = "generate-answer-button"
|
||||
|
||||
|
||||
class FeedbackVisibility(str, Enum):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -8,6 +9,7 @@ from slack_sdk.socket_mode import SocketModeClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
@@ -21,12 +23,17 @@ from danswer.danswerbot.slack.constants import VIEW_DOC_FEEDBACK_ID
|
||||
from danswer.danswerbot.slack.handlers.handle_message import (
|
||||
remove_scheduled_feedback_reminder,
|
||||
)
|
||||
from danswer.danswerbot.slack.handlers.handle_regular_answer import (
|
||||
handle_regular_answer,
|
||||
)
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import build_feedback_id
|
||||
from danswer.danswerbot.slack.utils import decompose_action_id
|
||||
from danswer.danswerbot.slack.utils import fetch_groupids_from_names
|
||||
from danswer.danswerbot.slack.utils import fetch_userids_from_emails
|
||||
from danswer.danswerbot.slack.utils import fetch_group_ids_from_names
|
||||
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
|
||||
from danswer.danswerbot.slack.utils import get_channel_name_from_id
|
||||
from danswer.danswerbot.slack.utils import get_feedback_visibility
|
||||
from danswer.danswerbot.slack.utils import read_slack_thread
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import update_emote_react
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
@@ -36,7 +43,7 @@ from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger_base = setup_logger()
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def handle_doc_feedback_button(
|
||||
@@ -44,7 +51,7 @@ def handle_doc_feedback_button(
|
||||
client: SocketModeClient,
|
||||
) -> None:
|
||||
if not (actions := req.payload.get("actions")):
|
||||
logger_base.error("Missing actions. Unable to build the source feedback view")
|
||||
logger.error("Missing actions. Unable to build the source feedback view")
|
||||
return
|
||||
|
||||
# Extracts the feedback_id coming from the 'source feedback' button
|
||||
@@ -72,6 +79,66 @@ def handle_doc_feedback_button(
|
||||
)
|
||||
|
||||
|
||||
def handle_generate_answer_button(
|
||||
req: SocketModeRequest,
|
||||
client: SocketModeClient,
|
||||
) -> None:
|
||||
channel_id = req.payload["channel"]["id"]
|
||||
channel_name = req.payload["channel"]["name"]
|
||||
message_ts = req.payload["message"]["ts"]
|
||||
thread_ts = req.payload["container"]["thread_ts"]
|
||||
user_id = req.payload["user"]["id"]
|
||||
|
||||
if not thread_ts:
|
||||
raise ValueError("Missing thread_ts in the payload")
|
||||
|
||||
thread_messages = read_slack_thread(
|
||||
channel=channel_id, thread=thread_ts, client=client.web_client
|
||||
)
|
||||
# remove all assistant messages till we get to the last user message
|
||||
# we want the new answer to be generated off of the last "question" in
|
||||
# the thread
|
||||
for i in range(len(thread_messages) - 1, -1, -1):
|
||||
if thread_messages[i].role == MessageType.USER:
|
||||
break
|
||||
if thread_messages[i].role == MessageType.ASSISTANT:
|
||||
thread_messages.pop(i)
|
||||
|
||||
# tell the user that we're working on it
|
||||
# Send an ephemeral message to the user that we're generating the answer
|
||||
respond_in_thread(
|
||||
client=client.web_client,
|
||||
channel=channel_id,
|
||||
receiver_ids=[user_id],
|
||||
text="I'm working on generating a full answer for you. This may take a moment...",
|
||||
thread_ts=thread_ts,
|
||||
)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
slack_bot_config = get_slack_bot_config_for_channel(
|
||||
channel_name=channel_name, db_session=db_session
|
||||
)
|
||||
|
||||
handle_regular_answer(
|
||||
message_info=SlackMessageInfo(
|
||||
thread_messages=thread_messages,
|
||||
channel_to_respond=channel_id,
|
||||
msg_to_respond=cast(str, message_ts or thread_ts),
|
||||
thread_to_respond=cast(str, thread_ts or message_ts),
|
||||
sender=user_id or None,
|
||||
bypass_filters=True,
|
||||
is_bot_msg=False,
|
||||
is_bot_dm=False,
|
||||
),
|
||||
slack_bot_config=slack_bot_config,
|
||||
receiver_ids=None,
|
||||
client=client.web_client,
|
||||
channel=channel_id,
|
||||
logger=cast(logging.Logger, logger),
|
||||
feedback_reminder_id=None,
|
||||
)
|
||||
|
||||
|
||||
def handle_slack_feedback(
|
||||
feedback_id: str,
|
||||
feedback_type: str,
|
||||
@@ -129,7 +196,7 @@ def handle_slack_feedback(
|
||||
feedback=feedback,
|
||||
)
|
||||
else:
|
||||
logger_base.error(f"Feedback type '{feedback_type}' not supported")
|
||||
logger.error(f"Feedback type '{feedback_type}' not supported")
|
||||
|
||||
if get_feedback_visibility() == FeedbackVisibility.PRIVATE or feedback_type not in [
|
||||
LIKE_BLOCK_ACTION_ID,
|
||||
@@ -193,11 +260,11 @@ def handle_followup_button(
|
||||
tag_names = slack_bot_config.channel_config.get("follow_up_tags")
|
||||
remaining = None
|
||||
if tag_names:
|
||||
tag_ids, remaining = fetch_userids_from_emails(
|
||||
tag_ids, remaining = fetch_user_ids_from_emails(
|
||||
tag_names, client.web_client
|
||||
)
|
||||
if remaining:
|
||||
group_ids, _ = fetch_groupids_from_names(remaining, client.web_client)
|
||||
group_ids, _ = fetch_group_ids_from_names(remaining, client.web_client)
|
||||
|
||||
blocks = build_follow_up_resolved_blocks(tag_ids=tag_ids, group_ids=group_ids)
|
||||
|
||||
@@ -272,7 +339,7 @@ def handle_followup_resolved_button(
|
||||
)
|
||||
|
||||
if not response.get("ok"):
|
||||
logger_base.error("Unable to delete message for resolved")
|
||||
logger.error("Unable to delete message for resolved")
|
||||
|
||||
if immediate:
|
||||
msg_text = f"{clicker_name} has marked this question as resolved!"
|
||||
|
||||
@@ -1,92 +1,34 @@
|
||||
import datetime
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.models.blocks import DividerBlock
|
||||
from slack_sdk.models.blocks import SectionBlock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_FEEDBACK_REMINDER
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
|
||||
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
|
||||
from danswer.configs.danswerbot_configs import DISABLE_DANSWER_BOT_FILTER_DETECT
|
||||
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
|
||||
from danswer.danswerbot.slack.blocks import build_documents_blocks
|
||||
from danswer.danswerbot.slack.blocks import build_follow_up_block
|
||||
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
|
||||
from danswer.danswerbot.slack.blocks import build_sources_blocks
|
||||
from danswer.danswerbot.slack.blocks import get_feedback_reminder_blocks
|
||||
from danswer.danswerbot.slack.blocks import get_restate_blocks
|
||||
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
|
||||
from danswer.danswerbot.slack.handlers.handle_regular_answer import (
|
||||
handle_regular_answer,
|
||||
)
|
||||
from danswer.danswerbot.slack.handlers.handle_standard_answers import (
|
||||
handle_standard_answers,
|
||||
)
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import ChannelIdAdapter
|
||||
from danswer.danswerbot.slack.utils import fetch_userids_from_emails
|
||||
from danswer.danswerbot.slack.utils import fetch_userids_from_groups
|
||||
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
|
||||
from danswer.danswerbot.slack.utils import fetch_user_ids_from_groups
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import slack_usage_report
|
||||
from danswer.danswerbot.slack.utils import SlackRateLimiter
|
||||
from danswer.danswerbot.slack.utils import update_emote_react
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.persona import fetch_persona_by_id
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.one_shot_answer.answer_question import get_search_answer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.search.models import BaseFilters
|
||||
from danswer.search.models import OptionalSearchSetting
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
|
||||
|
||||
logger_base = setup_logger()
|
||||
|
||||
srl = SlackRateLimiter()
|
||||
|
||||
RT = TypeVar("RT") # return type
|
||||
|
||||
|
||||
def rate_limits(
|
||||
client: WebClient, channel: str, thread_ts: Optional[str]
|
||||
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
|
||||
def decorator(func: Callable[..., RT]) -> Callable[..., RT]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> RT:
|
||||
if not srl.is_available():
|
||||
func_randid, position = srl.init_waiter()
|
||||
srl.notify(client, channel, position, thread_ts)
|
||||
while not srl.is_available():
|
||||
srl.waiter(func_randid)
|
||||
srl.acquire_slot()
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None:
|
||||
if details.is_bot_msg and details.sender:
|
||||
@@ -174,17 +116,9 @@ def remove_scheduled_feedback_reminder(
|
||||
|
||||
def handle_message(
|
||||
message_info: SlackMessageInfo,
|
||||
channel_config: SlackBotConfig | None,
|
||||
slack_bot_config: SlackBotConfig | None,
|
||||
client: WebClient,
|
||||
feedback_reminder_id: str | None,
|
||||
num_retries: int = DANSWER_BOT_NUM_RETRIES,
|
||||
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
|
||||
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
|
||||
disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER,
|
||||
disable_auto_detect_filters: bool = DISABLE_DANSWER_BOT_FILTER_DETECT,
|
||||
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
|
||||
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
|
||||
thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
|
||||
) -> bool:
|
||||
"""Potentially respond to the user message depending on filters and if an answer was generated
|
||||
|
||||
@@ -201,14 +135,22 @@ def handle_message(
|
||||
)
|
||||
|
||||
messages = message_info.thread_messages
|
||||
message_ts_to_respond_to = message_info.msg_to_respond
|
||||
sender_id = message_info.sender
|
||||
bypass_filters = message_info.bypass_filters
|
||||
is_bot_msg = message_info.is_bot_msg
|
||||
is_bot_dm = message_info.is_bot_dm
|
||||
|
||||
action = "slack_message"
|
||||
if is_bot_msg:
|
||||
action = "slack_slash_message"
|
||||
elif bypass_filters:
|
||||
action = "slack_tag_message"
|
||||
elif is_bot_dm:
|
||||
action = "slack_dm_message"
|
||||
slack_usage_report(action=action, sender_id=sender_id, client=client)
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
persona = channel_config.persona if channel_config else None
|
||||
persona = slack_bot_config.persona if slack_bot_config else None
|
||||
prompt = None
|
||||
if persona:
|
||||
document_set_names = [
|
||||
@@ -216,36 +158,13 @@ def handle_message(
|
||||
]
|
||||
prompt = persona.prompts[0] if persona.prompts else None
|
||||
|
||||
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
|
||||
|
||||
# figure out if we want to use citations or quotes
|
||||
use_citations = (
|
||||
not DANSWER_BOT_USE_QUOTES
|
||||
if channel_config is None
|
||||
else channel_config.response_type == SlackBotResponseType.CITATIONS
|
||||
)
|
||||
|
||||
# List of user id to send message to, if None, send to everyone in channel
|
||||
send_to: list[str] | None = None
|
||||
respond_tag_only = False
|
||||
respond_team_member_list = None
|
||||
|
||||
bypass_acl = False
|
||||
if (
|
||||
channel_config
|
||||
and channel_config.persona
|
||||
and channel_config.persona.document_sets
|
||||
):
|
||||
# For Slack channels, use the full document set, admin will be warned when configuring it
|
||||
# with non-public document sets
|
||||
bypass_acl = True
|
||||
respond_member_group_list = None
|
||||
|
||||
channel_conf = None
|
||||
if channel_config and channel_config.channel_config:
|
||||
channel_conf = channel_config.channel_config
|
||||
if slack_bot_config and slack_bot_config.channel_config:
|
||||
channel_conf = slack_bot_config.channel_config
|
||||
if not bypass_filters and "answer_filters" in channel_conf:
|
||||
reflexion = "well_answered_postfilter" in channel_conf["answer_filters"]
|
||||
|
||||
if (
|
||||
"questionmark_prefilter" in channel_conf["answer_filters"]
|
||||
and "?" not in messages[-1].message
|
||||
@@ -262,8 +181,7 @@ def handle_message(
|
||||
)
|
||||
|
||||
respond_tag_only = channel_conf.get("respond_tag_only") or False
|
||||
respond_team_member_list = channel_conf.get("respond_team_member_list") or None
|
||||
respond_slack_group_list = channel_conf.get("respond_slack_group_list") or None
|
||||
respond_member_group_list = channel_conf.get("respond_member_group_list", None)
|
||||
|
||||
if respond_tag_only and not bypass_filters:
|
||||
logger.info(
|
||||
@@ -272,17 +190,23 @@ def handle_message(
|
||||
)
|
||||
return False
|
||||
|
||||
if respond_team_member_list:
|
||||
send_to, _ = fetch_userids_from_emails(respond_team_member_list, client)
|
||||
if respond_slack_group_list:
|
||||
user_ids, _ = fetch_userids_from_groups(respond_slack_group_list, client)
|
||||
send_to = (send_to + user_ids) if send_to else user_ids
|
||||
if send_to:
|
||||
send_to = list(set(send_to)) # remove duplicates
|
||||
# List of user id to send message to, if None, send to everyone in channel
|
||||
send_to: list[str] | None = None
|
||||
missing_users: list[str] | None = None
|
||||
if respond_member_group_list:
|
||||
send_to, missing_ids = fetch_user_ids_from_emails(
|
||||
respond_member_group_list, client
|
||||
)
|
||||
|
||||
user_ids, missing_users = fetch_user_ids_from_groups(missing_ids, client)
|
||||
send_to = list(set(send_to + user_ids)) if send_to else user_ids
|
||||
|
||||
if missing_users:
|
||||
logger.warning(f"Failed to find these users/groups: {missing_users}")
|
||||
|
||||
# If configured to respond to team members only, then cannot be used with a /DanswerBot command
|
||||
# which would just respond to the sender
|
||||
if (respond_team_member_list or respond_slack_group_list) and is_bot_msg:
|
||||
if send_to and is_bot_msg:
|
||||
if sender_id:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
@@ -297,324 +221,28 @@ def handle_message(
|
||||
except SlackApiError as e:
|
||||
logger.error(f"Was not able to react to user message due to: {e}")
|
||||
|
||||
@retry(
|
||||
tries=num_retries,
|
||||
delay=0.25,
|
||||
backoff=2,
|
||||
logger=logger,
|
||||
)
|
||||
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
|
||||
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | None:
|
||||
action = "slack_message"
|
||||
if is_bot_msg:
|
||||
action = "slack_slash_message"
|
||||
elif bypass_filters:
|
||||
action = "slack_tag_message"
|
||||
elif is_bot_dm:
|
||||
action = "slack_dm_message"
|
||||
|
||||
slack_usage_report(action=action, sender_id=sender_id, client=client)
|
||||
|
||||
max_document_tokens: int | None = None
|
||||
max_history_tokens: int | None = None
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if len(new_message_request.messages) > 1:
|
||||
persona = cast(
|
||||
Persona,
|
||||
fetch_persona_by_id(db_session, new_message_request.persona_id),
|
||||
)
|
||||
llm, _ = get_llms_for_persona(persona)
|
||||
|
||||
# In cases of threads, split the available tokens between docs and thread context
|
||||
input_tokens = get_max_input_tokens(
|
||||
model_name=llm.config.model_name,
|
||||
model_provider=llm.config.model_provider,
|
||||
)
|
||||
max_history_tokens = int(input_tokens * thread_context_percent)
|
||||
|
||||
remaining_tokens = input_tokens - max_history_tokens
|
||||
|
||||
query_text = new_message_request.messages[0].message
|
||||
if persona:
|
||||
max_document_tokens = compute_max_document_tokens_for_persona(
|
||||
persona=persona,
|
||||
actual_user_input=query_text,
|
||||
max_llm_token_override=remaining_tokens,
|
||||
)
|
||||
else:
|
||||
max_document_tokens = (
|
||||
remaining_tokens
|
||||
- 512 # Needs to be more than any of the QA prompts
|
||||
- check_number_of_tokens(query_text)
|
||||
)
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
return None
|
||||
|
||||
# This also handles creating the query event in postgres
|
||||
answer = get_search_answer(
|
||||
query_req=new_message_request,
|
||||
user=None,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
db_session=db_session,
|
||||
answer_generation_timeout=answer_generation_timeout,
|
||||
enable_reflexion=reflexion,
|
||||
bypass_acl=bypass_acl,
|
||||
use_citations=use_citations,
|
||||
danswerbot_flow=True,
|
||||
)
|
||||
if not answer.error_msg:
|
||||
return answer
|
||||
else:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
|
||||
try:
|
||||
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
|
||||
# it allows the slack flow to extract out filters from the user query
|
||||
filters = BaseFilters(
|
||||
source_type=None,
|
||||
document_set=document_set_names,
|
||||
time_cutoff=None,
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# first check if we need to respond with a standard answer
|
||||
used_standard_answer = handle_standard_answers(
|
||||
message_info=message_info,
|
||||
receiver_ids=send_to,
|
||||
slack_bot_config=slack_bot_config,
|
||||
prompt=prompt,
|
||||
logger=logger,
|
||||
client=client,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Default True because no other ways to apply filters in Slack (no nice UI)
|
||||
auto_detect_filters = (
|
||||
persona.llm_filter_extraction if persona is not None else True
|
||||
)
|
||||
if disable_auto_detect_filters:
|
||||
auto_detect_filters = False
|
||||
|
||||
retrieval_details = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=False,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=auto_detect_filters,
|
||||
)
|
||||
|
||||
# This includes throwing out answer via reflexion
|
||||
answer = _get_answer(
|
||||
DirectQARequest(
|
||||
messages=messages,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
persona_id=persona.id if persona is not None else 0,
|
||||
retrieval_options=retrieval_details,
|
||||
chain_of_thought=not disable_cot,
|
||||
skip_rerank=not ENABLE_RERANKING_ASYNC_FLOW,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unable to process message - did not successfully answer "
|
||||
f"in {num_retries} attempts"
|
||||
)
|
||||
# Optionally, respond in thread with the error message, Used primarily
|
||||
# for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=None,
|
||||
text=f"Encountered exception when trying to answer: \n\n```{e}```",
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
# In case of failures, don't keep the reaction there permanently
|
||||
try:
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=True,
|
||||
client=client,
|
||||
)
|
||||
except SlackApiError as e:
|
||||
logger.error(f"Failed to remove Reaction due to: {e}")
|
||||
|
||||
return True
|
||||
|
||||
# Edge case handling, for tracking down the Slack usage issue
|
||||
if answer is None:
|
||||
assert DISABLE_GENERATIVE_AI is True
|
||||
try:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=send_to,
|
||||
text="Hello! Danswer has some results for you!",
|
||||
blocks=[
|
||||
SectionBlock(
|
||||
text="Danswer is down for maintenance.\nWe're working hard on recharging the AI!"
|
||||
)
|
||||
],
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
|
||||
unfurl=False,
|
||||
)
|
||||
|
||||
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
|
||||
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
|
||||
if respond_team_member_list or respond_slack_group_list:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
text=(
|
||||
"👋 Hi, we've just gathered and forwarded the relevant "
|
||||
+ "information to the team. They'll get back to you shortly!"
|
||||
),
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
if used_standard_answer:
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Unable to process message - could not respond in slack in {num_retries} attempts"
|
||||
)
|
||||
return True
|
||||
|
||||
# Got an answer at this point, can remove reaction and give results
|
||||
try:
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=True,
|
||||
client=client,
|
||||
)
|
||||
except SlackApiError as e:
|
||||
logger.error(f"Failed to remove Reaction due to: {e}")
|
||||
|
||||
if answer.answer_valid is False:
|
||||
logger.info(
|
||||
"Answer was evaluated to be invalid, throwing it away without responding."
|
||||
)
|
||||
update_emote_react(
|
||||
emoji=DANSWER_FOLLOWUP_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=False,
|
||||
client=client,
|
||||
)
|
||||
|
||||
if answer.answer:
|
||||
logger.debug(answer.answer)
|
||||
return True
|
||||
|
||||
retrieval_info = answer.docs
|
||||
if not retrieval_info:
|
||||
# This should not happen, even with no docs retrieved, there is still info returned
|
||||
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
|
||||
|
||||
top_docs = retrieval_info.top_documents
|
||||
if not top_docs and not should_respond_even_with_no_docs:
|
||||
logger.error(
|
||||
f"Unable to answer question: '{answer.rephrase}' - no documents found"
|
||||
)
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=None,
|
||||
text="Found no documents when trying to answer. Did you index any documents?",
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
return True
|
||||
|
||||
if not answer.answer and disable_docs_only_answer:
|
||||
logger.info(
|
||||
"Unable to find answer - not responding since the "
|
||||
"`DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER` env variable is set"
|
||||
)
|
||||
return True
|
||||
|
||||
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
|
||||
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
|
||||
|
||||
answer_blocks = build_qa_response_blocks(
|
||||
message_id=answer.chat_message_id,
|
||||
answer=answer.answer,
|
||||
quotes=answer.quotes.quotes if answer.quotes else None,
|
||||
source_filters=retrieval_info.applied_source_filters,
|
||||
time_cutoff=retrieval_info.applied_time_cutoff,
|
||||
favor_recent=retrieval_info.recency_bias_multiplier > 1,
|
||||
# currently Personas don't support quotes
|
||||
# if citations are enabled, also don't use quotes
|
||||
skip_quotes=persona is not None or use_citations,
|
||||
process_message_for_citations=use_citations,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
|
||||
# Get the chunks fed to the LLM only, then fill with other docs
|
||||
llm_doc_inds = answer.llm_chunks_indices or []
|
||||
llm_docs = [top_docs[i] for i in llm_doc_inds]
|
||||
remaining_docs = [
|
||||
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
|
||||
]
|
||||
priority_ordered_docs = llm_docs + remaining_docs
|
||||
|
||||
document_blocks = []
|
||||
citations_block = []
|
||||
# if citations are enabled, only show cited documents
|
||||
if use_citations:
|
||||
citations = answer.citations or []
|
||||
cited_docs = []
|
||||
for citation in citations:
|
||||
matching_doc = next(
|
||||
(d for d in top_docs if d.document_id == citation.document_id),
|
||||
None,
|
||||
)
|
||||
if matching_doc:
|
||||
cited_docs.append((citation.citation_num, matching_doc))
|
||||
|
||||
cited_docs.sort()
|
||||
citations_block = build_sources_blocks(cited_documents=cited_docs)
|
||||
elif priority_ordered_docs:
|
||||
document_blocks = build_documents_blocks(
|
||||
documents=priority_ordered_docs,
|
||||
message_id=answer.chat_message_id,
|
||||
)
|
||||
document_blocks = [DividerBlock()] + document_blocks
|
||||
|
||||
all_blocks = (
|
||||
restate_question_block + answer_blocks + citations_block + document_blocks
|
||||
)
|
||||
|
||||
if channel_conf and channel_conf.get("follow_up_tags") is not None:
|
||||
all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id))
|
||||
|
||||
try:
|
||||
respond_in_thread(
|
||||
# if no standard answer applies, try a regular answer
|
||||
issue_with_regular_answer = handle_regular_answer(
|
||||
message_info=message_info,
|
||||
slack_bot_config=slack_bot_config,
|
||||
receiver_ids=send_to,
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=send_to,
|
||||
text="Hello! Danswer has some results for you!",
|
||||
blocks=all_blocks,
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
|
||||
unfurl=False,
|
||||
logger=logger,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
|
||||
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
|
||||
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
|
||||
if respond_team_member_list or respond_slack_group_list:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
text=(
|
||||
"👋 Hi, we've just gathered and forwarded the relevant "
|
||||
+ "information to the team. They'll get back to you shortly!"
|
||||
),
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Unable to process message - could not respond in slack in {num_retries} attempts"
|
||||
)
|
||||
return True
|
||||
return issue_with_regular_answer
|
||||
|
||||
@@ -0,0 +1,465 @@
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.models.blocks import DividerBlock
|
||||
from slack_sdk.models.blocks import SectionBlock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
|
||||
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
|
||||
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
|
||||
from danswer.danswerbot.slack.blocks import build_documents_blocks
|
||||
from danswer.danswerbot.slack.blocks import build_follow_up_block
|
||||
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
|
||||
from danswer.danswerbot.slack.blocks import build_sources_blocks
|
||||
from danswer.danswerbot.slack.blocks import get_restate_blocks
|
||||
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import SlackRateLimiter
|
||||
from danswer.danswerbot.slack.utils import update_emote_react
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.persona import fetch_persona_by_id
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.one_shot_answer.answer_question import get_search_answer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.models import BaseFilters
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
|
||||
|
||||
|
||||
srl = SlackRateLimiter()
|
||||
|
||||
RT = TypeVar("RT") # return type
|
||||
|
||||
|
||||
def rate_limits(
|
||||
client: WebClient, channel: str, thread_ts: Optional[str]
|
||||
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
|
||||
def decorator(func: Callable[..., RT]) -> Callable[..., RT]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> RT:
|
||||
if not srl.is_available():
|
||||
func_randid, position = srl.init_waiter()
|
||||
srl.notify(client, channel, position, thread_ts)
|
||||
while not srl.is_available():
|
||||
srl.waiter(func_randid)
|
||||
srl.acquire_slot()
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def handle_regular_answer(
|
||||
message_info: SlackMessageInfo,
|
||||
slack_bot_config: SlackBotConfig | None,
|
||||
receiver_ids: list[str] | None,
|
||||
client: WebClient,
|
||||
channel: str,
|
||||
logger: logging.Logger,
|
||||
feedback_reminder_id: str | None,
|
||||
num_retries: int = DANSWER_BOT_NUM_RETRIES,
|
||||
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
|
||||
thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
|
||||
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
|
||||
disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER,
|
||||
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
|
||||
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
|
||||
) -> bool:
|
||||
channel_conf = slack_bot_config.channel_config if slack_bot_config else None
|
||||
|
||||
messages = message_info.thread_messages
|
||||
message_ts_to_respond_to = message_info.msg_to_respond
|
||||
is_bot_msg = message_info.is_bot_msg
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
persona = slack_bot_config.persona if slack_bot_config else None
|
||||
prompt = None
|
||||
if persona:
|
||||
document_set_names = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
]
|
||||
prompt = persona.prompts[0] if persona.prompts else None
|
||||
|
||||
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
|
||||
|
||||
bypass_acl = False
|
||||
if (
|
||||
slack_bot_config
|
||||
and slack_bot_config.persona
|
||||
and slack_bot_config.persona.document_sets
|
||||
):
|
||||
# For Slack channels, use the full document set, admin will be warned when configuring it
|
||||
# with non-public document sets
|
||||
bypass_acl = True
|
||||
|
||||
# figure out if we want to use citations or quotes
|
||||
use_citations = (
|
||||
not DANSWER_BOT_USE_QUOTES
|
||||
if slack_bot_config is None
|
||||
else slack_bot_config.response_type == SlackBotResponseType.CITATIONS
|
||||
)
|
||||
|
||||
if not message_ts_to_respond_to:
|
||||
raise RuntimeError(
|
||||
"No message timestamp to respond to in `handle_message`. This should never happen."
|
||||
)
|
||||
|
||||
@retry(
|
||||
tries=num_retries,
|
||||
delay=0.25,
|
||||
backoff=2,
|
||||
logger=logger,
|
||||
)
|
||||
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
|
||||
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | None:
|
||||
max_document_tokens: int | None = None
|
||||
max_history_tokens: int | None = None
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if len(new_message_request.messages) > 1:
|
||||
persona = cast(
|
||||
Persona,
|
||||
fetch_persona_by_id(db_session, new_message_request.persona_id),
|
||||
)
|
||||
llm, _ = get_llms_for_persona(persona)
|
||||
|
||||
# In cases of threads, split the available tokens between docs and thread context
|
||||
input_tokens = get_max_input_tokens(
|
||||
model_name=llm.config.model_name,
|
||||
model_provider=llm.config.model_provider,
|
||||
)
|
||||
max_history_tokens = int(input_tokens * thread_context_percent)
|
||||
|
||||
remaining_tokens = input_tokens - max_history_tokens
|
||||
|
||||
query_text = new_message_request.messages[0].message
|
||||
if persona:
|
||||
max_document_tokens = compute_max_document_tokens_for_persona(
|
||||
persona=persona,
|
||||
actual_user_input=query_text,
|
||||
max_llm_token_override=remaining_tokens,
|
||||
)
|
||||
else:
|
||||
max_document_tokens = (
|
||||
remaining_tokens
|
||||
- 512 # Needs to be more than any of the QA prompts
|
||||
- check_number_of_tokens(query_text)
|
||||
)
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
return None
|
||||
|
||||
# This also handles creating the query event in postgres
|
||||
answer = get_search_answer(
|
||||
query_req=new_message_request,
|
||||
user=None,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
db_session=db_session,
|
||||
answer_generation_timeout=answer_generation_timeout,
|
||||
enable_reflexion=reflexion,
|
||||
bypass_acl=bypass_acl,
|
||||
use_citations=use_citations,
|
||||
danswerbot_flow=True,
|
||||
)
|
||||
if not answer.error_msg:
|
||||
return answer
|
||||
else:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
|
||||
try:
|
||||
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
|
||||
# it allows the slack flow to extract out filters from the user query
|
||||
filters = BaseFilters(
|
||||
source_type=None,
|
||||
document_set=document_set_names,
|
||||
time_cutoff=None,
|
||||
)
|
||||
|
||||
# Default True because no other ways to apply filters in Slack (no nice UI)
|
||||
# Commenting this out because this is only available to the slackbot for now
|
||||
# later we plan to implement this at the persona level where this will get
|
||||
# commented back in
|
||||
# auto_detect_filters = (
|
||||
# persona.llm_filter_extraction if persona is not None else True
|
||||
# )
|
||||
auto_detect_filters = (
|
||||
slack_bot_config.enable_auto_filters
|
||||
if slack_bot_config is not None
|
||||
else False
|
||||
)
|
||||
retrieval_details = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=False,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=auto_detect_filters,
|
||||
)
|
||||
|
||||
# This includes throwing out answer via reflexion
|
||||
answer = _get_answer(
|
||||
DirectQARequest(
|
||||
messages=messages,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
persona_id=persona.id if persona is not None else 0,
|
||||
retrieval_options=retrieval_details,
|
||||
chain_of_thought=not disable_cot,
|
||||
skip_rerank=not ENABLE_RERANKING_ASYNC_FLOW,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unable to process message - did not successfully answer "
|
||||
f"in {num_retries} attempts"
|
||||
)
|
||||
# Optionally, respond in thread with the error message, Used primarily
|
||||
# for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=None,
|
||||
text=f"Encountered exception when trying to answer: \n\n```{e}```",
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
# In case of failures, don't keep the reaction there permanently
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=True,
|
||||
client=client,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
# Edge case handling, for tracking down the Slack usage issue
|
||||
if answer is None:
|
||||
assert DISABLE_GENERATIVE_AI is True
|
||||
try:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=receiver_ids,
|
||||
text="Hello! Danswer has some results for you!",
|
||||
blocks=[
|
||||
SectionBlock(
|
||||
text="Danswer is down for maintenance.\nWe're working hard on recharging the AI!"
|
||||
)
|
||||
],
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
|
||||
unfurl=False,
|
||||
)
|
||||
|
||||
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
|
||||
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
|
||||
if receiver_ids:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
text=(
|
||||
"👋 Hi, we've just gathered and forwarded the relevant "
|
||||
+ "information to the team. They'll get back to you shortly!"
|
||||
),
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Unable to process message - could not respond in slack in {num_retries} attempts"
|
||||
)
|
||||
return True
|
||||
|
||||
# Got an answer at this point, can remove reaction and give results
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=True,
|
||||
client=client,
|
||||
)
|
||||
|
||||
if answer.answer_valid is False:
|
||||
logger.info(
|
||||
"Answer was evaluated to be invalid, throwing it away without responding."
|
||||
)
|
||||
update_emote_react(
|
||||
emoji=DANSWER_FOLLOWUP_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=False,
|
||||
client=client,
|
||||
)
|
||||
|
||||
if answer.answer:
|
||||
logger.debug(answer.answer)
|
||||
return True
|
||||
|
||||
retrieval_info = answer.docs
|
||||
if not retrieval_info:
|
||||
# This should not happen, even with no docs retrieved, there is still info returned
|
||||
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
|
||||
|
||||
top_docs = retrieval_info.top_documents
|
||||
if not top_docs and not should_respond_even_with_no_docs:
|
||||
logger.error(
|
||||
f"Unable to answer question: '{answer.rephrase}' - no documents found"
|
||||
)
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=None,
|
||||
text="Found no documents when trying to answer. Did you index any documents?",
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
return True
|
||||
|
||||
if not answer.answer and disable_docs_only_answer:
|
||||
logger.info(
|
||||
"Unable to find answer - not responding since the "
|
||||
"`DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER` env variable is set"
|
||||
)
|
||||
return True
|
||||
|
||||
only_respond_with_citations_or_quotes = (
|
||||
channel_conf
|
||||
and "well_answered_postfilter" in channel_conf.get("answer_filters", [])
|
||||
)
|
||||
has_citations_or_quotes = bool(answer.citations or answer.quotes)
|
||||
if (
|
||||
only_respond_with_citations_or_quotes
|
||||
and not has_citations_or_quotes
|
||||
and not message_info.bypass_filters
|
||||
):
|
||||
logger.error(
|
||||
f"Unable to find citations or quotes to answer: '{answer.rephrase}' - not answering!"
|
||||
)
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=None,
|
||||
text="Found no citations or quotes when trying to answer.",
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
return True
|
||||
|
||||
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
|
||||
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
|
||||
|
||||
answer_blocks = build_qa_response_blocks(
|
||||
message_id=answer.chat_message_id,
|
||||
answer=answer.answer,
|
||||
quotes=answer.quotes.quotes if answer.quotes else None,
|
||||
source_filters=retrieval_info.applied_source_filters,
|
||||
time_cutoff=retrieval_info.applied_time_cutoff,
|
||||
favor_recent=retrieval_info.recency_bias_multiplier > 1,
|
||||
# currently Personas don't support quotes
|
||||
# if citations are enabled, also don't use quotes
|
||||
skip_quotes=persona is not None or use_citations,
|
||||
process_message_for_citations=use_citations,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
|
||||
# Get the chunks fed to the LLM only, then fill with other docs
|
||||
llm_doc_inds = answer.llm_chunks_indices or []
|
||||
llm_docs = [top_docs[i] for i in llm_doc_inds]
|
||||
remaining_docs = [
|
||||
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
|
||||
]
|
||||
priority_ordered_docs = llm_docs + remaining_docs
|
||||
|
||||
document_blocks = []
|
||||
citations_block = []
|
||||
# if citations are enabled, only show cited documents
|
||||
if use_citations:
|
||||
citations = answer.citations or []
|
||||
cited_docs = []
|
||||
for citation in citations:
|
||||
matching_doc = next(
|
||||
(d for d in top_docs if d.document_id == citation.document_id),
|
||||
None,
|
||||
)
|
||||
if matching_doc:
|
||||
cited_docs.append((citation.citation_num, matching_doc))
|
||||
|
||||
cited_docs.sort()
|
||||
citations_block = build_sources_blocks(cited_documents=cited_docs)
|
||||
elif priority_ordered_docs:
|
||||
document_blocks = build_documents_blocks(
|
||||
documents=priority_ordered_docs,
|
||||
message_id=answer.chat_message_id,
|
||||
)
|
||||
document_blocks = [DividerBlock()] + document_blocks
|
||||
|
||||
all_blocks = (
|
||||
restate_question_block + answer_blocks + citations_block + document_blocks
|
||||
)
|
||||
|
||||
if channel_conf and channel_conf.get("follow_up_tags") is not None:
|
||||
all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id))
|
||||
|
||||
try:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=receiver_ids,
|
||||
text="Hello! Danswer has some results for you!",
|
||||
blocks=all_blocks,
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
|
||||
unfurl=False,
|
||||
)
|
||||
|
||||
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
|
||||
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
|
||||
if receiver_ids:
|
||||
send_team_member_message(
|
||||
client=client,
|
||||
channel=channel,
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Unable to process message - could not respond in slack in {num_retries} attempts"
|
||||
)
|
||||
return True
|
||||
@@ -0,0 +1,216 @@
|
||||
import logging
|
||||
|
||||
from slack_sdk import WebClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
|
||||
from danswer.danswerbot.slack.blocks import build_standard_answer_blocks
|
||||
from danswer.danswerbot.slack.blocks import get_restate_blocks
|
||||
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import update_emote_react
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_chat_messages_by_sessions
|
||||
from danswer.db.chat import get_chat_sessions_by_slack_thread_id
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.standard_answer import fetch_standard_answer_categories_by_names
|
||||
from danswer.db.standard_answer import find_matching_standard_answers
|
||||
from danswer.server.manage.models import StandardAnswer
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def oneoff_standard_answers(
|
||||
message: str,
|
||||
slack_bot_categories: list[str],
|
||||
db_session: Session,
|
||||
) -> list[StandardAnswer]:
|
||||
"""
|
||||
Respond to the user message if it matches any configured standard answers.
|
||||
|
||||
Returns a list of matching StandardAnswers if found, otherwise None.
|
||||
"""
|
||||
configured_standard_answers = {
|
||||
standard_answer
|
||||
for category in fetch_standard_answer_categories_by_names(
|
||||
slack_bot_categories, db_session=db_session
|
||||
)
|
||||
for standard_answer in category.standard_answers
|
||||
}
|
||||
|
||||
matching_standard_answers = find_matching_standard_answers(
|
||||
query=message,
|
||||
id_in=[answer.id for answer in configured_standard_answers],
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
server_standard_answers = [
|
||||
StandardAnswer.from_model(db_answer) for db_answer in matching_standard_answers
|
||||
]
|
||||
return server_standard_answers
|
||||
|
||||
|
||||
def handle_standard_answers(
|
||||
message_info: SlackMessageInfo,
|
||||
receiver_ids: list[str] | None,
|
||||
slack_bot_config: SlackBotConfig | None,
|
||||
prompt: Prompt | None,
|
||||
logger: logging.Logger,
|
||||
client: WebClient,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Potentially respond to the user message depending on whether the user's message matches
|
||||
any of the configured standard answers and also whether those answers have already been
|
||||
provided in the current thread.
|
||||
|
||||
Returns True if standard answers are found to match the user's message and therefore,
|
||||
we still need to respond to the users.
|
||||
"""
|
||||
# if no channel config, then no standard answers are configured
|
||||
if not slack_bot_config:
|
||||
return False
|
||||
|
||||
slack_thread_id = message_info.thread_to_respond
|
||||
configured_standard_answer_categories = (
|
||||
slack_bot_config.standard_answer_categories if slack_bot_config else []
|
||||
)
|
||||
configured_standard_answers = set(
|
||||
[
|
||||
standard_answer
|
||||
for standard_answer_category in configured_standard_answer_categories
|
||||
for standard_answer in standard_answer_category.standard_answers
|
||||
]
|
||||
)
|
||||
query_msg = message_info.thread_messages[-1]
|
||||
|
||||
if slack_thread_id is None:
|
||||
used_standard_answer_ids = set([])
|
||||
else:
|
||||
chat_sessions = get_chat_sessions_by_slack_thread_id(
|
||||
slack_thread_id=slack_thread_id,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
chat_messages = get_chat_messages_by_sessions(
|
||||
chat_session_ids=[chat_session.id for chat_session in chat_sessions],
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
skip_permission_check=True,
|
||||
)
|
||||
used_standard_answer_ids = set(
|
||||
[
|
||||
standard_answer.id
|
||||
for chat_message in chat_messages
|
||||
for standard_answer in chat_message.standard_answers
|
||||
]
|
||||
)
|
||||
|
||||
usable_standard_answers = configured_standard_answers.difference(
|
||||
used_standard_answer_ids
|
||||
)
|
||||
if usable_standard_answers:
|
||||
matching_standard_answers = find_matching_standard_answers(
|
||||
query=query_msg.message,
|
||||
id_in=[standard_answer.id for standard_answer in usable_standard_answers],
|
||||
db_session=db_session,
|
||||
)
|
||||
else:
|
||||
matching_standard_answers = []
|
||||
if matching_standard_answers:
|
||||
chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="",
|
||||
user_id=None,
|
||||
persona_id=slack_bot_config.persona.id if slack_bot_config.persona else 0,
|
||||
danswerbot_flow=True,
|
||||
slack_thread_id=slack_thread_id,
|
||||
one_shot=True,
|
||||
)
|
||||
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
|
||||
new_user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=root_message,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
message=query_msg.message,
|
||||
token_count=0,
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
formatted_answers = []
|
||||
for standard_answer in matching_standard_answers:
|
||||
block_quotified_answer = ">" + standard_answer.answer.replace("\n", "\n> ")
|
||||
formatted_answer = (
|
||||
f'Since you mentioned _"{standard_answer.keyword}"_, '
|
||||
f"I thought this might be useful: \n\n{block_quotified_answer}"
|
||||
)
|
||||
formatted_answers.append(formatted_answer)
|
||||
answer_message = "\n\n".join(formatted_answers)
|
||||
|
||||
_ = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=new_user_message,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
message=answer_message,
|
||||
token_count=0,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
error=None,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=True,
|
||||
client=client,
|
||||
)
|
||||
|
||||
restate_question_blocks = get_restate_blocks(
|
||||
msg=query_msg.message,
|
||||
is_bot_msg=message_info.is_bot_msg,
|
||||
)
|
||||
|
||||
answer_blocks = build_standard_answer_blocks(
|
||||
answer_message=answer_message,
|
||||
)
|
||||
|
||||
all_blocks = restate_question_blocks + answer_blocks
|
||||
|
||||
try:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=message_info.channel_to_respond,
|
||||
receiver_ids=receiver_ids,
|
||||
text="Hello! Danswer has some results for you!",
|
||||
blocks=all_blocks,
|
||||
thread_ts=message_info.msg_to_respond,
|
||||
unfurl=False,
|
||||
)
|
||||
|
||||
if receiver_ids and slack_thread_id:
|
||||
send_team_member_message(
|
||||
client=client,
|
||||
channel=message_info.channel_to_respond,
|
||||
thread_ts=slack_thread_id,
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to send standard answer message: {e}")
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
19
backend/danswer/danswerbot/slack/handlers/utils.py
Normal file
19
backend/danswer/danswerbot/slack/handlers/utils.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
|
||||
|
||||
def send_team_member_message(
|
||||
client: WebClient,
|
||||
channel: str,
|
||||
thread_ts: str,
|
||||
) -> None:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
text=(
|
||||
"👋 Hi, we've just gathered and forwarded the relevant "
|
||||
+ "information to the team. They'll get back to you shortly!"
|
||||
),
|
||||
thread_ts=thread_ts,
|
||||
)
|
||||
@@ -18,6 +18,7 @@ from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
|
||||
@@ -27,6 +28,9 @@ from danswer.danswerbot.slack.handlers.handle_buttons import handle_followup_but
|
||||
from danswer.danswerbot.slack.handlers.handle_buttons import (
|
||||
handle_followup_resolved_button,
|
||||
)
|
||||
from danswer.danswerbot.slack.handlers.handle_buttons import (
|
||||
handle_generate_answer_button,
|
||||
)
|
||||
from danswer.danswerbot.slack.handlers.handle_buttons import handle_slack_feedback
|
||||
from danswer.danswerbot.slack.handlers.handle_message import handle_message
|
||||
from danswer.danswerbot.slack.handlers.handle_message import (
|
||||
@@ -266,6 +270,7 @@ def build_request_details(
|
||||
thread_messages=thread_messages,
|
||||
channel_to_respond=channel,
|
||||
msg_to_respond=cast(str, message_ts or thread_ts),
|
||||
thread_to_respond=cast(str, thread_ts or message_ts),
|
||||
sender=event.get("user") or None,
|
||||
bypass_filters=tagged,
|
||||
is_bot_msg=False,
|
||||
@@ -283,6 +288,7 @@ def build_request_details(
|
||||
thread_messages=[single_msg],
|
||||
channel_to_respond=channel,
|
||||
msg_to_respond=None,
|
||||
thread_to_respond=None,
|
||||
sender=sender,
|
||||
bypass_filters=True,
|
||||
is_bot_msg=True,
|
||||
@@ -352,7 +358,7 @@ def process_message(
|
||||
|
||||
failed = handle_message(
|
||||
message_info=details,
|
||||
channel_config=slack_bot_config,
|
||||
slack_bot_config=slack_bot_config,
|
||||
client=client.web_client,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
@@ -390,6 +396,8 @@ def action_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
|
||||
return handle_followup_resolved_button(req, client, immediate=True)
|
||||
elif action["action_id"] == FOLLOWUP_BUTTON_RESOLVED_ACTION_ID:
|
||||
return handle_followup_resolved_button(req, client, immediate=False)
|
||||
elif action["action_id"] == GENERATE_ANSWER_BUTTON_ACTION_ID:
|
||||
return handle_generate_answer_button(req, client)
|
||||
|
||||
|
||||
def view_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
|
||||
@@ -461,13 +469,13 @@ if __name__ == "__main__":
|
||||
# or the tokens have updated (set up for the first time)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
warm_up_encoders(
|
||||
model_name=embedding_model.model_name,
|
||||
normalize=embedding_model.normalize,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
if embedding_model.cloud_provider_id is None:
|
||||
warm_up_encoders(
|
||||
model_name=embedding_model.model_name,
|
||||
normalize=embedding_model.normalize,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
slack_bot_tokens = latest_slack_bot_tokens
|
||||
# potentially may cause a message to be dropped, but it is complicated
|
||||
|
||||
@@ -7,6 +7,7 @@ class SlackMessageInfo(BaseModel):
|
||||
thread_messages: list[ThreadMessage]
|
||||
channel_to_respond: str
|
||||
msg_to_respond: str | None
|
||||
thread_to_respond: str | None
|
||||
sender: str | None
|
||||
bypass_filters: bool # User has tagged @DanswerBot
|
||||
is_bot_msg: bool # User is using /DanswerBot
|
||||
|
||||
@@ -77,17 +77,25 @@ def update_emote_react(
|
||||
remove: bool,
|
||||
client: WebClient,
|
||||
) -> None:
|
||||
if not message_ts:
|
||||
logger.error(f"Tried to remove a react in {channel} but no message specified")
|
||||
return
|
||||
try:
|
||||
if not message_ts:
|
||||
logger.error(
|
||||
f"Tried to remove a react in {channel} but no message specified"
|
||||
)
|
||||
return
|
||||
|
||||
func = client.reactions_remove if remove else client.reactions_add
|
||||
slack_call = make_slack_api_rate_limited(func) # type: ignore
|
||||
slack_call(
|
||||
name=emoji,
|
||||
channel=channel,
|
||||
timestamp=message_ts,
|
||||
)
|
||||
func = client.reactions_remove if remove else client.reactions_add
|
||||
slack_call = make_slack_api_rate_limited(func) # type: ignore
|
||||
slack_call(
|
||||
name=emoji,
|
||||
channel=channel,
|
||||
timestamp=message_ts,
|
||||
)
|
||||
except SlackApiError as e:
|
||||
if remove:
|
||||
logger.error(f"Failed to remove Reaction due to: {e}")
|
||||
else:
|
||||
logger.error(f"Was not able to react to user message due to: {e}")
|
||||
|
||||
|
||||
def get_danswer_bot_app_id(web_client: WebClient) -> Any:
|
||||
@@ -136,16 +144,13 @@ def respond_in_thread(
|
||||
receiver_ids: list[str] | None = None,
|
||||
metadata: Metadata | None = None,
|
||||
unfurl: bool = True,
|
||||
) -> None:
|
||||
) -> list[str]:
|
||||
if not text and not blocks:
|
||||
raise ValueError("One of `text` or `blocks` must be provided")
|
||||
|
||||
message_ids: list[str] = []
|
||||
if not receiver_ids:
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postMessage)
|
||||
else:
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
|
||||
|
||||
if not receiver_ids:
|
||||
response = slack_call(
|
||||
channel=channel,
|
||||
text=text,
|
||||
@@ -157,7 +162,9 @@ def respond_in_thread(
|
||||
)
|
||||
if not response.get("ok"):
|
||||
raise RuntimeError(f"Failed to post message: {response}")
|
||||
message_ids.append(response["message_ts"])
|
||||
else:
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
|
||||
for receiver in receiver_ids:
|
||||
response = slack_call(
|
||||
channel=channel,
|
||||
@@ -171,6 +178,9 @@ def respond_in_thread(
|
||||
)
|
||||
if not response.get("ok"):
|
||||
raise RuntimeError(f"Failed to post message: {response}")
|
||||
message_ids.append(response["message_ts"])
|
||||
|
||||
return message_ids
|
||||
|
||||
|
||||
def build_feedback_id(
|
||||
@@ -292,7 +302,7 @@ def get_channel_name_from_id(
|
||||
raise e
|
||||
|
||||
|
||||
def fetch_userids_from_emails(
|
||||
def fetch_user_ids_from_emails(
|
||||
user_emails: list[str], client: WebClient
|
||||
) -> tuple[list[str], list[str]]:
|
||||
user_ids: list[str] = []
|
||||
@@ -308,57 +318,72 @@ def fetch_userids_from_emails(
|
||||
return user_ids, failed_to_find
|
||||
|
||||
|
||||
def fetch_userids_from_groups(
|
||||
group_names: list[str], client: WebClient
|
||||
def fetch_user_ids_from_groups(
|
||||
given_names: list[str], client: WebClient
|
||||
) -> tuple[list[str], list[str]]:
|
||||
user_ids: list[str] = []
|
||||
failed_to_find: list[str] = []
|
||||
for group_name in group_names:
|
||||
try:
|
||||
# First, find the group ID from the group name
|
||||
response = client.usergroups_list()
|
||||
groups = {group["name"]: group["id"] for group in response["usergroups"]}
|
||||
group_id = groups.get(group_name)
|
||||
try:
|
||||
response = client.usergroups_list()
|
||||
if not isinstance(response.data, dict):
|
||||
logger.error("Error fetching user groups")
|
||||
return user_ids, given_names
|
||||
|
||||
if group_id:
|
||||
# Fetch user IDs for the group
|
||||
all_group_data = response.data.get("usergroups", [])
|
||||
name_id_map = {d["name"]: d["id"] for d in all_group_data}
|
||||
handle_id_map = {d["handle"]: d["id"] for d in all_group_data}
|
||||
for given_name in given_names:
|
||||
group_id = name_id_map.get(given_name) or handle_id_map.get(
|
||||
given_name.lstrip("@")
|
||||
)
|
||||
if not group_id:
|
||||
failed_to_find.append(given_name)
|
||||
continue
|
||||
try:
|
||||
response = client.usergroups_users_list(usergroup=group_id)
|
||||
user_ids.extend(response["users"])
|
||||
else:
|
||||
failed_to_find.append(group_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching user IDs for group {group_name}: {str(e)}")
|
||||
failed_to_find.append(group_name)
|
||||
if isinstance(response.data, dict):
|
||||
user_ids.extend(response.data.get("users", []))
|
||||
else:
|
||||
failed_to_find.append(given_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching user group ids: {str(e)}")
|
||||
failed_to_find.append(given_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching user groups: {str(e)}")
|
||||
failed_to_find = given_names
|
||||
|
||||
return user_ids, failed_to_find
|
||||
|
||||
|
||||
def fetch_groupids_from_names(
|
||||
names: list[str], client: WebClient
|
||||
def fetch_group_ids_from_names(
|
||||
given_names: list[str], client: WebClient
|
||||
) -> tuple[list[str], list[str]]:
|
||||
group_ids: set[str] = set()
|
||||
group_data: list[str] = []
|
||||
failed_to_find: list[str] = []
|
||||
|
||||
try:
|
||||
response = client.usergroups_list()
|
||||
if response.get("ok") and "usergroups" in response.data:
|
||||
all_groups_dicts = response.data["usergroups"] # type: ignore
|
||||
name_id_map = {d["name"]: d["id"] for d in all_groups_dicts}
|
||||
handle_id_map = {d["handle"]: d["id"] for d in all_groups_dicts}
|
||||
for group in names:
|
||||
if group in name_id_map:
|
||||
group_ids.add(name_id_map[group])
|
||||
elif group in handle_id_map:
|
||||
group_ids.add(handle_id_map[group])
|
||||
else:
|
||||
failed_to_find.append(group)
|
||||
else:
|
||||
# Most likely a Slack App scope issue
|
||||
if not isinstance(response.data, dict):
|
||||
logger.error("Error fetching user groups")
|
||||
return group_data, given_names
|
||||
|
||||
all_group_data = response.data.get("usergroups", [])
|
||||
|
||||
name_id_map = {d["name"]: d["id"] for d in all_group_data}
|
||||
handle_id_map = {d["handle"]: d["id"] for d in all_group_data}
|
||||
|
||||
for given_name in given_names:
|
||||
id = handle_id_map.get(given_name.lstrip("@"))
|
||||
id = id or name_id_map.get(given_name)
|
||||
if id:
|
||||
group_data.append(id)
|
||||
else:
|
||||
failed_to_find.append(given_name)
|
||||
except Exception as e:
|
||||
failed_to_find = given_names
|
||||
logger.error(f"Error fetching user groups: {str(e)}")
|
||||
|
||||
return list(group_ids), failed_to_find
|
||||
return group_data, failed_to_find
|
||||
|
||||
|
||||
def fetch_user_semantic_id_from_id(
|
||||
|
||||
@@ -1,16 +1,22 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import nullsfirst
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.exc import MultipleResultsFound
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.chat.models import LLMRelevanceSummaryResponse
|
||||
from danswer.configs.chat_configs import HARD_DELETE_CHATS
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.models import ChatMessage
|
||||
@@ -33,6 +39,7 @@ from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -67,17 +74,59 @@ def get_chat_session_by_id(
|
||||
return chat_session
|
||||
|
||||
|
||||
def get_chat_sessions_by_slack_thread_id(
|
||||
slack_thread_id: str,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> Sequence[ChatSession]:
|
||||
stmt = select(ChatSession).where(ChatSession.slack_thread_id == slack_thread_id)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(ChatSession.user_id == user_id, ChatSession.user_id.is_(None))
|
||||
)
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def get_first_messages_for_chat_sessions(
|
||||
chat_session_ids: list[int], db_session: Session
|
||||
) -> dict[int, str]:
|
||||
subquery = (
|
||||
select(ChatMessage.chat_session_id, func.min(ChatMessage.id).label("min_id"))
|
||||
.where(
|
||||
and_(
|
||||
ChatMessage.chat_session_id.in_(chat_session_ids),
|
||||
ChatMessage.message_type == MessageType.USER, # Select USER messages
|
||||
)
|
||||
)
|
||||
.group_by(ChatMessage.chat_session_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = select(ChatMessage.chat_session_id, ChatMessage.message).join(
|
||||
subquery,
|
||||
(ChatMessage.chat_session_id == subquery.c.chat_session_id)
|
||||
& (ChatMessage.id == subquery.c.min_id),
|
||||
)
|
||||
|
||||
first_messages = db_session.execute(query).all()
|
||||
return dict([(row.chat_session_id, row.message) for row in first_messages])
|
||||
|
||||
|
||||
def get_chat_sessions_by_user(
|
||||
user_id: UUID | None,
|
||||
deleted: bool | None,
|
||||
db_session: Session,
|
||||
include_one_shot: bool = False,
|
||||
only_one_shot: bool = False,
|
||||
) -> list[ChatSession]:
|
||||
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
|
||||
|
||||
if not include_one_shot:
|
||||
if only_one_shot:
|
||||
stmt = stmt.where(ChatSession.one_shot.is_(True))
|
||||
else:
|
||||
stmt = stmt.where(ChatSession.one_shot.is_(False))
|
||||
|
||||
stmt = stmt.order_by(desc(ChatSession.time_created))
|
||||
|
||||
if deleted is not None:
|
||||
stmt = stmt.where(ChatSession.deleted == deleted)
|
||||
|
||||
@@ -97,6 +146,12 @@ def delete_search_doc_message_relationship(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_tool_call_for_message_id(message_id: int, db_session: Session) -> None:
|
||||
stmt = delete(ToolCall).where(ToolCall.message_id == message_id)
|
||||
db_session.execute(stmt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_orphaned_search_docs(db_session: Session) -> None:
|
||||
orphaned_docs = (
|
||||
db_session.query(SearchDoc)
|
||||
@@ -120,6 +175,7 @@ def delete_messages_and_files_from_chat_session(
|
||||
).fetchall()
|
||||
|
||||
for id, files in messages_with_files:
|
||||
delete_tool_call_for_message_id(message_id=id, db_session=db_session)
|
||||
delete_search_doc_message_relationship(message_id=id, db_session=db_session)
|
||||
for file_info in files or {}:
|
||||
lobj_name = file_info.get("id")
|
||||
@@ -139,11 +195,12 @@ def create_chat_session(
|
||||
db_session: Session,
|
||||
description: str,
|
||||
user_id: UUID | None,
|
||||
persona_id: int | None = None,
|
||||
persona_id: int,
|
||||
llm_override: LLMOverride | None = None,
|
||||
prompt_override: PromptOverride | None = None,
|
||||
one_shot: bool = False,
|
||||
danswerbot_flow: bool = False,
|
||||
slack_thread_id: str | None = None,
|
||||
) -> ChatSession:
|
||||
chat_session = ChatSession(
|
||||
user_id=user_id,
|
||||
@@ -153,6 +210,7 @@ def create_chat_session(
|
||||
prompt_override=prompt_override,
|
||||
one_shot=one_shot,
|
||||
danswerbot_flow=danswerbot_flow,
|
||||
slack_thread_id=slack_thread_id,
|
||||
)
|
||||
|
||||
db_session.add(chat_session)
|
||||
@@ -240,6 +298,39 @@ def get_chat_message(
|
||||
return chat_message
|
||||
|
||||
|
||||
def get_chat_messages_by_sessions(
|
||||
chat_session_ids: list[int],
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
skip_permission_check: bool = False,
|
||||
) -> Sequence[ChatMessage]:
|
||||
if not skip_permission_check:
|
||||
for chat_session_id in chat_session_ids:
|
||||
get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
|
||||
)
|
||||
stmt = (
|
||||
select(ChatMessage)
|
||||
.where(ChatMessage.chat_session_id.in_(chat_session_ids))
|
||||
.order_by(nullsfirst(ChatMessage.parent_message))
|
||||
)
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def get_search_docs_for_chat_message(
|
||||
chat_message_id: int, db_session: Session
|
||||
) -> list[SearchDoc]:
|
||||
stmt = (
|
||||
select(SearchDoc)
|
||||
.join(
|
||||
ChatMessage__SearchDoc, ChatMessage__SearchDoc.search_doc_id == SearchDoc.id
|
||||
)
|
||||
.where(ChatMessage__SearchDoc.chat_message_id == chat_message_id)
|
||||
)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def get_chat_messages_by_session(
|
||||
chat_session_id: int,
|
||||
user_id: UUID | None,
|
||||
@@ -260,8 +351,6 @@ def get_chat_messages_by_session(
|
||||
|
||||
if prefetch_tool_calls:
|
||||
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
|
||||
|
||||
if prefetch_tool_calls:
|
||||
result = db_session.scalars(stmt).unique().all()
|
||||
else:
|
||||
result = db_session.scalars(stmt).all()
|
||||
@@ -449,6 +538,27 @@ def get_doc_query_identifiers_from_model(
|
||||
return doc_query_identifiers
|
||||
|
||||
|
||||
def update_search_docs_table_with_relevance(
|
||||
db_session: Session,
|
||||
reference_db_search_docs: list[SearchDoc],
|
||||
relevance_summary: LLMRelevanceSummaryResponse,
|
||||
) -> None:
|
||||
for search_doc in reference_db_search_docs:
|
||||
relevance_data = relevance_summary.relevance_summaries.get(
|
||||
f"{search_doc.document_id}-{search_doc.chunk_ind}"
|
||||
)
|
||||
if relevance_data is not None:
|
||||
db_session.execute(
|
||||
update(SearchDoc)
|
||||
.where(SearchDoc.id == search_doc.id)
|
||||
.values(
|
||||
is_relevant=relevance_data.relevant,
|
||||
relevance_explanation=relevance_data.content,
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_db_search_doc(
|
||||
server_search_doc: ServerSearchDoc,
|
||||
db_session: Session,
|
||||
@@ -463,17 +573,19 @@ def create_db_search_doc(
|
||||
boost=server_search_doc.boost,
|
||||
hidden=server_search_doc.hidden,
|
||||
doc_metadata=server_search_doc.metadata,
|
||||
is_relevant=server_search_doc.is_relevant,
|
||||
relevance_explanation=server_search_doc.relevance_explanation,
|
||||
# For docs further down that aren't reranked, we can't use the retrieval score
|
||||
score=server_search_doc.score or 0.0,
|
||||
match_highlights=server_search_doc.match_highlights,
|
||||
updated_at=server_search_doc.updated_at,
|
||||
primary_owners=server_search_doc.primary_owners,
|
||||
secondary_owners=server_search_doc.secondary_owners,
|
||||
is_internet=server_search_doc.is_internet,
|
||||
)
|
||||
|
||||
db_session.add(db_search_doc)
|
||||
db_session.commit()
|
||||
|
||||
return db_search_doc
|
||||
|
||||
|
||||
@@ -502,11 +614,14 @@ def translate_db_search_doc_to_server_search_doc(
|
||||
match_highlights=(
|
||||
db_search_doc.match_highlights if not remove_doc_content else []
|
||||
),
|
||||
relevance_explanation=db_search_doc.relevance_explanation,
|
||||
is_relevant=db_search_doc.is_relevant,
|
||||
updated_at=db_search_doc.updated_at if not remove_doc_content else None,
|
||||
primary_owners=db_search_doc.primary_owners if not remove_doc_content else [],
|
||||
secondary_owners=(
|
||||
db_search_doc.secondary_owners if not remove_doc_content else []
|
||||
),
|
||||
is_internet=db_search_doc.is_internet,
|
||||
)
|
||||
|
||||
|
||||
@@ -524,9 +639,11 @@ def get_retrieval_docs_from_chat_message(
|
||||
|
||||
|
||||
def translate_db_message_to_chat_message_detail(
|
||||
chat_message: ChatMessage, remove_doc_content: bool = False
|
||||
chat_message: ChatMessage,
|
||||
remove_doc_content: bool = False,
|
||||
) -> ChatMessageDetail:
|
||||
chat_msg_detail = ChatMessageDetail(
|
||||
chat_session_id=chat_message.chat_session_id,
|
||||
message_id=chat_message.id,
|
||||
parent_message=chat_message.parent_message,
|
||||
latest_child_message=chat_message.latest_child_message,
|
||||
|
||||
@@ -152,7 +152,7 @@ def add_credential_to_connector(
|
||||
credential_id: int,
|
||||
cc_pair_name: str | None,
|
||||
is_public: bool,
|
||||
user: User,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> StatusResponse[int]:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
@@ -201,7 +201,7 @@ def add_credential_to_connector(
|
||||
def remove_credential_from_connector(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
user: User,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> StatusResponse[int]:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
|
||||
@@ -12,6 +12,7 @@ from danswer.connectors.gmail.constants import (
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import User
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
@@ -142,6 +143,18 @@ def delete_credential(
|
||||
f"Credential by provided id {credential_id} does not exist or does not belong to user"
|
||||
)
|
||||
|
||||
associated_connectors = (
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.filter(ConnectorCredentialPair.credential_id == credential_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
if associated_connectors:
|
||||
raise ValueError(
|
||||
f"Cannot delete credential {credential_id} as it is still associated with {len(associated_connectors)} connector(s). "
|
||||
"Please delete all associated connectors first."
|
||||
)
|
||||
|
||||
db_session.delete(credential)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -10,10 +10,15 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
from danswer.db.llm import fetch_embedding_provider
|
||||
from danswer.db.models import CloudEmbeddingProvider
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.indexing.models import EmbeddingModelDetail
|
||||
from danswer.search.search_nlp_models import clean_model_name
|
||||
from danswer.server.manage.embedding.models import (
|
||||
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
|
||||
)
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -31,6 +36,7 @@ def create_embedding_model(
|
||||
query_prefix=model_details.query_prefix,
|
||||
passage_prefix=model_details.passage_prefix,
|
||||
status=status,
|
||||
cloud_provider_id=model_details.cloud_provider_id,
|
||||
# Every single embedding model except the initial one from migrations has this name
|
||||
# The initial one from migration is called "danswer_chunk"
|
||||
index_name=f"danswer_chunk_{clean_model_name(model_details.model_name)}",
|
||||
@@ -42,6 +48,42 @@ def create_embedding_model(
|
||||
return embedding_model
|
||||
|
||||
|
||||
def get_model_id_from_name(
|
||||
db_session: Session, embedding_provider_name: str
|
||||
) -> int | None:
|
||||
query = select(CloudEmbeddingProvider).where(
|
||||
CloudEmbeddingProvider.name == embedding_provider_name
|
||||
)
|
||||
provider = db_session.execute(query).scalars().first()
|
||||
return provider.id if provider else None
|
||||
|
||||
|
||||
def get_current_db_embedding_provider(
|
||||
db_session: Session,
|
||||
) -> ServerCloudEmbeddingProvider | None:
|
||||
current_embedding_model = EmbeddingModelDetail.from_model(
|
||||
get_current_db_embedding_model(db_session=db_session)
|
||||
)
|
||||
|
||||
if (
|
||||
current_embedding_model is None
|
||||
or current_embedding_model.cloud_provider_id is None
|
||||
):
|
||||
return None
|
||||
|
||||
embedding_provider = fetch_embedding_provider(
|
||||
db_session=db_session, provider_id=current_embedding_model.cloud_provider_id
|
||||
)
|
||||
if embedding_provider is None:
|
||||
raise RuntimeError("No embedding provider exists for this model.")
|
||||
|
||||
current_embedding_provider = ServerCloudEmbeddingProvider.from_request(
|
||||
cloud_provider_model=embedding_provider
|
||||
)
|
||||
|
||||
return current_embedding_provider
|
||||
|
||||
|
||||
def get_current_db_embedding_model(db_session: Session) -> EmbeddingModel:
|
||||
query = (
|
||||
select(EmbeddingModel)
|
||||
|
||||
@@ -2,11 +2,34 @@ from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from danswer.server.manage.llm.models import FullLLMProvider
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
|
||||
|
||||
def upsert_cloud_embedding_provider(
|
||||
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
|
||||
) -> CloudEmbeddingProvider:
|
||||
existing_provider = (
|
||||
db_session.query(CloudEmbeddingProviderModel)
|
||||
.filter_by(name=provider.name)
|
||||
.first()
|
||||
)
|
||||
if existing_provider:
|
||||
for key, value in provider.dict().items():
|
||||
setattr(existing_provider, key, value)
|
||||
else:
|
||||
new_provider = CloudEmbeddingProviderModel(**provider.dict())
|
||||
db_session.add(new_provider)
|
||||
existing_provider = new_provider
|
||||
db_session.commit()
|
||||
db_session.refresh(existing_provider)
|
||||
return CloudEmbeddingProvider.from_request(existing_provider)
|
||||
|
||||
|
||||
def upsert_llm_provider(
|
||||
db_session: Session, llm_provider: LLMProviderUpsertRequest
|
||||
) -> FullLLMProvider:
|
||||
@@ -26,7 +49,6 @@ def upsert_llm_provider(
|
||||
existing_llm_provider.model_names = llm_provider.model_names
|
||||
db_session.commit()
|
||||
return FullLLMProvider.from_model(existing_llm_provider)
|
||||
|
||||
# if it does not exist, create a new entry
|
||||
llm_provider_model = LLMProviderModel(
|
||||
name=llm_provider.name,
|
||||
@@ -46,10 +68,26 @@ def upsert_llm_provider(
|
||||
return FullLLMProvider.from_model(llm_provider_model)
|
||||
|
||||
|
||||
def fetch_existing_embedding_providers(
|
||||
db_session: Session,
|
||||
) -> list[CloudEmbeddingProviderModel]:
|
||||
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
|
||||
|
||||
|
||||
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
|
||||
return list(db_session.scalars(select(LLMProviderModel)).all())
|
||||
|
||||
|
||||
def fetch_embedding_provider(
|
||||
db_session: Session, provider_id: int
|
||||
) -> CloudEmbeddingProviderModel | None:
|
||||
return db_session.scalar(
|
||||
select(CloudEmbeddingProviderModel).where(
|
||||
CloudEmbeddingProviderModel.id == provider_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
@@ -70,6 +108,16 @@ def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider |
|
||||
return FullLLMProvider.from_model(provider_model)
|
||||
|
||||
|
||||
def remove_embedding_provider(
|
||||
db_session: Session, embedding_provider_name: str
|
||||
) -> None:
|
||||
db_session.execute(
|
||||
delete(CloudEmbeddingProviderModel).where(
|
||||
CloudEmbeddingProviderModel.name == embedding_provider_name
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
||||
db_session.execute(
|
||||
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
|
||||
@@ -130,6 +130,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
chat_folders: Mapped[list["ChatFolder"]] = relationship(
|
||||
"ChatFolder", back_populates="user"
|
||||
)
|
||||
|
||||
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
|
||||
# Personas owned by this user
|
||||
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
|
||||
@@ -246,6 +247,39 @@ class Persona__Tool(Base):
|
||||
tool_id: Mapped[int] = mapped_column(ForeignKey("tool.id"), primary_key=True)
|
||||
|
||||
|
||||
class StandardAnswer__StandardAnswerCategory(Base):
|
||||
__tablename__ = "standard_answer__standard_answer_category"
|
||||
|
||||
standard_answer_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("standard_answer.id"), primary_key=True
|
||||
)
|
||||
standard_answer_category_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("standard_answer_category.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class SlackBotConfig__StandardAnswerCategory(Base):
|
||||
__tablename__ = "slack_bot_config__standard_answer_category"
|
||||
|
||||
slack_bot_config_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("slack_bot_config.id"), primary_key=True
|
||||
)
|
||||
standard_answer_category_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("standard_answer_category.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class ChatMessage__StandardAnswer(Base):
|
||||
__tablename__ = "chat_message__standard_answer"
|
||||
|
||||
chat_message_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("chat_message.id"), primary_key=True
|
||||
)
|
||||
standard_answer_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("standard_answer.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Documents/Indexing Tables
|
||||
"""
|
||||
@@ -436,7 +470,7 @@ class Credential(Base):
|
||||
|
||||
class EmbeddingModel(Base):
|
||||
__tablename__ = "embedding_model"
|
||||
# ID is used also to indicate the order that the models are configured by the admin
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
model_name: Mapped[str] = mapped_column(String)
|
||||
model_dim: Mapped[int] = mapped_column(Integer)
|
||||
@@ -448,6 +482,16 @@ class EmbeddingModel(Base):
|
||||
)
|
||||
index_name: Mapped[str] = mapped_column(String)
|
||||
|
||||
# New field for cloud provider relationship
|
||||
cloud_provider_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("embedding_provider.id")
|
||||
)
|
||||
cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship(
|
||||
"CloudEmbeddingProvider",
|
||||
back_populates="embedding_models",
|
||||
foreign_keys=[cloud_provider_id],
|
||||
)
|
||||
|
||||
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
|
||||
"IndexAttempt", back_populates="embedding_model"
|
||||
)
|
||||
@@ -467,6 +511,18 @@ class EmbeddingModel(Base):
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
|
||||
cloud_provider='{self.cloud_provider.name if self.cloud_provider else 'None'}')>"
|
||||
|
||||
@property
|
||||
def api_key(self) -> str | None:
|
||||
return self.cloud_provider.api_key if self.cloud_provider else None
|
||||
|
||||
@property
|
||||
def provider_type(self) -> str | None:
|
||||
return self.cloud_provider.name if self.cloud_provider else None
|
||||
|
||||
|
||||
class IndexAttempt(Base):
|
||||
"""
|
||||
@@ -486,6 +542,7 @@ class IndexAttempt(Base):
|
||||
ForeignKey("credential.id"),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Some index attempts that run from beginning will still have this as False
|
||||
# This is only for attempts that are explicitly marked as from the start via
|
||||
# the run once API
|
||||
@@ -612,6 +669,10 @@ class SearchDoc(Base):
|
||||
secondary_owners: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
is_internet: Mapped[bool] = mapped_column(Boolean, default=False, nullable=True)
|
||||
|
||||
is_relevant: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
relevance_explanation: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
chat_messages = relationship(
|
||||
"ChatMessage",
|
||||
@@ -663,6 +724,10 @@ class ChatSession(Base):
|
||||
|
||||
current_alternate_model: Mapped[str | None] = mapped_column(String, default=None)
|
||||
|
||||
slack_thread_id: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True, default=None
|
||||
)
|
||||
|
||||
# the latest "overrides" specified by the user. These take precedence over
|
||||
# the attached persona. However, overrides specified directly in the
|
||||
# `send-message` call will take precedence over these.
|
||||
@@ -760,6 +825,11 @@ class ChatMessage(Base):
|
||||
"ToolCall",
|
||||
back_populates="message",
|
||||
)
|
||||
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
|
||||
"StandardAnswer",
|
||||
secondary=ChatMessage__StandardAnswer.__table__,
|
||||
back_populates="chat_messages",
|
||||
)
|
||||
|
||||
|
||||
class ChatFolder(Base):
|
||||
@@ -836,11 +906,6 @@ class ChatMessageFeedback(Base):
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Structures, Organizational, Configurations Tables
|
||||
"""
|
||||
|
||||
|
||||
class LLMProvider(Base):
|
||||
__tablename__ = "llm_provider"
|
||||
|
||||
@@ -869,6 +934,29 @@ class LLMProvider(Base):
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(Base):
|
||||
__tablename__ = "embedding_provider"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString())
|
||||
default_model_id: Mapped[int | None] = mapped_column(
|
||||
Integer, ForeignKey("embedding_model.id"), nullable=True
|
||||
)
|
||||
|
||||
embedding_models: Mapped[list["EmbeddingModel"]] = relationship(
|
||||
"EmbeddingModel",
|
||||
back_populates="cloud_provider",
|
||||
foreign_keys="EmbeddingModel.cloud_provider_id",
|
||||
)
|
||||
default_model: Mapped["EmbeddingModel"] = relationship(
|
||||
"EmbeddingModel", foreign_keys=[default_model_id]
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<EmbeddingProvider(name='{self.name}')>"
|
||||
|
||||
|
||||
class DocumentSet(Base):
|
||||
__tablename__ = "document_set"
|
||||
|
||||
@@ -948,6 +1036,7 @@ class Tool(Base):
|
||||
# ID of the tool in the codebase, only applies for in-code tools.
|
||||
# tools defined via the UI will have this as None
|
||||
in_code_tool_id: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
display_name: Mapped[str] = mapped_column(String, nullable=True)
|
||||
|
||||
# OpenAPI scheme for the tool. Only applies to tools defined via the UI.
|
||||
openapi_schema: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
@@ -1077,14 +1166,60 @@ class ChannelConfig(TypedDict):
|
||||
channel_names: list[str]
|
||||
respond_tag_only: NotRequired[bool] # defaults to False
|
||||
respond_to_bots: NotRequired[bool] # defaults to False
|
||||
respond_team_member_list: NotRequired[list[str]]
|
||||
respond_slack_group_list: NotRequired[list[str]]
|
||||
respond_member_group_list: NotRequired[list[str]]
|
||||
answer_filters: NotRequired[list[AllowedAnswerFilters]]
|
||||
# If None then no follow up
|
||||
# If empty list, follow up with no tags
|
||||
follow_up_tags: NotRequired[list[str]]
|
||||
|
||||
|
||||
class StandardAnswerCategory(Base):
|
||||
__tablename__ = "standard_answer_category"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
|
||||
"StandardAnswer",
|
||||
secondary=StandardAnswer__StandardAnswerCategory.__table__,
|
||||
back_populates="categories",
|
||||
)
|
||||
slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship(
|
||||
"SlackBotConfig",
|
||||
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
|
||||
back_populates="standard_answer_categories",
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswer(Base):
|
||||
__tablename__ = "standard_answer"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
keyword: Mapped[str] = mapped_column(String)
|
||||
answer: Mapped[str] = mapped_column(String)
|
||||
active: Mapped[bool] = mapped_column(Boolean)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"unique_keyword_active",
|
||||
keyword,
|
||||
active,
|
||||
unique=True,
|
||||
postgresql_where=(active == True), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
categories: Mapped[list[StandardAnswerCategory]] = relationship(
|
||||
"StandardAnswerCategory",
|
||||
secondary=StandardAnswer__StandardAnswerCategory.__table__,
|
||||
back_populates="standard_answers",
|
||||
)
|
||||
chat_messages: Mapped[list[ChatMessage]] = relationship(
|
||||
"ChatMessage",
|
||||
secondary=ChatMessage__StandardAnswer.__table__,
|
||||
back_populates="standard_answers",
|
||||
)
|
||||
|
||||
|
||||
class SlackBotResponseType(str, PyEnum):
|
||||
QUOTES = "quotes"
|
||||
CITATIONS = "citations"
|
||||
@@ -1105,7 +1240,16 @@ class SlackBotConfig(Base):
|
||||
Enum(SlackBotResponseType, native_enum=False), nullable=False
|
||||
)
|
||||
|
||||
enable_auto_filters: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
|
||||
persona: Mapped[Persona | None] = relationship("Persona")
|
||||
standard_answer_categories: Mapped[list[StandardAnswerCategory]] = relationship(
|
||||
"StandardAnswerCategory",
|
||||
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
|
||||
back_populates="slack_bot_configs",
|
||||
)
|
||||
|
||||
|
||||
class TaskQueueState(Base):
|
||||
|
||||
@@ -12,8 +12,8 @@ from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from danswer.db.document_set import get_document_sets_by_ids
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import Persona
|
||||
@@ -62,19 +62,6 @@ def create_update_persona(
|
||||
) -> PersonaSnapshot:
|
||||
"""Higher level function than upsert_persona, although either is valid to use."""
|
||||
# Permission to actually use these is checked later
|
||||
document_sets = list(
|
||||
get_document_sets_by_ids(
|
||||
document_set_ids=create_persona_request.document_set_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
prompts = list(
|
||||
get_prompts_by_ids(
|
||||
prompt_ids=create_persona_request.prompt_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
persona = upsert_persona(
|
||||
persona_id=persona_id,
|
||||
@@ -85,9 +72,9 @@ def create_update_persona(
|
||||
llm_relevance_filter=create_persona_request.llm_relevance_filter,
|
||||
llm_filter_extraction=create_persona_request.llm_filter_extraction,
|
||||
recency_bias=create_persona_request.recency_bias,
|
||||
prompts=prompts,
|
||||
prompt_ids=create_persona_request.prompt_ids,
|
||||
tool_ids=create_persona_request.tool_ids,
|
||||
document_sets=document_sets,
|
||||
document_set_ids=create_persona_request.document_set_ids,
|
||||
llm_model_provider_override=create_persona_request.llm_model_provider_override,
|
||||
llm_model_version_override=create_persona_request.llm_model_version_override,
|
||||
starter_messages=create_persona_request.starter_messages,
|
||||
@@ -330,13 +317,13 @@ def upsert_persona(
|
||||
llm_relevance_filter: bool,
|
||||
llm_filter_extraction: bool,
|
||||
recency_bias: RecencyBiasSetting,
|
||||
prompts: list[Prompt] | None,
|
||||
document_sets: list[DocumentSet] | None,
|
||||
llm_model_provider_override: str | None,
|
||||
llm_model_version_override: str | None,
|
||||
starter_messages: list[StarterMessage] | None,
|
||||
is_public: bool,
|
||||
db_session: Session,
|
||||
prompt_ids: list[int] | None = None,
|
||||
document_set_ids: list[int] | None = None,
|
||||
tool_ids: list[int] | None = None,
|
||||
persona_id: int | None = None,
|
||||
default_persona: bool = False,
|
||||
@@ -356,6 +343,28 @@ def upsert_persona(
|
||||
if not tools and tool_ids:
|
||||
raise ValueError("Tools not found")
|
||||
|
||||
# Fetch and attach document_sets by IDs
|
||||
document_sets = None
|
||||
if document_set_ids is not None:
|
||||
document_sets = (
|
||||
db_session.query(DocumentSet)
|
||||
.filter(DocumentSet.id.in_(document_set_ids))
|
||||
.all()
|
||||
)
|
||||
if not document_sets and document_set_ids:
|
||||
raise ValueError("document_sets not found")
|
||||
|
||||
# Fetch and attach prompts by IDs
|
||||
prompts = None
|
||||
if prompt_ids is not None:
|
||||
prompts = db_session.query(Prompt).filter(Prompt.id.in_(prompt_ids)).all()
|
||||
if not prompts and prompt_ids:
|
||||
raise ValueError("prompts not found")
|
||||
|
||||
# ensure all specified tools are valid
|
||||
if tools:
|
||||
validate_persona_tools(tools)
|
||||
|
||||
if persona:
|
||||
if not default_persona and persona.default_persona:
|
||||
raise ValueError("Cannot update default persona with non-default.")
|
||||
@@ -383,10 +392,10 @@ def upsert_persona(
|
||||
|
||||
if prompts is not None:
|
||||
persona.prompts.clear()
|
||||
persona.prompts = prompts
|
||||
persona.prompts = prompts or []
|
||||
|
||||
if tools is not None:
|
||||
persona.tools = tools
|
||||
persona.tools = tools or []
|
||||
|
||||
else:
|
||||
persona = Persona(
|
||||
@@ -453,6 +462,14 @@ def update_persona_visibility(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def validate_persona_tools(tools: list[Tool]) -> None:
|
||||
for tool in tools:
|
||||
if tool.name == "InternetSearchTool" and not BING_API_KEY:
|
||||
raise ValueError(
|
||||
"Bing API key not found, please contact your Danswer admin to get it added!"
|
||||
)
|
||||
|
||||
|
||||
def check_user_can_edit_persona(user: User | None, persona: Persona) -> None:
|
||||
# if user is None, assume that no-auth is turned on
|
||||
if user is None:
|
||||
|
||||
@@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from danswer.db.document_set import get_document_sets_by_ids
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Persona__DocumentSet
|
||||
@@ -15,6 +14,7 @@ from danswer.db.models import User
|
||||
from danswer.db.persona import get_default_prompt
|
||||
from danswer.db.persona import mark_persona_as_deleted
|
||||
from danswer.db.persona import upsert_persona
|
||||
from danswer.db.standard_answer import fetch_standard_answer_categories_by_ids
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
|
||||
|
||||
@@ -42,12 +42,6 @@ def create_slack_bot_persona(
|
||||
num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
) -> Persona:
|
||||
"""NOTE: does not commit changes"""
|
||||
document_sets = list(
|
||||
get_document_sets_by_ids(
|
||||
document_set_ids=document_set_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
|
||||
# create/update persona associated with the slack bot
|
||||
persona_name = _build_persona_name(channel_names)
|
||||
@@ -59,10 +53,10 @@ def create_slack_bot_persona(
|
||||
description="",
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
prompts=[default_prompt],
|
||||
document_sets=document_sets,
|
||||
prompt_ids=[default_prompt.id],
|
||||
document_set_ids=document_set_ids,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
@@ -79,12 +73,25 @@ def insert_slack_bot_config(
|
||||
persona_id: int | None,
|
||||
channel_config: ChannelConfig,
|
||||
response_type: SlackBotResponseType,
|
||||
standard_answer_category_ids: list[int],
|
||||
enable_auto_filters: bool,
|
||||
db_session: Session,
|
||||
) -> SlackBotConfig:
|
||||
existing_standard_answer_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=standard_answer_category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
|
||||
raise ValueError(
|
||||
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
|
||||
)
|
||||
|
||||
slack_bot_config = SlackBotConfig(
|
||||
persona_id=persona_id,
|
||||
channel_config=channel_config,
|
||||
response_type=response_type,
|
||||
standard_answer_categories=existing_standard_answer_categories,
|
||||
enable_auto_filters=enable_auto_filters,
|
||||
)
|
||||
db_session.add(slack_bot_config)
|
||||
db_session.commit()
|
||||
@@ -97,6 +104,8 @@ def update_slack_bot_config(
|
||||
persona_id: int | None,
|
||||
channel_config: ChannelConfig,
|
||||
response_type: SlackBotResponseType,
|
||||
standard_answer_category_ids: list[int],
|
||||
enable_auto_filters: bool,
|
||||
db_session: Session,
|
||||
) -> SlackBotConfig:
|
||||
slack_bot_config = db_session.scalar(
|
||||
@@ -106,6 +115,16 @@ def update_slack_bot_config(
|
||||
raise ValueError(
|
||||
f"Unable to find slack bot config with ID {slack_bot_config_id}"
|
||||
)
|
||||
|
||||
existing_standard_answer_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=standard_answer_category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
|
||||
raise ValueError(
|
||||
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
|
||||
)
|
||||
|
||||
# get the existing persona id before updating the object
|
||||
existing_persona_id = slack_bot_config.persona_id
|
||||
|
||||
@@ -115,6 +134,10 @@ def update_slack_bot_config(
|
||||
slack_bot_config.persona_id = persona_id
|
||||
slack_bot_config.channel_config = channel_config
|
||||
slack_bot_config.response_type = response_type
|
||||
slack_bot_config.standard_answer_categories = list(
|
||||
existing_standard_answer_categories
|
||||
)
|
||||
slack_bot_config.enable_auto_filters = enable_auto_filters
|
||||
|
||||
# if the persona has changed, then clean up the old persona
|
||||
if persona_id != existing_persona_id and existing_persona_id:
|
||||
|
||||
239
backend/danswer/db/standard_answer.py
Normal file
239
backend/danswer/db/standard_answer.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import string
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import StandardAnswer
|
||||
from danswer.db.models import StandardAnswerCategory
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def check_category_validity(category_name: str) -> bool:
|
||||
"""If a category name is too long, it should not be used (it will cause an error in Postgres
|
||||
as the unique constraint can only apply to entries that are less than 2704 bytes).
|
||||
|
||||
Additionally, extremely long categories are not really usable / useful."""
|
||||
if len(category_name) > 255:
|
||||
logger.error(
|
||||
f"Category with name '{category_name}' is too long, cannot be used"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def insert_standard_answer_category(
|
||||
category_name: str, db_session: Session
|
||||
) -> StandardAnswerCategory:
|
||||
if not check_category_validity(category_name):
|
||||
raise ValueError(f"Invalid category name: {category_name}")
|
||||
standard_answer_category = StandardAnswerCategory(name=category_name)
|
||||
db_session.add(standard_answer_category)
|
||||
db_session.commit()
|
||||
|
||||
return standard_answer_category
|
||||
|
||||
|
||||
def insert_standard_answer(
|
||||
keyword: str,
|
||||
answer: str,
|
||||
category_ids: list[int],
|
||||
db_session: Session,
|
||||
) -> StandardAnswer:
|
||||
existing_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
if len(existing_categories) != len(category_ids):
|
||||
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
|
||||
|
||||
standard_answer = StandardAnswer(
|
||||
keyword=keyword,
|
||||
answer=answer,
|
||||
categories=existing_categories,
|
||||
active=True,
|
||||
)
|
||||
db_session.add(standard_answer)
|
||||
db_session.commit()
|
||||
return standard_answer
|
||||
|
||||
|
||||
def update_standard_answer(
|
||||
standard_answer_id: int,
|
||||
keyword: str,
|
||||
answer: str,
|
||||
category_ids: list[int],
|
||||
db_session: Session,
|
||||
) -> StandardAnswer:
|
||||
standard_answer = db_session.scalar(
|
||||
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
|
||||
)
|
||||
if standard_answer is None:
|
||||
raise ValueError(f"No standard answer with id {standard_answer_id}")
|
||||
|
||||
existing_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
if len(existing_categories) != len(category_ids):
|
||||
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
|
||||
|
||||
standard_answer.keyword = keyword
|
||||
standard_answer.answer = answer
|
||||
standard_answer.categories = list(existing_categories)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return standard_answer
|
||||
|
||||
|
||||
def remove_standard_answer(
|
||||
standard_answer_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
standard_answer = db_session.scalar(
|
||||
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
|
||||
)
|
||||
if standard_answer is None:
|
||||
raise ValueError(f"No standard answer with id {standard_answer_id}")
|
||||
|
||||
standard_answer.active = False
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_standard_answer_category(
|
||||
standard_answer_category_id: int,
|
||||
category_name: str,
|
||||
db_session: Session,
|
||||
) -> StandardAnswerCategory:
|
||||
standard_answer_category = db_session.scalar(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.id == standard_answer_category_id
|
||||
)
|
||||
)
|
||||
if standard_answer_category is None:
|
||||
raise ValueError(
|
||||
f"No standard answer category with id {standard_answer_category_id}"
|
||||
)
|
||||
|
||||
if not check_category_validity(category_name):
|
||||
raise ValueError(f"Invalid category name: {category_name}")
|
||||
|
||||
standard_answer_category.name = category_name
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return standard_answer_category
|
||||
|
||||
|
||||
def fetch_standard_answer_category(
|
||||
standard_answer_category_id: int,
|
||||
db_session: Session,
|
||||
) -> StandardAnswerCategory | None:
|
||||
return db_session.scalar(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.id == standard_answer_category_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def fetch_standard_answer_categories_by_names(
|
||||
standard_answer_category_names: list[str],
|
||||
db_session: Session,
|
||||
) -> Sequence[StandardAnswerCategory]:
|
||||
return db_session.scalars(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.name.in_(standard_answer_category_names)
|
||||
)
|
||||
).all()
|
||||
|
||||
|
||||
def fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids: list[int],
|
||||
db_session: Session,
|
||||
) -> Sequence[StandardAnswerCategory]:
|
||||
return db_session.scalars(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.id.in_(standard_answer_category_ids)
|
||||
)
|
||||
).all()
|
||||
|
||||
|
||||
def fetch_standard_answer_categories(
|
||||
db_session: Session,
|
||||
) -> Sequence[StandardAnswerCategory]:
|
||||
return db_session.scalars(select(StandardAnswerCategory)).all()
|
||||
|
||||
|
||||
def fetch_standard_answer(
|
||||
standard_answer_id: int,
|
||||
db_session: Session,
|
||||
) -> StandardAnswer | None:
|
||||
return db_session.scalar(
|
||||
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
|
||||
)
|
||||
|
||||
|
||||
def find_matching_standard_answers(
|
||||
id_in: list[int],
|
||||
query: str,
|
||||
db_session: Session,
|
||||
) -> list[StandardAnswer]:
|
||||
stmt = (
|
||||
select(StandardAnswer)
|
||||
.where(StandardAnswer.active.is_(True))
|
||||
.where(StandardAnswer.id.in_(id_in))
|
||||
)
|
||||
possible_standard_answers = db_session.scalars(stmt).all()
|
||||
|
||||
matching_standard_answers: list[StandardAnswer] = []
|
||||
for standard_answer in possible_standard_answers:
|
||||
# Remove punctuation and split the keyword into individual words
|
||||
keyword_words = "".join(
|
||||
char
|
||||
for char in standard_answer.keyword.lower()
|
||||
if char not in string.punctuation
|
||||
).split()
|
||||
|
||||
# Remove punctuation and split the query into individual words
|
||||
query_words = "".join(
|
||||
char for char in query.lower() if char not in string.punctuation
|
||||
).split()
|
||||
|
||||
# Check if all of the keyword words are in the query words
|
||||
if all(word in query_words for word in keyword_words):
|
||||
matching_standard_answers.append(standard_answer)
|
||||
|
||||
return matching_standard_answers
|
||||
|
||||
|
||||
def fetch_standard_answers(db_session: Session) -> Sequence[StandardAnswer]:
|
||||
return db_session.scalars(
|
||||
select(StandardAnswer).where(StandardAnswer.active.is_(True))
|
||||
).all()
|
||||
|
||||
|
||||
def create_initial_default_standard_answer_category(db_session: Session) -> None:
|
||||
default_category_id = 0
|
||||
default_category_name = "General"
|
||||
default_category = fetch_standard_answer_category(
|
||||
standard_answer_category_id=default_category_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if default_category is not None:
|
||||
if default_category.name != default_category_name:
|
||||
raise ValueError(
|
||||
"DB is not in a valid initial state. "
|
||||
"Default standard answer category does not have expected name."
|
||||
)
|
||||
return
|
||||
|
||||
standard_answer_category = StandardAnswerCategory(
|
||||
id=default_category_id,
|
||||
name=default_category_name,
|
||||
)
|
||||
db_session.add(standard_answer_category)
|
||||
db_session.commit()
|
||||
@@ -1,5 +1,6 @@
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -107,18 +108,28 @@ def create_or_add_document_tag_list(
|
||||
|
||||
|
||||
def get_tags_by_value_prefix_for_source_types(
|
||||
tag_key_prefix: str | None,
|
||||
tag_value_prefix: str | None,
|
||||
sources: list[DocumentSource] | None,
|
||||
limit: int | None,
|
||||
db_session: Session,
|
||||
) -> list[Tag]:
|
||||
query = select(Tag)
|
||||
|
||||
if tag_value_prefix:
|
||||
query = query.where(Tag.tag_value.startswith(tag_value_prefix))
|
||||
if tag_key_prefix or tag_value_prefix:
|
||||
conditions = []
|
||||
if tag_key_prefix:
|
||||
conditions.append(Tag.tag_key.ilike(f"{tag_key_prefix}%"))
|
||||
if tag_value_prefix:
|
||||
conditions.append(Tag.tag_value.ilike(f"{tag_value_prefix}%"))
|
||||
query = query.where(or_(*conditions))
|
||||
|
||||
if sources:
|
||||
query = query.where(Tag.source.in_(sources))
|
||||
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
result = db_session.execute(query)
|
||||
|
||||
tags = result.scalars().all()
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceChunkUncleaned
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -186,7 +186,7 @@ class IdRetrievalCapable(abc.ABC):
|
||||
min_chunk_ind: int | None,
|
||||
max_chunk_ind: int | None,
|
||||
user_access_control_list: list[str] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
"""
|
||||
Fetch chunk(s) based on document id
|
||||
|
||||
@@ -222,7 +222,7 @@ class KeywordCapable(abc.ABC):
|
||||
time_decay_multiplier: float,
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
) -> list[InferenceChunk]:
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
"""
|
||||
Run keyword search and return a list of chunks. Inference chunks are chunks with all of the
|
||||
information required for query time purposes. For example, some details of the document
|
||||
@@ -262,7 +262,7 @@ class VectorCapable(abc.ABC):
|
||||
time_decay_multiplier: float,
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
) -> list[InferenceChunk]:
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
"""
|
||||
Run vector/semantic search and return a list of inference chunks.
|
||||
|
||||
@@ -298,7 +298,7 @@ class HybridCapable(abc.ABC):
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
hybrid_alpha: float | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
"""
|
||||
Run hybrid search and return a list of inference chunks.
|
||||
|
||||
@@ -348,7 +348,7 @@ class AdminCapable(abc.ABC):
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
) -> list[InferenceChunk]:
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
"""
|
||||
Run the special search for the admin document explorer page
|
||||
|
||||
|
||||
@@ -91,6 +91,9 @@ schema DANSWER_CHUNK_NAME {
|
||||
field metadata type string {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
field metadata_suffix type string {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
field doc_updated_at type int {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
@@ -150,43 +153,41 @@ schema DANSWER_CHUNK_NAME {
|
||||
query(query_embedding) tensor<float>(x[VARIABLE_DIM])
|
||||
}
|
||||
|
||||
# This must be separate function for normalize_linear to work
|
||||
function vector_score() {
|
||||
function title_vector_score() {
|
||||
expression {
|
||||
# If no title, the full vector score comes from the content embedding
|
||||
(query(title_content_ratio) * if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))) +
|
||||
((1 - query(title_content_ratio)) * closeness(field, embeddings))
|
||||
}
|
||||
}
|
||||
|
||||
# This must be separate function for normalize_linear to work
|
||||
function keyword_score() {
|
||||
expression {
|
||||
(query(title_content_ratio) * bm25(title)) +
|
||||
((1 - query(title_content_ratio)) * bm25(content))
|
||||
#query(title_content_ratio) * if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))
|
||||
if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))
|
||||
}
|
||||
}
|
||||
|
||||
first-phase {
|
||||
expression: vector_score
|
||||
expression: closeness(field, embeddings)
|
||||
}
|
||||
|
||||
# Weighted average between Vector Search and BM-25
|
||||
# Each is a weighted average between the Title and Content fields
|
||||
# Finally each doc is boosted by it's user feedback based boost and recency
|
||||
# If any embedding or index field is missing, it just receives a score of 0
|
||||
# Assumptions:
|
||||
# - For a given query + corpus, the BM-25 scores will be relatively similar in distribution
|
||||
# therefore not normalizing before combining.
|
||||
# - For documents without title, it gets a score of 0 for that and this is ok as documents
|
||||
# without any title match should be penalized.
|
||||
global-phase {
|
||||
expression {
|
||||
(
|
||||
# Weighted Vector Similarity Score
|
||||
(query(alpha) * normalize_linear(vector_score)) +
|
||||
(
|
||||
query(alpha) * (
|
||||
(query(title_content_ratio) * normalize_linear(title_vector_score))
|
||||
+
|
||||
((1 - query(title_content_ratio)) * normalize_linear(closeness(field, embeddings)))
|
||||
)
|
||||
)
|
||||
|
||||
+
|
||||
|
||||
# Weighted Keyword Similarity Score
|
||||
((1 - query(alpha)) * normalize_linear(keyword_score))
|
||||
(
|
||||
(1 - query(alpha)) * (
|
||||
(query(title_content_ratio) * normalize_linear(bm25(title)))
|
||||
+
|
||||
((1 - query(title_content_ratio)) * normalize_linear(bm25(content)))
|
||||
)
|
||||
)
|
||||
)
|
||||
# Boost based on user feedback
|
||||
* document_boost
|
||||
@@ -201,8 +202,6 @@ schema DANSWER_CHUNK_NAME {
|
||||
bm25(content)
|
||||
closeness(field, title_embedding)
|
||||
closeness(field, embeddings)
|
||||
keyword_score
|
||||
vector_score
|
||||
document_boost
|
||||
recency_bias
|
||||
closest(embeddings)
|
||||
|
||||
@@ -41,6 +41,7 @@ from danswer.configs.constants import HIDDEN
|
||||
from danswer.configs.constants import INDEX_SEPARATOR
|
||||
from danswer.configs.constants import METADATA
|
||||
from danswer.configs.constants import METADATA_LIST
|
||||
from danswer.configs.constants import METADATA_SUFFIX
|
||||
from danswer.configs.constants import PRIMARY_OWNERS
|
||||
from danswer.configs.constants import RECENCY_BIAS
|
||||
from danswer.configs.constants import SECONDARY_OWNERS
|
||||
@@ -51,7 +52,6 @@ from danswer.configs.constants import SOURCE_LINKS
|
||||
from danswer.configs.constants import SOURCE_TYPE
|
||||
from danswer.configs.constants import TITLE
|
||||
from danswer.configs.constants import TITLE_EMBEDDING
|
||||
from danswer.configs.constants import TITLE_SEPARATOR
|
||||
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
@@ -64,7 +64,7 @@ from danswer.document_index.vespa.utils import remove_invalid_unicode_chars
|
||||
from danswer.document_index.vespa.utils import replace_invalid_doc_id_characters
|
||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceChunkUncleaned
|
||||
from danswer.search.retrieval.search_runner import query_processing
|
||||
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
|
||||
from danswer.utils.batching import batch_generator
|
||||
@@ -119,6 +119,7 @@ def _does_document_exist(
|
||||
chunk. This checks for whether the chunk exists already in the index"""
|
||||
doc_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}"
|
||||
doc_fetch_response = http_client.get(doc_url)
|
||||
|
||||
if doc_fetch_response.status_code == 404:
|
||||
return False
|
||||
|
||||
@@ -346,8 +347,10 @@ def _index_vespa_chunk(
|
||||
TITLE: remove_invalid_unicode_chars(title) if title else None,
|
||||
SKIP_TITLE_EMBEDDING: not title,
|
||||
CONTENT: remove_invalid_unicode_chars(chunk.content),
|
||||
# This duplication of `content` is needed for keyword highlighting :(
|
||||
CONTENT_SUMMARY: remove_invalid_unicode_chars(chunk.content),
|
||||
# This duplication of `content` is needed for keyword highlighting
|
||||
# Note that it's not exactly the same as the actual content
|
||||
# which contains the title prefix and metadata suffix
|
||||
CONTENT_SUMMARY: remove_invalid_unicode_chars(chunk.content_summary),
|
||||
SOURCE_TYPE: str(document.source.value),
|
||||
SOURCE_LINKS: json.dumps(chunk.source_links),
|
||||
SEMANTIC_IDENTIFIER: remove_invalid_unicode_chars(document.semantic_identifier),
|
||||
@@ -355,6 +358,7 @@ def _index_vespa_chunk(
|
||||
METADATA: json.dumps(document.metadata),
|
||||
# Save as a list for efficient extraction as an Attribute
|
||||
METADATA_LIST: chunk.source_document.get_metadata_str_attributes(),
|
||||
METADATA_SUFFIX: chunk.metadata_suffix,
|
||||
EMBEDDINGS: embeddings_name_vector_map,
|
||||
TITLE_EMBEDDING: chunk.title_embedding,
|
||||
BOOST: chunk.boost,
|
||||
@@ -559,7 +563,9 @@ def _process_dynamic_summary(
|
||||
return processed_summary
|
||||
|
||||
|
||||
def _vespa_hit_to_inference_chunk(hit: dict[str, Any]) -> InferenceChunk:
|
||||
def _vespa_hit_to_inference_chunk(
|
||||
hit: dict[str, Any], null_score: bool = False
|
||||
) -> InferenceChunkUncleaned:
|
||||
fields = cast(dict[str, Any], hit["fields"])
|
||||
|
||||
# parse fields that are stored as strings, but are really json / datetime
|
||||
@@ -582,19 +588,6 @@ def _vespa_hit_to_inference_chunk(hit: dict[str, Any]) -> InferenceChunk:
|
||||
f"Chunk with blurb: {fields.get(BLURB, 'Unknown')[:50]}... has no Semantic Identifier"
|
||||
)
|
||||
|
||||
# Remove the title from the first chunk as every chunk already included
|
||||
# its semantic identifier for LLM
|
||||
content = fields[CONTENT]
|
||||
if fields[CHUNK_ID] == 0:
|
||||
parts = content.split(TITLE_SEPARATOR, maxsplit=1)
|
||||
content = parts[1] if len(parts) > 1 and "\n" not in parts[0] else content
|
||||
|
||||
# User ran into this, not sure why this could happen, error checking here
|
||||
blurb = fields.get(BLURB)
|
||||
if not blurb:
|
||||
logger.error(f"Chunk with id {fields.get(semantic_identifier)} ")
|
||||
blurb = ""
|
||||
|
||||
source_links = fields.get(SOURCE_LINKS, {})
|
||||
source_links_dict_unprocessed = (
|
||||
json.loads(source_links) if isinstance(source_links, str) else source_links
|
||||
@@ -604,29 +597,33 @@ def _vespa_hit_to_inference_chunk(hit: dict[str, Any]) -> InferenceChunk:
|
||||
for k, v in cast(dict[str, str], source_links_dict_unprocessed).items()
|
||||
}
|
||||
|
||||
return InferenceChunk(
|
||||
return InferenceChunkUncleaned(
|
||||
chunk_id=fields[CHUNK_ID],
|
||||
blurb=blurb,
|
||||
content=content,
|
||||
blurb=fields.get(BLURB, ""), # Unused
|
||||
content=fields[CONTENT], # Includes extra title prefix and metadata suffix
|
||||
source_links=source_links_dict,
|
||||
section_continuation=fields[SECTION_CONTINUATION],
|
||||
document_id=fields[DOCUMENT_ID],
|
||||
source_type=fields[SOURCE_TYPE],
|
||||
title=fields.get(TITLE),
|
||||
semantic_identifier=fields[SEMANTIC_IDENTIFIER],
|
||||
boost=fields.get(BOOST, 1),
|
||||
recency_bias=fields.get("matchfeatures", {}).get(RECENCY_BIAS, 1.0),
|
||||
score=hit.get("relevance", 0),
|
||||
score=None if null_score else hit.get("relevance", 0),
|
||||
hidden=fields.get(HIDDEN, False),
|
||||
primary_owners=fields.get(PRIMARY_OWNERS),
|
||||
secondary_owners=fields.get(SECONDARY_OWNERS),
|
||||
metadata=metadata,
|
||||
metadata_suffix=fields.get(METADATA_SUFFIX),
|
||||
match_highlights=match_highlights,
|
||||
updated_at=updated_at,
|
||||
)
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _query_vespa(query_params: Mapping[str, str | int | float]) -> list[InferenceChunk]:
|
||||
def _query_vespa(
|
||||
query_params: Mapping[str, str | int | float]
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
if "query" in query_params and not cast(str, query_params["query"]).strip():
|
||||
raise ValueError("No/empty query received")
|
||||
|
||||
@@ -681,16 +678,6 @@ def _query_vespa(query_params: Mapping[str, str | int | float]) -> list[Inferenc
|
||||
return inference_chunks
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _inference_chunk_by_vespa_id(vespa_id: str, index_name: str) -> InferenceChunk:
|
||||
res = requests.get(
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_id}"
|
||||
)
|
||||
res.raise_for_status()
|
||||
|
||||
return _vespa_hit_to_inference_chunk(res.json())
|
||||
|
||||
|
||||
def in_memory_zip_from_file_bytes(file_contents: dict[str, bytes]) -> BinaryIO:
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zipf:
|
||||
@@ -735,6 +722,7 @@ class VespaIndex(DocumentIndex):
|
||||
f"{SOURCE_TYPE}, "
|
||||
f"{SOURCE_LINKS}, "
|
||||
f"{SEMANTIC_IDENTIFIER}, "
|
||||
f"{TITLE}, "
|
||||
f"{SECTION_CONTINUATION}, "
|
||||
f"{BOOST}, "
|
||||
f"{HIDDEN}, "
|
||||
@@ -742,6 +730,7 @@ class VespaIndex(DocumentIndex):
|
||||
f"{PRIMARY_OWNERS}, "
|
||||
f"{SECONDARY_OWNERS}, "
|
||||
f"{METADATA}, "
|
||||
f"{METADATA_SUFFIX}, "
|
||||
f"{CONTENT_SUMMARY} "
|
||||
f"from {{index_name}} where "
|
||||
)
|
||||
@@ -977,7 +966,7 @@ class VespaIndex(DocumentIndex):
|
||||
min_chunk_ind: int | None,
|
||||
max_chunk_ind: int | None,
|
||||
user_access_control_list: list[str] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
document_id = replace_invalid_doc_id_characters(document_id)
|
||||
|
||||
vespa_chunks = _get_vespa_chunks_by_document_id(
|
||||
@@ -992,7 +981,8 @@ class VespaIndex(DocumentIndex):
|
||||
return []
|
||||
|
||||
inference_chunks = [
|
||||
_vespa_hit_to_inference_chunk(chunk) for chunk in vespa_chunks
|
||||
_vespa_hit_to_inference_chunk(chunk, null_score=True)
|
||||
for chunk in vespa_chunks
|
||||
]
|
||||
inference_chunks.sort(key=lambda chunk: chunk.chunk_id)
|
||||
return inference_chunks
|
||||
@@ -1005,7 +995,7 @@ class VespaIndex(DocumentIndex):
|
||||
num_to_retrieve: int = NUM_RETURNED_HITS,
|
||||
offset: int = 0,
|
||||
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
|
||||
) -> list[InferenceChunk]:
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
# IMPORTANT: THIS FUNCTION IS NOT UP TO DATE, DOES NOT WORK CORRECTLY
|
||||
vespa_where_clauses = _build_vespa_filters(filters)
|
||||
yql = (
|
||||
@@ -1042,7 +1032,7 @@ class VespaIndex(DocumentIndex):
|
||||
offset: int = 0,
|
||||
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
|
||||
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
|
||||
) -> list[InferenceChunk]:
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
# IMPORTANT: THIS FUNCTION IS NOT UP TO DATE, DOES NOT WORK CORRECTLY
|
||||
vespa_where_clauses = _build_vespa_filters(filters)
|
||||
yql = (
|
||||
@@ -1086,7 +1076,7 @@ class VespaIndex(DocumentIndex):
|
||||
title_content_ratio: float | None = TITLE_CONTENT_RATIO,
|
||||
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
|
||||
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
|
||||
) -> list[InferenceChunk]:
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
vespa_where_clauses = _build_vespa_filters(filters)
|
||||
# Needs to be at least as much as the value set in Vespa schema config
|
||||
target_hits = max(10 * num_to_retrieve, 1000)
|
||||
@@ -1130,7 +1120,7 @@ class VespaIndex(DocumentIndex):
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int = NUM_RETURNED_HITS,
|
||||
offset: int = 0,
|
||||
) -> list[InferenceChunk]:
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
vespa_where_clauses = _build_vespa_filters(filters, include_hidden=True)
|
||||
yql = (
|
||||
VespaIndex.yql_base.format(index_name=self.index_name)
|
||||
|
||||
@@ -3,12 +3,16 @@ from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from danswer.configs.app_configs import BLURB_SIZE
|
||||
from danswer.configs.app_configs import CHUNK_OVERLAP
|
||||
from danswer.configs.app_configs import MINI_CHUNK_SIZE
|
||||
from danswer.configs.app_configs import SKIP_METADATA_IN_CHUNK
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MAX_CHUNK_TITLE_LEN
|
||||
from danswer.configs.constants import RETURN_SEPARATOR
|
||||
from danswer.configs.constants import SECTION_SEPARATOR
|
||||
from danswer.configs.constants import TITLE_SEPARATOR
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_metadata_keys_to_ignore,
|
||||
)
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.search.search_nlp_models import get_default_tokenizer
|
||||
@@ -19,6 +23,14 @@ if TYPE_CHECKING:
|
||||
from transformers import AutoTokenizer # type:ignore
|
||||
|
||||
|
||||
# Not supporting overlaps, we need a clean combination of chunks and it is unclear if overlaps
|
||||
# actually help quality at all
|
||||
CHUNK_OVERLAP = 0
|
||||
# Fairly arbitrary numbers but the general concept is we don't want the title/metadata to
|
||||
# overwhelm the actual contents of the chunk
|
||||
MAX_METADATA_PERCENTAGE = 0.25
|
||||
CHUNK_MIN_CONTENT = 256
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
ChunkFunc = Callable[[Document], list[DocAwareChunk]]
|
||||
@@ -44,6 +56,8 @@ def chunk_large_section(
|
||||
chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
chunk_overlap: int = CHUNK_OVERLAP,
|
||||
blurb_size: int = BLURB_SIZE,
|
||||
title_prefix: str = "",
|
||||
metadata_suffix: str = "",
|
||||
) -> list[DocAwareChunk]:
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
@@ -60,30 +74,69 @@ def chunk_large_section(
|
||||
source_document=document,
|
||||
chunk_id=start_chunk_id + chunk_ind,
|
||||
blurb=blurb,
|
||||
content=chunk_str,
|
||||
content=f"{title_prefix}{chunk_str}{metadata_suffix}",
|
||||
content_summary=chunk_str,
|
||||
source_links={0: section_link_text},
|
||||
section_continuation=(chunk_ind != 0),
|
||||
metadata_suffix=metadata_suffix,
|
||||
)
|
||||
for chunk_ind, chunk_str in enumerate(split_texts)
|
||||
]
|
||||
return chunks
|
||||
|
||||
|
||||
def _get_metadata_suffix_for_document_index(
|
||||
metadata: dict[str, str | list[str]]
|
||||
) -> str:
|
||||
if not metadata:
|
||||
return ""
|
||||
metadata_str = "Metadata:\n"
|
||||
for key, value in metadata.items():
|
||||
if key in get_metadata_keys_to_ignore():
|
||||
continue
|
||||
|
||||
value_str = ", ".join(value) if isinstance(value, list) else value
|
||||
metadata_str += f"\t{key} - {value_str}\n"
|
||||
return metadata_str.strip()
|
||||
|
||||
|
||||
def chunk_document(
|
||||
document: Document,
|
||||
chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
subsection_overlap: int = CHUNK_OVERLAP,
|
||||
blurb_size: int = BLURB_SIZE,
|
||||
include_metadata: bool = not SKIP_METADATA_IN_CHUNK,
|
||||
) -> list[DocAwareChunk]:
|
||||
title = document.get_title_for_document_index()
|
||||
title_prefix = title.replace("\n", " ") + TITLE_SEPARATOR if title else ""
|
||||
tokenizer = get_default_tokenizer()
|
||||
|
||||
title = document.get_title_for_document_index()
|
||||
title_prefix = f"{title[:MAX_CHUNK_TITLE_LEN]}{RETURN_SEPARATOR}" if title else ""
|
||||
title_tokens = len(tokenizer.tokenize(title_prefix))
|
||||
|
||||
metadata_suffix = ""
|
||||
metadata_tokens = 0
|
||||
if include_metadata:
|
||||
metadata = _get_metadata_suffix_for_document_index(document.metadata)
|
||||
metadata_suffix = RETURN_SEPARATOR + metadata if metadata else ""
|
||||
metadata_tokens = len(tokenizer.tokenize(metadata_suffix))
|
||||
|
||||
if metadata_tokens >= chunk_tok_size * MAX_METADATA_PERCENTAGE:
|
||||
metadata_suffix = ""
|
||||
metadata_tokens = 0
|
||||
|
||||
content_token_limit = chunk_tok_size - title_tokens - metadata_tokens
|
||||
|
||||
# If there is not enough context remaining then just index the chunk with no prefix/suffix
|
||||
if content_token_limit <= CHUNK_MIN_CONTENT:
|
||||
content_token_limit = chunk_tok_size
|
||||
title_prefix = ""
|
||||
metadata_suffix = ""
|
||||
|
||||
chunks: list[DocAwareChunk] = []
|
||||
link_offsets: dict[int, str] = {}
|
||||
chunk_text = ""
|
||||
for ind, section in enumerate(document.sections):
|
||||
section_text = title_prefix + section.text if ind == 0 else section.text
|
||||
for section in document.sections:
|
||||
section_text = section.text
|
||||
section_link_text = section.link or ""
|
||||
|
||||
section_tok_length = len(tokenizer.tokenize(section_text))
|
||||
@@ -92,16 +145,18 @@ def chunk_document(
|
||||
|
||||
# Large sections are considered self-contained/unique therefore they start a new chunk and are not concatenated
|
||||
# at the end by other sections
|
||||
if section_tok_length > chunk_tok_size:
|
||||
if section_tok_length > content_token_limit:
|
||||
if chunk_text:
|
||||
chunks.append(
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks),
|
||||
blurb=extract_blurb(chunk_text, blurb_size),
|
||||
content=chunk_text,
|
||||
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
|
||||
content_summary=chunk_text,
|
||||
source_links=link_offsets,
|
||||
section_continuation=False,
|
||||
metadata_suffix=metadata_suffix,
|
||||
)
|
||||
)
|
||||
link_offsets = {}
|
||||
@@ -113,9 +168,11 @@ def chunk_document(
|
||||
document=document,
|
||||
start_chunk_id=len(chunks),
|
||||
tokenizer=tokenizer,
|
||||
chunk_size=chunk_tok_size,
|
||||
chunk_size=content_token_limit,
|
||||
chunk_overlap=subsection_overlap,
|
||||
blurb_size=blurb_size,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix=metadata_suffix,
|
||||
)
|
||||
chunks.extend(large_section_chunks)
|
||||
continue
|
||||
@@ -125,7 +182,7 @@ def chunk_document(
|
||||
current_tok_length
|
||||
+ len(tokenizer.tokenize(SECTION_SEPARATOR))
|
||||
+ section_tok_length
|
||||
<= chunk_tok_size
|
||||
<= content_token_limit
|
||||
):
|
||||
chunk_text += (
|
||||
SECTION_SEPARATOR + section_text if chunk_text else section_text
|
||||
@@ -137,9 +194,11 @@ def chunk_document(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks),
|
||||
blurb=extract_blurb(chunk_text, blurb_size),
|
||||
content=chunk_text,
|
||||
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
|
||||
content_summary=chunk_text,
|
||||
source_links=link_offsets,
|
||||
section_continuation=False,
|
||||
metadata_suffix=metadata_suffix,
|
||||
)
|
||||
)
|
||||
link_offsets = {0: section_link_text}
|
||||
@@ -153,9 +212,11 @@ def chunk_document(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks),
|
||||
blurb=extract_blurb(chunk_text, blurb_size),
|
||||
content=chunk_text,
|
||||
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
|
||||
content_summary=chunk_text,
|
||||
source_links=link_offsets,
|
||||
section_continuation=False,
|
||||
metadata_suffix=metadata_suffix,
|
||||
)
|
||||
)
|
||||
return chunks
|
||||
@@ -164,6 +225,9 @@ def chunk_document(
|
||||
def split_chunk_text_into_mini_chunks(
|
||||
chunk_text: str, mini_chunk_size: int = MINI_CHUNK_SIZE
|
||||
) -> list[str]:
|
||||
"""The minichunks won't all have the title prefix or metadata suffix
|
||||
It could be a significant percentage of every minichunk so better to not include it
|
||||
"""
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
token_count_func = get_default_tokenizer().tokenize
|
||||
|
||||
@@ -14,12 +14,12 @@ from danswer.indexing.chunker import split_chunk_text_into_mini_chunks
|
||||
from danswer.indexing.models import ChunkEmbedding
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
from danswer.search.enums import EmbedTextType
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.utils.batching import batch_list
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
from shared_configs.enums import EmbedTextType
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -50,6 +50,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
normalize: bool,
|
||||
query_prefix: str | None,
|
||||
passage_prefix: str | None,
|
||||
api_key: str | None = None,
|
||||
provider_type: str | None = None,
|
||||
):
|
||||
super().__init__(model_name, normalize, query_prefix, passage_prefix)
|
||||
self.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE # Currently not customizable
|
||||
@@ -59,6 +61,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
query_prefix=query_prefix,
|
||||
passage_prefix=passage_prefix,
|
||||
normalize=normalize,
|
||||
api_key=api_key,
|
||||
provider_type=provider_type,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
@@ -81,7 +85,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
for chunk_ind, chunk in enumerate(chunks):
|
||||
chunk_texts.append(chunk.content)
|
||||
mini_chunk_texts = (
|
||||
split_chunk_text_into_mini_chunks(chunk.content)
|
||||
split_chunk_text_into_mini_chunks(chunk.content_summary)
|
||||
if enable_mini_chunk
|
||||
else []
|
||||
)
|
||||
|
||||
@@ -36,6 +36,16 @@ class DocAwareChunk(BaseChunk):
|
||||
# During inference we only have access to the document id and do not reconstruct the Document
|
||||
source_document: Document
|
||||
|
||||
# The Vespa documents require a separate highlight field. Since it is stored as a duplicate anyway,
|
||||
# it's easier to just store a not prefixed/suffixed string for the highlighting
|
||||
# Also during the chunking, this non-prefixed/suffixed string is used for mini-chunks
|
||||
content_summary: str
|
||||
|
||||
# During indexing we also (optionally) build a metadata string from the metadata dict
|
||||
# This is also indexed so that we can strip it out after indexing, this way it supports
|
||||
# multiple iterations of metadata representation for backwards compatibility
|
||||
metadata_suffix: str
|
||||
|
||||
def to_short_descriptor(self) -> str:
|
||||
"""Used when logging the identity of a chunk"""
|
||||
return (
|
||||
@@ -87,13 +97,19 @@ class EmbeddingModelDetail(BaseModel):
|
||||
normalize: bool
|
||||
query_prefix: str | None
|
||||
passage_prefix: str | None
|
||||
cloud_provider_id: int | None = None
|
||||
cloud_provider_name: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, embedding_model: "EmbeddingModel") -> "EmbeddingModelDetail":
|
||||
def from_model(
|
||||
cls,
|
||||
embedding_model: "EmbeddingModel",
|
||||
) -> "EmbeddingModelDetail":
|
||||
return cls(
|
||||
model_name=embedding_model.model_name,
|
||||
model_dim=embedding_model.model_dim,
|
||||
normalize=embedding_model.normalize,
|
||||
query_prefix=embedding_model.query_prefix,
|
||||
passage_prefix=embedding_model.passage_prefix,
|
||||
cloud_provider_id=embedding_model.cloud_provider_id,
|
||||
)
|
||||
|
||||
@@ -31,6 +31,8 @@ from danswer.llm.answering.stream_processing.citation_processing import (
|
||||
from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
build_quotes_processor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.llm.answering.stream_processing.utils import map_document_id_order
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import message_generator_to_string_generator
|
||||
@@ -43,6 +45,7 @@ from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
from danswer.tools.images.prompt import build_image_generation_user_prompt
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.message import build_tool_message
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS
|
||||
@@ -58,17 +61,22 @@ from danswer.tools.tool_runner import (
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.tool_runner import ToolCallKickoff
|
||||
from danswer.tools.tool_runner import ToolRunner
|
||||
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_answer_stream_processor(
|
||||
context_docs: list[LlmDoc],
|
||||
search_order_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
answer_style_configs: AnswerStyleConfig,
|
||||
) -> StreamProcessor:
|
||||
if answer_style_configs.citation_config:
|
||||
return build_citation_processor(
|
||||
context_docs=context_docs, search_order_docs=search_order_docs
|
||||
context_docs=context_docs, doc_id_to_rank_map=doc_id_to_rank_map
|
||||
)
|
||||
if answer_style_configs.quotes_config:
|
||||
return build_quotes_processor(
|
||||
@@ -81,6 +89,9 @@ def _get_answer_stream_processor(
|
||||
AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse]
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class Answer:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -104,6 +115,7 @@ class Answer:
|
||||
skip_explicit_tool_calling: bool = False,
|
||||
# Returns the full document sections text from the search tool
|
||||
return_contexts: bool = False,
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
) -> None:
|
||||
if single_message_history and message_history:
|
||||
raise ValueError(
|
||||
@@ -132,11 +144,12 @@ class Answer:
|
||||
self._final_prompt: list[BaseMessage] | None = None
|
||||
|
||||
self._streamed_output: list[str] | None = None
|
||||
self._processed_stream: list[
|
||||
AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff
|
||||
] | None = None
|
||||
self._processed_stream: (
|
||||
list[AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff] | None
|
||||
) = None
|
||||
|
||||
self._return_contexts = return_contexts
|
||||
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
|
||||
|
||||
def _update_prompt_builder_for_search_tool(
|
||||
self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc]
|
||||
@@ -228,7 +241,7 @@ class Answer:
|
||||
tool_call_requests = tool_call_chunk.tool_calls
|
||||
for tool_call_request in tool_call_requests:
|
||||
tool = [
|
||||
tool for tool in self.tools if tool.name() == tool_call_request["name"]
|
||||
tool for tool in self.tools if tool.name == tool_call_request["name"]
|
||||
][0]
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
@@ -247,15 +260,14 @@ class Answer:
|
||||
),
|
||||
)
|
||||
|
||||
if tool.name() == SearchTool.NAME:
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
self._update_prompt_builder_for_search_tool(prompt_builder, [])
|
||||
elif tool.name() == ImageGenerationTool.NAME:
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question,
|
||||
)
|
||||
)
|
||||
|
||||
yield tool_runner.tool_final_result()
|
||||
|
||||
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
|
||||
@@ -281,7 +293,7 @@ class Answer:
|
||||
[
|
||||
tool
|
||||
for tool in self.tools
|
||||
if tool.name() == self.force_use_tool.tool_name
|
||||
if tool.name == self.force_use_tool.tool_name
|
||||
]
|
||||
),
|
||||
None,
|
||||
@@ -301,21 +313,39 @@ class Answer:
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name()}' did not return args")
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
|
||||
chosen_tool_and_args = (tool, tool_args)
|
||||
else:
|
||||
all_tool_args = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=self.tools,
|
||||
query=self.question,
|
||||
history=self.message_history,
|
||||
llm=self.llm,
|
||||
)
|
||||
for ind, args in enumerate(all_tool_args):
|
||||
if args is not None:
|
||||
chosen_tool_and_args = (self.tools[ind], args)
|
||||
# for now, just pick the first tool selected
|
||||
break
|
||||
|
||||
available_tools_and_args = [
|
||||
(self.tools[ind], args)
|
||||
for ind, args in enumerate(tool_options)
|
||||
if args is not None
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
|
||||
)
|
||||
|
||||
chosen_tool_and_args = (
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=self.message_history,
|
||||
query=self.question,
|
||||
llm=self.llm,
|
||||
)
|
||||
if available_tools_and_args
|
||||
else None
|
||||
)
|
||||
|
||||
logger.info(f"Chosen tool: {chosen_tool_and_args}")
|
||||
|
||||
if not chosen_tool_and_args:
|
||||
prompt_builder.update_system_prompt(
|
||||
@@ -336,7 +366,7 @@ class Answer:
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
yield tool_runner.kickoff()
|
||||
|
||||
if tool.name() == SearchTool.NAME:
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
final_context_documents = None
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == FINAL_CONTEXT_DOCUMENTS:
|
||||
@@ -344,12 +374,14 @@ class Answer:
|
||||
yield response
|
||||
|
||||
if final_context_documents is None:
|
||||
raise RuntimeError("SearchTool did not return final context documents")
|
||||
raise RuntimeError(
|
||||
f"{tool.name} did not return final context documents"
|
||||
)
|
||||
|
||||
self._update_prompt_builder_for_search_tool(
|
||||
prompt_builder, final_context_documents
|
||||
)
|
||||
elif tool.name() == ImageGenerationTool.NAME:
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = []
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
@@ -371,13 +403,14 @@ class Answer:
|
||||
HumanMessage(
|
||||
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
|
||||
self.question,
|
||||
tool.name(),
|
||||
tool.name,
|
||||
*tool_runner.tool_responses(),
|
||||
)
|
||||
)
|
||||
)
|
||||
final = tool_runner.tool_final_result()
|
||||
|
||||
yield tool_runner.tool_final_result()
|
||||
yield final
|
||||
|
||||
prompt = prompt_builder.build()
|
||||
yield from message_generator_to_string_generator(self.llm.stream(prompt=prompt))
|
||||
@@ -417,6 +450,10 @@ class Answer:
|
||||
yield message
|
||||
elif isinstance(message, ToolResponse):
|
||||
if message.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
# We don't need to run section merging in this flow, this variable is only used
|
||||
# below to specify the ordering of the documents for the purpose of matching
|
||||
# citations to the right search documents. The deduplication logic is more lightweight
|
||||
# there and we don't need to do it twice
|
||||
search_results = [
|
||||
llm_doc_from_inference_section(section)
|
||||
for section in cast(
|
||||
@@ -436,20 +473,23 @@ class Answer:
|
||||
# assumes all tool responses will come first, then the final answer
|
||||
break
|
||||
|
||||
process_answer_stream_fn = _get_answer_stream_processor(
|
||||
context_docs=final_context_docs or [],
|
||||
# if doc selection is enabled, then search_results will be None,
|
||||
# so we need to use the final_context_docs
|
||||
search_order_docs=search_results or final_context_docs or [],
|
||||
answer_style_configs=self.answer_style_config,
|
||||
)
|
||||
if not self.skip_gen_ai_answer_generation:
|
||||
process_answer_stream_fn = _get_answer_stream_processor(
|
||||
context_docs=final_context_docs or [],
|
||||
# if doc selection is enabled, then search_results will be None,
|
||||
# so we need to use the final_context_docs
|
||||
doc_id_to_rank_map=map_document_id_order(
|
||||
search_results or final_context_docs or []
|
||||
),
|
||||
answer_style_configs=self.answer_style_config,
|
||||
)
|
||||
|
||||
def _stream() -> Iterator[str]:
|
||||
if message:
|
||||
yield cast(str, message)
|
||||
yield from cast(Iterator[str], stream)
|
||||
def _stream() -> Iterator[str]:
|
||||
if message:
|
||||
yield cast(str, message)
|
||||
yield from cast(Iterator[str], stream)
|
||||
|
||||
yield from process_answer_stream_fn(_stream())
|
||||
yield from process_answer_stream_fn(_stream())
|
||||
|
||||
processed_stream = []
|
||||
for processed_packet in _process_stream(output_generator):
|
||||
|
||||
@@ -1,230 +0,0 @@
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from typing import TypeVar
|
||||
|
||||
from danswer.chat.models import (
|
||||
LlmDoc,
|
||||
)
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import tokenizer_trim_content
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.tools.search.search_utils import llm_doc_to_dict
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
T = TypeVar("T", bound=LlmDoc | InferenceChunk)
|
||||
|
||||
_METADATA_TOKEN_ESTIMATE = 75
|
||||
|
||||
|
||||
class PruningError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _compute_limit(
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
question: str,
|
||||
max_chunks: int | None,
|
||||
max_window_percentage: float | None,
|
||||
max_tokens: int | None,
|
||||
tool_token_count: int,
|
||||
) -> int:
|
||||
llm_max_document_tokens = compute_max_document_tokens(
|
||||
prompt_config=prompt_config,
|
||||
llm_config=llm_config,
|
||||
tool_token_count=tool_token_count,
|
||||
actual_user_input=question,
|
||||
)
|
||||
|
||||
window_percentage_based_limit = (
|
||||
max_window_percentage * llm_max_document_tokens
|
||||
if max_window_percentage
|
||||
else None
|
||||
)
|
||||
chunk_count_based_limit = (
|
||||
max_chunks * DOC_EMBEDDING_CONTEXT_SIZE if max_chunks else None
|
||||
)
|
||||
|
||||
limit_options = [
|
||||
lim
|
||||
for lim in [
|
||||
window_percentage_based_limit,
|
||||
chunk_count_based_limit,
|
||||
max_tokens,
|
||||
llm_max_document_tokens,
|
||||
]
|
||||
if lim
|
||||
]
|
||||
return int(min(limit_options))
|
||||
|
||||
|
||||
def reorder_docs(
|
||||
docs: list[T],
|
||||
doc_relevance_list: list[bool] | None,
|
||||
) -> list[T]:
|
||||
if doc_relevance_list is None:
|
||||
return docs
|
||||
|
||||
reordered_docs: list[T] = []
|
||||
if doc_relevance_list is not None:
|
||||
for selection_target in [True, False]:
|
||||
for doc, is_relevant in zip(docs, doc_relevance_list):
|
||||
if is_relevant == selection_target:
|
||||
reordered_docs.append(doc)
|
||||
return reordered_docs
|
||||
|
||||
|
||||
def _remove_docs_to_ignore(docs: list[LlmDoc]) -> list[LlmDoc]:
|
||||
return [doc for doc in docs if not doc.metadata.get(IGNORE_FOR_QA)]
|
||||
|
||||
|
||||
def _apply_pruning(
|
||||
docs: list[LlmDoc],
|
||||
doc_relevance_list: list[bool] | None,
|
||||
token_limit: int,
|
||||
is_manually_selected_docs: bool,
|
||||
use_sections: bool,
|
||||
using_tool_message: bool,
|
||||
) -> list[LlmDoc]:
|
||||
llm_tokenizer = get_default_llm_tokenizer()
|
||||
docs = deepcopy(docs) # don't modify in place
|
||||
|
||||
# re-order docs with all the "relevant" docs at the front
|
||||
docs = reorder_docs(docs=docs, doc_relevance_list=doc_relevance_list)
|
||||
# remove docs that are explicitly marked as not for QA
|
||||
docs = _remove_docs_to_ignore(docs=docs)
|
||||
|
||||
tokens_per_doc: list[int] = []
|
||||
final_doc_ind = None
|
||||
total_tokens = 0
|
||||
for ind, llm_doc in enumerate(docs):
|
||||
doc_str = (
|
||||
json.dumps(llm_doc_to_dict(llm_doc, ind))
|
||||
if using_tool_message
|
||||
else build_doc_context_str(
|
||||
semantic_identifier=llm_doc.semantic_identifier,
|
||||
source_type=llm_doc.source_type,
|
||||
content=llm_doc.content,
|
||||
metadata_dict=llm_doc.metadata,
|
||||
updated_at=llm_doc.updated_at,
|
||||
ind=ind,
|
||||
)
|
||||
)
|
||||
|
||||
doc_tokens = len(llm_tokenizer.encode(doc_str))
|
||||
# if chunks, truncate chunks that are way too long
|
||||
# this can happen if the embedding model tokenizer is different
|
||||
# than the LLM tokenizer
|
||||
if (
|
||||
not is_manually_selected_docs
|
||||
and not use_sections
|
||||
and doc_tokens > DOC_EMBEDDING_CONTEXT_SIZE + _METADATA_TOKEN_ESTIMATE
|
||||
):
|
||||
logger.warning(
|
||||
"Found more tokens in chunk than expected, "
|
||||
"likely mismatch between embedding and LLM tokenizers. Trimming content..."
|
||||
)
|
||||
llm_doc.content = tokenizer_trim_content(
|
||||
content=llm_doc.content,
|
||||
desired_length=DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
doc_tokens = DOC_EMBEDDING_CONTEXT_SIZE
|
||||
tokens_per_doc.append(doc_tokens)
|
||||
total_tokens += doc_tokens
|
||||
if total_tokens > token_limit:
|
||||
final_doc_ind = ind
|
||||
break
|
||||
|
||||
if final_doc_ind is not None:
|
||||
if is_manually_selected_docs or use_sections:
|
||||
# for document selection, only allow the final document to get truncated
|
||||
# if more than that, then the user message is too long
|
||||
if final_doc_ind != len(docs) - 1:
|
||||
if use_sections:
|
||||
# Truncate the rest of the list since we're over the token limit
|
||||
# for the last one, trim it. In this case, the Sections can be rather long
|
||||
# so better to trim the back than throw away the whole thing.
|
||||
docs = docs[: final_doc_ind + 1]
|
||||
else:
|
||||
raise PruningError(
|
||||
"LLM context window exceeded. Please de-select some documents or shorten your query."
|
||||
)
|
||||
|
||||
amount_to_truncate = total_tokens - token_limit
|
||||
# NOTE: need to recalculate the length here, since the previous calculation included
|
||||
# overhead from JSON-fying the doc / the metadata
|
||||
final_doc_content_length = len(
|
||||
llm_tokenizer.encode(docs[final_doc_ind].content)
|
||||
) - (amount_to_truncate)
|
||||
# this could occur if we only have space for the title / metadata
|
||||
# not ideal, but it's the most reasonable thing to do
|
||||
# NOTE: the frontend prevents documents from being selected if
|
||||
# less than 75 tokens are available to try and avoid this situation
|
||||
# from occurring in the first place
|
||||
if final_doc_content_length <= 0:
|
||||
logger.error(
|
||||
f"Final doc ({docs[final_doc_ind].semantic_identifier}) content "
|
||||
"length is less than 0. Removing this doc from the final prompt."
|
||||
)
|
||||
docs.pop()
|
||||
else:
|
||||
docs[final_doc_ind].content = tokenizer_trim_content(
|
||||
content=docs[final_doc_ind].content,
|
||||
desired_length=final_doc_content_length,
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
else:
|
||||
# For regular search, don't truncate the final document unless it's the only one
|
||||
# If it's not the only one, we can throw it away, if it's the only one, we have to truncate
|
||||
if final_doc_ind != 0:
|
||||
docs = docs[:final_doc_ind]
|
||||
else:
|
||||
docs[0].content = tokenizer_trim_content(
|
||||
content=docs[0].content,
|
||||
desired_length=token_limit - _METADATA_TOKEN_ESTIMATE,
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
docs = [docs[0]]
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def prune_documents(
|
||||
docs: list[LlmDoc],
|
||||
doc_relevance_list: list[bool] | None,
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
question: str,
|
||||
document_pruning_config: DocumentPruningConfig,
|
||||
) -> list[LlmDoc]:
|
||||
if doc_relevance_list is not None:
|
||||
assert len(docs) == len(doc_relevance_list)
|
||||
|
||||
doc_token_limit = _compute_limit(
|
||||
prompt_config=prompt_config,
|
||||
llm_config=llm_config,
|
||||
question=question,
|
||||
max_chunks=document_pruning_config.max_chunks,
|
||||
max_window_percentage=document_pruning_config.max_window_percentage,
|
||||
max_tokens=document_pruning_config.max_tokens,
|
||||
tool_token_count=document_pruning_config.tool_num_tokens,
|
||||
)
|
||||
return _apply_pruning(
|
||||
docs=docs,
|
||||
doc_relevance_list=doc_relevance_list,
|
||||
token_limit=doc_token_limit,
|
||||
is_manually_selected_docs=document_pruning_config.is_manually_selected_docs,
|
||||
use_sections=document_pruning_config.use_sections,
|
||||
using_tool_message=document_pruning_config.using_tool_message,
|
||||
)
|
||||
@@ -70,9 +70,11 @@ class DocumentPruningConfig(BaseModel):
|
||||
# e.g. we don't want to truncate each document to be no more
|
||||
# than one chunk long
|
||||
is_manually_selected_docs: bool = False
|
||||
# If user specifies to include additional context chunks for each match, then different pruning
|
||||
# If user specifies to include additional context Chunks for each match, then different pruning
|
||||
# is used. As many Sections as possible are included, and the last Section is truncated
|
||||
use_sections: bool = False
|
||||
# If this is false, all of the Sections are truncated if they are longer than the expected Chunk size.
|
||||
# Sections are often expected to be longer than the maximum Chunk size but Chunks should not be.
|
||||
use_sections: bool = True
|
||||
# If using tools, then we need to consider the tool length
|
||||
tool_num_tokens: int = 0
|
||||
# If using a tool message to represent the docs, then we have to JSON serialize
|
||||
|
||||
360
backend/danswer/llm/answering/prune_and_merge.py
Normal file
360
backend/danswer/llm/answering/prune_and_merge.py
Normal file
@@ -0,0 +1,360 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.models import (
|
||||
LlmDoc,
|
||||
)
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import tokenizer_trim_content
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.tools.search.search_utils import section_to_dict
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
T = TypeVar("T", bound=LlmDoc | InferenceChunk | InferenceSection)
|
||||
|
||||
_METADATA_TOKEN_ESTIMATE = 75
|
||||
|
||||
|
||||
class PruningError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ChunkRange(BaseModel):
|
||||
chunks: list[InferenceChunk]
|
||||
start: int
|
||||
end: int
|
||||
|
||||
|
||||
def merge_chunk_intervals(chunk_ranges: list[ChunkRange]) -> list[ChunkRange]:
|
||||
"""
|
||||
This acts on a single document to merge the overlapping ranges of chunks
|
||||
Algo explained here for easy understanding: https://leetcode.com/problems/merge-intervals
|
||||
|
||||
NOTE: this is used to merge chunk ranges for retrieving the right chunk_ids against the
|
||||
document index, this does not merge the actual contents so it should not be used to actually
|
||||
merge chunks post retrieval.
|
||||
"""
|
||||
sorted_ranges = sorted(chunk_ranges, key=lambda x: x.start)
|
||||
|
||||
combined_ranges: list[ChunkRange] = []
|
||||
|
||||
for new_chunk_range in sorted_ranges:
|
||||
if not combined_ranges or combined_ranges[-1].end < new_chunk_range.start - 1:
|
||||
combined_ranges.append(new_chunk_range)
|
||||
else:
|
||||
current_range = combined_ranges[-1]
|
||||
current_range.end = max(current_range.end, new_chunk_range.end)
|
||||
current_range.chunks.extend(new_chunk_range.chunks)
|
||||
|
||||
return combined_ranges
|
||||
|
||||
|
||||
def _compute_limit(
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
question: str,
|
||||
max_chunks: int | None,
|
||||
max_window_percentage: float | None,
|
||||
max_tokens: int | None,
|
||||
tool_token_count: int,
|
||||
) -> int:
|
||||
llm_max_document_tokens = compute_max_document_tokens(
|
||||
prompt_config=prompt_config,
|
||||
llm_config=llm_config,
|
||||
tool_token_count=tool_token_count,
|
||||
actual_user_input=question,
|
||||
)
|
||||
|
||||
window_percentage_based_limit = (
|
||||
max_window_percentage * llm_max_document_tokens
|
||||
if max_window_percentage
|
||||
else None
|
||||
)
|
||||
chunk_count_based_limit = (
|
||||
max_chunks * DOC_EMBEDDING_CONTEXT_SIZE if max_chunks else None
|
||||
)
|
||||
|
||||
limit_options = [
|
||||
lim
|
||||
for lim in [
|
||||
window_percentage_based_limit,
|
||||
chunk_count_based_limit,
|
||||
max_tokens,
|
||||
llm_max_document_tokens,
|
||||
]
|
||||
if lim
|
||||
]
|
||||
return int(min(limit_options))
|
||||
|
||||
|
||||
def reorder_sections(
|
||||
sections: list[InferenceSection],
|
||||
section_relevance_list: list[bool] | None,
|
||||
) -> list[InferenceSection]:
|
||||
if section_relevance_list is None:
|
||||
return sections
|
||||
|
||||
reordered_sections: list[InferenceSection] = []
|
||||
if section_relevance_list is not None:
|
||||
for selection_target in [True, False]:
|
||||
for section, is_relevant in zip(sections, section_relevance_list):
|
||||
if is_relevant == selection_target:
|
||||
reordered_sections.append(section)
|
||||
return reordered_sections
|
||||
|
||||
|
||||
def _remove_sections_to_ignore(
|
||||
sections: list[InferenceSection],
|
||||
) -> list[InferenceSection]:
|
||||
return [
|
||||
section
|
||||
for section in sections
|
||||
if not section.center_chunk.metadata.get(IGNORE_FOR_QA)
|
||||
]
|
||||
|
||||
|
||||
def _apply_pruning(
|
||||
sections: list[InferenceSection],
|
||||
section_relevance_list: list[bool] | None,
|
||||
token_limit: int,
|
||||
is_manually_selected_docs: bool,
|
||||
use_sections: bool,
|
||||
using_tool_message: bool,
|
||||
) -> list[InferenceSection]:
|
||||
llm_tokenizer = get_default_llm_tokenizer()
|
||||
sections = deepcopy(sections) # don't modify in place
|
||||
|
||||
# re-order docs with all the "relevant" docs at the front
|
||||
sections = reorder_sections(
|
||||
sections=sections, section_relevance_list=section_relevance_list
|
||||
)
|
||||
# remove docs that are explicitly marked as not for QA
|
||||
sections = _remove_sections_to_ignore(sections=sections)
|
||||
|
||||
final_section_ind = None
|
||||
total_tokens = 0
|
||||
for ind, section in enumerate(sections):
|
||||
section_str = (
|
||||
# If using tool message, it will be a bit of an overestimate as the extra json text around the section
|
||||
# will be counted towards the token count. However, once the Sections are merged, the extra json parts
|
||||
# that overlap will not be counted multiple times like it is in the pruning step.
|
||||
json.dumps(section_to_dict(section, ind))
|
||||
if using_tool_message
|
||||
else build_doc_context_str(
|
||||
semantic_identifier=section.center_chunk.semantic_identifier,
|
||||
source_type=section.center_chunk.source_type,
|
||||
content=section.combined_content,
|
||||
metadata_dict=section.center_chunk.metadata,
|
||||
updated_at=section.center_chunk.updated_at,
|
||||
ind=ind,
|
||||
)
|
||||
)
|
||||
|
||||
section_tokens = len(llm_tokenizer.encode(section_str))
|
||||
# if not using sections (specifically, using Sections where each section maps exactly to the one center chunk),
|
||||
# truncate chunks that are way too long. This can happen if the embedding model tokenizer is different
|
||||
# than the LLM tokenizer
|
||||
if (
|
||||
not is_manually_selected_docs
|
||||
and not use_sections
|
||||
and section_tokens > DOC_EMBEDDING_CONTEXT_SIZE + _METADATA_TOKEN_ESTIMATE
|
||||
):
|
||||
logger.warning(
|
||||
"Found more tokens in Section than expected, "
|
||||
"likely mismatch between embedding and LLM tokenizers. Trimming content..."
|
||||
)
|
||||
section.combined_content = tokenizer_trim_content(
|
||||
content=section.combined_content,
|
||||
desired_length=DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
section_tokens = DOC_EMBEDDING_CONTEXT_SIZE
|
||||
|
||||
total_tokens += section_tokens
|
||||
if total_tokens > token_limit:
|
||||
final_section_ind = ind
|
||||
break
|
||||
|
||||
if final_section_ind is not None:
|
||||
if is_manually_selected_docs or use_sections:
|
||||
if final_section_ind != len(sections) - 1:
|
||||
# If using Sections, then the final section could be more than we need, in this case we are willing to
|
||||
# truncate the final section to fit the specified context window
|
||||
sections = sections[: final_section_ind + 1]
|
||||
|
||||
if is_manually_selected_docs:
|
||||
# For document selection flow, only allow the final document/section to get truncated
|
||||
# if more than that needs to be throw away then some documents are completely thrown away in which
|
||||
# case this should be reported to the user as an error
|
||||
raise PruningError(
|
||||
"LLM context window exceeded. Please de-select some documents or shorten your query."
|
||||
)
|
||||
|
||||
amount_to_truncate = total_tokens - token_limit
|
||||
# NOTE: need to recalculate the length here, since the previous calculation included
|
||||
# overhead from JSON-fying the doc / the metadata
|
||||
final_doc_content_length = len(
|
||||
llm_tokenizer.encode(sections[final_section_ind].combined_content)
|
||||
) - (amount_to_truncate)
|
||||
# this could occur if we only have space for the title / metadata
|
||||
# not ideal, but it's the most reasonable thing to do
|
||||
# NOTE: the frontend prevents documents from being selected if
|
||||
# less than 75 tokens are available to try and avoid this situation
|
||||
# from occurring in the first place
|
||||
if final_doc_content_length <= 0:
|
||||
logger.error(
|
||||
f"Final section ({sections[final_section_ind].center_chunk.semantic_identifier}) content "
|
||||
"length is less than 0. Removing this section from the final prompt."
|
||||
)
|
||||
sections.pop()
|
||||
else:
|
||||
sections[final_section_ind].combined_content = tokenizer_trim_content(
|
||||
content=sections[final_section_ind].combined_content,
|
||||
desired_length=final_doc_content_length,
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
else:
|
||||
# For search on chunk level (Section is just a chunk), don't truncate the final Chunk/Section unless it's the only one
|
||||
# If it's not the only one, we can throw it away, if it's the only one, we have to truncate
|
||||
if final_section_ind != 0:
|
||||
sections = sections[:final_section_ind]
|
||||
else:
|
||||
sections[0].combined_content = tokenizer_trim_content(
|
||||
content=sections[0].combined_content,
|
||||
desired_length=token_limit - _METADATA_TOKEN_ESTIMATE,
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
sections = [sections[0]]
|
||||
|
||||
return sections
|
||||
|
||||
|
||||
def prune_sections(
|
||||
sections: list[InferenceSection],
|
||||
section_relevance_list: list[bool] | None,
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
question: str,
|
||||
document_pruning_config: DocumentPruningConfig,
|
||||
) -> list[InferenceSection]:
|
||||
# Assumes the sections are score ordered with highest first
|
||||
if section_relevance_list is not None:
|
||||
assert len(sections) == len(section_relevance_list)
|
||||
|
||||
token_limit = _compute_limit(
|
||||
prompt_config=prompt_config,
|
||||
llm_config=llm_config,
|
||||
question=question,
|
||||
max_chunks=document_pruning_config.max_chunks,
|
||||
max_window_percentage=document_pruning_config.max_window_percentage,
|
||||
max_tokens=document_pruning_config.max_tokens,
|
||||
tool_token_count=document_pruning_config.tool_num_tokens,
|
||||
)
|
||||
|
||||
return _apply_pruning(
|
||||
sections=sections,
|
||||
section_relevance_list=section_relevance_list,
|
||||
token_limit=token_limit,
|
||||
is_manually_selected_docs=document_pruning_config.is_manually_selected_docs,
|
||||
use_sections=document_pruning_config.use_sections, # Now default True
|
||||
using_tool_message=document_pruning_config.using_tool_message,
|
||||
)
|
||||
|
||||
|
||||
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
|
||||
# Assuming there are no duplicates by this point
|
||||
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)
|
||||
|
||||
center_chunk = max(
|
||||
chunks, key=lambda x: x.score if x.score is not None else float("-inf")
|
||||
)
|
||||
|
||||
merged_content = []
|
||||
for i, chunk in enumerate(sorted_chunks):
|
||||
if i > 0:
|
||||
prev_chunk_id = sorted_chunks[i - 1].chunk_id
|
||||
if chunk.chunk_id == prev_chunk_id + 1:
|
||||
merged_content.append("\n")
|
||||
else:
|
||||
merged_content.append("\n\n...\n\n")
|
||||
merged_content.append(chunk.content)
|
||||
|
||||
combined_content = "".join(merged_content)
|
||||
|
||||
return InferenceSection(
|
||||
center_chunk=center_chunk,
|
||||
chunks=sorted_chunks,
|
||||
combined_content=combined_content,
|
||||
)
|
||||
|
||||
|
||||
def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
|
||||
docs_map: dict[str, dict[int, InferenceChunk]] = defaultdict(dict)
|
||||
doc_order: dict[str, int] = {}
|
||||
for index, section in enumerate(sections):
|
||||
if section.center_chunk.document_id not in doc_order:
|
||||
doc_order[section.center_chunk.document_id] = index
|
||||
for chunk in [section.center_chunk] + section.chunks:
|
||||
chunks_map = docs_map[section.center_chunk.document_id]
|
||||
existing_chunk = chunks_map.get(chunk.chunk_id)
|
||||
if (
|
||||
existing_chunk is None
|
||||
or existing_chunk.score is None
|
||||
or chunk.score is not None
|
||||
and chunk.score > existing_chunk.score
|
||||
):
|
||||
chunks_map[chunk.chunk_id] = chunk
|
||||
|
||||
new_sections = []
|
||||
for section_chunks in docs_map.values():
|
||||
new_sections.append(_merge_doc_chunks(chunks=list(section_chunks.values())))
|
||||
|
||||
# Sort by highest score, then by original document order
|
||||
# It is now 1 large section per doc, the center chunk being the one with the highest score
|
||||
new_sections.sort(
|
||||
key=lambda x: (
|
||||
x.center_chunk.score if x.center_chunk.score is not None else 0,
|
||||
-1 * doc_order[x.center_chunk.document_id],
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
return new_sections
|
||||
|
||||
|
||||
def prune_and_merge_sections(
|
||||
sections: list[InferenceSection],
|
||||
section_relevance_list: list[bool] | None,
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
question: str,
|
||||
document_pruning_config: DocumentPruningConfig,
|
||||
) -> list[InferenceSection]:
|
||||
# Assumes the sections are score ordered with highest first
|
||||
remaining_sections = prune_sections(
|
||||
sections=sections,
|
||||
section_relevance_list=section_relevance_list,
|
||||
prompt_config=prompt_config,
|
||||
llm_config=llm_config,
|
||||
question=question,
|
||||
document_pruning_config=document_pruning_config,
|
||||
)
|
||||
|
||||
merged_sections = _merge_sections(sections=remaining_sections)
|
||||
|
||||
return merged_sections
|
||||
@@ -7,7 +7,7 @@ from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import STOP_STREAM_PAT
|
||||
from danswer.llm.answering.models import StreamProcessor
|
||||
from danswer.llm.answering.stream_processing.utils import map_document_id_order
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.prompts.constants import TRIPLE_BACKTICK
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -23,104 +23,167 @@ def in_code_block(llm_text: str) -> bool:
|
||||
def extract_citations_from_stream(
|
||||
tokens: Iterator[str],
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: dict[str, int],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
|
||||
"""
|
||||
Key aspects:
|
||||
|
||||
1. Stream Processing:
|
||||
- Processes tokens one by one, allowing for real-time handling of large texts.
|
||||
|
||||
2. Citation Detection:
|
||||
- Uses regex to find citations in the format [number].
|
||||
- Example: [1], [2], etc.
|
||||
|
||||
3. Citation Mapping:
|
||||
- Maps detected citation numbers to actual document ranks using doc_id_to_rank_map.
|
||||
- Example: [1] might become [3] if doc_id_to_rank_map maps it to 3.
|
||||
|
||||
4. Citation Formatting:
|
||||
- Replaces citations with properly formatted versions.
|
||||
- Adds links if available: [[1]](https://example.com)
|
||||
- Handles cases where links are not available: [[1]]()
|
||||
|
||||
5. Duplicate Handling:
|
||||
- Skips consecutive citations of the same document to avoid redundancy.
|
||||
|
||||
6. Output Generation:
|
||||
- Yields DanswerAnswerPiece objects for regular text.
|
||||
- Yields CitationInfo objects for each unique citation encountered.
|
||||
|
||||
7. Context Awareness:
|
||||
- Uses context_docs to access document information for citations.
|
||||
|
||||
This function effectively processes a stream of text, identifies and reformats citations,
|
||||
and provides both the processed text and citation information as output.
|
||||
"""
|
||||
order_mapping = doc_id_to_rank_map.order_mapping
|
||||
llm_out = ""
|
||||
max_citation_num = len(context_docs)
|
||||
citation_order = []
|
||||
curr_segment = ""
|
||||
prepend_bracket = False
|
||||
cited_inds = set()
|
||||
hold = ""
|
||||
|
||||
raw_out = ""
|
||||
current_citations: list[int] = []
|
||||
past_cite_count = 0
|
||||
for raw_token in tokens:
|
||||
raw_out += raw_token
|
||||
if stop_stream:
|
||||
next_hold = hold + raw_token
|
||||
|
||||
if stop_stream in next_hold:
|
||||
break
|
||||
|
||||
if next_hold == stop_stream[: len(next_hold)]:
|
||||
hold = next_hold
|
||||
continue
|
||||
|
||||
token = next_hold
|
||||
hold = ""
|
||||
else:
|
||||
token = raw_token
|
||||
|
||||
# Special case of [1][ where ][ is a single token
|
||||
# This is where the model attempts to do consecutive citations like [1][2]
|
||||
if prepend_bracket:
|
||||
curr_segment += "[" + curr_segment
|
||||
prepend_bracket = False
|
||||
|
||||
curr_segment += token
|
||||
llm_out += token
|
||||
|
||||
citation_pattern = r"\[(\d+)\]"
|
||||
|
||||
citations_found = list(re.finditer(citation_pattern, curr_segment))
|
||||
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
|
||||
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
|
||||
|
||||
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
|
||||
citation_found = re.search(citation_pattern, curr_segment)
|
||||
# `past_cite_count`: number of characters since past citation
|
||||
# 5 to ensure a citation hasn't occured
|
||||
if len(citations_found) == 0 and len(llm_out) - past_cite_count > 5:
|
||||
current_citations = []
|
||||
|
||||
if citation_found and not in_code_block(llm_out):
|
||||
numerical_value = int(citation_found.group(1))
|
||||
if 1 <= numerical_value <= max_citation_num:
|
||||
context_llm_doc = context_docs[
|
||||
numerical_value - 1
|
||||
] # remove 1 index offset
|
||||
if citations_found and not in_code_block(llm_out):
|
||||
last_citation_end = 0
|
||||
length_to_add = 0
|
||||
while len(citations_found) > 0:
|
||||
citation = citations_found.pop(0)
|
||||
numerical_value = int(citation.group(1))
|
||||
|
||||
link = context_llm_doc.link
|
||||
target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id]
|
||||
if 1 <= numerical_value <= max_citation_num:
|
||||
context_llm_doc = context_docs[numerical_value - 1]
|
||||
real_citation_num = order_mapping[context_llm_doc.document_id]
|
||||
|
||||
# Use the citation number for the document's rank in
|
||||
# the search (or selected docs) results
|
||||
curr_segment = re.sub(
|
||||
rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment
|
||||
)
|
||||
if real_citation_num not in citation_order:
|
||||
citation_order.append(real_citation_num)
|
||||
|
||||
if target_citation_num not in cited_inds:
|
||||
cited_inds.add(target_citation_num)
|
||||
yield CitationInfo(
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
target_citation_num = citation_order.index(real_citation_num) + 1
|
||||
|
||||
# Skip consecutive citations of the same work
|
||||
if target_citation_num in current_citations:
|
||||
start, end = citation.span()
|
||||
real_start = length_to_add + start
|
||||
diff = end - start
|
||||
curr_segment = (
|
||||
curr_segment[: length_to_add + start]
|
||||
+ curr_segment[real_start + diff :]
|
||||
)
|
||||
length_to_add -= diff
|
||||
continue
|
||||
|
||||
link = context_llm_doc.link
|
||||
|
||||
# Replace the citation in the current segment
|
||||
start, end = citation.span()
|
||||
curr_segment = (
|
||||
curr_segment[: start + length_to_add]
|
||||
+ f"[{target_citation_num}]"
|
||||
+ curr_segment[end + length_to_add :]
|
||||
)
|
||||
|
||||
if link:
|
||||
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1)
|
||||
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
|
||||
past_cite_count = len(llm_out)
|
||||
current_citations.append(target_citation_num)
|
||||
|
||||
# In case there's another open bracket like [1][, don't want to match this
|
||||
possible_citation_found = None
|
||||
if target_citation_num not in cited_inds:
|
||||
cited_inds.add(target_citation_num)
|
||||
yield CitationInfo(
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
# if we see "[", but haven't seen the right side, hold back - this may be a
|
||||
# citation that needs to be replaced with a link
|
||||
if link:
|
||||
prev_length = len(curr_segment)
|
||||
curr_segment = (
|
||||
curr_segment[: start + length_to_add]
|
||||
+ f"[[{target_citation_num}]]({link})"
|
||||
+ curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(curr_segment) - prev_length
|
||||
|
||||
else:
|
||||
prev_length = len(curr_segment)
|
||||
curr_segment = (
|
||||
curr_segment[: start + length_to_add]
|
||||
+ f"[[{target_citation_num}]]()"
|
||||
+ curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(curr_segment) - prev_length
|
||||
last_citation_end = end + length_to_add
|
||||
|
||||
if last_citation_end > 0:
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment[:last_citation_end])
|
||||
curr_segment = curr_segment[last_citation_end:]
|
||||
if possible_citation_found:
|
||||
continue
|
||||
|
||||
# Special case with back to back citations [1][2]
|
||||
if curr_segment and curr_segment[-1] == "[":
|
||||
curr_segment = curr_segment[:-1]
|
||||
prepend_bracket = True
|
||||
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||
curr_segment = ""
|
||||
|
||||
if curr_segment:
|
||||
if prepend_bracket:
|
||||
yield DanswerAnswerPiece(answer_piece="[" + curr_segment)
|
||||
else:
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||
|
||||
|
||||
def build_citation_processor(
|
||||
context_docs: list[LlmDoc], search_order_docs: list[LlmDoc]
|
||||
context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
|
||||
) -> StreamProcessor:
|
||||
def stream_processor(tokens: Iterator[str]) -> AnswerQuestionStreamReturn:
|
||||
yield from extract_citations_from_stream(
|
||||
tokens=tokens,
|
||||
context_docs=context_docs,
|
||||
doc_id_to_rank_map=map_document_id_order(search_order_docs),
|
||||
doc_id_to_rank_map=doc_id_to_rank_map,
|
||||
)
|
||||
|
||||
return stream_processor
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.search.models import InferenceChunk
|
||||
|
||||
|
||||
class DocumentIdOrderMapping(BaseModel):
|
||||
order_mapping: dict[str, int]
|
||||
|
||||
|
||||
def map_document_id_order(
|
||||
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
|
||||
) -> dict[str, int]:
|
||||
) -> DocumentIdOrderMapping:
|
||||
order_mapping = {}
|
||||
current = 1 if one_indexed else 0
|
||||
for chunk in chunks:
|
||||
@@ -14,4 +20,4 @@ def map_document_id_order(
|
||||
order_mapping[chunk.document_id] = current
|
||||
current += 1
|
||||
|
||||
return order_mapping
|
||||
return DocumentIdOrderMapping(order_mapping=order_mapping)
|
||||
|
||||
@@ -23,6 +23,7 @@ from langchain_core.messages.tool import ToolCallChunk
|
||||
from langchain_core.messages.tool import ToolMessage
|
||||
|
||||
from danswer.configs.app_configs import LOG_ALL_MODEL_INTERACTIONS
|
||||
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
|
||||
from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING
|
||||
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_API_VERSION
|
||||
@@ -42,6 +43,8 @@ logger = setup_logger()
|
||||
litellm.drop_params = True
|
||||
litellm.telemetry = False
|
||||
|
||||
litellm.set_verbose = LOG_ALL_MODEL_INTERACTIONS
|
||||
|
||||
|
||||
def _base_msg_to_role(msg: BaseMessage) -> str:
|
||||
if isinstance(msg, HumanMessage) or isinstance(msg, HumanMessageChunk):
|
||||
@@ -229,32 +232,6 @@ class DefaultMultiLLM(LLM):
|
||||
|
||||
self._model_kwargs = model_kwargs
|
||||
|
||||
@staticmethod
|
||||
def _log_prompt(prompt: LanguageModelInput) -> None:
|
||||
if isinstance(prompt, list):
|
||||
for ind, msg in enumerate(prompt):
|
||||
if isinstance(msg, AIMessageChunk):
|
||||
if msg.content:
|
||||
log_msg = msg.content
|
||||
elif msg.tool_call_chunks:
|
||||
log_msg = "Tool Calls: " + str(
|
||||
[
|
||||
{
|
||||
key: value
|
||||
for key, value in tool_call.items()
|
||||
if key != "index"
|
||||
}
|
||||
for tool_call in msg.tool_call_chunks
|
||||
]
|
||||
)
|
||||
else:
|
||||
log_msg = ""
|
||||
logger.debug(f"Message {ind}:\n{log_msg}")
|
||||
else:
|
||||
logger.debug(f"Message {ind}:\n{msg.content}")
|
||||
if isinstance(prompt, str):
|
||||
logger.debug(f"Prompt:\n{prompt}")
|
||||
|
||||
def log_model_configs(self) -> None:
|
||||
logger.info(f"Config: {self.config}")
|
||||
|
||||
@@ -304,17 +281,18 @@ class DefaultMultiLLM(LLM):
|
||||
model_name=self._model_version,
|
||||
temperature=self._temperature,
|
||||
api_key=self._api_key,
|
||||
api_base=self._api_base,
|
||||
api_version=self._api_version,
|
||||
)
|
||||
|
||||
def invoke(
|
||||
def _invoke_implementation(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
) -> BaseMessage:
|
||||
if LOG_ALL_MODEL_INTERACTIONS:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
self._log_prompt(prompt)
|
||||
|
||||
response = cast(
|
||||
litellm.ModelResponse, self._completion(prompt, tools, tool_choice, False)
|
||||
@@ -323,15 +301,14 @@ class DefaultMultiLLM(LLM):
|
||||
response.choices[0].message
|
||||
)
|
||||
|
||||
def stream(
|
||||
def _stream_implementation(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
if LOG_ALL_MODEL_INTERACTIONS:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
self._log_prompt(prompt)
|
||||
|
||||
if DISABLE_LITELLM_STREAMING:
|
||||
yield self.invoke(prompt)
|
||||
@@ -357,7 +334,7 @@ class DefaultMultiLLM(LLM):
|
||||
"The AI model failed partway through generation, please try again."
|
||||
)
|
||||
|
||||
if LOG_ALL_MODEL_INTERACTIONS and output:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS and output:
|
||||
content = output.content or ""
|
||||
if isinstance(output, AIMessage):
|
||||
if content:
|
||||
|
||||
@@ -76,7 +76,7 @@ class CustomModelServer(LLM):
|
||||
def log_model_configs(self) -> None:
|
||||
logger.debug(f"Custom model at: {self._endpoint}")
|
||||
|
||||
def invoke(
|
||||
def _invoke_implementation(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
@@ -84,7 +84,7 @@ class CustomModelServer(LLM):
|
||||
) -> BaseMessage:
|
||||
return self._execute(prompt)
|
||||
|
||||
def stream(
|
||||
def _stream_implementation(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
|
||||
@@ -3,9 +3,12 @@ from collections.abc import Iterator
|
||||
from typing import Literal
|
||||
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -19,6 +22,34 @@ class LLMConfig(BaseModel):
|
||||
model_name: str
|
||||
temperature: float
|
||||
api_key: str | None
|
||||
api_base: str | None
|
||||
api_version: str | None
|
||||
|
||||
|
||||
def log_prompt(prompt: LanguageModelInput) -> None:
|
||||
if isinstance(prompt, list):
|
||||
for ind, msg in enumerate(prompt):
|
||||
if isinstance(msg, AIMessageChunk):
|
||||
if msg.content:
|
||||
log_msg = msg.content
|
||||
elif msg.tool_call_chunks:
|
||||
log_msg = "Tool Calls: " + str(
|
||||
[
|
||||
{
|
||||
key: value
|
||||
for key, value in tool_call.items()
|
||||
if key != "index"
|
||||
}
|
||||
for tool_call in msg.tool_call_chunks
|
||||
]
|
||||
)
|
||||
else:
|
||||
log_msg = ""
|
||||
logger.debug(f"Message {ind}:\n{log_msg}")
|
||||
else:
|
||||
logger.debug(f"Message {ind}:\n{msg.content}")
|
||||
if isinstance(prompt, str):
|
||||
logger.debug(f"Prompt:\n{prompt}")
|
||||
|
||||
|
||||
class LLM(abc.ABC):
|
||||
@@ -43,20 +74,48 @@ class LLM(abc.ABC):
|
||||
def log_model_configs(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def _precall(self, prompt: LanguageModelInput) -> None:
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
raise Exception("Generative AI is disabled")
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
log_prompt(prompt)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
) -> BaseMessage:
|
||||
self._precall(prompt)
|
||||
# TODO add a postcall to log model outputs independent of concrete class
|
||||
# implementation
|
||||
return self._invoke_implementation(prompt, tools, tool_choice)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _invoke_implementation(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
) -> BaseMessage:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def stream(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
self._precall(prompt)
|
||||
# TODO add a postcall to log model outputs independent of concrete class
|
||||
# implementation
|
||||
return self._stream_implementation(prompt, tools, tool_choice)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _stream_implementation(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -26,6 +26,7 @@ OPENAI_PROVIDER_NAME = "openai"
|
||||
OPEN_AI_MODEL_NAMES = [
|
||||
"gpt-4",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-1106-preview",
|
||||
|
||||
@@ -46,6 +46,7 @@ from danswer.db.engine import warm_up_connections
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from danswer.db.index_attempt import expire_index_attempts
|
||||
from danswer.db.persona import delete_old_default_personas
|
||||
from danswer.db.standard_answer import create_initial_default_standard_answer_category
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.llm.llm_initialization import load_llm_providers
|
||||
@@ -66,11 +67,14 @@ from danswer.server.features.tool.api import admin_router as admin_tool_router
|
||||
from danswer.server.features.tool.api import router as tool_router
|
||||
from danswer.server.gpts.api import router as gpts_router
|
||||
from danswer.server.manage.administrative import router as admin_router
|
||||
from danswer.server.manage.embedding.api import admin_router as embedding_admin_router
|
||||
from danswer.server.manage.embedding.api import basic_router as embedding_router
|
||||
from danswer.server.manage.get_state import router as state_router
|
||||
from danswer.server.manage.llm.api import admin_router as llm_admin_router
|
||||
from danswer.server.manage.llm.api import basic_router as llm_router
|
||||
from danswer.server.manage.secondary_index import router as secondary_index_router
|
||||
from danswer.server.manage.slack_bot import router as slack_bot_management_router
|
||||
from danswer.server.manage.standard_answer import router as standard_answer_router
|
||||
from danswer.server.manage.users import router as user_router
|
||||
from danswer.server.middleware.latency_logging import add_latency_logging_middleware
|
||||
from danswer.server.query_and_chat.chat_backend import router as chat_router
|
||||
@@ -207,6 +211,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
create_initial_default_connector(db_session)
|
||||
associate_default_cc_pair(db_session)
|
||||
|
||||
logger.info("Verifying default standard answer category exists.")
|
||||
create_initial_default_standard_answer_category(db_session)
|
||||
|
||||
logger.info("Loading LLM providers from env variables")
|
||||
load_llm_providers(db_session)
|
||||
|
||||
@@ -242,12 +249,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
time.sleep(wait_time)
|
||||
|
||||
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
if db_embedding_model.cloud_provider_id is None:
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
|
||||
yield
|
||||
@@ -273,6 +281,7 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, slack_bot_management_router
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, standard_answer_router)
|
||||
include_router_with_global_prefix_prepended(application, persona_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_persona_router)
|
||||
include_router_with_global_prefix_prepended(application, prompt_router)
|
||||
@@ -285,6 +294,8 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, settings_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, llm_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, llm_router)
|
||||
include_router_with_global_prefix_prepended(application, embedding_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, embedding_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, token_rate_limit_settings_router
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import LLMRelevanceSummaryResponse
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
@@ -21,6 +22,7 @@ from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.chat import update_search_docs_table_with_relevance
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_prompt_by_id
|
||||
@@ -48,6 +50,7 @@ from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from danswer.tools.search.search_tool import SEARCH_EVALUATION_ID
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
@@ -57,6 +60,7 @@ from danswer.tools.tool_runner import ToolCallKickoff
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
AnswerObjectIterator = Iterator[
|
||||
@@ -70,6 +74,7 @@ AnswerObjectIterator = Iterator[
|
||||
| ChatMessageDetail
|
||||
| CitationInfo
|
||||
| ToolCallKickoff
|
||||
| LLMRelevanceSummaryResponse
|
||||
]
|
||||
|
||||
|
||||
@@ -88,8 +93,9 @@ def stream_answer_objects(
|
||||
bypass_acl: bool = False,
|
||||
use_citations: bool = False,
|
||||
danswerbot_flow: bool = False,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
retrieval_metrics_callback: (
|
||||
Callable[[RetrievalMetricsContainer], None] | None
|
||||
) = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> AnswerObjectIterator:
|
||||
"""Streams in order:
|
||||
@@ -127,6 +133,7 @@ def stream_answer_objects(
|
||||
user_query=query_msg.message,
|
||||
history_str=history_str,
|
||||
)
|
||||
|
||||
# Given back ahead of the documents for latency reasons
|
||||
# In chat flow it's given back along with the documents
|
||||
yield QueryRephrase(rephrased_query=rephrased_query)
|
||||
@@ -168,6 +175,7 @@ def stream_answer_objects(
|
||||
max_tokens=max_document_tokens,
|
||||
use_sections=query_req.chunks_above > 0 or query_req.chunks_below > 0,
|
||||
)
|
||||
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
@@ -177,7 +185,11 @@ def stream_answer_objects(
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
chunks_above=query_req.chunks_above,
|
||||
chunks_below=query_req.chunks_below,
|
||||
full_doc=query_req.full_doc,
|
||||
bypass_acl=bypass_acl,
|
||||
llm_doc_eval=query_req.llm_doc_eval,
|
||||
)
|
||||
|
||||
answer_config = AnswerStyleConfig(
|
||||
@@ -185,6 +197,7 @@ def stream_answer_objects(
|
||||
quotes_config=QuotesConfig() if not use_citations else None,
|
||||
document_pruning_config=document_pruning_config,
|
||||
)
|
||||
|
||||
answer = Answer(
|
||||
question=query_msg.message,
|
||||
answer_style_config=answer_config,
|
||||
@@ -193,19 +206,23 @@ def stream_answer_objects(
|
||||
single_message_history=history_str,
|
||||
tools=[search_tool],
|
||||
force_use_tool=ForceUseTool(
|
||||
tool_name=search_tool.name(),
|
||||
tool_name=search_tool.name,
|
||||
args={"query": rephrased_query},
|
||||
),
|
||||
# for now, don't use tool calling for this flow, as we haven't
|
||||
# tested quotes with tool calling too much yet
|
||||
skip_explicit_tool_calling=True,
|
||||
return_contexts=query_req.return_contexts,
|
||||
skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
# won't be any ImageGenerationDisplay responses since that tool is never passed in
|
||||
dropped_inds: list[int] = []
|
||||
|
||||
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
|
||||
# for one-shot flow, don't currently do anything with these
|
||||
if isinstance(packet, ToolResponse):
|
||||
# (likely fine that it comes after the initial creation of the search docs)
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
search_response_summary = cast(SearchResponseSummary, packet.response)
|
||||
|
||||
@@ -238,6 +255,7 @@ def stream_answer_objects(
|
||||
recency_bias_multiplier=search_response_summary.recency_bias_multiplier,
|
||||
)
|
||||
yield initial_response
|
||||
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
chunk_indices = packet.response
|
||||
|
||||
@@ -249,8 +267,21 @@ def stream_answer_objects(
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(relevant_chunk_indices=packet.response)
|
||||
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID:
|
||||
yield packet.response
|
||||
|
||||
elif packet.id == SEARCH_EVALUATION_ID:
|
||||
evaluation_response = LLMRelevanceSummaryResponse(
|
||||
relevance_summaries=packet.response
|
||||
)
|
||||
if reference_db_search_docs is not None:
|
||||
update_search_docs_table_with_relevance(
|
||||
db_session=db_session,
|
||||
reference_db_search_docs=reference_db_search_docs,
|
||||
relevance_summary=evaluation_response,
|
||||
)
|
||||
yield evaluation_response
|
||||
else:
|
||||
yield packet
|
||||
|
||||
@@ -271,7 +302,6 @@ def stream_answer_objects(
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
|
||||
yield msg_detail_response
|
||||
|
||||
|
||||
@@ -305,8 +335,9 @@ def get_search_answer(
|
||||
bypass_acl: bool = False,
|
||||
use_citations: bool = False,
|
||||
danswerbot_flow: bool = False,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
retrieval_metrics_callback: (
|
||||
Callable[[RetrievalMetricsContainer], None] | None
|
||||
) = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> OneShotQAResponse:
|
||||
"""Collects the streamed one shot answer responses into a single object"""
|
||||
|
||||
@@ -27,12 +27,19 @@ class DirectQARequest(ChunkContext):
|
||||
messages: list[ThreadMessage]
|
||||
prompt_id: int | None
|
||||
persona_id: int
|
||||
agentic: bool | None = None
|
||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||
# This is to forcibly skip (or run) the step, if None it uses the system defaults
|
||||
skip_rerank: bool | None = None
|
||||
skip_llm_chunk_filter: bool | None = None
|
||||
chain_of_thought: bool = False
|
||||
return_contexts: bool = False
|
||||
# This is to toggle agentic evaluation:
|
||||
# 1. Evaluates whether each response is relevant or not
|
||||
# 2. Provides a summary of the document's relevance in the resulsts
|
||||
llm_doc_eval: bool = False
|
||||
# If True, skips generative an AI response to the search query
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
@root_validator
|
||||
def check_chain_of_thought_and_prompt_id(
|
||||
|
||||
43
backend/danswer/prompts/agentic_evaluation.py
Normal file
43
backend/danswer/prompts/agentic_evaluation.py
Normal file
@@ -0,0 +1,43 @@
|
||||
AGENTIC_SEARCH_SYSTEM_PROMPT = """
|
||||
You are an expert at evaluating the relevance of a document to a search query.
|
||||
Provided a document and a search query, you determine if the document is relevant to the user query.
|
||||
You ALWAYS output the 3 sections described below and every section always begins with the same header line.
|
||||
The "Chain of Thought" is to help you understand the document and query and their relevance to one another.
|
||||
The "Useful Analysis" is shown to the user to help them understand why the document is or is not useful for them.
|
||||
The "Final Relevance Determination" is always a single True or False.
|
||||
|
||||
You always output your response following these 3 sections:
|
||||
|
||||
1. Chain of Thought:
|
||||
Provide a chain of thought analysis considering:
|
||||
- The main purpose and content of the document
|
||||
- What the user is searching for
|
||||
- How the document relates to the query
|
||||
- Potential uses of the document for the given query
|
||||
Be thorough, but avoid unnecessary repetition. Think step by step.
|
||||
|
||||
2. Useful Analysis:
|
||||
Summarize the contents of the document as it relates to the user query.
|
||||
BE ABSOLUTELY AS CONCISE AS POSSIBLE.
|
||||
If the document is not useful, briefly mention the what the document is about.
|
||||
Do NOT say whether this document is useful or not useful, ONLY provide the summary.
|
||||
If referring to the document, prefer using "this" document over "the" document.
|
||||
|
||||
3. Final Relevance Determination:
|
||||
True or False
|
||||
"""
|
||||
|
||||
AGENTIC_SEARCH_USER_PROMPT = """
|
||||
Document:
|
||||
```
|
||||
{content}
|
||||
```
|
||||
|
||||
Query:
|
||||
{query}
|
||||
|
||||
Be sure to run through the 3 steps of evaluation:
|
||||
1. Chain of Thought
|
||||
2. Useful Analysis
|
||||
3. Final Relevance Determination
|
||||
""".strip()
|
||||
@@ -144,6 +144,23 @@ Follow Up Input: {{question}}
|
||||
Standalone question (Respond with only the short combined query):
|
||||
""".strip()
|
||||
|
||||
INTERNET_SEARCH_QUERY_REPHRASE = f"""
|
||||
Given the following conversation and a follow up input, rephrase the follow up into a SHORT, \
|
||||
standalone query suitable for an internet search engine.
|
||||
IMPORTANT: If a specific query might limit results, keep it broad. \
|
||||
If a broad query might yield too many results, make it detailed.
|
||||
If there is a clear change in topic, ensure the query reflects the new topic accurately.
|
||||
Strip out any information that is not relevant for the internet search.
|
||||
|
||||
{GENERAL_SEP_PAT}
|
||||
Chat History:
|
||||
{{chat_history}}
|
||||
{GENERAL_SEP_PAT}
|
||||
|
||||
Follow Up Input: {{question}}
|
||||
Internet Search Query (Respond with a detailed and specific query):
|
||||
""".strip()
|
||||
|
||||
|
||||
# The below prompts are retired
|
||||
NO_SEARCH = "No Search"
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
USEFUL_PAT = "Yes useful"
|
||||
NONUSEFUL_PAT = "Not useful"
|
||||
CHUNK_FILTER_PROMPT = f"""
|
||||
SECTION_FILTER_PROMPT = f"""
|
||||
Determine if the reference section is USEFUL for answering the user query.
|
||||
It is NOT enough for the section to be related to the query, \
|
||||
it must contain information that is USEFUL for answering the query.
|
||||
@@ -27,4 +27,4 @@ Respond with EXACTLY AND ONLY: "{USEFUL_PAT}" or "{NONUSEFUL_PAT}"
|
||||
|
||||
# Use the following for easy viewing of prompts
|
||||
if __name__ == "__main__":
|
||||
print(CHUNK_FILTER_PROMPT)
|
||||
print(SECTION_FILTER_PROMPT)
|
||||
|
||||
@@ -28,8 +28,3 @@ class SearchType(str, Enum):
|
||||
class QueryFlow(str, Enum):
|
||||
SEARCH = "search"
|
||||
QUESTION_ANSWER = "question-answer"
|
||||
|
||||
|
||||
class EmbedTextType(str, Enum):
|
||||
QUERY = "query"
|
||||
PASSAGE = "passage"
|
||||
|
||||
@@ -4,6 +4,8 @@ from typing import Any
|
||||
from pydantic import BaseModel
|
||||
from pydantic import validator
|
||||
|
||||
from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||
from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
|
||||
from danswer.configs.chat_configs import HYBRID_ALPHA
|
||||
from danswer.configs.chat_configs import NUM_RERANKED_RESULTS
|
||||
@@ -47,8 +49,8 @@ class ChunkMetric(BaseModel):
|
||||
class ChunkContext(BaseModel):
|
||||
# Additional surrounding context options, if full doc, then chunks are deduped
|
||||
# If surrounding context overlap, it is combined into one
|
||||
chunks_above: int = 0
|
||||
chunks_below: int = 0
|
||||
chunks_above: int = CONTEXT_CHUNKS_ABOVE
|
||||
chunks_below: int = CONTEXT_CHUNKS_BELOW
|
||||
full_doc: bool = False
|
||||
|
||||
@validator("chunks_above", "chunks_below", pre=True, each_item=False)
|
||||
@@ -94,7 +96,7 @@ class SearchQuery(ChunkContext):
|
||||
# Only used if not skip_rerank
|
||||
num_rerank: int | None = NUM_RERANKED_RESULTS
|
||||
# Only used if not skip_llm_chunk_filter
|
||||
max_llm_filter_chunks: int = NUM_RERANKED_RESULTS
|
||||
max_llm_filter_sections: int = NUM_RERANKED_RESULTS
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
@@ -128,11 +130,14 @@ class InferenceChunk(BaseChunk):
|
||||
recency_bias: float
|
||||
score: float | None
|
||||
hidden: bool
|
||||
is_relevant: bool | None = None
|
||||
relevance_explanation: str | None = None
|
||||
metadata: dict[str, str | list[str]]
|
||||
# Matched sections in the chunk. Uses Vespa syntax e.g. <hi>TEXT</hi>
|
||||
# to specify that a set of words should be highlighted. For example:
|
||||
# ["<hi>the</hi> <hi>answer</hi> is 42", "he couldn't find an <hi>answer</hi>"]
|
||||
match_highlights: list[str]
|
||||
|
||||
# when the doc was last updated
|
||||
updated_at: datetime | None
|
||||
primary_owners: list[str] | None = None
|
||||
@@ -162,20 +167,54 @@ class InferenceChunk(BaseChunk):
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.document_id, self.chunk_id))
|
||||
|
||||
def __lt__(self, other: Any) -> bool:
|
||||
if not isinstance(other, InferenceChunk):
|
||||
return NotImplemented
|
||||
if self.score is None:
|
||||
if other.score is None:
|
||||
return self.chunk_id > other.chunk_id
|
||||
return True
|
||||
if other.score is None:
|
||||
return False
|
||||
if self.score == other.score:
|
||||
return self.chunk_id > other.chunk_id
|
||||
return self.score < other.score
|
||||
|
||||
class InferenceSection(InferenceChunk):
|
||||
"""Section is a combination of chunks. A section could be a single chunk, several consecutive
|
||||
chunks or the entire document"""
|
||||
def __gt__(self, other: Any) -> bool:
|
||||
if not isinstance(other, InferenceChunk):
|
||||
return NotImplemented
|
||||
if self.score is None:
|
||||
return False
|
||||
if other.score is None:
|
||||
return True
|
||||
if self.score == other.score:
|
||||
return self.chunk_id < other.chunk_id
|
||||
return self.score > other.score
|
||||
|
||||
|
||||
class InferenceChunkUncleaned(InferenceChunk):
|
||||
title: str | None # Separate from Semantic Identifier though often same
|
||||
metadata_suffix: str | None
|
||||
|
||||
def to_inference_chunk(self) -> InferenceChunk:
|
||||
# Create a dict of all fields except 'title' and 'metadata_suffix'
|
||||
# Assumes the cleaning has already been applied and just needs to translate to the right type
|
||||
inference_chunk_data = {
|
||||
k: v
|
||||
for k, v in self.dict().items()
|
||||
if k not in ["title", "metadata_suffix"]
|
||||
}
|
||||
return InferenceChunk(**inference_chunk_data)
|
||||
|
||||
|
||||
class InferenceSection(BaseModel):
|
||||
"""Section list of chunks with a combined content. A section could be a single chunk, several
|
||||
chunks from the same document or the entire document."""
|
||||
|
||||
center_chunk: InferenceChunk
|
||||
chunks: list[InferenceChunk]
|
||||
combined_content: str
|
||||
|
||||
@classmethod
|
||||
def from_chunk(
|
||||
cls, inf_chunk: InferenceChunk, content: str | None = None
|
||||
) -> "InferenceSection":
|
||||
inf_chunk_data = inf_chunk.dict()
|
||||
return cls(**inf_chunk_data, combined_content=content or inf_chunk.content)
|
||||
|
||||
|
||||
class SearchDoc(BaseModel):
|
||||
document_id: str
|
||||
@@ -191,6 +230,8 @@ class SearchDoc(BaseModel):
|
||||
hidden: bool
|
||||
metadata: dict[str, str | list[str]]
|
||||
score: float | None
|
||||
is_relevant: bool | None = None
|
||||
relevance_explanation: str | None = None
|
||||
# Matched sections in the doc. Uses Vespa syntax e.g. <hi>TEXT</hi>
|
||||
# to specify that a set of words should be highlighted. For example:
|
||||
# ["<hi>the</hi> <hi>answer</hi> is 42", "the answer is <hi>42</hi>""]
|
||||
@@ -199,6 +240,7 @@ class SearchDoc(BaseModel):
|
||||
updated_at: datetime | None
|
||||
primary_owners: list[str] | None
|
||||
secondary_owners: list[str] | None
|
||||
is_internet: bool = False
|
||||
|
||||
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().dict(*args, **kwargs) # type: ignore
|
||||
@@ -229,6 +271,13 @@ class SavedSearchDoc(SearchDoc):
|
||||
return self.score < other.score
|
||||
|
||||
|
||||
class SavedSearchDocWithContent(SavedSearchDoc):
|
||||
"""Used for endpoints that need to return the actual contents of the retrieved
|
||||
section in addition to the match_highlights."""
|
||||
|
||||
content: str
|
||||
|
||||
|
||||
class RetrievalDocs(BaseModel):
|
||||
top_documents: list[SavedSearchDoc]
|
||||
|
||||
|
||||
@@ -1,15 +1,20 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import RelevanceChunk
|
||||
from danswer.configs.chat_configs import DISABLE_AGENTIC_SEARCH
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.models import User
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prune_and_merge import ChunkRange
|
||||
from danswer.llm.answering.prune_and_merge import merge_chunk_intervals
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
@@ -23,31 +28,14 @@ from danswer.search.models import SearchRequest
|
||||
from danswer.search.postprocessing.postprocessing import search_postprocessing
|
||||
from danswer.search.preprocessing.preprocessing import retrieval_preprocessing
|
||||
from danswer.search.retrieval.search_runner import retrieve_chunks
|
||||
from danswer.search.utils import inference_section_from_chunks
|
||||
from danswer.secondary_llm_flows.agentic_evaluation import evaluate_inference_section
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
|
||||
class ChunkRange(BaseModel):
|
||||
chunk: InferenceChunk
|
||||
start: int
|
||||
end: int
|
||||
combined_content: str | None = None
|
||||
|
||||
|
||||
def merge_chunk_intervals(chunk_ranges: list[ChunkRange]) -> list[ChunkRange]:
|
||||
"""This acts on a single document to merge the overlapping ranges of sections
|
||||
Algo explained here for easy understanding: https://leetcode.com/problems/merge-intervals
|
||||
"""
|
||||
sorted_ranges = sorted(chunk_ranges, key=lambda x: x.start)
|
||||
|
||||
ans: list[ChunkRange] = []
|
||||
|
||||
for chunk_range in sorted_ranges:
|
||||
if not ans or ans[-1].end < chunk_range.start:
|
||||
ans.append(chunk_range)
|
||||
else:
|
||||
ans[-1].end = max(ans[-1].end, chunk_range.end)
|
||||
|
||||
return ans
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class SearchPipeline:
|
||||
@@ -59,9 +47,12 @@ class SearchPipeline:
|
||||
fast_llm: LLM,
|
||||
db_session: Session,
|
||||
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
retrieval_metrics_callback: (
|
||||
Callable[[RetrievalMetricsContainer], None] | None
|
||||
) = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
prompt_config: PromptConfig | None = None,
|
||||
pruning_config: DocumentPruningConfig | None = None,
|
||||
):
|
||||
self.search_request = search_request
|
||||
self.user = user
|
||||
@@ -77,61 +68,116 @@ class SearchPipeline:
|
||||
primary_index_name=self.embedding_model.index_name,
|
||||
secondary_index_name=None,
|
||||
)
|
||||
self.prompt_config: PromptConfig | None = prompt_config
|
||||
self.pruning_config: DocumentPruningConfig | None = pruning_config
|
||||
|
||||
# Preprocessing steps generate this
|
||||
self._search_query: SearchQuery | None = None
|
||||
self._predicted_search_type: SearchType | None = None
|
||||
self._predicted_flow: QueryFlow | None = None
|
||||
|
||||
# Initial document index retrieval chunks
|
||||
self._retrieved_chunks: list[InferenceChunk] | None = None
|
||||
# Another call made to the document index to get surrounding sections
|
||||
self._retrieved_sections: list[InferenceSection] | None = None
|
||||
self._reranked_chunks: list[InferenceChunk] | None = None
|
||||
# Reranking and LLM section selection can be run together
|
||||
# If only LLM selection is on, the reranked chunks are yielded immediatly
|
||||
self._reranked_sections: list[InferenceSection] | None = None
|
||||
self._relevant_chunk_indices: list[int] | None = None
|
||||
self._relevant_section_indices: list[int] | None = None
|
||||
|
||||
# If chunks have been merged, the LLM filter flow no longer applies
|
||||
# as the indices no longer match. Can be implemented later as needed
|
||||
self.ran_merge_chunk = False
|
||||
# Generates reranked chunks and LLM selections
|
||||
self._postprocessing_generator: (
|
||||
Iterator[list[InferenceSection] | list[int]] | None
|
||||
) = None
|
||||
|
||||
# generator state
|
||||
self._postprocessing_generator: Generator[
|
||||
list[InferenceChunk] | list[str], None, None
|
||||
] | None = None
|
||||
"""Pre-processing"""
|
||||
|
||||
def _combine_chunks(self, post_rerank: bool) -> list[InferenceSection]:
|
||||
if not post_rerank and self._retrieved_sections:
|
||||
def _run_preprocessing(self) -> None:
|
||||
(
|
||||
final_search_query,
|
||||
predicted_search_type,
|
||||
predicted_flow,
|
||||
) = retrieval_preprocessing(
|
||||
search_request=self.search_request,
|
||||
user=self.user,
|
||||
llm=self.llm,
|
||||
db_session=self.db_session,
|
||||
bypass_acl=self.bypass_acl,
|
||||
)
|
||||
self._search_query = final_search_query
|
||||
self._predicted_search_type = predicted_search_type
|
||||
self._predicted_flow = predicted_flow
|
||||
|
||||
@property
|
||||
def search_query(self) -> SearchQuery:
|
||||
if self._search_query is not None:
|
||||
return self._search_query
|
||||
|
||||
self._run_preprocessing()
|
||||
|
||||
return cast(SearchQuery, self._search_query)
|
||||
|
||||
@property
|
||||
def predicted_search_type(self) -> SearchType:
|
||||
if self._predicted_search_type is not None:
|
||||
return self._predicted_search_type
|
||||
|
||||
self._run_preprocessing()
|
||||
return cast(SearchType, self._predicted_search_type)
|
||||
|
||||
@property
|
||||
def predicted_flow(self) -> QueryFlow:
|
||||
if self._predicted_flow is not None:
|
||||
return self._predicted_flow
|
||||
|
||||
self._run_preprocessing()
|
||||
return cast(QueryFlow, self._predicted_flow)
|
||||
|
||||
"""Retrieval and Postprocessing"""
|
||||
|
||||
def _get_chunks(self) -> list[InferenceChunk]:
|
||||
"""TODO as a future extension:
|
||||
If large chunks (above 512 tokens) are used which cannot be directly fed to the LLM,
|
||||
This step should run the two retrievals to get all of the base size chunks
|
||||
"""
|
||||
if self._retrieved_chunks is not None:
|
||||
return self._retrieved_chunks
|
||||
|
||||
self._retrieved_chunks = retrieve_chunks(
|
||||
query=self.search_query,
|
||||
document_index=self.document_index,
|
||||
db_session=self.db_session,
|
||||
hybrid_alpha=self.search_request.hybrid_alpha,
|
||||
multilingual_expansion_str=MULTILINGUAL_QUERY_EXPANSION,
|
||||
retrieval_metrics_callback=self.retrieval_metrics_callback,
|
||||
)
|
||||
|
||||
return cast(list[InferenceChunk], self._retrieved_chunks)
|
||||
|
||||
def _get_sections(self) -> list[InferenceSection]:
|
||||
"""Returns an expanded section from each of the chunks.
|
||||
If whole docs (instead of above/below context) is specified then it will give back all of the whole docs
|
||||
that have a corresponding chunk.
|
||||
|
||||
This step should be fast for any document index implementation.
|
||||
"""
|
||||
if self._retrieved_sections is not None:
|
||||
return self._retrieved_sections
|
||||
if post_rerank and self._reranked_sections:
|
||||
return self._reranked_sections
|
||||
|
||||
if not post_rerank:
|
||||
chunks = self.retrieved_chunks
|
||||
else:
|
||||
chunks = self.reranked_chunks
|
||||
retrieved_chunks = self._get_chunks()
|
||||
|
||||
if self._search_query is None:
|
||||
# Should never happen
|
||||
raise RuntimeError("Failed in Query Preprocessing")
|
||||
above = self.search_query.chunks_above
|
||||
below = self.search_query.chunks_below
|
||||
|
||||
functions_with_args: list[tuple[Callable, tuple]] = []
|
||||
final_inference_sections = []
|
||||
|
||||
# Nothing to combine, just return the chunks
|
||||
if (
|
||||
not self._search_query.chunks_above
|
||||
and not self._search_query.chunks_below
|
||||
and not self._search_query.full_doc
|
||||
):
|
||||
return [InferenceSection.from_chunk(chunk) for chunk in chunks]
|
||||
|
||||
# If chunk merges have been run, LLM reranking loses meaning
|
||||
# Needs reimplementation, out of scope for now
|
||||
self.ran_merge_chunk = True
|
||||
expanded_inference_sections = []
|
||||
|
||||
# Full doc setting takes priority
|
||||
if self._search_query.full_doc:
|
||||
if self.search_query.full_doc:
|
||||
seen_document_ids = set()
|
||||
unique_chunks = []
|
||||
for chunk in chunks:
|
||||
# This preserves the ordering since the chunks are retrieved in score order
|
||||
for chunk in retrieved_chunks:
|
||||
if chunk.document_id not in seen_document_ids:
|
||||
seen_document_ids.add(chunk.document_id)
|
||||
unique_chunks.append(chunk)
|
||||
@@ -156,43 +202,54 @@ class SearchPipeline:
|
||||
|
||||
for ind, chunk in enumerate(unique_chunks):
|
||||
inf_chunks = list_inference_chunks[ind]
|
||||
combined_content = "\n".join([chunk.content for chunk in inf_chunks])
|
||||
final_inference_sections.append(
|
||||
InferenceSection.from_chunk(chunk, content=combined_content)
|
||||
|
||||
inference_section = inference_section_from_chunks(
|
||||
center_chunk=chunk,
|
||||
chunks=inf_chunks,
|
||||
)
|
||||
|
||||
return final_inference_sections
|
||||
if inference_section is not None:
|
||||
expanded_inference_sections.append(inference_section)
|
||||
else:
|
||||
logger.warning("Skipped creation of section, no chunks found")
|
||||
|
||||
self._retrieved_sections = expanded_inference_sections
|
||||
return expanded_inference_sections
|
||||
|
||||
# General flow:
|
||||
# - Combine chunks into lists by document_id
|
||||
# - For each document, run merge-intervals to get combined ranges
|
||||
# - This allows for less queries to the document index
|
||||
# - Fetch all of the new chunks with contents for the combined ranges
|
||||
# - Map it back to the combined ranges (which each know their "center" chunk)
|
||||
# - Reiterate the chunks again and map to the results above based on the chunk.
|
||||
# This maintains the original chunks ordering. Note, we cannot simply sort by score here
|
||||
# as reranking flow may wipe the scores for a lot of the chunks.
|
||||
doc_chunk_ranges_map = defaultdict(list)
|
||||
for chunk in chunks:
|
||||
for chunk in retrieved_chunks:
|
||||
# The list of ranges for each document is ordered by score
|
||||
doc_chunk_ranges_map[chunk.document_id].append(
|
||||
ChunkRange(
|
||||
chunk=chunk,
|
||||
start=max(0, chunk.chunk_id - self._search_query.chunks_above),
|
||||
chunks=[chunk],
|
||||
start=max(0, chunk.chunk_id - above),
|
||||
# No max known ahead of time, filter will handle this anyway
|
||||
end=chunk.chunk_id + self._search_query.chunks_below,
|
||||
end=chunk.chunk_id + below,
|
||||
)
|
||||
)
|
||||
|
||||
# List of ranges, outside list represents documents, inner list represents ranges
|
||||
merged_ranges = [
|
||||
merge_chunk_intervals(ranges) for ranges in doc_chunk_ranges_map.values()
|
||||
]
|
||||
reverse_map = {r.chunk: r for doc_ranges in merged_ranges for r in doc_ranges}
|
||||
flat_ranges = [r for ranges in merged_ranges for r in ranges]
|
||||
|
||||
for chunk_range in reverse_map.values():
|
||||
for chunk_range in flat_ranges:
|
||||
functions_with_args.append(
|
||||
(
|
||||
# If Large Chunks are introduced, additional filters need to be added here
|
||||
self.document_index.id_based_retrieval,
|
||||
(
|
||||
chunk_range.chunk.document_id,
|
||||
# Only need the document_id here, just use any chunk in the range is fine
|
||||
chunk_range.chunks[0].document_id,
|
||||
chunk_range.start,
|
||||
chunk_range.end,
|
||||
# There is no chunk level permissioning, this expansion around chunks
|
||||
@@ -206,152 +263,107 @@ class SearchPipeline:
|
||||
list_inference_chunks = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=False
|
||||
)
|
||||
flattened_inference_chunks = [
|
||||
chunk for sublist in list_inference_chunks for chunk in sublist
|
||||
]
|
||||
|
||||
for ind, chunk_range in enumerate(reverse_map.values()):
|
||||
inf_chunks = list_inference_chunks[ind]
|
||||
combined_content = "\n".join([chunk.content for chunk in inf_chunks])
|
||||
chunk_range.combined_content = combined_content
|
||||
doc_chunk_ind_to_chunk = {
|
||||
(chunk.document_id, chunk.chunk_id): chunk
|
||||
for chunk in flattened_inference_chunks
|
||||
}
|
||||
|
||||
for chunk in chunks:
|
||||
if chunk not in reverse_map:
|
||||
continue
|
||||
chunk_range = reverse_map[chunk]
|
||||
final_inference_sections.append(
|
||||
InferenceSection.from_chunk(
|
||||
chunk_range.chunk, content=chunk_range.combined_content
|
||||
)
|
||||
# Build the surroundings for all of the initial retrieved chunks
|
||||
for chunk in retrieved_chunks:
|
||||
start_ind = max(0, chunk.chunk_id - above)
|
||||
end_ind = chunk.chunk_id + below
|
||||
|
||||
# Since the index of the max_chunk is unknown, just allow it to be None and filter after
|
||||
surrounding_chunks_or_none = [
|
||||
doc_chunk_ind_to_chunk.get((chunk.document_id, chunk_ind))
|
||||
for chunk_ind in range(start_ind, end_ind + 1) # end_ind is inclusive
|
||||
]
|
||||
# The None will apply to the would be "chunks" that are larger than the index of the last chunk
|
||||
# of the document
|
||||
surrounding_chunks = [
|
||||
chunk for chunk in surrounding_chunks_or_none if chunk is not None
|
||||
]
|
||||
|
||||
inference_section = inference_section_from_chunks(
|
||||
center_chunk=chunk,
|
||||
chunks=surrounding_chunks,
|
||||
)
|
||||
if inference_section is not None:
|
||||
expanded_inference_sections.append(inference_section)
|
||||
else:
|
||||
logger.warning("Skipped creation of section, no chunks found")
|
||||
|
||||
return final_inference_sections
|
||||
|
||||
"""Pre-processing"""
|
||||
|
||||
def _run_preprocessing(self) -> None:
|
||||
(
|
||||
final_search_query,
|
||||
predicted_search_type,
|
||||
predicted_flow,
|
||||
) = retrieval_preprocessing(
|
||||
search_request=self.search_request,
|
||||
user=self.user,
|
||||
llm=self.llm,
|
||||
db_session=self.db_session,
|
||||
bypass_acl=self.bypass_acl,
|
||||
)
|
||||
self._predicted_search_type = predicted_search_type
|
||||
self._predicted_flow = predicted_flow
|
||||
self._search_query = final_search_query
|
||||
|
||||
@property
|
||||
def search_query(self) -> SearchQuery:
|
||||
if self._search_query is not None:
|
||||
return self._search_query
|
||||
|
||||
self._run_preprocessing()
|
||||
return cast(SearchQuery, self._search_query)
|
||||
|
||||
@property
|
||||
def predicted_search_type(self) -> SearchType:
|
||||
if self._predicted_search_type is not None:
|
||||
return self._predicted_search_type
|
||||
|
||||
self._run_preprocessing()
|
||||
return cast(SearchType, self._predicted_search_type)
|
||||
|
||||
@property
|
||||
def predicted_flow(self) -> QueryFlow:
|
||||
if self._predicted_flow is not None:
|
||||
return self._predicted_flow
|
||||
|
||||
self._run_preprocessing()
|
||||
return cast(QueryFlow, self._predicted_flow)
|
||||
|
||||
"""Retrieval"""
|
||||
|
||||
@property
|
||||
def retrieved_chunks(self) -> list[InferenceChunk]:
|
||||
if self._retrieved_chunks is not None:
|
||||
return self._retrieved_chunks
|
||||
|
||||
self._retrieved_chunks = retrieve_chunks(
|
||||
query=self.search_query,
|
||||
document_index=self.document_index,
|
||||
db_session=self.db_session,
|
||||
hybrid_alpha=self.search_request.hybrid_alpha,
|
||||
multilingual_expansion_str=MULTILINGUAL_QUERY_EXPANSION,
|
||||
retrieval_metrics_callback=self.retrieval_metrics_callback,
|
||||
)
|
||||
|
||||
return cast(list[InferenceChunk], self._retrieved_chunks)
|
||||
|
||||
@property
|
||||
def retrieved_sections(self) -> list[InferenceSection]:
|
||||
# Calls retrieved_chunks inside
|
||||
self._retrieved_sections = self._combine_chunks(post_rerank=False)
|
||||
return self._retrieved_sections
|
||||
|
||||
"""Post-Processing"""
|
||||
|
||||
@property
|
||||
def reranked_chunks(self) -> list[InferenceChunk]:
|
||||
if self._reranked_chunks is not None:
|
||||
return self._reranked_chunks
|
||||
|
||||
self._postprocessing_generator = search_postprocessing(
|
||||
search_query=self.search_query,
|
||||
retrieved_chunks=self.retrieved_chunks,
|
||||
llm=self.fast_llm, # use fast_llm for relevance, since it is a relatively easier task
|
||||
rerank_metrics_callback=self.rerank_metrics_callback,
|
||||
)
|
||||
self._reranked_chunks = cast(
|
||||
list[InferenceChunk], next(self._postprocessing_generator)
|
||||
)
|
||||
return self._reranked_chunks
|
||||
self._retrieved_sections = expanded_inference_sections
|
||||
return expanded_inference_sections
|
||||
|
||||
@property
|
||||
def reranked_sections(self) -> list[InferenceSection]:
|
||||
# Calls reranked_chunks inside
|
||||
self._reranked_sections = self._combine_chunks(post_rerank=True)
|
||||
"""Reranking is always done at the chunk level since section merging could create arbitrarily
|
||||
long sections which could be:
|
||||
1. Longer than the maximum context limit of even large rerankers
|
||||
2. Slow to calculate due to the quadratic scaling laws of Transformers
|
||||
|
||||
See implementation in search_postprocessing for details
|
||||
"""
|
||||
if self._reranked_sections is not None:
|
||||
return self._reranked_sections
|
||||
|
||||
self._postprocessing_generator = search_postprocessing(
|
||||
search_query=self.search_query,
|
||||
retrieved_sections=self._get_sections(),
|
||||
llm=self.fast_llm,
|
||||
rerank_metrics_callback=self.rerank_metrics_callback,
|
||||
)
|
||||
|
||||
self._reranked_sections = cast(
|
||||
list[InferenceSection], next(self._postprocessing_generator)
|
||||
)
|
||||
|
||||
return self._reranked_sections
|
||||
|
||||
@property
|
||||
def relevant_chunk_indices(self) -> list[int]:
|
||||
# If chunks have been merged, then we cannot simply rely on the leading chunk
|
||||
# relevance, there is no way to get the full relevance of the Section now
|
||||
# without running a more token heavy pass. This can be an option but not
|
||||
# implementing now.
|
||||
if self.ran_merge_chunk:
|
||||
return []
|
||||
def relevant_section_indices(self) -> list[int]:
|
||||
if self._relevant_section_indices is not None:
|
||||
return self._relevant_section_indices
|
||||
|
||||
if self._relevant_chunk_indices is not None:
|
||||
return self._relevant_chunk_indices
|
||||
|
||||
# run first step of postprocessing generator if not already done
|
||||
reranked_docs = self.reranked_chunks
|
||||
|
||||
relevant_chunk_ids = next(
|
||||
cast(Generator[list[str], None, None], self._postprocessing_generator)
|
||||
self._relevant_section_indices = next(
|
||||
cast(Iterator[list[int]], self._postprocessing_generator)
|
||||
)
|
||||
self._relevant_chunk_indices = [
|
||||
ind
|
||||
for ind, chunk in enumerate(reranked_docs)
|
||||
if chunk.unique_id in relevant_chunk_ids
|
||||
]
|
||||
return self._relevant_chunk_indices
|
||||
return self._relevant_section_indices
|
||||
|
||||
@property
|
||||
def chunk_relevance_list(self) -> list[bool]:
|
||||
return [
|
||||
True if ind in self.relevant_chunk_indices else False
|
||||
for ind in range(len(self.reranked_chunks))
|
||||
def relevance_summaries(self) -> dict[str, RelevanceChunk]:
|
||||
if DISABLE_AGENTIC_SEARCH:
|
||||
raise ValueError(
|
||||
"Agentic saerch operation called while DISABLE_AGENTIC_SEARCH is toggled"
|
||||
)
|
||||
if len(self.reranked_sections) == 0:
|
||||
logger.warning(
|
||||
"No sections found in agentic search evalution. Returning empty dict."
|
||||
)
|
||||
return {}
|
||||
|
||||
sections = self.reranked_sections
|
||||
functions = [
|
||||
FunctionCall(
|
||||
evaluate_inference_section, (section, self.search_query.query, self.llm)
|
||||
)
|
||||
for section in sections
|
||||
]
|
||||
|
||||
results = run_functions_in_parallel(function_calls=functions)
|
||||
|
||||
return {
|
||||
next(iter(value)): value[next(iter(value))] for value in results.values()
|
||||
}
|
||||
|
||||
@property
|
||||
def section_relevance_list(self) -> list[bool]:
|
||||
if self.ran_merge_chunk:
|
||||
return [False] * len(self.reranked_sections)
|
||||
|
||||
return [
|
||||
True if ind in self.relevant_chunk_indices else False
|
||||
for ind in range(len(self.reranked_chunks))
|
||||
True if ind in self.relevant_section_indices else False
|
||||
for ind in range(len(self.reranked_sections))
|
||||
]
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
import numpy
|
||||
|
||||
from danswer.configs.constants import MAX_CHUNK_TITLE_LEN
|
||||
from danswer.configs.constants import RETURN_SEPARATOR
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
||||
from danswer.document_index.document_index_utils import (
|
||||
@@ -12,12 +14,14 @@ from danswer.document_index.document_index_utils import (
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.search.models import ChunkMetric
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceChunkUncleaned
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.models import MAX_METRICS_CONTENT
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
|
||||
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks
|
||||
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
@@ -27,9 +31,12 @@ from danswer.utils.timing import log_function_time
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None:
|
||||
def _log_top_section_links(search_flow: str, sections: list[InferenceSection]) -> None:
|
||||
top_links = [
|
||||
c.source_links[0] if c.source_links is not None else "No Link" for c in chunks
|
||||
section.center_chunk.source_links[0]
|
||||
if section.center_chunk.source_links is not None
|
||||
else "No Link"
|
||||
for section in sections
|
||||
]
|
||||
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
|
||||
|
||||
@@ -43,6 +50,33 @@ def should_apply_llm_based_relevance_filter(query: SearchQuery) -> bool:
|
||||
return not query.skip_llm_chunk_filter
|
||||
|
||||
|
||||
def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk]:
|
||||
def _remove_title(chunk: InferenceChunkUncleaned) -> str:
|
||||
if not chunk.title or not chunk.content:
|
||||
return chunk.content
|
||||
|
||||
if chunk.content.startswith(chunk.title):
|
||||
return chunk.content[len(chunk.title) :].lstrip()
|
||||
|
||||
if chunk.content.startswith(chunk.title[:MAX_CHUNK_TITLE_LEN]):
|
||||
return chunk.content[MAX_CHUNK_TITLE_LEN:].lstrip()
|
||||
|
||||
return chunk.content
|
||||
|
||||
def _remove_metadata_suffix(chunk: InferenceChunkUncleaned) -> str:
|
||||
if not chunk.metadata_suffix:
|
||||
return chunk.content
|
||||
return chunk.content.removesuffix(chunk.metadata_suffix).rstrip(
|
||||
RETURN_SEPARATOR
|
||||
)
|
||||
|
||||
for chunk in chunks:
|
||||
chunk.content = _remove_title(chunk)
|
||||
chunk.content = _remove_metadata_suffix(chunk)
|
||||
|
||||
return [chunk.to_inference_chunk() for chunk in chunks]
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def semantic_reranking(
|
||||
query: str,
|
||||
@@ -113,84 +147,113 @@ def semantic_reranking(
|
||||
return list(ranked_chunks), list(ranked_indices)
|
||||
|
||||
|
||||
def rerank_chunks(
|
||||
def rerank_sections(
|
||||
query: SearchQuery,
|
||||
chunks_to_rerank: list[InferenceChunk],
|
||||
sections_to_rerank: list[InferenceSection],
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
) -> list[InferenceSection]:
|
||||
"""Chunks are reranked rather than the containing sections, this is because of speed
|
||||
implications, if reranking models have lower latency for long inputs in the future
|
||||
we may rerank on the combined context of the section instead
|
||||
|
||||
Making the assumption here that often times we want larger Sections to provide context
|
||||
for the LLM to determine if a section is useful but for reranking, we don't need to be
|
||||
as stringent. If the Section is relevant, we assume that the chunk rerank score will
|
||||
also be high.
|
||||
"""
|
||||
chunks_to_rerank = [section.center_chunk for section in sections_to_rerank]
|
||||
|
||||
ranked_chunks, _ = semantic_reranking(
|
||||
query=query.query,
|
||||
chunks=chunks_to_rerank[: query.num_rerank],
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
lower_chunks = chunks_to_rerank[query.num_rerank :]
|
||||
|
||||
# Scores from rerank cannot be meaningfully combined with scores without rerank
|
||||
# However the ordering is still important
|
||||
for lower_chunk in lower_chunks:
|
||||
lower_chunk.score = None
|
||||
ranked_chunks.extend(lower_chunks)
|
||||
return ranked_chunks
|
||||
|
||||
chunk_id_to_section = {
|
||||
section.center_chunk.unique_id: section for section in sections_to_rerank
|
||||
}
|
||||
ordered_sections = [chunk_id_to_section[chunk.unique_id] for chunk in ranked_chunks]
|
||||
return ordered_sections
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def filter_chunks(
|
||||
def filter_sections(
|
||||
query: SearchQuery,
|
||||
chunks_to_filter: list[InferenceChunk],
|
||||
sections_to_filter: list[InferenceSection],
|
||||
llm: LLM,
|
||||
) -> list[str]:
|
||||
"""Filters chunks based on whether the LLM thought they were relevant to the query.
|
||||
# For cost saving, we may turn this on
|
||||
use_chunk: bool = False,
|
||||
) -> list[InferenceSection]:
|
||||
"""Filters sections based on whether the LLM thought they were relevant to the query.
|
||||
This applies on the section which has more context than the chunk. Hopefully this yields more accurate LLM evaluations.
|
||||
|
||||
Returns a list of the unique chunk IDs that were marked as relevant"""
|
||||
chunks_to_filter = chunks_to_filter[: query.max_llm_filter_chunks]
|
||||
llm_chunk_selection = llm_batch_eval_chunks(
|
||||
Returns a list of the unique chunk IDs that were marked as relevant
|
||||
"""
|
||||
sections_to_filter = sections_to_filter[: query.max_llm_filter_sections]
|
||||
|
||||
contents = [
|
||||
section.center_chunk.content if use_chunk else section.combined_content
|
||||
for section in sections_to_filter
|
||||
]
|
||||
|
||||
llm_chunk_selection = llm_batch_eval_sections(
|
||||
query=query.query,
|
||||
chunk_contents=[chunk.content for chunk in chunks_to_filter],
|
||||
section_contents=contents,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
return [
|
||||
chunk.unique_id
|
||||
for ind, chunk in enumerate(chunks_to_filter)
|
||||
section
|
||||
for ind, section in enumerate(sections_to_filter)
|
||||
if llm_chunk_selection[ind]
|
||||
]
|
||||
|
||||
|
||||
def search_postprocessing(
|
||||
search_query: SearchQuery,
|
||||
retrieved_chunks: list[InferenceChunk],
|
||||
retrieved_sections: list[InferenceSection],
|
||||
llm: LLM,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> Generator[list[InferenceChunk] | list[str], None, None]:
|
||||
) -> Iterator[list[InferenceSection] | list[int]]:
|
||||
post_processing_tasks: list[FunctionCall] = []
|
||||
|
||||
rerank_task_id = None
|
||||
chunks_yielded = False
|
||||
sections_yielded = False
|
||||
if should_rerank(search_query):
|
||||
post_processing_tasks.append(
|
||||
FunctionCall(
|
||||
rerank_chunks,
|
||||
rerank_sections,
|
||||
(
|
||||
search_query,
|
||||
retrieved_chunks,
|
||||
retrieved_sections,
|
||||
rerank_metrics_callback,
|
||||
),
|
||||
)
|
||||
)
|
||||
rerank_task_id = post_processing_tasks[-1].result_id
|
||||
else:
|
||||
final_chunks = retrieved_chunks
|
||||
# NOTE: if we don't rerank, we can return the chunks immediately
|
||||
# since we know this is the final order
|
||||
_log_top_chunk_links(search_query.search_type.value, final_chunks)
|
||||
yield final_chunks
|
||||
chunks_yielded = True
|
||||
# since we know this is the final order.
|
||||
# This way the user experience isn't delayed by the LLM step
|
||||
_log_top_section_links(search_query.search_type.value, retrieved_sections)
|
||||
yield retrieved_sections
|
||||
sections_yielded = True
|
||||
|
||||
llm_filter_task_id = None
|
||||
if should_apply_llm_based_relevance_filter(search_query):
|
||||
post_processing_tasks.append(
|
||||
FunctionCall(
|
||||
filter_chunks,
|
||||
filter_sections,
|
||||
(
|
||||
search_query,
|
||||
retrieved_chunks[: search_query.max_llm_filter_chunks],
|
||||
retrieved_sections[: search_query.max_llm_filter_sections],
|
||||
llm,
|
||||
),
|
||||
)
|
||||
@@ -202,30 +265,30 @@ def search_postprocessing(
|
||||
if post_processing_tasks
|
||||
else {}
|
||||
)
|
||||
reranked_chunks = cast(
|
||||
list[InferenceChunk] | None,
|
||||
reranked_sections = cast(
|
||||
list[InferenceSection] | None,
|
||||
post_processing_results.get(str(rerank_task_id)) if rerank_task_id else None,
|
||||
)
|
||||
if reranked_chunks:
|
||||
if chunks_yielded:
|
||||
if reranked_sections:
|
||||
if sections_yielded:
|
||||
logger.error(
|
||||
"Trying to yield re-ranked chunks, but chunks were already yielded. This should never happen."
|
||||
"Trying to yield re-ranked sections, but sections were already yielded. This should never happen."
|
||||
)
|
||||
else:
|
||||
_log_top_chunk_links(search_query.search_type.value, reranked_chunks)
|
||||
yield reranked_chunks
|
||||
_log_top_section_links(search_query.search_type.value, reranked_sections)
|
||||
yield reranked_sections
|
||||
|
||||
llm_chunk_selection = cast(
|
||||
list[str] | None,
|
||||
post_processing_results.get(str(llm_filter_task_id))
|
||||
if llm_filter_task_id
|
||||
else None,
|
||||
)
|
||||
if llm_chunk_selection is not None:
|
||||
yield [
|
||||
chunk.unique_id
|
||||
for chunk in reranked_chunks or retrieved_chunks
|
||||
if chunk.unique_id in llm_chunk_selection
|
||||
llm_selected_section_ids = (
|
||||
[
|
||||
section.center_chunk.unique_id
|
||||
for section in post_processing_results.get(str(llm_filter_task_id), [])
|
||||
]
|
||||
else:
|
||||
yield cast(list[str], [])
|
||||
if llm_filter_task_id
|
||||
else []
|
||||
)
|
||||
|
||||
yield [
|
||||
index
|
||||
for index, section in enumerate(reranked_sections or retrieved_sections)
|
||||
if section.center_chunk.unique_id in llm_selected_section_ids
|
||||
]
|
||||
|
||||
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def count_unk_tokens(text: str, tokenizer: "AutoTokenizer") -> int:
|
||||
"""Unclear if the wordpiece tokenizer used is actually tokenizing anything as the [UNK] token
|
||||
"""Unclear if the wordpiece/sentencepiece tokenizer used is actually tokenizing anything as the [UNK] token
|
||||
It splits up even foreign characters and unicode emojis without using UNK"""
|
||||
tokenized_text = tokenizer.tokenize(text)
|
||||
num_unk_tokens = len(
|
||||
@@ -73,6 +73,7 @@ def recommend_search_flow(
|
||||
non_stopword_percent = len(non_stopwords) / len(words)
|
||||
|
||||
# UNK tokens -> suggest Keyword (still may be valid QA)
|
||||
# TODO do a better job with the classifier model and retire the heuristics
|
||||
if count_unk_tokens(query, get_default_tokenizer(model_name=model_name)) > 0:
|
||||
if not keyword:
|
||||
heuristic_search_type = SearchType.KEYWORD
|
||||
|
||||
@@ -2,7 +2,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.chat_configs import BASE_RECENCY_DECAY
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION
|
||||
from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER
|
||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from danswer.db.models import User
|
||||
@@ -36,8 +35,6 @@ def retrieval_preprocessing(
|
||||
db_session: Session,
|
||||
bypass_acl: bool = False,
|
||||
include_query_intent: bool = True,
|
||||
enable_auto_detect_filters: bool = False,
|
||||
disable_llm_filter_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION,
|
||||
disable_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
|
||||
base_recency_decay: float = BASE_RECENCY_DECAY,
|
||||
favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER,
|
||||
@@ -63,10 +60,7 @@ def retrieval_preprocessing(
|
||||
|
||||
auto_detect_time_filter = True
|
||||
auto_detect_source_filter = True
|
||||
if disable_llm_filter_extraction:
|
||||
auto_detect_time_filter = False
|
||||
auto_detect_source_filter = False
|
||||
elif enable_auto_detect_filters is False:
|
||||
if not search_request.enable_auto_detect_filters:
|
||||
logger.debug("Retrieval details disables auto detect filters")
|
||||
auto_detect_time_filter = False
|
||||
auto_detect_source_filter = False
|
||||
|
||||
@@ -7,26 +7,28 @@ from nltk.stem import WordNetLemmatizer # type:ignore
|
||||
from nltk.tokenize import word_tokenize # type:ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import HYBRID_ALPHA
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.search.enums import EmbedTextType
|
||||
from danswer.search.models import ChunkMetric
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.models import MAX_METRICS_CONTENT
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.postprocessing.postprocessing import cleanup_chunks
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.search.utils import inference_section_from_chunks
|
||||
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from danswer.utils.timing import log_function_time
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.enums import EmbedTextType
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -129,6 +131,8 @@ def doc_index_retrieval(
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
normalize=db_embedding_model.normalize,
|
||||
api_key=db_embedding_model.api_key,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
@@ -159,7 +163,7 @@ def doc_index_retrieval(
|
||||
else:
|
||||
raise RuntimeError("Invalid Search Flow")
|
||||
|
||||
return top_chunks
|
||||
return cleanup_chunks(top_chunks)
|
||||
|
||||
|
||||
def _simplify_text(text: str) -> str:
|
||||
@@ -240,30 +244,10 @@ def retrieve_chunks(
|
||||
return top_chunks
|
||||
|
||||
|
||||
def combine_inference_chunks(inf_chunks: list[InferenceChunk]) -> LlmDoc:
|
||||
if not inf_chunks:
|
||||
raise ValueError("Cannot combine empty list of chunks")
|
||||
|
||||
# Use the first link of the document
|
||||
first_chunk = inf_chunks[0]
|
||||
chunk_texts = [chunk.content for chunk in inf_chunks]
|
||||
return LlmDoc(
|
||||
document_id=first_chunk.document_id,
|
||||
content="\n".join(chunk_texts),
|
||||
blurb=first_chunk.blurb,
|
||||
semantic_identifier=first_chunk.semantic_identifier,
|
||||
source_type=first_chunk.source_type,
|
||||
metadata=first_chunk.metadata,
|
||||
updated_at=first_chunk.updated_at,
|
||||
link=first_chunk.source_links[0] if first_chunk.source_links else None,
|
||||
source_links=first_chunk.source_links,
|
||||
)
|
||||
|
||||
|
||||
def inference_documents_from_ids(
|
||||
def inference_sections_from_ids(
|
||||
doc_identifiers: list[tuple[str, int]],
|
||||
document_index: DocumentIndex,
|
||||
) -> list[LlmDoc]:
|
||||
) -> list[InferenceSection]:
|
||||
# Currently only fetches whole docs
|
||||
doc_ids_set = set(doc_id for doc_id, chunk_id in doc_identifiers)
|
||||
|
||||
@@ -282,4 +266,17 @@ def inference_documents_from_ids(
|
||||
# Any failures to retrieve would give a None, drop the Nones and empty lists
|
||||
inference_chunks_sets = [res for res in parallel_results if res]
|
||||
|
||||
return [combine_inference_chunks(chunk_set) for chunk_set in inference_chunks_sets]
|
||||
return [
|
||||
inference_section
|
||||
for inference_section in [
|
||||
inference_section_from_chunks(
|
||||
# The scores will always be 0 because the fetching by id gives back
|
||||
# no search scores. This is not needed though if the user is explicitly
|
||||
# selecting a document.
|
||||
center_chunk=chunk_set[0],
|
||||
chunks=chunk_set,
|
||||
)
|
||||
for chunk_set in inference_chunks_sets
|
||||
]
|
||||
if inference_section is not None
|
||||
]
|
||||
|
||||
@@ -9,10 +9,10 @@ from transformers import logging as transformer_logging # type:ignore
|
||||
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.search.enums import EmbedTextType
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
from shared_configs.model_server_models import EmbedResponse
|
||||
from shared_configs.model_server_models import IntentRequest
|
||||
@@ -40,25 +40,22 @@ def clean_model_name(model_str: str) -> str:
|
||||
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
|
||||
|
||||
|
||||
# NOTE: If None is used, it may not be using the "correct" tokenizer, for cases
|
||||
# where this is more important, be sure to refresh with the actual model name
|
||||
def get_default_tokenizer(model_name: str | None = None) -> "AutoTokenizer":
|
||||
# NOTE: If no model_name is specified, it may not be using the "correct" tokenizer
|
||||
# for cases where this is more important, be sure to refresh with the actual model name
|
||||
# One case where it is not particularly important is in the document chunking flow,
|
||||
# they're basically all using the sentencepiece tokenizer and whether it's cased or
|
||||
# uncased does not really matter, they'll all generally end up with the same chunk lengths.
|
||||
def get_default_tokenizer(model_name: str = DOCUMENT_ENCODER_MODEL) -> "AutoTokenizer":
|
||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
||||
# processes importing this file despite not using any of this
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
global _TOKENIZER
|
||||
if _TOKENIZER[0] is None or (
|
||||
_TOKENIZER[1] is not None and _TOKENIZER[1] != model_name
|
||||
):
|
||||
if _TOKENIZER[0] is None or _TOKENIZER[1] != model_name:
|
||||
if _TOKENIZER[0] is not None:
|
||||
del _TOKENIZER
|
||||
gc.collect()
|
||||
|
||||
if model_name is None:
|
||||
# This could be inaccurate
|
||||
model_name = DOCUMENT_ENCODER_MODEL
|
||||
|
||||
_TOKENIZER = (AutoTokenizer.from_pretrained(model_name), model_name)
|
||||
|
||||
if hasattr(_TOKENIZER[0], "is_fast") and _TOKENIZER[0].is_fast:
|
||||
@@ -84,20 +81,24 @@ def build_model_server_url(
|
||||
class EmbeddingModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
query_prefix: str | None,
|
||||
passage_prefix: str | None,
|
||||
normalize: bool,
|
||||
server_host: str, # Changes depending on indexing or inference
|
||||
server_port: int,
|
||||
model_name: str | None,
|
||||
normalize: bool,
|
||||
query_prefix: str | None,
|
||||
passage_prefix: str | None,
|
||||
api_key: str | None,
|
||||
provider_type: str | None,
|
||||
# The following are globals are currently not configurable
|
||||
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.api_key = api_key
|
||||
self.provider_type = provider_type
|
||||
self.max_seq_length = max_seq_length
|
||||
self.query_prefix = query_prefix
|
||||
self.passage_prefix = passage_prefix
|
||||
self.normalize = normalize
|
||||
self.model_name = model_name
|
||||
|
||||
model_server_url = build_model_server_url(server_host, server_port)
|
||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||
@@ -111,10 +112,13 @@ class EmbeddingModel:
|
||||
prefixed_texts = texts
|
||||
|
||||
embed_request = EmbedRequest(
|
||||
texts=prefixed_texts,
|
||||
model_name=self.model_name,
|
||||
texts=prefixed_texts,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
)
|
||||
|
||||
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
|
||||
@@ -177,6 +181,7 @@ def warm_up_encoders(
|
||||
"https://docs.danswer.dev/quickstart"
|
||||
)
|
||||
|
||||
# May not be the exact same tokenizer used for the indexing flow
|
||||
get_default_tokenizer(model_name=model_name)(warm_up_str)
|
||||
|
||||
embed_model = EmbeddingModel(
|
||||
@@ -187,6 +192,8 @@ def warm_up_encoders(
|
||||
passage_prefix=None,
|
||||
server_host=model_server_host,
|
||||
server_port=model_server_port,
|
||||
api_key=None,
|
||||
provider_type=None,
|
||||
)
|
||||
|
||||
# First time downloading the models it may take even longer, but just in case,
|
||||
|
||||
@@ -5,10 +5,18 @@ from danswer.db.models import SearchDoc as DBSearchDoc
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.models import SavedSearchDoc
|
||||
from danswer.search.models import SavedSearchDocWithContent
|
||||
from danswer.search.models import SearchDoc
|
||||
|
||||
|
||||
T = TypeVar("T", InferenceSection, InferenceChunk, SearchDoc)
|
||||
T = TypeVar(
|
||||
"T",
|
||||
InferenceSection,
|
||||
InferenceChunk,
|
||||
SearchDoc,
|
||||
SavedSearchDoc,
|
||||
SavedSearchDocWithContent,
|
||||
)
|
||||
|
||||
|
||||
def dedupe_documents(items: list[T]) -> tuple[list[T], list[int]]:
|
||||
@@ -16,8 +24,13 @@ def dedupe_documents(items: list[T]) -> tuple[list[T], list[int]]:
|
||||
deduped_items = []
|
||||
dropped_indices = []
|
||||
for index, item in enumerate(items):
|
||||
if item.document_id not in seen_ids:
|
||||
seen_ids.add(item.document_id)
|
||||
if isinstance(item, InferenceSection):
|
||||
document_id = item.center_chunk.document_id
|
||||
else:
|
||||
document_id = item.document_id
|
||||
|
||||
if document_id not in seen_ids:
|
||||
seen_ids.add(document_id)
|
||||
deduped_items.append(item)
|
||||
else:
|
||||
dropped_indices.append(index)
|
||||
@@ -37,30 +50,51 @@ def drop_llm_indices(
|
||||
return [i for i, val in enumerate(llm_bools) if val]
|
||||
|
||||
|
||||
def chunks_or_sections_to_search_docs(
|
||||
chunks: Sequence[InferenceChunk | InferenceSection] | None,
|
||||
) -> list[SearchDoc]:
|
||||
search_docs = (
|
||||
[
|
||||
SearchDoc(
|
||||
document_id=chunk.document_id,
|
||||
chunk_ind=chunk.chunk_id,
|
||||
semantic_identifier=chunk.semantic_identifier or "Unknown",
|
||||
link=chunk.source_links.get(0) if chunk.source_links else None,
|
||||
blurb=chunk.blurb,
|
||||
source_type=chunk.source_type,
|
||||
boost=chunk.boost,
|
||||
hidden=chunk.hidden,
|
||||
metadata=chunk.metadata,
|
||||
score=chunk.score,
|
||||
match_highlights=chunk.match_highlights,
|
||||
updated_at=chunk.updated_at,
|
||||
primary_owners=chunk.primary_owners,
|
||||
secondary_owners=chunk.secondary_owners,
|
||||
)
|
||||
for chunk in chunks
|
||||
]
|
||||
if chunks
|
||||
else []
|
||||
def inference_section_from_chunks(
|
||||
center_chunk: InferenceChunk,
|
||||
chunks: list[InferenceChunk],
|
||||
) -> InferenceSection | None:
|
||||
if not chunks:
|
||||
return None
|
||||
|
||||
combined_content = "\n".join([chunk.content for chunk in chunks])
|
||||
|
||||
return InferenceSection(
|
||||
center_chunk=center_chunk,
|
||||
chunks=chunks,
|
||||
combined_content=combined_content,
|
||||
)
|
||||
|
||||
|
||||
def chunks_or_sections_to_search_docs(
|
||||
items: Sequence[InferenceChunk | InferenceSection] | None,
|
||||
) -> list[SearchDoc]:
|
||||
if not items:
|
||||
return []
|
||||
|
||||
search_docs = [
|
||||
SearchDoc(
|
||||
document_id=(
|
||||
chunk := item.center_chunk
|
||||
if isinstance(item, InferenceSection)
|
||||
else item
|
||||
).document_id,
|
||||
chunk_ind=chunk.chunk_id,
|
||||
semantic_identifier=chunk.semantic_identifier or "Unknown",
|
||||
link=chunk.source_links[0] if chunk.source_links else None,
|
||||
blurb=chunk.blurb,
|
||||
source_type=chunk.source_type,
|
||||
boost=chunk.boost,
|
||||
hidden=chunk.hidden,
|
||||
metadata=chunk.metadata,
|
||||
score=chunk.score,
|
||||
match_highlights=chunk.match_highlights,
|
||||
updated_at=chunk.updated_at,
|
||||
primary_owners=chunk.primary_owners,
|
||||
secondary_owners=chunk.secondary_owners,
|
||||
is_internet=False,
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
return search_docs
|
||||
|
||||
70
backend/danswer/secondary_llm_flows/agentic_evaluation.py
Normal file
70
backend/danswer/secondary_llm_flows/agentic_evaluation.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import re
|
||||
|
||||
from danswer.chat.models import RelevanceChunk
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.agentic_evaluation import AGENTIC_SEARCH_SYSTEM_PROMPT
|
||||
from danswer.prompts.agentic_evaluation import AGENTIC_SEARCH_USER_PROMPT
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_agent_eval_messages(
|
||||
title: str, content: str, query: str
|
||||
) -> list[dict[str, str]]:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": AGENTIC_SEARCH_SYSTEM_PROMPT,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": AGENTIC_SEARCH_USER_PROMPT.format(
|
||||
title=title, content=content, query=query
|
||||
),
|
||||
},
|
||||
]
|
||||
return messages
|
||||
|
||||
|
||||
def evaluate_inference_section(
|
||||
document: InferenceSection, query: str, llm: LLM
|
||||
) -> dict[str, RelevanceChunk]:
|
||||
results = {}
|
||||
|
||||
document_id = document.center_chunk.document_id
|
||||
semantic_id = document.center_chunk.semantic_identifier
|
||||
contents = document.combined_content
|
||||
chunk_id = document.center_chunk.chunk_id
|
||||
|
||||
messages = _get_agent_eval_messages(
|
||||
title=semantic_id, content=contents, query=query
|
||||
)
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
model_output = message_to_string(llm.invoke(filled_llm_prompt))
|
||||
|
||||
# Search for the "Useful Analysis" section in the model output
|
||||
# This regex looks for "2. Useful Analysis" (case-insensitive) followed by an optional colon,
|
||||
# then any text up to "3. Final Relevance"
|
||||
# The (?i) flag makes it case-insensitive, and re.DOTALL allows the dot to match newlines
|
||||
# If no match is found, the entire model output is used as the analysis
|
||||
analysis_match = re.search(
|
||||
r"(?i)2\.\s*useful analysis:?\s*(.+?)\n\n3\.\s*final relevance",
|
||||
model_output,
|
||||
re.DOTALL,
|
||||
)
|
||||
analysis = analysis_match.group(1).strip() if analysis_match else model_output
|
||||
|
||||
# Get the last non-empty line
|
||||
last_line = next(
|
||||
(line for line in reversed(model_output.split("\n")) if line.strip()), ""
|
||||
)
|
||||
relevant = last_line.strip().lower().startswith("true")
|
||||
|
||||
results[f"{document_id}-{chunk_id}"] = RelevanceChunk(
|
||||
relevant=relevant, content=analysis
|
||||
)
|
||||
return results
|
||||
@@ -3,21 +3,21 @@ from collections.abc import Callable
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.llm_chunk_filter import CHUNK_FILTER_PROMPT
|
||||
from danswer.prompts.llm_chunk_filter import NONUSEFUL_PAT
|
||||
from danswer.prompts.llm_chunk_filter import SECTION_FILTER_PROMPT
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def llm_eval_chunk(query: str, chunk_content: str, llm: LLM) -> bool:
|
||||
def llm_eval_section(query: str, section_content: str, llm: LLM) -> bool:
|
||||
def _get_usefulness_messages() -> list[dict[str, str]]:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": CHUNK_FILTER_PROMPT.format(
|
||||
chunk_text=chunk_content, user_query=query
|
||||
"content": SECTION_FILTER_PROMPT.format(
|
||||
chunk_text=section_content, user_query=query
|
||||
),
|
||||
},
|
||||
]
|
||||
@@ -42,13 +42,13 @@ def llm_eval_chunk(query: str, chunk_content: str, llm: LLM) -> bool:
|
||||
return _extract_usefulness(model_output)
|
||||
|
||||
|
||||
def llm_batch_eval_chunks(
|
||||
query: str, chunk_contents: list[str], llm: LLM, use_threads: bool = True
|
||||
def llm_batch_eval_sections(
|
||||
query: str, section_contents: list[str], llm: LLM, use_threads: bool = True
|
||||
) -> list[bool]:
|
||||
if use_threads:
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [
|
||||
(llm_eval_chunk, (query, chunk_content, llm))
|
||||
for chunk_content in chunk_contents
|
||||
(llm_eval_section, (query, section_content, llm))
|
||||
for section_content in section_contents
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
@@ -58,11 +58,11 @@ def llm_batch_eval_chunks(
|
||||
functions_with_args, allow_failures=True
|
||||
)
|
||||
|
||||
# In case of failure/timeout, don't throw out the chunk
|
||||
# In case of failure/timeout, don't throw out the section
|
||||
return [True if item is None else item for item in parallel_results]
|
||||
|
||||
else:
|
||||
return [
|
||||
llm_eval_chunk(query, chunk_content, llm)
|
||||
for chunk_content in chunk_contents
|
||||
llm_eval_section(query, section_content, llm)
|
||||
for section_content in section_contents
|
||||
]
|
||||
|
||||
@@ -74,11 +74,12 @@ def multilingual_query_expansion(
|
||||
def get_contextual_rephrase_messages(
|
||||
question: str,
|
||||
history_str: str,
|
||||
prompt_template: str = HISTORY_QUERY_REPHRASE,
|
||||
) -> list[dict[str, str]]:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": HISTORY_QUERY_REPHRASE.format(
|
||||
"content": prompt_template.format(
|
||||
question=question, chat_history=history_str
|
||||
),
|
||||
},
|
||||
@@ -94,6 +95,7 @@ def history_based_query_rephrase(
|
||||
size_heuristic: int = 200,
|
||||
punctuation_heuristic: int = 10,
|
||||
skip_first_rephrase: bool = False,
|
||||
prompt_template: str = HISTORY_QUERY_REPHRASE,
|
||||
) -> str:
|
||||
# Globally disabled, just use the exact user query
|
||||
if DISABLE_LLM_QUERY_REPHRASE:
|
||||
@@ -119,7 +121,7 @@ def history_based_query_rephrase(
|
||||
)
|
||||
|
||||
prompt_msgs = get_contextual_rephrase_messages(
|
||||
question=query, history_str=history_str
|
||||
question=query, history_str=history_str, prompt_template=prompt_template
|
||||
)
|
||||
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
|
||||
|
||||
@@ -77,7 +77,7 @@ def associate_credential_to_connector(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
metadata: ConnectorCredentialPairMetadata,
|
||||
user: User = Depends(current_user),
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse[int]:
|
||||
try:
|
||||
@@ -97,7 +97,7 @@ def associate_credential_to_connector(
|
||||
def dissociate_credential_from_connector(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
user: User = Depends(current_user),
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse[int]:
|
||||
return remove_credential_from_connector(
|
||||
|
||||
@@ -10,6 +10,7 @@ class ToolSnapshot(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
definition: dict[str, Any] | None
|
||||
display_name: str
|
||||
in_code_tool_id: str | None
|
||||
|
||||
@classmethod
|
||||
@@ -19,5 +20,6 @@ class ToolSnapshot(BaseModel):
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
definition=tool.openapi_schema,
|
||||
display_name=tool.display_name or tool.name,
|
||||
in_code_tool_id=tool.in_code_tool_id,
|
||||
)
|
||||
|
||||
@@ -68,7 +68,7 @@ def gpt_search(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> GptSearchResponse:
|
||||
llm, fast_llm = get_default_llms()
|
||||
top_chunks = SearchPipeline(
|
||||
top_sections = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=search_request.query,
|
||||
),
|
||||
@@ -76,20 +76,22 @@ def gpt_search(
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
).reranked_chunks
|
||||
).reranked_sections
|
||||
|
||||
return GptSearchResponse(
|
||||
matching_document_chunks=[
|
||||
GptDocChunk(
|
||||
title=chunk.semantic_identifier,
|
||||
content=chunk.content,
|
||||
source_type=chunk.source_type,
|
||||
link=chunk.source_links.get(0, "") if chunk.source_links else "",
|
||||
metadata=chunk.metadata,
|
||||
document_age=time_ago(chunk.updated_at)
|
||||
if chunk.updated_at
|
||||
title=section.center_chunk.semantic_identifier,
|
||||
content=section.center_chunk.content,
|
||||
source_type=section.center_chunk.source_type,
|
||||
link=section.center_chunk.source_links.get(0, "")
|
||||
if section.center_chunk.source_links
|
||||
else "",
|
||||
metadata=section.center_chunk.metadata,
|
||||
document_age=time_ago(section.center_chunk.updated_at)
|
||||
if section.center_chunk.updated_at
|
||||
else "Unknown",
|
||||
)
|
||||
for chunk in top_chunks
|
||||
for section in top_sections
|
||||
],
|
||||
)
|
||||
|
||||
93
backend/danswer/server/manage/embedding/api.py
Normal file
93
backend/danswer/server/manage/embedding/api.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.db.embedding_model import get_current_db_embedding_provider
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.llm import fetch_existing_embedding_providers
|
||||
from danswer.db.llm import remove_embedding_provider
|
||||
from danswer.db.llm import upsert_cloud_embedding_provider
|
||||
from danswer.db.models import User
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from danswer.server.manage.embedding.models import TestEmbeddingRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.enums import EmbedTextType
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/embedding")
|
||||
basic_router = APIRouter(prefix="/embedding")
|
||||
|
||||
|
||||
@admin_router.post("/test-embedding")
|
||||
def test_embedding_configuration(
|
||||
test_llm_request: TestEmbeddingRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
try:
|
||||
test_model = EmbeddingModel(
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
api_key=test_llm_request.api_key,
|
||||
provider_type=test_llm_request.provider,
|
||||
normalize=False,
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
model_name=None,
|
||||
)
|
||||
test_model.encode(["Test String"], text_type=EmbedTextType.QUERY)
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = f"Not a valid embedding model. Exception thrown: {e}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = "An error occurred while testing your embedding model. Please check your configuration."
|
||||
logger.error(f"{error_msg} Error message: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
|
||||
@admin_router.get("/embedding-provider")
|
||||
def list_embedding_providers(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[CloudEmbeddingProvider]:
|
||||
return [
|
||||
CloudEmbeddingProvider.from_request(embedding_provider_model)
|
||||
for embedding_provider_model in fetch_existing_embedding_providers(db_session)
|
||||
]
|
||||
|
||||
|
||||
@admin_router.delete("/embedding-provider/{embedding_provider_name}")
|
||||
def delete_embedding_provider(
|
||||
embedding_provider_name: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
embedding_provider = get_current_db_embedding_provider(db_session=db_session)
|
||||
if (
|
||||
embedding_provider is not None
|
||||
and embedding_provider_name == embedding_provider.name
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="You can't delete a currently active model"
|
||||
)
|
||||
|
||||
remove_embedding_provider(db_session, embedding_provider_name)
|
||||
|
||||
|
||||
@admin_router.put("/embedding-provider")
|
||||
def put_cloud_embedding_provider(
|
||||
provider: CloudEmbeddingProviderCreationRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CloudEmbeddingProvider:
|
||||
return upsert_cloud_embedding_provider(db_session, provider)
|
||||
35
backend/danswer/server/manage/embedding/models.py
Normal file
35
backend/danswer/server/manage/embedding/models.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
|
||||
|
||||
class TestEmbeddingRequest(BaseModel):
|
||||
provider: str
|
||||
api_key: str | None = None
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(BaseModel):
|
||||
name: str
|
||||
api_key: str | None = None
|
||||
default_model_id: int | None = None
|
||||
id: int
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
cls, cloud_provider_model: "CloudEmbeddingProviderModel"
|
||||
) -> "CloudEmbeddingProvider":
|
||||
return cls(
|
||||
id=cloud_provider_model.id,
|
||||
name=cloud_provider_model.name,
|
||||
api_key=cloud_provider_model.api_key,
|
||||
default_model_id=cloud_provider_model.default_model_id,
|
||||
)
|
||||
|
||||
|
||||
class CloudEmbeddingProviderCreationRequest(BaseModel):
|
||||
name: str
|
||||
api_key: str | None = None
|
||||
default_model_id: int | None = None
|
||||
@@ -4,6 +4,7 @@ from pydantic import BaseModel
|
||||
|
||||
from danswer.llm.llm_provider_options import fetch_models_for_provider
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ from danswer.db.models import AllowedAnswerFilters
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import SlackBotConfig as SlackBotConfigModel
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.models import StandardAnswer as StandardAnswerModel
|
||||
from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
|
||||
from danswer.indexing.models import EmbeddingModelDetail
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
from danswer.server.models import FullUserSnapshot
|
||||
@@ -84,6 +86,57 @@ class HiddenUpdateRequest(BaseModel):
|
||||
hidden: bool
|
||||
|
||||
|
||||
class StandardAnswerCategoryCreationRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class StandardAnswerCategory(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, standard_answer_category: StandardAnswerCategoryModel
|
||||
) -> "StandardAnswerCategory":
|
||||
return cls(
|
||||
id=standard_answer_category.id,
|
||||
name=standard_answer_category.name,
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswer(BaseModel):
|
||||
id: int
|
||||
keyword: str
|
||||
answer: str
|
||||
categories: list[StandardAnswerCategory]
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
|
||||
return cls(
|
||||
id=standard_answer_model.id,
|
||||
keyword=standard_answer_model.keyword,
|
||||
answer=standard_answer_model.answer,
|
||||
categories=[
|
||||
StandardAnswerCategory.from_model(standard_answer_category_model)
|
||||
for standard_answer_category_model in standard_answer_model.categories
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswerCreationRequest(BaseModel):
|
||||
keyword: str
|
||||
answer: str
|
||||
categories: list[int]
|
||||
|
||||
@validator("categories", pre=True)
|
||||
def validate_categories(cls, value: list[int]) -> list[int]:
|
||||
if len(value) < 1:
|
||||
raise ValueError(
|
||||
"At least one category must be attached to a standard answer"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
class SlackBotTokens(BaseModel):
|
||||
bot_token: str
|
||||
app_token: str
|
||||
@@ -102,13 +155,14 @@ class SlackBotConfigCreationRequest(BaseModel):
|
||||
channel_names: list[str]
|
||||
respond_tag_only: bool = False
|
||||
respond_to_bots: bool = False
|
||||
enable_auto_filters: bool = False
|
||||
# If no team members, assume respond in the channel to everyone
|
||||
respond_team_member_list: list[str] = []
|
||||
respond_slack_group_list: list[str] = []
|
||||
respond_member_group_list: list[str] = []
|
||||
answer_filters: list[AllowedAnswerFilters] = []
|
||||
# list of user emails
|
||||
follow_up_tags: list[str] | None = None
|
||||
response_type: SlackBotResponseType
|
||||
standard_answer_categories: list[int] = []
|
||||
|
||||
@validator("answer_filters", pre=True)
|
||||
def validate_filters(cls, value: list[str]) -> list[str]:
|
||||
@@ -133,6 +187,8 @@ class SlackBotConfig(BaseModel):
|
||||
persona: PersonaSnapshot | None
|
||||
channel_config: ChannelConfig
|
||||
response_type: SlackBotResponseType
|
||||
standard_answer_categories: list[StandardAnswerCategory]
|
||||
enable_auto_filters: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
@@ -149,6 +205,11 @@ class SlackBotConfig(BaseModel):
|
||||
),
|
||||
channel_config=slack_bot_config_model.channel_config,
|
||||
response_type=slack_bot_config_model.response_type,
|
||||
standard_answer_categories=[
|
||||
StandardAnswerCategory.from_model(standard_answer_category_model)
|
||||
for standard_answer_category_model in slack_bot_config_model.standard_answer_categories
|
||||
],
|
||||
enable_auto_filters=slack_bot_config_model.enable_auto_filters,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import resync_cc_pair
|
||||
from danswer.db.embedding_model import create_embedding_model
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_model_id_from_name
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.embedding_model import update_embedding_model_status
|
||||
from danswer.db.engine import get_session
|
||||
@@ -38,6 +39,19 @@ def set_new_embedding_model(
|
||||
"""
|
||||
current_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
if embed_model_details.cloud_provider_name is not None:
|
||||
cloud_id = get_model_id_from_name(
|
||||
db_session, embed_model_details.cloud_provider_name
|
||||
)
|
||||
|
||||
if cloud_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="No ID exists for given provider name",
|
||||
)
|
||||
|
||||
embed_model_details.cloud_provider_id = cloud_id
|
||||
|
||||
if embed_model_details.model_name == current_model.model_name:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
||||
@@ -34,11 +34,8 @@ def _form_channel_config(
|
||||
) -> ChannelConfig:
|
||||
raw_channel_names = slack_bot_config_creation_request.channel_names
|
||||
respond_tag_only = slack_bot_config_creation_request.respond_tag_only
|
||||
respond_team_member_list = (
|
||||
slack_bot_config_creation_request.respond_team_member_list
|
||||
)
|
||||
respond_slack_group_list = (
|
||||
slack_bot_config_creation_request.respond_slack_group_list
|
||||
respond_member_group_list = (
|
||||
slack_bot_config_creation_request.respond_member_group_list
|
||||
)
|
||||
answer_filters = slack_bot_config_creation_request.answer_filters
|
||||
follow_up_tags = slack_bot_config_creation_request.follow_up_tags
|
||||
@@ -61,7 +58,7 @@ def _form_channel_config(
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
if respond_tag_only and (respond_team_member_list or respond_slack_group_list):
|
||||
if respond_tag_only and respond_member_group_list:
|
||||
raise ValueError(
|
||||
"Cannot set DanswerBot to only respond to tags only and "
|
||||
"also respond to a predetermined set of users."
|
||||
@@ -72,10 +69,8 @@ def _form_channel_config(
|
||||
}
|
||||
if respond_tag_only is not None:
|
||||
channel_config["respond_tag_only"] = respond_tag_only
|
||||
if respond_team_member_list:
|
||||
channel_config["respond_team_member_list"] = respond_team_member_list
|
||||
if respond_slack_group_list:
|
||||
channel_config["respond_slack_group_list"] = respond_slack_group_list
|
||||
if respond_member_group_list:
|
||||
channel_config["respond_member_group_list"] = respond_member_group_list
|
||||
if answer_filters:
|
||||
channel_config["answer_filters"] = answer_filters
|
||||
if follow_up_tags is not None:
|
||||
@@ -113,7 +108,9 @@ def create_slack_bot_config(
|
||||
persona_id=persona_id,
|
||||
channel_config=channel_config,
|
||||
response_type=slack_bot_config_creation_request.response_type,
|
||||
standard_answer_category_ids=slack_bot_config_creation_request.standard_answer_categories,
|
||||
db_session=db_session,
|
||||
enable_auto_filters=slack_bot_config_creation_request.enable_auto_filters,
|
||||
)
|
||||
return SlackBotConfig.from_model(slack_bot_config_model)
|
||||
|
||||
@@ -171,7 +168,9 @@ def patch_slack_bot_config(
|
||||
persona_id=persona_id,
|
||||
channel_config=channel_config,
|
||||
response_type=slack_bot_config_creation_request.response_type,
|
||||
standard_answer_category_ids=slack_bot_config_creation_request.standard_answer_categories,
|
||||
db_session=db_session,
|
||||
enable_auto_filters=slack_bot_config_creation_request.enable_auto_filters,
|
||||
)
|
||||
return SlackBotConfig.from_model(slack_bot_config_model)
|
||||
|
||||
|
||||
139
backend/danswer/server/manage/standard_answer.py
Normal file
139
backend/danswer/server/manage/standard_answer.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.standard_answer import fetch_standard_answer
|
||||
from danswer.db.standard_answer import fetch_standard_answer_categories
|
||||
from danswer.db.standard_answer import fetch_standard_answer_category
|
||||
from danswer.db.standard_answer import fetch_standard_answers
|
||||
from danswer.db.standard_answer import insert_standard_answer
|
||||
from danswer.db.standard_answer import insert_standard_answer_category
|
||||
from danswer.db.standard_answer import remove_standard_answer
|
||||
from danswer.db.standard_answer import update_standard_answer
|
||||
from danswer.db.standard_answer import update_standard_answer_category
|
||||
from danswer.server.manage.models import StandardAnswer
|
||||
from danswer.server.manage.models import StandardAnswerCategory
|
||||
from danswer.server.manage.models import StandardAnswerCategoryCreationRequest
|
||||
from danswer.server.manage.models import StandardAnswerCreationRequest
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
|
||||
|
||||
@router.post("/admin/standard-answer")
|
||||
def create_standard_answer(
|
||||
standard_answer_creation_request: StandardAnswerCreationRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> StandardAnswer:
|
||||
standard_answer_model = insert_standard_answer(
|
||||
keyword=standard_answer_creation_request.keyword,
|
||||
answer=standard_answer_creation_request.answer,
|
||||
category_ids=standard_answer_creation_request.categories,
|
||||
db_session=db_session,
|
||||
)
|
||||
return StandardAnswer.from_model(standard_answer_model)
|
||||
|
||||
|
||||
@router.get("/admin/standard-answer")
|
||||
def list_standard_answers(
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[StandardAnswer]:
|
||||
standard_answer_models = fetch_standard_answers(db_session=db_session)
|
||||
return [
|
||||
StandardAnswer.from_model(standard_answer_model)
|
||||
for standard_answer_model in standard_answer_models
|
||||
]
|
||||
|
||||
|
||||
@router.patch("/admin/standard-answer/{standard_answer_id}")
|
||||
def patch_standard_answer(
|
||||
standard_answer_id: int,
|
||||
standard_answer_creation_request: StandardAnswerCreationRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> StandardAnswer:
|
||||
existing_standard_answer = fetch_standard_answer(
|
||||
standard_answer_id=standard_answer_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
if existing_standard_answer is None:
|
||||
raise HTTPException(status_code=404, detail="Standard answer not found")
|
||||
|
||||
standard_answer_model = update_standard_answer(
|
||||
standard_answer_id=standard_answer_id,
|
||||
keyword=standard_answer_creation_request.keyword,
|
||||
answer=standard_answer_creation_request.answer,
|
||||
category_ids=standard_answer_creation_request.categories,
|
||||
db_session=db_session,
|
||||
)
|
||||
return StandardAnswer.from_model(standard_answer_model)
|
||||
|
||||
|
||||
@router.delete("/admin/standard-answer/{standard_answer_id}")
|
||||
def delete_standard_answer(
|
||||
standard_answer_id: int,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
return remove_standard_answer(
|
||||
standard_answer_id=standard_answer_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/standard-answer/category")
|
||||
def create_standard_answer_category(
|
||||
standard_answer_category_creation_request: StandardAnswerCategoryCreationRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> StandardAnswerCategory:
|
||||
standard_answer_category_model = insert_standard_answer_category(
|
||||
category_name=standard_answer_category_creation_request.name,
|
||||
db_session=db_session,
|
||||
)
|
||||
return StandardAnswerCategory.from_model(standard_answer_category_model)
|
||||
|
||||
|
||||
@router.get("/admin/standard-answer/category")
|
||||
def list_standard_answer_categories(
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[StandardAnswerCategory]:
|
||||
standard_answer_category_models = fetch_standard_answer_categories(
|
||||
db_session=db_session
|
||||
)
|
||||
return [
|
||||
StandardAnswerCategory.from_model(standard_answer_category_model)
|
||||
for standard_answer_category_model in standard_answer_category_models
|
||||
]
|
||||
|
||||
|
||||
@router.patch("/admin/standard-answer/category/{standard_answer_category_id}")
|
||||
def patch_standard_answer_category(
|
||||
standard_answer_category_id: int,
|
||||
standard_answer_category_creation_request: StandardAnswerCategoryCreationRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> StandardAnswerCategory:
|
||||
existing_standard_answer_category = fetch_standard_answer_category(
|
||||
standard_answer_category_id=standard_answer_category_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
if existing_standard_answer_category is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Standard answer category not found"
|
||||
)
|
||||
|
||||
standard_answer_category_model = update_standard_answer_category(
|
||||
standard_answer_category_id=standard_answer_category_id,
|
||||
category_name=standard_answer_category_creation_request.name,
|
||||
db_session=db_session,
|
||||
)
|
||||
return StandardAnswerCategory.from_model(standard_answer_category_model)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user