Compare commits

..

47 Commits

Author SHA1 Message Date
pablodanswer
d97e96b3f0 build org 2024-12-01 17:20:43 -08:00
pablodanswer
911fbfa5a6 k 2024-12-01 17:14:09 -08:00
pablodanswer
d02305671a k 2024-12-01 17:12:54 -08:00
pablodanswer
bdfa29dcb5 slack chat 2024-12-01 15:06:32 -08:00
pablodanswer
897ed03c19 fix memoization 2024-12-01 15:02:46 -08:00
pablodanswer
49f0c4f1f8 fix memoization 2024-12-01 15:02:28 -08:00
pablodanswer
338c02171b rm shs 2024-12-01 12:46:07 -08:00
pablodanswer
ef1ade84b6 k 2024-12-01 12:46:07 -08:00
pablodanswer
7c81566c54 k 2024-12-01 12:46:07 -08:00
pablodanswer
c9df0aea47 k 2024-12-01 12:46:07 -08:00
pablodanswer
92e0aeecba k 2024-12-01 12:46:07 -08:00
pablodanswer
30c7e07783 update for all screen sizes 2024-12-01 12:46:07 -08:00
pablodanswer
e99704e9bd update sidebar line 2024-12-01 12:46:07 -08:00
pablodanswer
7f36387f7f k 2024-12-01 12:46:07 -08:00
pablodanswer
407592445b minor nit 2024-12-01 12:46:07 -08:00
pablodanswer
2e533d8188 minor date range clarity 2024-12-01 12:46:07 -08:00
pablodanswer
5b56869937 quick unification of icons 2024-12-01 12:46:07 -08:00
pablodanswer
7baeab54e2 address comments 2024-12-01 12:46:07 -08:00
pablodanswer
aefcfb75ef k 2024-12-01 12:46:07 -08:00
pablodanswer
e5adcb457d k 2024-12-01 12:46:07 -08:00
pablodanswer
db6463644a small nit 2024-12-01 12:46:07 -08:00
pablodanswer
e26ba70cc6 update filters 2024-12-01 12:46:07 -08:00
pablodanswer
66ff723c94 badge up 2024-12-01 12:46:07 -08:00
pablodanswer
dda66f2178 finalize changes 2024-12-01 12:46:07 -08:00
pablodanswer
0a27f72d20 cleanup complete 2024-12-01 12:46:07 -08:00
pablodanswer
fe397601ed minor cleanup 2024-12-01 12:46:07 -08:00
pablodanswer
3bc187c1d1 clean up unused components 2024-12-01 12:46:07 -08:00
pablodanswer
9a0b9eecf0 source types update 2024-12-01 12:46:07 -08:00
pablodanswer
e08db414c0 viewport height update 2024-12-01 12:46:07 -08:00
pablodanswer
b5734057b7 various updates 2024-12-01 12:46:07 -08:00
pablodanswer
56beb3ec82 k 2024-12-01 12:46:07 -08:00
pablodanswer
9f2c8118d7 updates 2024-12-01 12:46:07 -08:00
pablodanswer
6e4a3d5d57 finalize tags 2024-12-01 12:46:07 -08:00
pablodanswer
5b3dcf718f scroll nit 2024-12-01 12:46:07 -08:00
pablodanswer
07bd20b5b9 push fade 2024-12-01 12:46:07 -08:00
pablodanswer
eb01b175ae update logs 2024-12-01 12:46:06 -08:00
pablodanswer
6f55e5fe56 default 2024-12-01 12:46:06 -08:00
pablodanswer
18e7609bfc update scroll 2024-12-01 12:46:06 -08:00
pablodanswer
dd69ec6cdb cleanup 2024-12-01 12:46:06 -08:00
pablodanswer
e961fa2820 fix mystery reorg 2024-12-01 12:46:06 -08:00
pablodanswer
d41bf9a3ff clean up 2024-12-01 12:46:06 -08:00
pablodanswer
e3a6c76d51 k 2024-12-01 12:46:06 -08:00
pablodanswer
719c2aa0df update 2024-12-01 12:46:06 -08:00
pablodanswer
09f487e402 updates 2024-12-01 12:46:06 -08:00
pablodanswer
33a1548fc1 k 2024-12-01 12:46:06 -08:00
pablodanswer
e87c93226a updated chat flow 2024-12-01 12:46:06 -08:00
pablodanswer
5e11a79593 proper no assistant typing + no assistant modal 2024-12-01 12:46:06 -08:00
223 changed files with 2915 additions and 4134 deletions

View File

@@ -24,8 +24,6 @@ env:
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
# Slab
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
jobs:
connectors-check:

View File

@@ -73,7 +73,6 @@ RUN apt-get update && \
rm -rf /var/lib/apt/lists/* && \
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
# Pre-downloading models for setups with limited egress
RUN python -c "from tokenizers import Tokenizer; \
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"

View File

@@ -1,5 +1,5 @@
from sqlalchemy.engine.base import Connection
from typing import Literal
from typing import Any
import asyncio
from logging.config import fileConfig
import logging
@@ -8,7 +8,6 @@ from alembic import context
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.sql import text
from sqlalchemy.sql.schema import SchemaItem
from shared_configs.configs import MULTI_TENANT
from danswer.db.engine import build_connection_string
@@ -36,18 +35,7 @@ logger = logging.getLogger(__name__)
def include_object(
object: SchemaItem,
name: str | None,
type_: Literal[
"schema",
"table",
"column",
"index",
"unique_constraint",
"foreign_key_constraint",
],
reflected: bool,
compare_to: SchemaItem | None,
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
) -> bool:
"""
Determines whether a database object should be included in migrations.

View File

@@ -1,36 +0,0 @@
"""Combine Search and Chat
Revision ID: 9f696734098f
Revises: a8c2065484e6
Create Date: 2024-11-27 15:32:19.694972
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9f696734098f"
down_revision = "a8c2065484e6"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column("chat_session", "description", nullable=True)
op.drop_column("chat_session", "one_shot")
op.drop_column("slack_channel_config", "response_type")
def downgrade() -> None:
op.execute("UPDATE chat_session SET description = '' WHERE description IS NULL")
op.alter_column("chat_session", "description", nullable=False)
op.add_column(
"chat_session",
sa.Column("one_shot", sa.Boolean(), nullable=False, server_default=sa.false()),
)
op.add_column(
"slack_channel_config",
sa.Column(
"response_type", sa.String(), nullable=False, server_default="citations"
),
)

View File

@@ -1,6 +1,5 @@
import asyncio
from logging.config import fileConfig
from typing import Literal
from sqlalchemy import pool
from sqlalchemy.engine import Connection
@@ -38,15 +37,8 @@ EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
def include_object(
object: SchemaItem,
name: str | None,
type_: Literal[
"schema",
"table",
"column",
"index",
"unique_constraint",
"foreign_key_constraint",
],
name: str,
type_: str,
reflected: bool,
compare_to: SchemaItem | None,
) -> bool:

View File

@@ -18,11 +18,6 @@ class ExternalAccess:
@dataclass(frozen=True)
class DocExternalAccess:
"""
This is just a class to wrap the external access and the document ID
together. It's used for syncing document permissions to Redis.
"""
external_access: ExternalAccess
# The document ID
doc_id: str

View File

@@ -58,6 +58,7 @@ from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserUpdate
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import DISABLE_VERIFICATION
from danswer.configs.app_configs import EMAIL_FROM
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
@@ -86,7 +87,6 @@ from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
from danswer.db.users import get_user_by_email
from danswer.server.utils import BasicAuthenticationError
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
@@ -99,6 +99,11 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
class BasicAuthenticationError(HTTPException):
def __init__(self, detail: str):
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def is_user_admin(user: User | None) -> bool:
if AUTH_TYPE == AuthType.DISABLED:
return True
@@ -131,12 +136,11 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
def user_needs_to_be_verified() -> bool:
if AUTH_TYPE == AuthType.BASIC:
return REQUIRE_EMAIL_VERIFICATION
# For other auth types, if the user is authenticated it's assumed that
# the user is already verified via the external IDP
return False
# all other auth types besides basic should require users to be
# verified
return not DISABLE_VERIFICATION and (
AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
)
def verify_email_is_invited(email: str) -> None:

View File

@@ -11,7 +11,6 @@ from celery.exceptions import WorkerShutdown
from celery.states import READY_STATES
from celery.utils.log import get_task_logger
from celery.worker import strategy # type: ignore
from redis.lock import Lock as RedisLock
from sentry_sdk.integrations.celery import CeleryIntegration
from sqlalchemy import text
from sqlalchemy.orm import Session
@@ -333,16 +332,16 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
return
logger.info("Releasing primary worker lock.")
lock: RedisLock = sender.primary_worker_lock
lock = sender.primary_worker_lock
try:
if lock.owned():
try:
lock.release()
sender.primary_worker_lock = None
except Exception:
logger.exception("Failed to release primary worker lock")
except Exception:
logger.exception("Failed to check if primary worker lock is owned")
except Exception as e:
logger.error(f"Failed to release primary worker lock: {e}")
except Exception as e:
logger.error(f"Failed to check if primary worker lock is owned: {e}")
def on_setup_logging(

View File

@@ -11,7 +11,6 @@ from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
from redis.lock import Lock as RedisLock
import danswer.background.celery.apps.app_base as app_base
from danswer.background.celery.apps.app_base import task_logger
@@ -39,6 +38,7 @@ from danswer.redis.redis_usergroup import RedisUserGroup
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
@@ -116,13 +116,9 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
# set thread_local=False since we don't control what thread the periodic task might
# reacquire the lock with
lock: RedisLock = r.lock(
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
thread_local=False,
)
logger.info("Primary worker lock: Acquire starting.")
@@ -231,7 +227,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
if not hasattr(worker, "primary_worker_lock"):
return
lock: RedisLock = worker.primary_worker_lock
lock = worker.primary_worker_lock
r = get_redis_client(tenant_id=None)

View File

@@ -2,55 +2,54 @@ from datetime import timedelta
from typing import Any
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryTask
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
"task": "check_for_vespa_sync_task",
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-connector-deletion",
"task": DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
"task": "check_for_connector_deletion_task",
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-indexing",
"task": DanswerCeleryTask.CHECK_FOR_INDEXING,
"task": "check_for_indexing",
"schedule": timedelta(seconds=15),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": DanswerCeleryTask.CHECK_FOR_PRUNING,
"task": "check_for_pruning",
"schedule": timedelta(seconds=15),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "kombu-message-cleanup",
"task": DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
"task": "kombu_message_cleanup_task",
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
{
"name": "monitor-vespa-sync",
"task": DanswerCeleryTask.MONITOR_VESPA_SYNC,
"task": "monitor_vespa_sync",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-doc-permissions-sync",
"task": DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"task": "check_for_doc_permissions_sync",
"schedule": timedelta(seconds=30),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-external-group-sync",
"task": DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"task": "check_for_external_group_sync",
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},

View File

@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import get_connector_credential_pairs
@@ -29,7 +28,7 @@ class TaskDependencyError(RuntimeError):
@shared_task(
name=DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
name="check_for_connector_deletion_task",
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,

View File

@@ -18,11 +18,9 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import DocumentSource
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import upsert_document_by_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
@@ -84,7 +82,7 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
@shared_task(
name=DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
name="check_for_doc_permissions_sync",
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
@@ -166,7 +164,7 @@ def try_creating_permissions_sync_task(
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
result = app.send_task(
DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
"connector_permission_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
tenant_id=tenant_id,
@@ -193,7 +191,7 @@ def try_creating_permissions_sync_task(
@shared_task(
name=DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
name="connector_permission_sync_generator_task",
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
@@ -263,12 +261,7 @@ def connector_permission_sync_generator_task(
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = redis_connector.permissions.generate_tasks(
celery_app=self.app,
lock=lock,
new_permissions=document_external_accesses,
source_string=source_type,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
self.app, lock, document_external_accesses, source_type
)
if tasks_generated is None:
return None
@@ -293,7 +286,7 @@ def connector_permission_sync_generator_task(
@shared_task(
name=DanswerCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
name="update_external_document_permissions_task",
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
@@ -304,8 +297,6 @@ def update_external_document_permissions_task(
tenant_id: str | None,
serialized_doc_external_access: dict,
source_string: str,
connector_id: int,
credential_id: int,
) -> bool:
document_external_access = DocExternalAccess.from_dict(
serialized_doc_external_access
@@ -314,28 +305,18 @@ def update_external_document_permissions_task(
external_access = document_external_access.external_access
try:
with get_session_with_tenant(tenant_id) as db_session:
# Add the users to the DB if they don't exist
# Then we build the update requests to update vespa
batch_add_ext_perm_user_if_not_exists(
db_session=db_session,
emails=list(external_access.external_user_emails),
)
# Then we upsert the document's external permissions in postgres
created_new_doc = upsert_document_external_perms(
upsert_document_external_perms(
db_session=db_session,
doc_id=doc_id,
external_access=external_access,
source_type=DocumentSource(source_string),
)
if created_new_doc:
# If a new document was created, we associate it with the cc_pair
upsert_document_by_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
document_ids=[doc_id],
)
logger.debug(
f"Successfully synced postgres document permissions for {doc_id}"
)

View File

@@ -17,7 +17,6 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import mark_cc_pair_as_external_group_synced
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
@@ -32,14 +31,10 @@ from danswer.redis.redis_connector_ext_group_sync import (
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.danswer.db.connector_credential_pair import get_cc_pairs_by_source
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
from ee.danswer.external_permissions.sync_params import (
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
)
logger = setup_logger()
@@ -90,7 +85,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
@shared_task(
name=DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
name="check_for_external_group_sync",
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
@@ -111,22 +106,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
# We only want to sync one cc_pair per source type in
# GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
# These are ordered by cc_pair id so the first one is the one we want
cc_pairs_to_dedupe = get_cc_pairs_by_source(
db_session, source, only_sync=True
)
# We only want to sync one cc_pair per source type
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
for cc_pair_to_remove in cc_pairs_to_dedupe[1:]:
cc_pairs = [
cc_pair
for cc_pair in cc_pairs
if cc_pair.id != cc_pair_to_remove.id
]
for cc_pair in cc_pairs:
if _is_external_group_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
@@ -182,7 +161,7 @@ def try_creating_external_group_sync_task(
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
result = app.send_task(
DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
"connector_external_group_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
tenant_id=tenant_id,
@@ -212,7 +191,7 @@ def try_creating_external_group_sync_task(
@shared_task(
name=DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
name="connector_external_group_sync_generator_task",
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,

View File

@@ -23,7 +23,6 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import DocumentSource
from danswer.db.connector import mark_ccpair_with_indexing_trigger
@@ -157,7 +156,7 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
@shared_task(
name=DanswerCeleryTask.CHECK_FOR_INDEXING,
name="check_for_indexing",
soft_time_limit=300,
bind=True,
)
@@ -487,7 +486,7 @@ def try_creating_indexing_task(
# when the task is sent, we have yet to finish setting up the fence
# therefore, the task must contain code that blocks until the fence is ready
result = celery_app.send_task(
DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
"connector_indexing_proxy_task",
kwargs=dict(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair.id,
@@ -525,10 +524,7 @@ def try_creating_indexing_task(
@shared_task(
name=DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
bind=True,
acks_late=False,
track_started=True,
name="connector_indexing_proxy_task", bind=True, acks_late=False, track_started=True
)
def connector_indexing_proxy_task(
self: Task,
@@ -584,64 +580,39 @@ def connector_indexing_proxy_task(
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
"Indexing watchdog - termination signal detected: "
"Indexing proxy - termination signal detected: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
try:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
)
finally:
# if the DB exceptions, we'll just get an unfriendly failure message
# in the UI instead of the cancellation message
logger.exception(
"Indexing watchdog - transient exception marking index attempt as canceled: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
)
job.cancel()
job.cancel()
break
# do nothing for ongoing jobs that haven't been stopped
if not job.done():
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
try:
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
# if the DB exceptioned, just restart the check.
# polling the index attempt status doesn't need to be strongly consistent
logger.exception(
"Indexing watchdog - transient exception looking up index attempt: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
continue
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
if job.status == "error":
task_logger.error(
"Indexing watchdog - spawned task exceptioned: "
f"Indexing watchdog - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
@@ -789,12 +760,9 @@ def connector_indexing_task(
)
break
# set thread_local=False since we don't control what thread the indexing/pruning
# might run our callback with
lock: RedisLock = r.lock(
redis_connector_index.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
thread_local=False,
)
acquired = lock.acquire(blocking=False)

View File

@@ -13,13 +13,12 @@ from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import PostgresAdvisoryLocks
from danswer.db.engine import get_session_with_tenant
@shared_task(
name=DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
name="kombu_message_cleanup_task",
soft_time_limit=JOB_TIMEOUT,
bind=True,
base=AbortableTask,

View File

@@ -8,7 +8,6 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
@@ -21,7 +20,6 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
@@ -77,7 +75,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
@shared_task(
name=DanswerCeleryTask.CHECK_FOR_PRUNING,
name="check_for_pruning",
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
@@ -186,7 +184,7 @@ def try_creating_prune_generator_task(
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
celery_app.send_task(
DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
"connector_pruning_generator_task",
kwargs=dict(
cc_pair_id=cc_pair.id,
connector_id=cc_pair.connector_id,
@@ -211,7 +209,7 @@ def try_creating_prune_generator_task(
@shared_task(
name=DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
name="connector_pruning_generator_task",
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
@@ -240,12 +238,9 @@ def connector_pruning_generator_task(
r = get_redis_client(tenant_id=tenant_id)
# set thread_local=False since we don't control what thread the indexing/pruning
# might run our callback with
lock: RedisLock = r.lock(
lock = r.lock(
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
thread_local=False,
)
acquired = lock.acquire(blocking=False)

View File

@@ -9,7 +9,6 @@ from tenacity import RetryError
from danswer.access.access import get_access_for_document
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from danswer.configs.constants import DanswerCeleryTask
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document
@@ -32,7 +31,7 @@ LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
@shared_task(
name=DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
name="document_by_cc_pair_cleanup_task",
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
max_retries=DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES,

View File

@@ -25,7 +25,6 @@ from danswer.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector import mark_cc_pair_as_permissions_synced
@@ -81,7 +80,7 @@ logger = setup_logger()
# celery auto associates tasks created inside another task,
# which bloats the result metadata considerably. trail=False prevents this.
@shared_task(
name=DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
name="check_for_vespa_sync_task",
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
@@ -655,28 +654,24 @@ def monitor_ccpair_indexing_taskset(
# outer = result.state in READY state
status_int = redis_connector_index.get_completion()
if status_int is None: # inner signal not set ... possible error
task_state = result.state
result_state = result.state
if (
task_state in READY_STATES
result_state in READY_STATES
): # outer signal in terminal state ... possible error
# Now double check!
if redis_connector_index.get_completion() is None:
# inner signal still not set (and cannot change when outer result_state is READY)
# Task is finished but generator complete isn't set.
# We have a problem! Worker may have crashed.
task_result = str(result.result)
task_traceback = str(result.traceback)
msg = (
f"Connector indexing aborted or exceptioned: "
f"attempt={payload.index_attempt_id} "
f"celery_task={payload.celery_task_id} "
f"result_state={result_state} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"result.state={task_state} "
f"result.result={task_result} "
f"result.traceback={task_traceback}"
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
task_logger.warning(msg)
@@ -708,7 +703,7 @@ def monitor_ccpair_indexing_taskset(
redis_connector_index.reset()
@shared_task(name=DanswerCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
It scans for fence values and then gets the counts of any associated tasksets.
@@ -819,7 +814,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
@shared_task(
name=DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
name="vespa_metadata_sync_task",
bind=True,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,

View File

@@ -2,79 +2,20 @@ import re
from typing import cast
from uuid import UUID
from fastapi import HTTPException
from fastapi.datastructures import Headers
from sqlalchemy.orm import Session
from danswer.auth.users import is_user_admin
from danswer.chat.models import CitationInfo
from danswer.chat.models import LlmDoc
from danswer.chat.models import PersonaOverrideConfig
from danswer.chat.models import ThreadMessage
from danswer.configs.constants import DEFAULT_PERSONA_ID
from danswer.configs.constants import MessageType
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
from danswer.db.chat import create_chat_session
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.llm import fetch_existing_doc_sets
from danswer.db.llm import fetch_existing_tools
from danswer.db.models import ChatMessage
from danswer.db.models import Persona
from danswer.db.models import Prompt
from danswer.db.models import Tool
from danswer.db.models import User
from danswer.db.persona import get_prompts_by_ids
from danswer.llm.models import PreviousMessage
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.llm.answering.models import PreviousMessage
from danswer.utils.logger import setup_logger
logger = setup_logger()
def prepare_chat_message_request(
message_text: str,
user: User | None,
persona_id: int | None,
# Does the question need to have a persona override
persona_override_config: PersonaOverrideConfig | None,
prompt: Prompt | None,
message_ts_to_respond_to: str | None,
retrieval_details: RetrievalDetails | None,
rerank_settings: RerankingDetails | None,
db_session: Session,
) -> CreateChatMessageRequest:
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
new_chat_session = create_chat_session(
db_session=db_session,
description=None,
user_id=user.id if user else None,
# If using an override, this id will be ignored later on
persona_id=persona_id or DEFAULT_PERSONA_ID,
danswerbot_flow=True,
slack_thread_id=message_ts_to_respond_to,
)
return CreateChatMessageRequest(
chat_session_id=new_chat_session.id,
parent_message_id=None, # It's a standalone chat session each time
message=message_text,
file_descriptors=[], # Currently SlackBot/answer api do not support files in the context
prompt_id=prompt.id if prompt else None,
# Can always override the persona for the single query, if it's a normal persona
# then it will be treated the same
persona_override_config=persona_override_config,
search_doc_ids=None,
retrieval_options=retrieval_details,
rerank_settings=rerank_settings,
)
def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc:
return LlmDoc(
document_id=inference_section.center_chunk.document_id,
@@ -90,49 +31,9 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
if inference_section.center_chunk.source_links
else None,
source_links=inference_section.center_chunk.source_links,
match_highlights=inference_section.center_chunk.match_highlights,
)
def combine_message_thread(
messages: list[ThreadMessage],
max_tokens: int | None,
llm_tokenizer: BaseTokenizer,
) -> str:
"""Used to create a single combined message context from threads"""
if not messages:
return ""
message_strs: list[str] = []
total_token_count = 0
for message in reversed(messages):
if message.role == MessageType.USER:
role_str = message.role.value.upper()
if message.sender:
role_str += " " + message.sender
else:
# Since other messages might have the user identifying information
# better to use Unknown for symmetry
role_str += " Unknown"
else:
role_str = message.role.value.upper()
msg_str = f"{role_str}:\n{message.message}"
message_token_count = len(llm_tokenizer.encode(msg_str))
if (
max_tokens is not None
and total_token_count + message_token_count > max_tokens
):
break
message_strs.insert(0, msg_str)
total_token_count += message_token_count
return "\n\n".join(message_strs)
def create_chat_chain(
chat_session_id: UUID,
db_session: Session,
@@ -295,71 +196,3 @@ def extract_headers(
if lowercase_key in headers:
extracted_headers[lowercase_key] = headers[lowercase_key]
return extracted_headers
def create_temporary_persona(
persona_config: PersonaOverrideConfig, db_session: Session, user: User | None = None
) -> Persona:
if not is_user_admin(user):
raise HTTPException(
status_code=403,
detail="User is not authorized to create a persona in one shot queries",
)
"""Create a temporary Persona object from the provided configuration."""
persona = Persona(
name=persona_config.name,
description=persona_config.description,
num_chunks=persona_config.num_chunks,
llm_relevance_filter=persona_config.llm_relevance_filter,
llm_filter_extraction=persona_config.llm_filter_extraction,
recency_bias=persona_config.recency_bias,
llm_model_provider_override=persona_config.llm_model_provider_override,
llm_model_version_override=persona_config.llm_model_version_override,
)
if persona_config.prompts:
persona.prompts = [
Prompt(
name=p.name,
description=p.description,
system_prompt=p.system_prompt,
task_prompt=p.task_prompt,
include_citations=p.include_citations,
datetime_aware=p.datetime_aware,
)
for p in persona_config.prompts
]
elif persona_config.prompt_ids:
persona.prompts = get_prompts_by_ids(
db_session=db_session, prompt_ids=persona_config.prompt_ids
)
persona.tools = []
if persona_config.custom_tools_openapi:
for schema in persona_config.custom_tools_openapi:
tools = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(schema),
)
persona.tools.extend(tools)
if persona_config.tools:
tool_ids = [tool.id for tool in persona_config.tools]
persona.tools.extend(
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
)
if persona_config.tool_ids:
persona.tools.extend(
fetch_existing_tools(
db_session=db_session, tool_ids=persona_config.tool_ids
)
)
fetched_docs = fetch_existing_doc_sets(
db_session=db_session, doc_ids=persona_config.document_set_ids
)
persona.document_sets = fetched_docs
return persona

View File

@@ -1,30 +1,17 @@
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from enum import Enum
from typing import Any
from typing import TYPE_CHECKING
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import model_validator
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
from danswer.context.search.enums import QueryFlow
from danswer.context.search.enums import RecencyBiasSetting
from danswer.context.search.enums import SearchType
from danswer.context.search.models import RetrievalDocs
from danswer.llm.override_models import PromptOverride
from danswer.tools.models import ToolCallFinalResult
from danswer.tools.models import ToolCallKickoff
from danswer.tools.models import ToolResponse
from danswer.context.search.models import SearchResponse
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
if TYPE_CHECKING:
from danswer.db.models import Prompt
class LlmDoc(BaseModel):
"""This contains the minimal set information for the LLM portion including citations"""
@@ -38,7 +25,6 @@ class LlmDoc(BaseModel):
updated_at: datetime | None
link: str | None
source_links: dict[int, str] | None
match_highlights: list[str] | None
# First chunk of info for streaming QA
@@ -131,6 +117,20 @@ class StreamingError(BaseModel):
stack_trace: str | None = None
class DanswerQuote(BaseModel):
# This is during inference so everything is a string by this point
quote: str
document_id: str
link: str | None
source_type: str
semantic_identifier: str
blurb: str
class DanswerQuotes(BaseModel):
quotes: list[DanswerQuote]
class DanswerContext(BaseModel):
content: str
document_id: str
@@ -146,20 +146,14 @@ class DanswerAnswer(BaseModel):
answer: str | None
class ThreadMessage(BaseModel):
message: str
sender: str | None = None
role: MessageType = MessageType.USER
class ChatDanswerBotResponse(BaseModel):
answer: str | None = None
citations: list[CitationInfo] | None = None
docs: QADocsResponse | None = None
class QAResponse(SearchResponse, DanswerAnswer):
quotes: list[DanswerQuote] | None
contexts: list[DanswerContexts] | None
predicted_flow: QueryFlow
predicted_search: SearchType
eval_res_valid: bool | None = None
llm_selected_doc_indices: list[int] | None = None
error_msg: str | None = None
chat_message_id: int | None = None
answer_valid: bool = True # Reflexion result, default True if Reflexion not run
class FileChatDisplay(BaseModel):
@@ -171,41 +165,9 @@ class CustomToolResponse(BaseModel):
tool_name: str
class ToolConfig(BaseModel):
id: int
class PromptOverrideConfig(BaseModel):
name: str
description: str = ""
system_prompt: str
task_prompt: str = ""
include_citations: bool = True
datetime_aware: bool = True
class PersonaOverrideConfig(BaseModel):
name: str
description: str
search_type: SearchType = SearchType.SEMANTIC
num_chunks: float | None = None
llm_relevance_filter: bool = False
llm_filter_extraction: bool = False
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
prompt_ids: list[int] = Field(default_factory=list)
document_set_ids: list[int] = Field(default_factory=list)
tools: list[ToolConfig] = Field(default_factory=list)
tool_ids: list[int] = Field(default_factory=list)
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
AnswerQuestionPossibleReturn = (
DanswerAnswerPiece
| DanswerQuotes
| CitationInfo
| DanswerContexts
| FileChatDisplay
@@ -221,109 +183,3 @@ AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]
class LLMMetricsContainer(BaseModel):
prompt_tokens: int
response_tokens: int
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
class DocumentPruningConfig(BaseModel):
max_chunks: int | None = None
max_window_percentage: float | None = None
max_tokens: int | None = None
# different pruning behavior is expected when the
# user manually selects documents they want to chat with
# e.g. we don't want to truncate each document to be no more
# than one chunk long
is_manually_selected_docs: bool = False
# If user specifies to include additional context Chunks for each match, then different pruning
# is used. As many Sections as possible are included, and the last Section is truncated
# If this is false, all of the Sections are truncated if they are longer than the expected Chunk size.
# Sections are often expected to be longer than the maximum Chunk size but Chunks should not be.
use_sections: bool = True
# If using tools, then we need to consider the tool length
tool_num_tokens: int = 0
# If using a tool message to represent the docs, then we have to JSON serialize
# the document content, which adds to the token count.
using_tool_message: bool = False
class ContextualPruningConfig(DocumentPruningConfig):
num_chunk_multiple: int
@classmethod
def from_doc_pruning_config(
cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig
) -> "ContextualPruningConfig":
return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict())
class CitationConfig(BaseModel):
all_docs_useful: bool = False
class QuotesConfig(BaseModel):
pass
class AnswerStyleConfig(BaseModel):
citation_config: CitationConfig | None = None
quotes_config: QuotesConfig | None = None
document_pruning_config: DocumentPruningConfig = Field(
default_factory=DocumentPruningConfig
)
# forces the LLM to return a structured response, see
# https://platform.openai.com/docs/guides/structured-outputs/introduction
# right now, only used by the simple chat API
structured_response_format: dict | None = None
@model_validator(mode="after")
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
if self.citation_config is None and self.quotes_config is None:
raise ValueError(
"One of `citation_config` or `quotes_config` must be provided"
)
if self.citation_config is not None and self.quotes_config is not None:
raise ValueError(
"Only one of `citation_config` or `quotes_config` must be provided"
)
return self
class PromptConfig(BaseModel):
"""Final representation of the Prompt configuration passed
into the `Answer` object."""
system_prompt: str
task_prompt: str
datetime_aware: bool
include_citations: bool
@classmethod
def from_model(
cls, model: "Prompt", prompt_override: PromptOverride | None = None
) -> "PromptConfig":
override_system_prompt = (
prompt_override.system_prompt if prompt_override else None
)
override_task_prompt = prompt_override.task_prompt if prompt_override else None
return cls(
system_prompt=override_system_prompt or model.system_prompt,
task_prompt=override_task_prompt or model.task_prompt,
datetime_aware=model.datetime_aware,
include_citations=model.include_citations,
)
model_config = ConfigDict(frozen=True)
ResponsePart = (
DanswerAnswerPiece
| CitationInfo
| ToolCallKickoff
| ToolResponse
| ToolCallFinalResult
| StreamStopInfo
)

View File

@@ -6,24 +6,16 @@ from typing import cast
from sqlalchemy.orm import Session
from danswer.chat.answer import Answer
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.chat_utils import create_temporary_persona
from danswer.chat.models import AllCitations
from danswer.chat.models import AnswerStyleConfig
from danswer.chat.models import ChatDanswerBotResponse
from danswer.chat.models import CitationConfig
from danswer.chat.models import CitationInfo
from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerContexts
from danswer.chat.models import DocumentPruningConfig
from danswer.chat.models import FileChatDisplay
from danswer.chat.models import FinalUsedContextDocsResponse
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import PromptConfig
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.chat.models import StreamStopInfo
@@ -62,11 +54,16 @@ from danswer.document_index.factory import get_default_document_index
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import FileDescriptor
from danswer.file_store.utils import load_all_chat_files
from danswer.file_store.utils import save_files
from danswer.file_store.utils import save_files_from_urls
from danswer.llm.answering.answer import Answer
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import CitationConfig
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.models import PreviousMessage
from danswer.llm.utils import litellm_exception_to_error_msg
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.server.query_and_chat.models import ChatMessageDetail
@@ -105,7 +102,6 @@ from danswer.tools.tool_implementations.internet_search.internet_search_tool imp
from danswer.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from danswer.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
@@ -117,10 +113,7 @@ from danswer.tools.tool_implementations.search.search_tool import (
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.utils.logger import setup_logger
from danswer.utils.long_term_log import LongTermLogger
from danswer.utils.timing import log_function_time
from danswer.utils.timing import log_generator_function_time
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -263,7 +256,6 @@ def _get_force_search_settings(
ChatPacket = (
StreamingError
| QADocsResponse
| DanswerContexts
| LLMRelevanceFilterResponse
| FinalUsedContextDocsResponse
| ChatMessageDetail
@@ -294,8 +286,6 @@ def stream_chat_message_objects(
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
enforce_chat_session_id_for_search_docs: bool = True,
bypass_acl: bool = False,
include_contexts: bool = False,
) -> ChatPacketStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
@@ -303,7 +293,6 @@ def stream_chat_message_objects(
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
use_existing_user_message = new_msg_req.use_existing_user_message
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
@@ -333,31 +322,17 @@ def stream_chat_message_objects(
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
)
# use alternate persona if alternative assistant id is passed in
if alternate_assistant_id is not None:
# Allows users to specify a temporary persona (assistant) in the chat session
# this takes highest priority since it's user specified
persona = get_persona_by_id(
alternate_assistant_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
elif new_msg_req.persona_override_config:
# Certain endpoints allow users to specify arbitrary persona settings
# this should never conflict with the alternate_assistant_id
persona = persona = create_temporary_persona(
db_session=db_session,
persona_config=new_msg_req.persona_override_config,
user=user,
)
else:
persona = chat_session.persona
if not persona:
raise RuntimeError("No persona specified or found for chat session")
# If a prompt override is specified via the API, use that with highest priority
# but for saving it, we are just mapping it to an existing prompt
prompt_id = new_msg_req.prompt_id
if prompt_id is None and persona.prompts:
prompt_id = sorted(persona.prompts, key=lambda x: x.id)[-1].id
@@ -580,34 +555,19 @@ def stream_chat_message_objects(
reserved_message_id=reserved_message_id,
)
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
if new_msg_req.persona_override_config:
prompt_config = PromptConfig(
system_prompt=new_msg_req.persona_override_config.prompts[
0
].system_prompt,
task_prompt=new_msg_req.persona_override_config.prompts[0].task_prompt,
datetime_aware=new_msg_req.persona_override_config.prompts[
0
].datetime_aware,
include_citations=new_msg_req.persona_override_config.prompts[
0
].include_citations,
)
elif prompt_override:
if not final_msg.prompt:
raise ValueError(
"Prompt override cannot be applied, no base prompt found."
)
prompt_config = PromptConfig.from_model(
final_msg.prompt,
prompt_override=prompt_override,
)
elif final_msg.prompt:
prompt_config = PromptConfig.from_model(final_msg.prompt)
else:
prompt_config = PromptConfig.from_model(persona.prompts[0])
if not final_msg.prompt:
raise RuntimeError("No Prompt found")
prompt_config = (
PromptConfig.from_model(
final_msg.prompt,
prompt_override=(
new_msg_req.prompt_override or chat_session.prompt_override
),
)
if not persona
else PromptConfig.from_model(persona.prompts[0])
)
answer_style_config = AnswerStyleConfig(
citation_config=CitationConfig(
all_docs_useful=selected_db_search_docs is not None
@@ -627,13 +587,11 @@ def stream_chat_message_objects(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
retrieval_options=retrieval_options or RetrievalDetails(),
rerank_settings=new_msg_req.rerank_settings,
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
latest_query_files=latest_query_files,
bypass_acl=bypass_acl,
),
internet_search_tool_config=InternetSearchToolConfig(
answer_style_config=answer_style_config,
@@ -680,8 +638,7 @@ def stream_chat_message_objects(
reference_db_search_docs = None
qa_docs_response = None
# any files to associate with the AI message e.g. dall-e generated images
ai_message_files = []
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
dropped_indices = None
tool_result = None
@@ -736,14 +693,8 @@ def stream_chat_message_objects(
list[ImageGenerationResponse], packet.response
)
file_ids = save_files(
urls=[img.url for img in img_generation_response if img.url],
base64_files=[
img.image_data
for img in img_generation_response
if img.image_data
],
tenant_id=tenant_id,
file_ids = save_files_from_urls(
[img.url for img in img_generation_response]
)
ai_message_files = [
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
@@ -769,19 +720,15 @@ def stream_chat_message_objects(
or custom_tool_response.response_type == "csv"
):
file_ids = custom_tool_response.tool_result.file_ids
ai_message_files.extend(
[
FileDescriptor(
id=str(file_id),
type=(
ChatFileType.IMAGE
if custom_tool_response.response_type == "image"
else ChatFileType.CSV
),
)
for file_id in file_ids
]
)
ai_message_files = [
FileDescriptor(
id=str(file_id),
type=ChatFileType.IMAGE
if custom_tool_response.response_type == "image"
else ChatFileType.CSV,
)
for file_id in file_ids
]
yield FileChatDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
@@ -790,8 +737,6 @@ def stream_chat_message_objects(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts:
yield cast(DanswerContexts, packet.response)
elif isinstance(packet, StreamStopInfo):
pass
@@ -831,8 +776,7 @@ def stream_chat_message_objects(
citations_list=answer.citations,
db_docs=reference_db_search_docs,
)
if not answer.is_cancelled():
yield AllCitations(citations=answer.citations)
yield AllCitations(citations=answer.citations)
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
@@ -901,30 +845,3 @@ def stream_chat_message(
)
for obj in objects:
yield get_json_line(obj.model_dump())
@log_function_time()
def gather_stream_for_slack(
packets: ChatPacketStream,
) -> ChatDanswerBotResponse:
response = ChatDanswerBotResponse()
answer = ""
for packet in packets:
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
elif isinstance(packet, QADocsResponse):
response.docs = packet
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
response.chat_message_id = packet.message_id
elif isinstance(packet, LLMRelevanceFilterResponse):
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
elif isinstance(packet, AllCitations):
response.citations = packet.citations
if answer:
response.answer = answer
return response

View File

@@ -1,62 +0,0 @@
from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.models import PreviousMessage
from danswer.llm.utils import build_content_with_imgs
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
def build_dummy_prompt(
system_prompt: str, task_prompt: str, retrieval_disabled: bool
) -> str:
if retrieval_disabled:
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
return PARAMATERIZED_PROMPT.format(
context_docs_str="<CONTEXT_DOCS>",
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
def translate_danswer_msg_to_langchain(
msg: ChatMessage | PreviousMessage,
) -> BaseMessage:
files: list[InMemoryChatFile] = []
# If the message is a `ChatMessage`, it doesn't have the downloaded files
# attached. Just ignore them for now.
if not isinstance(msg, ChatMessage):
files = msg.files
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
if msg.message_type == MessageType.SYSTEM:
raise ValueError("System messages are not currently part of history")
if msg.message_type == MessageType.ASSISTANT:
return AIMessage(content=content)
if msg.message_type == MessageType.USER:
return HumanMessage(content=content)
raise ValueError(f"New message type {msg.message_type} not handled")
def translate_history_to_basemessages(
history: list[ChatMessage] | list["PreviousMessage"],
) -> tuple[list[BaseMessage], list[int]]:
history_basemessages = [
translate_danswer_msg_to_langchain(msg)
for msg in history
if msg.token_count != 0
]
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
return history_basemessages, history_token_counts

View File

@@ -43,6 +43,9 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
# Necessary for cloud integration tests
DISABLE_VERIFICATION = os.environ.get("DISABLE_VERIFICATION", "").lower() == "true"
# Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive
# information. This provides an extra layer of security on top of Postgres access controls
# and is available in Danswer EE
@@ -82,7 +85,6 @@ OAUTH_CLIENT_SECRET = (
)
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
# for basic auth
REQUIRE_EMAIL_VERIFICATION = (
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
@@ -116,8 +118,6 @@ VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
VESPA_CONFIG_SERVER_HOST = os.environ.get("VESPA_CONFIG_SERVER_HOST") or VESPA_HOST
VESPA_PORT = os.environ.get("VESPA_PORT") or "8081"
VESPA_TENANT_PORT = os.environ.get("VESPA_TENANT_PORT") or "19071"
# the number of times to try and connect to vespa on startup before giving up
VESPA_NUM_ATTEMPTS_ON_STARTUP = int(os.environ.get("NUM_RETRIES_ON_STARTUP") or 10)
VESPA_CLOUD_URL = os.environ.get("VESPA_CLOUD_URL", "")
@@ -308,22 +308,6 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
)
# Due to breakages in the confluence API, the timezone offset must be specified client side
# to match the user's specified timezone.
# The current state of affairs:
# CQL queries are parsed in the user's timezone and cannot be specified in UTC
# no API retrieves the user's timezone
# All data is returned in UTC, so we can't derive the user's timezone from that
# https://community.developer.atlassian.com/t/confluence-cloud-time-zone-get-via-rest-api/35954/16
# https://jira.atlassian.com/browse/CONFCLOUD-69670
# enter as a floating point offset from UTC in hours (-24 < val < 24)
# this will be applied globally, so it probably makes sense to transition this to per
# connector as some point.
CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", 0.0))
JIRA_CONNECTOR_LABELS_TO_SKIP = [
ignored_tag
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
@@ -509,6 +493,10 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
# JWT configuration
JWT_ALGORITHM = "HS256"
# Super Users
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
#####
# API Key Configs
@@ -522,6 +510,3 @@ API_KEY_HASH_ROUNDS = (
POD_NAME = os.environ.get("POD_NAME")
POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"

View File

@@ -31,8 +31,6 @@ DISABLED_GEN_AI_MSG = (
"You can still use Danswer as a search engine."
)
DEFAULT_PERSONA_ID = 0
# Postgres connection constants for application_name
POSTGRES_WEB_APP_NAME = "web"
POSTGRES_INDEXER_APP_NAME = "indexer"
@@ -261,32 +259,6 @@ class DanswerCeleryPriority(int, Enum):
LOWEST = auto()
class DanswerCeleryTask:
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
CHECK_FOR_INDEXING = "check_for_indexing"
CHECK_FOR_PRUNING = "check_for_pruning"
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
"connector_permission_sync_generator_task"
)
UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK = (
"update_external_document_permissions_task"
)
CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK = (
"connector_external_group_sync_generator_task"
)
CONNECTOR_INDEXING_PROXY_TASK = "connector_indexing_proxy_task"
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task"
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3

View File

@@ -4,8 +4,11 @@ import os
# Danswer Slack Bot Configs
#####
DANSWER_BOT_NUM_RETRIES = int(os.environ.get("DANSWER_BOT_NUM_RETRIES", "5"))
DANSWER_BOT_ANSWER_GENERATION_TIMEOUT = int(
os.environ.get("DANSWER_BOT_ANSWER_GENERATION_TIMEOUT", "90")
)
# How much of the available input context can be used for thread context
MAX_THREAD_CONTEXT_PERCENTAGE = 512 * 2 / 3072
DANSWER_BOT_TARGET_CHUNK_PERCENTAGE = 512 * 2 / 3072
# Number of docs to display in "Reference Documents"
DANSWER_BOT_NUM_DOCS_TO_DISPLAY = int(
os.environ.get("DANSWER_BOT_NUM_DOCS_TO_DISPLAY", "5")
@@ -44,6 +47,17 @@ DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
DANSWER_BOT_RESPOND_EVERY_CHANNEL = (
os.environ.get("DANSWER_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true"
)
# Add a second LLM call post Answer to verify if the Answer is valid
# Throws out answers that don't directly or fully answer the user query
# This is the default for all DanswerBot channels unless the channel is configured individually
# Set/unset by "Hide Non Answers"
ENABLE_DANSWERBOT_REFLEXION = (
os.environ.get("ENABLE_DANSWERBOT_REFLEXION", "").lower() == "true"
)
# Currently not support chain of thought, probably will add back later
DANSWER_BOT_DISABLE_COT = True
# if set, will default DanswerBot to use quotes and reference documents
DANSWER_BOT_USE_QUOTES = os.environ.get("DANSWER_BOT_USE_QUOTES", "").lower() == "true"
# Maximum Questions Per Minute, Default Uncapped
DANSWER_BOT_MAX_QPM = int(os.environ.get("DANSWER_BOT_MAX_QPM") or 0) or None

View File

@@ -2,8 +2,6 @@ import json
import os
IMAGE_GENERATION_OUTPUT_FORMAT = os.environ.get("IMAGE_GENERATION_OUTPUT_FORMAT", "url")
# if specified, will pass through request headers to the call to API calls made by custom tools
CUSTOM_TOOL_PASS_THROUGH_HEADERS: list[str] | None = None
_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get(

View File

@@ -11,16 +11,11 @@ Connectors come in 3 different flows:
- Load Connector:
- Bulk indexes documents to reflect a point in time. This type of connector generally works by either pulling all
documents via a connector's API or loads the documents from some sort of a dump file.
- Poll Connector:
- Poll connector:
- Incrementally updates documents based on a provided time range. It is used by the background job to pull the latest
changes and additions since the last round of polling. This connector helps keep the document index up to date
without needing to fetch/embed/index every document which would be too slow to do frequently on large sets of
documents.
- Slim Connector:
- This connector should be a lighter weight method of checking all documents in the source to see if they still exist.
- This connector should be identical to the Poll or Load Connector except that it only fetches the IDs of the documents, not the documents themselves.
- This is used by our pruning job which removes old documents from the index.
- The optional start and end datetimes can be ignored.
- Event Based connectors:
- Connectors that listen to events and update documents accordingly.
- Currently not used by the background job, this exists for future design purposes.
@@ -31,14 +26,8 @@ Refer to [interfaces.py](https://github.com/danswer-ai/danswer/blob/main/backend
and this first contributor created Pull Request for a new connector (Shoutout to Dan Brown):
[Reference Pull Request](https://github.com/danswer-ai/danswer/pull/139)
For implementing a Slim Connector, refer to the comments in this PR:
[Slim Connector PR](https://github.com/danswer-ai/danswer/pull/3303/files)
All new connectors should have tests added to the `backend/tests/daily/connectors` directory. Refer to the above PR for an example of adding tests for a new connector.
#### Implementing the new Connector
The connector must subclass one or more of LoadConnector, PollConnector, SlimConnector, or EventConnector.
The connector must subclass one or more of LoadConnector, PollConnector, or EventConnector.
The `__init__` should take arguments for configuring what documents the connector will and where it finds those
documents. For example, if you have a wiki site, it may include the configuration for the team, topic, folder, etc. of

View File

@@ -1,11 +1,9 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from urllib.parse import quote
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
@@ -15,7 +13,6 @@ from danswer.connectors.confluence.utils import attachment_to_content
from danswer.connectors.confluence.utils import build_confluence_document_id
from danswer.connectors.confluence.utils import datetime_from_string
from danswer.connectors.confluence.utils import extract_text_from_confluence_html
from danswer.connectors.confluence.utils import validate_attachment_filetype
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
@@ -72,7 +69,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
# skip it. This is generally used to avoid indexing extra sensitive
# pages.
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET,
) -> None:
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
@@ -108,8 +104,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
)
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset))
@property
def confluence_client(self) -> OnyxConfluence:
if self._confluence_client is None:
@@ -210,14 +204,12 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
confluence_page_ids: list[str] = []
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
logger.debug(f"page_query: {page_query}")
# Fetch pages as Documents
for page in self.confluence_client.paginated_cql_retrieval(
cql=page_query,
expand=",".join(_PAGE_EXPANSION_FIELDS),
limit=self.batch_size,
):
logger.debug(f"_fetch_document_batches: {page['id']}")
confluence_page_ids.append(page["id"])
doc = self._convert_object_to_document(page)
if doc is not None:
@@ -250,10 +242,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
# Add time filters
formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime(
formatted_start_time = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
"%Y-%m-%d %H:%M"
)
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
formatted_end_time = datetime.fromtimestamp(end, tz=timezone.utc).strftime(
"%Y-%m-%d %H:%M"
)
self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'"
@@ -277,11 +269,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
):
# If the page has restrictions, add them to the perm_sync_data
# These will be used by doc_sync.py to sync permissions
page_restrictions = page.get("restrictions")
page_space_key = page.get("space", {}).get("key")
page_perm_sync_data = {
"restrictions": page_restrictions or {},
"space_key": page_space_key,
perm_sync_data = {
"restrictions": page.get("restrictions", {}),
"space_key": page.get("space", {}).get("key"),
}
doc_metadata_list.append(
@@ -291,7 +281,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
page["_links"]["webui"],
self.is_cloud,
),
perm_sync_data=page_perm_sync_data,
perm_sync_data=perm_sync_data,
)
)
attachment_cql = f"type=attachment and container='{page['id']}'"
@@ -301,21 +291,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):
if not validate_attachment_filetype(attachment):
continue
attachment_restrictions = attachment.get("restrictions")
if not attachment_restrictions:
attachment_restrictions = page_restrictions
attachment_space_key = attachment.get("space", {}).get("key")
if not attachment_space_key:
attachment_space_key = page_space_key
attachment_perm_sync_data = {
"restrictions": attachment_restrictions or {},
"space_key": attachment_space_key,
}
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
@@ -323,7 +298,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
attachment["_links"]["webui"],
self.is_cloud,
),
perm_sync_data=attachment_perm_sync_data,
perm_sync_data=perm_sync_data,
)
)
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:

View File

@@ -134,32 +134,6 @@ class OnyxConfluence(Confluence):
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
self._wrap_methods()
def get_current_user(self, expand: str | None = None) -> Any:
"""
Implements a method that isn't in the third party client.
Get information about the current user
:param expand: OPTIONAL expand for get status of user.
Possible param is "status". Results are "Active, Deactivated"
:return: Returns the user details
"""
from atlassian.errors import ApiPermissionError # type:ignore
url = "rest/api/user/current"
params = {}
if expand:
params["expand"] = expand
try:
response = self.get(url, params=params)
except HTTPError as e:
if e.response.status_code == 403:
raise ApiPermissionError(
"The calling user does not have permission", reason=e
)
raise
return response
def _wrap_methods(self) -> None:
"""
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
@@ -332,13 +306,6 @@ def _validate_connector_configuration(
)
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
# uncomment the following for testing
# the following is an attempt to retrieve the user's timezone
# Unfornately, all data is returned in UTC regardless of the user's time zone
# even tho CQL parses incoming times based on the user's time zone
# space_key = spaces["results"][0]["key"]
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
if not spaces:
raise RuntimeError(
f"No spaces found at {wiki_base}! "

View File

@@ -32,11 +32,7 @@ def get_user_email_from_username__server(
response = confluence_client.get_mobile_parameters(user_name)
email = response.get("email")
except Exception:
# For now, we'll just return a string that indicates failure
# We may want to revert to returning None in the future
# email = None
email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
logger.warning(f"failed to get confluence email for {user_name}")
email = None
_USER_EMAIL_CACHE[user_name] = email
return _USER_EMAIL_CACHE[user_name]
@@ -177,23 +173,19 @@ def extract_text_from_confluence_html(
return format_document_soup(soup)
def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
return attachment["metadata"]["mediaType"] not in [
def attachment_to_content(
confluence_client: OnyxConfluence,
attachment: dict[str, Any],
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
if attachment["metadata"]["mediaType"] in [
"image/jpeg",
"image/png",
"image/gif",
"image/svg+xml",
"video/mp4",
"video/quicktime",
]
def attachment_to_content(
confluence_client: OnyxConfluence,
attachment: dict[str, Any],
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
if not validate_attachment_filetype(attachment):
]:
return None
download_link = confluence_client.url + attachment["_links"]["download"]
@@ -249,7 +241,7 @@ def build_confluence_document_id(
return f"{base_url}{content_url}"
def _extract_referenced_attachment_names(page_text: str) -> list[str]:
def extract_referenced_attachment_names(page_text: str) -> list[str]:
"""Parse a Confluence html page to generate a list of current
attachments in use

View File

@@ -12,15 +12,12 @@ from dateutil import parser
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.models import SlimDocument
from danswer.utils.logger import setup_logger
@@ -31,8 +28,6 @@ logger = setup_logger()
SLAB_GRAPHQL_MAX_TRIES = 10
SLAB_API_URL = "https://api.slab.com/v1/graphql"
_SLIM_BATCH_SIZE = 1000
def run_graphql_request(
graphql_query: dict, bot_token: str, max_tries: int = SLAB_GRAPHQL_MAX_TRIES
@@ -163,26 +158,21 @@ def get_slab_url_from_title_id(base_url: str, title: str, page_id: str) -> str:
return urljoin(urljoin(base_url, "posts/"), url_id)
class SlabConnector(LoadConnector, PollConnector, SlimConnector):
class SlabConnector(LoadConnector, PollConnector):
def __init__(
self,
base_url: str,
batch_size: int = INDEX_BATCH_SIZE,
slab_bot_token: str | None = None,
) -> None:
self.base_url = base_url
self.batch_size = batch_size
self._slab_bot_token: str | None = None
self.slab_bot_token = slab_bot_token
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self._slab_bot_token = credentials["slab_bot_token"]
self.slab_bot_token = credentials["slab_bot_token"]
return None
@property
def slab_bot_token(self) -> str:
if self._slab_bot_token is None:
raise ConnectorMissingCredentialError("Slab")
return self._slab_bot_token
def _iterate_posts(
self, time_filter: Callable[[datetime], bool] | None = None
) -> GenerateDocumentsOutput:
@@ -237,21 +227,3 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnector):
yield from self._iterate_posts(
time_filter=lambda t: start_time <= t <= end_time
)
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
slim_doc_batch: list[SlimDocument] = []
for post_id in get_all_post_ids(self.slab_bot_token):
slim_doc_batch.append(
SlimDocument(
id=post_id,
)
)
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
yield slim_doc_batch
slim_doc_batch = []
if slim_doc_batch:
yield slim_doc_batch

View File

@@ -5,11 +5,7 @@ from typing import cast
from sqlalchemy.orm import Session
from danswer.chat.models import PromptConfig
from danswer.chat.models import SectionRelevancePiece
from danswer.chat.prune_and_merge import _merge_sections
from danswer.chat.prune_and_merge import ChunkRange
from danswer.chat.prune_and_merge import merge_chunk_intervals
from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.enums import QueryFlow
@@ -31,6 +27,10 @@ from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaChunkRequest
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prune_and_merge import _merge_sections
from danswer.llm.answering.prune_and_merge import ChunkRange
from danswer.llm.answering.prune_and_merge import merge_chunk_intervals
from danswer.llm.interfaces import LLM
from danswer.secondary_llm_flows.agentic_evaluation import evaluate_inference_section
from danswer.utils.logger import setup_logger

View File

@@ -16,7 +16,7 @@ from slack_sdk.models.blocks import SectionBlock
from slack_sdk.models.blocks.basic_components import MarkdownTextObject
from slack_sdk.models.blocks.block_elements import ImageElement
from danswer.chat.models import ChatDanswerBotResponse
from danswer.chat.models import DanswerQuote
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
@@ -40,7 +40,10 @@ from danswer.danswerbot.slack.utils import translate_vespa_highlight_to_slack
from danswer.db.chat import get_chat_session_by_message_id
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import ChannelConfig
from danswer.db.models import Persona
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.utils.text_processing import decode_escapes
from danswer.utils.text_processing import replace_whitespaces_w_space
_MAX_BLURB_LEN = 45
@@ -324,7 +327,7 @@ def _build_sources_blocks(
def _priority_ordered_documents_blocks(
answer: ChatDanswerBotResponse,
answer: OneShotQAResponse,
) -> list[Block]:
docs_response = answer.docs if answer.docs else None
top_docs = docs_response.top_documents if docs_response else []
@@ -347,7 +350,7 @@ def _priority_ordered_documents_blocks(
def _build_citations_blocks(
answer: ChatDanswerBotResponse,
answer: OneShotQAResponse,
) -> list[Block]:
docs_response = answer.docs if answer.docs else None
top_docs = docs_response.top_documents if docs_response else []
@@ -366,8 +369,51 @@ def _build_citations_blocks(
return citations_block
def _build_quotes_block(
quotes: list[DanswerQuote],
) -> list[Block]:
quote_lines: list[str] = []
doc_to_quotes: dict[str, list[str]] = {}
doc_to_link: dict[str, str] = {}
doc_to_sem_id: dict[str, str] = {}
for q in quotes:
quote = q.quote
doc_id = q.document_id
doc_link = q.link
doc_name = q.semantic_identifier
if doc_link and doc_name and doc_id and quote:
if doc_id not in doc_to_quotes:
doc_to_quotes[doc_id] = [quote]
doc_to_link[doc_id] = doc_link
doc_to_sem_id[doc_id] = (
doc_name
if q.source_type != DocumentSource.SLACK.value
else "#" + doc_name
)
else:
doc_to_quotes[doc_id].append(quote)
for doc_id, quote_strs in doc_to_quotes.items():
quotes_str_clean = [
replace_whitespaces_w_space(q_str).strip() for q_str in quote_strs
]
longest_quotes = sorted(quotes_str_clean, key=len, reverse=True)[:5]
single_quote_str = "\n".join([f"```{q_str}```" for q_str in longest_quotes])
link = doc_to_link[doc_id]
sem_id = doc_to_sem_id[doc_id]
quote_lines.append(
f"<{link}|{sem_id}>:\n{remove_slack_text_interactions(single_quote_str)}"
)
if not doc_to_quotes:
return []
return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))]
def _build_qa_response_blocks(
answer: ChatDanswerBotResponse,
answer: OneShotQAResponse,
skip_quotes: bool = False,
process_message_for_citations: bool = False,
) -> list[Block]:
retrieval_info = answer.docs
@@ -376,10 +422,13 @@ def _build_qa_response_blocks(
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
quotes = answer.quotes.quotes if answer.quotes else None
if DISABLE_GENERATIVE_AI:
return []
quotes_blocks: list[Block] = []
filter_block: Block | None = None
if (
retrieval_info.applied_time_cutoff
@@ -422,6 +471,16 @@ def _build_qa_response_blocks(
answer_blocks = [
SectionBlock(text=text) for text in _split_text(answer_processed)
]
if quotes:
quotes_blocks = _build_quotes_block(quotes)
# if no quotes OR `_build_quotes_block()` did not give back any blocks
if not quotes_blocks:
quotes_blocks = [
SectionBlock(
text="*Warning*: no sources were quoted for this answer, so it may be unreliable 😔"
)
]
response_blocks: list[Block] = []
@@ -430,6 +489,9 @@ def _build_qa_response_blocks(
response_blocks.extend(answer_blocks)
if not skip_quotes:
response_blocks.extend(quotes_blocks)
return response_blocks
@@ -505,9 +567,10 @@ def build_follow_up_resolved_blocks(
def build_slack_response_blocks(
answer: ChatDanswerBotResponse,
tenant_id: str | None,
message_info: SlackMessageInfo,
answer: OneShotQAResponse,
persona: Persona | None,
channel_conf: ChannelConfig | None,
use_citations: bool,
feedback_reminder_id: str | None,
@@ -524,6 +587,7 @@ def build_slack_response_blocks(
answer_blocks = _build_qa_response_blocks(
answer=answer,
skip_quotes=persona is not None or use_citations,
process_message_for_citations=use_citations,
)
@@ -553,7 +617,8 @@ def build_slack_response_blocks(
citations_blocks = []
document_blocks = []
if use_citations and answer.citations:
if use_citations:
# if citations are enabled, only show cited documents
citations_blocks = _build_citations_blocks(answer)
else:
document_blocks = _priority_ordered_documents_blocks(answer)
@@ -572,5 +637,4 @@ def build_slack_response_blocks(
+ web_follow_up_block
+ follow_up_block
)
return all_blocks

View File

@@ -1,6 +1,7 @@
import functools
from collections.abc import Callable
from typing import Any
from typing import cast
from typing import Optional
from typing import TypeVar
@@ -8,36 +9,46 @@ from retry import retry
from slack_sdk import WebClient
from slack_sdk.models.blocks import SectionBlock
from danswer.chat.chat_utils import prepare_chat_message_request
from danswer.chat.models import ChatDanswerBotResponse
from danswer.chat.process_message import gather_stream_for_slack
from danswer.chat.process_message import stream_chat_message_objects
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.constants import DEFAULT_PERSONA_ID
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.configs.danswerbot_configs import MAX_THREAD_CONTEXT_PERCENTAGE
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
from danswer.context.search.enums import OptionalSearchSetting
from danswer.context.search.models import BaseFilters
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
from danswer.danswerbot.slack.blocks import build_slack_response_blocks
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
from danswer.danswerbot.slack.handlers.utils import slackify_message_thread
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import SlackRateLimiter
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import Persona
from danswer.db.models import SlackBotResponseType
from danswer.db.models import SlackChannelConfig
from danswer.db.models import User
from danswer.db.persona import get_persona_by_id
from danswer.db.persona import fetch_persona_by_id
from danswer.db.search_settings import get_current_search_settings
from danswer.db.users import get_user_by_email
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_max_input_tokens
from danswer.one_shot_answer.answer_question import get_search_answer
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.utils.logger import DanswerLoggingAdapter
srl = SlackRateLimiter()
RT = TypeVar("RT") # return type
@@ -72,14 +83,16 @@ def handle_regular_answer(
feedback_reminder_id: str | None,
tenant_id: str | None,
num_retries: int = DANSWER_BOT_NUM_RETRIES,
thread_context_percent: float = MAX_THREAD_CONTEXT_PERCENTAGE,
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER,
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
) -> bool:
channel_conf = slack_channel_config.channel_config if slack_channel_config else None
messages = message_info.thread_messages
message_ts_to_respond_to = message_info.msg_to_respond
is_bot_msg = message_info.is_bot_msg
user = None
@@ -89,18 +102,9 @@ def handle_regular_answer(
user = get_user_by_email(message_info.email, db_session)
document_set_names: list[str] | None = None
prompt = None
# If no persona is specified, use the default search based persona
# This way slack flow always has a persona
persona = slack_channel_config.persona if slack_channel_config else None
if not persona:
with get_session_with_tenant(tenant_id) as db_session:
persona = get_persona_by_id(DEFAULT_PERSONA_ID, user, db_session)
document_set_names = [
document_set.name for document_set in persona.document_sets
]
prompt = persona.prompts[0] if persona.prompts else None
else:
prompt = None
if persona:
document_set_names = [
document_set.name for document_set in persona.document_sets
]
@@ -108,26 +112,6 @@ def handle_regular_answer(
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
# TODO: Add in support for Slack to truncate messages based on max LLM context
# llm, _ = get_llms_for_persona(persona)
# llm_tokenizer = get_tokenizer(
# model_name=llm.config.model_name,
# provider_type=llm.config.model_provider,
# )
# # In cases of threads, split the available tokens between docs and thread context
# input_tokens = get_max_input_tokens(
# model_name=llm.config.model_name,
# model_provider=llm.config.model_provider,
# )
# max_history_tokens = int(input_tokens * thread_context_percent)
# combined_message = combine_message_thread(
# messages, max_tokens=max_history_tokens, llm_tokenizer=llm_tokenizer
# )
combined_message = slackify_message_thread(messages)
bypass_acl = False
if (
slack_channel_config
@@ -138,6 +122,13 @@ def handle_regular_answer(
# with non-public document sets
bypass_acl = True
# figure out if we want to use citations or quotes
use_citations = (
not DANSWER_BOT_USE_QUOTES
if slack_channel_config is None
else slack_channel_config.response_type == SlackBotResponseType.CITATIONS
)
if not message_ts_to_respond_to and not is_bot_msg:
# if the message is not "/danswer" command, then it should have a message ts to respond to
raise RuntimeError(
@@ -150,23 +141,75 @@ def handle_regular_answer(
backoff=2,
)
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
def _get_slack_answer(
new_message_request: CreateChatMessageRequest, danswer_user: User | None
) -> ChatDanswerBotResponse:
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | None:
max_document_tokens: int | None = None
max_history_tokens: int | None = None
with get_session_with_tenant(tenant_id) as db_session:
packets = stream_chat_message_objects(
new_msg_req=new_message_request,
user=danswer_user,
if len(new_message_request.messages) > 1:
if new_message_request.persona_config:
raise RuntimeError("Slack bot does not support persona config")
elif new_message_request.persona_id is not None:
persona = cast(
Persona,
fetch_persona_by_id(
db_session,
new_message_request.persona_id,
user=None,
get_editable=False,
),
)
else:
raise RuntimeError(
"No persona id provided, this should never happen."
)
llm, _ = get_llms_for_persona(persona)
# In cases of threads, split the available tokens between docs and thread context
input_tokens = get_max_input_tokens(
model_name=llm.config.model_name,
model_provider=llm.config.model_provider,
)
max_history_tokens = int(input_tokens * thread_context_percent)
remaining_tokens = input_tokens - max_history_tokens
query_text = new_message_request.messages[0].message
if persona:
max_document_tokens = compute_max_document_tokens_for_persona(
persona=persona,
actual_user_input=query_text,
max_llm_token_override=remaining_tokens,
)
else:
max_document_tokens = (
remaining_tokens
- 512 # Needs to be more than any of the QA prompts
- check_number_of_tokens(query_text)
)
if DISABLE_GENERATIVE_AI:
return None
# This also handles creating the query event in postgres
answer = get_search_answer(
query_req=new_message_request,
user=user,
max_document_tokens=max_document_tokens,
max_history_tokens=max_history_tokens,
db_session=db_session,
answer_generation_timeout=answer_generation_timeout,
enable_reflexion=reflexion,
bypass_acl=bypass_acl,
use_citations=use_citations,
danswerbot_flow=True,
)
answer = gather_stream_for_slack(packets)
if answer.error_msg:
raise RuntimeError(answer.error_msg)
return answer
if not answer.error_msg:
return answer
else:
raise RuntimeError(answer.error_msg)
try:
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
@@ -196,24 +239,26 @@ def handle_regular_answer(
enable_auto_detect_filters=auto_detect_filters,
)
# Always apply reranking settings if it exists, this is the non-streaming flow
with get_session_with_tenant(tenant_id) as db_session:
answer_request = prepare_chat_message_request(
message_text=combined_message,
user=user,
persona_id=persona.id,
# This is not used in the Slack flow, only in the answer API
persona_override_config=None,
prompt=prompt,
message_ts_to_respond_to=message_ts_to_respond_to,
retrieval_details=retrieval_details,
rerank_settings=None, # Rerank customization supported in Slack flow
db_session=db_session,
saved_search_settings = get_current_search_settings(db_session)
# This includes throwing out answer via reflexion
answer = _get_answer(
DirectQARequest(
messages=messages,
multilingual_query_expansion=saved_search_settings.multilingual_expansion
if saved_search_settings
else None,
prompt_id=prompt.id if prompt else None,
persona_id=persona.id if persona is not None else 0,
retrieval_options=retrieval_details,
chain_of_thought=not disable_cot,
rerank_settings=RerankingDetails.from_db_model(saved_search_settings)
if saved_search_settings
else None,
)
answer = _get_slack_answer(
new_message_request=answer_request, danswer_user=user
)
except Exception as e:
logger.exception(
f"Unable to process message - did not successfully answer "
@@ -314,7 +359,7 @@ def handle_regular_answer(
top_docs = retrieval_info.top_documents
if not top_docs and not should_respond_even_with_no_docs:
logger.error(
f"Unable to answer question: '{combined_message}' - no documents found"
f"Unable to answer question: '{answer.rephrase}' - no documents found"
)
# Optionally, respond in thread with the error message
# Used primarily for debugging purposes
@@ -335,18 +380,18 @@ def handle_regular_answer(
)
return True
only_respond_if_citations = (
only_respond_with_citations_or_quotes = (
channel_conf
and "well_answered_postfilter" in channel_conf.get("answer_filters", [])
)
has_citations_or_quotes = bool(answer.citations or answer.quotes)
if (
only_respond_if_citations
and not answer.citations
only_respond_with_citations_or_quotes
and not has_citations_or_quotes
and not message_info.bypass_filters
):
logger.error(
f"Unable to find citations to answer: '{answer.answer}' - not answering!"
f"Unable to find citations or quotes to answer: '{answer.rephrase}' - not answering!"
)
# Optionally, respond in thread with the error message
# Used primarily for debugging purposes
@@ -364,8 +409,9 @@ def handle_regular_answer(
tenant_id=tenant_id,
message_info=message_info,
answer=answer,
persona=persona,
channel_conf=channel_conf,
use_citations=True, # No longer supporting quotes
use_citations=use_citations,
feedback_reminder_id=feedback_reminder_id,
)

View File

@@ -1,33 +1,8 @@
from slack_sdk import WebClient
from danswer.chat.models import ThreadMessage
from danswer.configs.constants import MessageType
from danswer.danswerbot.slack.utils import respond_in_thread
def slackify_message_thread(messages: list[ThreadMessage]) -> str:
# Note: this does not handle extremely long threads, every message will be included
# with weaker LLMs, this could cause issues with exceeeding the token limit
if not messages:
return ""
message_strs: list[str] = []
for message in messages:
if message.role == MessageType.USER:
message_text = (
f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
)
elif message.role == MessageType.ASSISTANT:
message_text = f"AI said in Slack:\n{message.message}"
else:
message_text = (
f"{message.role.value.upper()} said in Slack:\n{message.message}"
)
message_strs.append(message_text)
return "\n\n".join(message_strs)
def send_team_member_message(
client: WebClient,
channel: str,

View File

@@ -19,8 +19,6 @@ from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from sqlalchemy.orm import Session
from danswer.chat.models import ThreadMessage
from danswer.configs.app_configs import DEV_MODE
from danswer.configs.app_configs import POD_NAME
from danswer.configs.app_configs import POD_NAMESPACE
from danswer.configs.constants import DanswerRedisLocks
@@ -76,6 +74,7 @@ from danswer.db.slack_bot import fetch_slack_bots
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.one_shot_answer.models import ThreadMessage
from danswer.redis.redis_pool import get_redis_client
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
@@ -251,7 +250,7 @@ class SlackbotHandler:
nx=True,
ex=TENANT_LOCK_EXPIRATION,
)
if not acquired and not DEV_MODE:
if not acquired:
logger.debug(f"Another pod holds the lock for tenant {tenant_id}")
continue

View File

@@ -1,6 +1,6 @@
from pydantic import BaseModel
from danswer.chat.models import ThreadMessage
from danswer.one_shot_answer.models import ThreadMessage
class SlackMessageInfo(BaseModel):

View File

@@ -30,13 +30,13 @@ from danswer.configs.danswerbot_configs import (
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import SlackTextCleaner
from danswer.danswerbot.slack.constants import FeedbackVisibility
from danswer.danswerbot.slack.models import ThreadMessage
from danswer.db.engine import get_session_with_tenant
from danswer.db.users import get_user_by_email
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llms
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
from danswer.one_shot_answer.models import ThreadMessage
from danswer.prompts.miscellaneous_prompts import SLACK_LANGUAGE_REPHRASE_PROMPT
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry

View File

@@ -145,10 +145,16 @@ def get_chat_sessions_by_user(
user_id: UUID | None,
deleted: bool | None,
db_session: Session,
only_one_shot: bool = False,
limit: int = 50,
) -> list[ChatSession]:
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
if only_one_shot:
stmt = stmt.where(ChatSession.one_shot.is_(True))
else:
stmt = stmt.where(ChatSession.one_shot.is_(False))
stmt = stmt.order_by(desc(ChatSession.time_created))
if deleted is not None:
@@ -220,11 +226,12 @@ def delete_messages_and_files_from_chat_session(
def create_chat_session(
db_session: Session,
description: str | None,
description: str,
user_id: UUID | None,
persona_id: int | None, # Can be none if temporary persona is used
llm_override: LLMOverride | None = None,
prompt_override: PromptOverride | None = None,
one_shot: bool = False,
danswerbot_flow: bool = False,
slack_thread_id: str | None = None,
) -> ChatSession:
@@ -234,6 +241,7 @@ def create_chat_session(
description=description,
llm_override=llm_override,
prompt_override=prompt_override,
one_shot=one_shot,
danswerbot_flow=danswerbot_flow,
slack_thread_id=slack_thread_id,
)
@@ -279,6 +287,8 @@ def duplicate_chat_session_for_user_from_slack(
description="",
llm_override=chat_session.llm_override,
prompt_override=chat_session.prompt_override,
# Chat sessions from Slack should put people in the chat UI, not the search
one_shot=False,
# Chat is in UI now so this is false
danswerbot_flow=False,
# Maybe we want this in the future to track if it was created from Slack

View File

@@ -37,7 +37,6 @@ from danswer.configs.app_configs import POSTGRES_PORT
from danswer.configs.app_configs import POSTGRES_USER
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.server.utils import BasicAuthenticationError
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
@@ -427,9 +426,7 @@ def get_session() -> Generator[Session, None, None]:
"""Generate a database session with the appropriate tenant schema set."""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
raise BasicAuthenticationError(
detail="User must authenticate",
)
raise HTTPException(status_code=401, detail="User must authenticate")
engine = get_sqlalchemy_engine()

View File

@@ -1,5 +1,6 @@
import datetime
import json
from enum import Enum as PyEnum
from typing import Any
from typing import Literal
from typing import NotRequired
@@ -963,8 +964,9 @@ class ChatSession(Base):
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True
)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
# This chat created by DanswerBot
description: Mapped[str] = mapped_column(Text)
# One-shot direct answering, currently the two types of chats are not mixed
one_shot: Mapped[bool] = mapped_column(Boolean, default=False)
danswerbot_flow: Mapped[bool] = mapped_column(Boolean, default=False)
# Only ever set to True if system is set to not hard-delete chats
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
@@ -1486,6 +1488,11 @@ class ChannelConfig(TypedDict):
show_continue_in_web_ui: NotRequired[bool] # defaults to False
class SlackBotResponseType(str, PyEnum):
QUOTES = "quotes"
CITATIONS = "citations"
class SlackChannelConfig(Base):
__tablename__ = "slack_channel_config"
@@ -1498,6 +1505,9 @@ class SlackChannelConfig(Base):
channel_config: Mapped[ChannelConfig] = mapped_column(
postgresql.JSONB(), nullable=False
)
response_type: Mapped[SlackBotResponseType] = mapped_column(
Enum(SlackBotResponseType, native_enum=False), nullable=False
)
enable_auto_filters: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False

View File

@@ -415,6 +415,9 @@ def upsert_prompt(
return prompt
# NOTE: This operation cannot update persona configuration options that
# are core to the persona, such as its display priority and
# whether or not the assistant is a built-in / default assistant
def upsert_persona(
user: User | None,
name: str,
@@ -446,12 +449,6 @@ def upsert_persona(
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
chunks_below: int = CONTEXT_CHUNKS_BELOW,
) -> Persona:
"""
NOTE: This operation cannot update persona configuration options that
are core to the persona, such as its display priority and
whether or not the assistant is a built-in / default assistant
"""
if persona_id is not None:
persona = db_session.query(Persona).filter_by(id=persona_id).first()
else:
@@ -489,8 +486,6 @@ def upsert_persona(
validate_persona_tools(tools)
if persona:
# Built-in personas can only be updated through YAML configuration.
# This ensures that core system personas are not modified unintentionally.
if persona.builtin_persona and not builtin_persona:
raise ValueError("Cannot update builtin persona with non-builtin.")
@@ -499,9 +494,6 @@ def upsert_persona(
db_session=db_session, persona_id=persona.id, user=user, get_editable=True
)
# The following update excludes `default`, `built-in`, and display priority.
# Display priority is handled separately in the `display-priority` endpoint.
# `default` and `built-in` properties can only be set when creating a persona.
persona.name = name
persona.description = description
persona.num_chunks = num_chunks

View File

@@ -10,6 +10,7 @@ from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
from danswer.db.models import ChannelConfig
from danswer.db.models import Persona
from danswer.db.models import Persona__DocumentSet
from danswer.db.models import SlackBotResponseType
from danswer.db.models import SlackChannelConfig
from danswer.db.models import User
from danswer.db.persona import get_default_prompt
@@ -82,6 +83,7 @@ def insert_slack_channel_config(
slack_bot_id: int,
persona_id: int | None,
channel_config: ChannelConfig,
response_type: SlackBotResponseType,
standard_answer_category_ids: list[int],
enable_auto_filters: bool,
) -> SlackChannelConfig:
@@ -113,6 +115,7 @@ def insert_slack_channel_config(
slack_bot_id=slack_bot_id,
persona_id=persona_id,
channel_config=channel_config,
response_type=response_type,
standard_answer_categories=existing_standard_answer_categories,
enable_auto_filters=enable_auto_filters,
)
@@ -127,6 +130,7 @@ def update_slack_channel_config(
slack_channel_config_id: int,
persona_id: int | None,
channel_config: ChannelConfig,
response_type: SlackBotResponseType,
standard_answer_category_ids: list[int],
enable_auto_filters: bool,
) -> SlackChannelConfig:
@@ -166,6 +170,7 @@ def update_slack_channel_config(
# will encounter `violates foreign key constraint` errors
slack_channel_config.persona_id = persona_id
slack_channel_config.channel_config = channel_config
slack_channel_config.response_type = response_type
slack_channel_config.standard_answer_categories = list(
existing_standard_answer_categories
)

View File

@@ -4,8 +4,6 @@ schema DANSWER_CHUNK_NAME {
# Not to be confused with the UUID generated for this chunk which is called documentid by default
field document_id type string {
indexing: summary | attribute
attribute: fast-search
rank: filter
}
field chunk_id type int {
indexing: summary | attribute

View File

@@ -6,7 +6,6 @@ import zipfile
from collections.abc import Callable
from collections.abc import Iterator
from email.parser import Parser as EmailParser
from io import BytesIO
from pathlib import Path
from typing import Any
from typing import Dict
@@ -16,17 +15,13 @@ import chardet
import docx # type: ignore
import openpyxl # type: ignore
import pptx # type: ignore
from docx import Document
from fastapi import UploadFile
from pypdf import PdfReader
from pypdf.errors import PdfStreamError
from danswer.configs.constants import DANSWER_METADATA_FILENAME
from danswer.configs.constants import FileOrigin
from danswer.file_processing.html_utils import parse_html_page_basic
from danswer.file_processing.unstructured import get_unstructured_api_key
from danswer.file_processing.unstructured import unstructured_to_text
from danswer.file_store.file_store import FileStore
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -380,35 +375,3 @@ def extract_file_text(
) from e
logger.warning(f"Failed to process file {file_name or 'Unknown'}: {str(e)}")
return ""
def convert_docx_to_txt(
file: UploadFile, file_store: FileStore, file_path: str
) -> None:
file.file.seek(0)
docx_content = file.file.read()
doc = Document(BytesIO(docx_content))
# Extract text from the document
full_text = []
for para in doc.paragraphs:
full_text.append(para.text)
# Join the extracted text
text_content = "\n".join(full_text)
txt_file_path = docx_to_txt_filename(file_path)
file_store.save_file(
file_name=txt_file_path,
content=BytesIO(text_content.encode("utf-8")),
display_name=file.filename,
file_origin=FileOrigin.CONNECTOR,
file_type="text/plain",
)
def docx_to_txt_filename(file_path: str) -> str:
"""
Convert a .docx file path to its corresponding .txt file path.
"""
return file_path.rsplit(".", 1)[0] + ".txt"

View File

@@ -59,12 +59,6 @@ class FileStore(ABC):
Contents of the file and metadata dict
"""
@abstractmethod
def read_file_record(self, file_name: str) -> PGFileStore:
"""
Read the file record by the name
"""
@abstractmethod
def delete_file(self, file_name: str) -> None:
"""

View File

@@ -1,6 +1,6 @@
import base64
from collections.abc import Callable
from io import BytesIO
from typing import Any
from typing import cast
from uuid import uuid4
@@ -13,8 +13,8 @@ from danswer.db.models import ChatMessage
from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import FileDescriptor
from danswer.file_store.models import InMemoryChatFile
from danswer.utils.b64 import get_image_type
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
def load_chat_file(
@@ -75,58 +75,11 @@ def save_file_from_url(url: str, tenant_id: str) -> str:
return unique_id
def save_file_from_base64(base64_string: str, tenant_id: str) -> str:
with get_session_with_tenant(tenant_id) as db_session:
unique_id = str(uuid4())
file_store = get_default_file_store(db_session)
file_store.save_file(
file_name=unique_id,
content=BytesIO(base64.b64decode(base64_string)),
display_name="GeneratedImage",
file_origin=FileOrigin.CHAT_IMAGE_GEN,
file_type=get_image_type(base64_string),
)
return unique_id
def save_files_from_urls(urls: list[str]) -> list[str]:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
def save_file(
tenant_id: str,
url: str | None = None,
base64_data: str | None = None,
) -> str:
"""Save a file from either a URL or base64 encoded string.
Args:
tenant_id: The tenant ID to save the file under
url: URL to download file from
base64_data: Base64 encoded file data
Returns:
The unique ID of the saved file
Raises:
ValueError: If neither url nor base64_data is provided, or if both are provided
"""
if url is not None and base64_data is not None:
raise ValueError("Cannot specify both url and base64_data")
if url is not None:
return save_file_from_url(url, tenant_id)
elif base64_data is not None:
return save_file_from_base64(base64_data, tenant_id)
else:
raise ValueError("Must specify either url or base64_data")
def save_files(urls: list[str], base64_files: list[str], tenant_id: str) -> list[str]:
# NOTE: be explicit about typing so that if we change things, we get notified
funcs: list[
tuple[
Callable[[str, str | None, str | None], str],
tuple[str, str | None, str | None],
]
] = [(save_file, (tenant_id, url, None)) for url in urls] + [
(save_file, (tenant_id, None, base64_file)) for base64_file in base64_files
funcs: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [
(save_file_from_url, (url, tenant_id)) for url in urls
]
# Must pass in tenant_id here, since this is called by multithreading
return run_functions_tuples_in_parallel(funcs)

View File

@@ -6,27 +6,33 @@ from langchain.schema.messages import BaseMessage
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import ToolCall
from danswer.chat.llm_response_handler import LLMResponseHandlerManager
from danswer.chat.models import AnswerQuestionPossibleReturn
from danswer.chat.models import AnswerStyleConfig
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import PromptConfig
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
from danswer.chat.prompt_builder.build import default_build_system_message
from danswer.chat.prompt_builder.build import default_build_user_message
from danswer.chat.prompt_builder.build import LLMCall
from danswer.chat.stream_processing.answer_response_handler import (
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.llm_response_handler import LLMCall
from danswer.llm.answering.llm_response_handler import LLMResponseHandlerManager
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.answering.prompts.build import default_build_system_message
from danswer.llm.answering.prompts.build import default_build_user_message
from danswer.llm.answering.stream_processing.answer_response_handler import (
AnswerResponseHandler,
)
from danswer.llm.answering.stream_processing.answer_response_handler import (
CitationResponseHandler,
)
from danswer.chat.stream_processing.answer_response_handler import (
from danswer.llm.answering.stream_processing.answer_response_handler import (
DummyAnswerResponseHandler,
)
from danswer.chat.stream_processing.utils import map_document_id_order
from danswer.chat.tool_handling.tool_response_handler import ToolResponseHandler
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.stream_processing.answer_response_handler import (
QuotesResponseHandler,
)
from danswer.llm.answering.stream_processing.utils import map_document_id_order
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
from danswer.llm.interfaces import LLM
from danswer.llm.models import PreviousMessage
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.tools.force import ForceUseTool
from danswer.tools.models import ToolResponse
@@ -208,23 +214,18 @@ class Answer:
search_result = SearchTool.get_search_result(current_llm_call) or []
# Quotes are no longer supported
# answer_handler: AnswerResponseHandler
# if self.answer_style_config.citation_config:
# answer_handler = CitationResponseHandler(
# context_docs=search_result,
# doc_id_to_rank_map=map_document_id_order(search_result),
# )
# elif self.answer_style_config.quotes_config:
# answer_handler = QuotesResponseHandler(
# context_docs=search_result,
# )
# else:
# raise ValueError("No answer style config provided")
answer_handler = CitationResponseHandler(
context_docs=search_result,
doc_id_to_rank_map=map_document_id_order(search_result),
)
answer_handler: AnswerResponseHandler
if self.answer_style_config.citation_config:
answer_handler = CitationResponseHandler(
context_docs=search_result,
doc_id_to_rank_map=map_document_id_order(search_result),
)
elif self.answer_style_config.quotes_config:
answer_handler = QuotesResponseHandler(
context_docs=search_result,
)
else:
raise ValueError("No answer style config provided")
response_handler_manager = LLMResponseHandlerManager(
tool_call_handler, answer_handler, self.is_cancelled

View File

@@ -1,22 +1,60 @@
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from typing import TYPE_CHECKING
from langchain_core.messages import BaseMessage
from pydantic.v1 import BaseModel as BaseModel__v1
from danswer.chat.models import ResponsePart
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import StreamStopInfo
from danswer.chat.models import StreamStopReason
from danswer.chat.prompt_builder.build import LLMCall
from danswer.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from danswer.chat.tool_handling.tool_response_handler import ToolResponseHandler
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.tools.force import ForceUseTool
from danswer.tools.models import ToolCallFinalResult
from danswer.tools.models import ToolCallKickoff
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
if TYPE_CHECKING:
from danswer.llm.answering.stream_processing.answer_response_handler import (
AnswerResponseHandler,
)
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
ResponsePart = (
DanswerAnswerPiece
| CitationInfo
| DanswerQuotes
| ToolCallKickoff
| ToolResponse
| ToolCallFinalResult
| StreamStopInfo
)
class LLMCall(BaseModel__v1):
prompt_builder: AnswerPromptBuilder
tools: list[Tool]
force_use_tool: ForceUseTool
files: list[InMemoryChatFile]
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
using_tool_calling_llm: bool
class Config:
arbitrary_types_allowed = True
class LLMResponseHandlerManager:
def __init__(
self,
tool_handler: ToolResponseHandler,
answer_handler: AnswerResponseHandler,
tool_handler: "ToolResponseHandler",
answer_handler: "AnswerResponseHandler",
is_cancelled: Callable[[], bool],
):
self.tool_handler = tool_handler

View File

@@ -0,0 +1,163 @@
from collections.abc import Callable
from collections.abc import Iterator
from typing import TYPE_CHECKING
from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import model_validator
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.configs.constants import MessageType
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.override_models import PromptOverride
from danswer.llm.utils import build_content_with_imgs
from danswer.tools.models import ToolCallFinalResult
if TYPE_CHECKING:
from danswer.db.models import ChatMessage
from danswer.db.models import Prompt
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
class PreviousMessage(BaseModel):
"""Simplified version of `ChatMessage`"""
message: str
token_count: int
message_type: MessageType
files: list[InMemoryChatFile]
tool_call: ToolCallFinalResult | None
@classmethod
def from_chat_message(
cls, chat_message: "ChatMessage", available_files: list[InMemoryChatFile]
) -> "PreviousMessage":
message_file_ids = (
[file["id"] for file in chat_message.files] if chat_message.files else []
)
return cls(
message=chat_message.message,
token_count=chat_message.token_count,
message_type=chat_message.message_type,
files=[
file
for file in available_files
if str(file.file_id) in message_file_ids
],
tool_call=ToolCallFinalResult(
tool_name=chat_message.tool_call.tool_name,
tool_args=chat_message.tool_call.tool_arguments,
tool_result=chat_message.tool_call.tool_result,
)
if chat_message.tool_call
else None,
)
def to_langchain_msg(self) -> BaseMessage:
content = build_content_with_imgs(self.message, self.files)
if self.message_type == MessageType.USER:
return HumanMessage(content=content)
elif self.message_type == MessageType.ASSISTANT:
return AIMessage(content=content)
else:
return SystemMessage(content=content)
class DocumentPruningConfig(BaseModel):
max_chunks: int | None = None
max_window_percentage: float | None = None
max_tokens: int | None = None
# different pruning behavior is expected when the
# user manually selects documents they want to chat with
# e.g. we don't want to truncate each document to be no more
# than one chunk long
is_manually_selected_docs: bool = False
# If user specifies to include additional context Chunks for each match, then different pruning
# is used. As many Sections as possible are included, and the last Section is truncated
# If this is false, all of the Sections are truncated if they are longer than the expected Chunk size.
# Sections are often expected to be longer than the maximum Chunk size but Chunks should not be.
use_sections: bool = True
# If using tools, then we need to consider the tool length
tool_num_tokens: int = 0
# If using a tool message to represent the docs, then we have to JSON serialize
# the document content, which adds to the token count.
using_tool_message: bool = False
class ContextualPruningConfig(DocumentPruningConfig):
num_chunk_multiple: int
@classmethod
def from_doc_pruning_config(
cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig
) -> "ContextualPruningConfig":
return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict())
class CitationConfig(BaseModel):
all_docs_useful: bool = False
class QuotesConfig(BaseModel):
pass
class AnswerStyleConfig(BaseModel):
citation_config: CitationConfig | None = None
quotes_config: QuotesConfig | None = None
document_pruning_config: DocumentPruningConfig = Field(
default_factory=DocumentPruningConfig
)
# forces the LLM to return a structured response, see
# https://platform.openai.com/docs/guides/structured-outputs/introduction
# right now, only used by the simple chat API
structured_response_format: dict | None = None
@model_validator(mode="after")
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
if self.citation_config is None and self.quotes_config is None:
raise ValueError(
"One of `citation_config` or `quotes_config` must be provided"
)
if self.citation_config is not None and self.quotes_config is not None:
raise ValueError(
"Only one of `citation_config` or `quotes_config` must be provided"
)
return self
class PromptConfig(BaseModel):
"""Final representation of the Prompt configuration passed
into the `Answer` object."""
system_prompt: str
task_prompt: str
datetime_aware: bool
include_citations: bool
@classmethod
def from_model(
cls, model: "Prompt", prompt_override: PromptOverride | None = None
) -> "PromptConfig":
override_system_prompt = (
prompt_override.system_prompt if prompt_override else None
)
override_task_prompt = prompt_override.task_prompt if prompt_override else None
return cls(
system_prompt=override_system_prompt or model.system_prompt,
task_prompt=override_task_prompt or model.task_prompt,
datetime_aware=model.datetime_aware,
include_citations=model.include_citations,
)
model_config = ConfigDict(frozen=True)

View File

@@ -4,26 +4,20 @@ from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from pydantic.v1 import BaseModel as BaseModel__v1
from danswer.chat.models import PromptConfig
from danswer.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
from danswer.chat.prompt_builder.utils import translate_history_to_basemessages
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input_tokens
from danswer.llm.interfaces import LLMConfig
from danswer.llm.models import PreviousMessage
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import check_message_tokens
from danswer.llm.utils import message_to_prompt_and_imgs
from danswer.llm.utils import translate_history_to_basemessages
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from danswer.prompts.prompt_utils import add_date_time_to_prompt
from danswer.prompts.prompt_utils import drop_messages_history_overflow
from danswer.tools.force import ForceUseTool
from danswer.tools.models import ToolCallFinalResult
from danswer.tools.models import ToolCallKickoff
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
def default_build_system_message(
@@ -145,15 +139,3 @@ class AnswerPromptBuilder:
return drop_messages_history_overflow(
final_messages_with_tokens, self.max_tokens
)
class LLMCall(BaseModel__v1):
prompt_builder: AnswerPromptBuilder
tools: list[Tool]
force_use_tool: ForceUseTool
files: list[InMemoryChatFile]
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
using_tool_calling_llm: bool
class Config:
arbitrary_types_allowed = True

View File

@@ -2,12 +2,12 @@ from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from danswer.chat.models import LlmDoc
from danswer.chat.models import PromptConfig
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from danswer.context.search.models import InferenceChunk
from danswer.db.models import Persona
from danswer.db.persona import get_default_prompt__read_only
from danswer.db.search_settings import get_multilingual_expansion
from danswer.llm.answering.models import PromptConfig
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig

View File

@@ -1,10 +1,10 @@
from langchain.schema.messages import HumanMessage
from danswer.chat.models import LlmDoc
from danswer.chat.models import PromptConfig
from danswer.configs.chat_configs import LANGUAGE_HINT
from danswer.context.search.models import InferenceChunk
from danswer.db.search_settings import get_multilingual_expansion
from danswer.llm.answering.models import PromptConfig
from danswer.llm.utils import message_to_prompt_and_imgs
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK

View File

@@ -0,0 +1,20 @@
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
def build_dummy_prompt(
system_prompt: str, task_prompt: str, retrieval_disabled: bool
) -> str:
if retrieval_disabled:
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
return PARAMATERIZED_PROMPT.format(
context_docs_str="<CONTEXT_DOCS>",
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()

View File

@@ -5,16 +5,16 @@ from typing import TypeVar
from pydantic import BaseModel
from danswer.chat.models import ContextualPruningConfig
from danswer.chat.models import (
LlmDoc,
)
from danswer.chat.models import PromptConfig
from danswer.chat.prompt_builder.citations_prompt import compute_max_document_tokens
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.context.search.models import InferenceChunk
from danswer.context.search.models import InferenceSection
from danswer.llm.answering.models import ContextualPruningConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
from danswer.llm.interfaces import LLMConfig
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content

View File

@@ -3,11 +3,16 @@ from collections.abc import Generator
from langchain_core.messages import BaseMessage
from danswer.chat.llm_response_handler import ResponsePart
from danswer.chat.models import CitationInfo
from danswer.chat.models import LlmDoc
from danswer.chat.stream_processing.citation_processing import CitationProcessor
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
from danswer.llm.answering.llm_response_handler import ResponsePart
from danswer.llm.answering.stream_processing.citation_processing import (
CitationProcessor,
)
from danswer.llm.answering.stream_processing.quotes_processing import (
QuotesProcessor,
)
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -65,29 +70,28 @@ class CitationResponseHandler(AnswerResponseHandler):
yield from self.citation_processor.process_token(content)
# No longer in use, remove later
# class QuotesResponseHandler(AnswerResponseHandler):
# def __init__(
# self,
# context_docs: list[LlmDoc],
# is_json_prompt: bool = True,
# ):
# self.quotes_processor = QuotesProcessor(
# context_docs=context_docs,
# is_json_prompt=is_json_prompt,
# )
class QuotesResponseHandler(AnswerResponseHandler):
def __init__(
self,
context_docs: list[LlmDoc],
is_json_prompt: bool = True,
):
self.quotes_processor = QuotesProcessor(
context_docs=context_docs,
is_json_prompt=is_json_prompt,
)
# def handle_response_part(
# self,
# response_item: BaseMessage | None,
# previous_response_items: list[BaseMessage],
# ) -> Generator[ResponsePart, None, None]:
# if response_item is None:
# yield from self.quotes_processor.process_token(None)
# return
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
) -> Generator[ResponsePart, None, None]:
if response_item is None:
yield from self.quotes_processor.process_token(None)
return
# content = (
# response_item.content if isinstance(response_item.content, str) else ""
# )
content = (
response_item.content if isinstance(response_item.content, str) else ""
)
# yield from self.quotes_processor.process_token(content)
yield from self.quotes_processor.process_token(content)

View File

@@ -4,8 +4,8 @@ from collections.abc import Generator
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
from danswer.configs.chat_configs import STOP_STREAM_PAT
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.prompts.constants import TRIPLE_BACKTICK
from danswer.utils.logger import setup_logger
@@ -67,9 +67,9 @@ class CitationProcessor:
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
self.curr_segment = self.curr_segment.replace("```", "```plaintext")
citation_pattern = r"\[(\d+)\]|\[\[(\d+)\]\]" # [1], [[1]], etc.
citation_pattern = r"\[(\d+)\]"
citations_found = list(re.finditer(citation_pattern, self.curr_segment))
possible_citation_pattern = r"(\[+\d*$)" # [1, [, [[, [[2, etc.
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
possible_citation_found = re.search(
possible_citation_pattern, self.curr_segment
)
@@ -77,15 +77,13 @@ class CitationProcessor:
if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5:
self.current_citations = []
result = ""
result = "" # Initialize result here
if citations_found and not in_code_block(self.llm_out):
last_citation_end = 0
length_to_add = 0
while len(citations_found) > 0:
citation = citations_found.pop(0)
numerical_value = int(
next(group for group in citation.groups() if group is not None)
)
numerical_value = int(citation.group(1))
if 1 <= numerical_value <= self.max_citation_num:
context_llm_doc = self.context_docs[numerical_value - 1]
@@ -133,6 +131,14 @@ class CitationProcessor:
link = context_llm_doc.link
# Replace the citation in the current segment
start, end = citation.span()
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[{target_citation_num}]"
+ self.curr_segment[end + length_to_add :]
)
self.past_cite_count = len(self.llm_out)
self.current_citations.append(target_citation_num)
@@ -143,7 +149,6 @@ class CitationProcessor:
document_id=context_llm_doc.document_id,
)
start, end = citation.span()
if link:
prev_length = len(self.curr_segment)
self.curr_segment = (

View File

@@ -1,4 +1,3 @@
# THIS IS NO LONGER IN USE
import math
import re
from collections.abc import Generator
@@ -6,10 +5,11 @@ from json import JSONDecodeError
from typing import Optional
import regex
from pydantic import BaseModel
from danswer.chat.models import DanswerAnswer
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerQuote
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.context.search.models import InferenceChunk
@@ -26,20 +26,6 @@ logger = setup_logger()
answer_pattern = re.compile(r'{\s*"answer"\s*:\s*"', re.IGNORECASE)
class DanswerQuote(BaseModel):
# This is during inference so everything is a string by this point
quote: str
document_id: str
link: str | None
source_type: str
semantic_identifier: str
blurb: str
class DanswerQuotes(BaseModel):
quotes: list[DanswerQuote]
def _extract_answer_quotes_freeform(
answer_raw: str,
) -> tuple[Optional[str], Optional[list[str]]]:

View File

@@ -4,8 +4,8 @@ from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langchain_core.messages import ToolCall
from danswer.chat.models import ResponsePart
from danswer.chat.prompt_builder.build import LLMCall
from danswer.llm.answering.llm_response_handler import LLMCall
from danswer.llm.answering.llm_response_handler import ResponsePart
from danswer.llm.interfaces import LLM
from danswer.tools.force import ForceUseTool
from danswer.tools.message import build_tool_message

View File

@@ -1,6 +1,5 @@
from typing import Any
from danswer.chat.models import PersonaOverrideConfig
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
@@ -14,11 +13,8 @@ from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.interfaces import LLM
from danswer.llm.override_models import LLMOverride
from danswer.utils.headers import build_llm_extra_headers
from danswer.utils.logger import setup_logger
from danswer.utils.long_term_log import LongTermLogger
logger = setup_logger()
def _build_extra_model_kwargs(provider: str) -> dict[str, Any]:
"""Ollama requires us to specify the max context window.
@@ -36,15 +32,11 @@ def get_main_llm_from_tuple(
def get_llms_for_persona(
persona: Persona | PersonaOverrideConfig | None,
persona: Persona,
llm_override: LLMOverride | None = None,
additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None,
) -> tuple[LLM, LLM]:
if persona is None:
logger.warning("No persona provided, using default LLMs")
return get_default_llms()
model_provider_override = llm_override.model_provider if llm_override else None
model_version_override = llm_override.model_version if llm_override else None
temperature_override = llm_override.temperature if llm_override else None
@@ -79,7 +71,6 @@ def get_llms_for_persona(
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
temperature=temperature_override,
additional_headers=additional_headers,
long_term_logger=long_term_logger,
)
@@ -137,13 +128,11 @@ def get_llm(
api_base: str | None = None,
api_version: str | None = None,
custom_config: dict[str, str] | None = None,
temperature: float | None = None,
temperature: float = GEN_AI_TEMPERATURE,
timeout: int = QA_TIMEOUT,
additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None,
) -> LLM:
if temperature is None:
temperature = GEN_AI_TEMPERATURE
return DefaultMultiLLM(
model_provider=provider,
model_name=model,

View File

@@ -1,59 +0,0 @@
from typing import TYPE_CHECKING
from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from pydantic import BaseModel
from danswer.configs.constants import MessageType
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.utils import build_content_with_imgs
from danswer.tools.models import ToolCallFinalResult
if TYPE_CHECKING:
from danswer.db.models import ChatMessage
class PreviousMessage(BaseModel):
"""Simplified version of `ChatMessage`"""
message: str
token_count: int
message_type: MessageType
files: list[InMemoryChatFile]
tool_call: ToolCallFinalResult | None
@classmethod
def from_chat_message(
cls, chat_message: "ChatMessage", available_files: list[InMemoryChatFile]
) -> "PreviousMessage":
message_file_ids = (
[file["id"] for file in chat_message.files] if chat_message.files else []
)
return cls(
message=chat_message.message,
token_count=chat_message.token_count,
message_type=chat_message.message_type,
files=[
file
for file in available_files
if str(file.file_id) in message_file_ids
],
tool_call=ToolCallFinalResult(
tool_name=chat_message.tool_call.tool_name,
tool_args=chat_message.tool_call.tool_arguments,
tool_result=chat_message.tool_call.tool_result,
)
if chat_message.tool_call
else None,
)
def to_langchain_msg(self) -> BaseMessage:
content = build_content_with_imgs(self.message, self.files)
if self.message_type == MessageType.USER:
return HumanMessage(content=content)
elif self.message_type == MessageType.ASSISTANT:
return AIMessage(content=content)
else:
return SystemMessage(content=content)

View File

@@ -5,6 +5,8 @@ from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
from typing import Union
import litellm # type: ignore
import pandas as pd
@@ -34,15 +36,17 @@ from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from danswer.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
from danswer.db.models import ChatMessage
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.interfaces import LLM
from danswer.prompts.constants import CODE_BLOCK_PAT
from danswer.utils.b64 import get_image_type
from danswer.utils.b64 import get_image_type_from_bytes
from danswer.utils.logger import setup_logger
from shared_configs.configs import LOG_LEVEL
if TYPE_CHECKING:
from danswer.llm.answering.models import PreviousMessage
logger = setup_logger()
@@ -100,6 +104,39 @@ def litellm_exception_to_error_msg(
return error_msg
def translate_danswer_msg_to_langchain(
msg: Union[ChatMessage, "PreviousMessage"],
) -> BaseMessage:
files: list[InMemoryChatFile] = []
# If the message is a `ChatMessage`, it doesn't have the downloaded files
# attached. Just ignore them for now.
if not isinstance(msg, ChatMessage):
files = msg.files
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
if msg.message_type == MessageType.SYSTEM:
raise ValueError("System messages are not currently part of history")
if msg.message_type == MessageType.ASSISTANT:
return AIMessage(content=content)
if msg.message_type == MessageType.USER:
return HumanMessage(content=content)
raise ValueError(f"New message type {msg.message_type} not handled")
def translate_history_to_basemessages(
history: list[ChatMessage] | list["PreviousMessage"],
) -> tuple[list[BaseMessage], list[int]]:
history_basemessages = [
translate_danswer_msg_to_langchain(msg)
for msg in history
if msg.token_count != 0
]
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
return history_basemessages, history_token_counts
# Processes CSV files to show the first 5 rows and max_columns (default 40) columns
def _process_csv_file(file: InMemoryChatFile, max_columns: int = 40) -> str:
df = pd.read_csv(io.StringIO(file.content.decode("utf-8")))
@@ -153,7 +190,6 @@ def build_content_with_imgs(
message: str,
files: list[InMemoryChatFile] | None = None,
img_urls: list[str] | None = None,
b64_imgs: list[str] | None = None,
message_type: MessageType = MessageType.USER,
) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type
files = files or []
@@ -166,7 +202,6 @@ def build_content_with_imgs(
)
img_urls = img_urls or []
b64_imgs = b64_imgs or []
message_main_content = _build_content(message, files)
@@ -185,22 +220,11 @@ def build_content_with_imgs(
{
"type": "image_url",
"image_url": {
"url": (
f"data:{get_image_type_from_bytes(file.content)};"
f"base64,{file.to_base64()}"
),
"url": f"data:image/jpeg;base64,{file.to_base64()}",
},
}
for file in img_files
]
+ [
{
"type": "image_url",
"image_url": {
"url": f"data:{get_image_type(b64_img)};base64,{b64_img}",
},
}
for b64_img in b64_imgs
for file in files
if file.file_type == "image"
]
+ [
{

View File

@@ -25,6 +25,7 @@ from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRead
from danswer.auth.schemas import UserUpdate
from danswer.auth.users import auth_backend
from danswer.auth.users import BasicAuthenticationError
from danswer.auth.users import create_danswer_oauth_router
from danswer.auth.users import fastapi_users
from danswer.configs.app_configs import APP_API_PREFIX
@@ -91,7 +92,6 @@ from danswer.server.settings.api import basic_router as settings_router
from danswer.server.token_rate_limits.api import (
router as token_rate_limit_settings_router,
)
from danswer.server.utils import BasicAuthenticationError
from danswer.setup import setup_danswer
from danswer.setup import setup_multitenant_danswer
from danswer.utils.logger import setup_logger
@@ -206,7 +206,7 @@ def log_http_error(_: Request, exc: Exception) -> JSONResponse:
if isinstance(exc, BasicAuthenticationError):
# For BasicAuthenticationError, just log a brief message without stack trace (almost always spam)
logger.warning(f"Authentication failed: {str(exc)}")
logger.error(f"Authentication failed: {str(exc)}")
elif status_code >= 400:
error_msg = f"{str(exc)}\n"

View File

@@ -0,0 +1,456 @@
from collections.abc import Callable
from collections.abc import Iterator
from typing import cast
from sqlalchemy.orm import Session
from danswer.chat.chat_utils import reorganize_citations
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerContexts
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import DocumentRelevance
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import QADocsResponse
from danswer.chat.models import RelevanceAnalysis
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.constants import MessageType
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.models import RerankMetricsContainer
from danswer.context.search.models import RetrievalMetricsContainer
from danswer.context.search.utils import chunks_or_sections_to_search_docs
from danswer.context.search.utils import dedupe_documents
from danswer.db.chat import create_chat_session
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_or_create_root_message
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.chat import update_search_docs_table_with_relevance
from danswer.db.engine import get_session_context_manager
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.db.persona import get_prompt_by_id
from danswer.llm.answering.answer import Answer
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import CitationConfig
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.models import QuotesConfig
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.one_shot_answer.models import QueryRephrase
from danswer.one_shot_answer.qa_utils import combine_message_thread
from danswer.one_shot_answer.qa_utils import slackify_message_thread
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.utils import get_json_line
from danswer.tools.force import ForceUseTool
from danswer.tools.models import ToolResponse
from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from danswer.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary
from danswer.tools.tool_implementations.search.search_tool import SearchTool
from danswer.tools.tool_implementations.search.search_tool import (
SECTION_RELEVANCE_LIST_ID,
)
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.utils.logger import setup_logger
from danswer.utils.long_term_log import LongTermLogger
from danswer.utils.timing import log_generator_function_time
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
AnswerObjectIterator = Iterator[
QueryRephrase
| QADocsResponse
| LLMRelevanceFilterResponse
| DanswerAnswerPiece
| DanswerQuotes
| DanswerContexts
| StreamingError
| ChatMessageDetail
| CitationInfo
| ToolCallKickoff
| DocumentRelevance
]
def stream_answer_objects(
query_req: DirectQARequest,
user: User | None,
# These need to be passed in because in Web UI one shot flow,
# we can have much more document as there is no history.
# For Slack flow, we need to save more tokens for the thread context
max_document_tokens: int | None,
max_history_tokens: int | None,
db_session: Session,
# Needed to translate persona num_chunks to tokens to the LLM
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
timeout: int = QA_TIMEOUT,
bypass_acl: bool = False,
use_citations: bool = False,
danswerbot_flow: bool = False,
retrieval_metrics_callback: (
Callable[[RetrievalMetricsContainer], None] | None
) = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> AnswerObjectIterator:
"""Streams in order:
1. [always] Retrieved documents, stops flow if nothing is found
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
3. [always] A set of streamed DanswerAnswerPiece and DanswerQuotes at the end
or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
user_id = user.id if user is not None else None
query_msg = query_req.messages[-1]
history = query_req.messages[:-1]
chat_session = create_chat_session(
db_session=db_session,
description="", # One shot queries don't need naming as it's never displayed
user_id=user_id,
persona_id=query_req.persona_id,
one_shot=True,
danswerbot_flow=danswerbot_flow,
)
# permanent "log" store, used primarily for debugging
long_term_logger = LongTermLogger(
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session.id)}
)
temporary_persona: Persona | None = None
if query_req.persona_config is not None:
temporary_persona = fetch_ee_implementation_or_noop(
"danswer.server.query_and_chat.utils", "create_temporary_persona", None
)(db_session=db_session, persona_config=query_req.persona_config, user=user)
persona = temporary_persona if temporary_persona else chat_session.persona
try:
llm, fast_llm = get_llms_for_persona(
persona=persona, long_term_logger=long_term_logger
)
except ValueError as e:
logger.error(
f"Failed to initialize LLMs for persona '{persona.name}': {str(e)}"
)
if "No LLM provider" in str(e):
raise ValueError(
"Please configure a Generative AI model to use this feature."
) from e
raise ValueError(
"Failed to initialize the AI model. Please check your configuration and try again."
) from e
llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,
provider_type=llm.config.model_provider,
)
# Create a chat session which will just store the root message, the query, and the AI response
root_message = get_or_create_root_message(
chat_session_id=chat_session.id, db_session=db_session
)
history_str = combine_message_thread(
messages=history,
max_tokens=max_history_tokens,
llm_tokenizer=llm_tokenizer,
)
rephrased_query = query_req.query_override or thread_based_query_rephrase(
user_query=query_msg.message,
history_str=history_str,
)
# Given back ahead of the documents for latency reasons
# In chat flow it's given back along with the documents
yield QueryRephrase(rephrased_query=rephrased_query)
prompt = None
if query_req.prompt_id is not None:
# NOTE: let the user access any prompt as long as the Persona is shared
# with them
prompt = get_prompt_by_id(
prompt_id=query_req.prompt_id, user=None, db_session=db_session
)
if prompt is None:
if not persona.prompts:
raise RuntimeError(
"Persona does not have any prompts - this should never happen"
)
prompt = persona.prompts[0]
user_message_str = query_msg.message
# For this endpoint, we only save one user message to the chat session
# However, for slackbot, we want to include the history of the entire thread
if danswerbot_flow:
# Right now, we only support bringing over citations and search docs
# from the last message in the thread, not the entire thread
# in the future, we may want to retrieve the entire thread
user_message_str = slackify_message_thread(query_req.messages)
# Create the first User query message
new_user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=root_message,
prompt_id=query_req.prompt_id,
message=user_message_str,
token_count=len(llm_tokenizer.encode(user_message_str)),
message_type=MessageType.USER,
db_session=db_session,
commit=True,
)
prompt_config = PromptConfig.from_model(prompt)
document_pruning_config = DocumentPruningConfig(
max_chunks=int(
persona.num_chunks if persona.num_chunks is not None else default_num_chunks
),
max_tokens=max_document_tokens,
)
answer_config = AnswerStyleConfig(
citation_config=CitationConfig() if use_citations else None,
quotes_config=QuotesConfig() if not use_citations else None,
document_pruning_config=document_pruning_config,
)
search_tool = SearchTool(
db_session=db_session,
user=user,
evaluation_type=(
LLMEvaluationType.SKIP
if DISABLE_LLM_DOC_RELEVANCE
else query_req.evaluation_type
),
persona=persona,
retrieval_options=query_req.retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
answer_style_config=answer_config,
bypass_acl=bypass_acl,
chunks_above=query_req.chunks_above,
chunks_below=query_req.chunks_below,
full_doc=query_req.full_doc,
)
answer = Answer(
question=query_msg.message,
answer_style_config=answer_config,
prompt_config=PromptConfig.from_model(prompt),
llm=get_main_llm_from_tuple(
get_llms_for_persona(persona=persona, long_term_logger=long_term_logger)
),
single_message_history=history_str,
tools=[search_tool] if search_tool else [],
force_use_tool=(
ForceUseTool(
tool_name=search_tool.name,
args={"query": rephrased_query},
force_use=True,
)
),
# for now, don't use tool calling for this flow, as we haven't
# tested quotes with tool calling too much yet
skip_explicit_tool_calling=True,
return_contexts=query_req.return_contexts,
skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation,
)
# won't be any FileChatDisplay responses since that tool is never passed in
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
# for one-shot flow, don't currently do anything with these
if isinstance(packet, ToolResponse):
# (likely fine that it comes after the initial creation of the search docs)
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
search_response_summary = cast(SearchResponseSummary, packet.response)
top_docs = chunks_or_sections_to_search_docs(
search_response_summary.top_sections
)
# Deduping happens at the last step to avoid harming quality by dropping content early on
deduped_docs = top_docs
if query_req.retrieval_options.dedupe_docs:
deduped_docs, dropped_inds = dedupe_documents(top_docs)
reference_db_search_docs = [
create_db_search_doc(server_search_doc=doc, db_session=db_session)
for doc in deduped_docs
]
response_docs = [
translate_db_search_doc_to_server_search_doc(db_search_doc)
for db_search_doc in reference_db_search_docs
]
initial_response = QADocsResponse(
rephrased_query=rephrased_query,
top_documents=response_docs,
predicted_flow=search_response_summary.predicted_flow,
predicted_search=search_response_summary.predicted_search,
applied_source_filters=search_response_summary.final_filters.source_type,
applied_time_cutoff=search_response_summary.final_filters.time_cutoff,
recency_bias_multiplier=search_response_summary.recency_bias_multiplier,
)
yield initial_response
elif packet.id == SEARCH_DOC_CONTENT_ID:
yield packet.response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
document_based_response = {}
if packet.response is not None:
for evaluation in packet.response:
document_based_response[
evaluation.document_id
] = RelevanceAnalysis(
relevant=evaluation.relevant, content=evaluation.content
)
evaluation_response = DocumentRelevance(
relevance_summaries=document_based_response
)
if reference_db_search_docs is not None:
update_search_docs_table_with_relevance(
db_session=db_session,
reference_db_search_docs=reference_db_search_docs,
relevance_summary=evaluation_response,
)
yield evaluation_response
else:
yield packet
# Saving Gen AI answer and responding with message info
gen_ai_response_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=new_user_message,
prompt_id=query_req.prompt_id,
message=answer.llm_answer,
token_count=len(llm_tokenizer.encode(answer.llm_answer)),
message_type=MessageType.ASSISTANT,
error=None,
reference_docs=reference_db_search_docs,
db_session=db_session,
commit=True,
)
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
yield msg_detail_response
@log_generator_function_time()
def stream_search_answer(
query_req: DirectQARequest,
user: User | None,
max_document_tokens: int | None,
max_history_tokens: int | None,
) -> Iterator[str]:
with get_session_context_manager() as session:
objects = stream_answer_objects(
query_req=query_req,
user=user,
max_document_tokens=max_document_tokens,
max_history_tokens=max_history_tokens,
db_session=session,
)
for obj in objects:
yield get_json_line(obj.model_dump())
def get_search_answer(
query_req: DirectQARequest,
user: User | None,
max_document_tokens: int | None,
max_history_tokens: int | None,
db_session: Session,
answer_generation_timeout: int = QA_TIMEOUT,
enable_reflexion: bool = False,
bypass_acl: bool = False,
use_citations: bool = False,
danswerbot_flow: bool = False,
retrieval_metrics_callback: (
Callable[[RetrievalMetricsContainer], None] | None
) = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> OneShotQAResponse:
"""Collects the streamed one shot answer responses into a single object"""
qa_response = OneShotQAResponse()
results = stream_answer_objects(
query_req=query_req,
user=user,
max_document_tokens=max_document_tokens,
max_history_tokens=max_history_tokens,
db_session=db_session,
bypass_acl=bypass_acl,
use_citations=use_citations,
danswerbot_flow=danswerbot_flow,
timeout=answer_generation_timeout,
retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback,
)
answer = ""
for packet in results:
if isinstance(packet, QueryRephrase):
qa_response.rephrase = packet.rephrased_query
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
elif isinstance(packet, QADocsResponse):
qa_response.docs = packet
elif isinstance(packet, LLMRelevanceFilterResponse):
qa_response.llm_selected_doc_indices = packet.llm_selected_doc_indices
elif isinstance(packet, DanswerQuotes):
qa_response.quotes = packet
elif isinstance(packet, CitationInfo):
if qa_response.citations:
qa_response.citations.append(packet)
else:
qa_response.citations = [packet]
elif isinstance(packet, DanswerContexts):
qa_response.contexts = packet
elif isinstance(packet, StreamingError):
qa_response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
qa_response.chat_message_id = packet.message_id
if answer:
qa_response.answer = answer
if enable_reflexion:
# Because follow up messages are explicitly tagged, we don't need to verify the answer
if len(query_req.messages) == 1:
first_query = query_req.messages[0].message
qa_response.answer_valid = get_answer_validity(first_query, answer)
else:
qa_response.answer_valid = True
if use_citations and qa_response.answer and qa_response.citations:
# Reorganize citation nums to be in the same order as the answer
qa_response.answer, qa_response.citations = reorganize_citations(
qa_response.answer, qa_response.citations
)
return qa_response

View File

@@ -0,0 +1,114 @@
from typing import Any
from pydantic import BaseModel
from pydantic import Field
from pydantic import model_validator
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerContexts
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import QADocsResponse
from danswer.configs.constants import MessageType
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.enums import RecencyBiasSetting
from danswer.context.search.enums import SearchType
from danswer.context.search.models import ChunkContext
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
class QueryRephrase(BaseModel):
rephrased_query: str
class ThreadMessage(BaseModel):
message: str
sender: str | None = None
role: MessageType = MessageType.USER
class PromptConfig(BaseModel):
name: str
description: str = ""
system_prompt: str
task_prompt: str = ""
include_citations: bool = True
datetime_aware: bool = True
class ToolConfig(BaseModel):
id: int
class PersonaConfig(BaseModel):
name: str
description: str
search_type: SearchType = SearchType.SEMANTIC
num_chunks: float | None = None
llm_relevance_filter: bool = False
llm_filter_extraction: bool = False
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
prompts: list[PromptConfig] = Field(default_factory=list)
prompt_ids: list[int] = Field(default_factory=list)
document_set_ids: list[int] = Field(default_factory=list)
tools: list[ToolConfig] = Field(default_factory=list)
tool_ids: list[int] = Field(default_factory=list)
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
class DirectQARequest(ChunkContext):
persona_config: PersonaConfig | None = None
persona_id: int | None = None
messages: list[ThreadMessage]
prompt_id: int | None = None
multilingual_query_expansion: list[str] | None = None
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
rerank_settings: RerankingDetails | None = None
evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED
chain_of_thought: bool = False
return_contexts: bool = False
# allows the caller to specify the exact search query they want to use
# can be used if the message sent to the LLM / query should not be the same
# will also disable Thread-based Rewording if specified
query_override: str | None = None
# If True, skips generative an AI response to the search query
skip_gen_ai_answer_generation: bool = False
@model_validator(mode="after")
def check_persona_fields(self) -> "DirectQARequest":
if (self.persona_config is None) == (self.persona_id is None):
raise ValueError("Exactly one of persona_config or persona_id must be set")
return self
@model_validator(mode="after")
def check_chain_of_thought_and_prompt_id(self) -> "DirectQARequest":
if self.chain_of_thought and self.prompt_id is not None:
raise ValueError(
"If chain_of_thought is True, prompt_id must be None"
"The chain of thought prompt is only for question "
"answering and does not accept customizing."
)
return self
class OneShotQAResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break
answer: str | None = None
rephrase: str | None = None
quotes: DanswerQuotes | None = None
citations: list[CitationInfo] | None = None
docs: QADocsResponse | None = None
llm_selected_doc_indices: list[int] | None = None
error_msg: str | None = None
answer_valid: bool = True # Reflexion result, default True if Reflexion not run
chat_message_id: int | None = None
contexts: DanswerContexts | None = None

View File

@@ -0,0 +1,81 @@
from collections.abc import Generator
from danswer.configs.constants import MessageType
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.one_shot_answer.models import ThreadMessage
from danswer.utils.logger import setup_logger
logger = setup_logger()
def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:
"""Mock streaming by generating the passed in model output, character by character"""
for token in model_out:
yield token
def combine_message_thread(
messages: list[ThreadMessage],
max_tokens: int | None,
llm_tokenizer: BaseTokenizer,
) -> str:
"""Used to create a single combined message context from threads"""
if not messages:
return ""
message_strs: list[str] = []
total_token_count = 0
for message in reversed(messages):
if message.role == MessageType.USER:
role_str = message.role.value.upper()
if message.sender:
role_str += " " + message.sender
else:
# Since other messages might have the user identifying information
# better to use Unknown for symmetry
role_str += " Unknown"
else:
role_str = message.role.value.upper()
msg_str = f"{role_str}:\n{message.message}"
message_token_count = len(llm_tokenizer.encode(msg_str))
if (
max_tokens is not None
and total_token_count + message_token_count > max_tokens
):
break
message_strs.insert(0, msg_str)
total_token_count += message_token_count
return "\n\n".join(message_strs)
def slackify_message(message: ThreadMessage) -> str:
if message.role != MessageType.USER:
return message.message
return f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
def slackify_message_thread(messages: list[ThreadMessage]) -> str:
if not messages:
return ""
message_strs: list[str] = []
for message in messages:
if message.role == MessageType.USER:
message_text = (
f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
)
elif message.role == MessageType.ASSISTANT:
message_text = f"DanswerBot said in Slack:\n{message.message}"
else:
message_text = (
f"{message.role.value.upper()} said in Slack:\n{message.message}"
)
message_strs.append(message_text)
return "\n\n".join(message_strs)

View File

@@ -5,11 +5,11 @@ from typing import cast
from langchain_core.messages import BaseMessage
from danswer.chat.models import LlmDoc
from danswer.chat.models import PromptConfig
from danswer.configs.chat_configs import LANGUAGE_HINT
from danswer.configs.constants import DocumentSource
from danswer.context.search.models import InferenceChunk
from danswer.db.models import Prompt
from danswer.llm.answering.models import PromptConfig
from danswer.prompts.chat_prompts import ADDITIONAL_INFO
from danswer.prompts.chat_prompts import CITATION_REMINDER
from danswer.prompts.constants import CODE_BLOCK_PAT

View File

@@ -10,7 +10,6 @@ from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import (
construct_document_select_for_connector_credential_pair_by_needs_sync,
@@ -106,7 +105,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,

View File

@@ -12,7 +12,6 @@ from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import construct_document_select_for_connector_credential_pair
from danswer.db.models import Document as DbDocument
@@ -115,7 +114,7 @@ class RedisConnectorDelete:
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc.id,
connector_id=cc_pair.connector_id,

View File

@@ -12,7 +12,6 @@ from danswer.access.models import DocExternalAccess
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
class RedisConnectorPermissionSyncPayload(BaseModel):
@@ -133,8 +132,6 @@ class RedisConnectorPermissionSync:
lock: RedisLock | None,
new_permissions: list[DocExternalAccess],
source_string: str,
connector_id: int,
credential_id: int,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
@@ -152,13 +149,11 @@ class RedisConnectorPermissionSync:
self.redis.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
DanswerCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
"update_external_document_permissions_task",
kwargs=dict(
tenant_id=self.tenant_id,
serialized_doc_external_access=doc_perm.to_dict(),
source_string=source_string,
connector_id=connector_id,
credential_id=credential_id,
),
queue=DanswerCeleryQueues.DOC_PERMISSIONS_UPSERT,
task_id=custom_task_id,

View File

@@ -10,7 +10,6 @@ from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
@@ -135,7 +134,7 @@ class RedisConnectorPrune:
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc_id,
connector_id=cc_pair.connector_id,

View File

@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.db.document_set import construct_document_select_by_docset
from danswer.redis.redis_object_helper import RedisObjectHelper
@@ -77,7 +76,7 @@ class RedisDocumentSet(RedisObjectHelper):
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,

View File

@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.redis.redis_object_helper import RedisObjectHelper
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import global_version
@@ -90,7 +89,7 @@ class RedisUserGroup(RedisObjectHelper):
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,

View File

@@ -3,14 +3,14 @@ from langchain.schema import HumanMessage
from langchain.schema import SystemMessage
from danswer.chat.chat_utils import combine_message_chain
from danswer.chat.prompt_builder.utils import translate_danswer_msg_to_langchain
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.db.models import ChatMessage
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.llm.models import PreviousMessage
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
from danswer.llm.utils import translate_danswer_msg_to_langchain
from danswer.prompts.chat_prompts import AGGRESSIVE_SEARCH_TEMPLATE
from danswer.prompts.chat_prompts import NO_SEARCH
from danswer.prompts.chat_prompts import REQUIRE_SEARCH_HINT

View File

@@ -4,10 +4,10 @@ from danswer.chat.chat_utils import combine_message_chain
from danswer.configs.chat_configs import DISABLE_LLM_QUERY_REPHRASE
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.db.models import ChatMessage
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llms
from danswer.llm.interfaces import LLM
from danswer.llm.models import PreviousMessage
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
from danswer.prompts.chat_prompts import HISTORY_QUERY_REPHRASE

View File

@@ -20,7 +20,6 @@ from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
from danswer.background.celery.versioned_apps.primary import app as primary_app
from danswer.configs.app_configs import ENABLED_CONNECTOR_TYPES
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import FileOrigin
from danswer.connectors.google_utils.google_auth import (
@@ -86,7 +85,6 @@ from danswer.db.models import SearchSettings
from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.file_processing.extract_file_text import convert_docx_to_txt
from danswer.file_store.file_store import get_default_file_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.redis.redis_connector import RedisConnector
@@ -394,12 +392,6 @@ def upload_files(
file_origin=FileOrigin.CONNECTOR,
file_type=file.content_type or "text/plain",
)
if file.content_type and file.content_type.startswith(
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
convert_docx_to_txt(file, file_store, file_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return FileUploadResponse(file_paths=deduped_file_paths)
@@ -875,7 +867,7 @@ def connector_run_once(
# run the beat task to pick up the triggers immediately
primary_app.send_task(
DanswerCeleryTask.CHECK_FOR_INDEXING,
"check_for_indexing",
priority=DanswerCeleryPriority.HIGH,
kwargs={"tenant_id": tenant_id},
)
@@ -1017,18 +1009,37 @@ def get_connector_by_id(
class BasicCCPairInfo(BaseModel):
docs_indexed: int
has_successful_run: bool
source: DocumentSource
@router.get("/connector-status")
@router.get("/indexing-status")
def get_basic_connector_indexing_status(
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[BasicCCPairInfo]:
cc_pairs = get_connector_credential_pairs(db_session)
cc_pair_identifiers = [
ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
)
for cc_pair in cc_pairs
]
document_count_info = get_document_counts_for_cc_pairs(
db_session=db_session,
cc_pair_identifiers=cc_pair_identifiers,
)
cc_pair_to_document_cnt = {
(connector_id, credential_id): cnt
for connector_id, credential_id, cnt in document_count_info
}
return [
BasicCCPairInfo(
docs_indexed=cc_pair_to_document_cnt.get(
(cc_pair.connector_id, cc_pair.credential_id)
)
or 0,
has_successful_run=cc_pair.last_successful_index_time is not None,
source=cc_pair.connector.source,
)

View File

@@ -13,7 +13,6 @@ from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_limited_user
from danswer.auth.users import current_user
from danswer.chat.prompt_builder.utils import build_dummy_prompt
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import NotificationType
from danswer.db.engine import get_session
@@ -34,6 +33,7 @@ from danswer.db.persona import update_persona_shared_users
from danswer.db.persona import update_persona_visibility
from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import ChatFileType
from danswer.llm.answering.prompts.utils import build_dummy_prompt
from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.features.persona.models import ImageGenerationToolStatus
from danswer.server.features.persona.models import PersonaCategoryCreate

View File

@@ -13,7 +13,6 @@ from danswer.auth.users import current_curator_or_admin_user
from danswer.background.celery.versioned_apps.primary import app as primary_app
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_GEN_AI_KEY_CHECK_TIME
from danswer.db.connector_credential_pair import get_connector_credential_pair
@@ -200,7 +199,7 @@ def create_deletion_attempt_for_connector_id(
# run the beat task to pick up this deletion from the db immediately
primary_app.send_task(
DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
"check_for_connector_deletion_task",
priority=DanswerCeleryPriority.HIGH,
kwargs={"tenant_id": tenant_id},
)

View File

@@ -1,5 +1,4 @@
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING
from pydantic import BaseModel
@@ -16,6 +15,7 @@ from danswer.danswerbot.slack.config import VALID_SLACK_FILTERS
from danswer.db.models import AllowedAnswerFilters
from danswer.db.models import ChannelConfig
from danswer.db.models import SlackBot as SlackAppModel
from danswer.db.models import SlackBotResponseType
from danswer.db.models import SlackChannelConfig as SlackChannelConfigModel
from danswer.db.models import User
from danswer.server.features.persona.models import PersonaSnapshot
@@ -148,12 +148,6 @@ class SlackBotTokens(BaseModel):
model_config = ConfigDict(frozen=True)
# TODO No longer in use, remove later
class SlackBotResponseType(str, Enum):
QUOTES = "quotes"
CITATIONS = "citations"
class SlackChannelConfigCreationRequest(BaseModel):
slack_bot_id: int
# currently, a persona is created for each Slack channel config
@@ -203,6 +197,7 @@ class SlackChannelConfig(BaseModel):
id: int
persona: PersonaSnapshot | None
channel_config: ChannelConfig
response_type: SlackBotResponseType
# XXX this is going away soon
standard_answer_categories: list[StandardAnswerCategory]
enable_auto_filters: bool
@@ -222,6 +217,7 @@ class SlackChannelConfig(BaseModel):
else None
),
channel_config=slack_channel_config_model.channel_config,
response_type=slack_channel_config_model.response_type,
# XXX this is going away soon
standard_answer_categories=[
StandardAnswerCategory.from_model(standard_answer_category_model)

View File

@@ -118,6 +118,7 @@ def create_slack_channel_config(
slack_bot_id=slack_channel_config_creation_request.slack_bot_id,
persona_id=persona_id,
channel_config=channel_config,
response_type=slack_channel_config_creation_request.response_type,
standard_answer_category_ids=slack_channel_config_creation_request.standard_answer_categories,
db_session=db_session,
enable_auto_filters=slack_channel_config_creation_request.enable_auto_filters,
@@ -181,6 +182,7 @@ def patch_slack_channel_config(
slack_channel_config_id=slack_channel_config_id,
persona_id=persona_id,
channel_config=channel_config,
response_type=slack_channel_config_creation_request.response_type,
standard_answer_category_ids=slack_channel_config_creation_request.standard_answer_categories,
enable_auto_filters=slack_channel_config_creation_request.enable_auto_filters,
)

View File

@@ -26,6 +26,7 @@ from danswer.auth.noauth_user import fetch_no_auth_user
from danswer.auth.noauth_user import set_no_auth_user_preferences
from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserStatus
from danswer.auth.users import BasicAuthenticationError
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
@@ -33,6 +34,7 @@ from danswer.auth.users import optional_user
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import ENABLE_EMAIL_INVITES
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import SUPER_USERS
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
from danswer.configs.constants import AuthType
from danswer.db.api_key import is_api_key_email_address
@@ -59,11 +61,9 @@ from danswer.server.manage.models import UserRoleUpdateRequest
from danswer.server.models import FullUserSnapshot
from danswer.server.models import InvitedUserSnapshot
from danswer.server.models import MinimalUserSnapshot
from danswer.server.utils import BasicAuthenticationError
from danswer.server.utils import send_user_email_invite
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
from ee.danswer.configs.app_configs import SUPER_USERS
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -194,11 +194,11 @@ def bulk_invite_users(
)
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
new_invited_emails = []
normalized_emails = []
try:
for email in emails:
email_info = validate_email(email)
new_invited_emails.append(email_info.normalized)
normalized_emails.append(email_info.normalized) # type: ignore
except (EmailUndeliverableError, EmailNotValidError) as e:
raise HTTPException(
@@ -210,7 +210,7 @@ def bulk_invite_users(
try:
fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning", "add_users_to_tenant", None
)(new_invited_emails, tenant_id)
)(normalized_emails, tenant_id)
except IntegrityError as e:
if isinstance(e.orig, UniqueViolation):
@@ -224,7 +224,7 @@ def bulk_invite_users(
initial_invited_users = get_invited_users()
all_emails = list(set(new_invited_emails) | set(initial_invited_users))
all_emails = list(set(normalized_emails) | set(initial_invited_users))
number_of_invited_users = write_invited_users(all_emails)
if not MULTI_TENANT:
@@ -236,7 +236,7 @@ def bulk_invite_users(
)(CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session))
if ENABLE_EMAIL_INVITES:
try:
for email in new_invited_emails:
for email in all_emails:
send_user_email_invite(email, current_user)
except Exception as e:
logger.error(f"Error sending email invite to invited users: {e}")
@@ -250,7 +250,7 @@ def bulk_invite_users(
write_invited_users(initial_invited_users) # Reset to original state
fetch_ee_implementation_or_noop(
"danswer.server.tenants.user_mapping", "remove_users_from_tenant", None
)(new_invited_emails, tenant_id)
)(normalized_emails, tenant_id)
raise e

View File

@@ -109,7 +109,6 @@ def process_run_in_background(
prompt_id=chat_session.persona.prompts[0].id,
search_doc_ids=None,
retrieval_options=search_tool_retrieval_details, # Adjust as needed
rerank_settings=None,
query_override=None,
regenerate=None,
llm_override=None,

View File

@@ -1,7 +1,6 @@
import asyncio
import io
import json
import os
import uuid
from collections.abc import Callable
from collections.abc import Generator
@@ -24,9 +23,6 @@ from danswer.auth.users import current_user
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.chat_utils import extract_headers
from danswer.chat.process_message import stream_chat_message
from danswer.chat.prompt_builder.citations_prompt import (
compute_max_document_tokens_for_persona,
)
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import MessageType
@@ -51,11 +47,13 @@ from danswer.db.models import User
from danswer.db.persona import get_persona_by_id
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.file_processing.extract_file_text import docx_to_txt_filename
from danswer.file_processing.extract_file_text import extract_file_text
from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import FileDescriptor
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llms
from danswer.llm.factory import get_llms_for_persona
@@ -709,30 +707,14 @@ def upload_files_for_chat(
}
@router.get("/file/{file_id:path}")
@router.get("/file/{file_id}")
def fetch_chat_file(
file_id: str,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_user),
) -> Response:
file_store = get_default_file_store(db_session)
file_record = file_store.read_file_record(file_id)
if not file_record:
raise HTTPException(status_code=404, detail="File not found")
original_file_name = file_record.display_name
if file_record.file_type.startswith(
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
# Check if a converted text file exists for .docx files
txt_file_name = docx_to_txt_filename(original_file_name)
txt_file_id = os.path.join(os.path.dirname(file_id), txt_file_name)
txt_file_record = file_store.read_file_record(txt_file_id)
if txt_file_record:
file_record = txt_file_record
file_id = txt_file_id
media_type = file_record.file_type
file_io = file_store.read_file(file_id, mode="b")
return StreamingResponse(file_io, media_type=media_type)
# NOTE: specifying "image/jpeg" here, but it still works for pngs
# TODO: do this properly
return Response(content=file_io.read(), media_type="image/jpeg")

View File

@@ -1,19 +1,16 @@
from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING
from uuid import UUID
from pydantic import BaseModel
from pydantic import model_validator
from danswer.chat.models import PersonaOverrideConfig
from danswer.chat.models import RetrievalDocs
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
from danswer.configs.constants import SearchFeedbackType
from danswer.context.search.models import BaseFilters
from danswer.context.search.models import ChunkContext
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
from danswer.context.search.models import SearchDoc
from danswer.context.search.models import Tag
@@ -23,9 +20,6 @@ from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
from danswer.tools.models import ToolCallFinalResult
if TYPE_CHECKING:
pass
class SourceTag(Tag):
source: DocumentSource
@@ -93,8 +87,6 @@ class CreateChatMessageRequest(ChunkContext):
# If search_doc_ids provided, then retrieval options are unused
search_doc_ids: list[int] | None
retrieval_options: RetrievalDetails | None
# Useable via the APIs but not recommended for most flows
rerank_settings: RerankingDetails | None = None
# allows the caller to specify the exact search query they want to use
# will disable Query Rewording if specified
query_override: str | None = None
@@ -110,10 +102,6 @@ class CreateChatMessageRequest(ChunkContext):
# allow user to specify an alternate assistnat
alternate_assistant_id: int | None = None
# This takes the priority over the prompt_override
# This won't be a type that's passed in directly from the API
persona_override_config: PersonaOverrideConfig | None = None
# used for seeded chats to kick off the generation of an AI answer
use_existing_user_message: bool = False
@@ -157,7 +145,7 @@ class RenameChatSessionResponse(BaseModel):
class ChatSessionDetails(BaseModel):
id: UUID
name: str | None
name: str
persona_id: int | None = None
time_created: str
shared_status: ChatSessionSharedStatus
@@ -210,14 +198,14 @@ class ChatMessageDetail(BaseModel):
class SearchSessionDetailResponse(BaseModel):
search_session_id: UUID
description: str | None
description: str
documents: list[SearchDoc]
messages: list[ChatMessageDetail]
class ChatSessionDetailResponse(BaseModel):
chat_session_id: UUID
description: str | None
description: str
persona_id: int | None = None
persona_name: str | None
messages: list[ChatMessageDetail]

View File

@@ -1,11 +1,15 @@
import json
from collections.abc import Generator
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_limited_user
from danswer.auth.users import current_user
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
@@ -28,6 +32,8 @@ from danswer.db.search_settings import get_current_search_settings
from danswer.db.tag import find_tags
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.vespa.index import VespaIndex
from danswer.one_shot_answer.answer_question import stream_search_answer
from danswer.one_shot_answer.models import DirectQARequest
from danswer.server.query_and_chat.models import AdminSearchRequest
from danswer.server.query_and_chat.models import AdminSearchResponse
from danswer.server.query_and_chat.models import ChatSessionDetails
@@ -35,6 +41,7 @@ from danswer.server.query_and_chat.models import ChatSessionsResponse
from danswer.server.query_and_chat.models import SearchSessionDetailResponse
from danswer.server.query_and_chat.models import SourceTag
from danswer.server.query_and_chat.models import TagResponse
from danswer.server.query_and_chat.token_limit import check_token_rate_limits
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -133,7 +140,7 @@ def get_user_search_sessions(
try:
search_sessions = get_chat_sessions_by_user(
user_id=user_id, deleted=False, db_session=db_session
user_id=user_id, deleted=False, db_session=db_session, only_one_shot=True
)
except ValueError:
raise HTTPException(
@@ -222,3 +229,29 @@ def get_search_session(
],
)
return response
@basic_router.post("/stream-answer-with-quote")
def get_answer_with_quote(
query_request: DirectQARequest,
user: User = Depends(current_limited_user),
_: None = Depends(check_token_rate_limits),
) -> StreamingResponse:
query = query_request.messages[0].message
logger.notice(f"Received query for one shot answer with quotes: {query}")
def stream_generator() -> Generator[str, None, None]:
try:
for packet in stream_search_answer(
query_req=query_request,
user=user,
max_document_tokens=None,
max_history_tokens=0,
):
yield json.dumps(packet) if isinstance(packet, dict) else packet
except Exception as e:
logger.exception("Error in search answer streaming")
yield json.dumps({"error": str(e)})
return StreamingResponse(stream_generator(), media_type="application/json")

View File

@@ -6,9 +6,6 @@ from email.mime.text import MIMEText
from textwrap import dedent
from typing import Any
from fastapi import HTTPException
from fastapi import status
from danswer.configs.app_configs import SMTP_PASS
from danswer.configs.app_configs import SMTP_PORT
from danswer.configs.app_configs import SMTP_SERVER
@@ -17,11 +14,6 @@ from danswer.configs.app_configs import WEB_DOMAIN
from danswer.db.models import User
class BasicAuthenticationError(HTTPException):
def __init__(self, detail: str):
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
class DateTimeEncoder(json.JSONEncoder):
"""Custom JSON encoder that converts datetime objects to ISO format strings."""

View File

@@ -4,7 +4,6 @@ from sqlalchemy.orm import Session
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.app_configs import MANAGED_VESPA
from danswer.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
from danswer.configs.constants import KV_REINDEX_KEY
from danswer.configs.constants import KV_SEARCH_SETTINGS
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
@@ -222,13 +221,13 @@ def setup_vespa(
document_index: DocumentIndex,
index_setting: IndexingSetting,
secondary_index_setting: IndexingSetting | None,
num_attempts: int = VESPA_NUM_ATTEMPTS_ON_STARTUP,
) -> bool:
# Vespa startup is a bit slow, so give it a few seconds
WAIT_SECONDS = 5
for x in range(num_attempts):
VESPA_ATTEMPTS = 5
for x in range(VESPA_ATTEMPTS):
try:
logger.notice(f"Setting up Vespa (attempt {x+1}/{num_attempts})...")
logger.notice(f"Setting up Vespa (attempt {x+1}/{VESPA_ATTEMPTS})...")
document_index.ensure_indices_exist(
index_embedding_dim=index_setting.model_dim,
secondary_index_embedding_dim=secondary_index_setting.model_dim
@@ -245,7 +244,7 @@ def setup_vespa(
time.sleep(WAIT_SECONDS)
logger.error(
f"Vespa setup did not succeed. Attempt limit reached. ({num_attempts})"
f"Vespa setup did not succeed. Attempt limit reached. ({VESPA_ATTEMPTS})"
)
return False

View File

@@ -7,7 +7,7 @@ from danswer.llm.utils import message_to_prompt_and_imgs
from danswer.tools.tool import Tool
if TYPE_CHECKING:
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.tools.tool_implementations.custom.custom_tool import (
CustomToolCallSummary,
)

View File

@@ -3,13 +3,13 @@ from collections.abc import Generator
from typing import Any
from typing import TYPE_CHECKING
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.llm.models import PreviousMessage
from danswer.utils.special_types import JSON_ro
if TYPE_CHECKING:
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.tools.message import ToolCallSummary
from danswer.tools.models import ToolResponse

View File

@@ -5,10 +5,6 @@ from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from danswer.chat.models import AnswerStyleConfig
from danswer.chat.models import CitationConfig
from danswer.chat.models import DocumentPruningConfig
from danswer.chat.models import PromptConfig
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
@@ -17,12 +13,15 @@ from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import CitationConfig
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.interfaces import LLM
from danswer.llm.interfaces import LLMConfig
from danswer.natural_language_processing.utils import get_tokenizer
@@ -103,14 +102,11 @@ class SearchToolConfig(BaseModel):
default_factory=DocumentPruningConfig
)
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
rerank_settings: RerankingDetails | None = None
selected_sections: list[InferenceSection] | None = None
chunks_above: int = 0
chunks_below: int = 0
full_doc: bool = False
latest_query_files: list[InMemoryChatFile] | None = None
# Use with care, should only be used for DanswerBot in channels with multiple users
bypass_acl: bool = False
class InternetSearchToolConfig(BaseModel):
@@ -174,8 +170,6 @@ def construct_tools(
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
rerank_settings=search_tool_config.rerank_settings,
bypass_acl=search_tool_config.bypass_acl,
)
tool_dict[db_tool_model.id] = [search_tool]

View File

@@ -15,14 +15,14 @@ from langchain_core.messages import SystemMessage
from pydantic import BaseModel
from requests import JSONDecodeError
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
from danswer.configs.constants import FileOrigin
from danswer.db.engine import get_session_with_default_tenant
from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.interfaces import LLM
from danswer.llm.models import PreviousMessage
from danswer.tools.base_tool import BaseTool
from danswer.tools.message import ToolCallSummary
from danswer.tools.models import CHAT_SESSION_ID_PLACEHOLDER

View File

@@ -4,16 +4,14 @@ from enum import Enum
from typing import Any
from typing import cast
import requests
from litellm import image_generation # type: ignore
from pydantic import BaseModel
from danswer.chat.chat_utils import combine_message_chain
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.configs.tool_configs import IMAGE_GENERATION_OUTPUT_FORMAT
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.interfaces import LLM
from danswer.llm.models import PreviousMessage
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import message_to_string
from danswer.prompts.constants import GENERAL_SEP_PAT
@@ -58,18 +56,9 @@ Follow Up Input:
""".strip()
class ImageFormat(str, Enum):
URL = "url"
BASE64 = "b64_json"
_DEFAULT_OUTPUT_FORMAT = ImageFormat(IMAGE_GENERATION_OUTPUT_FORMAT)
class ImageGenerationResponse(BaseModel):
revised_prompt: str
url: str | None
image_data: str | None
url: str
class ImageShape(str, Enum):
@@ -91,7 +80,6 @@ class ImageGenerationTool(Tool):
model: str = "dall-e-3",
num_imgs: int = 2,
additional_headers: dict[str, str] | None = None,
output_format: ImageFormat = _DEFAULT_OUTPUT_FORMAT,
) -> None:
self.api_key = api_key
self.api_base = api_base
@@ -101,7 +89,6 @@ class ImageGenerationTool(Tool):
self.num_imgs = num_imgs
self.additional_headers = additional_headers
self.output_format = output_format
@property
def name(self) -> str:
@@ -181,7 +168,7 @@ class ImageGenerationTool(Tool):
)
return build_content_with_imgs(
message=json.dumps(
json.dumps(
[
{
"revised_prompt": image_generation.revised_prompt,
@@ -190,10 +177,13 @@ class ImageGenerationTool(Tool):
for image_generation in image_generations
]
),
# NOTE: we can't pass in the image URLs here, since OpenAI doesn't allow
# Tool messages to contain images
# img_urls=[image_generation.url for image_generation in image_generations],
)
def _generate_image(
self, prompt: str, shape: ImageShape, format: ImageFormat
self, prompt: str, shape: ImageShape
) -> ImageGenerationResponse:
if shape == ImageShape.LANDSCAPE:
size = "1792x1024"
@@ -207,32 +197,20 @@ class ImageGenerationTool(Tool):
prompt=prompt,
model=self.model,
api_key=self.api_key,
# need to pass in None rather than empty str
api_base=self.api_base or None,
api_version=self.api_version or None,
size=size,
n=1,
response_format=format,
extra_headers=build_llm_extra_headers(self.additional_headers),
)
if format == ImageFormat.URL:
url = response.data[0]["url"]
image_data = None
else:
url = None
image_data = response.data[0]["b64_json"]
return ImageGenerationResponse(
revised_prompt=response.data[0]["revised_prompt"],
url=url,
image_data=image_data,
url=response.data[0]["url"],
)
except requests.RequestException as e:
logger.error(f"Error fetching or converting image: {e}")
raise ValueError("Failed to fetch or convert the generated image")
except Exception as e:
logger.debug(f"Error occurred during image generation: {e}")
logger.debug(f"Error occured during image generation: {e}")
error_message = str(e)
if "OpenAIException" in str(type(e)):
@@ -257,8 +235,9 @@ class ImageGenerationTool(Tool):
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
prompt = cast(str, kwargs["prompt"])
shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE))
format = self.output_format
# dalle3 only supports 1 image at a time, which is why we have to
# parallelize this via threading
results = cast(
list[ImageGenerationResponse],
run_functions_tuples_in_parallel(
@@ -268,7 +247,6 @@ class ImageGenerationTool(Tool):
(
prompt,
shape,
format,
),
)
for _ in range(self.num_imgs)
@@ -310,17 +288,11 @@ class ImageGenerationTool(Tool):
if img_generation_response is None:
raise ValueError("No image generation response found")
img_urls = [img.url for img in img_generation_response if img.url is not None]
b64_imgs = [
img.image_data
for img in img_generation_response
if img.image_data is not None
]
img_urls = [img.url for img in img_generation_response]
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=prompt_builder.get_user_message_content(),
img_urls=img_urls,
b64_imgs=b64_imgs,
)
)

View File

@@ -11,14 +11,11 @@ Can you please summarize them in a sentence or two? Do NOT include image urls or
def build_image_generation_user_prompt(
query: str,
img_urls: list[str] | None = None,
b64_imgs: list[str] | None = None,
query: str, img_urls: list[str] | None = None
) -> HumanMessage:
return HumanMessage(
content=build_content_with_imgs(
message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(),
b64_imgs=b64_imgs,
img_urls=img_urls,
)
)

View File

@@ -7,15 +7,15 @@ from typing import cast
import httpx
from danswer.chat.chat_utils import combine_message_chain
from danswer.chat.models import AnswerStyleConfig
from danswer.chat.models import LlmDoc
from danswer.chat.models import PromptConfig
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
from danswer.configs.constants import DocumentSource
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.context.search.models import SearchDoc
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.interfaces import LLM
from danswer.llm.models import PreviousMessage
from danswer.llm.utils import message_to_string
from danswer.prompts.chat_prompts import INTERNET_SEARCH_QUERY_REPHRASE
from danswer.prompts.constants import GENERAL_SEP_PAT
@@ -77,7 +77,6 @@ def llm_doc_from_internet_search_result(result: InternetSearchResult) -> LlmDoc:
updated_at=datetime.now(),
link=result.link,
source_links={0: result.link},
match_highlights=[],
)

View File

@@ -7,19 +7,10 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.chat.chat_utils import llm_doc_from_inference_section
from danswer.chat.llm_response_handler import LLMCall
from danswer.chat.models import AnswerStyleConfig
from danswer.chat.models import ContextualPruningConfig
from danswer.chat.models import DanswerContext
from danswer.chat.models import DanswerContexts
from danswer.chat.models import DocumentPruningConfig
from danswer.chat.models import LlmDoc
from danswer.chat.models import PromptConfig
from danswer.chat.models import SectionRelevancePiece
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
from danswer.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
from danswer.chat.prune_and_merge import prune_and_merge_sections
from danswer.chat.prune_and_merge import prune_sections
from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
@@ -28,14 +19,22 @@ from danswer.context.search.enums import QueryFlow
from danswer.context.search.enums import SearchType
from danswer.context.search.models import IndexFilters
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
from danswer.context.search.models import SearchRequest
from danswer.context.search.pipeline import SearchPipeline
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.llm.answering.llm_response_handler import LLMCall
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import ContextualPruningConfig
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input_tokens
from danswer.llm.answering.prune_and_merge import prune_and_merge_sections
from danswer.llm.answering.prune_and_merge import prune_sections
from danswer.llm.interfaces import LLM
from danswer.llm.models import PreviousMessage
from danswer.secondary_llm_flows.choose_search import check_if_need_search
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
from danswer.tools.message import ToolCallSummary
@@ -104,7 +103,6 @@ class SearchTool(Tool):
chunks_below: int | None = None,
full_doc: bool = False,
bypass_acl: bool = False,
rerank_settings: RerankingDetails | None = None,
) -> None:
self.user = user
self.persona = persona
@@ -120,9 +118,6 @@ class SearchTool(Tool):
self.bypass_acl = bypass_acl
self.db_session = db_session
# Only used via API
self.rerank_settings = rerank_settings
self.chunks_above = (
chunks_above
if chunks_above is not None
@@ -297,7 +292,6 @@ class SearchTool(Tool):
self.retrieval_options.offset if self.retrieval_options else None
),
limit=self.retrieval_options.limit if self.retrieval_options else None,
rerank_settings=self.rerank_settings,
chunks_above=self.chunks_above,
chunks_below=self.chunks_below,
full_doc=self.full_doc,

View File

@@ -2,15 +2,15 @@ from typing import cast
from langchain_core.messages import HumanMessage
from danswer.chat.models import AnswerStyleConfig
from danswer.chat.models import LlmDoc
from danswer.chat.models import PromptConfig
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
from danswer.chat.prompt_builder.citations_prompt import (
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.answering.prompts.citations_prompt import (
build_citations_system_message,
)
from danswer.chat.prompt_builder.citations_prompt import build_citations_user_message
from danswer.chat.prompt_builder.quotes_prompt import build_quotes_user_message
from danswer.llm.answering.prompts.citations_prompt import build_citations_user_message
from danswer.llm.answering.prompts.quotes_prompt import build_quotes_user_message
from danswer.tools.message import ToolCallSummary
from danswer.tools.models import ToolResponse

View File

@@ -2,8 +2,8 @@ from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.llm.models import PreviousMessage
from danswer.tools.models import ToolCallFinalResult
from danswer.tools.models import ToolCallKickoff
from danswer.tools.models import ToolResponse

View File

@@ -3,8 +3,8 @@ from typing import Any
from danswer.chat.chat_utils import combine_message_chain
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.llm.models import PreviousMessage
from danswer.llm.utils import message_to_string
from danswer.prompts.constants import GENERAL_SEP_PAT
from danswer.tools.tool import Tool

Some files were not shown because too many files have changed in this diff Show More